mirror of
https://github.com/jafioti/luminal.git
synced 2026-06-01 21:49:47 +09:00
Compare commits
151 Commits
readme-ref
...
worktree-b
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
028c7cb484 | ||
|
|
3a3cd04958 | ||
|
|
d21f55ed78 | ||
|
|
b2bd91f594 | ||
|
|
35ebf0c7c7 | ||
|
|
dea8a3e7aa | ||
|
|
439648a649 | ||
|
|
2d858829c7 | ||
|
|
6673d1d935 | ||
|
|
65f3cceaa1 | ||
|
|
cfe27e8001 | ||
|
|
9594d41e21 | ||
|
|
a2ce18063b | ||
|
|
f925431ad5 | ||
|
|
b6e5a71383 | ||
|
|
3a20266785 | ||
|
|
33ff774d62 | ||
|
|
ea04149691 | ||
|
|
aaeefeee8c | ||
|
|
0b917abd03 | ||
|
|
d9a5fcfe9f | ||
|
|
cf4d88bf48 | ||
|
|
98b9b8ac54 | ||
|
|
64eb2641fd | ||
|
|
c0f3970feb | ||
|
|
dbdb31523c | ||
|
|
da84f1a5a3 | ||
|
|
a5ab33a680 | ||
|
|
7235a98a43 | ||
|
|
6f291c4b9a | ||
|
|
b739a21d3b | ||
|
|
88bcd12a96 | ||
|
|
8bdcae291c | ||
|
|
322b85fd95 | ||
|
|
a590942274 | ||
|
|
45ae09b1c2 | ||
|
|
8f3f2a3048 | ||
|
|
6a7cefd3b2 | ||
|
|
f94f7ca43d | ||
|
|
86800211ff | ||
|
|
08c06d440e | ||
|
|
50733ea85c | ||
|
|
cfbdef2569 | ||
|
|
de2e820f48 | ||
|
|
30f067fa94 | ||
|
|
5f14b1e84f | ||
|
|
b5d6daf08e | ||
|
|
cf9c27aca9 | ||
|
|
1e3dff6ee7 | ||
|
|
e3968edb1a | ||
|
|
04b407560b | ||
|
|
ee0456d5bc | ||
|
|
b6403ec1be | ||
|
|
c2e12b666f | ||
|
|
89238d4b24 | ||
|
|
16c7345e5a | ||
|
|
bfbefc2fe1 | ||
|
|
2724466a3f | ||
|
|
4d1ff217be | ||
|
|
44b293bee0 | ||
|
|
f9b9657c1c | ||
|
|
6db0f716d5 | ||
|
|
d03ab816d8 | ||
|
|
61904fbc76 | ||
|
|
f461fca3da | ||
|
|
5f199e94c6 | ||
|
|
93fb02c495 | ||
|
|
16de9638fc | ||
|
|
f08d24e73f | ||
|
|
aba9627563 | ||
|
|
7d68b62aa8 | ||
|
|
13c870de86 | ||
|
|
f8b742d718 | ||
|
|
3555d169bd | ||
|
|
be74153c12 | ||
|
|
75535c93f0 | ||
|
|
84f13cae00 | ||
|
|
0e2ea24e46 | ||
|
|
703c2d9ea4 | ||
|
|
d03a41ec96 | ||
|
|
8aa9f14741 | ||
|
|
44324f1c2d | ||
|
|
f6845011d8 | ||
|
|
6e7ee5581d | ||
|
|
2e3158c48e | ||
|
|
8af22776aa | ||
|
|
cd8c01f620 | ||
|
|
461b746937 | ||
|
|
38e467aa6c | ||
|
|
7429ac163b | ||
|
|
07c151dd70 | ||
|
|
c0f7f1f054 | ||
|
|
df96fe5110 | ||
|
|
1460e6a3ee | ||
|
|
18a550dd15 | ||
|
|
254680001d | ||
|
|
2920011897 | ||
|
|
d879376697 | ||
|
|
2be30c18cd | ||
|
|
48f921d2a1 | ||
|
|
f55e7e0589 | ||
|
|
db2027d345 | ||
|
|
9a5032bfc9 | ||
|
|
c665b01c4e | ||
|
|
883508e682 | ||
|
|
080b99b69e | ||
|
|
0bd19289ea | ||
|
|
a138db0236 | ||
|
|
6a17670244 | ||
|
|
a3b7f6ecc1 | ||
|
|
438ae460bf | ||
|
|
da440fdef0 | ||
|
|
586365be4d | ||
|
|
3c962a9df8 | ||
|
|
1a460bac96 | ||
|
|
ce06a901cc | ||
|
|
c97288cdae | ||
|
|
d66b3f2643 | ||
|
|
66b0807462 | ||
|
|
c24ea4a7a5 | ||
|
|
c309d9b4ed | ||
|
|
745c071ee5 | ||
|
|
56ffe8bbb3 | ||
|
|
13dbdcb53b | ||
|
|
c8ad5f8b75 | ||
|
|
51c6596f6a | ||
|
|
aef4c68537 | ||
|
|
1ac423c36c | ||
|
|
59c38b3c88 | ||
|
|
9b3b2f5244 | ||
|
|
aed7b86aad | ||
|
|
e3c6d98f36 | ||
|
|
10971d7d05 | ||
|
|
4b0bfa5669 | ||
|
|
2c0c3bb988 | ||
|
|
ca6fac8f78 | ||
|
|
900fee4d67 | ||
|
|
59901c8b12 | ||
|
|
a860a2cb6b | ||
|
|
52b2a45c62 | ||
|
|
0af1c186fd | ||
|
|
e6d13a3979 | ||
|
|
86b2784b51 | ||
|
|
773935b91b | ||
|
|
fb23b80a01 | ||
|
|
d6a3171b7b | ||
|
|
59edd0b179 | ||
|
|
53c58576fc | ||
|
|
64e4eedcc6 | ||
|
|
63afb602b0 | ||
|
|
985e7752aa |
6
.gitignore
vendored
6
.gitignore
vendored
@@ -37,3 +37,9 @@ __pycache__/
|
||||
dist/
|
||||
build/
|
||||
uv.lock
|
||||
|
||||
# TTFT benchmark SQLite database (per-machine state)
|
||||
benchmarks/ttft/bench.db
|
||||
benchmarks/ttft/bench.db-journal
|
||||
benchmarks/ttft/bench.db-wal
|
||||
benchmarks/ttft/bench.db-shm
|
||||
|
||||
@@ -32,6 +32,7 @@ pretty-duration = "0.1.1"
|
||||
anyhow = "1.0"
|
||||
graphviz-rust = { version = "0.9", default-features = false}
|
||||
lru = "0.16.2"
|
||||
rayon = "1.10"
|
||||
|
||||
[workspace.package]
|
||||
edition = "2024"
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
<img href="luminal.com" alt="Screenshot 2025-08-14 at 9 18 54 PM" src="https://github.com/user-attachments/assets/c5832634-55d5-45b7-ba65-6efe36afce4a" />
|
||||
<img href="luminal.com" alt="Screenshot 2025-08-14 at 9 18 54 PM" src="https://github.com/luminal-ai/luminal/blob/main/docs/logo/inference_at_the_speed_of_light.png" />
|
||||
|
||||
<h3 align="center">
|
||||
Luminal is a high-performance general-purpose inference compiler.
|
||||
|
||||
117
benchmarks/ttft/bench_python_baseline.py
Normal file
117
benchmarks/ttft/bench_python_baseline.py
Normal file
@@ -0,0 +1,117 @@
|
||||
"""Pure HuggingFace/PyTorch TTFT + TPOT bench. Prints a JSON line on stdout.
|
||||
|
||||
Measures:
|
||||
TTFT — sum of single-token forward-pass durations over the prompt, using
|
||||
a StaticCache. Methodology matches bench_python_luminal.py and the
|
||||
rust path so the cross-path comparison is apples-to-apples.
|
||||
TPOT — average time per output token during KV-cache greedy decode.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import statistics
|
||||
import time
|
||||
|
||||
import torch
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||
from transformers.cache_utils import StaticCache
|
||||
|
||||
from bench_utils import encode_prompt, measure_tpot, static_cache_config
|
||||
|
||||
DEFAULT_MODEL = "NousResearch/Meta-Llama-3-8B-Instruct"
|
||||
DEFAULT_PROMPT = "Explain what a neural network is in a paragraph."
|
||||
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--model", default=DEFAULT_MODEL)
|
||||
ap.add_argument("--prompt", default=DEFAULT_PROMPT)
|
||||
ap.add_argument("--warmups", type=int, default=1)
|
||||
ap.add_argument("--iters", type=int, default=3)
|
||||
ap.add_argument("--dtype", default="float32", choices=["float32", "bfloat16", "float16"])
|
||||
ap.add_argument("--decode-tokens", type=int, default=50,
|
||||
help="Number of tokens to generate for TPOT measurement (0 = skip).")
|
||||
ap.add_argument("--max-cache-len", type=int, default=256,
|
||||
help="StaticCache max sequence length.")
|
||||
args = ap.parse_args()
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
dtype = {"float32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch.float16}[args.dtype]
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model)
|
||||
input_ids = encode_prompt(tokenizer, args.prompt, device)
|
||||
prompt_tokens = int(input_ids.shape[-1])
|
||||
|
||||
config = AutoConfig.from_pretrained(args.model)
|
||||
config._attn_implementation = "eager"
|
||||
|
||||
model = (
|
||||
AutoModelForCausalLM.from_pretrained(args.model, config=config, torch_dtype=dtype)
|
||||
.eval()
|
||||
.to(device)
|
||||
)
|
||||
|
||||
single_token = torch.zeros(1, 1, dtype=torch.long, device=device)
|
||||
|
||||
cache_config = static_cache_config(config)
|
||||
|
||||
def make_cache():
|
||||
return StaticCache(
|
||||
config=cache_config,
|
||||
max_batch_size=1,
|
||||
max_cache_len=args.max_cache_len,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
def measure_ttft() -> float:
|
||||
"""Sum of per-token forward-pass durations over prompt_tokens steps."""
|
||||
kv = make_cache()
|
||||
# Eager init at position 0 to satisfy StaticCache.lazy_initialization.
|
||||
with torch.no_grad():
|
||||
model(single_token, past_key_values=kv,
|
||||
cache_position=torch.tensor([0], device=device))
|
||||
total_ms = 0.0
|
||||
for pos in range(1, prompt_tokens):
|
||||
if device.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.perf_counter()
|
||||
with torch.no_grad():
|
||||
model(single_token, past_key_values=kv,
|
||||
cache_position=torch.tensor([pos], device=device))
|
||||
if device.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
total_ms += (time.perf_counter() - t0) * 1000.0
|
||||
return total_ms
|
||||
|
||||
for _ in range(args.warmups):
|
||||
measure_ttft()
|
||||
|
||||
ttft_samples_ms = [measure_ttft() for _ in range(args.iters)]
|
||||
|
||||
result = {
|
||||
"path": "python_baseline",
|
||||
"model": args.model,
|
||||
"device": str(device),
|
||||
"dtype": args.dtype,
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"iters": args.iters,
|
||||
"ttft_ms": statistics.median(ttft_samples_ms),
|
||||
"ttft_ms_mean": sum(ttft_samples_ms) / len(ttft_samples_ms),
|
||||
"ttft_ms_samples": ttft_samples_ms,
|
||||
"note": "sequential per-token, StaticCache KV cache",
|
||||
}
|
||||
|
||||
if args.decode_tokens > 0:
|
||||
tpot_samples_ms = measure_tpot(model, input_ids, device, args.decode_tokens)
|
||||
tpot_ms = sum(tpot_samples_ms) / len(tpot_samples_ms)
|
||||
result["decode_tokens"] = args.decode_tokens
|
||||
result["tpot_ms"] = tpot_ms
|
||||
result["tpot_ms_samples"] = tpot_samples_ms
|
||||
result["throughput_tps"] = 1000.0 / tpot_ms
|
||||
|
||||
print("BENCH_RESULT " + json.dumps(result))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
196
benchmarks/ttft/bench_python_luminal.py
Normal file
196
benchmarks/ttft/bench_python_luminal.py
Normal file
@@ -0,0 +1,196 @@
|
||||
"""Python -> Luminal TTFT + TPOT bench via torch.compile(backend=luminal_backend).
|
||||
|
||||
Methodology mirrors examples/llama (the Rust path):
|
||||
- One eager prefill step initialises the StaticCache (required by transformers'
|
||||
StaticCache.lazy_initialization) before compilation.
|
||||
- TTFT: run one forward pass per prompt token sequentially, each advancing
|
||||
cache_position by 1; sum durations.
|
||||
- TPOT: run --decode-tokens more single-token passes; average durations.
|
||||
- StaticCache pre-allocates K/V buffers up to max_cache_len; no growing allocation.
|
||||
|
||||
Prints a BENCH_RESULT JSON line on stdout.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import gc
|
||||
import json
|
||||
import statistics
|
||||
import time
|
||||
|
||||
import torch
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||
from transformers.cache_utils import StaticCache
|
||||
|
||||
from bench_utils import encode_prompt, static_cache_config
|
||||
from luminal import luminal_backend
|
||||
|
||||
DEFAULT_MODEL = "NousResearch/Meta-Llama-3-8B-Instruct"
|
||||
DEFAULT_PROMPT = "Explain what a neural network is in a paragraph."
|
||||
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--model", default=DEFAULT_MODEL)
|
||||
ap.add_argument("--prompt", default=DEFAULT_PROMPT)
|
||||
ap.add_argument("--warmups", type=int, default=1)
|
||||
ap.add_argument("--iters", type=int, default=3)
|
||||
ap.add_argument(
|
||||
"--search-iters",
|
||||
type=int,
|
||||
default=500,
|
||||
help="Egraph search iterations (matches examples/llama default of 500).",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--decode-tokens",
|
||||
type=int,
|
||||
default=50,
|
||||
help="Tokens to generate for TPOT measurement (0 = skip TPOT).",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--max-cache-len",
|
||||
type=int,
|
||||
default=256,
|
||||
help="StaticCache max sequence length.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--dtype",
|
||||
default="float32",
|
||||
choices=["float32", "bfloat16", "float16"],
|
||||
help="Torch dtype for model + StaticCache.",
|
||||
)
|
||||
args = ap.parse_args()
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
dtype = {"float32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch.float16}[args.dtype]
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model)
|
||||
input_ids = encode_prompt(tokenizer, args.prompt, device)
|
||||
prompt_tokens = int(input_ids.shape[-1])
|
||||
|
||||
config = AutoConfig.from_pretrained(args.model)
|
||||
config._attn_implementation = "eager"
|
||||
|
||||
model = (
|
||||
AutoModelForCausalLM.from_pretrained(args.model, config=config, torch_dtype=dtype)
|
||||
.eval()
|
||||
.to(device)
|
||||
)
|
||||
|
||||
single_token = torch.zeros(1, 1, dtype=torch.long, device=device)
|
||||
|
||||
cache_config = static_cache_config(config)
|
||||
|
||||
def make_cache():
|
||||
return StaticCache(
|
||||
config=cache_config,
|
||||
max_batch_size=1,
|
||||
max_cache_len=args.max_cache_len,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# Step 0: run ONE eager prefill to initialise the cache tensors and call
|
||||
# mark_static_address (required by transformers' StaticCache before compile).
|
||||
cache = make_cache()
|
||||
with torch.no_grad():
|
||||
model(single_token, past_key_values=cache, cache_position=torch.tensor([0], device=device))
|
||||
|
||||
# Compile for a single-token input — same graph is reused for every step.
|
||||
# Compilation happens on the first call after the eager init above.
|
||||
t0 = time.perf_counter()
|
||||
compiled = torch.compile(
|
||||
model,
|
||||
backend=luminal_backend,
|
||||
options={"search_iterations": args.search_iters},
|
||||
)
|
||||
cache_position = torch.tensor([1], dtype=torch.long, device=device)
|
||||
with torch.no_grad():
|
||||
compiled(single_token, past_key_values=cache, cache_position=cache_position)
|
||||
if device.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
compile_ms = (time.perf_counter() - t0) * 1000.0
|
||||
|
||||
gc.collect()
|
||||
if device.type == "cuda":
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def one_step(pos: int, kv_cache):
|
||||
cache_pos = torch.tensor([pos], dtype=torch.long, device=device)
|
||||
with torch.no_grad():
|
||||
compiled(single_token, past_key_values=kv_cache, cache_position=cache_pos)
|
||||
if device.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def measure_ttft():
|
||||
"""Sum of per-token forward-pass durations over prompt_tokens steps.
|
||||
|
||||
Uses a fresh cache so each TTFT measurement is independent.
|
||||
"""
|
||||
kv = make_cache()
|
||||
# Eager init for this fresh cache (required before compiled can run on it).
|
||||
with torch.no_grad():
|
||||
model(single_token, past_key_values=kv, cache_position=torch.tensor([0], device=device))
|
||||
total_ms = 0.0
|
||||
# Step 0 was the eager init above; measure from step 1 to prompt_tokens.
|
||||
for pos in range(1, prompt_tokens):
|
||||
if device.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.perf_counter()
|
||||
one_step(pos, kv)
|
||||
total_ms += (time.perf_counter() - t0) * 1000.0
|
||||
return total_ms
|
||||
|
||||
def measure_tpot(n, start_pos: int):
|
||||
"""Average single-token forward-pass duration over n decode steps."""
|
||||
kv = make_cache()
|
||||
# Eager init
|
||||
with torch.no_grad():
|
||||
model(single_token, past_key_values=kv, cache_position=torch.tensor([0], device=device))
|
||||
# One warmup step.
|
||||
one_step(1, kv)
|
||||
step_times_ms = []
|
||||
for i in range(n):
|
||||
pos = start_pos + i
|
||||
if device.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.perf_counter()
|
||||
one_step(pos, kv)
|
||||
step_times_ms.append((time.perf_counter() - t0) * 1000.0)
|
||||
return step_times_ms
|
||||
|
||||
# Warmups before timing TTFT (all run after compilation is complete).
|
||||
for _ in range(args.warmups):
|
||||
measure_ttft()
|
||||
|
||||
ttft_samples_ms = [measure_ttft() for _ in range(args.iters)]
|
||||
|
||||
tpot_ms_samples = []
|
||||
if args.decode_tokens > 0:
|
||||
tpot_ms_samples = measure_tpot(args.decode_tokens, start_pos=prompt_tokens)
|
||||
|
||||
tpot_ms = sum(tpot_ms_samples) / len(tpot_ms_samples) if tpot_ms_samples else None
|
||||
throughput_tps = (1000.0 / tpot_ms) if tpot_ms else None
|
||||
|
||||
result = {
|
||||
"path": "python_luminal",
|
||||
"model": args.model,
|
||||
"device": str(device),
|
||||
"dtype": args.dtype,
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"iters": args.iters,
|
||||
"ttft_ms": statistics.median(ttft_samples_ms),
|
||||
"ttft_ms_mean": sum(ttft_samples_ms) / len(ttft_samples_ms),
|
||||
"ttft_ms_samples": ttft_samples_ms,
|
||||
"compile_ms": compile_ms,
|
||||
"search_iters": args.search_iters,
|
||||
"decode_tokens": args.decode_tokens if args.decode_tokens > 0 else None,
|
||||
"tpot_ms": tpot_ms,
|
||||
"tpot_ms_samples": tpot_ms_samples,
|
||||
"throughput_tps": throughput_tps,
|
||||
"note": "sequential per-token, StaticCache KV cache",
|
||||
}
|
||||
print("BENCH_RESULT " + json.dumps(result))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
138
benchmarks/ttft/bench_python_torch_compile.py
Normal file
138
benchmarks/ttft/bench_python_torch_compile.py
Normal file
@@ -0,0 +1,138 @@
|
||||
"""Vanilla torch.compile TTFT + TPOT bench. Prints a JSON line on stdout.
|
||||
|
||||
Uses the default inductor backend (torch.compile without a custom backend).
|
||||
TTFT uses sequential per-token prefill with a StaticCache so the methodology
|
||||
matches bench_python_baseline.py, bench_python_luminal.py, and the rust path.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import statistics
|
||||
import time
|
||||
|
||||
import torch
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||
from transformers.cache_utils import StaticCache
|
||||
|
||||
from bench_utils import encode_prompt, measure_tpot, static_cache_config
|
||||
|
||||
DEFAULT_MODEL = "NousResearch/Meta-Llama-3-8B-Instruct"
|
||||
DEFAULT_PROMPT = "Explain what a neural network is in a paragraph."
|
||||
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--model", default=DEFAULT_MODEL)
|
||||
ap.add_argument("--prompt", default=DEFAULT_PROMPT)
|
||||
ap.add_argument("--warmups", type=int, default=1)
|
||||
ap.add_argument("--iters", type=int, default=3)
|
||||
ap.add_argument("--dtype", default="float32", choices=["float32", "bfloat16", "float16"])
|
||||
ap.add_argument(
|
||||
"--decode-tokens", type=int, default=50,
|
||||
help="Number of tokens to generate for TPOT measurement (0 = skip).",
|
||||
)
|
||||
ap.add_argument("--max-cache-len", type=int, default=256,
|
||||
help="StaticCache max sequence length.")
|
||||
args = ap.parse_args()
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
dtype = {"float32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch.float16}[args.dtype]
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model)
|
||||
input_ids = encode_prompt(tokenizer, args.prompt, device)
|
||||
prompt_tokens = int(input_ids.shape[-1])
|
||||
|
||||
config = AutoConfig.from_pretrained(args.model)
|
||||
config._attn_implementation = "eager"
|
||||
|
||||
model = (
|
||||
AutoModelForCausalLM.from_pretrained(args.model, config=config, torch_dtype=dtype)
|
||||
.eval()
|
||||
.to(device)
|
||||
)
|
||||
|
||||
single_token = torch.zeros(1, 1, dtype=torch.long, device=device)
|
||||
|
||||
cache_config = static_cache_config(config)
|
||||
|
||||
def make_cache():
|
||||
return StaticCache(
|
||||
config=cache_config,
|
||||
max_batch_size=1,
|
||||
max_cache_len=args.max_cache_len,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# Eager init on the uncompiled model so the StaticCache buffers get
|
||||
# registered (mark_static_address) before torch.compile traces them.
|
||||
init_cache = make_cache()
|
||||
with torch.no_grad():
|
||||
model(single_token, past_key_values=init_cache,
|
||||
cache_position=torch.tensor([0], device=device))
|
||||
|
||||
compiled = torch.compile(model)
|
||||
|
||||
# First compiled call triggers JIT compilation; time it as compile_ms.
|
||||
if device.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.perf_counter()
|
||||
with torch.no_grad():
|
||||
compiled(single_token, past_key_values=init_cache,
|
||||
cache_position=torch.tensor([1], device=device))
|
||||
if device.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
compile_ms = (time.perf_counter() - t0) * 1000.0
|
||||
|
||||
def measure_ttft() -> float:
|
||||
"""Sum of per-token compiled-forward durations over prompt_tokens steps."""
|
||||
kv = make_cache()
|
||||
# Fresh cache needs eager init via the uncompiled model first.
|
||||
with torch.no_grad():
|
||||
model(single_token, past_key_values=kv,
|
||||
cache_position=torch.tensor([0], device=device))
|
||||
total_ms = 0.0
|
||||
for pos in range(1, prompt_tokens):
|
||||
if device.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.perf_counter()
|
||||
with torch.no_grad():
|
||||
compiled(single_token, past_key_values=kv,
|
||||
cache_position=torch.tensor([pos], device=device))
|
||||
if device.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
total_ms += (time.perf_counter() - t0) * 1000.0
|
||||
return total_ms
|
||||
|
||||
for _ in range(args.warmups):
|
||||
measure_ttft()
|
||||
|
||||
ttft_samples_ms = [measure_ttft() for _ in range(args.iters)]
|
||||
|
||||
result = {
|
||||
"path": "python_torch_compile",
|
||||
"model": args.model,
|
||||
"device": str(device),
|
||||
"dtype": args.dtype,
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"iters": args.iters,
|
||||
"ttft_ms": statistics.median(ttft_samples_ms),
|
||||
"ttft_ms_mean": sum(ttft_samples_ms) / len(ttft_samples_ms),
|
||||
"ttft_ms_samples": ttft_samples_ms,
|
||||
"compile_ms": compile_ms,
|
||||
"note": "sequential per-token, StaticCache KV cache (torch.compile inductor)",
|
||||
}
|
||||
|
||||
if args.decode_tokens > 0:
|
||||
tpot_samples_ms = measure_tpot(compiled, input_ids, device, args.decode_tokens)
|
||||
tpot_ms = sum(tpot_samples_ms) / len(tpot_samples_ms)
|
||||
result["decode_tokens"] = args.decode_tokens
|
||||
result["tpot_ms"] = tpot_ms
|
||||
result["tpot_ms_samples"] = tpot_samples_ms
|
||||
result["throughput_tps"] = 1000.0 / tpot_ms
|
||||
|
||||
print("BENCH_RESULT " + json.dumps(result))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
94
benchmarks/ttft/bench_utils.py
Normal file
94
benchmarks/ttft/bench_utils.py
Normal file
@@ -0,0 +1,94 @@
|
||||
"""Shared helpers for the Python benchmark scripts."""
|
||||
|
||||
import time
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class _CfgWithoutKvShared:
|
||||
"""Wrapper that hides `num_kv_shared_layers` from a HF config.
|
||||
|
||||
transformers 5.6 has a bug in StaticCache.__init__:
|
||||
if hasattr(config, "num_kv_shared_layers"):
|
||||
layer_types = layer_types[: -config.num_kv_shared_layers]
|
||||
For configs where the attribute is 0 (e.g. Gemma-4), `[:-0]` returns an
|
||||
empty list, leaving StaticCache with zero layer slots, and the LM's
|
||||
first `past_key_values.update(..., layer_idx=0)` raises IndexError.
|
||||
|
||||
This wrapper makes `hasattr(...)` return False so the bad branch never
|
||||
fires. Used via `static_cache_config(config)` below.
|
||||
"""
|
||||
__slots__ = ("_inner",)
|
||||
|
||||
def __init__(self, inner):
|
||||
object.__setattr__(self, "_inner", inner)
|
||||
|
||||
def __getattr__(self, name):
|
||||
if name == "num_kv_shared_layers":
|
||||
raise AttributeError(name)
|
||||
return getattr(self._inner, name)
|
||||
|
||||
def get_text_config(self, *args, **kwargs):
|
||||
return _CfgWithoutKvShared(self._inner.get_text_config(*args, **kwargs))
|
||||
|
||||
|
||||
def static_cache_config(config):
|
||||
"""Return a config suitable for `StaticCache(config=..., ...)`.
|
||||
|
||||
Two normalizations:
|
||||
1. Multimodal wrappers (Gemma4ForConditionalGeneration, ...) nest the
|
||||
actual LM config under `.text_config`. Pass that, not the wrapper,
|
||||
so layer/head counts match the inner LM.
|
||||
2. If the resulting config has `num_kv_shared_layers == 0`, wrap it to
|
||||
hide the attribute (works around the transformers 5.6 slice bug).
|
||||
"""
|
||||
cfg = getattr(config, "text_config", config)
|
||||
if getattr(cfg, "num_kv_shared_layers", None) == 0:
|
||||
cfg = _CfgWithoutKvShared(cfg)
|
||||
return cfg
|
||||
|
||||
|
||||
def encode_prompt(tokenizer, prompt: str, device):
|
||||
"""Tokenize prompt using chat template if available, falling back to raw tokenization."""
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
try:
|
||||
encoded = tokenizer.apply_chat_template(
|
||||
messages, add_generation_prompt=True, return_tensors="pt"
|
||||
)
|
||||
except (ValueError, AttributeError):
|
||||
encoded = tokenizer(prompt, return_tensors="pt")
|
||||
if hasattr(encoded, "input_ids"):
|
||||
return encoded.input_ids.to(device)
|
||||
if isinstance(encoded, dict):
|
||||
return encoded["input_ids"].to(device)
|
||||
return encoded.to(device)
|
||||
|
||||
|
||||
def measure_tpot(model, input_ids, device, decode_tokens: int) -> list[float]:
|
||||
"""Prefill once with KV cache, then time each subsequent single-token decode step."""
|
||||
with torch.no_grad():
|
||||
out = model(input_ids, use_cache=True)
|
||||
if device.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
past = out.past_key_values
|
||||
next_id = out.logits[:, -1:].argmax(-1)
|
||||
|
||||
out = model(next_id, past_key_values=past, use_cache=True)
|
||||
if device.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
past = out.past_key_values
|
||||
next_id = out.logits[:, -1:].argmax(-1)
|
||||
|
||||
step_times_ms = []
|
||||
for _ in range(decode_tokens):
|
||||
if device.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.perf_counter()
|
||||
out = model(next_id, past_key_values=past, use_cache=True)
|
||||
if device.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
step_times_ms.append((time.perf_counter() - t0) * 1000.0)
|
||||
past = out.past_key_values
|
||||
next_id = out.logits[:, -1:].argmax(-1)
|
||||
|
||||
return step_times_ms
|
||||
92
benchmarks/ttft/benchmarks.toml
Normal file
92
benchmarks/ttft/benchmarks.toml
Normal file
@@ -0,0 +1,92 @@
|
||||
[ur_test]
|
||||
models = ["llama-8b", "qwen3-4b", "gemma3-4b", "gemma4-moe", "qwen3-moe"]
|
||||
# 3-point sweep (low/mid/high). The previous list [5, 10, 20, 50, 100, 500]
|
||||
# spent ~62 extra minutes on s=5/s=20/s=50 with little additional information.
|
||||
search_sweep_iters = [10, 100, 500]
|
||||
|
||||
[configs.llama-8b]
|
||||
model = "NousResearch/Meta-Llama-3-8B-Instruct"
|
||||
rust_package = "llama"
|
||||
search_iters = 500
|
||||
iters = 10
|
||||
warmups = 2
|
||||
decode_tokens = 50
|
||||
# On-disk weights are bf16-majority. fp32 upcast doubled python_luminal's
|
||||
# egglog Search peak past the 525 GB unified pool and triggered SIGKILLs on
|
||||
# gemma3-4b (and same risk here). bf16 matches rust's load path.
|
||||
dtype = "bfloat16"
|
||||
|
||||
[configs.as_fast_as_possible]
|
||||
prompt = "The"
|
||||
search_iters = 1
|
||||
iters = 1
|
||||
warmups = 0
|
||||
decode_tokens = 5
|
||||
|
||||
[configs.qwen3-4b]
|
||||
model = "Qwen/Qwen3-4B"
|
||||
rust_package = "qwen"
|
||||
search_iters = 50
|
||||
iters = 10
|
||||
warmups = 2
|
||||
decode_tokens = 20
|
||||
# bf16-majority on-disk; see llama-8b note.
|
||||
dtype = "bfloat16"
|
||||
|
||||
[configs.gemma3-4b]
|
||||
model = "unsloth/gemma-3-4b-it"
|
||||
rust_package = "gemma"
|
||||
search_iters = 50
|
||||
iters = 10
|
||||
warmups = 2
|
||||
decode_tokens = 20
|
||||
# bf16-majority on-disk; see llama-8b note.
|
||||
dtype = "bfloat16"
|
||||
|
||||
[configs.gemma4-moe]
|
||||
model = "google/gemma-4-26B-A4B"
|
||||
rust_package = "gemma4_moe"
|
||||
search_iters = 50
|
||||
iters = 10
|
||||
warmups = 2
|
||||
decode_tokens = 20
|
||||
# 26B params at fp32 = 104 GB → OOM on a 94 GB GPU. Use bf16 (matches the
|
||||
# on-disk safetensors dtype) so the python paths can actually load.
|
||||
dtype = "bfloat16"
|
||||
|
||||
[configs.qwen3-moe]
|
||||
model = "Qwen/Qwen3-30B-A3B"
|
||||
rust_package = "qwen3_moe"
|
||||
search_iters = 50
|
||||
iters = 10
|
||||
warmups = 2
|
||||
decode_tokens = 20
|
||||
# 30B params at fp32 = 120 GB → OOM. See gemma4-moe note.
|
||||
dtype = "bfloat16"
|
||||
|
||||
[configs.llama-8b-const]
|
||||
model = "NousResearch/Meta-Llama-3-8B-Instruct"
|
||||
rust_package = "llama"
|
||||
prompt = "We the People of the United States, in Order to form a more perfect Union, establish Justice, insure domestic Tranquility, provide for the common defence, promote the general Welfare, and secure the Blessings of Liberty to ourselves and our Posterity, do ordain and establish this Constitution for the United States of America."
|
||||
search_iters = 500
|
||||
iters = 10
|
||||
warmups = 2
|
||||
decode_tokens = 20
|
||||
|
||||
[configs.qwen3-4b-const]
|
||||
model = "Qwen/Qwen3-4B"
|
||||
rust_package = "qwen"
|
||||
prompt = "We the People of the United States, in Order to form a more perfect Union, establish Justice, insure domestic Tranquility, provide for the common defence, promote the general Welfare, and secure the Blessings of Liberty to ourselves and our Posterity, do ordain and establish this Constitution for the United States of America."
|
||||
search_iters = 50
|
||||
iters = 10
|
||||
warmups = 2
|
||||
decode_tokens = 20
|
||||
|
||||
[configs.gemma3-4b-const]
|
||||
model = "unsloth/gemma-3-4b-it"
|
||||
rust_package = "gemma"
|
||||
prompt = "We the People of the United States, in Order to form a more perfect Union, establish Justice, insure domestic Tranquility, provide for the common defence, promote the general Welfare, and secure the Blessings of Liberty to ourselves and our Posterity, do ordain and establish this Constitution for the United States of America."
|
||||
search_iters = 50
|
||||
iters = 10
|
||||
warmups = 2
|
||||
decode_tokens = 20
|
||||
610
benchmarks/ttft/dashboard.html
Normal file
610
benchmarks/ttft/dashboard.html
Normal file
@@ -0,0 +1,610 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1">
|
||||
<title>Luminal · Benchmark Dashboard</title>
|
||||
<link rel="preconnect" href="https://fonts.googleapis.com">
|
||||
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
|
||||
<link href="https://fonts.googleapis.com/css2?family=Geist:wght@300;400;500;600&family=Geist+Mono:wght@300;400;500&display=swap" rel="stylesheet">
|
||||
<script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
|
||||
<style>
|
||||
*, *::before, *::after { box-sizing: border-box; margin: 0; padding: 0; }
|
||||
html { -webkit-font-smoothing: antialiased; scroll-behavior: smooth; }
|
||||
|
||||
body {
|
||||
font-family: 'Geist', system-ui, sans-serif;
|
||||
background: #030712;
|
||||
color: #d7d8d9;
|
||||
min-height: 100vh;
|
||||
line-height: 1.5;
|
||||
}
|
||||
|
||||
/* ── NAV ── */
|
||||
nav {
|
||||
position: sticky;
|
||||
top: 0;
|
||||
z-index: 50;
|
||||
height: 56px;
|
||||
background: rgba(8, 15, 17, 0.92);
|
||||
backdrop-filter: blur(8px);
|
||||
-webkit-backdrop-filter: blur(8px);
|
||||
border-bottom: 1px solid #2d3335;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
padding: 0 24px;
|
||||
gap: 0;
|
||||
}
|
||||
.nav-brand {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
font-family: 'Geist Mono', monospace;
|
||||
font-size: 14px;
|
||||
font-weight: 500;
|
||||
letter-spacing: 0.05em;
|
||||
color: #2faa6e;
|
||||
text-decoration: none;
|
||||
}
|
||||
.nav-dot {
|
||||
width: 6px;
|
||||
height: 6px;
|
||||
background: #2faa6e;
|
||||
border-radius: 50%;
|
||||
flex-shrink: 0;
|
||||
animation: pulse-glow 2s ease-in-out infinite;
|
||||
}
|
||||
.nav-sep {
|
||||
color: #2d3335;
|
||||
margin: 0 14px;
|
||||
font-size: 18px;
|
||||
font-weight: 300;
|
||||
}
|
||||
.nav-page {
|
||||
font-family: 'Geist Mono', monospace;
|
||||
font-size: 11px;
|
||||
letter-spacing: 0.1em;
|
||||
text-transform: uppercase;
|
||||
color: #7e8385;
|
||||
}
|
||||
|
||||
@keyframes pulse-glow {
|
||||
0%, 100% { opacity: 1; }
|
||||
50% { opacity: 0.35; }
|
||||
}
|
||||
|
||||
/* ── MAIN ── */
|
||||
main {
|
||||
max-width: 1200px;
|
||||
margin: 0 auto;
|
||||
padding: 40px 24px 80px;
|
||||
}
|
||||
|
||||
/* ── PAGE HEADER ── */
|
||||
.page-header {
|
||||
margin-bottom: 40px;
|
||||
padding-bottom: 32px;
|
||||
border-bottom: 1px solid #1c2225;
|
||||
}
|
||||
.page-eyebrow {
|
||||
font-family: 'Geist Mono', monospace;
|
||||
font-size: 11px;
|
||||
letter-spacing: 0.1em;
|
||||
text-transform: uppercase;
|
||||
color: #2faa6e;
|
||||
margin-bottom: 10px;
|
||||
}
|
||||
.page-title {
|
||||
font-size: 30px;
|
||||
font-weight: 500;
|
||||
letter-spacing: -0.025em;
|
||||
color: #d7d8d9;
|
||||
margin-bottom: 10px;
|
||||
}
|
||||
.page-meta {
|
||||
font-size: 14px;
|
||||
color: #7e8385;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0;
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
.meta-sep {
|
||||
font-family: 'Geist Mono', monospace;
|
||||
color: #2d3335;
|
||||
margin: 0 10px;
|
||||
}
|
||||
.meta-val {
|
||||
font-family: 'Geist Mono', monospace;
|
||||
font-size: 13px;
|
||||
color: #5b5f61;
|
||||
}
|
||||
|
||||
/* ── LEGEND STRIP ── */
|
||||
.legend-strip {
|
||||
display: flex;
|
||||
flex-wrap: wrap;
|
||||
gap: 6px;
|
||||
margin-bottom: 32px;
|
||||
}
|
||||
.legend-pill {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 6px;
|
||||
font-family: 'Geist Mono', monospace;
|
||||
font-size: 11px;
|
||||
letter-spacing: 0.04em;
|
||||
color: #a1a4a5;
|
||||
background: #141b1d;
|
||||
border: 1px solid #2d3335;
|
||||
border-radius: 2px;
|
||||
padding: 4px 10px;
|
||||
}
|
||||
.legend-swatch {
|
||||
width: 8px;
|
||||
height: 8px;
|
||||
border-radius: 50%;
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
/* ── SECTIONS ── */
|
||||
section { margin-bottom: 48px; }
|
||||
.section-header {
|
||||
display: flex;
|
||||
align-items: baseline;
|
||||
gap: 10px;
|
||||
margin-bottom: 16px;
|
||||
padding-bottom: 12px;
|
||||
border-bottom: 1px solid #1c2225;
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
.section-eyebrow {
|
||||
font-family: 'Geist Mono', monospace;
|
||||
font-size: 11px;
|
||||
letter-spacing: 0.1em;
|
||||
text-transform: uppercase;
|
||||
color: #404647;
|
||||
}
|
||||
.section-title {
|
||||
font-size: 18px;
|
||||
font-weight: 500;
|
||||
color: #d7d8d9;
|
||||
letter-spacing: -0.01em;
|
||||
}
|
||||
.section-title .unit {
|
||||
color: #7e8385;
|
||||
font-weight: 400;
|
||||
}
|
||||
.section-tag {
|
||||
font-family: 'Geist Mono', monospace;
|
||||
font-size: 11px;
|
||||
letter-spacing: 0.04em;
|
||||
text-transform: uppercase;
|
||||
color: #2faa6e;
|
||||
background: #162322;
|
||||
border: 1px solid #1c372e;
|
||||
padding: 2px 8px;
|
||||
border-radius: 2px;
|
||||
margin-left: auto;
|
||||
}
|
||||
|
||||
/* ── CHART GRID ── */
|
||||
.chart-grid {
|
||||
display: grid;
|
||||
gap: 10px;
|
||||
}
|
||||
.chart-card {
|
||||
background: #141b1d;
|
||||
border: 1px solid #2d3335;
|
||||
border-radius: 2px;
|
||||
overflow: hidden;
|
||||
transition: border-color 150ms;
|
||||
min-width: 0;
|
||||
}
|
||||
.chart-card:hover { border-color: #404647; }
|
||||
.chart-card-header {
|
||||
padding: 10px 14px 0;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
}
|
||||
.model-tag {
|
||||
font-family: 'Geist Mono', monospace;
|
||||
font-size: 11px;
|
||||
letter-spacing: 0.06em;
|
||||
text-transform: uppercase;
|
||||
color: #7e8385;
|
||||
}
|
||||
|
||||
/* ── FOOTER ── */
|
||||
footer {
|
||||
max-width: 1200px;
|
||||
margin: 0 auto;
|
||||
padding: 20px 24px;
|
||||
border-top: 1px solid #1c2225;
|
||||
font-family: 'Geist Mono', monospace;
|
||||
font-size: 11px;
|
||||
letter-spacing: 0.04em;
|
||||
color: #404647;
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
flex-wrap: wrap;
|
||||
gap: 8px;
|
||||
}
|
||||
|
||||
.section-divider {
|
||||
border: none;
|
||||
border-top: 1px solid #1c2225;
|
||||
margin: 8px 0 40px;
|
||||
}
|
||||
.sweep-hint {
|
||||
font-family: 'Geist Mono', monospace;
|
||||
font-size: 11px;
|
||||
letter-spacing: 0.04em;
|
||||
color: #404647;
|
||||
margin-bottom: 12px;
|
||||
}
|
||||
|
||||
@media (max-width: 768px) {
|
||||
.chart-grid { grid-template-columns: 1fr !important; }
|
||||
.page-title { font-size: 22px; }
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
|
||||
<nav>
|
||||
<a class="nav-brand" href="https://luminal.com">
|
||||
<span class="nav-dot"></span>luminal
|
||||
</a>
|
||||
<span class="nav-sep">/</span>
|
||||
<span class="nav-page">benchmarks</span>
|
||||
</nav>
|
||||
|
||||
<main>
|
||||
|
||||
<header class="page-header">
|
||||
<p class="page-eyebrow">performance · time-series</p>
|
||||
<h1 class="page-title">Benchmark Dashboard</h1>
|
||||
<div class="page-meta">
|
||||
<span>Last updated</span>
|
||||
<span class="meta-sep">·</span>
|
||||
<span class="meta-val">May 01, 2026 · 18:56</span>
|
||||
<span class="meta-sep">·</span>
|
||||
<span class="meta-val">1 run in history</span>
|
||||
</div>
|
||||
</header>
|
||||
|
||||
<div class="legend-strip">
|
||||
<div class="legend-pill"><span class="legend-swatch" style="background:#5b5f61"></span>HF Baseline</div><div class="legend-pill"><span class="legend-swatch" style="background:#3b82f6"></span>torch.compile</div><div class="legend-pill"><span class="legend-swatch" style="background:#a855f7"></span>luminal backend</div><div class="legend-pill"><span class="legend-swatch" style="background:#e8855a"></span>Rust (luminal)</div>
|
||||
</div>
|
||||
|
||||
|
||||
<section>
|
||||
<div class="section-header">
|
||||
<span class="section-eyebrow">metric</span>
|
||||
<h2 class="section-title">TTFT <span class="unit">over time</span></h2>
|
||||
<span class="section-tag">Time to first token (ms)</span>
|
||||
</div>
|
||||
<div class="chart-grid" style="grid-template-columns: repeat(4, 1fr)">
|
||||
<div class="chart-card">
|
||||
<div class="chart-card-header">
|
||||
<span class="model-tag">llama-8b</span>
|
||||
</div>
|
||||
<div id="c_ttft_ms_llama_8b"></div>
|
||||
<script>
|
||||
Plotly.newPlot("c_ttft_ms_llama_8b", [{"x": ["2026-05-01T18-56-26-996695"], "y": [705.9654394979589], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "HF Baseline", "line": {"color": "#5b5f61", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": true, "hovertemplate": "<b>HF Baseline</b><br>%{customdata[1]}<br>%{y:.1f} ms<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}, {"x": ["2026-05-01T18-56-26-996695"], "y": [307.66548847896047], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "torch.compile", "line": {"color": "#3b82f6", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": true, "hovertemplate": "<b>torch.compile</b><br>%{customdata[1]}<br>%{y:.1f} ms<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}, {"x": ["2026-05-01T18-56-26-996695"], "y": [461.48114453535527], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "luminal backend", "line": {"color": "#a855f7", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": true, "hovertemplate": "<b>luminal backend</b><br>%{customdata[1]}<br>%{y:.1f} ms<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}, {"x": ["2026-05-01T18-56-26-996695"], "y": [1026.86], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "Rust (luminal)", "line": {"color": "#e8855a", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": true, "hovertemplate": "<b>Rust (luminal)</b><br>%{customdata[1]}<br>%{y:.1f} ms<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}], {"plot_bgcolor": "#0d1416", "paper_bgcolor": "#141b1d", "font": {"family": "Geist, system-ui, sans-serif", "color": "#d7d8d9"}, "margin": {"t": 16, "b": 48, "l": 52, "r": 12}, "height": 280, "xaxis": {"type": "category", "categoryorder": "array", "color": "#5b5f61", "gridcolor": "#1c2225", "linecolor": "#2d3335", "tickfont": {"size": 11, "family": "Geist Mono, monospace"}, "tickangle": -30, "automargin": true, "zeroline": false, "categoryarray": ["2026-05-01T18-56-26-996695"], "tickvals": ["2026-05-01T18-56-26-996695"], "ticktext": ["May 01 \u00b7 18:56"]}, "yaxis": {"rangemode": "tozero", "color": "#5b5f61", "gridcolor": "#1c2225", "linecolor": "#2d3335", "tickfont": {"size": 11, "family": "Geist Mono, monospace"}, "ticksuffix": " ms", "zeroline": false}, "legend": {"orientation": "h", "y": -0.28, "x": 0, "font": {"size": 11, "color": "#a1a4a5"}, "bgcolor": "rgba(0,0,0,0)"}, "hoverlabel": {"bgcolor": "#1c2225", "bordercolor": "#2d3335", "font": {"size": 12, "color": "#d7d8d9", "family": "Geist Mono, monospace"}}},
|
||||
{responsive: true, displayModeBar: false});
|
||||
</script>
|
||||
</div><div class="chart-card">
|
||||
<div class="chart-card-header">
|
||||
<span class="model-tag">qwen3-4b</span>
|
||||
</div>
|
||||
<div id="c_ttft_ms_qwen3_4b"></div>
|
||||
<script>
|
||||
Plotly.newPlot("c_ttft_ms_qwen3_4b", [{"x": ["2026-05-01T18-56-26-996695"], "y": [869.2860195587855], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "HF Baseline", "line": {"color": "#5b5f61", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": false, "hovertemplate": "<b>HF Baseline</b><br>%{customdata[1]}<br>%{y:.1f} ms<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}, {"x": ["2026-05-01T18-56-26-996695"], "y": [298.27259748708457], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "torch.compile", "line": {"color": "#3b82f6", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": false, "hovertemplate": "<b>torch.compile</b><br>%{customdata[1]}<br>%{y:.1f} ms<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}, {"x": ["2026-05-01T18-56-26-996695"], "y": [485.3892414830625], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "luminal backend", "line": {"color": "#a855f7", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": false, "hovertemplate": "<b>luminal backend</b><br>%{customdata[1]}<br>%{y:.1f} ms<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}, {"x": ["2026-05-01T18-56-26-996695"], "y": [398.58], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "Rust (luminal)", "line": {"color": "#e8855a", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": false, "hovertemplate": "<b>Rust (luminal)</b><br>%{customdata[1]}<br>%{y:.1f} ms<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}], {"plot_bgcolor": "#0d1416", "paper_bgcolor": "#141b1d", "font": {"family": "Geist, system-ui, sans-serif", "color": "#d7d8d9"}, "margin": {"t": 16, "b": 16, "l": 52, "r": 12}, "height": 280, "xaxis": {"type": "category", "categoryorder": "array", "color": "#5b5f61", "gridcolor": "#1c2225", "linecolor": "#2d3335", "tickfont": {"size": 11, "family": "Geist Mono, monospace"}, "tickangle": -30, "automargin": true, "zeroline": false, "categoryarray": ["2026-05-01T18-56-26-996695"], "tickvals": ["2026-05-01T18-56-26-996695"], "ticktext": ["May 01 \u00b7 18:56"]}, "yaxis": {"rangemode": "tozero", "color": "#5b5f61", "gridcolor": "#1c2225", "linecolor": "#2d3335", "tickfont": {"size": 11, "family": "Geist Mono, monospace"}, "ticksuffix": " ms", "zeroline": false}, "legend": {"orientation": "h", "y": -0.28, "x": 0, "font": {"size": 11, "color": "#a1a4a5"}, "bgcolor": "rgba(0,0,0,0)", "visible": false}, "hoverlabel": {"bgcolor": "#1c2225", "bordercolor": "#2d3335", "font": {"size": 12, "color": "#d7d8d9", "family": "Geist Mono, monospace"}}},
|
||||
{responsive: true, displayModeBar: false});
|
||||
</script>
|
||||
</div><div class="chart-card">
|
||||
<div class="chart-card-header">
|
||||
<span class="model-tag">gemma3-4b</span>
|
||||
</div>
|
||||
<div id="c_ttft_ms_gemma3_4b"></div>
|
||||
<script>
|
||||
Plotly.newPlot("c_ttft_ms_gemma3_4b", [{"x": ["2026-05-01T18-56-26-996695"], "y": [951.1196144158021], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "HF Baseline", "line": {"color": "#5b5f61", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": false, "hovertemplate": "<b>HF Baseline</b><br>%{customdata[1]}<br>%{y:.1f} ms<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}, {"x": ["2026-05-01T18-56-26-996695"], "y": [300.9451600664761], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "torch.compile", "line": {"color": "#3b82f6", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": false, "hovertemplate": "<b>torch.compile</b><br>%{customdata[1]}<br>%{y:.1f} ms<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}, {"x": ["2026-05-01T18-56-26-996695"], "y": [404.43], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "Rust (luminal)", "line": {"color": "#e8855a", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": false, "hovertemplate": "<b>Rust (luminal)</b><br>%{customdata[1]}<br>%{y:.1f} ms<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}], {"plot_bgcolor": "#0d1416", "paper_bgcolor": "#141b1d", "font": {"family": "Geist, system-ui, sans-serif", "color": "#d7d8d9"}, "margin": {"t": 16, "b": 16, "l": 52, "r": 12}, "height": 280, "xaxis": {"type": "category", "categoryorder": "array", "color": "#5b5f61", "gridcolor": "#1c2225", "linecolor": "#2d3335", "tickfont": {"size": 11, "family": "Geist Mono, monospace"}, "tickangle": -30, "automargin": true, "zeroline": false, "categoryarray": ["2026-05-01T18-56-26-996695"], "tickvals": ["2026-05-01T18-56-26-996695"], "ticktext": ["May 01 \u00b7 18:56"]}, "yaxis": {"rangemode": "tozero", "color": "#5b5f61", "gridcolor": "#1c2225", "linecolor": "#2d3335", "tickfont": {"size": 11, "family": "Geist Mono, monospace"}, "ticksuffix": " ms", "zeroline": false}, "legend": {"orientation": "h", "y": -0.28, "x": 0, "font": {"size": 11, "color": "#a1a4a5"}, "bgcolor": "rgba(0,0,0,0)", "visible": false}, "hoverlabel": {"bgcolor": "#1c2225", "bordercolor": "#2d3335", "font": {"size": 12, "color": "#d7d8d9", "family": "Geist Mono, monospace"}}},
|
||||
{responsive: true, displayModeBar: false});
|
||||
</script>
|
||||
</div><div class="chart-card">
|
||||
<div class="chart-card-header">
|
||||
<span class="model-tag">gemma4-moe</span>
|
||||
</div>
|
||||
<div id="c_ttft_ms_gemma4_moe"></div>
|
||||
<script>
|
||||
Plotly.newPlot("c_ttft_ms_gemma4_moe", [{"x": ["2026-05-01T18-56-26-996695"], "y": [837.3980740143452], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "HF Baseline", "line": {"color": "#5b5f61", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": false, "hovertemplate": "<b>HF Baseline</b><br>%{customdata[1]}<br>%{y:.1f} ms<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}, {"x": ["2026-05-01T18-56-26-996695"], "y": [245.510076492792], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "torch.compile", "line": {"color": "#3b82f6", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": false, "hovertemplate": "<b>torch.compile</b><br>%{customdata[1]}<br>%{y:.1f} ms<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}], {"plot_bgcolor": "#0d1416", "paper_bgcolor": "#141b1d", "font": {"family": "Geist, system-ui, sans-serif", "color": "#d7d8d9"}, "margin": {"t": 16, "b": 16, "l": 52, "r": 12}, "height": 280, "xaxis": {"type": "category", "categoryorder": "array", "color": "#5b5f61", "gridcolor": "#1c2225", "linecolor": "#2d3335", "tickfont": {"size": 11, "family": "Geist Mono, monospace"}, "tickangle": -30, "automargin": true, "zeroline": false, "categoryarray": ["2026-05-01T18-56-26-996695"], "tickvals": ["2026-05-01T18-56-26-996695"], "ticktext": ["May 01 \u00b7 18:56"]}, "yaxis": {"rangemode": "tozero", "color": "#5b5f61", "gridcolor": "#1c2225", "linecolor": "#2d3335", "tickfont": {"size": 11, "family": "Geist Mono, monospace"}, "ticksuffix": " ms", "zeroline": false}, "legend": {"orientation": "h", "y": -0.28, "x": 0, "font": {"size": 11, "color": "#a1a4a5"}, "bgcolor": "rgba(0,0,0,0)", "visible": false}, "hoverlabel": {"bgcolor": "#1c2225", "bordercolor": "#2d3335", "font": {"size": 12, "color": "#d7d8d9", "family": "Geist Mono, monospace"}}},
|
||||
{responsive: true, displayModeBar: false});
|
||||
</script>
|
||||
</div><div class="chart-card">
|
||||
<div class="chart-card-header">
|
||||
<span class="model-tag">qwen3-moe</span>
|
||||
</div>
|
||||
<div id="c_ttft_ms_qwen3_moe"></div>
|
||||
<script>
|
||||
Plotly.newPlot("c_ttft_ms_qwen3_moe", [{"x": ["2026-05-01T18-56-26-996695"], "y": [1565.540504961973], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "HF Baseline", "line": {"color": "#5b5f61", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": false, "hovertemplate": "<b>HF Baseline</b><br>%{customdata[1]}<br>%{y:.1f} ms<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}, {"x": ["2026-05-01T18-56-26-996695"], "y": [460.077923577046], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "torch.compile", "line": {"color": "#3b82f6", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": false, "hovertemplate": "<b>torch.compile</b><br>%{customdata[1]}<br>%{y:.1f} ms<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}, {"x": ["2026-05-01T18-56-26-996695"], "y": [21002.791983017232], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "luminal backend", "line": {"color": "#a855f7", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": false, "hovertemplate": "<b>luminal backend</b><br>%{customdata[1]}<br>%{y:.1f} ms<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}, {"x": ["2026-05-01T18-56-26-996695"], "y": [662.07], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "Rust (luminal)", "line": {"color": "#e8855a", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": false, "hovertemplate": "<b>Rust (luminal)</b><br>%{customdata[1]}<br>%{y:.1f} ms<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}], {"plot_bgcolor": "#0d1416", "paper_bgcolor": "#141b1d", "font": {"family": "Geist, system-ui, sans-serif", "color": "#d7d8d9"}, "margin": {"t": 16, "b": 16, "l": 52, "r": 12}, "height": 280, "xaxis": {"type": "category", "categoryorder": "array", "color": "#5b5f61", "gridcolor": "#1c2225", "linecolor": "#2d3335", "tickfont": {"size": 11, "family": "Geist Mono, monospace"}, "tickangle": -30, "automargin": true, "zeroline": false, "categoryarray": ["2026-05-01T18-56-26-996695"], "tickvals": ["2026-05-01T18-56-26-996695"], "ticktext": ["May 01 \u00b7 18:56"]}, "yaxis": {"rangemode": "tozero", "color": "#5b5f61", "gridcolor": "#1c2225", "linecolor": "#2d3335", "tickfont": {"size": 11, "family": "Geist Mono, monospace"}, "ticksuffix": " ms", "zeroline": false}, "legend": {"orientation": "h", "y": -0.28, "x": 0, "font": {"size": 11, "color": "#a1a4a5"}, "bgcolor": "rgba(0,0,0,0)", "visible": false}, "hoverlabel": {"bgcolor": "#1c2225", "bordercolor": "#2d3335", "font": {"size": 12, "color": "#d7d8d9", "family": "Geist Mono, monospace"}}},
|
||||
{responsive: true, displayModeBar: false});
|
||||
</script>
|
||||
</div>
|
||||
</div>
|
||||
</section>
|
||||
<section>
|
||||
<div class="section-header">
|
||||
<span class="section-eyebrow">metric</span>
|
||||
<h2 class="section-title">TPOT <span class="unit">over time</span></h2>
|
||||
<span class="section-tag">Time per output token (ms)</span>
|
||||
</div>
|
||||
<div class="chart-grid" style="grid-template-columns: repeat(4, 1fr)">
|
||||
<div class="chart-card">
|
||||
<div class="chart-card-header">
|
||||
<span class="model-tag">llama-8b</span>
|
||||
</div>
|
||||
<div id="c_tpot_ms_llama_8b"></div>
|
||||
<script>
|
||||
Plotly.newPlot("c_tpot_ms_llama_8b", [{"x": ["2026-05-01T18-56-26-996695"], "y": [34.15271903970279], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "HF Baseline", "line": {"color": "#5b5f61", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": true, "hovertemplate": "<b>HF Baseline</b><br>%{customdata[1]}<br>%{y:.1f} ms<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}, {"x": ["2026-05-01T18-56-26-996695"], "y": [171.7862353892997], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "torch.compile", "line": {"color": "#3b82f6", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": true, "hovertemplate": "<b>torch.compile</b><br>%{customdata[1]}<br>%{y:.1f} ms<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}, {"x": ["2026-05-01T18-56-26-996695"], "y": [23.078908618772402], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "luminal backend", "line": {"color": "#a855f7", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": true, "hovertemplate": "<b>luminal backend</b><br>%{customdata[1]}<br>%{y:.1f} ms<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}, {"x": ["2026-05-01T18-56-26-996695"], "y": [51.64], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "Rust (luminal)", "line": {"color": "#e8855a", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": true, "hovertemplate": "<b>Rust (luminal)</b><br>%{customdata[1]}<br>%{y:.1f} ms<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}], {"plot_bgcolor": "#0d1416", "paper_bgcolor": "#141b1d", "font": {"family": "Geist, system-ui, sans-serif", "color": "#d7d8d9"}, "margin": {"t": 16, "b": 48, "l": 52, "r": 12}, "height": 280, "xaxis": {"type": "category", "categoryorder": "array", "color": "#5b5f61", "gridcolor": "#1c2225", "linecolor": "#2d3335", "tickfont": {"size": 11, "family": "Geist Mono, monospace"}, "tickangle": -30, "automargin": true, "zeroline": false, "categoryarray": ["2026-05-01T18-56-26-996695"], "tickvals": ["2026-05-01T18-56-26-996695"], "ticktext": ["May 01 \u00b7 18:56"]}, "yaxis": {"rangemode": "tozero", "color": "#5b5f61", "gridcolor": "#1c2225", "linecolor": "#2d3335", "tickfont": {"size": 11, "family": "Geist Mono, monospace"}, "ticksuffix": " ms", "zeroline": false}, "legend": {"orientation": "h", "y": -0.28, "x": 0, "font": {"size": 11, "color": "#a1a4a5"}, "bgcolor": "rgba(0,0,0,0)"}, "hoverlabel": {"bgcolor": "#1c2225", "bordercolor": "#2d3335", "font": {"size": 12, "color": "#d7d8d9", "family": "Geist Mono, monospace"}}},
|
||||
{responsive: true, displayModeBar: false});
|
||||
</script>
|
||||
</div><div class="chart-card">
|
||||
<div class="chart-card-header">
|
||||
<span class="model-tag">qwen3-4b</span>
|
||||
</div>
|
||||
<div id="c_tpot_ms_qwen3_4b"></div>
|
||||
<script>
|
||||
Plotly.newPlot("c_tpot_ms_qwen3_4b", [{"x": ["2026-05-01T18-56-26-996695"], "y": [47.71483448566869], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "HF Baseline", "line": {"color": "#5b5f61", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": false, "hovertemplate": "<b>HF Baseline</b><br>%{customdata[1]}<br>%{y:.1f} ms<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}, {"x": ["2026-05-01T18-56-26-996695"], "y": [468.56868775503244], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "torch.compile", "line": {"color": "#3b82f6", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": false, "hovertemplate": "<b>torch.compile</b><br>%{customdata[1]}<br>%{y:.1f} ms<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}, {"x": ["2026-05-01T18-56-26-996695"], "y": [26.90318431414198], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "luminal backend", "line": {"color": "#a855f7", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": false, "hovertemplate": "<b>luminal backend</b><br>%{customdata[1]}<br>%{y:.1f} ms<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}, {"x": ["2026-05-01T18-56-26-996695"], "y": [40.62], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "Rust (luminal)", "line": {"color": "#e8855a", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": false, "hovertemplate": "<b>Rust (luminal)</b><br>%{customdata[1]}<br>%{y:.1f} ms<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}], {"plot_bgcolor": "#0d1416", "paper_bgcolor": "#141b1d", "font": {"family": "Geist, system-ui, sans-serif", "color": "#d7d8d9"}, "margin": {"t": 16, "b": 16, "l": 52, "r": 12}, "height": 280, "xaxis": {"type": "category", "categoryorder": "array", "color": "#5b5f61", "gridcolor": "#1c2225", "linecolor": "#2d3335", "tickfont": {"size": 11, "family": "Geist Mono, monospace"}, "tickangle": -30, "automargin": true, "zeroline": false, "categoryarray": ["2026-05-01T18-56-26-996695"], "tickvals": ["2026-05-01T18-56-26-996695"], "ticktext": ["May 01 \u00b7 18:56"]}, "yaxis": {"rangemode": "tozero", "color": "#5b5f61", "gridcolor": "#1c2225", "linecolor": "#2d3335", "tickfont": {"size": 11, "family": "Geist Mono, monospace"}, "ticksuffix": " ms", "zeroline": false}, "legend": {"orientation": "h", "y": -0.28, "x": 0, "font": {"size": 11, "color": "#a1a4a5"}, "bgcolor": "rgba(0,0,0,0)", "visible": false}, "hoverlabel": {"bgcolor": "#1c2225", "bordercolor": "#2d3335", "font": {"size": 12, "color": "#d7d8d9", "family": "Geist Mono, monospace"}}},
|
||||
{responsive: true, displayModeBar: false});
|
||||
</script>
|
||||
</div><div class="chart-card">
|
||||
<div class="chart-card-header">
|
||||
<span class="model-tag">gemma3-4b</span>
|
||||
</div>
|
||||
<div id="c_tpot_ms_gemma3_4b"></div>
|
||||
<script>
|
||||
Plotly.newPlot("c_tpot_ms_gemma3_4b", [{"x": ["2026-05-01T18-56-26-996695"], "y": [52.498737201676704], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "HF Baseline", "line": {"color": "#5b5f61", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": false, "hovertemplate": "<b>HF Baseline</b><br>%{customdata[1]}<br>%{y:.1f} ms<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}, {"x": ["2026-05-01T18-56-26-996695"], "y": [2197.426627812092], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "torch.compile", "line": {"color": "#3b82f6", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": false, "hovertemplate": "<b>torch.compile</b><br>%{customdata[1]}<br>%{y:.1f} ms<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}, {"x": ["2026-05-01T18-56-26-996695"], "y": [38.99], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "Rust (luminal)", "line": {"color": "#e8855a", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": false, "hovertemplate": "<b>Rust (luminal)</b><br>%{customdata[1]}<br>%{y:.1f} ms<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}], {"plot_bgcolor": "#0d1416", "paper_bgcolor": "#141b1d", "font": {"family": "Geist, system-ui, sans-serif", "color": "#d7d8d9"}, "margin": {"t": 16, "b": 16, "l": 52, "r": 12}, "height": 280, "xaxis": {"type": "category", "categoryorder": "array", "color": "#5b5f61", "gridcolor": "#1c2225", "linecolor": "#2d3335", "tickfont": {"size": 11, "family": "Geist Mono, monospace"}, "tickangle": -30, "automargin": true, "zeroline": false, "categoryarray": ["2026-05-01T18-56-26-996695"], "tickvals": ["2026-05-01T18-56-26-996695"], "ticktext": ["May 01 \u00b7 18:56"]}, "yaxis": {"rangemode": "tozero", "color": "#5b5f61", "gridcolor": "#1c2225", "linecolor": "#2d3335", "tickfont": {"size": 11, "family": "Geist Mono, monospace"}, "ticksuffix": " ms", "zeroline": false}, "legend": {"orientation": "h", "y": -0.28, "x": 0, "font": {"size": 11, "color": "#a1a4a5"}, "bgcolor": "rgba(0,0,0,0)", "visible": false}, "hoverlabel": {"bgcolor": "#1c2225", "bordercolor": "#2d3335", "font": {"size": 12, "color": "#d7d8d9", "family": "Geist Mono, monospace"}}},
|
||||
{responsive: true, displayModeBar: false});
|
||||
</script>
|
||||
</div><div class="chart-card">
|
||||
<div class="chart-card-header">
|
||||
<span class="model-tag">gemma4-moe</span>
|
||||
</div>
|
||||
<div id="c_tpot_ms_gemma4_moe"></div>
|
||||
<script>
|
||||
Plotly.newPlot("c_tpot_ms_gemma4_moe", [{"x": ["2026-05-01T18-56-26-996695"], "y": [83.64427039632574], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "HF Baseline", "line": {"color": "#5b5f61", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": false, "hovertemplate": "<b>HF Baseline</b><br>%{customdata[1]}<br>%{y:.1f} ms<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}, {"x": ["2026-05-01T18-56-26-996695"], "y": [654.9649795080768], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "torch.compile", "line": {"color": "#3b82f6", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": false, "hovertemplate": "<b>torch.compile</b><br>%{customdata[1]}<br>%{y:.1f} ms<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}], {"plot_bgcolor": "#0d1416", "paper_bgcolor": "#141b1d", "font": {"family": "Geist, system-ui, sans-serif", "color": "#d7d8d9"}, "margin": {"t": 16, "b": 16, "l": 52, "r": 12}, "height": 280, "xaxis": {"type": "category", "categoryorder": "array", "color": "#5b5f61", "gridcolor": "#1c2225", "linecolor": "#2d3335", "tickfont": {"size": 11, "family": "Geist Mono, monospace"}, "tickangle": -30, "automargin": true, "zeroline": false, "categoryarray": ["2026-05-01T18-56-26-996695"], "tickvals": ["2026-05-01T18-56-26-996695"], "ticktext": ["May 01 \u00b7 18:56"]}, "yaxis": {"rangemode": "tozero", "color": "#5b5f61", "gridcolor": "#1c2225", "linecolor": "#2d3335", "tickfont": {"size": 11, "family": "Geist Mono, monospace"}, "ticksuffix": " ms", "zeroline": false}, "legend": {"orientation": "h", "y": -0.28, "x": 0, "font": {"size": 11, "color": "#a1a4a5"}, "bgcolor": "rgba(0,0,0,0)", "visible": false}, "hoverlabel": {"bgcolor": "#1c2225", "bordercolor": "#2d3335", "font": {"size": 12, "color": "#d7d8d9", "family": "Geist Mono, monospace"}}},
|
||||
{responsive: true, displayModeBar: false});
|
||||
</script>
|
||||
</div><div class="chart-card">
|
||||
<div class="chart-card-header">
|
||||
<span class="model-tag">qwen3-moe</span>
|
||||
</div>
|
||||
<div id="c_tpot_ms_qwen3_moe"></div>
|
||||
<script>
|
||||
Plotly.newPlot("c_tpot_ms_qwen3_moe", [{"x": ["2026-05-01T18-56-26-996695"], "y": [84.527321747737], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "HF Baseline", "line": {"color": "#5b5f61", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": false, "hovertemplate": "<b>HF Baseline</b><br>%{customdata[1]}<br>%{y:.1f} ms<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}, {"x": ["2026-05-01T18-56-26-996695"], "y": [753.0061075551203], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "torch.compile", "line": {"color": "#3b82f6", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": false, "hovertemplate": "<b>torch.compile</b><br>%{customdata[1]}<br>%{y:.1f} ms<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}, {"x": ["2026-05-01T18-56-26-996695"], "y": [1166.8824461026816], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "luminal backend", "line": {"color": "#a855f7", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": false, "hovertemplate": "<b>luminal backend</b><br>%{customdata[1]}<br>%{y:.1f} ms<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}, {"x": ["2026-05-01T18-56-26-996695"], "y": [60.08], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "Rust (luminal)", "line": {"color": "#e8855a", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": false, "hovertemplate": "<b>Rust (luminal)</b><br>%{customdata[1]}<br>%{y:.1f} ms<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}], {"plot_bgcolor": "#0d1416", "paper_bgcolor": "#141b1d", "font": {"family": "Geist, system-ui, sans-serif", "color": "#d7d8d9"}, "margin": {"t": 16, "b": 16, "l": 52, "r": 12}, "height": 280, "xaxis": {"type": "category", "categoryorder": "array", "color": "#5b5f61", "gridcolor": "#1c2225", "linecolor": "#2d3335", "tickfont": {"size": 11, "family": "Geist Mono, monospace"}, "tickangle": -30, "automargin": true, "zeroline": false, "categoryarray": ["2026-05-01T18-56-26-996695"], "tickvals": ["2026-05-01T18-56-26-996695"], "ticktext": ["May 01 \u00b7 18:56"]}, "yaxis": {"rangemode": "tozero", "color": "#5b5f61", "gridcolor": "#1c2225", "linecolor": "#2d3335", "tickfont": {"size": 11, "family": "Geist Mono, monospace"}, "ticksuffix": " ms", "zeroline": false}, "legend": {"orientation": "h", "y": -0.28, "x": 0, "font": {"size": 11, "color": "#a1a4a5"}, "bgcolor": "rgba(0,0,0,0)", "visible": false}, "hoverlabel": {"bgcolor": "#1c2225", "bordercolor": "#2d3335", "font": {"size": 12, "color": "#d7d8d9", "family": "Geist Mono, monospace"}}},
|
||||
{responsive: true, displayModeBar: false});
|
||||
</script>
|
||||
</div>
|
||||
</div>
|
||||
</section>
|
||||
<section>
|
||||
<div class="section-header">
|
||||
<span class="section-eyebrow">metric</span>
|
||||
<h2 class="section-title">Time to Search <span class="unit">over time</span></h2>
|
||||
<span class="section-tag">Search time (sec)</span>
|
||||
</div>
|
||||
<div class="chart-grid" style="grid-template-columns: repeat(4, 1fr)">
|
||||
<div class="chart-card">
|
||||
<div class="chart-card-header">
|
||||
<span class="model-tag">llama-8b</span>
|
||||
</div>
|
||||
<div id="c_compile_ms_llama_8b"></div>
|
||||
<script>
|
||||
Plotly.newPlot("c_compile_ms_llama_8b", [{"x": ["2026-05-01T18-56-26-996695"], "y": [18.760145067994017], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "torch.compile", "line": {"color": "#3b82f6", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": true, "hovertemplate": "<b>torch.compile</b><br>%{customdata[1]}<br>%{y:.1f} sec<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}, {"x": ["2026-05-01T18-56-26-996695"], "y": [95.96263545705006], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "luminal backend", "line": {"color": "#a855f7", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": true, "hovertemplate": "<b>luminal backend</b><br>%{customdata[1]}<br>%{y:.1f} sec<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}, {"x": ["2026-05-01T18-56-26-996695"], "y": [84.45343], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "Rust (luminal)", "line": {"color": "#e8855a", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": true, "hovertemplate": "<b>Rust (luminal)</b><br>%{customdata[1]}<br>%{y:.1f} sec<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}], {"plot_bgcolor": "#0d1416", "paper_bgcolor": "#141b1d", "font": {"family": "Geist, system-ui, sans-serif", "color": "#d7d8d9"}, "margin": {"t": 16, "b": 48, "l": 52, "r": 12}, "height": 280, "xaxis": {"type": "category", "categoryorder": "array", "color": "#5b5f61", "gridcolor": "#1c2225", "linecolor": "#2d3335", "tickfont": {"size": 11, "family": "Geist Mono, monospace"}, "tickangle": -30, "automargin": true, "zeroline": false, "categoryarray": ["2026-05-01T18-56-26-996695"], "tickvals": ["2026-05-01T18-56-26-996695"], "ticktext": ["May 01 \u00b7 18:56"]}, "yaxis": {"rangemode": "tozero", "color": "#5b5f61", "gridcolor": "#1c2225", "linecolor": "#2d3335", "tickfont": {"size": 11, "family": "Geist Mono, monospace"}, "ticksuffix": " sec", "zeroline": false}, "legend": {"orientation": "h", "y": -0.28, "x": 0, "font": {"size": 11, "color": "#a1a4a5"}, "bgcolor": "rgba(0,0,0,0)"}, "hoverlabel": {"bgcolor": "#1c2225", "bordercolor": "#2d3335", "font": {"size": 12, "color": "#d7d8d9", "family": "Geist Mono, monospace"}}},
|
||||
{responsive: true, displayModeBar: false});
|
||||
</script>
|
||||
</div><div class="chart-card">
|
||||
<div class="chart-card-header">
|
||||
<span class="model-tag">qwen3-4b</span>
|
||||
</div>
|
||||
<div id="c_compile_ms_qwen3_4b"></div>
|
||||
<script>
|
||||
Plotly.newPlot("c_compile_ms_qwen3_4b", [{"x": ["2026-05-01T18-56-26-996695"], "y": [4.680963660997804], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "torch.compile", "line": {"color": "#3b82f6", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": false, "hovertemplate": "<b>torch.compile</b><br>%{customdata[1]}<br>%{y:.1f} sec<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}, {"x": ["2026-05-01T18-56-26-996695"], "y": [45.345814052037895], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "luminal backend", "line": {"color": "#a855f7", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": false, "hovertemplate": "<b>luminal backend</b><br>%{customdata[1]}<br>%{y:.1f} sec<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}, {"x": ["2026-05-01T18-56-26-996695"], "y": [19.92977], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "Rust (luminal)", "line": {"color": "#e8855a", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": false, "hovertemplate": "<b>Rust (luminal)</b><br>%{customdata[1]}<br>%{y:.1f} sec<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}], {"plot_bgcolor": "#0d1416", "paper_bgcolor": "#141b1d", "font": {"family": "Geist, system-ui, sans-serif", "color": "#d7d8d9"}, "margin": {"t": 16, "b": 16, "l": 52, "r": 12}, "height": 280, "xaxis": {"type": "category", "categoryorder": "array", "color": "#5b5f61", "gridcolor": "#1c2225", "linecolor": "#2d3335", "tickfont": {"size": 11, "family": "Geist Mono, monospace"}, "tickangle": -30, "automargin": true, "zeroline": false, "categoryarray": ["2026-05-01T18-56-26-996695"], "tickvals": ["2026-05-01T18-56-26-996695"], "ticktext": ["May 01 \u00b7 18:56"]}, "yaxis": {"rangemode": "tozero", "color": "#5b5f61", "gridcolor": "#1c2225", "linecolor": "#2d3335", "tickfont": {"size": 11, "family": "Geist Mono, monospace"}, "ticksuffix": " sec", "zeroline": false}, "legend": {"orientation": "h", "y": -0.28, "x": 0, "font": {"size": 11, "color": "#a1a4a5"}, "bgcolor": "rgba(0,0,0,0)", "visible": false}, "hoverlabel": {"bgcolor": "#1c2225", "bordercolor": "#2d3335", "font": {"size": 12, "color": "#d7d8d9", "family": "Geist Mono, monospace"}}},
|
||||
{responsive: true, displayModeBar: false});
|
||||
</script>
|
||||
</div><div class="chart-card">
|
||||
<div class="chart-card-header">
|
||||
<span class="model-tag">gemma3-4b</span>
|
||||
</div>
|
||||
<div id="c_compile_ms_gemma3_4b"></div>
|
||||
<script>
|
||||
Plotly.newPlot("c_compile_ms_gemma3_4b", [{"x": ["2026-05-01T18-56-26-996695"], "y": [26.649526304972824], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "torch.compile", "line": {"color": "#3b82f6", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": false, "hovertemplate": "<b>torch.compile</b><br>%{customdata[1]}<br>%{y:.1f} sec<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}, {"x": ["2026-05-01T18-56-26-996695"], "y": [156.84164], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "Rust (luminal)", "line": {"color": "#e8855a", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": false, "hovertemplate": "<b>Rust (luminal)</b><br>%{customdata[1]}<br>%{y:.1f} sec<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}], {"plot_bgcolor": "#0d1416", "paper_bgcolor": "#141b1d", "font": {"family": "Geist, system-ui, sans-serif", "color": "#d7d8d9"}, "margin": {"t": 16, "b": 16, "l": 52, "r": 12}, "height": 280, "xaxis": {"type": "category", "categoryorder": "array", "color": "#5b5f61", "gridcolor": "#1c2225", "linecolor": "#2d3335", "tickfont": {"size": 11, "family": "Geist Mono, monospace"}, "tickangle": -30, "automargin": true, "zeroline": false, "categoryarray": ["2026-05-01T18-56-26-996695"], "tickvals": ["2026-05-01T18-56-26-996695"], "ticktext": ["May 01 \u00b7 18:56"]}, "yaxis": {"rangemode": "tozero", "color": "#5b5f61", "gridcolor": "#1c2225", "linecolor": "#2d3335", "tickfont": {"size": 11, "family": "Geist Mono, monospace"}, "ticksuffix": " sec", "zeroline": false}, "legend": {"orientation": "h", "y": -0.28, "x": 0, "font": {"size": 11, "color": "#a1a4a5"}, "bgcolor": "rgba(0,0,0,0)", "visible": false}, "hoverlabel": {"bgcolor": "#1c2225", "bordercolor": "#2d3335", "font": {"size": 12, "color": "#d7d8d9", "family": "Geist Mono, monospace"}}},
|
||||
{responsive: true, displayModeBar: false});
|
||||
</script>
|
||||
</div><div class="chart-card">
|
||||
<div class="chart-card-header">
|
||||
<span class="model-tag">gemma4-moe</span>
|
||||
</div>
|
||||
<div id="c_compile_ms_gemma4_moe"></div>
|
||||
<script>
|
||||
Plotly.newPlot("c_compile_ms_gemma4_moe", [{"x": ["2026-05-01T18-56-26-996695"], "y": [38.81582092499593], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "torch.compile", "line": {"color": "#3b82f6", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": false, "hovertemplate": "<b>torch.compile</b><br>%{customdata[1]}<br>%{y:.1f} sec<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}], {"plot_bgcolor": "#0d1416", "paper_bgcolor": "#141b1d", "font": {"family": "Geist, system-ui, sans-serif", "color": "#d7d8d9"}, "margin": {"t": 16, "b": 16, "l": 52, "r": 12}, "height": 280, "xaxis": {"type": "category", "categoryorder": "array", "color": "#5b5f61", "gridcolor": "#1c2225", "linecolor": "#2d3335", "tickfont": {"size": 11, "family": "Geist Mono, monospace"}, "tickangle": -30, "automargin": true, "zeroline": false, "categoryarray": ["2026-05-01T18-56-26-996695"], "tickvals": ["2026-05-01T18-56-26-996695"], "ticktext": ["May 01 \u00b7 18:56"]}, "yaxis": {"rangemode": "tozero", "color": "#5b5f61", "gridcolor": "#1c2225", "linecolor": "#2d3335", "tickfont": {"size": 11, "family": "Geist Mono, monospace"}, "ticksuffix": " sec", "zeroline": false}, "legend": {"orientation": "h", "y": -0.28, "x": 0, "font": {"size": 11, "color": "#a1a4a5"}, "bgcolor": "rgba(0,0,0,0)", "visible": false}, "hoverlabel": {"bgcolor": "#1c2225", "bordercolor": "#2d3335", "font": {"size": 12, "color": "#d7d8d9", "family": "Geist Mono, monospace"}}},
|
||||
{responsive: true, displayModeBar: false});
|
||||
</script>
|
||||
</div><div class="chart-card">
|
||||
<div class="chart-card-header">
|
||||
<span class="model-tag">qwen3-moe</span>
|
||||
</div>
|
||||
<div id="c_compile_ms_qwen3_moe"></div>
|
||||
<script>
|
||||
Plotly.newPlot("c_compile_ms_qwen3_moe", [{"x": ["2026-05-01T18-56-26-996695"], "y": [8.341281775035895], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "torch.compile", "line": {"color": "#3b82f6", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": false, "hovertemplate": "<b>torch.compile</b><br>%{customdata[1]}<br>%{y:.1f} sec<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}, {"x": ["2026-05-01T18-56-26-996695"], "y": [111.70731823903043], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "luminal backend", "line": {"color": "#a855f7", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": false, "hovertemplate": "<b>luminal backend</b><br>%{customdata[1]}<br>%{y:.1f} sec<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}, {"x": ["2026-05-01T18-56-26-996695"], "y": [80.83241000000001], "customdata": [["b2bd91f5", "2026-05-01T18:56:26.996695"]], "type": "scatter", "mode": "lines+markers", "name": "Rust (luminal)", "line": {"color": "#e8855a", "width": 2}, "marker": {"size": 7, "symbol": "circle"}, "connectgaps": false, "showlegend": false, "hovertemplate": "<b>Rust (luminal)</b><br>%{customdata[1]}<br>%{y:.1f} sec<br><span style='color:#7e8385'>commit %{customdata[0]}</span><extra></extra>"}], {"plot_bgcolor": "#0d1416", "paper_bgcolor": "#141b1d", "font": {"family": "Geist, system-ui, sans-serif", "color": "#d7d8d9"}, "margin": {"t": 16, "b": 16, "l": 52, "r": 12}, "height": 280, "xaxis": {"type": "category", "categoryorder": "array", "color": "#5b5f61", "gridcolor": "#1c2225", "linecolor": "#2d3335", "tickfont": {"size": 11, "family": "Geist Mono, monospace"}, "tickangle": -30, "automargin": true, "zeroline": false, "categoryarray": ["2026-05-01T18-56-26-996695"], "tickvals": ["2026-05-01T18-56-26-996695"], "ticktext": ["May 01 \u00b7 18:56"]}, "yaxis": {"rangemode": "tozero", "color": "#5b5f61", "gridcolor": "#1c2225", "linecolor": "#2d3335", "tickfont": {"size": 11, "family": "Geist Mono, monospace"}, "ticksuffix": " sec", "zeroline": false}, "legend": {"orientation": "h", "y": -0.28, "x": 0, "font": {"size": 11, "color": "#a1a4a5"}, "bgcolor": "rgba(0,0,0,0)", "visible": false}, "hoverlabel": {"bgcolor": "#1c2225", "bordercolor": "#2d3335", "font": {"size": 12, "color": "#d7d8d9", "family": "Geist Mono, monospace"}}},
|
||||
{responsive: true, displayModeBar: false});
|
||||
</script>
|
||||
</div>
|
||||
</div>
|
||||
</section>
|
||||
<hr class='section-divider'>
|
||||
<section>
|
||||
<div class="section-header">
|
||||
<span class="section-eyebrow">sweep · 3d</span>
|
||||
<h2 class="section-title">TTFT <span class="unit">vs search budget · over time</span></h2>
|
||||
<span class="section-tag">1 run</span>
|
||||
</div>
|
||||
<p class="sweep-hint">Drag to rotate · scroll to zoom · each curve = one run</p>
|
||||
<div class="chart-grid" style="grid-template-columns: repeat(4, 1fr)">
|
||||
<div class="chart-card">
|
||||
<div class="chart-card-header">
|
||||
<span class="model-tag">llama-8b</span>
|
||||
</div>
|
||||
<div id="sw_ttft_ms_llama_8b"></div>
|
||||
<script>
|
||||
Plotly.newPlot("sw_ttft_ms_llama_8b", [{"type": "scatter3d", "mode": "lines+markers", "x": [10, 100, 500], "y": ["May 01", "May 01", "May 01"], "z": [470.7036415056791, 460.72837291285396, 472.43661794345826], "name": "luminal backend", "legendgroup": "python_luminal", "showlegend": true, "line": {"color": "#a855f7", "width": 5}, "marker": {"color": "#a855f7", "size": 4}, "hovertemplate": "<b>luminal backend</b><br>s=%{x} iters<br>%{z:.1f} ms<br>May 01 \u00b7 b2bd91f5<extra></extra>"}, {"type": "scatter3d", "mode": "lines+markers", "x": [10, 100, 500], "y": ["May 01", "May 01", "May 01"], "z": [751.03, 1038.34, 453.16], "name": "Rust (luminal)", "legendgroup": "rust", "showlegend": true, "line": {"color": "#e8855a", "width": 5}, "marker": {"color": "#e8855a", "size": 4}, "hovertemplate": "<b>Rust (luminal)</b><br>s=%{x} iters<br>%{z:.1f} ms<br>May 01 \u00b7 b2bd91f5<extra></extra>"}], {"paper_bgcolor": "#141b1d", "font": {"family": "Geist, system-ui, sans-serif", "color": "#d7d8d9", "size": 11}, "height": 420, "margin": {"t": 20, "b": 0, "l": 0, "r": 0}, "legend": {"orientation": "h", "y": -0.05, "x": 0, "font": {"size": 11, "color": "#a1a4a5", "family": "Geist Mono, monospace"}, "bgcolor": "rgba(0,0,0,0)"}, "hoverlabel": {"bgcolor": "#1c2225", "bordercolor": "#2d3335", "font": {"size": 12, "color": "#d7d8d9", "family": "Geist Mono, monospace"}}, "scene": {"bgcolor": "#0d1416", "xaxis": {"title": {"text": "search iters", "font": {"size": 10, "color": "#7e8385"}}, "type": "log", "tickvals": [5, 10, 20, 50, 100, 500], "ticktext": ["5", "10", "20", "50", "100", "500"], "tickfont": {"size": 10, "family": "Geist Mono, monospace", "color": "#5b5f61"}, "gridcolor": "#1c2225", "linecolor": "#2d3335", "zerolinecolor": "#2d3335"}, "yaxis": {"title": {"text": "run", "font": {"size": 10, "color": "#7e8385"}}, "tickfont": {"size": 10, "family": "Geist Mono, monospace", "color": "#5b5f61"}, "gridcolor": "#1c2225", "linecolor": "#2d3335"}, "zaxis": {"title": {"text": "ms", "font": {"size": 10, "color": "#7e8385"}}, "rangemode": "tozero", "tickfont": {"size": 10, "family": "Geist Mono, monospace", "color": "#5b5f61"}, "ticksuffix": " ms", "gridcolor": "#1c2225", "linecolor": "#2d3335"}, "camera": {"eye": {"x": 1.6, "y": -1.6, "z": 0.9}}}},
|
||||
{responsive: true, displayModeBar: true, displaylogo: false,
|
||||
modeBarButtonsToRemove: ["toImage","sendDataToCloud"]});
|
||||
</script>
|
||||
</div><div class="chart-card">
|
||||
<div class="chart-card-header">
|
||||
<span class="model-tag">qwen3-4b</span>
|
||||
</div>
|
||||
<div id="sw_ttft_ms_qwen3_4b"></div>
|
||||
<script>
|
||||
Plotly.newPlot("sw_ttft_ms_qwen3_4b", [{"type": "scatter3d", "mode": "lines+markers", "x": [10, 100, 500], "y": ["May 01", "May 01", "May 01"], "z": [465.02652901108377, 465.9317950136028, 495.75577257201076], "name": "luminal backend", "legendgroup": "python_luminal", "showlegend": true, "line": {"color": "#a855f7", "width": 5}, "marker": {"color": "#a855f7", "size": 4}, "hovertemplate": "<b>luminal backend</b><br>s=%{x} iters<br>%{z:.1f} ms<br>May 01 \u00b7 b2bd91f5<extra></extra>"}, {"type": "scatter3d", "mode": "lines+markers", "x": [10, 100, 500], "y": ["May 01", "May 01", "May 01"], "z": [398.44, 390.08, 559.29], "name": "Rust (luminal)", "legendgroup": "rust", "showlegend": true, "line": {"color": "#e8855a", "width": 5}, "marker": {"color": "#e8855a", "size": 4}, "hovertemplate": "<b>Rust (luminal)</b><br>s=%{x} iters<br>%{z:.1f} ms<br>May 01 \u00b7 b2bd91f5<extra></extra>"}], {"paper_bgcolor": "#141b1d", "font": {"family": "Geist, system-ui, sans-serif", "color": "#d7d8d9", "size": 11}, "height": 420, "margin": {"t": 20, "b": 0, "l": 0, "r": 0}, "legend": {"orientation": "h", "y": -0.05, "x": 0, "font": {"size": 11, "color": "#a1a4a5", "family": "Geist Mono, monospace"}, "bgcolor": "rgba(0,0,0,0)"}, "hoverlabel": {"bgcolor": "#1c2225", "bordercolor": "#2d3335", "font": {"size": 12, "color": "#d7d8d9", "family": "Geist Mono, monospace"}}, "scene": {"bgcolor": "#0d1416", "xaxis": {"title": {"text": "search iters", "font": {"size": 10, "color": "#7e8385"}}, "type": "log", "tickvals": [5, 10, 20, 50, 100, 500], "ticktext": ["5", "10", "20", "50", "100", "500"], "tickfont": {"size": 10, "family": "Geist Mono, monospace", "color": "#5b5f61"}, "gridcolor": "#1c2225", "linecolor": "#2d3335", "zerolinecolor": "#2d3335"}, "yaxis": {"title": {"text": "run", "font": {"size": 10, "color": "#7e8385"}}, "tickfont": {"size": 10, "family": "Geist Mono, monospace", "color": "#5b5f61"}, "gridcolor": "#1c2225", "linecolor": "#2d3335"}, "zaxis": {"title": {"text": "ms", "font": {"size": 10, "color": "#7e8385"}}, "rangemode": "tozero", "tickfont": {"size": 10, "family": "Geist Mono, monospace", "color": "#5b5f61"}, "ticksuffix": " ms", "gridcolor": "#1c2225", "linecolor": "#2d3335"}, "camera": {"eye": {"x": 1.6, "y": -1.6, "z": 0.9}}}},
|
||||
{responsive: true, displayModeBar: true, displaylogo: false,
|
||||
modeBarButtonsToRemove: ["toImage","sendDataToCloud"]});
|
||||
</script>
|
||||
</div><div class="chart-card">
|
||||
<div class="chart-card-header">
|
||||
<span class="model-tag">gemma3-4b</span>
|
||||
</div>
|
||||
<div id="sw_ttft_ms_gemma3_4b"></div>
|
||||
<script>
|
||||
Plotly.newPlot("sw_ttft_ms_gemma3_4b", [{"type": "scatter3d", "mode": "lines+markers", "x": [10, 100, 500], "y": ["May 01", "May 01", "May 01"], "z": [388.19, 436.49, 386.13], "name": "Rust (luminal)", "legendgroup": "rust", "showlegend": true, "line": {"color": "#e8855a", "width": 5}, "marker": {"color": "#e8855a", "size": 4}, "hovertemplate": "<b>Rust (luminal)</b><br>s=%{x} iters<br>%{z:.1f} ms<br>May 01 \u00b7 b2bd91f5<extra></extra>"}], {"paper_bgcolor": "#141b1d", "font": {"family": "Geist, system-ui, sans-serif", "color": "#d7d8d9", "size": 11}, "height": 420, "margin": {"t": 20, "b": 0, "l": 0, "r": 0}, "legend": {"orientation": "h", "y": -0.05, "x": 0, "font": {"size": 11, "color": "#a1a4a5", "family": "Geist Mono, monospace"}, "bgcolor": "rgba(0,0,0,0)"}, "hoverlabel": {"bgcolor": "#1c2225", "bordercolor": "#2d3335", "font": {"size": 12, "color": "#d7d8d9", "family": "Geist Mono, monospace"}}, "scene": {"bgcolor": "#0d1416", "xaxis": {"title": {"text": "search iters", "font": {"size": 10, "color": "#7e8385"}}, "type": "log", "tickvals": [5, 10, 20, 50, 100, 500], "ticktext": ["5", "10", "20", "50", "100", "500"], "tickfont": {"size": 10, "family": "Geist Mono, monospace", "color": "#5b5f61"}, "gridcolor": "#1c2225", "linecolor": "#2d3335", "zerolinecolor": "#2d3335"}, "yaxis": {"title": {"text": "run", "font": {"size": 10, "color": "#7e8385"}}, "tickfont": {"size": 10, "family": "Geist Mono, monospace", "color": "#5b5f61"}, "gridcolor": "#1c2225", "linecolor": "#2d3335"}, "zaxis": {"title": {"text": "ms", "font": {"size": 10, "color": "#7e8385"}}, "rangemode": "tozero", "tickfont": {"size": 10, "family": "Geist Mono, monospace", "color": "#5b5f61"}, "ticksuffix": " ms", "gridcolor": "#1c2225", "linecolor": "#2d3335"}, "camera": {"eye": {"x": 1.6, "y": -1.6, "z": 0.9}}}},
|
||||
{responsive: true, displayModeBar: true, displaylogo: false,
|
||||
modeBarButtonsToRemove: ["toImage","sendDataToCloud"]});
|
||||
</script>
|
||||
</div><div class="chart-card">
|
||||
<div class="chart-card-header">
|
||||
<span class="model-tag">qwen3-moe</span>
|
||||
</div>
|
||||
<div id="sw_ttft_ms_qwen3_moe"></div>
|
||||
<script>
|
||||
Plotly.newPlot("sw_ttft_ms_qwen3_moe", [{"type": "scatter3d", "mode": "lines+markers", "x": [10, 100, 500], "y": ["May 01", "May 01", "May 01"], "z": [21002.663500519702, 21018.686580006033, 21034.366824431345], "name": "luminal backend", "legendgroup": "python_luminal", "showlegend": true, "line": {"color": "#a855f7", "width": 5}, "marker": {"color": "#a855f7", "size": 4}, "hovertemplate": "<b>luminal backend</b><br>s=%{x} iters<br>%{z:.1f} ms<br>May 01 \u00b7 b2bd91f5<extra></extra>"}, {"type": "scatter3d", "mode": "lines+markers", "x": [10, 100, 500], "y": ["May 01", "May 01", "May 01"], "z": [656.7, 540.37, 542.34], "name": "Rust (luminal)", "legendgroup": "rust", "showlegend": true, "line": {"color": "#e8855a", "width": 5}, "marker": {"color": "#e8855a", "size": 4}, "hovertemplate": "<b>Rust (luminal)</b><br>s=%{x} iters<br>%{z:.1f} ms<br>May 01 \u00b7 b2bd91f5<extra></extra>"}], {"paper_bgcolor": "#141b1d", "font": {"family": "Geist, system-ui, sans-serif", "color": "#d7d8d9", "size": 11}, "height": 420, "margin": {"t": 20, "b": 0, "l": 0, "r": 0}, "legend": {"orientation": "h", "y": -0.05, "x": 0, "font": {"size": 11, "color": "#a1a4a5", "family": "Geist Mono, monospace"}, "bgcolor": "rgba(0,0,0,0)"}, "hoverlabel": {"bgcolor": "#1c2225", "bordercolor": "#2d3335", "font": {"size": 12, "color": "#d7d8d9", "family": "Geist Mono, monospace"}}, "scene": {"bgcolor": "#0d1416", "xaxis": {"title": {"text": "search iters", "font": {"size": 10, "color": "#7e8385"}}, "type": "log", "tickvals": [5, 10, 20, 50, 100, 500], "ticktext": ["5", "10", "20", "50", "100", "500"], "tickfont": {"size": 10, "family": "Geist Mono, monospace", "color": "#5b5f61"}, "gridcolor": "#1c2225", "linecolor": "#2d3335", "zerolinecolor": "#2d3335"}, "yaxis": {"title": {"text": "run", "font": {"size": 10, "color": "#7e8385"}}, "tickfont": {"size": 10, "family": "Geist Mono, monospace", "color": "#5b5f61"}, "gridcolor": "#1c2225", "linecolor": "#2d3335"}, "zaxis": {"title": {"text": "ms", "font": {"size": 10, "color": "#7e8385"}}, "rangemode": "tozero", "tickfont": {"size": 10, "family": "Geist Mono, monospace", "color": "#5b5f61"}, "ticksuffix": " ms", "gridcolor": "#1c2225", "linecolor": "#2d3335"}, "camera": {"eye": {"x": 1.6, "y": -1.6, "z": 0.9}}}},
|
||||
{responsive: true, displayModeBar: true, displaylogo: false,
|
||||
modeBarButtonsToRemove: ["toImage","sendDataToCloud"]});
|
||||
</script>
|
||||
</div>
|
||||
</div>
|
||||
</section>
|
||||
<section>
|
||||
<div class="section-header">
|
||||
<span class="section-eyebrow">sweep · 3d</span>
|
||||
<h2 class="section-title">TPOT <span class="unit">vs search budget · over time</span></h2>
|
||||
<span class="section-tag">1 run</span>
|
||||
</div>
|
||||
<p class="sweep-hint">Drag to rotate · scroll to zoom · each curve = one run</p>
|
||||
<div class="chart-grid" style="grid-template-columns: repeat(4, 1fr)">
|
||||
<div class="chart-card">
|
||||
<div class="chart-card-header">
|
||||
<span class="model-tag">llama-8b</span>
|
||||
</div>
|
||||
<div id="sw_tpot_ms_llama_8b"></div>
|
||||
<script>
|
||||
Plotly.newPlot("sw_tpot_ms_llama_8b", [{"type": "scatter3d", "mode": "lines+markers", "x": [10, 100, 500], "y": ["May 01", "May 01", "May 01"], "z": [23.540849717101082, 23.101884137140587, 23.610779400914907], "name": "luminal backend", "legendgroup": "python_luminal", "showlegend": true, "line": {"color": "#a855f7", "width": 5}, "marker": {"color": "#a855f7", "size": 4}, "hovertemplate": "<b>luminal backend</b><br>s=%{x} iters<br>%{z:.1f} ms<br>May 01 \u00b7 b2bd91f5<extra></extra>"}, {"type": "scatter3d", "mode": "lines+markers", "x": [10, 100, 500], "y": ["May 01", "May 01", "May 01"], "z": [38.2, 51.92, 24.09], "name": "Rust (luminal)", "legendgroup": "rust", "showlegend": true, "line": {"color": "#e8855a", "width": 5}, "marker": {"color": "#e8855a", "size": 4}, "hovertemplate": "<b>Rust (luminal)</b><br>s=%{x} iters<br>%{z:.1f} ms<br>May 01 \u00b7 b2bd91f5<extra></extra>"}], {"paper_bgcolor": "#141b1d", "font": {"family": "Geist, system-ui, sans-serif", "color": "#d7d8d9", "size": 11}, "height": 420, "margin": {"t": 20, "b": 0, "l": 0, "r": 0}, "legend": {"orientation": "h", "y": -0.05, "x": 0, "font": {"size": 11, "color": "#a1a4a5", "family": "Geist Mono, monospace"}, "bgcolor": "rgba(0,0,0,0)"}, "hoverlabel": {"bgcolor": "#1c2225", "bordercolor": "#2d3335", "font": {"size": 12, "color": "#d7d8d9", "family": "Geist Mono, monospace"}}, "scene": {"bgcolor": "#0d1416", "xaxis": {"title": {"text": "search iters", "font": {"size": 10, "color": "#7e8385"}}, "type": "log", "tickvals": [5, 10, 20, 50, 100, 500], "ticktext": ["5", "10", "20", "50", "100", "500"], "tickfont": {"size": 10, "family": "Geist Mono, monospace", "color": "#5b5f61"}, "gridcolor": "#1c2225", "linecolor": "#2d3335", "zerolinecolor": "#2d3335"}, "yaxis": {"title": {"text": "run", "font": {"size": 10, "color": "#7e8385"}}, "tickfont": {"size": 10, "family": "Geist Mono, monospace", "color": "#5b5f61"}, "gridcolor": "#1c2225", "linecolor": "#2d3335"}, "zaxis": {"title": {"text": "ms", "font": {"size": 10, "color": "#7e8385"}}, "rangemode": "tozero", "tickfont": {"size": 10, "family": "Geist Mono, monospace", "color": "#5b5f61"}, "ticksuffix": " ms", "gridcolor": "#1c2225", "linecolor": "#2d3335"}, "camera": {"eye": {"x": 1.6, "y": -1.6, "z": 0.9}}}},
|
||||
{responsive: true, displayModeBar: true, displaylogo: false,
|
||||
modeBarButtonsToRemove: ["toImage","sendDataToCloud"]});
|
||||
</script>
|
||||
</div><div class="chart-card">
|
||||
<div class="chart-card-header">
|
||||
<span class="model-tag">qwen3-4b</span>
|
||||
</div>
|
||||
<div id="sw_tpot_ms_qwen3_4b"></div>
|
||||
<script>
|
||||
Plotly.newPlot("sw_tpot_ms_qwen3_4b", [{"type": "scatter3d", "mode": "lines+markers", "x": [10, 100, 500], "y": ["May 01", "May 01", "May 01"], "z": [25.875402649398893, 25.884080055402592, 27.492373346467502], "name": "luminal backend", "legendgroup": "python_luminal", "showlegend": true, "line": {"color": "#a855f7", "width": 5}, "marker": {"color": "#a855f7", "size": 4}, "hovertemplate": "<b>luminal backend</b><br>s=%{x} iters<br>%{z:.1f} ms<br>May 01 \u00b7 b2bd91f5<extra></extra>"}, {"type": "scatter3d", "mode": "lines+markers", "x": [10, 100, 500], "y": ["May 01", "May 01", "May 01"], "z": [40.64, 39.98, 55.37], "name": "Rust (luminal)", "legendgroup": "rust", "showlegend": true, "line": {"color": "#e8855a", "width": 5}, "marker": {"color": "#e8855a", "size": 4}, "hovertemplate": "<b>Rust (luminal)</b><br>s=%{x} iters<br>%{z:.1f} ms<br>May 01 \u00b7 b2bd91f5<extra></extra>"}], {"paper_bgcolor": "#141b1d", "font": {"family": "Geist, system-ui, sans-serif", "color": "#d7d8d9", "size": 11}, "height": 420, "margin": {"t": 20, "b": 0, "l": 0, "r": 0}, "legend": {"orientation": "h", "y": -0.05, "x": 0, "font": {"size": 11, "color": "#a1a4a5", "family": "Geist Mono, monospace"}, "bgcolor": "rgba(0,0,0,0)"}, "hoverlabel": {"bgcolor": "#1c2225", "bordercolor": "#2d3335", "font": {"size": 12, "color": "#d7d8d9", "family": "Geist Mono, monospace"}}, "scene": {"bgcolor": "#0d1416", "xaxis": {"title": {"text": "search iters", "font": {"size": 10, "color": "#7e8385"}}, "type": "log", "tickvals": [5, 10, 20, 50, 100, 500], "ticktext": ["5", "10", "20", "50", "100", "500"], "tickfont": {"size": 10, "family": "Geist Mono, monospace", "color": "#5b5f61"}, "gridcolor": "#1c2225", "linecolor": "#2d3335", "zerolinecolor": "#2d3335"}, "yaxis": {"title": {"text": "run", "font": {"size": 10, "color": "#7e8385"}}, "tickfont": {"size": 10, "family": "Geist Mono, monospace", "color": "#5b5f61"}, "gridcolor": "#1c2225", "linecolor": "#2d3335"}, "zaxis": {"title": {"text": "ms", "font": {"size": 10, "color": "#7e8385"}}, "rangemode": "tozero", "tickfont": {"size": 10, "family": "Geist Mono, monospace", "color": "#5b5f61"}, "ticksuffix": " ms", "gridcolor": "#1c2225", "linecolor": "#2d3335"}, "camera": {"eye": {"x": 1.6, "y": -1.6, "z": 0.9}}}},
|
||||
{responsive: true, displayModeBar: true, displaylogo: false,
|
||||
modeBarButtonsToRemove: ["toImage","sendDataToCloud"]});
|
||||
</script>
|
||||
</div><div class="chart-card">
|
||||
<div class="chart-card-header">
|
||||
<span class="model-tag">gemma3-4b</span>
|
||||
</div>
|
||||
<div id="sw_tpot_ms_gemma3_4b"></div>
|
||||
<script>
|
||||
Plotly.newPlot("sw_tpot_ms_gemma3_4b", [{"type": "scatter3d", "mode": "lines+markers", "x": [10, 100, 500], "y": ["May 01", "May 01", "May 01"], "z": [37.47, 41.95, 37.25], "name": "Rust (luminal)", "legendgroup": "rust", "showlegend": true, "line": {"color": "#e8855a", "width": 5}, "marker": {"color": "#e8855a", "size": 4}, "hovertemplate": "<b>Rust (luminal)</b><br>s=%{x} iters<br>%{z:.1f} ms<br>May 01 \u00b7 b2bd91f5<extra></extra>"}], {"paper_bgcolor": "#141b1d", "font": {"family": "Geist, system-ui, sans-serif", "color": "#d7d8d9", "size": 11}, "height": 420, "margin": {"t": 20, "b": 0, "l": 0, "r": 0}, "legend": {"orientation": "h", "y": -0.05, "x": 0, "font": {"size": 11, "color": "#a1a4a5", "family": "Geist Mono, monospace"}, "bgcolor": "rgba(0,0,0,0)"}, "hoverlabel": {"bgcolor": "#1c2225", "bordercolor": "#2d3335", "font": {"size": 12, "color": "#d7d8d9", "family": "Geist Mono, monospace"}}, "scene": {"bgcolor": "#0d1416", "xaxis": {"title": {"text": "search iters", "font": {"size": 10, "color": "#7e8385"}}, "type": "log", "tickvals": [5, 10, 20, 50, 100, 500], "ticktext": ["5", "10", "20", "50", "100", "500"], "tickfont": {"size": 10, "family": "Geist Mono, monospace", "color": "#5b5f61"}, "gridcolor": "#1c2225", "linecolor": "#2d3335", "zerolinecolor": "#2d3335"}, "yaxis": {"title": {"text": "run", "font": {"size": 10, "color": "#7e8385"}}, "tickfont": {"size": 10, "family": "Geist Mono, monospace", "color": "#5b5f61"}, "gridcolor": "#1c2225", "linecolor": "#2d3335"}, "zaxis": {"title": {"text": "ms", "font": {"size": 10, "color": "#7e8385"}}, "rangemode": "tozero", "tickfont": {"size": 10, "family": "Geist Mono, monospace", "color": "#5b5f61"}, "ticksuffix": " ms", "gridcolor": "#1c2225", "linecolor": "#2d3335"}, "camera": {"eye": {"x": 1.6, "y": -1.6, "z": 0.9}}}},
|
||||
{responsive: true, displayModeBar: true, displaylogo: false,
|
||||
modeBarButtonsToRemove: ["toImage","sendDataToCloud"]});
|
||||
</script>
|
||||
</div><div class="chart-card">
|
||||
<div class="chart-card-header">
|
||||
<span class="model-tag">qwen3-moe</span>
|
||||
</div>
|
||||
<div id="sw_tpot_ms_qwen3_moe"></div>
|
||||
<script>
|
||||
Plotly.newPlot("sw_tpot_ms_qwen3_moe", [{"type": "scatter3d", "mode": "lines+markers", "x": [10, 100, 500], "y": ["May 01", "May 01", "May 01"], "z": [1166.6714247548953, 1167.2746865515364, 1168.7990181031637], "name": "luminal backend", "legendgroup": "python_luminal", "showlegend": true, "line": {"color": "#a855f7", "width": 5}, "marker": {"color": "#a855f7", "size": 4}, "hovertemplate": "<b>luminal backend</b><br>s=%{x} iters<br>%{z:.1f} ms<br>May 01 \u00b7 b2bd91f5<extra></extra>"}, {"type": "scatter3d", "mode": "lines+markers", "x": [10, 100, 500], "y": ["May 01", "May 01", "May 01"], "z": [59.6, 48.79, 48.88], "name": "Rust (luminal)", "legendgroup": "rust", "showlegend": true, "line": {"color": "#e8855a", "width": 5}, "marker": {"color": "#e8855a", "size": 4}, "hovertemplate": "<b>Rust (luminal)</b><br>s=%{x} iters<br>%{z:.1f} ms<br>May 01 \u00b7 b2bd91f5<extra></extra>"}], {"paper_bgcolor": "#141b1d", "font": {"family": "Geist, system-ui, sans-serif", "color": "#d7d8d9", "size": 11}, "height": 420, "margin": {"t": 20, "b": 0, "l": 0, "r": 0}, "legend": {"orientation": "h", "y": -0.05, "x": 0, "font": {"size": 11, "color": "#a1a4a5", "family": "Geist Mono, monospace"}, "bgcolor": "rgba(0,0,0,0)"}, "hoverlabel": {"bgcolor": "#1c2225", "bordercolor": "#2d3335", "font": {"size": 12, "color": "#d7d8d9", "family": "Geist Mono, monospace"}}, "scene": {"bgcolor": "#0d1416", "xaxis": {"title": {"text": "search iters", "font": {"size": 10, "color": "#7e8385"}}, "type": "log", "tickvals": [5, 10, 20, 50, 100, 500], "ticktext": ["5", "10", "20", "50", "100", "500"], "tickfont": {"size": 10, "family": "Geist Mono, monospace", "color": "#5b5f61"}, "gridcolor": "#1c2225", "linecolor": "#2d3335", "zerolinecolor": "#2d3335"}, "yaxis": {"title": {"text": "run", "font": {"size": 10, "color": "#7e8385"}}, "tickfont": {"size": 10, "family": "Geist Mono, monospace", "color": "#5b5f61"}, "gridcolor": "#1c2225", "linecolor": "#2d3335"}, "zaxis": {"title": {"text": "ms", "font": {"size": 10, "color": "#7e8385"}}, "rangemode": "tozero", "tickfont": {"size": 10, "family": "Geist Mono, monospace", "color": "#5b5f61"}, "ticksuffix": " ms", "gridcolor": "#1c2225", "linecolor": "#2d3335"}, "camera": {"eye": {"x": 1.6, "y": -1.6, "z": 0.9}}}},
|
||||
{responsive: true, displayModeBar: true, displaylogo: false,
|
||||
modeBarButtonsToRemove: ["toImage","sendDataToCloud"]});
|
||||
</script>
|
||||
</div>
|
||||
</div>
|
||||
</section>
|
||||
<section>
|
||||
<div class="section-header">
|
||||
<span class="section-eyebrow">sweep · 3d</span>
|
||||
<h2 class="section-title">Time to Search <span class="unit">vs search budget · over time</span></h2>
|
||||
<span class="section-tag">1 run</span>
|
||||
</div>
|
||||
<p class="sweep-hint">Drag to rotate · scroll to zoom · each curve = one run</p>
|
||||
<div class="chart-grid" style="grid-template-columns: repeat(4, 1fr)">
|
||||
<div class="chart-card">
|
||||
<div class="chart-card-header">
|
||||
<span class="model-tag">llama-8b</span>
|
||||
</div>
|
||||
<div id="sw_compile_ms_llama_8b"></div>
|
||||
<script>
|
||||
Plotly.newPlot("sw_compile_ms_llama_8b", [{"type": "scatter3d", "mode": "lines+markers", "x": [10, 100, 500], "y": ["May 01", "May 01", "May 01"], "z": [28.428826077957638, 43.57440591201885, 95.52432684396626], "name": "luminal backend", "legendgroup": "python_luminal", "showlegend": true, "line": {"color": "#a855f7", "width": 5}, "marker": {"color": "#a855f7", "size": 4}, "hovertemplate": "<b>luminal backend</b><br>s=%{x} iters<br>%{z:.1f} sec<br>May 01 \u00b7 b2bd91f5<extra></extra>"}, {"type": "scatter3d", "mode": "lines+markers", "x": [10, 100, 500], "y": ["May 01", "May 01", "May 01"], "z": [15.14307, 30.12727, 84.87889], "name": "Rust (luminal)", "legendgroup": "rust", "showlegend": true, "line": {"color": "#e8855a", "width": 5}, "marker": {"color": "#e8855a", "size": 4}, "hovertemplate": "<b>Rust (luminal)</b><br>s=%{x} iters<br>%{z:.1f} sec<br>May 01 \u00b7 b2bd91f5<extra></extra>"}], {"paper_bgcolor": "#141b1d", "font": {"family": "Geist, system-ui, sans-serif", "color": "#d7d8d9", "size": 11}, "height": 420, "margin": {"t": 20, "b": 0, "l": 0, "r": 0}, "legend": {"orientation": "h", "y": -0.05, "x": 0, "font": {"size": 11, "color": "#a1a4a5", "family": "Geist Mono, monospace"}, "bgcolor": "rgba(0,0,0,0)"}, "hoverlabel": {"bgcolor": "#1c2225", "bordercolor": "#2d3335", "font": {"size": 12, "color": "#d7d8d9", "family": "Geist Mono, monospace"}}, "scene": {"bgcolor": "#0d1416", "xaxis": {"title": {"text": "search iters", "font": {"size": 10, "color": "#7e8385"}}, "type": "log", "tickvals": [5, 10, 20, 50, 100, 500], "ticktext": ["5", "10", "20", "50", "100", "500"], "tickfont": {"size": 10, "family": "Geist Mono, monospace", "color": "#5b5f61"}, "gridcolor": "#1c2225", "linecolor": "#2d3335", "zerolinecolor": "#2d3335"}, "yaxis": {"title": {"text": "run", "font": {"size": 10, "color": "#7e8385"}}, "tickfont": {"size": 10, "family": "Geist Mono, monospace", "color": "#5b5f61"}, "gridcolor": "#1c2225", "linecolor": "#2d3335"}, "zaxis": {"title": {"text": "sec", "font": {"size": 10, "color": "#7e8385"}}, "rangemode": "tozero", "tickfont": {"size": 10, "family": "Geist Mono, monospace", "color": "#5b5f61"}, "ticksuffix": " sec", "gridcolor": "#1c2225", "linecolor": "#2d3335"}, "camera": {"eye": {"x": 1.6, "y": -1.6, "z": 0.9}}}},
|
||||
{responsive: true, displayModeBar: true, displaylogo: false,
|
||||
modeBarButtonsToRemove: ["toImage","sendDataToCloud"]});
|
||||
</script>
|
||||
</div><div class="chart-card">
|
||||
<div class="chart-card-header">
|
||||
<span class="model-tag">qwen3-4b</span>
|
||||
</div>
|
||||
<div id="sw_compile_ms_qwen3_4b"></div>
|
||||
<script>
|
||||
Plotly.newPlot("sw_compile_ms_qwen3_4b", [{"type": "scatter3d", "mode": "lines+markers", "x": [10, 100, 500], "y": ["May 01", "May 01", "May 01"], "z": [37.92102829599753, 54.08867314597592, 118.29659596900456], "name": "luminal backend", "legendgroup": "python_luminal", "showlegend": true, "line": {"color": "#a855f7", "width": 5}, "marker": {"color": "#a855f7", "size": 4}, "hovertemplate": "<b>luminal backend</b><br>s=%{x} iters<br>%{z:.1f} sec<br>May 01 \u00b7 b2bd91f5<extra></extra>"}, {"type": "scatter3d", "mode": "lines+markers", "x": [10, 100, 500], "y": ["May 01", "May 01", "May 01"], "z": [12.448030000000001, 27.06796, 81.89342], "name": "Rust (luminal)", "legendgroup": "rust", "showlegend": true, "line": {"color": "#e8855a", "width": 5}, "marker": {"color": "#e8855a", "size": 4}, "hovertemplate": "<b>Rust (luminal)</b><br>s=%{x} iters<br>%{z:.1f} sec<br>May 01 \u00b7 b2bd91f5<extra></extra>"}], {"paper_bgcolor": "#141b1d", "font": {"family": "Geist, system-ui, sans-serif", "color": "#d7d8d9", "size": 11}, "height": 420, "margin": {"t": 20, "b": 0, "l": 0, "r": 0}, "legend": {"orientation": "h", "y": -0.05, "x": 0, "font": {"size": 11, "color": "#a1a4a5", "family": "Geist Mono, monospace"}, "bgcolor": "rgba(0,0,0,0)"}, "hoverlabel": {"bgcolor": "#1c2225", "bordercolor": "#2d3335", "font": {"size": 12, "color": "#d7d8d9", "family": "Geist Mono, monospace"}}, "scene": {"bgcolor": "#0d1416", "xaxis": {"title": {"text": "search iters", "font": {"size": 10, "color": "#7e8385"}}, "type": "log", "tickvals": [5, 10, 20, 50, 100, 500], "ticktext": ["5", "10", "20", "50", "100", "500"], "tickfont": {"size": 10, "family": "Geist Mono, monospace", "color": "#5b5f61"}, "gridcolor": "#1c2225", "linecolor": "#2d3335", "zerolinecolor": "#2d3335"}, "yaxis": {"title": {"text": "run", "font": {"size": 10, "color": "#7e8385"}}, "tickfont": {"size": 10, "family": "Geist Mono, monospace", "color": "#5b5f61"}, "gridcolor": "#1c2225", "linecolor": "#2d3335"}, "zaxis": {"title": {"text": "sec", "font": {"size": 10, "color": "#7e8385"}}, "rangemode": "tozero", "tickfont": {"size": 10, "family": "Geist Mono, monospace", "color": "#5b5f61"}, "ticksuffix": " sec", "gridcolor": "#1c2225", "linecolor": "#2d3335"}, "camera": {"eye": {"x": 1.6, "y": -1.6, "z": 0.9}}}},
|
||||
{responsive: true, displayModeBar: true, displaylogo: false,
|
||||
modeBarButtonsToRemove: ["toImage","sendDataToCloud"]});
|
||||
</script>
|
||||
</div><div class="chart-card">
|
||||
<div class="chart-card-header">
|
||||
<span class="model-tag">gemma3-4b</span>
|
||||
</div>
|
||||
<div id="sw_compile_ms_gemma3_4b"></div>
|
||||
<script>
|
||||
Plotly.newPlot("sw_compile_ms_gemma3_4b", [{"type": "scatter3d", "mode": "lines+markers", "x": [10, 100, 500], "y": ["May 01", "May 01", "May 01"], "z": [102.18644, 186.34269, 498.48983000000004], "name": "Rust (luminal)", "legendgroup": "rust", "showlegend": true, "line": {"color": "#e8855a", "width": 5}, "marker": {"color": "#e8855a", "size": 4}, "hovertemplate": "<b>Rust (luminal)</b><br>s=%{x} iters<br>%{z:.1f} sec<br>May 01 \u00b7 b2bd91f5<extra></extra>"}], {"paper_bgcolor": "#141b1d", "font": {"family": "Geist, system-ui, sans-serif", "color": "#d7d8d9", "size": 11}, "height": 420, "margin": {"t": 20, "b": 0, "l": 0, "r": 0}, "legend": {"orientation": "h", "y": -0.05, "x": 0, "font": {"size": 11, "color": "#a1a4a5", "family": "Geist Mono, monospace"}, "bgcolor": "rgba(0,0,0,0)"}, "hoverlabel": {"bgcolor": "#1c2225", "bordercolor": "#2d3335", "font": {"size": 12, "color": "#d7d8d9", "family": "Geist Mono, monospace"}}, "scene": {"bgcolor": "#0d1416", "xaxis": {"title": {"text": "search iters", "font": {"size": 10, "color": "#7e8385"}}, "type": "log", "tickvals": [5, 10, 20, 50, 100, 500], "ticktext": ["5", "10", "20", "50", "100", "500"], "tickfont": {"size": 10, "family": "Geist Mono, monospace", "color": "#5b5f61"}, "gridcolor": "#1c2225", "linecolor": "#2d3335", "zerolinecolor": "#2d3335"}, "yaxis": {"title": {"text": "run", "font": {"size": 10, "color": "#7e8385"}}, "tickfont": {"size": 10, "family": "Geist Mono, monospace", "color": "#5b5f61"}, "gridcolor": "#1c2225", "linecolor": "#2d3335"}, "zaxis": {"title": {"text": "sec", "font": {"size": 10, "color": "#7e8385"}}, "rangemode": "tozero", "tickfont": {"size": 10, "family": "Geist Mono, monospace", "color": "#5b5f61"}, "ticksuffix": " sec", "gridcolor": "#1c2225", "linecolor": "#2d3335"}, "camera": {"eye": {"x": 1.6, "y": -1.6, "z": 0.9}}}},
|
||||
{responsive: true, displayModeBar: true, displaylogo: false,
|
||||
modeBarButtonsToRemove: ["toImage","sendDataToCloud"]});
|
||||
</script>
|
||||
</div><div class="chart-card">
|
||||
<div class="chart-card-header">
|
||||
<span class="model-tag">qwen3-moe</span>
|
||||
</div>
|
||||
<div id="sw_compile_ms_qwen3_moe"></div>
|
||||
<script>
|
||||
Plotly.newPlot("sw_compile_ms_qwen3_moe", [{"type": "scatter3d", "mode": "lines+markers", "x": [10, 100, 500], "y": ["May 01", "May 01", "May 01"], "z": [93.47603664599592, 132.266081985028, 298.05094401398674], "name": "luminal backend", "legendgroup": "python_luminal", "showlegend": true, "line": {"color": "#a855f7", "width": 5}, "marker": {"color": "#a855f7", "size": 4}, "hovertemplate": "<b>luminal backend</b><br>s=%{x} iters<br>%{z:.1f} sec<br>May 01 \u00b7 b2bd91f5<extra></extra>"}, {"type": "scatter3d", "mode": "lines+markers", "x": [10, 100, 500], "y": ["May 01", "May 01", "May 01"], "z": [25.48138, 47.5342, 134.79345], "name": "Rust (luminal)", "legendgroup": "rust", "showlegend": true, "line": {"color": "#e8855a", "width": 5}, "marker": {"color": "#e8855a", "size": 4}, "hovertemplate": "<b>Rust (luminal)</b><br>s=%{x} iters<br>%{z:.1f} sec<br>May 01 \u00b7 b2bd91f5<extra></extra>"}], {"paper_bgcolor": "#141b1d", "font": {"family": "Geist, system-ui, sans-serif", "color": "#d7d8d9", "size": 11}, "height": 420, "margin": {"t": 20, "b": 0, "l": 0, "r": 0}, "legend": {"orientation": "h", "y": -0.05, "x": 0, "font": {"size": 11, "color": "#a1a4a5", "family": "Geist Mono, monospace"}, "bgcolor": "rgba(0,0,0,0)"}, "hoverlabel": {"bgcolor": "#1c2225", "bordercolor": "#2d3335", "font": {"size": 12, "color": "#d7d8d9", "family": "Geist Mono, monospace"}}, "scene": {"bgcolor": "#0d1416", "xaxis": {"title": {"text": "search iters", "font": {"size": 10, "color": "#7e8385"}}, "type": "log", "tickvals": [5, 10, 20, 50, 100, 500], "ticktext": ["5", "10", "20", "50", "100", "500"], "tickfont": {"size": 10, "family": "Geist Mono, monospace", "color": "#5b5f61"}, "gridcolor": "#1c2225", "linecolor": "#2d3335", "zerolinecolor": "#2d3335"}, "yaxis": {"title": {"text": "run", "font": {"size": 10, "color": "#7e8385"}}, "tickfont": {"size": 10, "family": "Geist Mono, monospace", "color": "#5b5f61"}, "gridcolor": "#1c2225", "linecolor": "#2d3335"}, "zaxis": {"title": {"text": "sec", "font": {"size": 10, "color": "#7e8385"}}, "rangemode": "tozero", "tickfont": {"size": 10, "family": "Geist Mono, monospace", "color": "#5b5f61"}, "ticksuffix": " sec", "gridcolor": "#1c2225", "linecolor": "#2d3335"}, "camera": {"eye": {"x": 1.6, "y": -1.6, "z": 0.9}}}},
|
||||
{responsive: true, displayModeBar: true, displaylogo: false,
|
||||
modeBarButtonsToRemove: ["toImage","sendDataToCloud"]});
|
||||
</script>
|
||||
</div>
|
||||
</div>
|
||||
</section>
|
||||
|
||||
</main>
|
||||
|
||||
<footer>
|
||||
<span>luminal · benchmark dashboard</span>
|
||||
<span>generated May 01, 2026 · 18:56</span>
|
||||
</footer>
|
||||
|
||||
</body>
|
||||
</html>
|
||||
242
benchmarks/ttft/db.py
Normal file
242
benchmarks/ttft/db.py
Normal file
@@ -0,0 +1,242 @@
|
||||
"""SQLite persistence for TTFT/TPOT benchmark runs.
|
||||
|
||||
Two tables:
|
||||
runs — one row per orchestrator invocation
|
||||
results — many rows per run, one per (path, config) combination
|
||||
|
||||
`results` carries every field that today's BENCH_RESULT JSON record carries.
|
||||
Per-iteration sample arrays (`ttft_ms_samples`, `tpot_ms_samples`) are kept as
|
||||
JSON TEXT — they're archival, no consumer aggregates over them.
|
||||
|
||||
The default DB path is benchmarks/ttft/bench.db (gitignored). Schema is
|
||||
created lazily on first connect.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
from typing import Any, Iterable
|
||||
|
||||
BENCH_DIR = Path(__file__).resolve().parent
|
||||
DEFAULT_DB_PATH = BENCH_DIR / "bench.db"
|
||||
|
||||
|
||||
_SCHEMA = """
|
||||
CREATE TABLE IF NOT EXISTS runs (
|
||||
run_id TEXT PRIMARY KEY,
|
||||
timestamp TEXT NOT NULL,
|
||||
git_commit TEXT,
|
||||
git_branch TEXT,
|
||||
gpu_name TEXT,
|
||||
gpu_driver TEXT,
|
||||
gpu_vram_mb INTEGER,
|
||||
cuda_version TEXT,
|
||||
mode TEXT NOT NULL -- 'single' | 'all-configs' | 'search-sweep' | 'ur-test' | 'ur-test-fast'
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS results (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
run_id TEXT NOT NULL REFERENCES runs(run_id) ON DELETE CASCADE,
|
||||
path TEXT NOT NULL,
|
||||
model TEXT NOT NULL,
|
||||
model_key TEXT,
|
||||
config TEXT NOT NULL,
|
||||
device TEXT,
|
||||
dtype TEXT,
|
||||
prompt_tokens INTEGER,
|
||||
iters INTEGER,
|
||||
decode_tokens INTEGER,
|
||||
search_iters INTEGER,
|
||||
ttft_ms REAL,
|
||||
ttft_ms_mean REAL,
|
||||
tpot_ms REAL,
|
||||
throughput_tps REAL,
|
||||
compile_ms REAL,
|
||||
note TEXT,
|
||||
error TEXT,
|
||||
ttft_ms_samples TEXT,
|
||||
tpot_ms_samples TEXT,
|
||||
created_at TEXT NOT NULL DEFAULT (datetime('now'))
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_results_run ON results(run_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_results_path ON results(path);
|
||||
CREATE INDEX IF NOT EXISTS idx_results_config ON results(config);
|
||||
CREATE INDEX IF NOT EXISTS idx_results_modelk ON results(model_key);
|
||||
"""
|
||||
|
||||
|
||||
# Columns that map 1:1 from a BENCH_RESULT record dict into `results`.
|
||||
_SCALAR_RESULT_COLS = (
|
||||
"path", "model", "model_key", "config",
|
||||
"device", "dtype",
|
||||
"prompt_tokens", "iters", "decode_tokens", "search_iters",
|
||||
"ttft_ms", "ttft_ms_mean", "tpot_ms", "throughput_tps", "compile_ms",
|
||||
"note", "error",
|
||||
)
|
||||
_SAMPLE_COLS = ("ttft_ms_samples", "tpot_ms_samples")
|
||||
_ALL_RESULT_COLS = ("run_id",) + _SCALAR_RESULT_COLS + _SAMPLE_COLS
|
||||
|
||||
|
||||
def connect(path: str | Path = DEFAULT_DB_PATH) -> sqlite3.Connection:
|
||||
"""Open (or create) the bench DB and ensure the schema exists."""
|
||||
p = Path(path)
|
||||
p.parent.mkdir(parents=True, exist_ok=True)
|
||||
conn = sqlite3.connect(p)
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute("PRAGMA foreign_keys = ON")
|
||||
conn.executescript(_SCHEMA)
|
||||
return conn
|
||||
|
||||
|
||||
def insert_run(
|
||||
conn: sqlite3.Connection,
|
||||
*,
|
||||
run_id: str,
|
||||
timestamp: str,
|
||||
mode: str,
|
||||
git_commit: str | None = None,
|
||||
git_branch: str | None = None,
|
||||
gpu_name: str | None = None,
|
||||
gpu_driver: str | None = None,
|
||||
gpu_vram_mb: int | None = None,
|
||||
cuda_version: str | None = None,
|
||||
if_exists: str = "ignore",
|
||||
) -> str:
|
||||
"""Insert a run row. if_exists='ignore' (default) leaves an existing
|
||||
row untouched; 'replace' overwrites."""
|
||||
verb = {"ignore": "INSERT OR IGNORE", "replace": "INSERT OR REPLACE"}[if_exists]
|
||||
conn.execute(
|
||||
f"""{verb} INTO runs
|
||||
(run_id, timestamp, git_commit, git_branch,
|
||||
gpu_name, gpu_driver, gpu_vram_mb, cuda_version, mode)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""",
|
||||
(run_id, timestamp, git_commit, git_branch,
|
||||
gpu_name, gpu_driver, gpu_vram_mb, cuda_version, mode),
|
||||
)
|
||||
return run_id
|
||||
|
||||
|
||||
def insert_result(conn: sqlite3.Connection, run_id: str, record: dict[str, Any]) -> int:
|
||||
"""Insert one BENCH_RESULT-shaped record under the given run_id."""
|
||||
values = [run_id]
|
||||
for col in _SCALAR_RESULT_COLS:
|
||||
values.append(record.get(col))
|
||||
for col in _SAMPLE_COLS:
|
||||
v = record.get(col)
|
||||
values.append(json.dumps(v) if v is not None else None)
|
||||
placeholders = ", ".join(["?"] * len(_ALL_RESULT_COLS))
|
||||
cols = ", ".join(_ALL_RESULT_COLS)
|
||||
cur = conn.execute(
|
||||
f"INSERT INTO results ({cols}) VALUES ({placeholders})",
|
||||
values,
|
||||
)
|
||||
return cur.lastrowid
|
||||
|
||||
|
||||
def insert_results(conn: sqlite3.Connection, run_id: str, records: Iterable[dict[str, Any]]) -> int:
|
||||
"""Bulk-insert; returns count."""
|
||||
n = 0
|
||||
for r in records:
|
||||
insert_result(conn, run_id, r)
|
||||
n += 1
|
||||
return n
|
||||
|
||||
|
||||
def latest_run_id(conn: sqlite3.Connection) -> str | None:
|
||||
row = conn.execute(
|
||||
"SELECT run_id FROM runs ORDER BY timestamp DESC, run_id DESC LIMIT 1"
|
||||
).fetchone()
|
||||
return row["run_id"] if row else None
|
||||
|
||||
|
||||
def load_run(conn: sqlite3.Connection, run_id: str) -> dict[str, Any] | None:
|
||||
row = conn.execute("SELECT * FROM runs WHERE run_id = ?", (run_id,)).fetchone()
|
||||
return dict(row) if row else None
|
||||
|
||||
|
||||
def load_runs(conn: sqlite3.Connection) -> list[dict[str, Any]]:
|
||||
"""All runs, oldest → newest."""
|
||||
rows = conn.execute(
|
||||
"SELECT * FROM runs ORDER BY timestamp ASC, run_id ASC"
|
||||
).fetchall()
|
||||
return [dict(r) for r in rows]
|
||||
|
||||
|
||||
def _row_to_record(row: sqlite3.Row) -> dict[str, Any]:
|
||||
"""Convert a results row into a BENCH_RESULT-shaped dict, stripping NULLs
|
||||
so consumers see the same shape they did with JSON."""
|
||||
out: dict[str, Any] = {}
|
||||
for col in _SCALAR_RESULT_COLS:
|
||||
v = row[col]
|
||||
if v is not None:
|
||||
out[col] = v
|
||||
for col in _SAMPLE_COLS:
|
||||
v = row[col]
|
||||
if v is not None:
|
||||
out[col] = json.loads(v)
|
||||
return out
|
||||
|
||||
|
||||
def load_results(conn: sqlite3.Connection, run_id: str) -> list[dict[str, Any]]:
|
||||
"""All results for one run, in insertion order."""
|
||||
rows = conn.execute(
|
||||
"SELECT * FROM results WHERE run_id = ? ORDER BY id ASC", (run_id,)
|
||||
).fetchall()
|
||||
return [_row_to_record(r) for r in rows]
|
||||
|
||||
|
||||
def load_history(conn: sqlite3.Connection) -> list[dict[str, Any]]:
|
||||
"""Mirror the legacy gen_dashboard.load_history() shape:
|
||||
[{"meta": {...}, "results": [...], "sweep": [...]}], sorted oldest→newest.
|
||||
Splits results vs sweep by config-startswith('s=')."""
|
||||
out = []
|
||||
for run in load_runs(conn):
|
||||
run_id = run["run_id"]
|
||||
meta = {
|
||||
"run_id": run_id,
|
||||
"timestamp": run["timestamp"],
|
||||
"git_commit": run["git_commit"] or "?",
|
||||
"git_branch": run["git_branch"] or "?",
|
||||
}
|
||||
if run["gpu_name"] is not None:
|
||||
meta["gpu_name"] = run["gpu_name"]
|
||||
if run["gpu_driver"] is not None:
|
||||
meta["gpu_driver"] = run["gpu_driver"]
|
||||
if run["gpu_vram_mb"] is not None:
|
||||
meta["gpu_vram_mb"] = run["gpu_vram_mb"]
|
||||
if run["cuda_version"] is not None:
|
||||
meta["cuda_version"] = run["cuda_version"]
|
||||
|
||||
records = load_results(conn, run_id)
|
||||
comparison, sweep = [], []
|
||||
for r in records:
|
||||
(sweep if r.get("config", "").startswith("s=") else comparison).append(r)
|
||||
out.append({"meta": meta, "results": comparison, "sweep": sweep})
|
||||
return out
|
||||
|
||||
|
||||
# ── self-test ────────────────────────────────────────────────────────────────
|
||||
|
||||
if __name__ == "__main__":
|
||||
# In-memory smoke test: round-trip one record.
|
||||
conn = sqlite3.connect(":memory:")
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.executescript(_SCHEMA)
|
||||
insert_run(conn, run_id="test", timestamp="2026-04-27T00:00:00", mode="single")
|
||||
insert_result(conn, "test", {
|
||||
"path": "rust",
|
||||
"model": "test-model",
|
||||
"config": "default",
|
||||
"ttft_ms": 12.34,
|
||||
"ttft_ms_samples": [12.0, 12.5, 12.3],
|
||||
"search_iters": 500,
|
||||
})
|
||||
[row] = load_results(conn, "test")
|
||||
assert row["path"] == "rust", row
|
||||
assert row["ttft_ms"] == 12.34, row
|
||||
assert row["ttft_ms_samples"] == [12.0, 12.5, 12.3], row
|
||||
assert latest_run_id(conn) == "test"
|
||||
print("db.py smoke test ok")
|
||||
832
benchmarks/ttft/gen_dashboard.py
Normal file
832
benchmarks/ttft/gen_dashboard.py
Normal file
@@ -0,0 +1,832 @@
|
||||
"""Time-series benchmark dashboard generator.
|
||||
|
||||
Reads every run from the SQLite DB (benchmarks/ttft/bench.db) and produces a
|
||||
single standalone HTML file with Plotly.js charts styled to match luminal.com.
|
||||
|
||||
Layout:
|
||||
TTFT over time → one chart per model, lines = execution paths
|
||||
TPOT over time → same
|
||||
|
||||
Usage:
|
||||
python3 benchmarks/ttft/gen_dashboard.py [--db PATH] [--out FILE]
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import db
|
||||
|
||||
BENCH_DIR = Path(__file__).resolve().parent
|
||||
|
||||
# Path colours – kept distinct against the dark green Luminal accent
|
||||
PATH_COLORS = {
|
||||
"python_baseline": "#5b5f61", # muted slate
|
||||
"python_torch_compile": "#3b82f6", # blue (luminal accent palette)
|
||||
"python_luminal": "#a855f7", # purple (luminal accent palette)
|
||||
"rust": "#e8855a", # warm orange – Rust brand feel
|
||||
}
|
||||
PATH_LABELS = {
|
||||
"python_baseline": "HF Baseline",
|
||||
"python_torch_compile": "torch.compile",
|
||||
"python_luminal": "luminal backend",
|
||||
"rust": "Rust (luminal)",
|
||||
}
|
||||
PATH_ORDER = ["python_baseline", "python_torch_compile", "python_luminal", "rust"]
|
||||
|
||||
# (key, short label, y-axis label, scale, axis ticksuffix)
|
||||
# scale is applied to raw value before plotting (e.g. ms → sec via 0.001).
|
||||
METRICS = [
|
||||
("ttft_ms", "TTFT", "Time to first token (ms)", 1.0, " ms"),
|
||||
("tpot_ms", "TPOT", "Time per output token (ms)", 1.0, " ms"),
|
||||
("compile_ms", "Time to Search", "Search time (sec)", 0.001, " sec"),
|
||||
]
|
||||
|
||||
|
||||
# ── data loading ─────────────────────────────────────────────────────────────
|
||||
|
||||
def load_history(db_path: Path) -> list[dict]:
|
||||
"""Return [{"meta", "results", "sweep"}, …] from the bench DB,
|
||||
oldest→newest. Same shape the legacy JSON loader returned."""
|
||||
if not Path(db_path).exists():
|
||||
return []
|
||||
conn = db.connect(db_path)
|
||||
return db.load_history(conn)
|
||||
|
||||
|
||||
def build_series(runs: list[dict]) -> tuple[dict, list[str], list[str]]:
|
||||
"""Returns (data, run_ids, run_labels).
|
||||
|
||||
- data[model][path][metric] = [(run_id, value, commit, ts), ...]
|
||||
`run_id` is the categorical x value; `ts` is kept for tooltip formatting.
|
||||
- run_ids: chronological list of every run that appears in the comparison data.
|
||||
- run_labels: parallel to run_ids; "MMM DD · HH:MM" for nice axis ticks.
|
||||
|
||||
The categorical x-axis (one column per run_id) replaces the previous
|
||||
`type: date` axis. With multiple runs on the same day, the date axis
|
||||
silently stacked them on one column; the category axis spaces them
|
||||
evenly so each run is visually distinct.
|
||||
"""
|
||||
data: dict = {}
|
||||
seen_run_ids: list[str] = []
|
||||
seen_ts: dict[str, str] = {}
|
||||
|
||||
for run in runs:
|
||||
run_id = run["meta"]["run_id"]
|
||||
ts = run["meta"]["timestamp"]
|
||||
commit = run["meta"].get("git_commit", "?")
|
||||
had_data = False
|
||||
for r in run["results"]:
|
||||
if r.get("error") or r.get("ttft_ms") is None:
|
||||
continue
|
||||
model = r.get("config", r.get("model", "unknown"))
|
||||
path = r.get("path", "unknown")
|
||||
data.setdefault(model, {}).setdefault(path, {})
|
||||
for metric, _, _, scale, _ in METRICS:
|
||||
val = r.get(metric)
|
||||
if val is not None:
|
||||
data[model][path].setdefault(metric, []).append(
|
||||
(run_id, val * scale, commit, ts)
|
||||
)
|
||||
had_data = True
|
||||
if had_data and run_id not in seen_ts:
|
||||
seen_run_ids.append(run_id)
|
||||
seen_ts[run_id] = ts
|
||||
|
||||
run_ids = sorted(seen_run_ids, key=lambda rid: seen_ts.get(rid, rid))
|
||||
run_labels = []
|
||||
for rid in run_ids:
|
||||
ts = seen_ts.get(rid, rid)
|
||||
try:
|
||||
run_labels.append(datetime.fromisoformat(ts).strftime("%b %d · %H:%M"))
|
||||
except ValueError:
|
||||
run_labels.append(rid[:16].replace("T", " "))
|
||||
return data, run_ids, run_labels
|
||||
|
||||
|
||||
def build_sweep_series(runs: list[dict]) -> tuple[dict, list[str]]:
|
||||
"""Collect sweep records from ALL runs for 3D charting.
|
||||
|
||||
Returns:
|
||||
data[model_key][path][metric][run_id] = {
|
||||
"label": str, # short date label for Y axis
|
||||
"commit": str,
|
||||
"points": [(iters, ms), …] # sorted by iters
|
||||
}
|
||||
run_ids: list[str] in chronological order (oldest → newest)
|
||||
"""
|
||||
data: dict = {}
|
||||
run_ids: list[str] = []
|
||||
|
||||
for run in runs:
|
||||
if not run.get("sweep"):
|
||||
continue
|
||||
run_id = run["meta"]["run_id"]
|
||||
commit = run["meta"].get("git_commit", "?")
|
||||
try:
|
||||
label = datetime.fromisoformat(run["meta"]["timestamp"]).strftime("%b %d")
|
||||
except ValueError:
|
||||
label = run_id[:10]
|
||||
if run_id not in run_ids:
|
||||
run_ids.append(run_id)
|
||||
|
||||
for r in run["sweep"]:
|
||||
if r.get("error"):
|
||||
continue
|
||||
n = r.get("search_iters")
|
||||
if n is None:
|
||||
cfg = r.get("config", "")
|
||||
if cfg.startswith("s="):
|
||||
try:
|
||||
n = int(cfg[2:])
|
||||
except ValueError:
|
||||
continue
|
||||
if n is None:
|
||||
continue
|
||||
model_key = r.get("model_key", "unknown")
|
||||
path = r.get("path", "unknown")
|
||||
for metric, _, _, scale, _ in METRICS:
|
||||
val = r.get(metric)
|
||||
if val is None:
|
||||
continue
|
||||
(data
|
||||
.setdefault(model_key, {})
|
||||
.setdefault(path, {})
|
||||
.setdefault(metric, {})
|
||||
.setdefault(run_id, {"label": label, "commit": commit, "points": []})
|
||||
["points"].append((n, val * scale)))
|
||||
|
||||
# Sort points within each run by search_iters
|
||||
for mk in data:
|
||||
for path in data[mk]:
|
||||
for metric in data[mk][path]:
|
||||
for run_id in data[mk][path][metric]:
|
||||
data[mk][path][metric][run_id]["points"].sort(key=lambda x: x[0])
|
||||
|
||||
return data, run_ids
|
||||
|
||||
|
||||
# ── chart building ────────────────────────────────────────────────────────────
|
||||
|
||||
def _traces_json(path_data: dict, metric: str, show_legend: bool, unit: str = " ms") -> str:
|
||||
traces = []
|
||||
for path in PATH_ORDER:
|
||||
if path not in path_data or metric not in path_data[path]:
|
||||
continue
|
||||
pts = path_data[path][metric]
|
||||
# pts: list of (run_id, val, commit, ts)
|
||||
trace = {
|
||||
"x": [p[0] for p in pts],
|
||||
"y": [p[1] for p in pts],
|
||||
"customdata": [[p[2], p[3]] for p in pts],
|
||||
"type": "scatter",
|
||||
"mode": "lines+markers",
|
||||
"name": PATH_LABELS.get(path, path),
|
||||
"line": {"color": PATH_COLORS.get(path, "#aaa"), "width": 2},
|
||||
"marker": {"size": 7, "symbol": "circle"},
|
||||
"connectgaps": False,
|
||||
"showlegend": show_legend,
|
||||
"hovertemplate": (
|
||||
f"<b>{PATH_LABELS.get(path, path)}</b><br>"
|
||||
"%{customdata[1]}<br>"
|
||||
f"%{{y:.1f}}{unit}<br>"
|
||||
"<span style='color:#7e8385'>commit %{customdata[0]}</span>"
|
||||
"<extra></extra>"
|
||||
),
|
||||
}
|
||||
traces.append(trace)
|
||||
return json.dumps(traces)
|
||||
|
||||
|
||||
_CHART_LAYOUT = {
|
||||
"plot_bgcolor": "#0d1416",
|
||||
"paper_bgcolor": "#141b1d",
|
||||
"font": {"family": "Geist, system-ui, sans-serif", "color": "#d7d8d9"},
|
||||
"margin": {"t": 16, "b": 48, "l": 52, "r": 12},
|
||||
"height": 280,
|
||||
"xaxis": {
|
||||
# Categorical: one column per run, evenly spaced. Same-day runs
|
||||
# used to collapse on a date axis; this keeps every run distinct.
|
||||
"type": "category",
|
||||
"categoryorder": "array", # categoryarray injected per chart
|
||||
"color": "#5b5f61",
|
||||
"gridcolor": "#1c2225",
|
||||
"linecolor": "#2d3335",
|
||||
"tickfont": {"size": 11, "family": "Geist Mono, monospace"},
|
||||
"tickangle": -30,
|
||||
"automargin": True,
|
||||
"zeroline": False,
|
||||
},
|
||||
"yaxis": {
|
||||
"rangemode": "tozero",
|
||||
"color": "#5b5f61",
|
||||
"gridcolor": "#1c2225",
|
||||
"linecolor": "#2d3335",
|
||||
"tickfont": {"size": 11, "family": "Geist Mono, monospace"},
|
||||
"ticksuffix": " ms",
|
||||
"zeroline": False,
|
||||
},
|
||||
"legend": {
|
||||
"orientation": "h",
|
||||
"y": -0.28,
|
||||
"x": 0,
|
||||
"font": {"size": 11, "color": "#a1a4a5"},
|
||||
"bgcolor": "rgba(0,0,0,0)",
|
||||
},
|
||||
"hoverlabel": {
|
||||
"bgcolor": "#1c2225",
|
||||
"bordercolor":"#2d3335",
|
||||
"font": {"size": 12, "color": "#d7d8d9", "family": "Geist Mono, monospace"},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _chart_card(div_id: str, model: str, traces_json: str, show_legend: bool,
|
||||
run_ids: list[str], run_labels: list[str], unit: str = " ms") -> str:
|
||||
layout = dict(_CHART_LAYOUT)
|
||||
xaxis = {
|
||||
**layout["xaxis"],
|
||||
"categoryarray": run_ids,
|
||||
"tickvals": run_ids,
|
||||
"ticktext": run_labels,
|
||||
}
|
||||
layout = {**layout,
|
||||
"xaxis": xaxis,
|
||||
"yaxis": {**layout["yaxis"], "ticksuffix": unit}}
|
||||
if not show_legend:
|
||||
layout = {**layout, "legend": {**layout["legend"], "visible": False},
|
||||
"margin": {**layout["margin"], "b": 16}}
|
||||
return f"""<div class="chart-card">
|
||||
<div class="chart-card-header">
|
||||
<span class="model-tag">{model}</span>
|
||||
</div>
|
||||
<div id="{div_id}"></div>
|
||||
<script>
|
||||
Plotly.newPlot("{div_id}", {traces_json}, {json.dumps(layout)},
|
||||
{{responsive: true, displayModeBar: false}});
|
||||
</script>
|
||||
</div>"""
|
||||
|
||||
|
||||
def _sweep_3d_traces_json(model_data: dict, metric: str, run_ids: list[str], unit: str = " ms") -> str:
|
||||
"""One scatter3d trace per (path, run) — same colour per path, stacked by run on Y."""
|
||||
traces = []
|
||||
path_legend_shown: set[str] = set()
|
||||
|
||||
for run_id in run_ids:
|
||||
for path in PATH_ORDER:
|
||||
run_map = model_data.get(path, {}).get(metric, {})
|
||||
if run_id not in run_map:
|
||||
continue
|
||||
entry = run_map[run_id]
|
||||
pts = entry["points"]
|
||||
label = entry["label"]
|
||||
commit = entry["commit"]
|
||||
color = PATH_COLORS.get(path, "#aaa")
|
||||
show_legend = path not in path_legend_shown
|
||||
path_legend_shown.add(path)
|
||||
|
||||
traces.append({
|
||||
"type": "scatter3d",
|
||||
"mode": "lines+markers",
|
||||
"x": [p[0] for p in pts], # search iters
|
||||
"y": [label] * len(pts), # run label (categorical)
|
||||
"z": [p[1] for p in pts], # value (already scaled by build_sweep_series)
|
||||
"name": PATH_LABELS.get(path, path),
|
||||
"legendgroup": path,
|
||||
"showlegend": show_legend,
|
||||
"line": {"color": color, "width": 5},
|
||||
"marker": {"color": color, "size": 4},
|
||||
"hovertemplate": (
|
||||
f"<b>{PATH_LABELS.get(path, path)}</b><br>"
|
||||
f"s=%{{x}} iters<br>%{{z:.1f}}{unit}<br>"
|
||||
f"{label} · {commit}"
|
||||
"<extra></extra>"
|
||||
),
|
||||
})
|
||||
|
||||
# Cross-run wire lines: for each path, connect same-budget points across
|
||||
# runs. Makes regressions at a fixed search budget visible as a kink in the
|
||||
# wireframe. Dashed + thinner than the per-run curves; legendgroup matches
|
||||
# the path so toggling one toggles both.
|
||||
for path in PATH_ORDER:
|
||||
metric_runs = model_data.get(path, {}).get(metric, {})
|
||||
if len(metric_runs) < 2:
|
||||
continue
|
||||
color = PATH_COLORS.get(path, "#aaa")
|
||||
# by_budget[iters] -> list of (run_label, value) in chronological order
|
||||
by_budget: dict = {}
|
||||
for run_id in run_ids:
|
||||
if run_id not in metric_runs:
|
||||
continue
|
||||
entry = metric_runs[run_id]
|
||||
for iters, val in entry["points"]:
|
||||
by_budget.setdefault(iters, []).append((entry["label"], val))
|
||||
for budget, items in sorted(by_budget.items()):
|
||||
if len(items) < 2:
|
||||
continue
|
||||
traces.append({
|
||||
"type": "scatter3d",
|
||||
"mode": "lines",
|
||||
"x": [budget] * len(items),
|
||||
"y": [it[0] for it in items],
|
||||
"z": [it[1] for it in items],
|
||||
"legendgroup": path,
|
||||
"showlegend": False,
|
||||
"line": {"color": color, "width": 2, "dash": "dash"},
|
||||
"hovertemplate": (
|
||||
f"<b>{PATH_LABELS.get(path, path)} @ s={budget}</b><br>"
|
||||
f"%{{y}}: %{{z:.1f}}{unit}"
|
||||
"<extra></extra>"
|
||||
),
|
||||
})
|
||||
return json.dumps(traces)
|
||||
|
||||
|
||||
_SWEEP_3D_LAYOUT = {
|
||||
"paper_bgcolor": "#141b1d",
|
||||
"font": {"family": "Geist, system-ui, sans-serif", "color": "#d7d8d9", "size": 11},
|
||||
"height": 420,
|
||||
"margin": {"t": 20, "b": 0, "l": 0, "r": 0},
|
||||
"legend": {
|
||||
"orientation": "h",
|
||||
"y": -0.05,
|
||||
"x": 0,
|
||||
"font": {"size": 11, "color": "#a1a4a5", "family": "Geist Mono, monospace"},
|
||||
"bgcolor": "rgba(0,0,0,0)",
|
||||
},
|
||||
"hoverlabel": {
|
||||
"bgcolor": "#1c2225",
|
||||
"bordercolor": "#2d3335",
|
||||
"font": {"size": 12, "color": "#d7d8d9", "family": "Geist Mono, monospace"},
|
||||
},
|
||||
"scene": {
|
||||
"bgcolor": "#0d1416",
|
||||
"xaxis": {
|
||||
"title": {"text": "search iters", "font": {"size": 10, "color": "#7e8385"}},
|
||||
"type": "log",
|
||||
"tickvals": [5, 10, 20, 50, 100, 500],
|
||||
"ticktext": ["5", "10", "20", "50", "100", "500"],
|
||||
"tickfont": {"size": 10, "family": "Geist Mono, monospace", "color": "#5b5f61"},
|
||||
"gridcolor": "#1c2225",
|
||||
"linecolor": "#2d3335",
|
||||
"zerolinecolor": "#2d3335",
|
||||
},
|
||||
"yaxis": {
|
||||
"title": {"text": "run", "font": {"size": 10, "color": "#7e8385"}},
|
||||
"tickfont": {"size": 10, "family": "Geist Mono, monospace", "color": "#5b5f61"},
|
||||
"gridcolor": "#1c2225",
|
||||
"linecolor": "#2d3335",
|
||||
},
|
||||
"zaxis": {
|
||||
"title": {"text": "ms", "font": {"size": 10, "color": "#7e8385"}},
|
||||
"rangemode": "tozero",
|
||||
"tickfont": {"size": 10, "family": "Geist Mono, monospace", "color": "#5b5f61"},
|
||||
"ticksuffix": " ms",
|
||||
"gridcolor": "#1c2225",
|
||||
"linecolor": "#2d3335",
|
||||
},
|
||||
"camera": {
|
||||
"eye": {"x": 1.6, "y": -1.6, "z": 0.9},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _sweep_3d_card(div_id: str, model: str, traces_json: str, unit: str = " ms") -> str:
|
||||
layout = {**_SWEEP_3D_LAYOUT,
|
||||
"scene": {**_SWEEP_3D_LAYOUT["scene"],
|
||||
"zaxis": {**_SWEEP_3D_LAYOUT["scene"]["zaxis"],
|
||||
"title": {**_SWEEP_3D_LAYOUT["scene"]["zaxis"]["title"],
|
||||
"text": unit.strip()},
|
||||
"ticksuffix": unit}}}
|
||||
return f"""<div class="chart-card">
|
||||
<div class="chart-card-header">
|
||||
<span class="model-tag">{model}</span>
|
||||
</div>
|
||||
<div id="{div_id}"></div>
|
||||
<script>
|
||||
Plotly.newPlot("{div_id}", {traces_json}, {json.dumps(layout)},
|
||||
{{responsive: true, displayModeBar: true, displaylogo: false,
|
||||
modeBarButtonsToRemove: ["toImage","sendDataToCloud"]}});
|
||||
</script>
|
||||
</div>"""
|
||||
|
||||
|
||||
# ── HTML assembly ─────────────────────────────────────────────────────────────
|
||||
|
||||
def build_html(runs: list[dict], data: dict,
|
||||
run_ids: list[str], run_labels: list[str],
|
||||
sweep_data: dict | None = None,
|
||||
sweep_run_ids: list[str] | None = None) -> str:
|
||||
# Preserve insertion order of models as seen across runs
|
||||
models = list(dict.fromkeys(
|
||||
r["config"]
|
||||
for run in runs
|
||||
for r in run["results"]
|
||||
if not r.get("config", "").startswith("s=") and not r.get("error")
|
||||
))
|
||||
|
||||
last_ts = ""
|
||||
if runs:
|
||||
raw = runs[-1]["meta"]["timestamp"]
|
||||
try:
|
||||
last_ts = datetime.fromisoformat(raw).strftime("%b %d, %Y · %H:%M")
|
||||
except ValueError:
|
||||
last_ts = raw[:16].replace("T", " ")
|
||||
|
||||
n_runs = len(runs)
|
||||
|
||||
sections_html = ""
|
||||
for metric_key, metric_label, ylabel, _scale, unit in METRICS:
|
||||
active_models = [
|
||||
m for m in models
|
||||
if any(metric_key in data.get(m, {}).get(p, {}) for p in PATH_ORDER)
|
||||
]
|
||||
if not active_models:
|
||||
continue
|
||||
|
||||
cards_html = ""
|
||||
first = True
|
||||
for model in active_models:
|
||||
path_data = data.get(model, {})
|
||||
div_id = f"c_{metric_key}_{model.replace('-','_').replace('.','_')}"
|
||||
traces = _traces_json(path_data, metric_key, show_legend=first, unit=unit)
|
||||
cards_html += _chart_card(div_id, model, traces, show_legend=first,
|
||||
run_ids=run_ids, run_labels=run_labels, unit=unit)
|
||||
first = False
|
||||
|
||||
n = len(active_models)
|
||||
# Clamp columns so charts don't get too narrow; wrap at 4
|
||||
cols = min(n, 4)
|
||||
sections_html += f"""
|
||||
<section>
|
||||
<div class="section-header">
|
||||
<span class="section-eyebrow">metric</span>
|
||||
<h2 class="section-title">{metric_label} <span class="unit">over time</span></h2>
|
||||
<span class="section-tag">{ylabel}</span>
|
||||
</div>
|
||||
<div class="chart-grid" style="grid-template-columns: repeat({cols}, 1fr)">
|
||||
{cards_html}
|
||||
</div>
|
||||
</section>"""
|
||||
|
||||
# ── sweep sections (3D) ──────────────────────────────────────────────────
|
||||
sweep_sections_html = ""
|
||||
if sweep_data and sweep_run_ids:
|
||||
sweep_models = list(sweep_data.keys())
|
||||
for metric_key, metric_label, ylabel, _scale, unit in METRICS:
|
||||
active = [
|
||||
m for m in sweep_models
|
||||
if any(
|
||||
run_id in sweep_data[m].get(p, {}).get(metric_key, {})
|
||||
for p in PATH_ORDER
|
||||
for run_id in sweep_run_ids
|
||||
)
|
||||
]
|
||||
if not active:
|
||||
continue
|
||||
cards_html = ""
|
||||
for model in active:
|
||||
div_id = f"sw_{metric_key}_{model.replace('-','_').replace('.','_')}"
|
||||
traces = _sweep_3d_traces_json(sweep_data[model], metric_key, sweep_run_ids, unit=unit)
|
||||
cards_html += _sweep_3d_card(div_id, model, traces, unit=unit)
|
||||
cols = min(len(active), 4)
|
||||
run_count = len(sweep_run_ids)
|
||||
sweep_sections_html += f"""
|
||||
<section>
|
||||
<div class="section-header">
|
||||
<span class="section-eyebrow">sweep · 3d</span>
|
||||
<h2 class="section-title">{metric_label} <span class="unit">vs search budget · over time</span></h2>
|
||||
<span class="section-tag">{run_count} run{"s" if run_count != 1 else ""}</span>
|
||||
</div>
|
||||
<p class="sweep-hint">Drag to rotate · scroll to zoom · each curve = one run</p>
|
||||
<div class="chart-grid" style="grid-template-columns: repeat({cols}, 1fr)">
|
||||
{cards_html}
|
||||
</div>
|
||||
</section>"""
|
||||
|
||||
return f"""<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1">
|
||||
<title>Luminal · Benchmark Dashboard</title>
|
||||
<link rel="preconnect" href="https://fonts.googleapis.com">
|
||||
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
|
||||
<link href="https://fonts.googleapis.com/css2?family=Geist:wght@300;400;500;600&family=Geist+Mono:wght@300;400;500&display=swap" rel="stylesheet">
|
||||
<script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
|
||||
<style>
|
||||
*, *::before, *::after {{ box-sizing: border-box; margin: 0; padding: 0; }}
|
||||
html {{ -webkit-font-smoothing: antialiased; scroll-behavior: smooth; }}
|
||||
|
||||
body {{
|
||||
font-family: 'Geist', system-ui, sans-serif;
|
||||
background: #030712;
|
||||
color: #d7d8d9;
|
||||
min-height: 100vh;
|
||||
line-height: 1.5;
|
||||
}}
|
||||
|
||||
/* ── NAV ── */
|
||||
nav {{
|
||||
position: sticky;
|
||||
top: 0;
|
||||
z-index: 50;
|
||||
height: 56px;
|
||||
background: rgba(8, 15, 17, 0.92);
|
||||
backdrop-filter: blur(8px);
|
||||
-webkit-backdrop-filter: blur(8px);
|
||||
border-bottom: 1px solid #2d3335;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
padding: 0 24px;
|
||||
gap: 0;
|
||||
}}
|
||||
.nav-brand {{
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
font-family: 'Geist Mono', monospace;
|
||||
font-size: 14px;
|
||||
font-weight: 500;
|
||||
letter-spacing: 0.05em;
|
||||
color: #2faa6e;
|
||||
text-decoration: none;
|
||||
}}
|
||||
.nav-dot {{
|
||||
width: 6px;
|
||||
height: 6px;
|
||||
background: #2faa6e;
|
||||
border-radius: 50%;
|
||||
flex-shrink: 0;
|
||||
animation: pulse-glow 2s ease-in-out infinite;
|
||||
}}
|
||||
.nav-sep {{
|
||||
color: #2d3335;
|
||||
margin: 0 14px;
|
||||
font-size: 18px;
|
||||
font-weight: 300;
|
||||
}}
|
||||
.nav-page {{
|
||||
font-family: 'Geist Mono', monospace;
|
||||
font-size: 11px;
|
||||
letter-spacing: 0.1em;
|
||||
text-transform: uppercase;
|
||||
color: #7e8385;
|
||||
}}
|
||||
|
||||
@keyframes pulse-glow {{
|
||||
0%, 100% {{ opacity: 1; }}
|
||||
50% {{ opacity: 0.35; }}
|
||||
}}
|
||||
|
||||
/* ── MAIN ── */
|
||||
main {{
|
||||
max-width: 1200px;
|
||||
margin: 0 auto;
|
||||
padding: 40px 24px 80px;
|
||||
}}
|
||||
|
||||
/* ── PAGE HEADER ── */
|
||||
.page-header {{
|
||||
margin-bottom: 40px;
|
||||
padding-bottom: 32px;
|
||||
border-bottom: 1px solid #1c2225;
|
||||
}}
|
||||
.page-eyebrow {{
|
||||
font-family: 'Geist Mono', monospace;
|
||||
font-size: 11px;
|
||||
letter-spacing: 0.1em;
|
||||
text-transform: uppercase;
|
||||
color: #2faa6e;
|
||||
margin-bottom: 10px;
|
||||
}}
|
||||
.page-title {{
|
||||
font-size: 30px;
|
||||
font-weight: 500;
|
||||
letter-spacing: -0.025em;
|
||||
color: #d7d8d9;
|
||||
margin-bottom: 10px;
|
||||
}}
|
||||
.page-meta {{
|
||||
font-size: 14px;
|
||||
color: #7e8385;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0;
|
||||
flex-wrap: wrap;
|
||||
}}
|
||||
.meta-sep {{
|
||||
font-family: 'Geist Mono', monospace;
|
||||
color: #2d3335;
|
||||
margin: 0 10px;
|
||||
}}
|
||||
.meta-val {{
|
||||
font-family: 'Geist Mono', monospace;
|
||||
font-size: 13px;
|
||||
color: #5b5f61;
|
||||
}}
|
||||
|
||||
/* ── LEGEND STRIP ── */
|
||||
.legend-strip {{
|
||||
display: flex;
|
||||
flex-wrap: wrap;
|
||||
gap: 6px;
|
||||
margin-bottom: 32px;
|
||||
}}
|
||||
.legend-pill {{
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 6px;
|
||||
font-family: 'Geist Mono', monospace;
|
||||
font-size: 11px;
|
||||
letter-spacing: 0.04em;
|
||||
color: #a1a4a5;
|
||||
background: #141b1d;
|
||||
border: 1px solid #2d3335;
|
||||
border-radius: 2px;
|
||||
padding: 4px 10px;
|
||||
}}
|
||||
.legend-swatch {{
|
||||
width: 8px;
|
||||
height: 8px;
|
||||
border-radius: 50%;
|
||||
flex-shrink: 0;
|
||||
}}
|
||||
|
||||
/* ── SECTIONS ── */
|
||||
section {{ margin-bottom: 48px; }}
|
||||
.section-header {{
|
||||
display: flex;
|
||||
align-items: baseline;
|
||||
gap: 10px;
|
||||
margin-bottom: 16px;
|
||||
padding-bottom: 12px;
|
||||
border-bottom: 1px solid #1c2225;
|
||||
flex-wrap: wrap;
|
||||
}}
|
||||
.section-eyebrow {{
|
||||
font-family: 'Geist Mono', monospace;
|
||||
font-size: 11px;
|
||||
letter-spacing: 0.1em;
|
||||
text-transform: uppercase;
|
||||
color: #404647;
|
||||
}}
|
||||
.section-title {{
|
||||
font-size: 18px;
|
||||
font-weight: 500;
|
||||
color: #d7d8d9;
|
||||
letter-spacing: -0.01em;
|
||||
}}
|
||||
.section-title .unit {{
|
||||
color: #7e8385;
|
||||
font-weight: 400;
|
||||
}}
|
||||
.section-tag {{
|
||||
font-family: 'Geist Mono', monospace;
|
||||
font-size: 11px;
|
||||
letter-spacing: 0.04em;
|
||||
text-transform: uppercase;
|
||||
color: #2faa6e;
|
||||
background: #162322;
|
||||
border: 1px solid #1c372e;
|
||||
padding: 2px 8px;
|
||||
border-radius: 2px;
|
||||
margin-left: auto;
|
||||
}}
|
||||
|
||||
/* ── CHART GRID ── */
|
||||
.chart-grid {{
|
||||
display: grid;
|
||||
gap: 10px;
|
||||
}}
|
||||
.chart-card {{
|
||||
background: #141b1d;
|
||||
border: 1px solid #2d3335;
|
||||
border-radius: 2px;
|
||||
overflow: hidden;
|
||||
transition: border-color 150ms;
|
||||
min-width: 0;
|
||||
}}
|
||||
.chart-card:hover {{ border-color: #404647; }}
|
||||
.chart-card-header {{
|
||||
padding: 10px 14px 0;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
}}
|
||||
.model-tag {{
|
||||
font-family: 'Geist Mono', monospace;
|
||||
font-size: 11px;
|
||||
letter-spacing: 0.06em;
|
||||
text-transform: uppercase;
|
||||
color: #7e8385;
|
||||
}}
|
||||
|
||||
/* ── FOOTER ── */
|
||||
footer {{
|
||||
max-width: 1200px;
|
||||
margin: 0 auto;
|
||||
padding: 20px 24px;
|
||||
border-top: 1px solid #1c2225;
|
||||
font-family: 'Geist Mono', monospace;
|
||||
font-size: 11px;
|
||||
letter-spacing: 0.04em;
|
||||
color: #404647;
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
flex-wrap: wrap;
|
||||
gap: 8px;
|
||||
}}
|
||||
|
||||
.section-divider {{
|
||||
border: none;
|
||||
border-top: 1px solid #1c2225;
|
||||
margin: 8px 0 40px;
|
||||
}}
|
||||
.sweep-hint {{
|
||||
font-family: 'Geist Mono', monospace;
|
||||
font-size: 11px;
|
||||
letter-spacing: 0.04em;
|
||||
color: #404647;
|
||||
margin-bottom: 12px;
|
||||
}}
|
||||
|
||||
@media (max-width: 768px) {{
|
||||
.chart-grid {{ grid-template-columns: 1fr !important; }}
|
||||
.page-title {{ font-size: 22px; }}
|
||||
}}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
|
||||
<nav>
|
||||
<a class="nav-brand" href="https://luminal.com">
|
||||
<span class="nav-dot"></span>luminal
|
||||
</a>
|
||||
<span class="nav-sep">/</span>
|
||||
<span class="nav-page">benchmarks</span>
|
||||
</nav>
|
||||
|
||||
<main>
|
||||
|
||||
<header class="page-header">
|
||||
<p class="page-eyebrow">performance · time-series</p>
|
||||
<h1 class="page-title">Benchmark Dashboard</h1>
|
||||
<div class="page-meta">
|
||||
<span>Last updated</span>
|
||||
<span class="meta-sep">·</span>
|
||||
<span class="meta-val">{last_ts}</span>
|
||||
<span class="meta-sep">·</span>
|
||||
<span class="meta-val">{n_runs} run{"s" if n_runs != 1 else ""} in history</span>
|
||||
</div>
|
||||
</header>
|
||||
|
||||
<div class="legend-strip">
|
||||
{"".join(
|
||||
f'<div class="legend-pill"><span class="legend-swatch" style="background:{PATH_COLORS[p]}"></span>{PATH_LABELS[p]}</div>'
|
||||
for p in PATH_ORDER
|
||||
)}
|
||||
</div>
|
||||
|
||||
{sections_html}
|
||||
{"<hr class='section-divider'>" + sweep_sections_html if sweep_sections_html else ""}
|
||||
|
||||
</main>
|
||||
|
||||
<footer>
|
||||
<span>luminal · benchmark dashboard</span>
|
||||
<span>generated {last_ts}</span>
|
||||
</footer>
|
||||
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
|
||||
# ── entry point ───────────────────────────────────────────────────────────────
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--db", default=str(db.DEFAULT_DB_PATH),
|
||||
help=f"SQLite bench DB (default: {db.DEFAULT_DB_PATH})")
|
||||
ap.add_argument("--out", default=str(BENCH_DIR / "dashboard.html"),
|
||||
help="Output HTML file")
|
||||
args = ap.parse_args()
|
||||
|
||||
runs = load_history(Path(args.db))
|
||||
if not runs:
|
||||
print(f"No runs found in {args.db}. Run --ur-test (or backfill) first.")
|
||||
return
|
||||
|
||||
data, run_ids, run_labels = build_series(runs)
|
||||
sweep_data, sweep_run_ids = build_sweep_series(runs)
|
||||
html = build_html(runs, data, run_ids, run_labels, sweep_data, sweep_run_ids)
|
||||
Path(args.out).write_text(html)
|
||||
|
||||
print(f"wrote {args.out} ({len(runs)} runs, {sum(len(v) for v in data.values())} model×path series)")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
349
benchmarks/ttft/gen_report.py
Normal file
349
benchmarks/ttft/gen_report.py
Normal file
@@ -0,0 +1,349 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Generate a standalone HTML benchmark report from a single benchmark run.
|
||||
|
||||
Usage:
|
||||
python3 gen_report.py [--db PATH] [--run RUN_ID] [--out report.html] [--title "..."]
|
||||
|
||||
Sections are split out of a single run automatically:
|
||||
- per-model_key, "comparison" (configs not matching s=N) → grouped bar chart
|
||||
- per-model_key, "sweep" (configs matching s=N) → line chart (log X)
|
||||
For runs without model_key (e.g. single-config runs), one section per detected
|
||||
shape is produced instead.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import db
|
||||
|
||||
PATH_ORDER = ["python_baseline", "python_torch_compile", "python_luminal", "rust"]
|
||||
PATH_LABELS = {
|
||||
"python_baseline": "HF Baseline",
|
||||
"python_torch_compile": "torch.compile",
|
||||
"python_luminal": "luminal backend",
|
||||
"rust": "Rust (luminal)",
|
||||
}
|
||||
PATH_COLORS = {
|
||||
"python_baseline": "#888888",
|
||||
"python_torch_compile": "#5ab552",
|
||||
"python_luminal": "#4c9ed9",
|
||||
"rust": "#d97a4c",
|
||||
}
|
||||
|
||||
|
||||
# ── helpers ──────────────────────────────────────────────────────────────────
|
||||
|
||||
def _fmt(v, decimals=1, suffix=""):
|
||||
return f"{v:.{decimals}f}{suffix}" if v is not None else "—"
|
||||
|
||||
def _section_title(path: Path) -> str:
|
||||
stem = path.stem.replace("_", " ").replace("-", " ")
|
||||
return stem.title()
|
||||
|
||||
def _is_sweep(configs: list[str]) -> bool:
|
||||
return bool(configs) and all(re.fullmatch(r"s=\d+", c) for c in configs)
|
||||
|
||||
def _group_by_config(results: list[dict]) -> dict[str, dict[str, dict]]:
|
||||
"""Return {config: {path: result_dict}}."""
|
||||
out: dict[str, dict[str, dict]] = {}
|
||||
for r in results:
|
||||
cfg = r.get("config", "default")
|
||||
out.setdefault(cfg, {})[r["path"]] = r
|
||||
return out
|
||||
|
||||
|
||||
# ── chart builders (return Plotly figure dicts) ───────────────────────────────
|
||||
|
||||
def _bar_figure(by_config: dict, metric: str, title: str,
|
||||
scale: float = 1.0, unit: str = "ms") -> dict:
|
||||
configs = list(by_config.keys())
|
||||
traces = []
|
||||
for path in PATH_ORDER:
|
||||
ys, texts = [], []
|
||||
for cfg in configs:
|
||||
r = by_config[cfg].get(path)
|
||||
raw = r.get(metric) if r and not r.get("error") else None
|
||||
v = raw * scale if raw is not None else None
|
||||
ys.append(v if v is not None else 0)
|
||||
texts.append(f"{v:.1f} {unit}" if v is not None else "n/a")
|
||||
if any(y > 0 for y in ys):
|
||||
traces.append({
|
||||
"type": "bar",
|
||||
"name": PATH_LABELS.get(path, path),
|
||||
"x": configs,
|
||||
"y": ys,
|
||||
"text": texts,
|
||||
"textposition": "outside",
|
||||
"marker": {"color": PATH_COLORS.get(path, "#aaaaaa")},
|
||||
"hovertemplate": "%{x}<br>" + PATH_LABELS.get(path, path)
|
||||
+ f": %{{y:.1f}} {unit}<extra></extra>",
|
||||
})
|
||||
return {
|
||||
"data": traces,
|
||||
"layout": {
|
||||
"title": title,
|
||||
"yaxis": {"title": unit, "rangemode": "tozero"},
|
||||
"barmode": "group",
|
||||
"legend": {"orientation": "h", "y": -0.2},
|
||||
"margin": {"t": 50, "b": 80},
|
||||
"plot_bgcolor": "#fafafa",
|
||||
"paper_bgcolor": "#ffffff",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _line_figure(by_config: dict, metric: str, title: str,
|
||||
scale: float = 1.0, unit: str = "ms") -> dict:
|
||||
"""Line chart for sweep data. Config names are 's=N'; X = N (log scale)."""
|
||||
def _iter(cfg):
|
||||
m = re.fullmatch(r"s=(\d+)", cfg)
|
||||
return int(m.group(1)) if m else 0
|
||||
|
||||
configs_sorted = sorted(by_config.keys(), key=_iter)
|
||||
xs = [_iter(c) for c in configs_sorted]
|
||||
|
||||
paths_present = {p for cfg in by_config.values() for p in cfg}
|
||||
traces = []
|
||||
for path in PATH_ORDER:
|
||||
if path not in paths_present:
|
||||
continue
|
||||
ys = []
|
||||
for cfg in configs_sorted:
|
||||
r = by_config[cfg].get(path)
|
||||
raw = r.get(metric) if r and not r.get("error") else None
|
||||
ys.append(raw * scale if raw is not None else None)
|
||||
if any(y is not None for y in ys):
|
||||
traces.append({
|
||||
"type": "scatter",
|
||||
"mode": "lines+markers",
|
||||
"name": PATH_LABELS.get(path, path),
|
||||
"x": xs,
|
||||
"y": ys,
|
||||
"marker": {"size": 8, "color": PATH_COLORS.get(path, "#aaaaaa")},
|
||||
"line": {"color": PATH_COLORS.get(path, "#aaaaaa"), "width": 2},
|
||||
"hovertemplate": "iters=%{x}<br>" + PATH_LABELS.get(path, path)
|
||||
+ f": %{{y:.1f}} {unit}<extra></extra>",
|
||||
})
|
||||
return {
|
||||
"data": traces,
|
||||
"layout": {
|
||||
"title": title,
|
||||
"xaxis": {"title": "Search iterations", "type": "log",
|
||||
"tickvals": xs, "ticktext": [str(x) for x in xs]},
|
||||
"yaxis": {"title": unit, "rangemode": "tozero"},
|
||||
"legend": {"orientation": "h", "y": -0.25},
|
||||
"margin": {"t": 50, "b": 90},
|
||||
"plot_bgcolor": "#fafafa",
|
||||
"paper_bgcolor": "#ffffff",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# ── table builder ─────────────────────────────────────────────────────────────
|
||||
|
||||
def _table_html(results: list[dict]) -> str:
|
||||
rows = []
|
||||
for r in sorted(results, key=lambda r: (r.get("config", ""), PATH_ORDER.index(r["path"]) if r["path"] in PATH_ORDER else 99)):
|
||||
error = r.get("error")
|
||||
style = ' style="background:#fff0f0"' if error else ""
|
||||
path_label = PATH_LABELS.get(r["path"], r["path"])
|
||||
cfg = r.get("config", "—")
|
||||
ttft = _fmt(r.get("ttft_ms"), 1, " ms")
|
||||
tpot = _fmt(r.get("tpot_ms"), 1, " ms")
|
||||
tput = _fmt(r.get("throughput_tps"), 1, " tok/s")
|
||||
comp = _fmt(r.get("compile_ms"), 0, " ms") if r.get("compile_ms") else "—"
|
||||
ptok = str(r.get("prompt_tokens", "—"))
|
||||
note = (r.get("error") or r.get("note") or "")[:90]
|
||||
note_style = ' style="color:#c00"' if error else ' style="color:#777"'
|
||||
rows.append(
|
||||
f'<tr{style}>'
|
||||
f'<td>{path_label}</td><td>{cfg}</td>'
|
||||
f'<td>{ttft}</td><td>{tpot}</td><td>{tput}</td>'
|
||||
f'<td>{comp}</td><td>{ptok}</td>'
|
||||
f'<td{note_style}>{note}</td>'
|
||||
f'</tr>'
|
||||
)
|
||||
return (
|
||||
'<table>'
|
||||
'<thead><tr>'
|
||||
'<th>Path</th><th>Config</th>'
|
||||
'<th>TTFT</th><th>TPOT</th><th>Throughput</th>'
|
||||
'<th>Compile</th><th>Prompt tokens</th><th>Note</th>'
|
||||
'</tr></thead>'
|
||||
'<tbody>' + "\n".join(rows) + '</tbody>'
|
||||
'</table>'
|
||||
)
|
||||
|
||||
|
||||
# ── section builder ───────────────────────────────────────────────────────────
|
||||
|
||||
def _section_html(sec_id: str, title: str, results: list[dict], fig_counter: list) -> str:
|
||||
by_config = _group_by_config(results)
|
||||
configs = list(by_config.keys())
|
||||
sweep = _is_sweep(configs)
|
||||
|
||||
models = list(dict.fromkeys(r.get("model", "") for r in results if r.get("model")))
|
||||
model_str = ", ".join(models) if models else "—"
|
||||
prompt_tokens = list(dict.fromkeys(r.get("prompt_tokens") for r in results if r.get("prompt_tokens")))
|
||||
tok_str = "/".join(str(t) for t in prompt_tokens) + " prompt tokens" if prompt_tokens else ""
|
||||
|
||||
builder = _line_figure if sweep else _bar_figure
|
||||
ttft_fig = builder(by_config, "ttft_ms", "TTFT")
|
||||
has_tpot = any(r.get("tpot_ms") is not None for r in results if not r.get("error"))
|
||||
tpot_fig = builder(by_config, "tpot_ms", "TPOT") if has_tpot else None
|
||||
has_compile = any(r.get("compile_ms") is not None and r.get("compile_ms") > 0
|
||||
for r in results if not r.get("error"))
|
||||
compile_fig = (builder(by_config, "compile_ms", "Time to Search",
|
||||
scale=0.001, unit="sec")
|
||||
if has_compile else None)
|
||||
|
||||
def chart_div(fig):
|
||||
n = fig_counter[0]
|
||||
fig_counter[0] += 1
|
||||
return (
|
||||
f'<div id="fig{n}" class="chart"></div>'
|
||||
f'<script>Plotly.newPlot("fig{n}", {json.dumps(fig["data"])}, {json.dumps(fig["layout"])}, {{responsive:true}});</script>'
|
||||
)
|
||||
|
||||
charts_html = f'<div class="charts-row">{chart_div(ttft_fig)}'
|
||||
if tpot_fig:
|
||||
charts_html += chart_div(tpot_fig)
|
||||
if compile_fig:
|
||||
charts_html += chart_div(compile_fig)
|
||||
charts_html += '</div>'
|
||||
|
||||
return f"""
|
||||
<section id="{sec_id}">
|
||||
<h2>{title}</h2>
|
||||
<p class="meta">{model_str}{" · " + tok_str if tok_str else ""} · {len(results)} results</p>
|
||||
{charts_html}
|
||||
{_table_html(results)}
|
||||
</section>
|
||||
"""
|
||||
|
||||
|
||||
# ── full page ─────────────────────────────────────────────────────────────────
|
||||
|
||||
CSS = """
|
||||
* { box-sizing: border-box; margin: 0; padding: 0; }
|
||||
body { font-family: system-ui, sans-serif; background: #f0f2f5; color: #222; }
|
||||
header { background: #1a1a2e; color: #fff; padding: 1rem 2rem;
|
||||
position: sticky; top: 0; z-index: 100; display: flex;
|
||||
align-items: center; gap: 2rem; }
|
||||
header h1 { font-size: 1.2rem; white-space: nowrap; }
|
||||
nav a { color: #a0c4ff; text-decoration: none; font-size: 0.9rem;
|
||||
padding: 0.3rem 0.7rem; border-radius: 4px; white-space: nowrap; }
|
||||
nav a:hover { background: rgba(255,255,255,0.15); }
|
||||
main { max-width: 1400px; margin: 0 auto; padding: 2rem; display: flex;
|
||||
flex-direction: column; gap: 2.5rem; }
|
||||
section { background: #fff; border-radius: 8px; padding: 1.5rem 2rem;
|
||||
box-shadow: 0 1px 4px rgba(0,0,0,.08); }
|
||||
h2 { font-size: 1.3rem; margin-bottom: 0.4rem; }
|
||||
.meta { color: #666; font-size: 0.85rem; margin-bottom: 1.2rem; }
|
||||
.charts-row { display: flex; gap: 1.5rem; flex-wrap: wrap; margin-bottom: 1.5rem; }
|
||||
.chart { flex: 1; min-width: 340px; height: 360px; }
|
||||
table { width: 100%; border-collapse: collapse; font-size: 0.82rem; }
|
||||
thead tr { background: #f5f5f5; }
|
||||
th, td { padding: 0.45rem 0.7rem; text-align: left;
|
||||
border-bottom: 1px solid #e8e8e8; }
|
||||
th { font-weight: 600; white-space: nowrap; }
|
||||
tr:last-child td { border-bottom: none; }
|
||||
tr:hover { background: #fafafa; }
|
||||
"""
|
||||
|
||||
def _build_html(sections: list[tuple[str, str, list[dict]]], title: str) -> str:
|
||||
nav_links = "".join(f'<a href="#{sid}">{stitle}</a>' for sid, stitle, _ in sections)
|
||||
fig_counter = [0]
|
||||
body = "".join(_section_html(sid, stitle, results, fig_counter)
|
||||
for sid, stitle, results in sections)
|
||||
return f"""<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1">
|
||||
<title>{title}</title>
|
||||
<script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
|
||||
<style>{CSS}</style>
|
||||
</head>
|
||||
<body>
|
||||
<header>
|
||||
<h1>{title}</h1>
|
||||
<nav>{nav_links}</nav>
|
||||
</header>
|
||||
<main>{body}</main>
|
||||
</body>
|
||||
</html>"""
|
||||
|
||||
|
||||
# ── CLI ───────────────────────────────────────────────────────────────────────
|
||||
|
||||
def _sections_for_run(results: list[dict]) -> list[tuple[str, str, list[dict]]]:
|
||||
"""Split a single run's results into (sec_id, title, records) sections.
|
||||
|
||||
Splits first by model_key (NULL → 'results'), then within each by
|
||||
sweep-vs-comparison based on config 's=N' shape."""
|
||||
by_key: dict[str | None, list[dict]] = {}
|
||||
for r in results:
|
||||
by_key.setdefault(r.get("model_key"), []).append(r)
|
||||
|
||||
sections: list[tuple[str, str, list[dict]]] = []
|
||||
for key, recs in by_key.items():
|
||||
comp, sweep = [], []
|
||||
for r in recs:
|
||||
(sweep if str(r.get("config", "")).startswith("s=") else comp).append(r)
|
||||
prefix = (key or "results").replace("-", "_").replace(".", "_")
|
||||
title_prefix = key or "Results"
|
||||
if comp:
|
||||
sections.append((f"{prefix}_comparison",
|
||||
f"{title_prefix} comparison".strip().title(),
|
||||
comp))
|
||||
if sweep:
|
||||
sections.append((f"{prefix}_sweep",
|
||||
f"{title_prefix} sweep".strip().title(),
|
||||
sweep))
|
||||
return sections
|
||||
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser(description=__doc__)
|
||||
ap.add_argument("--db", default=str(db.DEFAULT_DB_PATH),
|
||||
help=f"SQLite bench DB (default: {db.DEFAULT_DB_PATH})")
|
||||
ap.add_argument("--run", default=None,
|
||||
help="Run ID to render (default: latest run in DB)")
|
||||
ap.add_argument("--out", default=None,
|
||||
help="Output HTML path (default: report.html in benchmarks/ttft/)")
|
||||
ap.add_argument("--title", default="Luminal TTFT Benchmark Report",
|
||||
help="Page title and heading")
|
||||
args = ap.parse_args()
|
||||
|
||||
if not Path(args.db).exists():
|
||||
print(f"DB not found: {args.db}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
conn = db.connect(args.db)
|
||||
run_id = args.run or db.latest_run_id(conn)
|
||||
if run_id is None:
|
||||
print(f"No runs in {args.db}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
results = db.load_results(conn, run_id)
|
||||
if not results:
|
||||
print(f"No results for run {run_id}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
sections = _sections_for_run(results)
|
||||
if not sections:
|
||||
print(f"No section data for run {run_id}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
out = Path(args.out) if args.out else Path(__file__).parent / "report.html"
|
||||
html = _build_html(sections, f"{args.title} — {run_id}")
|
||||
out.write_text(html)
|
||||
print(f"wrote {out} (run {run_id}, {len(sections)} sections, {len(results)} results)")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
148
benchmarks/ttft/report.html
Normal file
148
benchmarks/ttft/report.html
Normal file
@@ -0,0 +1,148 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1">
|
||||
<title>Luminal TTFT Benchmark Report — 2026-05-01T18-56-26-996695</title>
|
||||
<script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
|
||||
<style>
|
||||
* { box-sizing: border-box; margin: 0; padding: 0; }
|
||||
body { font-family: system-ui, sans-serif; background: #f0f2f5; color: #222; }
|
||||
header { background: #1a1a2e; color: #fff; padding: 1rem 2rem;
|
||||
position: sticky; top: 0; z-index: 100; display: flex;
|
||||
align-items: center; gap: 2rem; }
|
||||
header h1 { font-size: 1.2rem; white-space: nowrap; }
|
||||
nav a { color: #a0c4ff; text-decoration: none; font-size: 0.9rem;
|
||||
padding: 0.3rem 0.7rem; border-radius: 4px; white-space: nowrap; }
|
||||
nav a:hover { background: rgba(255,255,255,0.15); }
|
||||
main { max-width: 1400px; margin: 0 auto; padding: 2rem; display: flex;
|
||||
flex-direction: column; gap: 2.5rem; }
|
||||
section { background: #fff; border-radius: 8px; padding: 1.5rem 2rem;
|
||||
box-shadow: 0 1px 4px rgba(0,0,0,.08); }
|
||||
h2 { font-size: 1.3rem; margin-bottom: 0.4rem; }
|
||||
.meta { color: #666; font-size: 0.85rem; margin-bottom: 1.2rem; }
|
||||
.charts-row { display: flex; gap: 1.5rem; flex-wrap: wrap; margin-bottom: 1.5rem; }
|
||||
.chart { flex: 1; min-width: 340px; height: 360px; }
|
||||
table { width: 100%; border-collapse: collapse; font-size: 0.82rem; }
|
||||
thead tr { background: #f5f5f5; }
|
||||
th, td { padding: 0.45rem 0.7rem; text-align: left;
|
||||
border-bottom: 1px solid #e8e8e8; }
|
||||
th { font-weight: 600; white-space: nowrap; }
|
||||
tr:last-child td { border-bottom: none; }
|
||||
tr:hover { background: #fafafa; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<header>
|
||||
<h1>Luminal TTFT Benchmark Report — 2026-05-01T18-56-26-996695</h1>
|
||||
<nav><a href="#llama_8b_comparison">Llama-8B Comparison</a><a href="#llama_8b_sweep">Llama-8B Sweep</a><a href="#qwen3_4b_comparison">Qwen3-4B Comparison</a><a href="#qwen3_4b_sweep">Qwen3-4B Sweep</a><a href="#gemma3_4b_comparison">Gemma3-4B Comparison</a><a href="#gemma3_4b_sweep">Gemma3-4B Sweep</a><a href="#gemma4_moe_comparison">Gemma4-Moe Comparison</a><a href="#gemma4_moe_sweep">Gemma4-Moe Sweep</a><a href="#qwen3_moe_comparison">Qwen3-Moe Comparison</a><a href="#qwen3_moe_sweep">Qwen3-Moe Sweep</a></nav>
|
||||
</header>
|
||||
<main>
|
||||
<section id="llama_8b_comparison">
|
||||
<h2>Llama-8B Comparison</h2>
|
||||
<p class="meta">NousResearch/Meta-Llama-3-8B-Instruct · 21 prompt tokens · 4 results</p>
|
||||
<div class="charts-row"><div id="fig0" class="chart"></div><script>Plotly.newPlot("fig0", [{"type": "bar", "name": "HF Baseline", "x": ["llama-8b"], "y": [705.9654394979589], "text": ["706.0 ms"], "textposition": "outside", "marker": {"color": "#888888"}, "hovertemplate": "%{x}<br>HF Baseline: %{y:.1f} ms<extra></extra>"}, {"type": "bar", "name": "torch.compile", "x": ["llama-8b"], "y": [307.66548847896047], "text": ["307.7 ms"], "textposition": "outside", "marker": {"color": "#5ab552"}, "hovertemplate": "%{x}<br>torch.compile: %{y:.1f} ms<extra></extra>"}, {"type": "bar", "name": "luminal backend", "x": ["llama-8b"], "y": [461.48114453535527], "text": ["461.5 ms"], "textposition": "outside", "marker": {"color": "#4c9ed9"}, "hovertemplate": "%{x}<br>luminal backend: %{y:.1f} ms<extra></extra>"}, {"type": "bar", "name": "Rust (luminal)", "x": ["llama-8b"], "y": [1026.86], "text": ["1026.9 ms"], "textposition": "outside", "marker": {"color": "#d97a4c"}, "hovertemplate": "%{x}<br>Rust (luminal): %{y:.1f} ms<extra></extra>"}], {"title": "TTFT", "yaxis": {"title": "ms", "rangemode": "tozero"}, "barmode": "group", "legend": {"orientation": "h", "y": -0.2}, "margin": {"t": 50, "b": 80}, "plot_bgcolor": "#fafafa", "paper_bgcolor": "#ffffff"}, {responsive:true});</script><div id="fig1" class="chart"></div><script>Plotly.newPlot("fig1", [{"type": "bar", "name": "HF Baseline", "x": ["llama-8b"], "y": [34.15271903970279], "text": ["34.2 ms"], "textposition": "outside", "marker": {"color": "#888888"}, "hovertemplate": "%{x}<br>HF Baseline: %{y:.1f} ms<extra></extra>"}, {"type": "bar", "name": "torch.compile", "x": ["llama-8b"], "y": [171.7862353892997], "text": ["171.8 ms"], "textposition": "outside", "marker": {"color": "#5ab552"}, "hovertemplate": "%{x}<br>torch.compile: %{y:.1f} ms<extra></extra>"}, {"type": "bar", "name": "luminal backend", "x": ["llama-8b"], "y": [23.078908618772402], "text": ["23.1 ms"], "textposition": "outside", "marker": {"color": "#4c9ed9"}, "hovertemplate": "%{x}<br>luminal backend: %{y:.1f} ms<extra></extra>"}, {"type": "bar", "name": "Rust (luminal)", "x": ["llama-8b"], "y": [51.64], "text": ["51.6 ms"], "textposition": "outside", "marker": {"color": "#d97a4c"}, "hovertemplate": "%{x}<br>Rust (luminal): %{y:.1f} ms<extra></extra>"}], {"title": "TPOT", "yaxis": {"title": "ms", "rangemode": "tozero"}, "barmode": "group", "legend": {"orientation": "h", "y": -0.2}, "margin": {"t": 50, "b": 80}, "plot_bgcolor": "#fafafa", "paper_bgcolor": "#ffffff"}, {responsive:true});</script><div id="fig2" class="chart"></div><script>Plotly.newPlot("fig2", [{"type": "bar", "name": "torch.compile", "x": ["llama-8b"], "y": [18.760145067994017], "text": ["18.8 sec"], "textposition": "outside", "marker": {"color": "#5ab552"}, "hovertemplate": "%{x}<br>torch.compile: %{y:.1f} sec<extra></extra>"}, {"type": "bar", "name": "luminal backend", "x": ["llama-8b"], "y": [95.96263545705006], "text": ["96.0 sec"], "textposition": "outside", "marker": {"color": "#4c9ed9"}, "hovertemplate": "%{x}<br>luminal backend: %{y:.1f} sec<extra></extra>"}, {"type": "bar", "name": "Rust (luminal)", "x": ["llama-8b"], "y": [84.45343], "text": ["84.5 sec"], "textposition": "outside", "marker": {"color": "#d97a4c"}, "hovertemplate": "%{x}<br>Rust (luminal): %{y:.1f} sec<extra></extra>"}], {"title": "Time to Search", "yaxis": {"title": "sec", "rangemode": "tozero"}, "barmode": "group", "legend": {"orientation": "h", "y": -0.2}, "margin": {"t": 50, "b": 80}, "plot_bgcolor": "#fafafa", "paper_bgcolor": "#ffffff"}, {responsive:true});</script></div>
|
||||
<table><thead><tr><th>Path</th><th>Config</th><th>TTFT</th><th>TPOT</th><th>Throughput</th><th>Compile</th><th>Prompt tokens</th><th>Note</th></tr></thead><tbody><tr><td>HF Baseline</td><td>llama-8b</td><td>706.0 ms</td><td>34.2 ms</td><td>29.3 tok/s</td><td>—</td><td>21</td><td style="color:#777">sequential per-token, StaticCache KV cache</td></tr>
|
||||
<tr><td>torch.compile</td><td>llama-8b</td><td>307.7 ms</td><td>171.8 ms</td><td>5.8 tok/s</td><td>18760 ms</td><td>21</td><td style="color:#777">sequential per-token, StaticCache KV cache (torch.compile inductor)</td></tr>
|
||||
<tr><td>luminal backend</td><td>llama-8b</td><td>461.5 ms</td><td>23.1 ms</td><td>43.3 tok/s</td><td>95963 ms</td><td>21</td><td style="color:#777">sequential per-token, StaticCache KV cache</td></tr>
|
||||
<tr><td>Rust (luminal)</td><td>llama-8b</td><td>1026.9 ms</td><td>51.6 ms</td><td>19.4 tok/s</td><td>84453 ms</td><td>21</td><td style="color:#777">sum of per-token prefill durations</td></tr></tbody></table>
|
||||
</section>
|
||||
|
||||
<section id="llama_8b_sweep">
|
||||
<h2>Llama-8B Sweep</h2>
|
||||
<p class="meta">NousResearch/Meta-Llama-3-8B-Instruct · 21 prompt tokens · 6 results</p>
|
||||
<div class="charts-row"><div id="fig3" class="chart"></div><script>Plotly.newPlot("fig3", [{"type": "scatter", "mode": "lines+markers", "name": "luminal backend", "x": [10, 100, 500], "y": [470.7036415056791, 460.72837291285396, 472.43661794345826], "marker": {"size": 8, "color": "#4c9ed9"}, "line": {"color": "#4c9ed9", "width": 2}, "hovertemplate": "iters=%{x}<br>luminal backend: %{y:.1f} ms<extra></extra>"}, {"type": "scatter", "mode": "lines+markers", "name": "Rust (luminal)", "x": [10, 100, 500], "y": [751.03, 1038.34, 453.16], "marker": {"size": 8, "color": "#d97a4c"}, "line": {"color": "#d97a4c", "width": 2}, "hovertemplate": "iters=%{x}<br>Rust (luminal): %{y:.1f} ms<extra></extra>"}], {"title": "TTFT", "xaxis": {"title": "Search iterations", "type": "log", "tickvals": [10, 100, 500], "ticktext": ["10", "100", "500"]}, "yaxis": {"title": "ms", "rangemode": "tozero"}, "legend": {"orientation": "h", "y": -0.25}, "margin": {"t": 50, "b": 90}, "plot_bgcolor": "#fafafa", "paper_bgcolor": "#ffffff"}, {responsive:true});</script><div id="fig4" class="chart"></div><script>Plotly.newPlot("fig4", [{"type": "scatter", "mode": "lines+markers", "name": "luminal backend", "x": [10, 100, 500], "y": [23.540849717101082, 23.101884137140587, 23.610779400914907], "marker": {"size": 8, "color": "#4c9ed9"}, "line": {"color": "#4c9ed9", "width": 2}, "hovertemplate": "iters=%{x}<br>luminal backend: %{y:.1f} ms<extra></extra>"}, {"type": "scatter", "mode": "lines+markers", "name": "Rust (luminal)", "x": [10, 100, 500], "y": [38.2, 51.92, 24.09], "marker": {"size": 8, "color": "#d97a4c"}, "line": {"color": "#d97a4c", "width": 2}, "hovertemplate": "iters=%{x}<br>Rust (luminal): %{y:.1f} ms<extra></extra>"}], {"title": "TPOT", "xaxis": {"title": "Search iterations", "type": "log", "tickvals": [10, 100, 500], "ticktext": ["10", "100", "500"]}, "yaxis": {"title": "ms", "rangemode": "tozero"}, "legend": {"orientation": "h", "y": -0.25}, "margin": {"t": 50, "b": 90}, "plot_bgcolor": "#fafafa", "paper_bgcolor": "#ffffff"}, {responsive:true});</script><div id="fig5" class="chart"></div><script>Plotly.newPlot("fig5", [{"type": "scatter", "mode": "lines+markers", "name": "luminal backend", "x": [10, 100, 500], "y": [28.428826077957638, 43.57440591201885, 95.52432684396626], "marker": {"size": 8, "color": "#4c9ed9"}, "line": {"color": "#4c9ed9", "width": 2}, "hovertemplate": "iters=%{x}<br>luminal backend: %{y:.1f} sec<extra></extra>"}, {"type": "scatter", "mode": "lines+markers", "name": "Rust (luminal)", "x": [10, 100, 500], "y": [15.14307, 30.12727, 84.87889], "marker": {"size": 8, "color": "#d97a4c"}, "line": {"color": "#d97a4c", "width": 2}, "hovertemplate": "iters=%{x}<br>Rust (luminal): %{y:.1f} sec<extra></extra>"}], {"title": "Time to Search", "xaxis": {"title": "Search iterations", "type": "log", "tickvals": [10, 100, 500], "ticktext": ["10", "100", "500"]}, "yaxis": {"title": "sec", "rangemode": "tozero"}, "legend": {"orientation": "h", "y": -0.25}, "margin": {"t": 50, "b": 90}, "plot_bgcolor": "#fafafa", "paper_bgcolor": "#ffffff"}, {responsive:true});</script></div>
|
||||
<table><thead><tr><th>Path</th><th>Config</th><th>TTFT</th><th>TPOT</th><th>Throughput</th><th>Compile</th><th>Prompt tokens</th><th>Note</th></tr></thead><tbody><tr><td>luminal backend</td><td>s=10</td><td>470.7 ms</td><td>23.5 ms</td><td>42.5 tok/s</td><td>28429 ms</td><td>21</td><td style="color:#777">sequential per-token, StaticCache KV cache</td></tr>
|
||||
<tr><td>Rust (luminal)</td><td>s=10</td><td>751.0 ms</td><td>38.2 ms</td><td>26.2 tok/s</td><td>15143 ms</td><td>21</td><td style="color:#777">sum of per-token prefill durations</td></tr>
|
||||
<tr><td>luminal backend</td><td>s=100</td><td>460.7 ms</td><td>23.1 ms</td><td>43.3 tok/s</td><td>43574 ms</td><td>21</td><td style="color:#777">sequential per-token, StaticCache KV cache</td></tr>
|
||||
<tr><td>Rust (luminal)</td><td>s=100</td><td>1038.3 ms</td><td>51.9 ms</td><td>19.3 tok/s</td><td>30127 ms</td><td>21</td><td style="color:#777">sum of per-token prefill durations</td></tr>
|
||||
<tr><td>luminal backend</td><td>s=500</td><td>472.4 ms</td><td>23.6 ms</td><td>42.4 tok/s</td><td>95524 ms</td><td>21</td><td style="color:#777">sequential per-token, StaticCache KV cache</td></tr>
|
||||
<tr><td>Rust (luminal)</td><td>s=500</td><td>453.2 ms</td><td>24.1 ms</td><td>41.5 tok/s</td><td>84879 ms</td><td>21</td><td style="color:#777">sum of per-token prefill durations</td></tr></tbody></table>
|
||||
</section>
|
||||
|
||||
<section id="qwen3_4b_comparison">
|
||||
<h2>Qwen3-4B Comparison</h2>
|
||||
<p class="meta">Qwen/Qwen3-4B · 19/11 prompt tokens · 4 results</p>
|
||||
<div class="charts-row"><div id="fig6" class="chart"></div><script>Plotly.newPlot("fig6", [{"type": "bar", "name": "HF Baseline", "x": ["qwen3-4b"], "y": [869.2860195587855], "text": ["869.3 ms"], "textposition": "outside", "marker": {"color": "#888888"}, "hovertemplate": "%{x}<br>HF Baseline: %{y:.1f} ms<extra></extra>"}, {"type": "bar", "name": "torch.compile", "x": ["qwen3-4b"], "y": [298.27259748708457], "text": ["298.3 ms"], "textposition": "outside", "marker": {"color": "#5ab552"}, "hovertemplate": "%{x}<br>torch.compile: %{y:.1f} ms<extra></extra>"}, {"type": "bar", "name": "luminal backend", "x": ["qwen3-4b"], "y": [485.3892414830625], "text": ["485.4 ms"], "textposition": "outside", "marker": {"color": "#4c9ed9"}, "hovertemplate": "%{x}<br>luminal backend: %{y:.1f} ms<extra></extra>"}, {"type": "bar", "name": "Rust (luminal)", "x": ["qwen3-4b"], "y": [398.58], "text": ["398.6 ms"], "textposition": "outside", "marker": {"color": "#d97a4c"}, "hovertemplate": "%{x}<br>Rust (luminal): %{y:.1f} ms<extra></extra>"}], {"title": "TTFT", "yaxis": {"title": "ms", "rangemode": "tozero"}, "barmode": "group", "legend": {"orientation": "h", "y": -0.2}, "margin": {"t": 50, "b": 80}, "plot_bgcolor": "#fafafa", "paper_bgcolor": "#ffffff"}, {responsive:true});</script><div id="fig7" class="chart"></div><script>Plotly.newPlot("fig7", [{"type": "bar", "name": "HF Baseline", "x": ["qwen3-4b"], "y": [47.71483448566869], "text": ["47.7 ms"], "textposition": "outside", "marker": {"color": "#888888"}, "hovertemplate": "%{x}<br>HF Baseline: %{y:.1f} ms<extra></extra>"}, {"type": "bar", "name": "torch.compile", "x": ["qwen3-4b"], "y": [468.56868775503244], "text": ["468.6 ms"], "textposition": "outside", "marker": {"color": "#5ab552"}, "hovertemplate": "%{x}<br>torch.compile: %{y:.1f} ms<extra></extra>"}, {"type": "bar", "name": "luminal backend", "x": ["qwen3-4b"], "y": [26.90318431414198], "text": ["26.9 ms"], "textposition": "outside", "marker": {"color": "#4c9ed9"}, "hovertemplate": "%{x}<br>luminal backend: %{y:.1f} ms<extra></extra>"}, {"type": "bar", "name": "Rust (luminal)", "x": ["qwen3-4b"], "y": [40.62], "text": ["40.6 ms"], "textposition": "outside", "marker": {"color": "#d97a4c"}, "hovertemplate": "%{x}<br>Rust (luminal): %{y:.1f} ms<extra></extra>"}], {"title": "TPOT", "yaxis": {"title": "ms", "rangemode": "tozero"}, "barmode": "group", "legend": {"orientation": "h", "y": -0.2}, "margin": {"t": 50, "b": 80}, "plot_bgcolor": "#fafafa", "paper_bgcolor": "#ffffff"}, {responsive:true});</script><div id="fig8" class="chart"></div><script>Plotly.newPlot("fig8", [{"type": "bar", "name": "torch.compile", "x": ["qwen3-4b"], "y": [4.680963660997804], "text": ["4.7 sec"], "textposition": "outside", "marker": {"color": "#5ab552"}, "hovertemplate": "%{x}<br>torch.compile: %{y:.1f} sec<extra></extra>"}, {"type": "bar", "name": "luminal backend", "x": ["qwen3-4b"], "y": [45.345814052037895], "text": ["45.3 sec"], "textposition": "outside", "marker": {"color": "#4c9ed9"}, "hovertemplate": "%{x}<br>luminal backend: %{y:.1f} sec<extra></extra>"}, {"type": "bar", "name": "Rust (luminal)", "x": ["qwen3-4b"], "y": [19.92977], "text": ["19.9 sec"], "textposition": "outside", "marker": {"color": "#d97a4c"}, "hovertemplate": "%{x}<br>Rust (luminal): %{y:.1f} sec<extra></extra>"}], {"title": "Time to Search", "yaxis": {"title": "sec", "rangemode": "tozero"}, "barmode": "group", "legend": {"orientation": "h", "y": -0.2}, "margin": {"t": 50, "b": 80}, "plot_bgcolor": "#fafafa", "paper_bgcolor": "#ffffff"}, {responsive:true});</script></div>
|
||||
<table><thead><tr><th>Path</th><th>Config</th><th>TTFT</th><th>TPOT</th><th>Throughput</th><th>Compile</th><th>Prompt tokens</th><th>Note</th></tr></thead><tbody><tr><td>HF Baseline</td><td>qwen3-4b</td><td>869.3 ms</td><td>47.7 ms</td><td>21.0 tok/s</td><td>—</td><td>19</td><td style="color:#777">sequential per-token, StaticCache KV cache</td></tr>
|
||||
<tr><td>torch.compile</td><td>qwen3-4b</td><td>298.3 ms</td><td>468.6 ms</td><td>2.1 tok/s</td><td>4681 ms</td><td>19</td><td style="color:#777">sequential per-token, StaticCache KV cache (torch.compile inductor)</td></tr>
|
||||
<tr><td>luminal backend</td><td>qwen3-4b</td><td>485.4 ms</td><td>26.9 ms</td><td>37.2 tok/s</td><td>45346 ms</td><td>19</td><td style="color:#777">sequential per-token, StaticCache KV cache</td></tr>
|
||||
<tr><td>Rust (luminal)</td><td>qwen3-4b</td><td>398.6 ms</td><td>40.6 ms</td><td>24.6 tok/s</td><td>19930 ms</td><td>11</td><td style="color:#777">sum of per-token prefill durations</td></tr></tbody></table>
|
||||
</section>
|
||||
|
||||
<section id="qwen3_4b_sweep">
|
||||
<h2>Qwen3-4B Sweep</h2>
|
||||
<p class="meta">Qwen/Qwen3-4B · 19/11 prompt tokens · 6 results</p>
|
||||
<div class="charts-row"><div id="fig9" class="chart"></div><script>Plotly.newPlot("fig9", [{"type": "scatter", "mode": "lines+markers", "name": "luminal backend", "x": [10, 100, 500], "y": [465.02652901108377, 465.9317950136028, 495.75577257201076], "marker": {"size": 8, "color": "#4c9ed9"}, "line": {"color": "#4c9ed9", "width": 2}, "hovertemplate": "iters=%{x}<br>luminal backend: %{y:.1f} ms<extra></extra>"}, {"type": "scatter", "mode": "lines+markers", "name": "Rust (luminal)", "x": [10, 100, 500], "y": [398.44, 390.08, 559.29], "marker": {"size": 8, "color": "#d97a4c"}, "line": {"color": "#d97a4c", "width": 2}, "hovertemplate": "iters=%{x}<br>Rust (luminal): %{y:.1f} ms<extra></extra>"}], {"title": "TTFT", "xaxis": {"title": "Search iterations", "type": "log", "tickvals": [10, 100, 500], "ticktext": ["10", "100", "500"]}, "yaxis": {"title": "ms", "rangemode": "tozero"}, "legend": {"orientation": "h", "y": -0.25}, "margin": {"t": 50, "b": 90}, "plot_bgcolor": "#fafafa", "paper_bgcolor": "#ffffff"}, {responsive:true});</script><div id="fig10" class="chart"></div><script>Plotly.newPlot("fig10", [{"type": "scatter", "mode": "lines+markers", "name": "luminal backend", "x": [10, 100, 500], "y": [25.875402649398893, 25.884080055402592, 27.492373346467502], "marker": {"size": 8, "color": "#4c9ed9"}, "line": {"color": "#4c9ed9", "width": 2}, "hovertemplate": "iters=%{x}<br>luminal backend: %{y:.1f} ms<extra></extra>"}, {"type": "scatter", "mode": "lines+markers", "name": "Rust (luminal)", "x": [10, 100, 500], "y": [40.64, 39.98, 55.37], "marker": {"size": 8, "color": "#d97a4c"}, "line": {"color": "#d97a4c", "width": 2}, "hovertemplate": "iters=%{x}<br>Rust (luminal): %{y:.1f} ms<extra></extra>"}], {"title": "TPOT", "xaxis": {"title": "Search iterations", "type": "log", "tickvals": [10, 100, 500], "ticktext": ["10", "100", "500"]}, "yaxis": {"title": "ms", "rangemode": "tozero"}, "legend": {"orientation": "h", "y": -0.25}, "margin": {"t": 50, "b": 90}, "plot_bgcolor": "#fafafa", "paper_bgcolor": "#ffffff"}, {responsive:true});</script><div id="fig11" class="chart"></div><script>Plotly.newPlot("fig11", [{"type": "scatter", "mode": "lines+markers", "name": "luminal backend", "x": [10, 100, 500], "y": [37.92102829599753, 54.08867314597592, 118.29659596900456], "marker": {"size": 8, "color": "#4c9ed9"}, "line": {"color": "#4c9ed9", "width": 2}, "hovertemplate": "iters=%{x}<br>luminal backend: %{y:.1f} sec<extra></extra>"}, {"type": "scatter", "mode": "lines+markers", "name": "Rust (luminal)", "x": [10, 100, 500], "y": [12.448030000000001, 27.06796, 81.89342], "marker": {"size": 8, "color": "#d97a4c"}, "line": {"color": "#d97a4c", "width": 2}, "hovertemplate": "iters=%{x}<br>Rust (luminal): %{y:.1f} sec<extra></extra>"}], {"title": "Time to Search", "xaxis": {"title": "Search iterations", "type": "log", "tickvals": [10, 100, 500], "ticktext": ["10", "100", "500"]}, "yaxis": {"title": "sec", "rangemode": "tozero"}, "legend": {"orientation": "h", "y": -0.25}, "margin": {"t": 50, "b": 90}, "plot_bgcolor": "#fafafa", "paper_bgcolor": "#ffffff"}, {responsive:true});</script></div>
|
||||
<table><thead><tr><th>Path</th><th>Config</th><th>TTFT</th><th>TPOT</th><th>Throughput</th><th>Compile</th><th>Prompt tokens</th><th>Note</th></tr></thead><tbody><tr><td>luminal backend</td><td>s=10</td><td>465.0 ms</td><td>25.9 ms</td><td>38.6 tok/s</td><td>37921 ms</td><td>19</td><td style="color:#777">sequential per-token, StaticCache KV cache</td></tr>
|
||||
<tr><td>Rust (luminal)</td><td>s=10</td><td>398.4 ms</td><td>40.6 ms</td><td>24.6 tok/s</td><td>12448 ms</td><td>11</td><td style="color:#777">sum of per-token prefill durations</td></tr>
|
||||
<tr><td>luminal backend</td><td>s=100</td><td>465.9 ms</td><td>25.9 ms</td><td>38.6 tok/s</td><td>54089 ms</td><td>19</td><td style="color:#777">sequential per-token, StaticCache KV cache</td></tr>
|
||||
<tr><td>Rust (luminal)</td><td>s=100</td><td>390.1 ms</td><td>40.0 ms</td><td>25.0 tok/s</td><td>27068 ms</td><td>11</td><td style="color:#777">sum of per-token prefill durations</td></tr>
|
||||
<tr><td>luminal backend</td><td>s=500</td><td>495.8 ms</td><td>27.5 ms</td><td>36.4 tok/s</td><td>118297 ms</td><td>19</td><td style="color:#777">sequential per-token, StaticCache KV cache</td></tr>
|
||||
<tr><td>Rust (luminal)</td><td>s=500</td><td>559.3 ms</td><td>55.4 ms</td><td>18.1 tok/s</td><td>81893 ms</td><td>11</td><td style="color:#777">sum of per-token prefill durations</td></tr></tbody></table>
|
||||
</section>
|
||||
|
||||
<section id="gemma3_4b_comparison">
|
||||
<h2>Gemma3-4B Comparison</h2>
|
||||
<p class="meta">unsloth/gemma-3-4b-it · 19/11 prompt tokens · 4 results</p>
|
||||
<div class="charts-row"><div id="fig12" class="chart"></div><script>Plotly.newPlot("fig12", [{"type": "bar", "name": "HF Baseline", "x": ["gemma3-4b"], "y": [951.1196144158021], "text": ["951.1 ms"], "textposition": "outside", "marker": {"color": "#888888"}, "hovertemplate": "%{x}<br>HF Baseline: %{y:.1f} ms<extra></extra>"}, {"type": "bar", "name": "torch.compile", "x": ["gemma3-4b"], "y": [300.9451600664761], "text": ["300.9 ms"], "textposition": "outside", "marker": {"color": "#5ab552"}, "hovertemplate": "%{x}<br>torch.compile: %{y:.1f} ms<extra></extra>"}, {"type": "bar", "name": "Rust (luminal)", "x": ["gemma3-4b"], "y": [404.43], "text": ["404.4 ms"], "textposition": "outside", "marker": {"color": "#d97a4c"}, "hovertemplate": "%{x}<br>Rust (luminal): %{y:.1f} ms<extra></extra>"}], {"title": "TTFT", "yaxis": {"title": "ms", "rangemode": "tozero"}, "barmode": "group", "legend": {"orientation": "h", "y": -0.2}, "margin": {"t": 50, "b": 80}, "plot_bgcolor": "#fafafa", "paper_bgcolor": "#ffffff"}, {responsive:true});</script><div id="fig13" class="chart"></div><script>Plotly.newPlot("fig13", [{"type": "bar", "name": "HF Baseline", "x": ["gemma3-4b"], "y": [52.498737201676704], "text": ["52.5 ms"], "textposition": "outside", "marker": {"color": "#888888"}, "hovertemplate": "%{x}<br>HF Baseline: %{y:.1f} ms<extra></extra>"}, {"type": "bar", "name": "torch.compile", "x": ["gemma3-4b"], "y": [2197.426627812092], "text": ["2197.4 ms"], "textposition": "outside", "marker": {"color": "#5ab552"}, "hovertemplate": "%{x}<br>torch.compile: %{y:.1f} ms<extra></extra>"}, {"type": "bar", "name": "Rust (luminal)", "x": ["gemma3-4b"], "y": [38.99], "text": ["39.0 ms"], "textposition": "outside", "marker": {"color": "#d97a4c"}, "hovertemplate": "%{x}<br>Rust (luminal): %{y:.1f} ms<extra></extra>"}], {"title": "TPOT", "yaxis": {"title": "ms", "rangemode": "tozero"}, "barmode": "group", "legend": {"orientation": "h", "y": -0.2}, "margin": {"t": 50, "b": 80}, "plot_bgcolor": "#fafafa", "paper_bgcolor": "#ffffff"}, {responsive:true});</script><div id="fig14" class="chart"></div><script>Plotly.newPlot("fig14", [{"type": "bar", "name": "torch.compile", "x": ["gemma3-4b"], "y": [26.649526304972824], "text": ["26.6 sec"], "textposition": "outside", "marker": {"color": "#5ab552"}, "hovertemplate": "%{x}<br>torch.compile: %{y:.1f} sec<extra></extra>"}, {"type": "bar", "name": "Rust (luminal)", "x": ["gemma3-4b"], "y": [156.84164], "text": ["156.8 sec"], "textposition": "outside", "marker": {"color": "#d97a4c"}, "hovertemplate": "%{x}<br>Rust (luminal): %{y:.1f} sec<extra></extra>"}], {"title": "Time to Search", "yaxis": {"title": "sec", "rangemode": "tozero"}, "barmode": "group", "legend": {"orientation": "h", "y": -0.2}, "margin": {"t": 50, "b": 80}, "plot_bgcolor": "#fafafa", "paper_bgcolor": "#ffffff"}, {responsive:true});</script></div>
|
||||
<table><thead><tr><th>Path</th><th>Config</th><th>TTFT</th><th>TPOT</th><th>Throughput</th><th>Compile</th><th>Prompt tokens</th><th>Note</th></tr></thead><tbody><tr><td>HF Baseline</td><td>gemma3-4b</td><td>951.1 ms</td><td>52.5 ms</td><td>19.0 tok/s</td><td>—</td><td>19</td><td style="color:#777">sequential per-token, StaticCache KV cache</td></tr>
|
||||
<tr><td>torch.compile</td><td>gemma3-4b</td><td>300.9 ms</td><td>2197.4 ms</td><td>0.5 tok/s</td><td>26650 ms</td><td>19</td><td style="color:#777">sequential per-token, StaticCache KV cache (torch.compile inductor)</td></tr>
|
||||
<tr style="background:#fff0f0"><td>luminal backend</td><td>gemma3-4b</td><td>—</td><td>—</td><td>—</td><td>—</td><td>—</td><td style="color:#c00">bench_python_luminal.py failed with code 1</td></tr>
|
||||
<tr><td>Rust (luminal)</td><td>gemma3-4b</td><td>404.4 ms</td><td>39.0 ms</td><td>25.6 tok/s</td><td>156842 ms</td><td>11</td><td style="color:#777">sum of per-token prefill durations</td></tr></tbody></table>
|
||||
</section>
|
||||
|
||||
<section id="gemma3_4b_sweep">
|
||||
<h2>Gemma3-4B Sweep</h2>
|
||||
<p class="meta">unsloth/gemma-3-4b-it · 11 prompt tokens · 6 results</p>
|
||||
<div class="charts-row"><div id="fig15" class="chart"></div><script>Plotly.newPlot("fig15", [{"type": "scatter", "mode": "lines+markers", "name": "Rust (luminal)", "x": [10, 100, 500], "y": [388.19, 436.49, 386.13], "marker": {"size": 8, "color": "#d97a4c"}, "line": {"color": "#d97a4c", "width": 2}, "hovertemplate": "iters=%{x}<br>Rust (luminal): %{y:.1f} ms<extra></extra>"}], {"title": "TTFT", "xaxis": {"title": "Search iterations", "type": "log", "tickvals": [10, 100, 500], "ticktext": ["10", "100", "500"]}, "yaxis": {"title": "ms", "rangemode": "tozero"}, "legend": {"orientation": "h", "y": -0.25}, "margin": {"t": 50, "b": 90}, "plot_bgcolor": "#fafafa", "paper_bgcolor": "#ffffff"}, {responsive:true});</script><div id="fig16" class="chart"></div><script>Plotly.newPlot("fig16", [{"type": "scatter", "mode": "lines+markers", "name": "Rust (luminal)", "x": [10, 100, 500], "y": [37.47, 41.95, 37.25], "marker": {"size": 8, "color": "#d97a4c"}, "line": {"color": "#d97a4c", "width": 2}, "hovertemplate": "iters=%{x}<br>Rust (luminal): %{y:.1f} ms<extra></extra>"}], {"title": "TPOT", "xaxis": {"title": "Search iterations", "type": "log", "tickvals": [10, 100, 500], "ticktext": ["10", "100", "500"]}, "yaxis": {"title": "ms", "rangemode": "tozero"}, "legend": {"orientation": "h", "y": -0.25}, "margin": {"t": 50, "b": 90}, "plot_bgcolor": "#fafafa", "paper_bgcolor": "#ffffff"}, {responsive:true});</script><div id="fig17" class="chart"></div><script>Plotly.newPlot("fig17", [{"type": "scatter", "mode": "lines+markers", "name": "Rust (luminal)", "x": [10, 100, 500], "y": [102.18644, 186.34269, 498.48983000000004], "marker": {"size": 8, "color": "#d97a4c"}, "line": {"color": "#d97a4c", "width": 2}, "hovertemplate": "iters=%{x}<br>Rust (luminal): %{y:.1f} sec<extra></extra>"}], {"title": "Time to Search", "xaxis": {"title": "Search iterations", "type": "log", "tickvals": [10, 100, 500], "ticktext": ["10", "100", "500"]}, "yaxis": {"title": "sec", "rangemode": "tozero"}, "legend": {"orientation": "h", "y": -0.25}, "margin": {"t": 50, "b": 90}, "plot_bgcolor": "#fafafa", "paper_bgcolor": "#ffffff"}, {responsive:true});</script></div>
|
||||
<table><thead><tr><th>Path</th><th>Config</th><th>TTFT</th><th>TPOT</th><th>Throughput</th><th>Compile</th><th>Prompt tokens</th><th>Note</th></tr></thead><tbody><tr style="background:#fff0f0"><td>luminal backend</td><td>s=10</td><td>—</td><td>—</td><td>—</td><td>—</td><td>—</td><td style="color:#c00">bench_python_luminal.py failed with code 1</td></tr>
|
||||
<tr><td>Rust (luminal)</td><td>s=10</td><td>388.2 ms</td><td>37.5 ms</td><td>26.7 tok/s</td><td>102186 ms</td><td>11</td><td style="color:#777">sum of per-token prefill durations</td></tr>
|
||||
<tr style="background:#fff0f0"><td>luminal backend</td><td>s=100</td><td>—</td><td>—</td><td>—</td><td>—</td><td>—</td><td style="color:#c00">bench_python_luminal.py failed with code 1</td></tr>
|
||||
<tr><td>Rust (luminal)</td><td>s=100</td><td>436.5 ms</td><td>42.0 ms</td><td>23.8 tok/s</td><td>186343 ms</td><td>11</td><td style="color:#777">sum of per-token prefill durations</td></tr>
|
||||
<tr style="background:#fff0f0"><td>luminal backend</td><td>s=500</td><td>—</td><td>—</td><td>—</td><td>—</td><td>—</td><td style="color:#c00">bench_python_luminal.py failed with code 1</td></tr>
|
||||
<tr><td>Rust (luminal)</td><td>s=500</td><td>386.1 ms</td><td>37.2 ms</td><td>26.8 tok/s</td><td>498490 ms</td><td>11</td><td style="color:#777">sum of per-token prefill durations</td></tr></tbody></table>
|
||||
</section>
|
||||
|
||||
<section id="gemma4_moe_comparison">
|
||||
<h2>Gemma4-Moe Comparison</h2>
|
||||
<p class="meta">google/gemma-4-26B-A4B · 11 prompt tokens · 4 results</p>
|
||||
<div class="charts-row"><div id="fig18" class="chart"></div><script>Plotly.newPlot("fig18", [{"type": "bar", "name": "HF Baseline", "x": ["gemma4-moe"], "y": [837.3980740143452], "text": ["837.4 ms"], "textposition": "outside", "marker": {"color": "#888888"}, "hovertemplate": "%{x}<br>HF Baseline: %{y:.1f} ms<extra></extra>"}, {"type": "bar", "name": "torch.compile", "x": ["gemma4-moe"], "y": [245.510076492792], "text": ["245.5 ms"], "textposition": "outside", "marker": {"color": "#5ab552"}, "hovertemplate": "%{x}<br>torch.compile: %{y:.1f} ms<extra></extra>"}], {"title": "TTFT", "yaxis": {"title": "ms", "rangemode": "tozero"}, "barmode": "group", "legend": {"orientation": "h", "y": -0.2}, "margin": {"t": 50, "b": 80}, "plot_bgcolor": "#fafafa", "paper_bgcolor": "#ffffff"}, {responsive:true});</script><div id="fig19" class="chart"></div><script>Plotly.newPlot("fig19", [{"type": "bar", "name": "HF Baseline", "x": ["gemma4-moe"], "y": [83.64427039632574], "text": ["83.6 ms"], "textposition": "outside", "marker": {"color": "#888888"}, "hovertemplate": "%{x}<br>HF Baseline: %{y:.1f} ms<extra></extra>"}, {"type": "bar", "name": "torch.compile", "x": ["gemma4-moe"], "y": [654.9649795080768], "text": ["655.0 ms"], "textposition": "outside", "marker": {"color": "#5ab552"}, "hovertemplate": "%{x}<br>torch.compile: %{y:.1f} ms<extra></extra>"}], {"title": "TPOT", "yaxis": {"title": "ms", "rangemode": "tozero"}, "barmode": "group", "legend": {"orientation": "h", "y": -0.2}, "margin": {"t": 50, "b": 80}, "plot_bgcolor": "#fafafa", "paper_bgcolor": "#ffffff"}, {responsive:true});</script><div id="fig20" class="chart"></div><script>Plotly.newPlot("fig20", [{"type": "bar", "name": "torch.compile", "x": ["gemma4-moe"], "y": [38.81582092499593], "text": ["38.8 sec"], "textposition": "outside", "marker": {"color": "#5ab552"}, "hovertemplate": "%{x}<br>torch.compile: %{y:.1f} sec<extra></extra>"}], {"title": "Time to Search", "yaxis": {"title": "sec", "rangemode": "tozero"}, "barmode": "group", "legend": {"orientation": "h", "y": -0.2}, "margin": {"t": 50, "b": 80}, "plot_bgcolor": "#fafafa", "paper_bgcolor": "#ffffff"}, {responsive:true});</script></div>
|
||||
<table><thead><tr><th>Path</th><th>Config</th><th>TTFT</th><th>TPOT</th><th>Throughput</th><th>Compile</th><th>Prompt tokens</th><th>Note</th></tr></thead><tbody><tr><td>HF Baseline</td><td>gemma4-moe</td><td>837.4 ms</td><td>83.6 ms</td><td>12.0 tok/s</td><td>—</td><td>11</td><td style="color:#777">sequential per-token, StaticCache KV cache</td></tr>
|
||||
<tr><td>torch.compile</td><td>gemma4-moe</td><td>245.5 ms</td><td>655.0 ms</td><td>1.5 tok/s</td><td>38816 ms</td><td>11</td><td style="color:#777">sequential per-token, StaticCache KV cache (torch.compile inductor)</td></tr>
|
||||
<tr style="background:#fff0f0"><td>luminal backend</td><td>gemma4-moe</td><td>—</td><td>—</td><td>—</td><td>—</td><td>—</td><td style="color:#c00">bench_python_luminal.py failed with code -9</td></tr>
|
||||
<tr style="background:#fff0f0"><td>Rust (luminal)</td><td>gemma4-moe</td><td>—</td><td>—</td><td>—</td><td>—</td><td>—</td><td style="color:#c00">rust bench failed with code -9</td></tr></tbody></table>
|
||||
</section>
|
||||
|
||||
<section id="gemma4_moe_sweep">
|
||||
<h2>Gemma4-Moe Sweep</h2>
|
||||
<p class="meta">google/gemma-4-26B-A4B · 2 results</p>
|
||||
<div class="charts-row"><div id="fig21" class="chart"></div><script>Plotly.newPlot("fig21", [], {"title": "TTFT", "xaxis": {"title": "Search iterations", "type": "log", "tickvals": [10], "ticktext": ["10"]}, "yaxis": {"title": "ms", "rangemode": "tozero"}, "legend": {"orientation": "h", "y": -0.25}, "margin": {"t": 50, "b": 90}, "plot_bgcolor": "#fafafa", "paper_bgcolor": "#ffffff"}, {responsive:true});</script></div>
|
||||
<table><thead><tr><th>Path</th><th>Config</th><th>TTFT</th><th>TPOT</th><th>Throughput</th><th>Compile</th><th>Prompt tokens</th><th>Note</th></tr></thead><tbody><tr style="background:#fff0f0"><td>luminal backend</td><td>s=10</td><td>—</td><td>—</td><td>—</td><td>—</td><td>—</td><td style="color:#c00">bench_python_luminal.py failed with code -9</td></tr>
|
||||
<tr style="background:#fff0f0"><td>Rust (luminal)</td><td>s=10</td><td>—</td><td>—</td><td>—</td><td>—</td><td>—</td><td style="color:#c00">rust bench failed with code -9</td></tr></tbody></table>
|
||||
</section>
|
||||
|
||||
<section id="qwen3_moe_comparison">
|
||||
<h2>Qwen3-Moe Comparison</h2>
|
||||
<p class="meta">Qwen/Qwen3-30B-A3B · 19 prompt tokens · 4 results</p>
|
||||
<div class="charts-row"><div id="fig22" class="chart"></div><script>Plotly.newPlot("fig22", [{"type": "bar", "name": "HF Baseline", "x": ["qwen3-moe"], "y": [1565.540504961973], "text": ["1565.5 ms"], "textposition": "outside", "marker": {"color": "#888888"}, "hovertemplate": "%{x}<br>HF Baseline: %{y:.1f} ms<extra></extra>"}, {"type": "bar", "name": "torch.compile", "x": ["qwen3-moe"], "y": [460.077923577046], "text": ["460.1 ms"], "textposition": "outside", "marker": {"color": "#5ab552"}, "hovertemplate": "%{x}<br>torch.compile: %{y:.1f} ms<extra></extra>"}, {"type": "bar", "name": "luminal backend", "x": ["qwen3-moe"], "y": [21002.791983017232], "text": ["21002.8 ms"], "textposition": "outside", "marker": {"color": "#4c9ed9"}, "hovertemplate": "%{x}<br>luminal backend: %{y:.1f} ms<extra></extra>"}, {"type": "bar", "name": "Rust (luminal)", "x": ["qwen3-moe"], "y": [662.07], "text": ["662.1 ms"], "textposition": "outside", "marker": {"color": "#d97a4c"}, "hovertemplate": "%{x}<br>Rust (luminal): %{y:.1f} ms<extra></extra>"}], {"title": "TTFT", "yaxis": {"title": "ms", "rangemode": "tozero"}, "barmode": "group", "legend": {"orientation": "h", "y": -0.2}, "margin": {"t": 50, "b": 80}, "plot_bgcolor": "#fafafa", "paper_bgcolor": "#ffffff"}, {responsive:true});</script><div id="fig23" class="chart"></div><script>Plotly.newPlot("fig23", [{"type": "bar", "name": "HF Baseline", "x": ["qwen3-moe"], "y": [84.527321747737], "text": ["84.5 ms"], "textposition": "outside", "marker": {"color": "#888888"}, "hovertemplate": "%{x}<br>HF Baseline: %{y:.1f} ms<extra></extra>"}, {"type": "bar", "name": "torch.compile", "x": ["qwen3-moe"], "y": [753.0061075551203], "text": ["753.0 ms"], "textposition": "outside", "marker": {"color": "#5ab552"}, "hovertemplate": "%{x}<br>torch.compile: %{y:.1f} ms<extra></extra>"}, {"type": "bar", "name": "luminal backend", "x": ["qwen3-moe"], "y": [1166.8824461026816], "text": ["1166.9 ms"], "textposition": "outside", "marker": {"color": "#4c9ed9"}, "hovertemplate": "%{x}<br>luminal backend: %{y:.1f} ms<extra></extra>"}, {"type": "bar", "name": "Rust (luminal)", "x": ["qwen3-moe"], "y": [60.08], "text": ["60.1 ms"], "textposition": "outside", "marker": {"color": "#d97a4c"}, "hovertemplate": "%{x}<br>Rust (luminal): %{y:.1f} ms<extra></extra>"}], {"title": "TPOT", "yaxis": {"title": "ms", "rangemode": "tozero"}, "barmode": "group", "legend": {"orientation": "h", "y": -0.2}, "margin": {"t": 50, "b": 80}, "plot_bgcolor": "#fafafa", "paper_bgcolor": "#ffffff"}, {responsive:true});</script><div id="fig24" class="chart"></div><script>Plotly.newPlot("fig24", [{"type": "bar", "name": "torch.compile", "x": ["qwen3-moe"], "y": [8.341281775035895], "text": ["8.3 sec"], "textposition": "outside", "marker": {"color": "#5ab552"}, "hovertemplate": "%{x}<br>torch.compile: %{y:.1f} sec<extra></extra>"}, {"type": "bar", "name": "luminal backend", "x": ["qwen3-moe"], "y": [111.70731823903043], "text": ["111.7 sec"], "textposition": "outside", "marker": {"color": "#4c9ed9"}, "hovertemplate": "%{x}<br>luminal backend: %{y:.1f} sec<extra></extra>"}, {"type": "bar", "name": "Rust (luminal)", "x": ["qwen3-moe"], "y": [80.83241000000001], "text": ["80.8 sec"], "textposition": "outside", "marker": {"color": "#d97a4c"}, "hovertemplate": "%{x}<br>Rust (luminal): %{y:.1f} sec<extra></extra>"}], {"title": "Time to Search", "yaxis": {"title": "sec", "rangemode": "tozero"}, "barmode": "group", "legend": {"orientation": "h", "y": -0.2}, "margin": {"t": 50, "b": 80}, "plot_bgcolor": "#fafafa", "paper_bgcolor": "#ffffff"}, {responsive:true});</script></div>
|
||||
<table><thead><tr><th>Path</th><th>Config</th><th>TTFT</th><th>TPOT</th><th>Throughput</th><th>Compile</th><th>Prompt tokens</th><th>Note</th></tr></thead><tbody><tr><td>HF Baseline</td><td>qwen3-moe</td><td>1565.5 ms</td><td>84.5 ms</td><td>11.8 tok/s</td><td>—</td><td>19</td><td style="color:#777">sequential per-token, StaticCache KV cache</td></tr>
|
||||
<tr><td>torch.compile</td><td>qwen3-moe</td><td>460.1 ms</td><td>753.0 ms</td><td>1.3 tok/s</td><td>8341 ms</td><td>19</td><td style="color:#777">sequential per-token, StaticCache KV cache (torch.compile inductor)</td></tr>
|
||||
<tr><td>luminal backend</td><td>qwen3-moe</td><td>21002.8 ms</td><td>1166.9 ms</td><td>0.9 tok/s</td><td>111707 ms</td><td>19</td><td style="color:#777">sequential per-token, StaticCache KV cache</td></tr>
|
||||
<tr><td>Rust (luminal)</td><td>qwen3-moe</td><td>662.1 ms</td><td>60.1 ms</td><td>16.6 tok/s</td><td>80832 ms</td><td>—</td><td style="color:#777">sum of per-token prefill durations</td></tr></tbody></table>
|
||||
</section>
|
||||
|
||||
<section id="qwen3_moe_sweep">
|
||||
<h2>Qwen3-Moe Sweep</h2>
|
||||
<p class="meta">Qwen/Qwen3-30B-A3B · 19 prompt tokens · 6 results</p>
|
||||
<div class="charts-row"><div id="fig25" class="chart"></div><script>Plotly.newPlot("fig25", [{"type": "scatter", "mode": "lines+markers", "name": "luminal backend", "x": [10, 100, 500], "y": [21002.663500519702, 21018.686580006033, 21034.366824431345], "marker": {"size": 8, "color": "#4c9ed9"}, "line": {"color": "#4c9ed9", "width": 2}, "hovertemplate": "iters=%{x}<br>luminal backend: %{y:.1f} ms<extra></extra>"}, {"type": "scatter", "mode": "lines+markers", "name": "Rust (luminal)", "x": [10, 100, 500], "y": [656.7, 540.37, 542.34], "marker": {"size": 8, "color": "#d97a4c"}, "line": {"color": "#d97a4c", "width": 2}, "hovertemplate": "iters=%{x}<br>Rust (luminal): %{y:.1f} ms<extra></extra>"}], {"title": "TTFT", "xaxis": {"title": "Search iterations", "type": "log", "tickvals": [10, 100, 500], "ticktext": ["10", "100", "500"]}, "yaxis": {"title": "ms", "rangemode": "tozero"}, "legend": {"orientation": "h", "y": -0.25}, "margin": {"t": 50, "b": 90}, "plot_bgcolor": "#fafafa", "paper_bgcolor": "#ffffff"}, {responsive:true});</script><div id="fig26" class="chart"></div><script>Plotly.newPlot("fig26", [{"type": "scatter", "mode": "lines+markers", "name": "luminal backend", "x": [10, 100, 500], "y": [1166.6714247548953, 1167.2746865515364, 1168.7990181031637], "marker": {"size": 8, "color": "#4c9ed9"}, "line": {"color": "#4c9ed9", "width": 2}, "hovertemplate": "iters=%{x}<br>luminal backend: %{y:.1f} ms<extra></extra>"}, {"type": "scatter", "mode": "lines+markers", "name": "Rust (luminal)", "x": [10, 100, 500], "y": [59.6, 48.79, 48.88], "marker": {"size": 8, "color": "#d97a4c"}, "line": {"color": "#d97a4c", "width": 2}, "hovertemplate": "iters=%{x}<br>Rust (luminal): %{y:.1f} ms<extra></extra>"}], {"title": "TPOT", "xaxis": {"title": "Search iterations", "type": "log", "tickvals": [10, 100, 500], "ticktext": ["10", "100", "500"]}, "yaxis": {"title": "ms", "rangemode": "tozero"}, "legend": {"orientation": "h", "y": -0.25}, "margin": {"t": 50, "b": 90}, "plot_bgcolor": "#fafafa", "paper_bgcolor": "#ffffff"}, {responsive:true});</script><div id="fig27" class="chart"></div><script>Plotly.newPlot("fig27", [{"type": "scatter", "mode": "lines+markers", "name": "luminal backend", "x": [10, 100, 500], "y": [93.47603664599592, 132.266081985028, 298.05094401398674], "marker": {"size": 8, "color": "#4c9ed9"}, "line": {"color": "#4c9ed9", "width": 2}, "hovertemplate": "iters=%{x}<br>luminal backend: %{y:.1f} sec<extra></extra>"}, {"type": "scatter", "mode": "lines+markers", "name": "Rust (luminal)", "x": [10, 100, 500], "y": [25.48138, 47.5342, 134.79345], "marker": {"size": 8, "color": "#d97a4c"}, "line": {"color": "#d97a4c", "width": 2}, "hovertemplate": "iters=%{x}<br>Rust (luminal): %{y:.1f} sec<extra></extra>"}], {"title": "Time to Search", "xaxis": {"title": "Search iterations", "type": "log", "tickvals": [10, 100, 500], "ticktext": ["10", "100", "500"]}, "yaxis": {"title": "sec", "rangemode": "tozero"}, "legend": {"orientation": "h", "y": -0.25}, "margin": {"t": 50, "b": 90}, "plot_bgcolor": "#fafafa", "paper_bgcolor": "#ffffff"}, {responsive:true});</script></div>
|
||||
<table><thead><tr><th>Path</th><th>Config</th><th>TTFT</th><th>TPOT</th><th>Throughput</th><th>Compile</th><th>Prompt tokens</th><th>Note</th></tr></thead><tbody><tr><td>luminal backend</td><td>s=10</td><td>21002.7 ms</td><td>1166.7 ms</td><td>0.9 tok/s</td><td>93476 ms</td><td>19</td><td style="color:#777">sequential per-token, StaticCache KV cache</td></tr>
|
||||
<tr><td>Rust (luminal)</td><td>s=10</td><td>656.7 ms</td><td>59.6 ms</td><td>16.8 tok/s</td><td>25481 ms</td><td>—</td><td style="color:#777">sum of per-token prefill durations</td></tr>
|
||||
<tr><td>luminal backend</td><td>s=100</td><td>21018.7 ms</td><td>1167.3 ms</td><td>0.9 tok/s</td><td>132266 ms</td><td>19</td><td style="color:#777">sequential per-token, StaticCache KV cache</td></tr>
|
||||
<tr><td>Rust (luminal)</td><td>s=100</td><td>540.4 ms</td><td>48.8 ms</td><td>20.5 tok/s</td><td>47534 ms</td><td>—</td><td style="color:#777">sum of per-token prefill durations</td></tr>
|
||||
<tr><td>luminal backend</td><td>s=500</td><td>21034.4 ms</td><td>1168.8 ms</td><td>0.9 tok/s</td><td>298051 ms</td><td>19</td><td style="color:#777">sequential per-token, StaticCache KV cache</td></tr>
|
||||
<tr><td>Rust (luminal)</td><td>s=500</td><td>542.3 ms</td><td>48.9 ms</td><td>20.5 tok/s</td><td>134793 ms</td><td>—</td><td style="color:#777">sum of per-token prefill durations</td></tr></tbody></table>
|
||||
</section>
|
||||
</main>
|
||||
</body>
|
||||
</html>
|
||||
683
benchmarks/ttft/run.py
Normal file
683
benchmarks/ttft/run.py
Normal file
@@ -0,0 +1,683 @@
|
||||
"""TTFT + TPOT benchmark orchestrator.
|
||||
|
||||
Runs four paths in isolated subprocesses:
|
||||
1. python_baseline — HuggingFace / PyTorch eager on CUDA
|
||||
2. python_torch_compile — torch.compile(model) inductor backend
|
||||
3. python_luminal — torch.compile(model, backend=luminal_backend)
|
||||
4. rust — examples/<package> binary (luminal_cuda_lite)
|
||||
|
||||
Use --config to select a named configuration, or --all-configs to run every
|
||||
entry in CONFIGS. All output is written to the SQLite bench DB
|
||||
(benchmarks/ttft/bench.db); the TUI / dashboard / report read from there.
|
||||
|
||||
Notes on comparability:
|
||||
- python_baseline: single chunked forward for TTFT; KV-cache decode for TPOT.
|
||||
- python_torch_compile: inductor, same chunked prefill as baseline; first
|
||||
call triggers JIT compilation (recorded separately as compile_ms).
|
||||
- python_luminal: sequential per-token prefill with StaticCache; TPOT via
|
||||
autoregressive decode steps.
|
||||
- rust: sequential per-token prefill; TTFT = sum of prefill step durations.
|
||||
Steady-state execution only — compile / egraph-search time excluded from TTFT but
|
||||
recorded separately as compile_ms for all paths that support it.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
try:
|
||||
import tomllib
|
||||
except ImportError:
|
||||
try:
|
||||
import tomli as tomllib # type: ignore[no-redef]
|
||||
except ImportError:
|
||||
raise ImportError("Python 3.11+ or 'pip install tomli' required to load benchmarks.toml")
|
||||
|
||||
import db
|
||||
|
||||
BENCH_DIR = Path(__file__).resolve().parent
|
||||
REPO_ROOT = BENCH_DIR.parent.parent
|
||||
|
||||
DEFAULT_PROMPT = "Explain what a neural network is in a paragraph."
|
||||
DEFAULT_MODEL = "NousResearch/Meta-Llama-3-8B-Instruct"
|
||||
|
||||
_CONFIG_PATH = BENCH_DIR / "benchmarks.toml"
|
||||
with open(_CONFIG_PATH, "rb") as _f:
|
||||
_BENCH_CONFIG = tomllib.load(_f)
|
||||
|
||||
# Named benchmark configurations. Each entry overrides any subset of the
|
||||
# CLI defaults; explicit CLI flags always take precedence over the config.
|
||||
CONFIGS: dict = _BENCH_CONFIG["configs"]
|
||||
UR_TEST_MODELS: list = _BENCH_CONFIG["ur_test"]["models"]
|
||||
SEARCH_SWEEP_ITERS: list = _BENCH_CONFIG["ur_test"]["search_sweep_iters"]
|
||||
|
||||
SWEEP_CONFIG_PREFIX = "s="
|
||||
|
||||
BENCH_LINE = re.compile(r"^BENCH_RESULT (.*)$", re.MULTILINE)
|
||||
RUST_TTFT_LINE = re.compile(r"TTFT:\s*([0-9]+\.?[0-9]*)\s*ms")
|
||||
RUST_TPOT_LINE = re.compile(r"TPOT:\s*([0-9]+\.?[0-9]*)\s*ms")
|
||||
RUST_COMPILE_LINE = re.compile(r"COMPILE:\s*([0-9]+\.?[0-9]*)\s*ms")
|
||||
RUST_PROMPT_LINE = re.compile(r"Prompt:\s*(\d+)\s*tokens")
|
||||
|
||||
|
||||
def _stream(proc, tee_prefix):
|
||||
"""Drain subprocess stdout, tee-ing to our stdout line-by-line. Returns full stdout."""
|
||||
buf = []
|
||||
assert proc.stdout is not None
|
||||
for line in proc.stdout:
|
||||
buf.append(line)
|
||||
sys.stdout.write(f"[{tee_prefix}] {line}")
|
||||
sys.stdout.flush()
|
||||
proc.wait()
|
||||
return "".join(buf)
|
||||
|
||||
|
||||
_MEM_LOG_PATH = os.environ.get("BENCH_MEM_LOG", "/tmp/bench_mem_snapshots.log")
|
||||
|
||||
|
||||
def _snapshot_memory(label: str) -> None:
|
||||
"""Append a host+GPU memory snapshot to BENCH_MEM_LOG. Cheap, never raises."""
|
||||
try:
|
||||
ts = datetime.datetime.now().isoformat(timespec="seconds")
|
||||
meminfo_keys = ("MemTotal", "MemFree", "MemAvailable", "Cached", "Slab", "SReclaimable")
|
||||
meminfo = {}
|
||||
with open("/proc/meminfo") as f:
|
||||
for line in f:
|
||||
k, _, rest = line.partition(":")
|
||||
if k in meminfo_keys:
|
||||
meminfo[k] = rest.strip().split()[0] # kB
|
||||
try:
|
||||
gpu = subprocess.check_output(
|
||||
["nvidia-smi", "--query-gpu=memory.used,memory.free,memory.total",
|
||||
"--format=csv,noheader,nounits"],
|
||||
stderr=subprocess.DEVNULL, text=True, timeout=5,
|
||||
).strip().splitlines()[0]
|
||||
except Exception:
|
||||
gpu = "n/a"
|
||||
parent_rss = "?"
|
||||
try:
|
||||
with open(f"/proc/{os.getpid()}/status") as f:
|
||||
for line in f:
|
||||
if line.startswith("VmRSS:"):
|
||||
parent_rss = line.split()[1]
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
host_str = " ".join(f"{k}={meminfo.get(k, '?')}kB" for k in meminfo_keys)
|
||||
with open(_MEM_LOG_PATH, "a") as f:
|
||||
f.write(f"{ts} [{label}] parent_rss={parent_rss}kB {host_str} gpu(used,free,total MiB)={gpu}\n")
|
||||
except Exception as e:
|
||||
sys.stderr.write(f"[mem-snapshot warn] {e}\n")
|
||||
|
||||
|
||||
def _cargo_env():
|
||||
"""Return env dict with ~/.cargo/bin prepended to PATH."""
|
||||
cargo_bin = str(Path.home() / ".cargo" / "bin")
|
||||
path = os.environ.get("PATH", "")
|
||||
if cargo_bin not in path:
|
||||
path = f"{cargo_bin}:{path}"
|
||||
return {**os.environ, "PATH": path}
|
||||
|
||||
|
||||
def run_rust(_prompt, package="llama", env_vars=None):
|
||||
print(f"\n=== Running: rust (examples/{package}) ===", flush=True)
|
||||
cmd = ["cargo", "run", "--release", "-p", package]
|
||||
env = _cargo_env()
|
||||
if env_vars:
|
||||
env.update(env_vars)
|
||||
proc = subprocess.Popen(
|
||||
cmd,
|
||||
cwd=REPO_ROOT,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
bufsize=1,
|
||||
env=env,
|
||||
)
|
||||
output = _stream(proc, "rust")
|
||||
if proc.returncode != 0:
|
||||
raise RuntimeError(f"rust bench failed with code {proc.returncode}")
|
||||
m = RUST_TTFT_LINE.search(output)
|
||||
if not m:
|
||||
raise RuntimeError("could not find 'TTFT: X ms' in rust stdout")
|
||||
ttft_ms = float(m.group(1))
|
||||
result = {
|
||||
"path": "rust",
|
||||
"model": DEFAULT_MODEL,
|
||||
"ttft_ms": ttft_ms,
|
||||
"note": "sum of per-token prefill durations",
|
||||
}
|
||||
m_compile = RUST_COMPILE_LINE.search(output)
|
||||
if m_compile:
|
||||
result["compile_ms"] = float(m_compile.group(1))
|
||||
m_tpot = RUST_TPOT_LINE.search(output)
|
||||
if m_tpot:
|
||||
tpot_ms = float(m_tpot.group(1))
|
||||
result["tpot_ms"] = tpot_ms
|
||||
result["throughput_tps"] = 1000.0 / tpot_ms
|
||||
m_prompt = RUST_PROMPT_LINE.search(output)
|
||||
if m_prompt:
|
||||
result["prompt_tokens"] = int(m_prompt.group(1))
|
||||
return result
|
||||
|
||||
|
||||
def run_python_script(name, extra_args):
|
||||
script = BENCH_DIR / name
|
||||
print(f"\n=== Running: {script.name} ===", flush=True)
|
||||
cmd = [sys.executable, str(script), *extra_args]
|
||||
proc = subprocess.Popen(
|
||||
cmd,
|
||||
cwd=REPO_ROOT,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
bufsize=1,
|
||||
env={**os.environ},
|
||||
)
|
||||
output = _stream(proc, script.stem)
|
||||
if proc.returncode != 0:
|
||||
raise RuntimeError(f"{script.name} failed with code {proc.returncode}")
|
||||
m = BENCH_LINE.search(output)
|
||||
if not m:
|
||||
raise RuntimeError(f"no BENCH_RESULT line in {script.name} output")
|
||||
return json.loads(m.group(1))
|
||||
|
||||
|
||||
PATH_ORDER = ["python_baseline", "python_torch_compile", "python_luminal", "rust"]
|
||||
PATH_LABELS = {
|
||||
"python_baseline": "Python\n(HF baseline)",
|
||||
"python_torch_compile": "Python\n(torch.compile)",
|
||||
"python_luminal": "Python → Rust\n(luminal_backend)",
|
||||
"rust": "Rust\n(examples/llama)",
|
||||
}
|
||||
PATH_COLORS = {
|
||||
"python_baseline": "#888888",
|
||||
"python_torch_compile": "#5ab552",
|
||||
"python_luminal": "#4c9ed9",
|
||||
"rust": "#d97a4c",
|
||||
}
|
||||
|
||||
|
||||
def run_one_config(config_name, settings, global_skip, inter_path_cooldown=0):
|
||||
"""Run all four paths for one config. Returns list of result dicts tagged with 'config'."""
|
||||
model = settings["model"]
|
||||
rust_package = settings["rust_package"]
|
||||
prompt = settings["prompt"]
|
||||
iters = settings["iters"]
|
||||
warmups = settings["warmups"]
|
||||
decode_tokens = settings["decode_tokens"]
|
||||
search_iters = settings["search_iters"]
|
||||
dtype = settings.get("dtype", "float32")
|
||||
skip = set(global_skip) | set(settings.get("skip", []))
|
||||
|
||||
common_py = [
|
||||
"--model", model,
|
||||
"--prompt", prompt,
|
||||
"--iters", str(iters),
|
||||
"--warmups", str(warmups),
|
||||
"--decode-tokens", str(decode_tokens),
|
||||
"--dtype", dtype,
|
||||
]
|
||||
luminal_py = common_py + ["--search-iters", str(search_iters)]
|
||||
|
||||
rust_env = {"SEARCH_GRAPHS": str(search_iters), "PROMPT": prompt, "ITERS": str(iters)}
|
||||
|
||||
results = []
|
||||
first_path = True
|
||||
for path, fn in [
|
||||
("python_baseline", lambda: run_python_script("bench_python_baseline.py", common_py)),
|
||||
("python_torch_compile", lambda: run_python_script("bench_python_torch_compile.py", common_py)),
|
||||
("python_luminal", lambda: run_python_script("bench_python_luminal.py", luminal_py)),
|
||||
("rust", lambda: run_rust(prompt, package=rust_package, env_vars=rust_env)),
|
||||
]:
|
||||
if path in skip:
|
||||
continue
|
||||
if not first_path and inter_path_cooldown > 0:
|
||||
print(f" [cooldown {inter_path_cooldown}s]", flush=True)
|
||||
time.sleep(inter_path_cooldown)
|
||||
first_path = False
|
||||
_snapshot_memory(f"{config_name}/{path} BEFORE")
|
||||
try:
|
||||
r = fn()
|
||||
r["config"] = config_name
|
||||
r["model"] = model # ensure correct model is always tagged
|
||||
if path in ("python_luminal", "rust"):
|
||||
r["search_iters"] = search_iters
|
||||
results.append(r)
|
||||
except Exception as e:
|
||||
print(f"\n[WARN] {config_name}/{path} failed: {e}", flush=True)
|
||||
results.append({
|
||||
"path": path,
|
||||
"config": config_name,
|
||||
"model": model,
|
||||
"error": str(e),
|
||||
"ttft_ms": None,
|
||||
})
|
||||
_snapshot_memory(f"{config_name}/{path} AFTER")
|
||||
return results
|
||||
|
||||
|
||||
def plot(results, out_path):
|
||||
import matplotlib
|
||||
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# Group by config so each config gets its own subplot column.
|
||||
configs_seen: list[str] = []
|
||||
by_config: dict[str, dict] = {}
|
||||
for r in results:
|
||||
cfg = r.get("config", "default")
|
||||
if cfg not in by_config:
|
||||
configs_seen.append(cfg)
|
||||
by_config[cfg] = {}
|
||||
by_config[cfg][r["path"]] = r
|
||||
|
||||
has_tpot = any(
|
||||
r.get("tpot_ms") is not None
|
||||
for r in results
|
||||
if not r.get("error")
|
||||
)
|
||||
nrows = 2 if has_tpot else 1
|
||||
ncols = len(configs_seen)
|
||||
fig, axes = plt.subplots(nrows, ncols, figsize=(6 * ncols, 4.5 * nrows), squeeze=False)
|
||||
|
||||
for col, cfg in enumerate(configs_seen):
|
||||
by_path = by_config[cfg]
|
||||
present = [p for p in PATH_ORDER if p in by_path]
|
||||
|
||||
def _bar(ax, title, ylabel, key):
|
||||
raw = [by_path[p].get(key) for p in present]
|
||||
ys = [v if v is not None else 0.0 for v in raw]
|
||||
cs = [PATH_COLORS.get(p, "#aaaaaa") if raw[i] is not None else "#cccccc"
|
||||
for i, p in enumerate(present)]
|
||||
xs = [PATH_LABELS.get(p, p) for p in present]
|
||||
bars = ax.bar(xs, ys, color=cs)
|
||||
ax.set_ylabel(ylabel)
|
||||
ax.set_title(f"{title} — {cfg}")
|
||||
ax.grid(axis="y", alpha=0.3)
|
||||
for b, v in zip(bars, raw):
|
||||
if v is not None:
|
||||
ax.text(b.get_x() + b.get_width() / 2, v, f"{v:.0f} ms",
|
||||
ha="center", va="bottom", fontsize=9)
|
||||
|
||||
_bar(axes[0][col], "TTFT", "Time to first token (ms)", "ttft_ms")
|
||||
if has_tpot:
|
||||
_bar(axes[1][col], "TPOT", "Time per output token (ms)", "tpot_ms")
|
||||
|
||||
fig.tight_layout()
|
||||
fig.savefig(out_path, dpi=150)
|
||||
print(f"wrote {out_path}")
|
||||
|
||||
|
||||
def run_ur_test(args, conn, run_id):
|
||||
"""The ur-test: all 4 paths at default budget + full search sweep, for each model.
|
||||
|
||||
Inserts each result into the DB as it is produced so a mid-run crash still
|
||||
leaves partial data behind.
|
||||
"""
|
||||
all_results = []
|
||||
|
||||
for model_idx, model_key in enumerate(UR_TEST_MODELS):
|
||||
s = _settings_for_config(model_key, args)
|
||||
|
||||
if model_idx > 0:
|
||||
print(f"\n [cooldown 30s between models]", flush=True)
|
||||
time.sleep(30)
|
||||
|
||||
# ── Phase 1: comparison — all 4 paths at the model's default search budget ──
|
||||
print(f"\n{'='*60}\nUR-TEST COMPARISON: {model_key}\n{'='*60}", flush=True)
|
||||
comp_results = run_one_config(model_key, s, args.skip, inter_path_cooldown=20)
|
||||
for r in comp_results:
|
||||
r["model_key"] = model_key
|
||||
db.insert_result(conn, run_id, r)
|
||||
conn.commit()
|
||||
all_results.extend(comp_results)
|
||||
|
||||
# ── Phase 2: search sweep — python_luminal + rust across all budgets ──
|
||||
if args.no_sweep:
|
||||
continue
|
||||
print(f"\n{'='*60}\nUR-TEST SWEEP: {model_key}\n{'='*60}", flush=True)
|
||||
sweep_skip_base = set(args.skip) | {"python_baseline", "python_torch_compile"}
|
||||
# Memory peak in egglog Search grows monotonically with search-iters.
|
||||
# If a path SIGKILLs (-9) at budget N, every higher budget will too —
|
||||
# skip it to avoid wasting another ~hour per model on guaranteed OOMs.
|
||||
oom_paths: set[str] = set()
|
||||
for n in SEARCH_SWEEP_ITERS:
|
||||
print(f" [cooldown 20s before s={n}]", flush=True)
|
||||
time.sleep(20)
|
||||
sweep_skip = list(sweep_skip_base | oom_paths)
|
||||
if oom_paths:
|
||||
print(f" [skip-on-prior-OOM] {sorted(oom_paths)} OOM'd at lower budget; skipping at s={n}", flush=True)
|
||||
sweep_s = {**s, "search_iters": n}
|
||||
results_n = run_one_config(f"s={n}", sweep_s, sweep_skip, inter_path_cooldown=20)
|
||||
for r in results_n:
|
||||
r["model_key"] = model_key # preserve ur-test model identity for dashboard
|
||||
db.insert_result(conn, run_id, r)
|
||||
if "code -9" in (r.get("error") or ""):
|
||||
oom_paths.add(r["path"])
|
||||
conn.commit()
|
||||
all_results.extend(results_n)
|
||||
|
||||
print("\nGenerate report with:")
|
||||
print(f" python3 benchmarks/ttft/gen_report.py --db benchmarks/ttft/bench.db --run {run_id} \\")
|
||||
print(" --out benchmarks/ttft/report.html")
|
||||
print("\nGenerate dashboard with:")
|
||||
print(" python3 benchmarks/ttft/gen_dashboard.py --out benchmarks/ttft/dashboard.html")
|
||||
|
||||
return all_results
|
||||
|
||||
|
||||
def _git_info():
|
||||
"""Return (short_commit, branch) from the repo, or ('unknown', 'unknown') if unavailable."""
|
||||
try:
|
||||
commit = subprocess.check_output(
|
||||
["git", "rev-parse", "--short", "HEAD"],
|
||||
cwd=REPO_ROOT, stderr=subprocess.DEVNULL, text=True,
|
||||
).strip()
|
||||
branch = subprocess.check_output(
|
||||
["git", "rev-parse", "--abbrev-ref", "HEAD"],
|
||||
cwd=REPO_ROOT, stderr=subprocess.DEVNULL, text=True,
|
||||
).strip()
|
||||
return commit, branch
|
||||
except Exception:
|
||||
return "unknown", "unknown"
|
||||
|
||||
|
||||
def _gpu_info() -> dict:
|
||||
"""Return GPU metadata from nvidia-smi, or empty dict if unavailable."""
|
||||
try:
|
||||
out = subprocess.check_output(
|
||||
[
|
||||
"nvidia-smi",
|
||||
"--query-gpu=name,driver_version,memory.total",
|
||||
"--format=csv,noheader,nounits",
|
||||
],
|
||||
stderr=subprocess.DEVNULL,
|
||||
text=True,
|
||||
).strip()
|
||||
if not out:
|
||||
return {}
|
||||
parts = [p.strip() for p in out.splitlines()[0].split(",")]
|
||||
if len(parts) < 3:
|
||||
return {}
|
||||
return {
|
||||
"gpu_name": parts[0],
|
||||
"gpu_driver": parts[1],
|
||||
"gpu_vram_mb": int(parts[2]),
|
||||
}
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
|
||||
def _cuda_version() -> str:
|
||||
"""Return CUDA version string from nvidia-smi, or 'unknown'."""
|
||||
try:
|
||||
out = subprocess.check_output(
|
||||
["nvidia-smi", "--query", "--display=COMPUTE"],
|
||||
stderr=subprocess.DEVNULL,
|
||||
text=True,
|
||||
)
|
||||
for line in out.splitlines():
|
||||
if "CUDA Version" in line:
|
||||
return line.split(":")[-1].strip()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
out = subprocess.check_output(
|
||||
["nvidia-smi"], stderr=subprocess.DEVNULL, text=True
|
||||
)
|
||||
import re as _re
|
||||
m = _re.search(r"CUDA Version:\s*([\d.]+)", out)
|
||||
if m:
|
||||
return m.group(1)
|
||||
except Exception:
|
||||
pass
|
||||
return "unknown"
|
||||
|
||||
|
||||
def _record_run(conn, mode):
|
||||
"""Insert a `runs` row capturing this orchestrator invocation. Returns run_id.
|
||||
|
||||
Uses microsecond resolution in the run_id so two invocations within the
|
||||
same wallclock second never collide on the runs PRIMARY KEY (insert_run
|
||||
defaults to OR IGNORE, which would otherwise silently merge them and
|
||||
corrupt history). Microseconds also let the dashboard plot back-to-back
|
||||
runs at distinct x-positions instead of stacking them on one date label.
|
||||
"""
|
||||
now = datetime.datetime.now()
|
||||
run_id = now.strftime("%Y-%m-%dT%H-%M-%S-%f")
|
||||
commit, branch = _git_info()
|
||||
db.insert_run(
|
||||
conn,
|
||||
run_id=run_id,
|
||||
timestamp=now.isoformat(),
|
||||
mode=mode,
|
||||
git_commit=commit,
|
||||
git_branch=branch,
|
||||
cuda_version=_cuda_version(),
|
||||
**_gpu_info(),
|
||||
)
|
||||
conn.commit()
|
||||
return run_id
|
||||
|
||||
|
||||
def _settings_from_args(args):
|
||||
"""Build a settings dict from parsed CLI args."""
|
||||
return {
|
||||
"model": args.model,
|
||||
"rust_package": args.rust_package,
|
||||
"prompt": args.prompt,
|
||||
"iters": args.iters,
|
||||
"warmups": args.warmups,
|
||||
"decode_tokens": args.decode_tokens,
|
||||
"search_iters": args.search_iters,
|
||||
"dtype": args.dtype,
|
||||
"skip": [],
|
||||
}
|
||||
|
||||
|
||||
def _settings_for_config(config_name, args):
|
||||
"""Merge CONFIGS[config_name] over CLI arg defaults."""
|
||||
cfg = CONFIGS[config_name]
|
||||
return {
|
||||
"model": cfg.get("model", args.model),
|
||||
"rust_package": cfg.get("rust_package", args.rust_package),
|
||||
"prompt": cfg.get("prompt", args.prompt),
|
||||
"iters": cfg.get("iters", args.iters),
|
||||
"warmups": cfg.get("warmups", args.warmups),
|
||||
"decode_tokens":cfg.get("decode_tokens",args.decode_tokens),
|
||||
"search_iters": cfg.get("search_iters", args.search_iters),
|
||||
"dtype": cfg.get("dtype", args.dtype),
|
||||
"skip": cfg.get("skip", []),
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument(
|
||||
"--config",
|
||||
choices=list(CONFIGS),
|
||||
default=None,
|
||||
help="Named benchmark configuration. Sets parameter defaults; explicit flags override.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--all-configs",
|
||||
action="store_true",
|
||||
dest="all_configs",
|
||||
help="Run every entry in CONFIGS into a single run_id in the DB.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--search-sweep",
|
||||
action="store_true",
|
||||
dest="search_sweep",
|
||||
help=(
|
||||
"Run python_luminal + rust across all SEARCH_SWEEP_ITERS budgets "
|
||||
f"({SEARCH_SWEEP_ITERS}). Uses --config (default: llama-8b) as the base settings."
|
||||
),
|
||||
)
|
||||
ap.add_argument(
|
||||
"--skip-configs",
|
||||
nargs="*",
|
||||
default=[],
|
||||
choices=list(CONFIGS),
|
||||
dest="skip_configs",
|
||||
metavar="CONFIG",
|
||||
help="Config names to exclude when using --all-configs.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--no-sweep",
|
||||
action="store_true",
|
||||
dest="no_sweep",
|
||||
help=(
|
||||
"With --ur-test: skip the search-budget sweep phase and only run "
|
||||
"the 4-path comparison for each model. ~1.5 hr instead of ~5 hr."
|
||||
),
|
||||
)
|
||||
ap.add_argument("--model", default=DEFAULT_MODEL)
|
||||
ap.add_argument("--rust-package", default="llama", dest="rust_package",
|
||||
help="Cargo package name for the rust bench (examples/<name>).")
|
||||
ap.add_argument("--prompt", default=DEFAULT_PROMPT)
|
||||
ap.add_argument("--iters", type=int, default=3)
|
||||
ap.add_argument("--warmups", type=int, default=1)
|
||||
ap.add_argument("--skip", nargs="*", default=[],
|
||||
choices=["rust", "python_luminal", "python_baseline", "python_torch_compile"])
|
||||
ap.add_argument("--out", default=str(BENCH_DIR / "ttft.png"))
|
||||
ap.add_argument("--db", default=str(db.DEFAULT_DB_PATH),
|
||||
help="SQLite database file (default: benchmarks/ttft/bench.db).")
|
||||
ap.add_argument("--run", default=None, dest="run",
|
||||
help="With --render-only: run_id to render (default: latest).")
|
||||
ap.add_argument(
|
||||
"--decode-tokens", type=int, default=50,
|
||||
help="Tokens to generate for TPOT measurement (0 = skip TPOT).",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--search-iters", type=int, default=500,
|
||||
help="Egraph search iterations for the python_luminal path.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--dtype", default="float32",
|
||||
choices=["float32", "bfloat16", "float16"],
|
||||
help="Torch dtype for the python paths. Configs may override per-model.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--render-only", action="store_true",
|
||||
help="Skip running benches; render an existing run from the DB. "
|
||||
"Use --run RUN_ID to pick a specific run, otherwise the latest is used.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--ur-test", action="store_true", dest="ur_test",
|
||||
help=(
|
||||
f"The mega-test: run all 4 paths at default budget + full search sweep "
|
||||
f"({SEARCH_SWEEP_ITERS}) for each of {UR_TEST_MODELS}."
|
||||
),
|
||||
)
|
||||
|
||||
# Pre-parse to apply named config as argparse defaults so explicit CLI
|
||||
# flags still override them.
|
||||
pre, _ = ap.parse_known_args()
|
||||
if pre.config and not (pre.all_configs or getattr(pre, "search_sweep", False)):
|
||||
cfg = CONFIGS[pre.config]
|
||||
ap.set_defaults(**{k: v for k, v in cfg.items() if k not in ("skip",)})
|
||||
args = ap.parse_args()
|
||||
if pre.config and not args.all_configs and not args.search_sweep:
|
||||
for path in CONFIGS[pre.config].get("skip", []):
|
||||
if path not in args.skip:
|
||||
args.skip.append(path)
|
||||
|
||||
conn = db.connect(args.db)
|
||||
|
||||
if args.render_only:
|
||||
run_id = args.run or db.latest_run_id(conn)
|
||||
if run_id is None:
|
||||
sys.exit(f"--render-only: no runs found in {args.db}")
|
||||
results = db.load_results(conn, run_id)
|
||||
if not results:
|
||||
sys.exit(f"--render-only: no results found for run {run_id} in {args.db}")
|
||||
print(f"rendering run {run_id} ({len(results)} results)")
|
||||
else:
|
||||
mode = (
|
||||
("ur-test-fast" if args.no_sweep else "ur-test") if args.ur_test
|
||||
else "search-sweep" if args.search_sweep
|
||||
else "all-configs" if args.all_configs
|
||||
else "single"
|
||||
)
|
||||
run_id = _record_run(conn, mode)
|
||||
print(f"run_id: {run_id} → {args.db}")
|
||||
|
||||
if args.ur_test:
|
||||
results = run_ur_test(args, conn, run_id)
|
||||
elif args.search_sweep:
|
||||
results = []
|
||||
# Base settings come from --config (default: llama-8b) or bare CLI args.
|
||||
base = (
|
||||
_settings_for_config(args.config, args)
|
||||
if args.config
|
||||
else _settings_for_config("llama-8b", args)
|
||||
)
|
||||
sweep_skip = set(args.skip) | {"python_baseline", "python_torch_compile"}
|
||||
for i, n in enumerate(SEARCH_SWEEP_ITERS):
|
||||
if i > 0:
|
||||
print(f" [cooldown 20s — letting CUDA free previous model memory]", flush=True)
|
||||
time.sleep(20)
|
||||
print(f"\n{'='*60}\nSEARCH SWEEP: s={n}\n{'='*60}", flush=True)
|
||||
s = {**base, "search_iters": n}
|
||||
rs = run_one_config(f"s={n}", s, list(sweep_skip))
|
||||
for r in rs:
|
||||
db.insert_result(conn, run_id, r)
|
||||
conn.commit()
|
||||
results.extend(rs)
|
||||
elif args.all_configs:
|
||||
results = []
|
||||
for config_name in CONFIGS:
|
||||
if config_name in args.skip_configs:
|
||||
continue
|
||||
print(f"\n{'='*60}\nCONFIG: {config_name}\n{'='*60}", flush=True)
|
||||
settings = _settings_for_config(config_name, args)
|
||||
rs = run_one_config(config_name, settings, args.skip)
|
||||
for r in rs:
|
||||
db.insert_result(conn, run_id, r)
|
||||
conn.commit()
|
||||
results.extend(rs)
|
||||
else:
|
||||
config_name = args.config or "default"
|
||||
settings = (
|
||||
_settings_for_config(args.config, args)
|
||||
if args.config
|
||||
else _settings_from_args(args)
|
||||
)
|
||||
results = run_one_config(config_name, settings, args.skip)
|
||||
for r in results:
|
||||
db.insert_result(conn, run_id, r)
|
||||
conn.commit()
|
||||
|
||||
# Summary
|
||||
configs_in_results = list(dict.fromkeys(r.get("config", "default") for r in results))
|
||||
for cfg in configs_in_results:
|
||||
group = [r for r in results if r.get("config", "default") == cfg]
|
||||
print(f"\nSummary ({cfg}):")
|
||||
for r in group:
|
||||
if r.get("error"):
|
||||
print(f" {r['path']:>22}: FAILED — {r['error']}")
|
||||
continue
|
||||
if r.get("ttft_ms") is None:
|
||||
print(f" {r['path']:>22}: no data")
|
||||
continue
|
||||
compile_ms = r.get("compile_ms")
|
||||
compile_str = f" compile {compile_ms:.0f} ms" if compile_ms is not None else ""
|
||||
tpot = r.get("tpot_ms")
|
||||
tput = r.get("throughput_tps")
|
||||
tpot_str = f" TPOT {tpot:.2f} ms ({tput:.1f} tok/s)" if tpot is not None else ""
|
||||
print(f" {r['path']:>22}: TTFT {r['ttft_ms']:.2f} ms{compile_str}{tpot_str}")
|
||||
|
||||
plot(results, args.out)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
7
benchmarks/ttft/run.sh
Executable file
7
benchmarks/ttft/run.sh
Executable file
@@ -0,0 +1,7 @@
|
||||
#!/bin/bash
|
||||
# TTFT benchmark entrypoint. Runs via uv against the luminal_python venv.
|
||||
set -e
|
||||
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
|
||||
REPO_ROOT="$( cd "$SCRIPT_DIR/../.." && pwd )"
|
||||
cd "$REPO_ROOT/crates/luminal_python"
|
||||
exec uv run python "$SCRIPT_DIR/run.py" "$@"
|
||||
BIN
benchmarks/ttft/ttft.png
Normal file
BIN
benchmarks/ttft/ttft.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 272 KiB |
@@ -106,13 +106,13 @@ impl Case {
|
||||
let out = match self {
|
||||
Case::Mul => {
|
||||
let x = cx.tensor(size);
|
||||
x.clone() * x
|
||||
x * x
|
||||
}
|
||||
Case::Sigmoid => cx.tensor(size).sigmoid(),
|
||||
Case::Tanh => cx.tensor(size).tanh(),
|
||||
Case::GeluInner => {
|
||||
let x = cx.tensor(size);
|
||||
(0.797_884_560_8_f32 * x.clone() * (1. + 0.044_715_f32 * x.clone() * x)).tanh()
|
||||
(0.797_884_6_f32 * x * (1. + 0.044_715_f32 * x * x)).tanh()
|
||||
}
|
||||
Case::Gelu => cx.tensor(size).gelu(),
|
||||
Case::LayerNorm => {
|
||||
@@ -447,10 +447,10 @@ where
|
||||
if let Some(ref backend) = backend_analysis {
|
||||
print_lowering_analysis(backend);
|
||||
}
|
||||
} else if !args.inspect_ops.is_empty() {
|
||||
if let Some(ref backend) = backend_analysis {
|
||||
print_lowering_analysis(backend);
|
||||
}
|
||||
} else if !args.inspect_ops.is_empty()
|
||||
&& let Some(ref backend) = backend_analysis
|
||||
{
|
||||
print_lowering_analysis(backend);
|
||||
}
|
||||
|
||||
// Trace facts for explicit variables.
|
||||
|
||||
75
crates/luminal_cuda_lite/src/dyn_backend.rs
Normal file
75
crates/luminal_cuda_lite/src/dyn_backend.rs
Normal file
@@ -0,0 +1,75 @@
|
||||
//! [`DynBackend`] implementation for the CUDA lite runtime.
|
||||
|
||||
use luminal::dtype::DType;
|
||||
use luminal::dyn_backend::{BackendCompileArgs, DynBackend, compile_backend};
|
||||
use luminal::prelude::*;
|
||||
|
||||
use crate::cudarc::driver::CudaContext;
|
||||
use crate::runtime::CudaRuntime;
|
||||
|
||||
/// [`DynBackend`] wrapper for [`CudaRuntime`].
|
||||
pub struct CudaLiteDynBackend {
|
||||
pub runtime: CudaRuntime,
|
||||
}
|
||||
|
||||
impl DynBackend for CudaLiteDynBackend {
|
||||
fn name(&self) -> &str {
|
||||
"cuda_lite"
|
||||
}
|
||||
fn device_type(&self) -> &str {
|
||||
"cuda"
|
||||
}
|
||||
|
||||
fn set_data_bytes(&mut self, node: NodeIndex, bytes: Vec<u8>, _dtype: DType) {
|
||||
self.runtime.set_data(node, bytes);
|
||||
}
|
||||
fn set_data_f32(&mut self, node: NodeIndex, data: Vec<f32>) {
|
||||
self.runtime.set_data(node, data);
|
||||
}
|
||||
fn get_output_f32(&self, node: NodeIndex) -> Vec<f32> {
|
||||
self.runtime.get_f32(node)
|
||||
}
|
||||
fn get_output_i32(&self, node: NodeIndex) -> Vec<i32> {
|
||||
self.runtime.get_i32(node)
|
||||
}
|
||||
fn get_output_bool(&self, node: NodeIndex) -> Vec<bool> {
|
||||
self.runtime.get_bool(node)
|
||||
}
|
||||
fn execute(&mut self, dyn_map: &FxHashMap<char, usize>) {
|
||||
self.runtime.execute(dyn_map);
|
||||
}
|
||||
|
||||
fn supports_device_ptrs(&self) -> bool {
|
||||
true
|
||||
}
|
||||
unsafe fn set_device_ptr(&mut self, node: NodeIndex, ptr: u64, n: usize) {
|
||||
unsafe { self.runtime.set_device_ptr(node, ptr, n) }
|
||||
}
|
||||
unsafe fn set_output_device_ptr(&mut self, node: NodeIndex, ptr: u64, n: usize) {
|
||||
unsafe { self.runtime.set_output_device_ptr(node, ptr, n) }
|
||||
}
|
||||
fn output_is_zero_copy(&self, node: NodeIndex) -> bool {
|
||||
self.runtime.output_is_zero_copy(node)
|
||||
}
|
||||
unsafe fn copy_output_to_device_ptr(&self, node: NodeIndex, ptr: u64, n: usize) {
|
||||
unsafe { self.runtime.copy_output_to_device_ptr(node, ptr, n) }
|
||||
}
|
||||
}
|
||||
|
||||
pub fn cuda_lite_factory(
|
||||
graph: &mut Graph,
|
||||
args: BackendCompileArgs,
|
||||
) -> Result<Box<dyn DynBackend>, String> {
|
||||
let cuda_ctx = CudaContext::new(0).map_err(|e| format!("CUDA init failed: {e}"))?;
|
||||
let stream = cuda_ctx.default_stream();
|
||||
compile_backend::<CudaRuntime>(
|
||||
graph,
|
||||
args,
|
||||
|| Ok(CudaRuntime::initialize(stream)),
|
||||
|rt, node, bytes, _dtype| {
|
||||
rt.set_data(node, bytes);
|
||||
},
|
||||
Some(&|rt, node, ptr, n| unsafe { rt.set_device_ptr(node, ptr, n) }),
|
||||
|rt| Box::new(CudaLiteDynBackend { runtime: rt }),
|
||||
)
|
||||
}
|
||||
@@ -32,6 +32,7 @@ use crate::{
|
||||
driver::{CudaSlice, CudaStream, DevicePtr},
|
||||
},
|
||||
host::{HostOp, cublas::parse_cublas_op},
|
||||
try_create_cublaslt,
|
||||
};
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -248,6 +249,19 @@ fn dtype_to_cuda_types(dtype: DType) -> (cudaDataType, cublasComputeType_t, cuda
|
||||
}
|
||||
}
|
||||
|
||||
impl CuBlasLt {
|
||||
fn get_cublaslt(&self, stream: &Arc<CudaStream>) -> anyhow::Result<Arc<CudaBlasLT>> {
|
||||
if let Some(cublaslt) = self.cublaslt.get() {
|
||||
return Ok(cublaslt.clone());
|
||||
}
|
||||
let created = try_create_cublaslt(stream.clone()).map_err(|message| {
|
||||
anyhow::anyhow!("cuBLASLt unavailable on this machine: {message}")
|
||||
})?;
|
||||
let _ = self.cublaslt.set(created.clone());
|
||||
Ok(created)
|
||||
}
|
||||
}
|
||||
|
||||
impl HostOp for CuBlasLt {
|
||||
fn execute(
|
||||
&self,
|
||||
@@ -324,9 +338,7 @@ impl HostOp for CuBlasLt {
|
||||
)
|
||||
.entered();
|
||||
|
||||
let cublaslt = self
|
||||
.cublaslt
|
||||
.get_or_init(|| Arc::new(CudaBlasLT::new(stream.clone()).unwrap()));
|
||||
let cublaslt = self.get_cublaslt(stream)?;
|
||||
|
||||
let mut matmul_desc: cublasLtMatmulDesc_t = std::ptr::null_mut();
|
||||
let mut a_desc: cublasLtMatrixLayout_t = std::ptr::null_mut();
|
||||
|
||||
@@ -1,128 +1,213 @@
|
||||
; GLUMoE: Match the expert computation subgraph of a Gated MoE (SwiGLU variant).
|
||||
; GLUMoE: Match the expert computation subgraph of a gated MoE.
|
||||
;
|
||||
; This matches the pattern produced by QwenMoE::forward() starting from the
|
||||
; expert gathers through to the final weighted sum, and replaces it with a
|
||||
; fused GLUMoE HostOp.
|
||||
; One fused op supports two activation modes:
|
||||
; mode=0: Qwen-style SwiGLU (silu(gate) * up)
|
||||
; mode=1: Gemma-style GELU (gate * sigmoid(1.595769 * gate * (1 + 0.044715 * gate^2)))
|
||||
;
|
||||
; Inputs extracted:
|
||||
; ?x - input activations [s, H] F32
|
||||
; ?topk_idx - top-k expert indices [s, k] Int (from argsort+slice)
|
||||
; ?topk_vals - top-k routing values [s, k] F32 (from gather on softmax)
|
||||
; ?gate_up_w - stacked gate+up expert weights [E, intermediate*2, H] BF16
|
||||
; ?down_w - stacked down expert weights [E, H, intermediate] BF16
|
||||
;
|
||||
; The pattern captures:
|
||||
; 1. Gate-up expert gather (Iota, Mul, Cast, Iota, Cast, Add, Cast, Gather)
|
||||
; 2. Cast BF16→F32 of gathered gate-up weights
|
||||
; 3. Gate-up batched matmul (Mul + SumReduce)
|
||||
; 4. Gate/Up split via Iota+Gather (slice semantics)
|
||||
; 5. SwiGLU: silu(gate) * up
|
||||
; 6. Down expert gather (same pattern as gate-up)
|
||||
; 7. Cast BF16→F32 of gathered down weights
|
||||
; 8. Down batched matmul (Mul + SumReduce)
|
||||
; 9. Weighted sum: (down_out * topk_values) summed over k
|
||||
;
|
||||
; Variables with ? prefix are egglog pattern variables.
|
||||
; We use wildcards (?_xxx) for shapes/strides we don't extract.
|
||||
; To keep matching fast, we stage through marker states:
|
||||
; 1) Shared gate-up matmul marker
|
||||
; 2) Activation marker (separate swiglu / gemma_gelu paths)
|
||||
; 3) Down matmul marker (separate swiglu / gemma_gelu paths)
|
||||
; 4) Final GLUMoE fusion (separate swiglu / gemma_gelu rules)
|
||||
|
||||
(datatype*
|
||||
(GLUMoEGateUpState
|
||||
(MkGLUMoEGateUpState Expression Expression Expression IR IR IR)
|
||||
)
|
||||
(GLUMoESwiGLUState
|
||||
(MkGLUMoESwiGLUState GLUMoEGateUpState)
|
||||
)
|
||||
(GLUMoEGemmaGELUState
|
||||
(MkGLUMoEGemmaGELUState GLUMoEGateUpState)
|
||||
)
|
||||
(GLUMoESwiGLUDownState
|
||||
(MkGLUMoESwiGLUDownState Expression Expression Expression GLUMoESwiGLUState IR IR)
|
||||
)
|
||||
(GLUMoEGemmaDownState
|
||||
(MkGLUMoEGemmaDownState Expression Expression Expression GLUMoEGemmaGELUState IR IR)
|
||||
)
|
||||
)
|
||||
|
||||
(function glumoe_gate_up (IR) GLUMoEGateUpState :merge new)
|
||||
(function glumoe_swiglu (IR) GLUMoESwiGLUState :merge new)
|
||||
(function glumoe_gemma_gelu (IR) GLUMoEGemmaGELUState :merge new)
|
||||
(function glumoe_swiglu_down (IR) GLUMoESwiGLUDownState :merge new)
|
||||
(function glumoe_gemma_down (IR) GLUMoEGemmaDownState :merge new)
|
||||
|
||||
(rule
|
||||
(
|
||||
; ===== Gate-up expert gather =====
|
||||
; t51: Iota for base index (expert_idx * io_gu)
|
||||
(= ?gu_iota_base (Op (Iota ?gu_io ?gu_iota_base_range) (INil)))
|
||||
; t52: Mul topk_indices * io → base offsets [s, k]
|
||||
(= ?gu_mul_base (Op (Mul ?gu_mul_base_shape ?gu_mul_base_a_stride ?gu_mul_base_b_stride ?gu_mul_base_out_stride) (ICons ?topk_idx (ICons ?gu_iota_base (INil)))))
|
||||
; t53: Cast to F32
|
||||
(= ?gu_cast_base (Op (Cast ?gu_cast_base_size (F32)) (ICons ?gu_mul_base (INil))))
|
||||
; t54: Iota for within-expert index
|
||||
(= ?gu_iota_within (Op (Iota (MIter) ?gu_iota_within_range) (INil)))
|
||||
; t55: Cast within to F32
|
||||
(= ?gu_cast_within (Op (Cast ?gu_cast_within_size (F32)) (ICons ?gu_iota_within (INil))))
|
||||
; t56: Add base + within → flat gather indices
|
||||
(= ?gu_add_idx (Op (Add ?gu_add_shape ?gu_add_a_stride ?gu_add_b_stride ?gu_add_out_stride) (ICons ?gu_cast_base (ICons ?gu_cast_within (INil)))))
|
||||
; t57: Cast to Int
|
||||
(= ?gu_cast_idx (Op (Cast ?gu_cast_idx_size (Int)) (ICons ?gu_add_idx (INil))))
|
||||
; t58: Gather gate_up weights
|
||||
(= ?gu_gathered (Op (Gather ?gu_gather_idx_shape ?gu_gather_idx_stride ?gu_gather_data_shape ?gu_gather_data_stride) (ICons ?gu_cast_idx (ICons ?gate_up_w (INil)))))
|
||||
(= ?gu_add_idx (Op (Add ?gu_add_shape ?gu_add_a_stride ?gu_add_b_stride ?gu_add_out_stride) (ICons ?gu_mul_base (ICons ?gu_iota_within (INil)))))
|
||||
(= ?gu_gathered (Op (Gather ?gu_gather_idx_shape ?gu_gather_idx_stride ?gu_gather_data_shape ?gu_gather_data_stride) (ICons ?gu_add_idx (ICons ?gate_up_w (INil)))))
|
||||
|
||||
; ===== Cast BF16→F32 =====
|
||||
; t59: Cast gathered gate_up to F32
|
||||
(= ?gu_f32 (Op (Cast ?gu_f32_size (F32)) (ICons ?gu_gathered (INil))))
|
||||
|
||||
; ===== Gate-up batched matmul =====
|
||||
; t60: Mul x * gathered_gu (broadcast multiply)
|
||||
(= ?gu_matmul_mul (Op (Mul ?gu_matmul_mul_shape ?gu_matmul_a_stride ?gu_matmul_b_stride ?gu_matmul_mul_out_stride) (ICons ?x (ICons ?gu_f32 (INil)))))
|
||||
; t61: SumReduce over K dimension
|
||||
(= ?gu_matmul (Op (Sum ?gu_matmul_out_shape ?gu_matmul_k ?gu_matmul_in_stride ?gu_matmul_k_stride ?gu_matmul_out_stride) (ICons ?gu_matmul_mul (INil))))
|
||||
)
|
||||
(
|
||||
(set (glumoe_gate_up ?gu_matmul)
|
||||
(MkGLUMoEGateUpState ?gu_io ?gu_matmul_k ?gu_iota_within_range ?x ?topk_idx ?gate_up_w))
|
||||
)
|
||||
:name "GLUMoE gate-up matmul marker"
|
||||
)
|
||||
|
||||
; ===== SwiGLU activation marker =====
|
||||
(rule
|
||||
(
|
||||
(= ?gate_up_state (glumoe_gate_up ?gu_matmul))
|
||||
(= ?gate_up_state (MkGLUMoEGateUpState ?gu_io ?gu_matmul_k ?gu_within_range ?x ?topk_idx ?gate_up_w))
|
||||
|
||||
; ===== Up slice via Iota+Gather =====
|
||||
; t62: Iota with complex expression (slicing the "up" half)
|
||||
(= ?up_iota (Op (Iota ?up_iota_expr ?up_iota_range) (INil)))
|
||||
; t63: Gather to select up portion from matmul result
|
||||
(= ?up_slice (Op (Gather ?up_gather_idx_shape ?up_gather_idx_stride ?up_gather_data_shape ?up_gather_data_stride) (ICons ?up_iota (ICons ?gu_matmul (INil)))))
|
||||
|
||||
; ===== SwiGLU: silu(gate) * up =====
|
||||
; t64: Constant(-1)
|
||||
(= ?neg1 (Op (Constant -1.000000) (INil)))
|
||||
; t65: gate * -1
|
||||
(= ?neg_gate (Op (Mul ?silu_shape1 ?silu_a_stride1 ?silu_b_stride1 ?silu_out_stride1) (ICons ?gu_matmul (ICons ?neg1 (INil)))))
|
||||
; t66: Constant(log2e)
|
||||
(= ?log2e (Op (Constant 1.442695) (INil)))
|
||||
; t67: neg_gate * log2e
|
||||
(= ?scaled (Op (Mul ?silu_shape2 ?silu_a_stride2 ?silu_b_stride2 ?silu_out_stride2) (ICons ?neg_gate (ICons ?log2e (INil)))))
|
||||
; t68: exp2
|
||||
(= ?exp2_val (Op (Exp2 ?silu_shape3 ?silu_in_stride3 ?silu_out_stride3) (ICons ?scaled (INil))))
|
||||
; t69: Constant(1)
|
||||
(= ?one (Op (Constant 1.000000) (INil)))
|
||||
; t70: exp2 + 1
|
||||
(= ?plus1 (Op (Add ?silu_shape4 ?silu_a_stride4 ?silu_b_stride4 ?silu_out_stride4) (ICons ?exp2_val (ICons ?one (INil)))))
|
||||
; t71: recip
|
||||
(= ?sigmoid (Op (Recip ?silu_shape5 ?silu_in_stride5 ?silu_out_stride5) (ICons ?plus1 (INil))))
|
||||
; t72: gate * sigmoid(gate) = silu(gate)
|
||||
(= ?silu_out (Op (Mul ?silu_shape6 ?silu_a_stride6 ?silu_b_stride6 ?silu_out_stride6) (ICons ?gu_matmul (ICons ?sigmoid (INil)))))
|
||||
; t73: silu(gate) * up
|
||||
(= ?swiglu_out (Op (Mul ?swiglu_shape ?swiglu_a_stride ?swiglu_b_stride ?swiglu_out_stride) (ICons ?silu_out (ICons ?up_slice (INil)))))
|
||||
)
|
||||
(
|
||||
(set (glumoe_swiglu ?swiglu_out) (MkGLUMoESwiGLUState ?gate_up_state))
|
||||
)
|
||||
:name "GLUMoE swiglu marker"
|
||||
)
|
||||
|
||||
; ===== Gemma GELU activation marker =====
|
||||
(rule
|
||||
(
|
||||
(= ?gate_up_state (glumoe_gate_up ?gu_matmul))
|
||||
(= ?gate_up_state (MkGLUMoEGateUpState ?gu_io ?gu_matmul_k ?gu_within_range ?x ?topk_idx ?gate_up_w))
|
||||
|
||||
(= ?up_iota (Op (Iota ?up_iota_expr ?up_iota_range) (INil)))
|
||||
(= ?up_slice (Op (Gather ?up_gather_idx_shape ?up_gather_idx_stride ?up_gather_data_shape ?up_gather_data_stride) (ICons ?up_iota (ICons ?gu_matmul (INil)))))
|
||||
|
||||
(= ?gelu_coeff_inner (Op (Constant 0.044715) (INil)))
|
||||
(= ?gelu_inner_scaled (Op (Mul ?gelu_inner_scaled_shape ?gelu_inner_scaled_a_stride ?gelu_inner_scaled_b_stride ?gelu_inner_scaled_out_stride) (ICons ?gu_matmul (ICons ?gelu_coeff_inner (INil)))))
|
||||
(= ?gelu_inner_quad (Op (Mul ?gelu_inner_quad_shape ?gelu_inner_quad_a_stride ?gelu_inner_quad_b_stride ?gelu_inner_quad_out_stride) (ICons ?gelu_inner_scaled (ICons ?gu_matmul (INil)))))
|
||||
(= ?gelu_one (Op (Constant 1.000000) (INil)))
|
||||
(= ?gelu_poly (Op (Add ?gelu_poly_shape ?gelu_poly_a_stride ?gelu_poly_b_stride ?gelu_poly_out_stride) (ICons ?gelu_inner_quad (ICons ?gelu_one (INil)))))
|
||||
(= ?gelu_coeff_outer (Op (Constant 1.595769) (INil)))
|
||||
(= ?gelu_outer_scaled (Op (Mul ?gelu_outer_scaled_shape ?gelu_outer_scaled_a_stride ?gelu_outer_scaled_b_stride ?gelu_outer_scaled_out_stride) (ICons ?gu_matmul (ICons ?gelu_coeff_outer (INil)))))
|
||||
(= ?gelu_scaled (Op (Mul ?gelu_scaled_shape ?gelu_scaled_a_stride ?gelu_scaled_b_stride ?gelu_scaled_out_stride) (ICons ?gelu_outer_scaled (ICons ?gelu_poly (INil)))))
|
||||
(= ?neg1 (Op (Constant -1.000000) (INil)))
|
||||
(= ?gelu_neg (Op (Mul ?gelu_neg_shape ?gelu_neg_a_stride ?gelu_neg_b_stride ?gelu_neg_out_stride) (ICons ?gelu_scaled (ICons ?neg1 (INil)))))
|
||||
(= ?log2e (Op (Constant 1.442695) (INil)))
|
||||
(= ?gelu_exp_scaled (Op (Mul ?gelu_exp_scaled_shape ?gelu_exp_scaled_a_stride ?gelu_exp_scaled_b_stride ?gelu_exp_scaled_out_stride) (ICons ?gelu_neg (ICons ?log2e (INil)))))
|
||||
(= ?gelu_exp2_val (Op (Exp2 ?gelu_exp_shape ?gelu_exp_in_stride ?gelu_exp_out_stride) (ICons ?gelu_exp_scaled (INil))))
|
||||
(= ?gelu_plus1 (Op (Add ?gelu_plus1_shape ?gelu_plus1_a_stride ?gelu_plus1_b_stride ?gelu_plus1_out_stride) (ICons ?gelu_exp2_val (ICons ?gelu_one (INil)))))
|
||||
(= ?gelu_sigmoid (Op (Recip ?gelu_sigmoid_shape ?gelu_sigmoid_in_stride ?gelu_sigmoid_out_stride) (ICons ?gelu_plus1 (INil))))
|
||||
(= ?gelu_out (Op (Mul ?gelu_out_shape ?gelu_out_a_stride ?gelu_out_b_stride ?gelu_out_out_stride) (ICons ?gu_matmul (ICons ?gelu_sigmoid (INil)))))
|
||||
(= ?gemma_out (Op (Mul ?geglu_shape ?geglu_a_stride ?geglu_b_stride ?geglu_out_stride) (ICons ?gelu_out (ICons ?up_slice (INil)))))
|
||||
)
|
||||
(
|
||||
(set (glumoe_gemma_gelu ?gemma_out) (MkGLUMoEGemmaGELUState ?gate_up_state))
|
||||
)
|
||||
:name "GLUMoE gemma gelu marker"
|
||||
)
|
||||
|
||||
; ===== SwiGLU down marker =====
|
||||
(rule
|
||||
(
|
||||
(= ?swiglu_state (glumoe_swiglu ?swiglu_out))
|
||||
(= ?swiglu_state (MkGLUMoESwiGLUState ?gate_up_state))
|
||||
|
||||
; ===== Down expert gather =====
|
||||
; t74: Iota for base index (expert_idx * io_down)
|
||||
(= ?dn_iota_base (Op (Iota ?dn_io ?dn_iota_base_range) (INil)))
|
||||
; t75: Mul topk_indices * io_down
|
||||
(= ?dn_mul_base (Op (Mul ?dn_mul_base_shape ?dn_mul_base_a_stride ?dn_mul_base_b_stride ?dn_mul_base_out_stride) (ICons ?topk_idx (ICons ?dn_iota_base (INil)))))
|
||||
; t76: Cast to F32
|
||||
(= ?dn_cast_base (Op (Cast ?dn_cast_base_size (F32)) (ICons ?dn_mul_base (INil))))
|
||||
; t77: Iota for within-expert index
|
||||
(= ?dn_iota_within (Op (Iota (MIter) ?dn_iota_within_range) (INil)))
|
||||
; t78: Cast within to F32
|
||||
(= ?dn_cast_within (Op (Cast ?dn_cast_within_size (F32)) (ICons ?dn_iota_within (INil))))
|
||||
; t79: Add base + within
|
||||
(= ?dn_add_idx (Op (Add ?dn_add_shape ?dn_add_a_stride ?dn_add_b_stride ?dn_add_out_stride) (ICons ?dn_cast_base (ICons ?dn_cast_within (INil)))))
|
||||
; t80: Cast to Int
|
||||
(= ?dn_cast_idx (Op (Cast ?dn_cast_idx_size (Int)) (ICons ?dn_add_idx (INil))))
|
||||
; t81: Gather down weights
|
||||
(= ?dn_gathered (Op (Gather ?dn_gather_idx_shape ?dn_gather_idx_stride ?dn_gather_data_shape ?dn_gather_data_stride) (ICons ?dn_cast_idx (ICons ?down_w (INil)))))
|
||||
|
||||
; ===== Cast BF16→F32 =====
|
||||
; t82: Cast gathered down to F32
|
||||
(= ?dn_add_idx (Op (Add ?dn_add_shape ?dn_add_a_stride ?dn_add_b_stride ?dn_add_out_stride) (ICons ?dn_mul_base (ICons ?dn_iota_within (INil)))))
|
||||
(= ?dn_gathered (Op (Gather ?dn_gather_idx_shape ?dn_gather_idx_stride ?dn_gather_data_shape ?dn_gather_data_stride) (ICons ?dn_add_idx (ICons ?down_w (INil)))))
|
||||
(= ?dn_f32 (Op (Cast ?dn_f32_size (F32)) (ICons ?dn_gathered (INil))))
|
||||
|
||||
; ===== Down batched matmul =====
|
||||
; t83: Mul swiglu_out * gathered_down (broadcast multiply)
|
||||
(= ?dn_matmul_mul (Op (Mul ?dn_matmul_mul_shape ?dn_matmul_a_stride ?dn_matmul_b_stride ?dn_matmul_mul_out_stride) (ICons ?swiglu_out (ICons ?dn_f32 (INil)))))
|
||||
; t84: SumReduce
|
||||
(= ?dn_matmul (Op (Sum ?dn_matmul_out_shape ?dn_matmul_k ?dn_matmul_in_stride ?dn_matmul_k_stride ?dn_matmul_out_stride) (ICons ?dn_matmul_mul (INil))))
|
||||
)
|
||||
(
|
||||
(set (glumoe_swiglu_down ?dn_matmul)
|
||||
(MkGLUMoESwiGLUDownState ?dn_io ?dn_matmul_k ?dn_iota_within_range ?swiglu_state ?topk_idx ?down_w))
|
||||
)
|
||||
:name "GLUMoE swiglu down marker"
|
||||
)
|
||||
|
||||
; ===== Gemma GELU down marker =====
|
||||
(rule
|
||||
(
|
||||
(= ?gemma_state (glumoe_gemma_gelu ?gemma_out))
|
||||
(= ?gemma_state (MkGLUMoEGemmaGELUState ?gate_up_state))
|
||||
|
||||
(= ?dn_iota_base (Op (Iota ?dn_io ?dn_iota_base_range) (INil)))
|
||||
(= ?dn_mul_base (Op (Mul ?dn_mul_base_shape ?dn_mul_base_a_stride ?dn_mul_base_b_stride ?dn_mul_base_out_stride) (ICons ?topk_idx (ICons ?dn_iota_base (INil)))))
|
||||
(= ?dn_iota_within (Op (Iota (MIter) ?dn_iota_within_range) (INil)))
|
||||
(= ?dn_add_idx (Op (Add ?dn_add_shape ?dn_add_a_stride ?dn_add_b_stride ?dn_add_out_stride) (ICons ?dn_mul_base (ICons ?dn_iota_within (INil)))))
|
||||
(= ?dn_gathered (Op (Gather ?dn_gather_idx_shape ?dn_gather_idx_stride ?dn_gather_data_shape ?dn_gather_data_stride) (ICons ?dn_add_idx (ICons ?down_w (INil)))))
|
||||
(= ?dn_f32 (Op (Cast ?dn_f32_size (F32)) (ICons ?dn_gathered (INil))))
|
||||
(= ?dn_matmul_mul (Op (Mul ?dn_matmul_mul_shape ?dn_matmul_a_stride ?dn_matmul_b_stride ?dn_matmul_mul_out_stride) (ICons ?gemma_out (ICons ?dn_f32 (INil)))))
|
||||
(= ?dn_matmul (Op (Sum ?dn_matmul_out_shape ?dn_matmul_k ?dn_matmul_in_stride ?dn_matmul_k_stride ?dn_matmul_out_stride) (ICons ?dn_matmul_mul (INil))))
|
||||
)
|
||||
(
|
||||
(set (glumoe_gemma_down ?dn_matmul)
|
||||
(MkGLUMoEGemmaDownState ?dn_io ?dn_matmul_k ?dn_iota_within_range ?gemma_state ?topk_idx ?down_w))
|
||||
)
|
||||
:name "GLUMoE gemma down marker"
|
||||
)
|
||||
|
||||
; ===== Final fusion: mode 0 (SwiGLU) =====
|
||||
(rule
|
||||
(
|
||||
(= ?down_state (glumoe_swiglu_down ?dn_matmul))
|
||||
(= ?down_state (MkGLUMoESwiGLUDownState ?dn_io ?dn_matmul_k ?dn_within_range ?swiglu_state ?topk_idx ?down_w))
|
||||
(= ?swiglu_state (MkGLUMoESwiGLUState ?gate_up_state))
|
||||
(= ?gate_up_state (MkGLUMoEGateUpState ?gu_io ?gu_matmul_k ?gu_within_range ?x ?topk_idx ?gate_up_w))
|
||||
|
||||
; ===== Weighted sum over k experts =====
|
||||
; t85: Mul down_out * topk_values
|
||||
(= ?weighted (Op (Mul ?weighted_shape ?weighted_a_stride ?weighted_b_stride ?weighted_out_stride) (ICons ?dn_matmul (ICons ?topk_vals (INil)))))
|
||||
; t86: SumReduce over k dimension → [s, H]
|
||||
(= ?output (Op (Sum ?output_shape ?output_k ?output_in_stride ?output_k_stride ?output_out_stride) (ICons ?weighted (INil))))
|
||||
)
|
||||
(
|
||||
(let ?glumoe (Op (GLUMoE
|
||||
?gu_io ?dn_io ?gu_matmul_k ?dn_matmul_k ?output_k
|
||||
?gu_iota_within_range ?dn_iota_within_range)
|
||||
(ICons ?x (ICons ?topk_idx (ICons ?topk_vals (ICons ?gate_up_w (ICons ?down_w (INil))))))))
|
||||
?gu_within_range ?dn_within_range (MNum 0))
|
||||
(ICons ?x (ICons ?topk_idx (ICons ?topk_vals (ICons ?gate_up_w (ICons ?down_w (ICons ?topk_vals (INil)))))))))
|
||||
(union ?output ?glumoe)
|
||||
)
|
||||
:name "GLUMoE fused expert computation"
|
||||
:name "GLUMoE fused expert computation (swiglu)"
|
||||
)
|
||||
|
||||
; ===== Final fusion: mode 1 (Gemma GELU) =====
|
||||
(rule
|
||||
(
|
||||
(= ?down_state (glumoe_gemma_down ?dn_matmul))
|
||||
(= ?down_state (MkGLUMoEGemmaDownState ?dn_io ?dn_matmul_k ?dn_within_range ?gemma_state ?topk_idx ?down_w))
|
||||
(= ?gemma_state (MkGLUMoEGemmaGELUState ?gate_up_state))
|
||||
(= ?gate_up_state (MkGLUMoEGateUpState ?gu_io ?gu_matmul_k ?gu_within_range ?x ?topk_idx ?gate_up_w))
|
||||
|
||||
; Gemma expert weights: topk_weights = normed_topk * per_expert_scale.gather(topk_idx)
|
||||
(= ?per_expert_vals (Op (Gather ?scale_gather_idx_shape ?scale_gather_idx_stride ?scale_gather_data_shape ?scale_gather_data_stride) (ICons ?topk_idx (ICons ?per_expert_scale (INil)))))
|
||||
(= ?topk_row_offsets (Op (Iota ?topk_row_offsets_expr ?topk_row_offsets_range) (INil)))
|
||||
(= ?topk_flat_idx (Op (Add ?topk_flat_idx_shape ?topk_flat_idx_a_stride ?topk_flat_idx_b_stride ?topk_flat_idx_out_stride) (ICons ?topk_row_offsets (ICons ?topk_idx (INil)))))
|
||||
(= ?topk_vals (Op (Gather ?topk_vals_gather_idx_shape ?topk_vals_gather_idx_stride ?topk_vals_gather_data_shape ?topk_vals_gather_data_stride) (ICons ?topk_flat_idx (ICons ?routing_weights (INil)))))
|
||||
(= ?topk_norm (Op (Sum ?topk_norm_shape ?output_k ?topk_norm_in_stride ?topk_norm_k_stride ?topk_norm_out_stride) (ICons ?topk_vals (INil))))
|
||||
(= ?topk_norm_factor (Op (Recip ?topk_norm_recip_shape ?topk_norm_recip_in_stride ?topk_norm_recip_out_stride) (ICons ?topk_norm (INil))))
|
||||
(= ?normed_topk (Op (Mul ?normed_topk_shape ?normed_topk_a_stride ?normed_topk_b_stride ?normed_topk_out_stride) (ICons ?topk_vals (ICons ?topk_norm_factor (INil)))))
|
||||
(= ?expert_weights (Op (Mul ?expert_weights_shape ?expert_weights_a_stride ?expert_weights_b_stride ?expert_weights_out_stride) (ICons ?normed_topk (ICons ?per_expert_vals (INil)))))
|
||||
|
||||
(= ?weighted (Op (Mul ?weighted_shape ?weighted_a_stride ?weighted_b_stride ?weighted_out_stride) (ICons ?dn_matmul (ICons ?expert_weights (INil)))))
|
||||
(= ?output (Op (Sum ?output_shape ?output_k ?output_in_stride ?output_k_stride ?output_out_stride) (ICons ?weighted (INil))))
|
||||
)
|
||||
(
|
||||
(let ?glumoe (Op (GLUMoE
|
||||
?gu_io ?dn_io ?gu_matmul_k ?dn_matmul_k ?output_k
|
||||
?gu_within_range ?dn_within_range (MNum 1))
|
||||
(ICons ?x (ICons ?topk_idx (ICons ?topk_vals (ICons ?gate_up_w (ICons ?down_w (ICons ?per_expert_scale (INil)))))))))
|
||||
(union ?output ?glumoe)
|
||||
)
|
||||
:name "GLUMoE fused expert computation (gemma_gelu)"
|
||||
)
|
||||
|
||||
@@ -33,14 +33,15 @@ use crate::{
|
||||
},
|
||||
},
|
||||
host::HostOp,
|
||||
try_create_cublaslt,
|
||||
};
|
||||
|
||||
const WORKSPACE_SIZE: usize = 32 * 1024 * 1024; // 32 MiB
|
||||
|
||||
/// Fused GLU-MoE HostOp matched via egglog pattern.
|
||||
///
|
||||
/// Replaces the expert computation subgraph (expert gathers + matmuls + SwiGLU
|
||||
/// + weighted sum) with an efficient cuBLASLt implementation.
|
||||
/// Replaces the expert computation subgraph (expert gathers + matmuls + gated
|
||||
/// activation + weighted sum) with an efficient cuBLASLt implementation.
|
||||
///
|
||||
/// Inputs (graph edges, in order):
|
||||
/// 0: x [seq, hidden] F32
|
||||
@@ -48,9 +49,13 @@ const WORKSPACE_SIZE: usize = 32 * 1024 * 1024; // 32 MiB
|
||||
/// 2: topk_values [seq, k] F32
|
||||
/// 3: gate_up_w [E, gate_up_dim, hidden] BF16
|
||||
/// 4: down_w [E, hidden, intermediate] BF16
|
||||
/// 5: mode_aux
|
||||
/// - SwiGLU: ignored (rewriter wires `topk_values` again)
|
||||
/// - GemmaGELU: per_expert_scale [E] F32
|
||||
///
|
||||
/// Output: [seq, hidden] F32
|
||||
pub struct GLUMoE {
|
||||
pub(crate) mode: GLUMoEMode,
|
||||
/// Product of gate_up weight dimensions per expert (gate_up_dim * hidden) used for gather stride
|
||||
gu_io: Expression,
|
||||
/// Product of down weight dimensions per expert (hidden * intermediate) used for gather stride
|
||||
@@ -69,9 +74,35 @@ pub struct GLUMoE {
|
||||
module: OnceLock<(Arc<CudaModule>, CudaFunction, CudaFunction)>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub(crate) enum GLUMoEMode {
|
||||
SwiGLU,
|
||||
GemmaGELU,
|
||||
}
|
||||
|
||||
impl GLUMoEMode {
|
||||
fn from_mode_id(mode_id: usize) -> Self {
|
||||
match mode_id {
|
||||
0 => Self::SwiGLU,
|
||||
1 => Self::GemmaGELU,
|
||||
other => {
|
||||
panic!("Unknown GLUMoE mode id: {other}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn activation_kernel_mode(self) -> i32 {
|
||||
match self {
|
||||
Self::SwiGLU => 0,
|
||||
Self::GemmaGELU => 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for GLUMoE {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
mode: GLUMoEMode::SwiGLU,
|
||||
gu_io: Expression::default(),
|
||||
dn_io: Expression::default(),
|
||||
gu_matmul_k: Expression::default(),
|
||||
@@ -88,6 +119,7 @@ impl Default for GLUMoE {
|
||||
impl std::fmt::Debug for GLUMoE {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("GLUMoE")
|
||||
.field("mode", &self.mode)
|
||||
.field("gu_io", &self.gu_io)
|
||||
.field("dn_io", &self.dn_io)
|
||||
.field("gu_matmul_k", &self.gu_matmul_k)
|
||||
@@ -100,6 +132,7 @@ impl std::fmt::Debug for GLUMoE {
|
||||
impl Clone for GLUMoE {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
mode: self.mode,
|
||||
gu_io: self.gu_io,
|
||||
dn_io: self.dn_io,
|
||||
gu_matmul_k: self.gu_matmul_k,
|
||||
@@ -114,9 +147,15 @@ impl Clone for GLUMoE {
|
||||
}
|
||||
|
||||
impl GLUMoE {
|
||||
fn get_cublaslt(&self, stream: &Arc<CudaStream>) -> &Arc<CudaBlasLT> {
|
||||
self.cublaslt
|
||||
.get_or_init(|| Arc::new(CudaBlasLT::new(stream.clone()).unwrap()))
|
||||
fn get_cublaslt(&self, stream: &Arc<CudaStream>) -> anyhow::Result<Arc<CudaBlasLT>> {
|
||||
if let Some(cublaslt) = self.cublaslt.get() {
|
||||
return Ok(cublaslt.clone());
|
||||
}
|
||||
let created = try_create_cublaslt(stream.clone()).map_err(|message| {
|
||||
anyhow::anyhow!("cuBLASLt unavailable on this machine: {message}")
|
||||
})?;
|
||||
let _ = self.cublaslt.set(created.clone());
|
||||
Ok(created)
|
||||
}
|
||||
|
||||
fn get_kernels(
|
||||
@@ -134,23 +173,34 @@ extern "C" __global__ void f32_to_bf16(unsigned long long in_ptr, unsigned long
|
||||
if (i < n) out[i] = __float2bfloat16(in_[i]);
|
||||
}
|
||||
|
||||
extern "C" __global__ void swiglu_bf16(unsigned long long gate_up_ptr, unsigned long long out_ptr, int intermediate) {
|
||||
extern "C" __global__ void glu_activation_bf16(
|
||||
unsigned long long gate_up_ptr,
|
||||
unsigned long long out_ptr,
|
||||
int intermediate,
|
||||
int mode
|
||||
) {
|
||||
const __nv_bfloat16* gate_up = (const __nv_bfloat16*)gate_up_ptr;
|
||||
__nv_bfloat16* out = (__nv_bfloat16*)out_ptr;
|
||||
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (i < intermediate) {
|
||||
float gate = __bfloat162float(gate_up[i]);
|
||||
float up = __bfloat162float(gate_up[i + intermediate]);
|
||||
float silu = gate / (1.0f + expf(-gate));
|
||||
out[i] = __float2bfloat16(silu * up);
|
||||
float activated;
|
||||
if (mode == 0) {
|
||||
activated = gate / (1.0f + expf(-gate));
|
||||
} else {
|
||||
float scaled = 1.5957691216f * gate * (1.0f + 0.044715f * gate * gate);
|
||||
activated = gate / (1.0f + expf(-scaled));
|
||||
}
|
||||
out[i] = __float2bfloat16(activated * up);
|
||||
}
|
||||
}
|
||||
"#;
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), src).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let f32_to_bf16 = module.load_function("f32_to_bf16").unwrap();
|
||||
let swiglu = module.load_function("swiglu_bf16").unwrap();
|
||||
(module, f32_to_bf16, swiglu)
|
||||
let activation = module.load_function("glu_activation_bf16").unwrap();
|
||||
(module, f32_to_bf16, activation)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -168,12 +218,27 @@ impl EgglogOp for GLUMoE {
|
||||
("output_k", EXPRESSION),
|
||||
("gu_within_range", EXPRESSION),
|
||||
("dn_within_range", EXPRESSION),
|
||||
("mode", EXPRESSION),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![Rule::raw(
|
||||
"(rule
|
||||
(
|
||||
(= ?e (Op (GLUMoE ?gu_io ?dn_io ?gu_matmul_k ?dn_matmul_k ?output_k ?gu_within_range ?dn_within_range ?mode) ?inputs))
|
||||
)
|
||||
(
|
||||
(set (dtype ?e) (F32))
|
||||
)
|
||||
:ruleset dtype_prop
|
||||
)",
|
||||
)]
|
||||
}
|
||||
|
||||
fn n_inputs(&self) -> usize {
|
||||
5
|
||||
6
|
||||
}
|
||||
|
||||
fn early_rewrites(&self) -> Vec<Rule> {
|
||||
@@ -195,8 +260,14 @@ impl EgglogOp for GLUMoE {
|
||||
let output_k = extract_expr(egraph, kind_children[4], expr_cache).unwrap();
|
||||
let gu_within_range = extract_expr(egraph, kind_children[5], expr_cache).unwrap();
|
||||
let dn_within_range = extract_expr(egraph, kind_children[6], expr_cache).unwrap();
|
||||
let mode_expr = extract_expr(egraph, kind_children[7], expr_cache).unwrap();
|
||||
let mode_id = mode_expr
|
||||
.to_usize()
|
||||
.unwrap_or_else(|| panic!("GLUMoE mode must be static, got expression: {mode_expr}"));
|
||||
let mode = GLUMoEMode::from_mode_id(mode_id);
|
||||
|
||||
let extracted = GLUMoE {
|
||||
mode,
|
||||
gu_io,
|
||||
dn_io,
|
||||
gu_matmul_k,
|
||||
@@ -209,7 +280,7 @@ impl EgglogOp for GLUMoE {
|
||||
};
|
||||
|
||||
let op = LLIROp::new::<dyn HostOp>(Box::new(extracted) as Box<dyn HostOp>);
|
||||
// Return the 5 IR inputs: x, topk_idx, topk_vals, gate_up_w, down_w
|
||||
// Return the 6 IR inputs: x, topk_idx, topk_values, gate_up_w, down_w, mode_aux
|
||||
(op, input_enodes)
|
||||
}
|
||||
|
||||
@@ -230,9 +301,9 @@ impl HostOp for GLUMoE {
|
||||
// Resolve dimensions
|
||||
let hidden = self.gu_matmul_k.exec(dyn_map).unwrap();
|
||||
let intermediate = self.dn_matmul_k.exec(dyn_map).unwrap();
|
||||
let top_k = self.output_k.exec(dyn_map).unwrap();
|
||||
let top_k_expected = self.output_k.exec(dyn_map).unwrap();
|
||||
let gate_up_dim = self.gu_io.exec(dyn_map).unwrap() / hidden; // gate_up_dim = gu_io / hidden
|
||||
let _num_experts = self.gu_within_range.exec(dyn_map).unwrap() / (gate_up_dim * hidden);
|
||||
let num_experts = self.gu_within_range.exec(dyn_map).unwrap() / (gate_up_dim * hidden);
|
||||
|
||||
// Derive seq from x buffer size: x is [seq, hidden] F32 → seq = len / (hidden * 4)
|
||||
let x_buf = buffers[&inputs[0]];
|
||||
@@ -243,6 +314,7 @@ impl HostOp for GLUMoE {
|
||||
let topk_vals_buf = buffers[&inputs[2]]; // [seq, k] F32
|
||||
let gate_up_buf = buffers[&inputs[3]]; // [E, gate_up_dim, hidden] BF16
|
||||
let down_buf = buffers[&inputs[4]]; // [E, hidden, intermediate] BF16
|
||||
let mode_aux_buf = buffers[&inputs[5]];
|
||||
let output_buf = buffers[&self_node]; // [seq, hidden] F32
|
||||
|
||||
// Get raw device pointer addresses
|
||||
@@ -251,14 +323,59 @@ impl HostOp for GLUMoE {
|
||||
let down_ptr = buf_ptr(down_buf, stream);
|
||||
let output_ptr = buf_ptr(output_buf, stream);
|
||||
|
||||
let cublaslt = self.get_cublaslt(stream);
|
||||
let (_, f32_to_bf16_fn, swiglu_fn) = self.get_kernels(stream);
|
||||
let cublaslt = self.get_cublaslt(stream)?;
|
||||
let (_, f32_to_bf16_fn, activation_fn) = self.get_kernels(stream);
|
||||
|
||||
// Read topk indices and values from GPU
|
||||
// Read top-k routing values from GPU
|
||||
let topk_idx_host: Vec<u8> = stream.clone_dtoh(topk_idx_buf)?;
|
||||
let topk_idx_i32: &[i32] = bytemuck::cast_slice(&topk_idx_host);
|
||||
let topk_vals_host: Vec<u8> = stream.clone_dtoh(topk_vals_buf)?;
|
||||
let topk_vals_f32: &[f32] = bytemuck::cast_slice(&topk_vals_host);
|
||||
let idx_k = topk_idx_i32
|
||||
.len()
|
||||
.checked_div(seq)
|
||||
.unwrap_or(top_k_expected);
|
||||
let val_k = topk_vals_f32
|
||||
.len()
|
||||
.checked_div(seq)
|
||||
.unwrap_or(top_k_expected);
|
||||
let top_k = idx_k.min(val_k);
|
||||
if seq > 0 && top_k == 0 {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Mode-dependent expert weights used for the final reduction:
|
||||
// - SwiGLU: direct topk values
|
||||
// - GemmaGELU: normalize topk values and scale by per-expert factors
|
||||
let mut expert_weights_storage: Vec<f32> = Vec::new();
|
||||
let expert_weights_f32: &[f32] = match self.mode {
|
||||
GLUMoEMode::SwiGLU => topk_vals_f32,
|
||||
GLUMoEMode::GemmaGELU => {
|
||||
let per_expert_scale_host: Vec<u8> = stream.clone_dtoh(mode_aux_buf)?;
|
||||
let per_expert_scale_f32: &[f32] = bytemuck::cast_slice(&per_expert_scale_host);
|
||||
debug_assert!(per_expert_scale_f32.len() >= num_experts);
|
||||
expert_weights_storage.resize(seq * top_k, 0.0);
|
||||
for t in 0..seq {
|
||||
let base = t * top_k;
|
||||
let vals = &topk_vals_f32[base..base + top_k];
|
||||
let norm = vals.iter().copied().sum::<f32>();
|
||||
let inv_norm = if norm != 0.0 { norm.recip() } else { 0.0 };
|
||||
for i in 0..top_k {
|
||||
let expert_idx = topk_idx_i32[base + i] as usize;
|
||||
if expert_idx >= per_expert_scale_f32.len() {
|
||||
anyhow::bail!(
|
||||
"GLUMoE Gemma mode expert index {} out of bounds {}",
|
||||
expert_idx,
|
||||
per_expert_scale_f32.len()
|
||||
);
|
||||
}
|
||||
let scale = per_expert_scale_f32[expert_idx];
|
||||
expert_weights_storage[base + i] = vals[i] * inv_norm * scale;
|
||||
}
|
||||
}
|
||||
&expert_weights_storage
|
||||
}
|
||||
};
|
||||
|
||||
// Allocate temp buffers
|
||||
let x_bf16_buf = unsafe { stream.alloc::<u8>(seq * hidden * 2)? }; // BF16
|
||||
@@ -291,22 +408,10 @@ impl HostOp for GLUMoE {
|
||||
let gu_stride = (gate_up_dim * hidden * 2) as u64; // bytes per expert gate_up (BF16)
|
||||
let down_stride = (hidden * intermediate * 2) as u64; // bytes per expert down (BF16)
|
||||
|
||||
// Normalize top-k values per token (norm_topk_prob=true)
|
||||
let mut normalized_vals = topk_vals_f32.to_vec();
|
||||
for t in 0..seq {
|
||||
let row = &mut normalized_vals[t * top_k..(t + 1) * top_k];
|
||||
let sum: f32 = row.iter().sum();
|
||||
if sum > 0.0 {
|
||||
for v in row.iter_mut() {
|
||||
*v /= sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for t in 0..seq {
|
||||
let x_t_ptr = xbf16_ptr + (t * hidden * 2) as u64; // BF16
|
||||
let expert_indices = &topk_idx_i32[t * top_k..(t + 1) * top_k];
|
||||
let weights = &normalized_vals[t * top_k..(t + 1) * top_k];
|
||||
let weights = &expert_weights_f32[t * top_k..(t + 1) * top_k];
|
||||
|
||||
for (i, (&expert_idx, &weight)) in expert_indices.iter().zip(weights.iter()).enumerate()
|
||||
{
|
||||
@@ -316,7 +421,7 @@ impl HostOp for GLUMoE {
|
||||
let expert_gu_ptr = gate_up_ptr + expert_idx as u64 * gu_stride;
|
||||
cublas_matmul(
|
||||
stream,
|
||||
cublaslt,
|
||||
&cublaslt,
|
||||
ws_ptr,
|
||||
gate_up_dim as u64,
|
||||
1,
|
||||
@@ -335,17 +440,19 @@ impl HostOp for GLUMoE {
|
||||
0.0f32,
|
||||
)?;
|
||||
|
||||
// b. SwiGLU kernel (BF16 → BF16)
|
||||
// b. Mode-specific gated activation (BF16 → BF16)
|
||||
let moe_int = intermediate as i32;
|
||||
let swiglu_blocks = (moe_int as u32).div_ceil(256);
|
||||
let activation_mode = self.mode.activation_kernel_mode();
|
||||
let activation_blocks = (moe_int as u32).div_ceil(256);
|
||||
unsafe {
|
||||
stream
|
||||
.launch_builder(swiglu_fn)
|
||||
.launch_builder(activation_fn)
|
||||
.arg(&gu_out_ptr)
|
||||
.arg(&hid_ptr)
|
||||
.arg(&moe_int)
|
||||
.arg(&activation_mode)
|
||||
.launch(LaunchConfig {
|
||||
grid_dim: (swiglu_blocks, 1, 1),
|
||||
grid_dim: (activation_blocks, 1, 1),
|
||||
block_dim: (256, 1, 1),
|
||||
shared_mem_bytes: 0,
|
||||
})?;
|
||||
@@ -358,7 +465,7 @@ impl HostOp for GLUMoE {
|
||||
let beta = if i == 0 { 0.0f32 } else { 1.0f32 };
|
||||
cublas_matmul_mixed(
|
||||
stream,
|
||||
cublaslt,
|
||||
&cublaslt,
|
||||
ws_ptr,
|
||||
hidden as u64,
|
||||
1,
|
||||
|
||||
451
crates/luminal_cuda_lite/src/kernel/fusion/fused_ops.rs
Normal file
451
crates/luminal_cuda_lite/src/kernel/fusion/fused_ops.rs
Normal file
@@ -0,0 +1,451 @@
|
||||
// =========================================================================
|
||||
// Fused elementwise op variants used inside FusionStart/FusionEnd regions.
|
||||
//
|
||||
// Each `FusedX` struct mirrors its un-fused `KernelX` sibling field-for-field
|
||||
// and serves a single purpose: give the egglog rules a distinct sort to
|
||||
// rewrite into so a pair-fuse rule's RHS can never re-match its own LHS
|
||||
// pattern. Cascade prevention by typing.
|
||||
//
|
||||
// `compile()` is a *fallback* path. The fast path collapses each FE-rooted
|
||||
// region into one CUDA kernel inside `region_codegen` and FusedX/FS/FE
|
||||
// never reach kernel_to_host's compile loop. But extraction can produce
|
||||
// LLIR shapes the detector doesn't sweep into a region, so each FusedX's
|
||||
// standalone `compile()` falls back to emitting the same kernel its
|
||||
// un-fused KernelX sibling would — correct, just one launch per op.
|
||||
// =========================================================================
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use cudarc::driver::{CudaFunction, CudaModule, CudaSlice, CudaStream};
|
||||
use luminal::{
|
||||
egglog_utils::{
|
||||
api::{Rule, SortDef, sort},
|
||||
base::{DTYPE, ELIST, OP_KIND},
|
||||
extract_dtype, extract_expr_list,
|
||||
},
|
||||
op::*,
|
||||
prelude::*,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
compile_module_image_for_current_device, cuda_dtype,
|
||||
kernel::KernelOp,
|
||||
kernel::hlir::{dtype_includes, generate_dyn_dims_defines},
|
||||
};
|
||||
|
||||
pub type Ops = (
|
||||
FusedSin,
|
||||
FusedSqrt,
|
||||
FusedExp,
|
||||
FusedExp2,
|
||||
FusedLog2,
|
||||
FusedRecip,
|
||||
FusedAdd,
|
||||
FusedMul,
|
||||
);
|
||||
|
||||
// Standard `compile()` return tuple (matches the trait signature).
|
||||
type CompileOut = (
|
||||
CudaFunction,
|
||||
Arc<CudaModule>,
|
||||
String,
|
||||
(Expression, Expression, Expression),
|
||||
(Expression, Expression, Expression),
|
||||
Expression,
|
||||
FxHashMap<char, CudaSlice<u8>>,
|
||||
);
|
||||
|
||||
// =========================================================================
|
||||
// Fallback kernel templates — used when a FusedX op reaches
|
||||
// `kernel_to_host` standalone (region detection missed it). Same CUDA as
|
||||
// the matching un-fused KernelX would emit, parameterised by the per-op
|
||||
// body expression. The fast path goes through `region_codegen`.
|
||||
// =========================================================================
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn compile_unary_fallback(
|
||||
stream: &Arc<CudaStream>,
|
||||
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
kernel_name: &str,
|
||||
body_expr: &str, // CUDA expression on `in[{in_idx}]`, e.g. "sinf(in[{in_idx}])"
|
||||
shape: &[Expression],
|
||||
in_strides: &[Expression],
|
||||
out_strides: &[Expression],
|
||||
dtype: DType,
|
||||
) -> CompileOut {
|
||||
let vars = shape
|
||||
.iter()
|
||||
.flat_map(|e| e.dyn_vars())
|
||||
.chain(in_strides.iter().flat_map(|e| e.dyn_vars()))
|
||||
.chain(out_strides.iter().flat_map(|e| e.dyn_vars()))
|
||||
.collect::<FxHashSet<_>>();
|
||||
let cuda_ty = cuda_dtype(dtype);
|
||||
let includes = dtype_includes(&[dtype]);
|
||||
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
|
||||
let dyn_dims_param = if vars.is_empty() {
|
||||
""
|
||||
} else {
|
||||
", const int* dyn_dims"
|
||||
};
|
||||
let n_elements = shape.iter().copied().product::<Expression>().to_kernel();
|
||||
let out_idx = flatten_strides(shape, out_strides).to_kernel();
|
||||
let in_idx = flatten_strides(shape, in_strides).to_kernel();
|
||||
let body = body_expr.replace("{in_idx}", &in_idx);
|
||||
let kernel = format!(
|
||||
"{includes}\n{dyn_defines}\nextern \"C\" {{\n\
|
||||
\x20 __global__ void {kernel_name}({cuda_ty} *out, const {cuda_ty} *in{dyn_dims_param}) {{\n\
|
||||
\x20 long long const_z = (long long)blockIdx.x * blockDim.x + threadIdx.x;\n\
|
||||
\x20 if (const_z >= {n_elements}) return;\n\
|
||||
\x20 out[{out_idx}] = {body};\n\
|
||||
\x20 }}\n}}"
|
||||
);
|
||||
let (module, func) = if let Some((m, f)) = compile_cache.get(&kernel) {
|
||||
(m.clone(), f.clone())
|
||||
} else {
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let func = module.load_function(kernel_name).unwrap();
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
(module, func)
|
||||
};
|
||||
let out_size = shape.iter().copied().product::<Expression>();
|
||||
(
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(out_size.ceil_div(256), 1.into(), 1.into()),
|
||||
(out_size.min(256), 1.into(), 1.into()),
|
||||
0.into(),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn compile_binary_fallback(
|
||||
stream: &Arc<CudaStream>,
|
||||
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
kernel_name: &str,
|
||||
op_str: &str, // CUDA infix operator, e.g. "+", "*"
|
||||
out_shape: &[Expression],
|
||||
a_stride: &[Expression],
|
||||
b_stride: &[Expression],
|
||||
out_stride: &[Expression],
|
||||
dtype: DType,
|
||||
) -> CompileOut {
|
||||
let vars = out_shape
|
||||
.iter()
|
||||
.flat_map(|e| e.dyn_vars())
|
||||
.chain(a_stride.iter().flat_map(|e| e.dyn_vars()))
|
||||
.chain(b_stride.iter().flat_map(|e| e.dyn_vars()))
|
||||
.chain(out_stride.iter().flat_map(|e| e.dyn_vars()))
|
||||
.collect::<FxHashSet<_>>();
|
||||
let cuda_ty = cuda_dtype(dtype);
|
||||
let includes = dtype_includes(&[dtype, dtype]);
|
||||
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
|
||||
let dyn_dims_param = if vars.is_empty() {
|
||||
""
|
||||
} else {
|
||||
", const int* dyn_dims"
|
||||
};
|
||||
let n_elements = out_shape
|
||||
.iter()
|
||||
.copied()
|
||||
.product::<Expression>()
|
||||
.to_kernel();
|
||||
let out_idx = flatten_strides(out_shape, out_stride).to_kernel();
|
||||
let a_idx = flatten_strides(out_shape, a_stride).to_kernel();
|
||||
let b_idx = flatten_strides(out_shape, b_stride).to_kernel();
|
||||
let kernel = format!(
|
||||
"{includes}\n{dyn_defines}\nextern \"C\" {{\n\
|
||||
\x20 __global__ void {kernel_name}({cuda_ty} *C, const {cuda_ty} *A, const {cuda_ty} *B{dyn_dims_param}) {{\n\
|
||||
\x20 long long const_z = (long long)blockIdx.x * blockDim.x + threadIdx.x;\n\
|
||||
\x20 if (const_z >= {n_elements}) return;\n\
|
||||
\x20 C[{out_idx}] = A[{a_idx}] {op_str} B[{b_idx}];\n\
|
||||
\x20 }}\n}}"
|
||||
);
|
||||
let (module, func) = if let Some((m, f)) = compile_cache.get(&kernel) {
|
||||
(m.clone(), f.clone())
|
||||
} else {
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let func = module.load_function(kernel_name).unwrap();
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
(module, func)
|
||||
};
|
||||
let out_size = out_shape.iter().copied().product::<Expression>();
|
||||
(
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(out_size.ceil_div(256), 1.into(), 1.into()),
|
||||
(out_size.min(256), 1.into(), 1.into()),
|
||||
0.into(),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
}
|
||||
|
||||
/// Generate `pub struct $Name { … unary fields … }` plus its `EgglogOp` and
|
||||
/// `KernelOp` impls. `$kernel_name` names the CUDA function (and the cache
|
||||
/// key); `$body` is the per-op CUDA expression, e.g. `"sinf(in[{in_idx}])"`.
|
||||
macro_rules! impl_fused_unary {
|
||||
($Name:ident, $sort:literal, $kernel_name:literal, $body:literal) => {
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct $Name {
|
||||
pub(crate) shape: Vec<Expression>,
|
||||
pub(crate) in_strides: Vec<Expression>,
|
||||
pub(crate) out_strides: Vec<Expression>,
|
||||
pub(crate) dtype: DType,
|
||||
}
|
||||
|
||||
impl EgglogOp for $Name {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
$sort,
|
||||
&[
|
||||
("shape", ELIST),
|
||||
("strides", ELIST),
|
||||
("out_strides", ELIST),
|
||||
("dtype", DTYPE),
|
||||
],
|
||||
)
|
||||
}
|
||||
fn n_inputs(&self) -> usize {
|
||||
1
|
||||
}
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
Vec::new()
|
||||
}
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
(
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(Self {
|
||||
shape: extract_expr_list(egraph, kind_children[0], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
in_strides: extract_expr_list(
|
||||
egraph,
|
||||
kind_children[1],
|
||||
list_cache,
|
||||
expr_cache,
|
||||
)
|
||||
.unwrap(),
|
||||
out_strides: extract_expr_list(
|
||||
egraph,
|
||||
kind_children[2],
|
||||
list_cache,
|
||||
expr_cache,
|
||||
)
|
||||
.unwrap(),
|
||||
dtype: extract_dtype(egraph, kind_children[3]),
|
||||
})),
|
||||
input_enodes,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl KernelOp for $Name {
|
||||
fn compile(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> CompileOut {
|
||||
compile_unary_fallback(
|
||||
stream,
|
||||
compile_cache,
|
||||
$kernel_name,
|
||||
$body,
|
||||
&self.shape,
|
||||
&self.in_strides,
|
||||
&self.out_strides,
|
||||
self.dtype,
|
||||
)
|
||||
}
|
||||
fn output_size(&self) -> Expression {
|
||||
self.shape.iter().copied().product()
|
||||
}
|
||||
fn output_bytes(&self) -> Expression {
|
||||
(self.output_size() * self.dtype.bits()).ceil_div(8)
|
||||
}
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
self.output_bytes()
|
||||
}
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
self.output_bytes()
|
||||
}
|
||||
fn flops(&self) -> Expression {
|
||||
self.shape.iter().copied().product()
|
||||
}
|
||||
fn output_dtype(&self) -> DType {
|
||||
self.dtype
|
||||
}
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
$sort
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/// As `impl_fused_unary!` but for binary ops: 5-field sort signature
|
||||
/// (shape + per-input strides + out_stride + dtype), n_inputs = 2.
|
||||
/// `$op_str` is the CUDA infix operator, e.g. `"+"`, `"*"`.
|
||||
macro_rules! impl_fused_binary {
|
||||
($Name:ident, $sort:literal, $kernel_name:literal, $op_str:literal) => {
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct $Name {
|
||||
pub(crate) out_shape: Vec<Expression>,
|
||||
pub(crate) a_stride: Vec<Expression>,
|
||||
pub(crate) b_stride: Vec<Expression>,
|
||||
pub(crate) out_stride: Vec<Expression>,
|
||||
pub(crate) dtype: DType,
|
||||
}
|
||||
|
||||
impl EgglogOp for $Name {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
$sort,
|
||||
&[
|
||||
("shape", ELIST),
|
||||
("a_strides", ELIST),
|
||||
("b_strides", ELIST),
|
||||
("out_strides", ELIST),
|
||||
("dtype", DTYPE),
|
||||
],
|
||||
)
|
||||
}
|
||||
fn n_inputs(&self) -> usize {
|
||||
2
|
||||
}
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
Vec::new()
|
||||
}
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
(
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(Self {
|
||||
out_shape: extract_expr_list(
|
||||
egraph,
|
||||
kind_children[0],
|
||||
list_cache,
|
||||
expr_cache,
|
||||
)
|
||||
.unwrap(),
|
||||
a_stride: extract_expr_list(
|
||||
egraph,
|
||||
kind_children[1],
|
||||
list_cache,
|
||||
expr_cache,
|
||||
)
|
||||
.unwrap(),
|
||||
b_stride: extract_expr_list(
|
||||
egraph,
|
||||
kind_children[2],
|
||||
list_cache,
|
||||
expr_cache,
|
||||
)
|
||||
.unwrap(),
|
||||
out_stride: extract_expr_list(
|
||||
egraph,
|
||||
kind_children[3],
|
||||
list_cache,
|
||||
expr_cache,
|
||||
)
|
||||
.unwrap(),
|
||||
dtype: extract_dtype(egraph, kind_children[4]),
|
||||
})),
|
||||
input_enodes,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl KernelOp for $Name {
|
||||
fn compile(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> CompileOut {
|
||||
compile_binary_fallback(
|
||||
stream,
|
||||
compile_cache,
|
||||
$kernel_name,
|
||||
$op_str,
|
||||
&self.out_shape,
|
||||
&self.a_stride,
|
||||
&self.b_stride,
|
||||
&self.out_stride,
|
||||
self.dtype,
|
||||
)
|
||||
}
|
||||
fn output_size(&self) -> Expression {
|
||||
self.out_shape.iter().copied().product()
|
||||
}
|
||||
fn output_bytes(&self) -> Expression {
|
||||
(self.output_size() * self.dtype.bits()).ceil_div(8)
|
||||
}
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
let bytes = (self.output_size() * self.dtype.bits()).ceil_div(8);
|
||||
bytes + bytes
|
||||
}
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
self.output_bytes()
|
||||
}
|
||||
fn flops(&self) -> Expression {
|
||||
self.out_shape.iter().copied().product()
|
||||
}
|
||||
fn output_dtype(&self) -> DType {
|
||||
self.dtype
|
||||
}
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
$sort
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl_fused_unary!(FusedSin, "FusedSin", "fused_sin_k", "sinf(in[{in_idx}])");
|
||||
impl_fused_unary!(
|
||||
FusedSqrt,
|
||||
"FusedSqrt",
|
||||
"fused_sqrt_k",
|
||||
"sqrtf(in[{in_idx}])"
|
||||
);
|
||||
impl_fused_unary!(FusedExp, "FusedExp", "fused_exp_k", "expf(in[{in_idx}])");
|
||||
impl_fused_unary!(
|
||||
FusedExp2,
|
||||
"FusedExp2",
|
||||
"fused_exp2_k",
|
||||
"exp2f(in[{in_idx}])"
|
||||
);
|
||||
impl_fused_unary!(
|
||||
FusedLog2,
|
||||
"FusedLog2",
|
||||
"fused_log2_k",
|
||||
"log2f(in[{in_idx}])"
|
||||
);
|
||||
impl_fused_unary!(
|
||||
FusedRecip,
|
||||
"FusedRecip",
|
||||
"fused_recip_k",
|
||||
"1.0f / in[{in_idx}]"
|
||||
);
|
||||
|
||||
impl_fused_binary!(FusedAdd, "FusedAdd", "fused_add_k", "+");
|
||||
impl_fused_binary!(FusedMul, "FusedMul", "fused_mul_k", "*");
|
||||
490
crates/luminal_cuda_lite/src/kernel/fusion/markers.rs
Normal file
490
crates/luminal_cuda_lite/src/kernel/fusion/markers.rs
Normal file
@@ -0,0 +1,490 @@
|
||||
// =========================================================================
|
||||
// Fusion boundary markers — FusionStart and FusionEnd.
|
||||
//
|
||||
// Tag-like LLIR ops that bracket a region of elementwise ops destined to
|
||||
// be emitted as a single CUDA kernel:
|
||||
// - N FusionStart nodes per region (one per FS leaf — distinct external
|
||||
// reads),
|
||||
// - exactly 1 FusionEnd per region.
|
||||
//
|
||||
// `FusionEnd::rewrites()` carries the seven rule families that build and
|
||||
// extend regions (pair-fuse / grow / merge); the actual single-kernel
|
||||
// codegen lives in `region_codegen`. Like FusedX, both markers'
|
||||
// `compile()` is `unreachable!()` — region codegen folds them away
|
||||
// before kernel_to_host's compile loop reaches an interior node.
|
||||
// =========================================================================
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use cudarc::driver::{CudaFunction, CudaModule, CudaSlice, CudaStream};
|
||||
use luminal::{
|
||||
egglog_utils::{
|
||||
api::{Rule, SortDef, sort},
|
||||
base::{DTYPE, ELIST, OP_KIND},
|
||||
extract_dtype, extract_expr_list,
|
||||
},
|
||||
op::*,
|
||||
prelude::*,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
compile_module_image_for_current_device, cuda_dtype,
|
||||
kernel::KernelOp,
|
||||
kernel::hlir::{dtype_includes, generate_dyn_dims_defines},
|
||||
};
|
||||
|
||||
/// Identity-memcpy kernel used as a *fallback* when a FusionStart or
|
||||
/// FusionEnd reaches `kernel_to_host`'s compile loop standalone (i.e.,
|
||||
/// region detection didn't sweep it into a `CompileUnit::Region`). The
|
||||
/// fast path is region collapse, but model-fuzz extraction sometimes
|
||||
/// produces LLIR shapes the detector doesn't catch; this keeps
|
||||
/// execution correct in those cases.
|
||||
#[allow(clippy::type_complexity)]
|
||||
fn compile_identity_kernel(
|
||||
stream: &Arc<CudaStream>,
|
||||
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
kernel_name: &str,
|
||||
shape: &[Expression],
|
||||
strides: &[Expression],
|
||||
dtype: DType,
|
||||
) -> CompileOut {
|
||||
let vars = shape
|
||||
.iter()
|
||||
.flat_map(|e| e.dyn_vars())
|
||||
.chain(strides.iter().flat_map(|e| e.dyn_vars()))
|
||||
.collect::<FxHashSet<_>>();
|
||||
let cuda_ty = cuda_dtype(dtype);
|
||||
let includes = dtype_includes(&[dtype]);
|
||||
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
|
||||
let dyn_dims_param = if vars.is_empty() {
|
||||
""
|
||||
} else {
|
||||
", const int* dyn_dims"
|
||||
};
|
||||
let n_elements = shape.iter().copied().product::<Expression>().to_kernel();
|
||||
let idx = flatten_strides(shape, strides).to_kernel();
|
||||
let kernel = format!(
|
||||
"{includes}\n{dyn_defines}\nextern \"C\" {{\n\
|
||||
\x20 __global__ void {kernel_name}({cuda_ty} *out, const {cuda_ty} *in{dyn_dims_param}) {{\n\
|
||||
\x20 long long const_z = (long long)blockIdx.x * blockDim.x + threadIdx.x;\n\
|
||||
\x20 if (const_z >= {n_elements}) return;\n\
|
||||
\x20 out[{idx}] = in[{idx}];\n\
|
||||
\x20 }}\n}}"
|
||||
);
|
||||
let (module, func) = if let Some((m, f)) = compile_cache.get(&kernel) {
|
||||
(m.clone(), f.clone())
|
||||
} else {
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let func = module.load_function(kernel_name).unwrap();
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
(module, func)
|
||||
};
|
||||
let out_size = shape.iter().copied().product::<Expression>();
|
||||
(
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(out_size.ceil_div(256), 1.into(), 1.into()),
|
||||
(out_size.min(256), 1.into(), 1.into()),
|
||||
0.into(),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
}
|
||||
|
||||
pub type Ops = (FusionStart, FusionEnd);
|
||||
|
||||
type CompileOut = (
|
||||
CudaFunction,
|
||||
Arc<CudaModule>,
|
||||
String,
|
||||
(Expression, Expression, Expression),
|
||||
(Expression, Expression, Expression),
|
||||
Expression,
|
||||
FxHashMap<char, CudaSlice<u8>>,
|
||||
);
|
||||
|
||||
// =========================================================================
|
||||
// FusionStart
|
||||
// =========================================================================
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct FusionStart {
|
||||
pub(crate) shape: Vec<Expression>,
|
||||
pub(crate) strides: Vec<Expression>,
|
||||
pub(crate) dtype: DType,
|
||||
}
|
||||
|
||||
impl EgglogOp for FusionStart {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
"FusionStart",
|
||||
&[("shape", ELIST), ("strides", ELIST), ("dtype", DTYPE)],
|
||||
)
|
||||
}
|
||||
fn n_inputs(&self) -> usize {
|
||||
1
|
||||
}
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
// No idempotence rule. `FusionStart(FusionStart(x)) ≡ FusionStart(x)`
|
||||
// would unify nested markers and create eclass cycles via the
|
||||
// pair-fuse rules; without it, occasional re-firings produce extra
|
||||
// semantically-correct identity layers, bounded by the run schedule.
|
||||
Vec::new()
|
||||
}
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
(
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(Self {
|
||||
shape: extract_expr_list(egraph, kind_children[0], list_cache, expr_cache).unwrap(),
|
||||
strides: extract_expr_list(egraph, kind_children[1], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
dtype: extract_dtype(egraph, kind_children[2]),
|
||||
})),
|
||||
input_enodes,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl KernelOp for FusionStart {
|
||||
fn compile(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> CompileOut {
|
||||
compile_identity_kernel(
|
||||
stream,
|
||||
compile_cache,
|
||||
"fusion_start_k",
|
||||
&self.shape,
|
||||
&self.strides,
|
||||
self.dtype,
|
||||
)
|
||||
}
|
||||
fn output_size(&self) -> Expression {
|
||||
self.shape.iter().copied().product()
|
||||
}
|
||||
fn output_bytes(&self) -> Expression {
|
||||
(self.output_size() * self.dtype.bits()).ceil_div(8)
|
||||
}
|
||||
fn output_dtype(&self) -> DType {
|
||||
self.dtype
|
||||
}
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
"FusionStart"
|
||||
}
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// FusionEnd
|
||||
// =========================================================================
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct FusionEnd {
|
||||
pub(crate) shape: Vec<Expression>,
|
||||
pub(crate) strides: Vec<Expression>,
|
||||
pub(crate) dtype: DType,
|
||||
}
|
||||
|
||||
impl EgglogOp for FusionEnd {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
"FusionEnd",
|
||||
&[("shape", ELIST), ("strides", ELIST), ("dtype", DTYPE)],
|
||||
)
|
||||
}
|
||||
fn n_inputs(&self) -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
// Ablation switch: with `LUMINAL_DISABLE_BINARY_FUSION=1` set, do
|
||||
// not register any fusion rules. The e-graph never sees the FS/FE
|
||||
// bracketed alternative, extraction always picks the un-fused
|
||||
// form, and the runtime path matches main with no fusion at all.
|
||||
// Used to A/B fusion's runtime impact on a single binary.
|
||||
if std::env::var("LUMINAL_DISABLE_BINARY_FUSION").is_ok() {
|
||||
return Vec::new();
|
||||
}
|
||||
// Seven rule families build and extend FE-bracketed regions. Each
|
||||
// pair-fuse rule's LHS pattern matches *un-fused* `KernelX` ops; the
|
||||
// RHS produces `FusedX` variants in a different egglog sort, so the
|
||||
// rule's own output cannot re-match its LHS — cascade is prevented
|
||||
// by typing rather than by a discriminator field.
|
||||
//
|
||||
// Stride compatibility is expressed by reusing variable names: a
|
||||
// unary inside a region matches `(KernelU ?shape ?s ?s ?dt)` (in =
|
||||
// out, no transpose); a binary feeding a downstream op binds the
|
||||
// binary's out-stride to the downstream op's in-stride along the
|
||||
// connecting side.
|
||||
let mut rules = Vec::new();
|
||||
|
||||
// (KernelX kind, FusedX kind)
|
||||
let unaries: &[(&str, &str)] = &[
|
||||
("KernelSin", "FusedSin"),
|
||||
("KernelSqrt", "FusedSqrt"),
|
||||
("KernelExp", "FusedExp"),
|
||||
("KernelExp2", "FusedExp2"),
|
||||
("KernelLog2", "FusedLog2"),
|
||||
("KernelRecip", "FusedRecip"),
|
||||
];
|
||||
// (KernelX kind, FusedX kind, rule-name label)
|
||||
let binaries: &[(&str, &str, &str)] = &[
|
||||
("KernelAdd", "FusedAdd", "Add"),
|
||||
("KernelMul", "FusedMul", "Mul"),
|
||||
];
|
||||
|
||||
// 1. Pair-fuse U → U: U2(U1(x)) → FE(FU2(FU1(FS(x)))).
|
||||
for (ki1, fi1) in unaries {
|
||||
for (ko2, fo2) in unaries {
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule (
|
||||
(= ?u1 (Op ({ki1} ?shape ?s ?s ?dt) (ICons ?x (INil))))
|
||||
(= ?u2 (Op ({ko2} ?shape ?s ?s ?dt) (ICons ?u1 (INil))))
|
||||
) (
|
||||
(let ?fs (Op (FusionStart ?shape ?s ?dt) (ICons ?x (INil))))
|
||||
(let ?fu1 (Op ({fi1} ?shape ?s ?s ?dt) (ICons ?fs (INil))))
|
||||
(let ?fu2 (Op ({fo2} ?shape ?s ?s ?dt) (ICons ?fu1 (INil))))
|
||||
(let ?fe (Op (FusionEnd ?shape ?s ?dt) (ICons ?fu2 (INil))))
|
||||
(union ?u2 ?fe)
|
||||
) :name \"pair-fuse-U-U-{ki1}-{ko2}\")"
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Pair-fuse B → U: U(B(a, b)) → FE(FU(FB(FS(a), FS(b)))).
|
||||
for (kb, fb, lb) in binaries {
|
||||
for (ku, fu) in unaries {
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule (
|
||||
(= ?bin (Op ({kb} ?shape ?a_s ?b_s ?o_s ?dt)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(= ?u (Op ({ku} ?shape ?o_s ?o_s ?dt) (ICons ?bin (INil))))
|
||||
) (
|
||||
(let ?fs_a (Op (FusionStart ?shape ?a_s ?dt) (ICons ?a (INil))))
|
||||
(let ?fs_b (Op (FusionStart ?shape ?b_s ?dt) (ICons ?b (INil))))
|
||||
(let ?fbin (Op ({fb} ?shape ?a_s ?b_s ?o_s ?dt)
|
||||
(ICons ?fs_a (ICons ?fs_b (INil)))))
|
||||
(let ?fu (Op ({fu} ?shape ?o_s ?o_s ?dt) (ICons ?fbin (INil))))
|
||||
(let ?fe (Op (FusionEnd ?shape ?o_s ?dt) (ICons ?fu (INil))))
|
||||
(union ?u ?fe)
|
||||
) :name \"pair-fuse-B-U-{lb}-{ku}\")"
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Pair-fuse U → B (lhs / rhs): unary feeds binary's A or B input.
|
||||
// LHS: B(U(a), b) → FE(FB(FU(FS(a)), FS(b))).
|
||||
// RHS: B(a, U(b)) → FE(FB(FS(a), FU(FS(b)))).
|
||||
for (ku, fu) in unaries {
|
||||
for (kb, fb, lb) in binaries {
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule (
|
||||
(= ?u (Op ({ku} ?shape ?u_s ?u_s ?dt) (ICons ?a (INil))))
|
||||
(= ?bin (Op ({kb} ?shape ?u_s ?b_s ?o_s ?dt)
|
||||
(ICons ?u (ICons ?b (INil)))))
|
||||
) (
|
||||
(let ?fs_a (Op (FusionStart ?shape ?u_s ?dt) (ICons ?a (INil))))
|
||||
(let ?fs_b (Op (FusionStart ?shape ?b_s ?dt) (ICons ?b (INil))))
|
||||
(let ?fu (Op ({fu} ?shape ?u_s ?u_s ?dt) (ICons ?fs_a (INil))))
|
||||
(let ?fbin (Op ({fb} ?shape ?u_s ?b_s ?o_s ?dt)
|
||||
(ICons ?fu (ICons ?fs_b (INil)))))
|
||||
(let ?fe (Op (FusionEnd ?shape ?o_s ?dt) (ICons ?fbin (INil))))
|
||||
(union ?bin ?fe)
|
||||
) :name \"pair-fuse-U-B-lhs-{ku}-{lb}\")"
|
||||
)));
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule (
|
||||
(= ?u (Op ({ku} ?shape ?u_s ?u_s ?dt) (ICons ?b (INil))))
|
||||
(= ?bin (Op ({kb} ?shape ?a_s ?u_s ?o_s ?dt)
|
||||
(ICons ?a (ICons ?u (INil)))))
|
||||
) (
|
||||
(let ?fs_a (Op (FusionStart ?shape ?a_s ?dt) (ICons ?a (INil))))
|
||||
(let ?fs_b (Op (FusionStart ?shape ?u_s ?dt) (ICons ?b (INil))))
|
||||
(let ?fu (Op ({fu} ?shape ?u_s ?u_s ?dt) (ICons ?fs_b (INil))))
|
||||
(let ?fbin (Op ({fb} ?shape ?a_s ?u_s ?o_s ?dt)
|
||||
(ICons ?fs_a (ICons ?fu (INil)))))
|
||||
(let ?fe (Op (FusionEnd ?shape ?o_s ?dt) (ICons ?fbin (INil))))
|
||||
(union ?bin ?fe)
|
||||
) :name \"pair-fuse-U-B-rhs-{ku}-{lb}\")"
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
// 4. Pair-fuse B → B (lhs / rhs): inner binary feeds outer's A or B.
|
||||
for (kbi, fbi, lbi) in binaries {
|
||||
for (kbo, fbo, lbo) in binaries {
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule (
|
||||
(= ?bi (Op ({kbi} ?shape ?ai_s ?bi_s ?oi_s ?dt)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(= ?bo (Op ({kbo} ?shape ?oi_s ?co_s ?oo_s ?dt)
|
||||
(ICons ?bi (ICons ?c (INil)))))
|
||||
) (
|
||||
(let ?fs_a (Op (FusionStart ?shape ?ai_s ?dt) (ICons ?a (INil))))
|
||||
(let ?fs_b (Op (FusionStart ?shape ?bi_s ?dt) (ICons ?b (INil))))
|
||||
(let ?fs_c (Op (FusionStart ?shape ?co_s ?dt) (ICons ?c (INil))))
|
||||
(let ?fbi (Op ({fbi} ?shape ?ai_s ?bi_s ?oi_s ?dt)
|
||||
(ICons ?fs_a (ICons ?fs_b (INil)))))
|
||||
(let ?fbo (Op ({fbo} ?shape ?oi_s ?co_s ?oo_s ?dt)
|
||||
(ICons ?fbi (ICons ?fs_c (INil)))))
|
||||
(let ?fe (Op (FusionEnd ?shape ?oo_s ?dt) (ICons ?fbo (INil))))
|
||||
(union ?bo ?fe)
|
||||
) :name \"pair-fuse-B-B-lhs-{lbi}-{lbo}\")"
|
||||
)));
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule (
|
||||
(= ?bi (Op ({kbi} ?shape ?ai_s ?bi_s ?oi_s ?dt)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(= ?bo (Op ({kbo} ?shape ?co_s ?oi_s ?oo_s ?dt)
|
||||
(ICons ?c (ICons ?bi (INil)))))
|
||||
) (
|
||||
(let ?fs_a (Op (FusionStart ?shape ?ai_s ?dt) (ICons ?a (INil))))
|
||||
(let ?fs_b (Op (FusionStart ?shape ?bi_s ?dt) (ICons ?b (INil))))
|
||||
(let ?fs_c (Op (FusionStart ?shape ?co_s ?dt) (ICons ?c (INil))))
|
||||
(let ?fbi (Op ({fbi} ?shape ?ai_s ?bi_s ?oi_s ?dt)
|
||||
(ICons ?fs_a (ICons ?fs_b (INil)))))
|
||||
(let ?fbo (Op ({fbo} ?shape ?co_s ?oi_s ?oo_s ?dt)
|
||||
(ICons ?fs_c (ICons ?fbi (INil)))))
|
||||
(let ?fe (Op (FusionEnd ?shape ?oo_s ?dt) (ICons ?fbo (INil))))
|
||||
(union ?bo ?fe)
|
||||
) :name \"pair-fuse-B-B-rhs-{lbi}-{lbo}\")"
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
// 5. Grow FE → U: U(FE(inner)) → FE(FU(inner)). No new FS.
|
||||
for (ku, fu) in unaries {
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule (
|
||||
(= ?fe (Op (FusionEnd ?shape ?s ?dt) (ICons ?inner (INil))))
|
||||
(= ?u (Op ({ku} ?shape ?s ?s ?dt) (ICons ?fe (INil))))
|
||||
) (
|
||||
(let ?fu (Op ({fu} ?shape ?s ?s ?dt) (ICons ?inner (INil))))
|
||||
(let ?new_fe (Op (FusionEnd ?shape ?s ?dt) (ICons ?fu (INil))))
|
||||
(union ?u ?new_fe)
|
||||
) :name \"grow-FE-U-{ku}\")"
|
||||
)));
|
||||
}
|
||||
|
||||
// 6. Grow FE → B (lhs / rhs): one input is the FE, the other external.
|
||||
for (kb, fb, lb) in binaries {
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule (
|
||||
(= ?fe (Op (FusionEnd ?shape ?a_s ?dt) (ICons ?inner_a (INil))))
|
||||
(= ?bin (Op ({kb} ?shape ?a_s ?b_s ?o_s ?dt)
|
||||
(ICons ?fe (ICons ?b (INil)))))
|
||||
) (
|
||||
(let ?fs_b (Op (FusionStart ?shape ?b_s ?dt) (ICons ?b (INil))))
|
||||
(let ?fbin (Op ({fb} ?shape ?a_s ?b_s ?o_s ?dt)
|
||||
(ICons ?inner_a (ICons ?fs_b (INil)))))
|
||||
(let ?new_fe (Op (FusionEnd ?shape ?o_s ?dt) (ICons ?fbin (INil))))
|
||||
(union ?bin ?new_fe)
|
||||
) :name \"grow-FE-B-lhs-{lb}\")"
|
||||
)));
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule (
|
||||
(= ?fe (Op (FusionEnd ?shape ?b_s ?dt) (ICons ?inner_b (INil))))
|
||||
(= ?bin (Op ({kb} ?shape ?a_s ?b_s ?o_s ?dt)
|
||||
(ICons ?a (ICons ?fe (INil)))))
|
||||
) (
|
||||
(let ?fs_a (Op (FusionStart ?shape ?a_s ?dt) (ICons ?a (INil))))
|
||||
(let ?fbin (Op ({fb} ?shape ?a_s ?b_s ?o_s ?dt)
|
||||
(ICons ?fs_a (ICons ?inner_b (INil)))))
|
||||
(let ?new_fe (Op (FusionEnd ?shape ?o_s ?dt) (ICons ?fbin (INil))))
|
||||
(union ?bin ?new_fe)
|
||||
) :name \"grow-FE-B-rhs-{lb}\")"
|
||||
)));
|
||||
}
|
||||
|
||||
// 7. Merge two FEs at a binary: B(FE(ia), FE(ib)) → FE(FB(ia, ib)).
|
||||
// Both inners reused, no new FS — shared external tensors with
|
||||
// upstream FSes stay at one FS.
|
||||
for (kb, fb, lb) in binaries {
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule (
|
||||
(= ?fe_a (Op (FusionEnd ?shape ?a_s ?dt) (ICons ?inner_a (INil))))
|
||||
(= ?fe_b (Op (FusionEnd ?shape ?b_s ?dt) (ICons ?inner_b (INil))))
|
||||
(= ?bin (Op ({kb} ?shape ?a_s ?b_s ?o_s ?dt)
|
||||
(ICons ?fe_a (ICons ?fe_b (INil)))))
|
||||
) (
|
||||
(let ?fbin (Op ({fb} ?shape ?a_s ?b_s ?o_s ?dt)
|
||||
(ICons ?inner_a (ICons ?inner_b (INil)))))
|
||||
(let ?new_fe (Op (FusionEnd ?shape ?o_s ?dt) (ICons ?fbin (INil))))
|
||||
(union ?bin ?new_fe)
|
||||
) :name \"merge-FE-FE-{lb}\")"
|
||||
)));
|
||||
}
|
||||
|
||||
// No dissolve rule (`FS(FE(x)) → x`): unioning FS's eclass with FE's
|
||||
// inner eclass creates self-referential eclasses after grow rules
|
||||
// extend the downstream region, and extraction then panics with
|
||||
// `Cycle(NodeIndex(_))`. Grow rules already compose adjacent regions
|
||||
// correctly without dissolve.
|
||||
|
||||
rules
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
(
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(Self {
|
||||
shape: extract_expr_list(egraph, kind_children[0], list_cache, expr_cache).unwrap(),
|
||||
strides: extract_expr_list(egraph, kind_children[1], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
dtype: extract_dtype(egraph, kind_children[2]),
|
||||
})),
|
||||
input_enodes,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl KernelOp for FusionEnd {
|
||||
fn compile(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> CompileOut {
|
||||
compile_identity_kernel(
|
||||
stream,
|
||||
compile_cache,
|
||||
"fusion_end_k",
|
||||
&self.shape,
|
||||
&self.strides,
|
||||
self.dtype,
|
||||
)
|
||||
}
|
||||
fn output_size(&self) -> Expression {
|
||||
self.shape.iter().copied().product()
|
||||
}
|
||||
fn output_bytes(&self) -> Expression {
|
||||
(self.output_size() * self.dtype.bits()).ceil_div(8)
|
||||
}
|
||||
fn output_dtype(&self) -> DType {
|
||||
self.dtype
|
||||
}
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
"FusionEnd"
|
||||
}
|
||||
}
|
||||
26
crates/luminal_cuda_lite/src/kernel/fusion/mod.rs
Normal file
26
crates/luminal_cuda_lite/src/kernel/fusion/mod.rs
Normal file
@@ -0,0 +1,26 @@
|
||||
//! Binary-inclusive elementwise kernel fusion.
|
||||
//!
|
||||
//! - `markers` — `FusionStart` / `FusionEnd` ops + the seven egglog rule
|
||||
//! families that build and extend FE-bracketed regions.
|
||||
//! - `fused_ops` — eight `FusedX` op variants (interior to a region) so
|
||||
//! pair-fuse rules' RHS sit in a different egglog sort than their LHS,
|
||||
//! blocking cascade by typing.
|
||||
//! - `region_codegen` — `kernel_to_host` calls into here to collapse each
|
||||
//! FE-rooted region into a single CUDA kernel at compile time.
|
||||
//!
|
||||
//! The LLIR keeps `FusionStart` / `FusedX` / `FusionEnd` nodes after
|
||||
//! extraction; `region_codegen` is the only place that walks them.
|
||||
|
||||
pub mod fused_ops;
|
||||
pub mod markers;
|
||||
pub mod region_codegen;
|
||||
|
||||
pub use fused_ops::{
|
||||
FusedAdd, FusedExp, FusedExp2, FusedLog2, FusedMul, FusedRecip, FusedSin, FusedSqrt,
|
||||
};
|
||||
pub use markers::{FusionEnd, FusionStart};
|
||||
|
||||
/// All fusion-related op types that the egglog runtime needs to know about
|
||||
/// (markers + interior FusedX variants). Combined into a flat tuple for the
|
||||
/// `Ops` registry in `kernel::mod`.
|
||||
pub type Ops = (markers::Ops, fused_ops::Ops);
|
||||
479
crates/luminal_cuda_lite/src/kernel/fusion/region_codegen.rs
Normal file
479
crates/luminal_cuda_lite/src/kernel/fusion/region_codegen.rs
Normal file
@@ -0,0 +1,479 @@
|
||||
// =========================================================================
|
||||
// Region codegen for FusionStart / FusionEnd-bracketed fused regions.
|
||||
//
|
||||
// PR1 left FusedX / FusionStart / FusionEnd nodes in the post-extraction
|
||||
// LLIR, each compiling to its own standalone CUDA kernel. PR2 collapses
|
||||
// every FusionEnd-rooted region into ONE fused CUDA kernel at codegen
|
||||
// time — without rewriting the LLIR.
|
||||
//
|
||||
// Pipeline:
|
||||
// `kernel_to_host` builds a Vec<CompileUnit> from the topo order:
|
||||
// - CompileUnit::Single(node) — un-fused KernelX, compiled as before.
|
||||
// - CompileUnit::Region(rgn) — one FE + its interior FusedX DAG +
|
||||
// its FS leaves. Compiled here as a
|
||||
// single CUDA kernel that reads from
|
||||
// the region's external inputs once,
|
||||
// chains all FusedX bodies through
|
||||
// register-resident locals, and writes
|
||||
// the FE's output.
|
||||
//
|
||||
// The CompiledKernel for a Region is keyed on the FE node and stores
|
||||
// `inputs = external producer NodeIndices` (one per interior FusionStart),
|
||||
// so the existing buffer-pointer wiring in to_host.rs picks up the right
|
||||
// device pointers at execute time. Interior FusedX / FusionStart nodes
|
||||
// never enter the kernels Vec — they have no buffers, no launches.
|
||||
// =========================================================================
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use cudarc::driver::{CudaFunction, CudaModule, CudaSlice, CudaStream};
|
||||
use luminal::{
|
||||
graph::LLIRGraph,
|
||||
prelude::{
|
||||
petgraph::{Direction, algo::toposort, visit::EdgeRef},
|
||||
*,
|
||||
},
|
||||
};
|
||||
|
||||
use as_any::Downcast;
|
||||
|
||||
use crate::{
|
||||
compile_module_image_for_current_device, cuda_dtype,
|
||||
kernel::KernelOp,
|
||||
kernel::fusion::markers::{FusionEnd, FusionStart},
|
||||
kernel::hlir::{dtype_includes, generate_dyn_dims_defines},
|
||||
};
|
||||
|
||||
// =========================================================================
|
||||
// Compile units — what `kernel_to_host` iterates over instead of nodes.
|
||||
// =========================================================================
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct RegionUnit {
|
||||
/// The FusionEnd node that anchors this region.
|
||||
pub fe_node: NodeIndex,
|
||||
/// Interior FusedX nodes, in topological order (predecessors before
|
||||
/// consumers). Used to emit register-binding statements in dependency
|
||||
/// order in the fused CUDA kernel body.
|
||||
pub fusedx_topo: Vec<NodeIndex>,
|
||||
/// FusionStart nodes that bound the region's leaves. One per external
|
||||
/// read site — duplicates (different FS LLIR nodes wrapping the same
|
||||
/// upstream tensor) are kept separate so each read uses its own
|
||||
/// strides; the host launch passes the same device pointer twice.
|
||||
pub fs_nodes: Vec<NodeIndex>,
|
||||
/// External producer NodeIndices, one per `fs_nodes` entry in the same
|
||||
/// order. Becomes the `inputs` field of the FE's `CompiledKernel`, and
|
||||
/// the kernel function's `in0`, `in1`, ... parameters in that order.
|
||||
pub external_inputs: Vec<NodeIndex>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) enum CompileUnit {
|
||||
Single(NodeIndex),
|
||||
Region(RegionUnit),
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// Region detection.
|
||||
// =========================================================================
|
||||
|
||||
/// Group a sub-DAG's topo order into compile units. Each FusionEnd node
|
||||
/// becomes the root of a `CompileUnit::Region`; the region's interior
|
||||
/// FusedX and FusionStart nodes are absorbed into that region and removed
|
||||
/// from the per-node iteration. Anything else is wrapped in
|
||||
/// `CompileUnit::Single`.
|
||||
/// Globally-absorbed FS / FE markers — the set of marker nodes that any
|
||||
/// `FusionEnd` in the LLIR walks back to during region detection. A
|
||||
/// marker is "absorbed" iff some FE in the LLIR can reach it by walking
|
||||
/// incoming edges through `FusionEnd` / `FusedX` nodes, stopping at
|
||||
/// `FusionStart` leaves.
|
||||
///
|
||||
/// This is computed once over the full LLIR rather than per-convex-
|
||||
/// subgraph, because `partition_marked_convex` may put a shared FS leaf
|
||||
/// (one whose e-graph congruence-deduplicated it across multiple
|
||||
/// regions) into a different subgraph than the FE that absorbs it.
|
||||
/// Without this global view, `build_compile_units` running on the FS's
|
||||
/// subgraph would not see any FE walking back to the FS, would emit the
|
||||
/// FS as `CompileUnit::Single`, and the markers' identity-memcpy
|
||||
/// fallback would compile and launch — pure overhead at runtime.
|
||||
pub(crate) fn globally_absorbed_markers(llir_graph: &LLIRGraph) -> FxHashSet<NodeIndex> {
|
||||
let name_of = |idx: NodeIndex| -> Option<&'static str> {
|
||||
llir_graph
|
||||
.node_weight(idx)
|
||||
.and_then(|op| op.to_dialect::<dyn KernelOp>().map(|k| k.kernel_name()))
|
||||
};
|
||||
|
||||
let mut absorbed: FxHashSet<NodeIndex> = FxHashSet::default();
|
||||
for fe in llir_graph.node_indices() {
|
||||
if name_of(fe) != Some("FusionEnd") {
|
||||
continue;
|
||||
}
|
||||
let mut visited: FxHashSet<NodeIndex> = FxHashSet::default();
|
||||
let mut stack: Vec<NodeIndex> = vec![fe];
|
||||
visited.insert(fe);
|
||||
while let Some(cur) = stack.pop() {
|
||||
for pred in llir_graph.neighbors_directed(cur, Direction::Incoming) {
|
||||
if !visited.insert(pred) {
|
||||
continue;
|
||||
}
|
||||
match name_of(pred) {
|
||||
Some("FusionStart") => {
|
||||
absorbed.insert(pred);
|
||||
}
|
||||
Some("FusionEnd") => {
|
||||
absorbed.insert(pred);
|
||||
stack.push(pred);
|
||||
}
|
||||
Some(other) if other.starts_with("Fused") => {
|
||||
absorbed.insert(pred);
|
||||
stack.push(pred);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
absorbed
|
||||
}
|
||||
|
||||
pub(crate) fn build_compile_units(
|
||||
topo_order: &[NodeIndex],
|
||||
llir_graph: &LLIRGraph,
|
||||
globally_absorbed: &FxHashSet<NodeIndex>,
|
||||
) -> Vec<CompileUnit> {
|
||||
let name_of = |idx: NodeIndex| -> Option<&'static str> {
|
||||
llir_graph
|
||||
.node_weight(idx)
|
||||
.and_then(|op| op.to_dialect::<dyn KernelOp>().map(|k| k.kernel_name()))
|
||||
};
|
||||
|
||||
// First pass: every FusionEnd in the subgraph anchors a region; gather
|
||||
// the region's interior + FS leaves by walking incoming edges
|
||||
// backward, stopping at FusionStart (a leaf — its predecessor is the
|
||||
// external producer, outside the region).
|
||||
let mut absorbed: FxHashSet<NodeIndex> = FxHashSet::default();
|
||||
let mut regions: FxHashMap<NodeIndex, RegionUnit> = FxHashMap::default();
|
||||
|
||||
for &node in topo_order {
|
||||
if name_of(node) != Some("FusionEnd") {
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut interior: Vec<NodeIndex> = Vec::new();
|
||||
let mut fs_nodes: Vec<NodeIndex> = Vec::new();
|
||||
let mut visited: FxHashSet<NodeIndex> = FxHashSet::default();
|
||||
let mut stack: Vec<NodeIndex> = Vec::new();
|
||||
stack.push(node);
|
||||
visited.insert(node);
|
||||
|
||||
while let Some(cur) = stack.pop() {
|
||||
for pred in llir_graph.neighbors_directed(cur, Direction::Incoming) {
|
||||
if !visited.insert(pred) {
|
||||
continue;
|
||||
}
|
||||
match name_of(pred) {
|
||||
Some("FusionStart") => {
|
||||
fs_nodes.push(pred);
|
||||
// Don't recurse past FS — its predecessor is
|
||||
// external (outside the region).
|
||||
}
|
||||
Some("FusionEnd") => {
|
||||
// A nested FE inside a region. Under the current
|
||||
// rule design these are cascade artifacts — treat
|
||||
// them as transparent (walk through) rather than
|
||||
// as a separate region. The outer region absorbs
|
||||
// them. They do not become CompileUnit::Region
|
||||
// anchors because their eclass is already the
|
||||
// outer region's.
|
||||
absorbed.insert(pred);
|
||||
stack.push(pred);
|
||||
}
|
||||
Some(other) if other.starts_with("Fused") => {
|
||||
interior.push(pred);
|
||||
stack.push(pred);
|
||||
}
|
||||
_ => {
|
||||
// Non-marker, non-FusedX predecessor inside what
|
||||
// we thought was a region. Shouldn't happen with
|
||||
// the current rules; treat conservatively: do
|
||||
// not absorb — let the kernel_to_host single
|
||||
// path handle it. This means the region is
|
||||
// malformed and we likely should not have a
|
||||
// region at all. Caller will see incomplete
|
||||
// interior; the safer thing is to fall back.
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Topological order on the interior + FS nodes (so the kernel
|
||||
// emits `let v = ...;` lines after their inputs are bound). We
|
||||
// use the parent graph's toposort filtered to in-region nodes.
|
||||
let mut region_set: FxHashSet<NodeIndex> = FxHashSet::default();
|
||||
region_set.extend(interior.iter().copied());
|
||||
region_set.extend(fs_nodes.iter().copied());
|
||||
let topo = toposort(llir_graph, None).expect("LLIR cycle in region detection");
|
||||
let interior_topo: Vec<NodeIndex> = topo
|
||||
.iter()
|
||||
.copied()
|
||||
.filter(|n| region_set.contains(n) && interior.contains(n))
|
||||
.collect();
|
||||
let fs_topo: Vec<NodeIndex> = topo
|
||||
.iter()
|
||||
.copied()
|
||||
.filter(|n| region_set.contains(n) && fs_nodes.contains(n))
|
||||
.collect();
|
||||
|
||||
// External producer for each FS leaf, in the same order.
|
||||
let external_inputs: Vec<NodeIndex> = fs_topo
|
||||
.iter()
|
||||
.map(|&fs| {
|
||||
llir_graph
|
||||
.neighbors_directed(fs, Direction::Incoming)
|
||||
.next()
|
||||
.expect("FusionStart with no predecessor")
|
||||
})
|
||||
.collect();
|
||||
|
||||
absorbed.extend(interior_topo.iter().copied());
|
||||
absorbed.extend(fs_topo.iter().copied());
|
||||
|
||||
regions.insert(
|
||||
node,
|
||||
RegionUnit {
|
||||
fe_node: node,
|
||||
fusedx_topo: interior_topo,
|
||||
fs_nodes: fs_topo,
|
||||
external_inputs,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
// Second pass: emit compile units in original topo order, replacing
|
||||
// FE nodes with their RegionUnit and skipping anything absorbed —
|
||||
// either by a region in *this* subgraph (`absorbed`) or by any
|
||||
// region anywhere in the LLIR (`globally_absorbed`). Skipping the
|
||||
// latter prevents the identity-memcpy fallback from firing on
|
||||
// shared FS markers whose consumers live in other convex subgraphs:
|
||||
// those FSes are absorbed by some other region, and the consuming
|
||||
// region reads from FS's external producer, so the FS never needs
|
||||
// its own kernel.
|
||||
let mut units: Vec<CompileUnit> = Vec::new();
|
||||
for &node in topo_order {
|
||||
if let Some(region) = regions.remove(&node) {
|
||||
units.push(CompileUnit::Region(region));
|
||||
} else if absorbed.contains(&node) || globally_absorbed.contains(&node) {
|
||||
continue;
|
||||
} else {
|
||||
units.push(CompileUnit::Single(node));
|
||||
}
|
||||
}
|
||||
units
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// Per-FusedX body templates.
|
||||
//
|
||||
// Each entry takes the names of the local variables holding the op's
|
||||
// inputs and returns a CUDA expression evaluating to the op's output
|
||||
// (a register-resident value, no buffer involved).
|
||||
// =========================================================================
|
||||
|
||||
fn fused_body(name: &str, locals: &[&str]) -> String {
|
||||
match name {
|
||||
"FusedSin" => format!("sinf({})", locals[0]),
|
||||
"FusedSqrt" => format!("sqrtf({})", locals[0]),
|
||||
"FusedExp" => format!("expf({})", locals[0]),
|
||||
"FusedExp2" => format!("exp2f({})", locals[0]),
|
||||
"FusedLog2" => format!("log2f({})", locals[0]),
|
||||
"FusedRecip" => format!("1.0f / {}", locals[0]),
|
||||
"FusedAdd" => format!("{} + {}", locals[0], locals[1]),
|
||||
"FusedMul" => format!("{} * {}", locals[0], locals[1]),
|
||||
other => panic!("region_codegen: unknown FusedX op {other}"),
|
||||
}
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// Region compilation — emit one CUDA kernel for the whole region.
|
||||
// =========================================================================
|
||||
|
||||
#[allow(clippy::type_complexity)]
|
||||
pub(crate) struct CompiledRegion {
|
||||
pub function: CudaFunction,
|
||||
pub module: Arc<CudaModule>,
|
||||
pub kernel_str: String,
|
||||
pub grid: (Expression, Expression, Expression),
|
||||
pub block: (Expression, Expression, Expression),
|
||||
pub shared_mem: Expression,
|
||||
pub constants: FxHashMap<char, CudaSlice<u8>>,
|
||||
}
|
||||
|
||||
#[allow(clippy::type_complexity)]
|
||||
pub(crate) fn compile_region(
|
||||
region: &RegionUnit,
|
||||
llir_graph: &LLIRGraph,
|
||||
stream: &Arc<CudaStream>,
|
||||
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> CompiledRegion {
|
||||
// Resolve FE: shape, strides (for the write), dtype.
|
||||
let fe_op = llir_graph[region.fe_node]
|
||||
.to_dialect::<dyn KernelOp>()
|
||||
.expect("FE node must be a KernelOp");
|
||||
let fe_struct: &FusionEnd = (***fe_op)
|
||||
.downcast_ref::<FusionEnd>()
|
||||
.expect("region root must be FusionEnd");
|
||||
let out_shape: &[Expression] = &fe_struct.shape;
|
||||
let out_strides: &[Expression] = &fe_struct.strides;
|
||||
let dtype: DType = fe_struct.dtype;
|
||||
|
||||
// Aggregate all dynamic vars used anywhere in the region (FS strides,
|
||||
// FE strides, FusedX shape — all FusedX share `out_shape`, but their
|
||||
// own strides are likewise relevant for any future stride-affine ops).
|
||||
let mut all_vars: FxHashSet<char> = FxHashSet::default();
|
||||
all_vars.extend(out_shape.iter().flat_map(|e| e.dyn_vars()));
|
||||
all_vars.extend(out_strides.iter().flat_map(|e| e.dyn_vars()));
|
||||
for &fs_idx in ®ion.fs_nodes {
|
||||
let fs_op = llir_graph[fs_idx].to_dialect::<dyn KernelOp>().unwrap();
|
||||
let fs_struct: &FusionStart = (***fs_op).downcast_ref::<FusionStart>().unwrap();
|
||||
all_vars.extend(fs_struct.strides.iter().flat_map(|e| e.dyn_vars()));
|
||||
}
|
||||
|
||||
let cuda_ty = cuda_dtype(dtype);
|
||||
let includes = dtype_includes(&[dtype]);
|
||||
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&all_vars);
|
||||
let dyn_dims_param = if all_vars.is_empty() {
|
||||
""
|
||||
} else {
|
||||
", const int* dyn_dims"
|
||||
};
|
||||
|
||||
let n_elements = out_shape
|
||||
.iter()
|
||||
.copied()
|
||||
.product::<Expression>()
|
||||
.to_kernel();
|
||||
|
||||
// Build kernel signature: out, then one input per FS leaf in
|
||||
// `region.fs_nodes` order. The `external_inputs` list (parallel to
|
||||
// `fs_nodes`) is what the host wires into the launch params.
|
||||
let mut signature_params: Vec<String> = vec![format!("{cuda_ty} *out")];
|
||||
for i in 0..region.fs_nodes.len() {
|
||||
signature_params.push(format!("const {cuda_ty} *in{i}"));
|
||||
}
|
||||
let signature = signature_params.join(", ");
|
||||
|
||||
// Body: read FS leaves, then walk FusedX in topo order emitting a
|
||||
// local per op, then write FE output. Every node gets a local keyed
|
||||
// by a position-in-region index so the kernel string is invariant
|
||||
// under NodeIndex churn (each `egglog_to_llir` reissues NodeIndexes,
|
||||
// so naming locals by `n.index()` would invalidate the kernel
|
||||
// string cache on every search candidate). Indices: FS leaves get
|
||||
// 0..fs_nodes.len(), FusedX get fs_nodes.len()..(+ fusedx_topo.len()).
|
||||
let mut local_idx_map: FxHashMap<NodeIndex, usize> = FxHashMap::default();
|
||||
for (i, &fs_idx) in region.fs_nodes.iter().enumerate() {
|
||||
local_idx_map.insert(fs_idx, i);
|
||||
}
|
||||
let fs_count = region.fs_nodes.len();
|
||||
for (i, &op_idx) in region.fusedx_topo.iter().enumerate() {
|
||||
local_idx_map.insert(op_idx, fs_count + i);
|
||||
}
|
||||
let local_name = |n: NodeIndex| format!("v_{}", local_idx_map[&n]);
|
||||
|
||||
let mut body = String::new();
|
||||
body.push_str(&format!(
|
||||
" long long const_z = (long long)blockIdx.x * blockDim.x + threadIdx.x;\n\
|
||||
\x20 if (const_z >= {n_elements}) return;\n"
|
||||
));
|
||||
|
||||
// FS leaves: each reads from its corresponding `in_i` parameter using
|
||||
// its own strides.
|
||||
for (i, &fs_idx) in region.fs_nodes.iter().enumerate() {
|
||||
let fs_op = llir_graph[fs_idx].to_dialect::<dyn KernelOp>().unwrap();
|
||||
let fs_struct: &FusionStart = (***fs_op).downcast_ref::<FusionStart>().unwrap();
|
||||
let read_idx = flatten_strides(out_shape, &fs_struct.strides).to_kernel();
|
||||
body.push_str(&format!(
|
||||
" {cuda_ty} {name} = in{i}[{read_idx}];\n",
|
||||
name = local_name(fs_idx),
|
||||
));
|
||||
}
|
||||
|
||||
// FusedX ops in topo order. Each looks up its predecessor locals
|
||||
// (in incoming-edge id order to match the original op's input
|
||||
// arity / position).
|
||||
for &op_idx in ®ion.fusedx_topo {
|
||||
let op_ref = llir_graph[op_idx].to_dialect::<dyn KernelOp>().unwrap();
|
||||
let op_name = op_ref.kernel_name();
|
||||
|
||||
let mut input_locals: Vec<String> = llir_graph
|
||||
.edges_directed(op_idx, Direction::Incoming)
|
||||
.map(|e| (e.id(), e.source()))
|
||||
.collect::<Vec<_>>()
|
||||
.into_iter()
|
||||
.map(|(_, src)| local_name(src))
|
||||
.collect();
|
||||
// Sort by edge id like the rest of the codegen does for stable
|
||||
// input ordering.
|
||||
let mut edges: Vec<(_, NodeIndex)> = llir_graph
|
||||
.edges_directed(op_idx, Direction::Incoming)
|
||||
.map(|e| (e.id(), e.source()))
|
||||
.collect();
|
||||
edges.sort_by_key(|(eid, _)| *eid);
|
||||
input_locals = edges.into_iter().map(|(_, src)| local_name(src)).collect();
|
||||
let inputs_ref: Vec<&str> = input_locals.iter().map(|s| s.as_str()).collect();
|
||||
|
||||
let expr = fused_body(op_name, &inputs_ref);
|
||||
body.push_str(&format!(
|
||||
" {cuda_ty} {name} = {expr};\n",
|
||||
name = local_name(op_idx),
|
||||
));
|
||||
}
|
||||
|
||||
// FE write: pick the FusedX feeding FE (its single incoming edge in
|
||||
// the region — a FusedX or, in degenerate single-FS regions which
|
||||
// shouldn't arise, an FS).
|
||||
let fe_input: NodeIndex = llir_graph
|
||||
.neighbors_directed(region.fe_node, Direction::Incoming)
|
||||
.next()
|
||||
.expect("FusionEnd with no predecessor");
|
||||
let fe_input_local = local_name(fe_input);
|
||||
let write_idx = flatten_strides(out_shape, out_strides).to_kernel();
|
||||
body.push_str(&format!(" out[{write_idx}] = {fe_input_local};\n"));
|
||||
|
||||
let kernel = format!(
|
||||
"{includes}\n\
|
||||
{dyn_defines}\n\
|
||||
extern \"C\" {{\n\
|
||||
\x20 __global__ void fused_region_k({signature}{dyn_dims_param}) {{\n\
|
||||
{body}\
|
||||
\x20 }}\n\
|
||||
}}"
|
||||
);
|
||||
|
||||
let (module, function) = if let Some((m, f)) = compile_cache.get(&kernel) {
|
||||
(m.clone(), f.clone())
|
||||
} else {
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), &kernel)
|
||||
.expect("region kernel PTX compile failed");
|
||||
let module = stream
|
||||
.context()
|
||||
.load_module(ptx)
|
||||
.expect("module load failed");
|
||||
let function = module
|
||||
.load_function("fused_region_k")
|
||||
.expect("region kernel function not found");
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), function.clone()));
|
||||
(module, function)
|
||||
};
|
||||
|
||||
let out_size = out_shape.iter().copied().product::<Expression>();
|
||||
|
||||
CompiledRegion {
|
||||
function,
|
||||
module,
|
||||
kernel_str: kernel,
|
||||
grid: (out_size.ceil_div(256), 1.into(), 1.into()),
|
||||
block: (out_size.min(256), 1.into(), 1.into()),
|
||||
shared_mem: 0.into(),
|
||||
constants: FxHashMap::default(),
|
||||
}
|
||||
}
|
||||
@@ -1200,7 +1200,25 @@ impl KernelOp for KernelScatter {
|
||||
|
||||
// Single-kernel scatter: copy dest→output then scatter src→output[indexes]
|
||||
// Launched as 1 block of 1024 threads with __syncthreads() barrier.
|
||||
// Uses float4 vectorized copy (4x throughput) for the copy phase.
|
||||
// Uses float4 vectorized copy (16 bytes per op) for the copy phase.
|
||||
//
|
||||
// The number of dtype elements that fit in a float4 (16 bytes) depends
|
||||
// on the element size. Computing `n_vec = n_dest / 4` would only be
|
||||
// correct for 4-byte dtypes — for bf16 it walks 2× past the end of
|
||||
// `out`, producing CUDA_ERROR_ILLEGAL_ADDRESS once the OOB region
|
||||
// happens to land on an unmapped page.
|
||||
let elements_per_vec: usize = match self.dtype {
|
||||
DType::F64 => 2,
|
||||
DType::F32 | DType::Int => 4,
|
||||
DType::F16 | DType::Bf16 | DType::I16 | DType::U16 => 8,
|
||||
DType::Bool
|
||||
| DType::I8
|
||||
| DType::U8
|
||||
| DType::F8UE8M0
|
||||
| DType::F8E4M3
|
||||
| DType::F8E5M2 => 16,
|
||||
other => panic!("Unsupported dtype for scatter vectorization: {other:?}"),
|
||||
};
|
||||
let n_src_elements = self
|
||||
.index_shape
|
||||
.iter()
|
||||
@@ -1225,15 +1243,17 @@ extern \"C\" {{
|
||||
int tid = threadIdx.x;
|
||||
long long n_dest = {n_dest_elements};
|
||||
long long n_src = {n_src_elements};
|
||||
// Phase 1: vectorized copy dest → output (float4 = 4 elements per op)
|
||||
long long n_vec = n_dest / 4;
|
||||
// Phase 1: vectorized copy dest → output (float4 = 16 bytes / iter,
|
||||
// i.e. {elements_per_vec} {dtype} elements). n_vec is sized so the
|
||||
// total bytes covered (`n_vec * 16`) never exceed `n_dest * sizeof({dtype})`.
|
||||
long long n_vec = n_dest / {elements_per_vec};
|
||||
float4 *out4 = (float4 *)out;
|
||||
const float4 *dest4 = (const float4 *)dest;
|
||||
for (long long i = tid; i < n_vec; i += blockDim.x) {{
|
||||
out4[i] = dest4[i];
|
||||
}}
|
||||
// Handle remaining elements
|
||||
long long remainder_start = n_vec * 4;
|
||||
// Handle remaining elements (the dtype-tail past the last full float4).
|
||||
long long remainder_start = n_vec * {elements_per_vec};
|
||||
for (long long i = remainder_start + tid; i < n_dest; i += blockDim.x) {{
|
||||
out[i] = dest[i];
|
||||
}}
|
||||
@@ -2060,7 +2080,7 @@ extern \"C\" {{
|
||||
__global__ void recip_k({dtype} *out, const {dtype} *in{dyn_dims_param}) {{
|
||||
long long const_z = (long long)blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (const_z >= {n_elements}) return;
|
||||
out[{out_idx}] = 1.0f / in[{in_idx}];
|
||||
out[{out_idx}] = ({dtype})1.0f / in[{in_idx}];
|
||||
}}
|
||||
}}"
|
||||
);
|
||||
|
||||
@@ -10,12 +10,13 @@ use luminal_tracing::schema::{
|
||||
use uuid::Uuid;
|
||||
|
||||
pub mod cuda_graph;
|
||||
pub mod fusion;
|
||||
pub mod hlir;
|
||||
pub mod other_ops;
|
||||
|
||||
pub use cuda_graph::*;
|
||||
|
||||
pub type Ops = (hlir::Ops, other_ops::Ops);
|
||||
pub type Ops = (hlir::Ops, other_ops::Ops, fusion::Ops);
|
||||
|
||||
/// Build a mapping from interned string IDs to their string values for a given sequence.
|
||||
fn build_interned_strings(trace: &schema::Trace) -> std::collections::HashMap<(u32, u64), String> {
|
||||
|
||||
@@ -10,7 +10,7 @@ use itertools::Itertools;
|
||||
use luminal::{
|
||||
egglog_utils::{
|
||||
api::{Rule, SortDef, sort},
|
||||
base::{DTYPE, ELIST, EXPRESSION, OP_KIND},
|
||||
base::{DTYPE, ELIST, EXPRESSION, OP_KIND, STRING},
|
||||
extract_dtype, extract_expr, extract_expr_list,
|
||||
},
|
||||
op::*,
|
||||
|
||||
@@ -7,7 +7,8 @@ use std::cell::RefCell;
|
||||
use std::sync::Arc;
|
||||
|
||||
use cudarc::driver::{
|
||||
CudaFunction, CudaModule, CudaSlice, CudaStream, DevicePtr, sys::CUgraphNode,
|
||||
CudaFunction, CudaModule, CudaSlice, CudaStream, DevicePtr,
|
||||
sys::{CUgraphNode, CUresult, cuLaunchKernel},
|
||||
};
|
||||
use itertools::Itertools;
|
||||
use luminal::{
|
||||
@@ -26,6 +27,7 @@ use crate::{
|
||||
kernel::{
|
||||
CudaFunctionExt, CudaGraphExecHandle, CudaGraphHandle, KernelOp, create_cuda_event,
|
||||
destroy_cuda_event,
|
||||
fusion::region_codegen::{self, CompileUnit},
|
||||
hlir::{clear_global_dyn_dims, get_global_dyn_dims, set_global_dyn_dims},
|
||||
},
|
||||
runtime::partition_marked_convex,
|
||||
@@ -274,6 +276,14 @@ impl CudaGraphOp {
|
||||
buffers: &FxHashMap<NodeIndex, &CudaSlice<u8>>,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> anyhow::Result<()> {
|
||||
// Debug path: launch each kernel sequentially with sync between, so the
|
||||
// failing kernel surfaces instead of the generic "CudaGraph" panic.
|
||||
// Enable via `LUMINAL_DEBUG_SEQ=1`. Slow — only for diagnosing
|
||||
// CUDA_ERROR_ILLEGAL_ADDRESS / NaN / wrong-output bugs in graph batching.
|
||||
if std::env::var("LUMINAL_DEBUG_SEQ").is_ok() {
|
||||
return self.execute_sequential_for_debug(stream, buffers, dyn_map);
|
||||
}
|
||||
|
||||
let mut state = self.state.borrow_mut();
|
||||
let _span = span!(Level::TRACE, "cuda_graph", kernels = state.kernels.len()).entered();
|
||||
|
||||
@@ -446,6 +456,152 @@ impl CudaGraphOp {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Diagnostic path for kernel-level errors that surface as a generic
|
||||
/// `CUDA_ERROR_ILLEGAL_ADDRESS` panic from the batched cuda_graph_exec
|
||||
/// launch. Bypasses CUDA-graph batching entirely: builds params per
|
||||
/// kernel and launches each via `cuLaunchKernel`, syncing afterwards so
|
||||
/// the offending kernel reports itself instead of being hidden inside
|
||||
/// the graph's atomic launch.
|
||||
///
|
||||
/// Enabled via `LUMINAL_DEBUG_SEQ=1`. ~10–100× slower than the graph
|
||||
/// path; not for production.
|
||||
fn execute_sequential_for_debug(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
buffers: &FxHashMap<NodeIndex, &CudaSlice<u8>>,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> anyhow::Result<()> {
|
||||
let mut state = self.state.borrow_mut();
|
||||
let num_kernels = state.kernels.len();
|
||||
|
||||
// Allocate dyn_dims_buffer if needed and copy current values.
|
||||
if !self.dyn_dims_order.is_empty() && state.dyn_dims_buffer.is_none() {
|
||||
state.dyn_dims_buffer = Some(stream.alloc_zeros::<i32>(self.dyn_dims_order.len())?);
|
||||
}
|
||||
if !self.dyn_dims_order.is_empty() {
|
||||
let values: Vec<i32> = self
|
||||
.dyn_dims_order
|
||||
.iter()
|
||||
.map(|d| dyn_map.get(d).copied().unwrap_or(0) as i32)
|
||||
.collect();
|
||||
if let Some(buf) = state.dyn_dims_buffer.as_mut() {
|
||||
stream.memcpy_htod(&values, buf)?;
|
||||
}
|
||||
}
|
||||
let dyn_dims_ptr = state
|
||||
.dyn_dims_buffer
|
||||
.as_ref()
|
||||
.map(|buf| buf.device_ptr(stream).0)
|
||||
.unwrap_or(0);
|
||||
|
||||
// Collect buffer pointers (mirrors the graph path).
|
||||
let mut buffer_ptrs: FxHashMap<NodeIndex, u64> = FxHashMap::default();
|
||||
for &node in &self.buffer_nodes {
|
||||
if let Some(buf) = buffers.get(&node) {
|
||||
buffer_ptrs.insert(node, buf.device_ptr(stream).0);
|
||||
}
|
||||
}
|
||||
for kernel in state.kernels.iter() {
|
||||
if let Some(input_idx) = kernel.kernel_op.output_aliases_input()
|
||||
&& let Some(&input_ptr) = buffer_ptrs.get(&kernel.inputs[input_idx])
|
||||
{
|
||||
buffer_ptrs.insert(kernel.node, input_ptr);
|
||||
}
|
||||
}
|
||||
|
||||
// Allocate internal buffers + run pre_execute for every kernel up front.
|
||||
for idx in 0..num_kernels {
|
||||
let kernel = &mut state.kernels[idx];
|
||||
if kernel.internal_bufs.is_empty() {
|
||||
kernel.internal_bufs = kernel.kernel_op.allocate_internal_buffers(stream, dyn_map);
|
||||
}
|
||||
kernel.kernel_op.pre_execute(
|
||||
stream,
|
||||
&mut kernel.internal_bufs,
|
||||
&mut kernel.constants,
|
||||
&buffer_ptrs,
|
||||
dyn_map,
|
||||
);
|
||||
}
|
||||
|
||||
let cu_stream = stream.cu_stream();
|
||||
|
||||
for idx in 0..num_kernels {
|
||||
let kernel = &state.kernels[idx];
|
||||
let kernel_name = kernel.kernel_op.kernel_name();
|
||||
let node = kernel.node;
|
||||
|
||||
let grid = (
|
||||
kernel.grid.0.exec(dyn_map).unwrap() as u32,
|
||||
kernel.grid.1.exec(dyn_map).unwrap() as u32,
|
||||
kernel.grid.2.exec(dyn_map).unwrap() as u32,
|
||||
);
|
||||
let block = (
|
||||
kernel.block.0.exec(dyn_map).unwrap() as u32,
|
||||
kernel.block.1.exec(dyn_map).unwrap() as u32,
|
||||
kernel.block.2.exec(dyn_map).unwrap() as u32,
|
||||
);
|
||||
let shared_mem = kernel.shared_mem.exec(dyn_map).unwrap() as u32;
|
||||
|
||||
let output_ptr = buffer_ptrs.get(&node).copied().unwrap_or(0);
|
||||
let input_ptrs: Vec<u64> = kernel
|
||||
.inputs
|
||||
.iter()
|
||||
.map(|inp| buffer_ptrs.get(inp).copied().unwrap_or(0))
|
||||
.collect();
|
||||
|
||||
let param_values = kernel.kernel_op.build_params(
|
||||
stream,
|
||||
output_ptr,
|
||||
&input_ptrs,
|
||||
&kernel.internal_bufs,
|
||||
dyn_dims_ptr,
|
||||
);
|
||||
let mut params = UnifiedKernelParams::new(param_values);
|
||||
let cu_func = unsafe { kernel.function.raw_function() };
|
||||
|
||||
let result = unsafe {
|
||||
cuLaunchKernel(
|
||||
cu_func,
|
||||
grid.0,
|
||||
grid.1,
|
||||
grid.2,
|
||||
block.0,
|
||||
block.1,
|
||||
block.2,
|
||||
shared_mem,
|
||||
cu_stream,
|
||||
params.as_cuda_params(),
|
||||
std::ptr::null_mut(),
|
||||
)
|
||||
};
|
||||
if result != CUresult::CUDA_SUCCESS {
|
||||
eprintln!(
|
||||
"[seq-debug] kernel #{idx}/{num_kernels} '{kernel_name}' \
|
||||
node={node:?} grid={grid:?} block={block:?} \
|
||||
output_ptr={output_ptr:#x} inputs={input_ptrs:#x?} \
|
||||
LAUNCH FAILED: {result:?}"
|
||||
);
|
||||
anyhow::bail!(
|
||||
"kernel #{idx} '{kernel_name}' (node {node:?}) launch failed: {result:?}"
|
||||
);
|
||||
}
|
||||
if let Err(e) = stream.synchronize() {
|
||||
eprintln!(
|
||||
"[seq-debug] kernel #{idx}/{num_kernels} '{kernel_name}' \
|
||||
node={node:?} grid={grid:?} block={block:?} \
|
||||
output_ptr={output_ptr:#x} inputs={input_ptrs:#x?} \
|
||||
SYNC FAILED: {e}"
|
||||
);
|
||||
anyhow::bail!(
|
||||
"kernel #{idx} '{kernel_name}' (node {node:?}) sync failed: {e}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Build the CUDA graph from compiled kernels.
|
||||
fn build_graph(
|
||||
&self,
|
||||
@@ -655,6 +811,11 @@ pub fn kernel_to_host(
|
||||
}
|
||||
|
||||
let kernel_subgraphs = partition_marked_convex(llir_graph, &kernel_ops_in_graph).unwrap();
|
||||
// Compute the set of FS / FE / FusedX nodes globally absorbed by some
|
||||
// FusionEnd in the LLIR. Used by `build_compile_units` to suppress the
|
||||
// identity-memcpy fallback for shared FS leaves whose consumers live
|
||||
// in a different convex subgraph than the FS itself.
|
||||
let globally_absorbed = region_codegen::globally_absorbed_markers(llir_graph);
|
||||
|
||||
// Track which kernel node belongs to which CudaGraphOp (for later edge creation)
|
||||
let mut kernel_to_cuda_graph: FxHashMap<NodeIndex, NodeIndex> = FxHashMap::default();
|
||||
@@ -689,45 +850,98 @@ pub fn kernel_to_host(
|
||||
set_global_dyn_dims(global_dyn_dims.clone());
|
||||
}
|
||||
|
||||
// Compile all kernels with global ordering for correct dyn_dims indices
|
||||
let mut kernels = Vec::with_capacity(topo_order.len());
|
||||
for kernel_node_idx in &topo_order {
|
||||
let kernel_op_ref = llir_graph[*kernel_node_idx]
|
||||
.to_dialect::<dyn KernelOp>()
|
||||
.unwrap();
|
||||
// Group the topo order into compile units: each FusionEnd-rooted
|
||||
// region collapses to a single CompileUnit::Region (one fused
|
||||
// CUDA kernel for the whole DAG); everything else stays as
|
||||
// CompileUnit::Single (the existing per-op compile path).
|
||||
let compile_units =
|
||||
region_codegen::build_compile_units(&topo_order, llir_graph, &globally_absorbed);
|
||||
|
||||
let (kernel_function, _, _kernel_str, grid, block, shared_mem, constants) =
|
||||
kernel_op_ref.compile(cuda_stream, kernel_cache);
|
||||
// Compile all units with global ordering for correct dyn_dims indices
|
||||
let mut kernels = Vec::with_capacity(compile_units.len());
|
||||
for unit in &compile_units {
|
||||
match unit {
|
||||
CompileUnit::Single(kernel_node_idx) => {
|
||||
let kernel_op_ref = llir_graph[*kernel_node_idx]
|
||||
.to_dialect::<dyn KernelOp>()
|
||||
.unwrap();
|
||||
|
||||
// Collect inputs from graph edges
|
||||
let mut inputs: Vec<NodeIndex> = llir_graph
|
||||
.edges_directed(*kernel_node_idx, Direction::Incoming)
|
||||
.sorted_by_key(|e| e.id())
|
||||
.map(|e| e.source())
|
||||
.collect_vec();
|
||||
let (kernel_function, _, _kernel_str, grid, block, shared_mem, constants) =
|
||||
kernel_op_ref.compile(cuda_stream, kernel_cache);
|
||||
|
||||
// Collect buffer nodes and sizes
|
||||
// Only add kernel nodes with non-zero output size (MegakernelOps have size 0)
|
||||
let output_size = kernel_op_ref.output_size();
|
||||
if output_size.exec(&FxHashMap::default()).unwrap_or(1) != 0 {
|
||||
all_buffer_nodes.insert(*kernel_node_idx);
|
||||
all_buffer_sizes.insert(*kernel_node_idx, output_size);
|
||||
// Collect inputs from graph edges
|
||||
let inputs: Vec<NodeIndex> = llir_graph
|
||||
.edges_directed(*kernel_node_idx, Direction::Incoming)
|
||||
.sorted_by_key(|e| e.id())
|
||||
.map(|e| e.source())
|
||||
.collect_vec();
|
||||
|
||||
// Collect buffer nodes and sizes
|
||||
// Only add kernel nodes with non-zero output size (MegakernelOps have size 0)
|
||||
let output_size = kernel_op_ref.output_size();
|
||||
if output_size.exec(&FxHashMap::default()).unwrap_or(1) != 0 {
|
||||
all_buffer_nodes.insert(*kernel_node_idx);
|
||||
all_buffer_sizes.insert(*kernel_node_idx, output_size);
|
||||
}
|
||||
all_buffer_nodes.extend(inputs.iter().copied());
|
||||
|
||||
let kernel_op: Arc<Box<dyn KernelOp>> = Arc::clone(kernel_op_ref);
|
||||
|
||||
kernels.push(CompiledKernel::new(
|
||||
*kernel_node_idx,
|
||||
kernel_function,
|
||||
grid,
|
||||
block,
|
||||
shared_mem,
|
||||
inputs,
|
||||
kernel_op.clone(),
|
||||
constants,
|
||||
kernel_op.kernel_name(),
|
||||
));
|
||||
}
|
||||
CompileUnit::Region(region) => {
|
||||
// Generate one fused CUDA kernel for the whole region.
|
||||
let compiled = region_codegen::compile_region(
|
||||
region,
|
||||
llir_graph,
|
||||
cuda_stream,
|
||||
kernel_cache,
|
||||
);
|
||||
|
||||
// The region's CompiledKernel is keyed on the FE node
|
||||
// (so FE provides trait methods like output_size /
|
||||
// build_params) but its `inputs` are the external
|
||||
// producers, not FE's literal LLIR predecessors —
|
||||
// those are interior FusedX nodes that don't exist
|
||||
// as buffer-bearing nodes from the host's view.
|
||||
let fe_op_ref = llir_graph[region.fe_node]
|
||||
.to_dialect::<dyn KernelOp>()
|
||||
.unwrap();
|
||||
|
||||
let inputs: Vec<NodeIndex> = region.external_inputs.clone();
|
||||
|
||||
let output_size = fe_op_ref.output_size();
|
||||
if output_size.exec(&FxHashMap::default()).unwrap_or(1) != 0 {
|
||||
all_buffer_nodes.insert(region.fe_node);
|
||||
all_buffer_sizes.insert(region.fe_node, output_size);
|
||||
}
|
||||
all_buffer_nodes.extend(inputs.iter().copied());
|
||||
|
||||
let kernel_op: Arc<Box<dyn KernelOp>> = Arc::clone(fe_op_ref);
|
||||
|
||||
kernels.push(CompiledKernel::new(
|
||||
region.fe_node,
|
||||
compiled.function,
|
||||
compiled.grid,
|
||||
compiled.block,
|
||||
compiled.shared_mem,
|
||||
inputs,
|
||||
kernel_op,
|
||||
compiled.constants,
|
||||
"FusedRegion",
|
||||
));
|
||||
}
|
||||
}
|
||||
all_buffer_nodes.extend(inputs.iter().copied());
|
||||
|
||||
let kernel_op: Arc<Box<dyn KernelOp>> = Arc::clone(kernel_op_ref);
|
||||
|
||||
kernels.push(CompiledKernel::new(
|
||||
*kernel_node_idx,
|
||||
kernel_function,
|
||||
grid,
|
||||
block,
|
||||
shared_mem,
|
||||
inputs,
|
||||
kernel_op.clone(),
|
||||
constants,
|
||||
kernel_op.kernel_name(),
|
||||
));
|
||||
}
|
||||
|
||||
// Get the possibly-extended global ordering (kernels may have discovered new dims)
|
||||
@@ -820,22 +1034,41 @@ pub fn kernel_to_host(
|
||||
}
|
||||
}
|
||||
|
||||
// Add collected edges (deduplicate), skipping back-edges to preserve DAG property
|
||||
// Add each cross-CudaGraphOp dep edge iff it would carry new ordering
|
||||
// information without closing a cycle. The previous topo-position gate
|
||||
// ("skip when src_pos >= dst_pos") was too coarse: it dropped edges
|
||||
// whose src happened to land later in the toposort than their dst even
|
||||
// when no path dst→src actually existed, leaving consumers free to run
|
||||
// before the producer wrote their input buffer (wrong outputs); and it
|
||||
// also added edges that were already implied by an existing src→dst
|
||||
// path (extra serialization, no new info).
|
||||
let edges_to_add: FxHashSet<(NodeIndex, NodeIndex)> = edges_to_add.into_iter().collect();
|
||||
let topo = toposort(&*llir_graph, None).unwrap();
|
||||
let mut topo_pos: FxHashMap<NodeIndex, usize> = FxHashMap::default();
|
||||
for (i, n) in topo.iter().enumerate() {
|
||||
topo_pos.insert(*n, i);
|
||||
}
|
||||
use petgraph::algo::has_path_connecting;
|
||||
for (src, dst) in edges_to_add {
|
||||
// Only add forward edges (src before dst in topo order) to avoid creating cycles
|
||||
let src_pos = topo_pos.get(&src).copied().unwrap_or(usize::MAX);
|
||||
let dst_pos = topo_pos.get(&dst).copied().unwrap_or(usize::MAX);
|
||||
if src_pos >= dst_pos {
|
||||
continue; // Skip back-edges
|
||||
if has_path_connecting(&*llir_graph, src, dst, None) {
|
||||
continue; // already ordered src→dst by some path; edge redundant
|
||||
}
|
||||
if !llir_graph.edges_connecting(src, dst).any(|_| true) {
|
||||
llir_graph.add_edge(src, dst, ());
|
||||
if has_path_connecting(&*llir_graph, dst, src, None) {
|
||||
continue; // adding src→dst would close a cycle
|
||||
}
|
||||
llir_graph.add_edge(src, dst, ());
|
||||
}
|
||||
|
||||
// Strip fully-absorbed marker nodes (FusionStart, nested FusionEnd,
|
||||
// FusedX) from the LLIR. Region codegen has already folded them into
|
||||
// a single fused CUDA function anchored at each region's root
|
||||
// FusionEnd; the absorbed nodes have no consumers outside the region
|
||||
// and never need their own buffers. Removing them keeps later
|
||||
// per-execute walks (e.g., `allocate_intermediate_buffers`) from
|
||||
// chewing through dead nodes every decode token.
|
||||
//
|
||||
// Root FusionEnd nodes are NOT in `globally_absorbed` (they were the
|
||||
// walks' starting points), so we keep them — they're the kernel
|
||||
// anchor for the region's compiled kernel.
|
||||
for node in globally_absorbed {
|
||||
// Defensive: only remove if the node still exists.
|
||||
if llir_graph.node_weight(node).is_some() {
|
||||
llir_graph.remove_node(node);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
pub mod dyn_backend;
|
||||
pub mod host;
|
||||
pub mod kernel;
|
||||
pub mod runtime;
|
||||
@@ -9,6 +10,8 @@ use std::{
|
||||
|
||||
pub use cudarc;
|
||||
|
||||
use cudarc::{cublaslt::CudaBlasLT, driver::CudaStream};
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
@@ -137,6 +140,25 @@ fn cuda_driver_diagnostics() -> (Option<i32>, Option<i32>) {
|
||||
(driver_version, None)
|
||||
}
|
||||
|
||||
pub(crate) fn try_create_cublaslt(
|
||||
stream: Arc<CudaStream>,
|
||||
) -> std::result::Result<Arc<CudaBlasLT>, String> {
|
||||
match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| CudaBlasLT::new(stream))) {
|
||||
Ok(Ok(handle)) => Ok(Arc::new(handle)),
|
||||
Ok(Err(err)) => Err(err.to_string()),
|
||||
Err(payload) => {
|
||||
let message = if let Some(message) = payload.downcast_ref::<String>() {
|
||||
message.clone()
|
||||
} else if let Some(message) = payload.downcast_ref::<&str>() {
|
||||
message.to_string()
|
||||
} else {
|
||||
"cuBLASLt initialization panicked".to_string()
|
||||
};
|
||||
Err(message)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn cuda_nvrtc_compile_options(target_arch: &str) -> Vec<String> {
|
||||
let mut options = cuda_nvrtc_include_paths()
|
||||
.into_iter()
|
||||
@@ -186,9 +208,9 @@ fn get_cubin(program: nvrtc_sys::nvrtcProgram) -> Result<Vec<u8>, NvrtcError> {
|
||||
}
|
||||
|
||||
let mut cubin = Vec::with_capacity(cubin_size);
|
||||
cubin.resize(cubin_size, 0);
|
||||
unsafe { nvrtc_sys::nvrtcGetCUBIN(program, cubin.as_mut_ptr()) }.result()?;
|
||||
Ok(cubin.into_iter().map(|byte| byte as u8).collect())
|
||||
cubin.resize(cubin_size, 0u8);
|
||||
unsafe { nvrtc_sys::nvrtcGetCUBIN(program, cubin.as_mut_ptr() as *mut _) }.result()?;
|
||||
Ok(cubin)
|
||||
}
|
||||
|
||||
pub(crate) fn compile_module_image_for_current_device<S: AsRef<str>>(
|
||||
|
||||
@@ -664,6 +664,22 @@ impl CudaRuntime {
|
||||
if bucket.llir_graph[node].to_op::<Input>().is_some() {
|
||||
continue;
|
||||
}
|
||||
// Skip fusion marker / interior nodes. Region codegen folds
|
||||
// FusionStart / FusionEnd / FusedX into a single CUDA function
|
||||
// anchored at the FusionEnd; these marker nodes never need a
|
||||
// device buffer of their own at runtime, so walking them here
|
||||
// each step (with `p` incrementing every decode token) is
|
||||
// pure overhead. Skipping them recovers ~2 ms / token on
|
||||
// llama with fusion enabled.
|
||||
if let Some(op) = bucket.llir_graph[node].to_dialect::<dyn KernelOp>() {
|
||||
let kn = op.kernel_name();
|
||||
if kn == "FusionStart" || kn.starts_with("Fused") {
|
||||
continue;
|
||||
}
|
||||
// Note: we deliberately keep "FusionEnd" because it is the
|
||||
// anchor for the region's compiled kernel and DOES need a
|
||||
// buffer for the region's output.
|
||||
}
|
||||
let needed_bytes =
|
||||
if let Some(op) = bucket.llir_graph[node].to_dialect::<dyn KernelOp>() {
|
||||
let out_bytes = op.output_bytes();
|
||||
@@ -860,6 +876,10 @@ impl Runtime for CudaRuntime {
|
||||
}
|
||||
}
|
||||
|
||||
fn aggregate_profile_metrics(metrics: &[Self::ProfileMetric]) -> Self::ProfileMetric {
|
||||
metrics.iter().copied().sum()
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
fn load_llir(&mut self, llir_graph: &LLIRGraph) {
|
||||
// Sync before clearing old data to ensure all operations complete
|
||||
@@ -892,15 +912,13 @@ impl Runtime for CudaRuntime {
|
||||
}
|
||||
}
|
||||
|
||||
fn allocate_dummy_input(&mut self, node_index: usize, num_elements: usize) {
|
||||
// Use small non-zero values (ones) instead of zeros so that NaN-producing
|
||||
// graph variants are detected during profiling. Zero inputs often hide
|
||||
// numerical issues that appear with real data.
|
||||
let host_data = vec![1.0f32; num_elements];
|
||||
let buf = self
|
||||
.cuda_stream
|
||||
.clone_htod(bytemuck::cast_slice::<f32, u8>(&host_data))
|
||||
.unwrap();
|
||||
fn allocate_dummy_input(&mut self, node_index: usize, num_bytes: usize) {
|
||||
// Boundary scratch buffers are sized in raw bytes and may represent
|
||||
// non-float tensors such as gather/scatter indices. Initialize with zero
|
||||
// bytes so integer boundaries stay in-range and the raw allocation size
|
||||
// matches the requested tensor storage.
|
||||
let host_data = vec![0u8; num_bytes];
|
||||
let buf = self.cuda_stream.clone_htod(&host_data).unwrap();
|
||||
let id = NodeIndex::new(node_index);
|
||||
self.hlir_buffers.insert(id, CudaInput::Buffer(buf));
|
||||
self.changed_hlir.insert(id);
|
||||
|
||||
@@ -301,9 +301,8 @@ fn test_scatter_kv_cache_roundtrip() {
|
||||
}
|
||||
|
||||
/// Test scatter with TWO cache buffers and dual outputs (closer to llama K+V pattern).
|
||||
/// Also verifies graph_break interaction.
|
||||
#[test]
|
||||
fn test_scatter_dual_cache_with_graph_break() {
|
||||
fn test_scatter_dual_cache() {
|
||||
let ctx = CudaContext::new(0).unwrap();
|
||||
ctx.bind_to_thread().unwrap();
|
||||
let stream = ctx.default_stream();
|
||||
|
||||
986
crates/luminal_cuda_lite/src/tests/fusion.rs
Normal file
986
crates/luminal_cuda_lite/src/tests/fusion.rs
Normal file
@@ -0,0 +1,986 @@
|
||||
use luminal::egglog_utils::{egglog_to_llir, random_initial_choice};
|
||||
use luminal::prelude::*;
|
||||
|
||||
use crate::kernel::KernelOp;
|
||||
use crate::runtime::CudaRuntime;
|
||||
use crate::tests::utilities::{
|
||||
TOLERANCE_SAFETY_FACTOR, dtype_epsilon, random_f32_vec, test_binary_cuda, test_unary_cuda,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn test_two_unary_ops_fuse() {
|
||||
// Marker form: `a.sin().sqrt()` should fuse into a region with FusedSin
|
||||
// and FusedSqrt under one FusionEnd (per pair-fuse U→U).
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor(8);
|
||||
let _b = a.sin().sqrt().output();
|
||||
|
||||
let regions = extract_all_fused_regions(&mut cx);
|
||||
let expected = sorted_names(&["FusedSin", "FusedSqrt"]);
|
||||
assert!(
|
||||
regions
|
||||
.iter()
|
||||
.any(|r| r.internal_ops_sorted == expected && r.start_count == 1 && r.end_count == 1),
|
||||
"expected a marker region of {expected:?} with 1 FusionStart, got: {regions:#?}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_stride_mismatch_prevents_fusion() {
|
||||
// A permute between sin and sqrt gives sqrt a non-contiguous view of sin's
|
||||
// contiguous output, so sqrt's in_strides != its out_strides and the
|
||||
// non-linear `?s ?s` match in the pair-fuse U→U rule can't fire.
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor((3, 4));
|
||||
let _b = a.sin().permute((1, 0)).sqrt().output();
|
||||
|
||||
let regions = extract_all_fused_regions(&mut cx);
|
||||
for r in ®ions {
|
||||
let has_sin = r.internal_ops_sorted.iter().any(|n| n == "FusedSin");
|
||||
let has_sqrt = r.internal_ops_sorted.iter().any(|n| n == "FusedSqrt");
|
||||
assert!(
|
||||
!(has_sin && has_sqrt),
|
||||
"permute between sin and sqrt must prevent them sharing a fused region, \
|
||||
but found: {r:#?}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reduction_prevents_unary_fusion() {
|
||||
// A reduction between two unaries is not elementwise, so pair-fuse U→U
|
||||
// (which only matches adjacent elementwise pairs) must not fire across
|
||||
// the reduction.
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor((4, 4));
|
||||
let _b = a.sin().sum(1).sqrt().output();
|
||||
|
||||
let regions = extract_all_fused_regions(&mut cx);
|
||||
for r in ®ions {
|
||||
let has_sin = r.internal_ops_sorted.iter().any(|n| n == "FusedSin");
|
||||
let has_sqrt = r.internal_ops_sorted.iter().any(|n| n == "FusedSqrt");
|
||||
assert!(
|
||||
!(has_sin && has_sqrt),
|
||||
"reduction between sin and sqrt must prevent them sharing a fused region, \
|
||||
but found: {r:#?}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unary_fusion_preserves_output() {
|
||||
// End-to-end numerical check: sqrt(sin(x)) must produce the same values
|
||||
// whether or not the fusion rule fired. Runs on GPU when available;
|
||||
// silently no-ops otherwise via get_cuda_stream().
|
||||
let seed = 0xC0FFEEu64;
|
||||
let gen_lambda = |n, s| random_f32_vec(n, s, 0.0, 1.0);
|
||||
test_unary_cuda::<f32>(
|
||||
8,
|
||||
|a| a.sin().sqrt(),
|
||||
|a| a.sin().unwrap().sqrt().unwrap(),
|
||||
gen_lambda,
|
||||
seed,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_three_unary_ops_fuse() {
|
||||
// A chain of 3 pure-elementwise unaries with matching strides should be
|
||||
// reachable as a single marker region containing all three FusedX ops.
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor(16);
|
||||
let _b = a.sin().sqrt().exp2().output();
|
||||
|
||||
let regions = extract_all_fused_regions(&mut cx);
|
||||
let expected = sorted_names(&["FusedSin", "FusedSqrt", "FusedExp2"]);
|
||||
assert!(
|
||||
regions
|
||||
.iter()
|
||||
.any(|r| r.internal_ops_sorted == expected && r.start_count == 1 && r.end_count == 1),
|
||||
"expected a marker region of {expected:?} with 1 FusionStart, got: {regions:#?}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_four_unary_ops_fuse() {
|
||||
// 4-op chain should collapse into a single marker region containing all
|
||||
// four FusedX ops (one pair-fuse + repeated grow-FE→U firings).
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor(16);
|
||||
let _b = a.sin().sqrt().exp2().log2().output();
|
||||
|
||||
let regions = extract_all_fused_regions(&mut cx);
|
||||
let expected = sorted_names(&["FusedSin", "FusedSqrt", "FusedExp2", "FusedLog2"]);
|
||||
assert!(
|
||||
regions
|
||||
.iter()
|
||||
.any(|r| r.internal_ops_sorted == expected && r.start_count == 1 && r.end_count == 1),
|
||||
"expected a marker region of {expected:?} with 1 FusionStart, got: {regions:#?}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_three_unary_chain_preserves_output() {
|
||||
// End-to-end numerical check for a 3-op chain.
|
||||
// Uses sin→sqrt→sin because candle lacks exp2/log2 and this still exercises
|
||||
// a 3-link chain. The structural tests above cover the distinct-ops shape.
|
||||
let seed = 0xBEEFu64;
|
||||
let gen_lambda = |n, s| random_f32_vec(n, s, 0.0, 1.0);
|
||||
test_unary_cuda::<f32>(
|
||||
16,
|
||||
|a| a.sin().sqrt().sin(),
|
||||
|a| a.sin().unwrap().sqrt().unwrap().sin().unwrap(),
|
||||
gen_lambda,
|
||||
seed,
|
||||
);
|
||||
}
|
||||
|
||||
/// Isolated per-kernel microbenchmark: time two unfused kernels
|
||||
/// (`sqrt_k` then `recip_k`) vs one fused kernel (`fused_k` that does
|
||||
/// `1.0f / sqrtf(x)` in a single launch) on a fixed-size input, using
|
||||
/// CUDA events for device-side timing.
|
||||
///
|
||||
/// Ignored by default — run with
|
||||
/// `cargo test -p luminal_cuda_lite -- --ignored bench_fused_vs_unfused_sqrt_recip --nocapture`.
|
||||
#[test]
|
||||
#[ignore]
|
||||
fn bench_fused_vs_unfused_sqrt_recip() {
|
||||
use crate::compile_module_image_for_current_device;
|
||||
use cudarc::driver::{CudaContext, LaunchConfig, PushKernelArg};
|
||||
|
||||
const N: usize = 1 << 20; // 1M elements
|
||||
const WARMUP: usize = 100;
|
||||
const TRIALS: usize = 2000;
|
||||
|
||||
let ctx = match CudaContext::new(0) {
|
||||
Ok(c) => c,
|
||||
Err(_) => return, // no GPU available, skip
|
||||
};
|
||||
ctx.bind_to_thread().unwrap();
|
||||
let stream = ctx.default_stream();
|
||||
|
||||
// Prepare input (values in (0, 1] so sqrt/recip are well-defined).
|
||||
let host_input: Vec<f32> = (0..N).map(|i| (i as f32 + 1.0) / (N as f32)).collect();
|
||||
let d_in = stream.clone_htod(&host_input).unwrap();
|
||||
let mut d_scratch = stream.alloc_zeros::<f32>(N).unwrap();
|
||||
let mut d_out = stream.alloc_zeros::<f32>(N).unwrap();
|
||||
|
||||
let compile = |src: &str, name: &str| {
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), src).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
module.load_function(name).unwrap()
|
||||
};
|
||||
|
||||
let sqrt_k = compile(
|
||||
r#"
|
||||
extern "C" __global__ void sqrt_k(float* out, const float* in, long long n) {
|
||||
long long i = (long long)blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (i >= n) return;
|
||||
out[i] = sqrtf(in[i]);
|
||||
}
|
||||
"#,
|
||||
"sqrt_k",
|
||||
);
|
||||
let recip_k = compile(
|
||||
r#"
|
||||
extern "C" __global__ void recip_k(float* out, const float* in, long long n) {
|
||||
long long i = (long long)blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (i >= n) return;
|
||||
out[i] = 1.0f / in[i];
|
||||
}
|
||||
"#,
|
||||
"recip_k",
|
||||
);
|
||||
let fused_k = compile(
|
||||
r#"
|
||||
extern "C" __global__ void fused_k(float* out, const float* in, long long n) {
|
||||
long long i = (long long)blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (i >= n) return;
|
||||
float v = in[i];
|
||||
v = sqrtf(v);
|
||||
v = 1.0f / v;
|
||||
out[i] = v;
|
||||
}
|
||||
"#,
|
||||
"fused_k",
|
||||
);
|
||||
|
||||
let cfg = LaunchConfig::for_num_elems(N as u32);
|
||||
let n_arg: i64 = N as i64;
|
||||
|
||||
let launch_unfused = |d_out: &mut cudarc::driver::CudaSlice<f32>,
|
||||
d_scratch: &mut cudarc::driver::CudaSlice<f32>| {
|
||||
let mut b = stream.launch_builder(&sqrt_k);
|
||||
b.arg(&mut *d_scratch).arg(&d_in).arg(&n_arg);
|
||||
unsafe { b.launch(cfg) }.unwrap();
|
||||
let mut b = stream.launch_builder(&recip_k);
|
||||
b.arg(d_out).arg(&*d_scratch).arg(&n_arg);
|
||||
unsafe { b.launch(cfg) }.unwrap();
|
||||
};
|
||||
let launch_fused = |d_out: &mut cudarc::driver::CudaSlice<f32>| {
|
||||
let mut b = stream.launch_builder(&fused_k);
|
||||
b.arg(d_out).arg(&d_in).arg(&n_arg);
|
||||
unsafe { b.launch(cfg) }.unwrap();
|
||||
};
|
||||
|
||||
// Warmup
|
||||
for _ in 0..WARMUP {
|
||||
launch_unfused(&mut d_out, &mut d_scratch);
|
||||
launch_fused(&mut d_out);
|
||||
}
|
||||
stream.synchronize().unwrap();
|
||||
|
||||
let start = ctx.new_event(None).unwrap();
|
||||
let end = ctx.new_event(None).unwrap();
|
||||
|
||||
// Time unfused
|
||||
start.record(&stream).unwrap();
|
||||
for _ in 0..TRIALS {
|
||||
launch_unfused(&mut d_out, &mut d_scratch);
|
||||
}
|
||||
end.record(&stream).unwrap();
|
||||
end.synchronize().unwrap();
|
||||
let unfused_total_ms = start.elapsed_ms(&end).unwrap();
|
||||
|
||||
// Time fused
|
||||
start.record(&stream).unwrap();
|
||||
for _ in 0..TRIALS {
|
||||
launch_fused(&mut d_out);
|
||||
}
|
||||
end.record(&stream).unwrap();
|
||||
end.synchronize().unwrap();
|
||||
let fused_total_ms = start.elapsed_ms(&end).unwrap();
|
||||
|
||||
let unfused_us = unfused_total_ms as f64 * 1_000.0 / TRIALS as f64;
|
||||
let fused_us = fused_total_ms as f64 * 1_000.0 / TRIALS as f64;
|
||||
let speedup = unfused_us / fused_us;
|
||||
|
||||
println!(
|
||||
"\n[fusion microbench, N={N}, trials={TRIALS}]\n\
|
||||
unfused (sqrt_k; recip_k): {unfused_us:8.3} us/iter ({unfused_total_ms:.2} ms total)\n\
|
||||
fused (sqrtf; 1.0f/): {fused_us:8.3} us/iter ({fused_total_ms:.2} ms total)\n\
|
||||
speedup: {speedup:.2}x"
|
||||
);
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// Binary-inclusive fusion tests (marker-based FusionStart / FusionEnd scheme).
|
||||
//
|
||||
// Detects fused regions by walking backward from each `FusionEnd`-tagged LLIR
|
||||
// node through `Direction::Incoming` edges until a `FusionStart` is reached.
|
||||
// The walker stops at FusionStarts (they mark the external-input boundary of
|
||||
// the region). A region's summary is: the sorted set of internal op names,
|
||||
// the count of distinct FusionStart nodes reached, and the count of FusionEnd
|
||||
// nodes (invariant: always 1 per region).
|
||||
// =========================================================================
|
||||
|
||||
/// A single fused region extracted from the LLIR graph after egglog.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
struct FusedRegion {
|
||||
/// Sorted internal op `kernel_name()`s, excluding the `FusionStart` /
|
||||
/// `FusionEnd` markers. Sorted so DAG traversal order doesn't produce
|
||||
/// spurious "distinct" regions.
|
||||
internal_ops_sorted: Vec<String>,
|
||||
/// Number of distinct `FusionStart` nodes reached by the walk. Per design
|
||||
/// this equals the number of distinct external input tensors.
|
||||
start_count: usize,
|
||||
/// Number of `FusionEnd` nodes in the region. Per design this is always 1.
|
||||
end_count: usize,
|
||||
}
|
||||
|
||||
/// Helper: collect every distinct fused region reachable across many random
|
||||
/// extractions of the search space.
|
||||
fn extract_all_fused_regions(cx: &mut Graph) -> Vec<FusedRegion> {
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
let egraph = cx.egraph().expect("egraph not built");
|
||||
let ops = cx.egglog_ops().expect("ops not built");
|
||||
let custom_ops = &cx.custom_ops;
|
||||
|
||||
let mut seen: Vec<FusedRegion> = Vec::new();
|
||||
// 200 samples: the random extractor picks one e-node per e-class per
|
||||
// call, and the fully-fused diamond form lives in an e-class with
|
||||
// many equivalent forms. 50 was flaky; 200 is reliably stable and
|
||||
// each sample is cheap (~100 µs).
|
||||
for _ in 0..200 {
|
||||
let choices = random_initial_choice(egraph, &mut rand::rng());
|
||||
let mut list_cache = Default::default();
|
||||
let mut expr_cache = Default::default();
|
||||
let llir = egglog_to_llir(
|
||||
egraph,
|
||||
choices,
|
||||
ops,
|
||||
custom_ops,
|
||||
&mut list_cache,
|
||||
&mut expr_cache,
|
||||
None,
|
||||
);
|
||||
|
||||
let name_of = |idx: NodeIndex| -> Option<String> {
|
||||
llir.node_weight(idx).and_then(|op| {
|
||||
op.to_dialect::<dyn KernelOp>()
|
||||
.map(|k| k.kernel_name().to_string())
|
||||
})
|
||||
};
|
||||
|
||||
let end_nodes: Vec<NodeIndex> = llir
|
||||
.node_indices()
|
||||
.filter(|&idx| name_of(idx).as_deref() == Some("FusionEnd"))
|
||||
.collect();
|
||||
|
||||
for end in end_nodes {
|
||||
let mut internal: Vec<String> = Vec::new();
|
||||
// Count distinct external input *tensors*, not distinct FusionStart
|
||||
// node indices. Egglog rule firings can emit multiple FusionStart
|
||||
// enodes that all wrap the same source tensor (e.g. when the same
|
||||
// `a` is consumed at two sites inside the fused region, each
|
||||
// pair-fuse / grow firing mints its own FusionStart). Those are
|
||||
// logically one FusionStart per the design invariant
|
||||
// ("N = number of distinct external input tensors").
|
||||
let mut start_sources: FxHashSet<NodeIndex> = FxHashSet::default();
|
||||
let mut visited: FxHashSet<NodeIndex> = FxHashSet::default();
|
||||
visited.insert(end);
|
||||
let mut stack = vec![end];
|
||||
|
||||
// Resolve chains of nested FusionStart wrappers (cascade artifact)
|
||||
// to the real external source. A FusionStart whose incoming neighbor
|
||||
// is itself a FusionStart — or a FusionEnd whose region is fully
|
||||
// inside ours — is a cascade layer, not a new external tensor.
|
||||
let resolve_source = |mut n: NodeIndex| -> NodeIndex {
|
||||
loop {
|
||||
match name_of(n).as_deref() {
|
||||
Some("FusionStart") | Some("FusionEnd") => {
|
||||
let mut inc = llir.neighbors_directed(n, petgraph::Direction::Incoming);
|
||||
match inc.next() {
|
||||
Some(p) => n = p,
|
||||
None => return n,
|
||||
}
|
||||
}
|
||||
_ => return n,
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
while let Some(node) = stack.pop() {
|
||||
for pred in llir.neighbors_directed(node, petgraph::Direction::Incoming) {
|
||||
if !visited.insert(pred) {
|
||||
continue;
|
||||
}
|
||||
match name_of(pred).as_deref() {
|
||||
Some("FusionStart") => {
|
||||
// If this FS's predecessor is itself a FE (or a
|
||||
// chain of FS/FE wrappers that eventually hits a
|
||||
// non-marker op inside the region), the FS is a
|
||||
// cascade artifact, not a real external boundary.
|
||||
// Walk past it and its upstream FE into the same
|
||||
// region. Otherwise treat the predecessor as the
|
||||
// external source tensor — which may be a KernelOp
|
||||
// *or* a non-KernelOp (HLIR loadable) node, so we
|
||||
// can't gate counting on `name_of` being `Some`.
|
||||
let mut inc =
|
||||
llir.neighbors_directed(pred, petgraph::Direction::Incoming);
|
||||
match inc.next() {
|
||||
Some(src_node)
|
||||
if name_of(src_node).as_deref() == Some("FusionEnd") =>
|
||||
{
|
||||
// Merge adjacent regions — treat the FS/FE
|
||||
// pair as internal; walk past the upstream
|
||||
// FE into its region.
|
||||
visited.insert(src_node);
|
||||
stack.push(src_node);
|
||||
}
|
||||
Some(src_node) => {
|
||||
start_sources.insert(resolve_source(src_node));
|
||||
}
|
||||
None => {
|
||||
// FS with no predecessor — degenerate.
|
||||
}
|
||||
}
|
||||
}
|
||||
Some("FusionEnd") => {
|
||||
// Transparent: inner FusionEnds are cascade-wart
|
||||
// artifacts from grow rules re-firing and creating
|
||||
// nested `FE(Op(FE(...)))` wrappers. They don't
|
||||
// represent real work or a real boundary — walk
|
||||
// past them and do not count them as internal ops.
|
||||
stack.push(pred);
|
||||
}
|
||||
Some(other) => {
|
||||
internal.push(other.to_string());
|
||||
stack.push(pred);
|
||||
}
|
||||
None => {
|
||||
// Non-KernelOp predecessor (shouldn't appear inside a
|
||||
// fused region under the design). Stop walking this path.
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
internal.sort();
|
||||
// Skip singleton regions: every elementwise op has a seeded
|
||||
// `FE(Op(FS(...)))` form, so random extraction will surface
|
||||
// many one-op regions that are equivalent to not fusing. We
|
||||
// only care about regions that represent real multi-op fusion.
|
||||
if internal.len() < 2 {
|
||||
continue;
|
||||
}
|
||||
let region = FusedRegion {
|
||||
internal_ops_sorted: internal,
|
||||
start_count: start_sources.len(),
|
||||
end_count: 1,
|
||||
};
|
||||
if !seen.contains(®ion) {
|
||||
seen.push(region);
|
||||
}
|
||||
}
|
||||
}
|
||||
seen
|
||||
}
|
||||
|
||||
fn sorted_names(items: &[&str]) -> Vec<String> {
|
||||
let mut v: Vec<String> = items.iter().map(|s| (*s).to_string()).collect();
|
||||
v.sort();
|
||||
v
|
||||
}
|
||||
|
||||
// ---- Structural tests: the expected fused shape is reachable ----
|
||||
|
||||
#[test]
|
||||
fn test_single_binary_does_not_fuse_alone() {
|
||||
// A lone elementwise op gets a seeded singleton region by design; we
|
||||
// filter singletons out in `extract_all_fused_regions`. What this test
|
||||
// asserts is that no *multi-op* region appears for a standalone binary
|
||||
// — nothing to grow into.
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor(8);
|
||||
let b = cx.tensor(8);
|
||||
let _c = (a + b).output();
|
||||
|
||||
let regions = extract_all_fused_regions(&mut cx);
|
||||
assert!(
|
||||
regions.is_empty(),
|
||||
"a solo binary op should not form a multi-op fused region, but got: {regions:#?}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chain_of_binaries_fuses() {
|
||||
// `(a + b) * c`: three external inputs collapse into one region with
|
||||
// internal [Add, Mul] and 3 FusionStarts.
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor(8);
|
||||
let b = cx.tensor(8);
|
||||
let c = cx.tensor(8);
|
||||
let _d = ((a + b) * c).output();
|
||||
|
||||
let regions = extract_all_fused_regions(&mut cx);
|
||||
let expected = sorted_names(&["FusedAdd", "FusedMul"]);
|
||||
assert!(
|
||||
regions
|
||||
.iter()
|
||||
.any(|r| r.internal_ops_sorted == expected && r.start_count == 3),
|
||||
"expected a fused region of {expected:?} with 3 FusionStarts, got: {regions:#?}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_then_unary_fuses() {
|
||||
// `sin(a + b)`: binary feeds a unary inside one fused region.
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor(8);
|
||||
let b = cx.tensor(8);
|
||||
let _c = (a + b).sin().output();
|
||||
|
||||
let regions = extract_all_fused_regions(&mut cx);
|
||||
let expected = sorted_names(&["FusedAdd", "FusedSin"]);
|
||||
assert!(
|
||||
regions
|
||||
.iter()
|
||||
.any(|r| r.internal_ops_sorted == expected && r.start_count == 2),
|
||||
"expected a fused region of {expected:?} with 2 FusionStarts, got: {regions:#?}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unary_then_binary_fuses() {
|
||||
// `sin(a) + b`: unary feeds a binary inside one fused region.
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor(8);
|
||||
let b = cx.tensor(8);
|
||||
let _c = (a.sin() + b).output();
|
||||
|
||||
let regions = extract_all_fused_regions(&mut cx);
|
||||
let expected = sorted_names(&["FusedAdd", "FusedSin"]);
|
||||
assert!(
|
||||
regions
|
||||
.iter()
|
||||
.any(|r| r.internal_ops_sorted == expected && r.start_count == 2),
|
||||
"expected a fused region of {expected:?} with 2 FusionStarts, got: {regions:#?}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_diamond_dag_fuses() {
|
||||
// The canonical diamond-DAG example agreed with the user:
|
||||
// t = a + b; u = exp2(t); v = sin(t); w = u * a; out = w + v
|
||||
// `a` is reused (feeds outer Add and Mul) and `t` is reused (feeds Exp2 and
|
||||
// Sin). Expected: one fused region with internal ops [Add, Add, Exp2, Mul,
|
||||
// Sin], 2 FusionStarts (distinct tensors a, b), 1 FusionEnd.
|
||||
// We use exp2 rather than exp because the frontend's exp() desugars to
|
||||
// Mul(x, LOG2E).exp2(), which would add a constant input and a Mul op and
|
||||
// obscure the diamond topology this test is checking.
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor(8);
|
||||
let b = cx.tensor(8);
|
||||
let t = a + b;
|
||||
let u = t.exp2();
|
||||
let v = t.sin();
|
||||
let w = u * a;
|
||||
let _out = (w + v).output();
|
||||
|
||||
let regions = extract_all_fused_regions(&mut cx);
|
||||
let expected = sorted_names(&["FusedAdd", "FusedAdd", "FusedExp2", "FusedMul", "FusedSin"]);
|
||||
assert!(
|
||||
regions
|
||||
.iter()
|
||||
.any(|r| r.internal_ops_sorted == expected && r.start_count == 2 && r.end_count == 1),
|
||||
"expected diamond DAG to fuse into one region with ops {expected:?}, \
|
||||
2 FusionStarts, 1 FusionEnd. Got: {regions:#?}"
|
||||
);
|
||||
}
|
||||
|
||||
// ---- Negative tests: fusion must NOT happen across these blockers ----
|
||||
|
||||
#[test]
|
||||
fn test_reduction_blocks_binary_fusion() {
|
||||
// A reduction between a binary and anything downstream is not elementwise,
|
||||
// so Add and SumReduce must never appear in the same fused region.
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor((4, 4));
|
||||
let b = cx.tensor((4, 4));
|
||||
let _c = (a + b).sum(1).output();
|
||||
|
||||
let regions = extract_all_fused_regions(&mut cx);
|
||||
for r in ®ions {
|
||||
let has_add = r.internal_ops_sorted.iter().any(|n| n == "FusedAdd");
|
||||
let has_sum = r.internal_ops_sorted.iter().any(|n| n == "SumReduce");
|
||||
assert!(
|
||||
!(has_add && has_sum),
|
||||
"FusedAdd and SumReduce must not share a fused region, but got: {r:#?}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_stride_mismatch_blocks_binary_fusion() {
|
||||
// A permute gives `b` a non-contiguous view whose strides do not match `a`'s,
|
||||
// so the binary fusion rule's stride-compatibility check must prevent the
|
||||
// Add from being absorbed into any fused region.
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor((3, 4));
|
||||
let b = cx.tensor((4, 3));
|
||||
let _c = (a + b.permute((1, 0))).output();
|
||||
|
||||
let regions = extract_all_fused_regions(&mut cx);
|
||||
for r in ®ions {
|
||||
assert!(
|
||||
!r.internal_ops_sorted.iter().any(|n| n == "FusedAdd"),
|
||||
"permuted binary must not fuse into a region, but found: {r:#?}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Numerical parity tests: fused output matches candle reference ----
|
||||
|
||||
#[test]
|
||||
fn test_simple_binary_fusion_preserves_output() {
|
||||
// End-to-end numerical check: `a + b` on GPU matches candle's add across
|
||||
// all reachable genomes (fused or unfused) via test_binary_cuda's fuzzer.
|
||||
let seed = 0xADDBEEFu64;
|
||||
let eps = dtype_epsilon(luminal::dtype::DType::F32);
|
||||
let tol = eps * TOLERANCE_SAFETY_FACTOR;
|
||||
test_binary_cuda::<f32>(
|
||||
16,
|
||||
16,
|
||||
|a, b| a + b,
|
||||
|a, b| (a + b).unwrap(),
|
||||
|n, s| random_f32_vec(n, s, 0.0, 1.0),
|
||||
|n, s| random_f32_vec(n, s, 0.0, 1.0),
|
||||
seed,
|
||||
tol,
|
||||
tol,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_diamond_dag_preserves_output() {
|
||||
// Numerical parity for the diamond DAG: `(exp(a+b) * a) + sin(a+b)`
|
||||
// matches candle's equivalent across fused and unfused genomes.
|
||||
// Inputs are drawn from [-1, 1] so exp() doesn't overflow.
|
||||
let seed = 0xD1A_0D1Au64;
|
||||
let eps = dtype_epsilon(luminal::dtype::DType::F32);
|
||||
// Five-op chain with exp + sin: allow ~5x safety to absorb accumulated
|
||||
// rounding vs candle's kernels.
|
||||
let tol = eps * TOLERANCE_SAFETY_FACTOR * 5.0;
|
||||
test_binary_cuda::<f32>(
|
||||
16,
|
||||
16,
|
||||
|a, b| {
|
||||
let t = a + b;
|
||||
let u = t.exp();
|
||||
let v = t.sin();
|
||||
let w = u * a;
|
||||
w + v
|
||||
},
|
||||
|a, b| {
|
||||
let t = (&a + &b).unwrap();
|
||||
let u = t.exp().unwrap();
|
||||
let v = t.sin().unwrap();
|
||||
let w = (&u * &a).unwrap();
|
||||
(&w + &v).unwrap()
|
||||
},
|
||||
|n, s| random_f32_vec(n, s, -1.0, 1.0),
|
||||
|n, s| random_f32_vec(n, s, -1.0, 1.0),
|
||||
seed,
|
||||
tol,
|
||||
tol,
|
||||
);
|
||||
}
|
||||
|
||||
// ---- Marker invariant tests ----
|
||||
|
||||
#[test]
|
||||
fn test_fused_region_has_exactly_one_end() {
|
||||
// Design invariant: a fused region always has exactly one FusionEnd.
|
||||
// Uses the diamond DAG so there's real fan-in/out inside the region.
|
||||
// See test_diamond_dag_fuses for why we use exp2 directly.
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor(8);
|
||||
let b = cx.tensor(8);
|
||||
let t = a + b;
|
||||
let u = t.exp2();
|
||||
let v = t.sin();
|
||||
let w = u * a;
|
||||
let _out = (w + v).output();
|
||||
|
||||
let regions = extract_all_fused_regions(&mut cx);
|
||||
let expected = sorted_names(&["FusedAdd", "FusedAdd", "FusedExp2", "FusedMul", "FusedSin"]);
|
||||
let full = regions
|
||||
.iter()
|
||||
.find(|r| r.internal_ops_sorted == expected)
|
||||
.expect("expected at least one extraction to produce the full 5-op diamond region");
|
||||
assert_eq!(
|
||||
full.end_count, 1,
|
||||
"fused region must have exactly one FusionEnd, got {}",
|
||||
full.end_count
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fused_region_starts_match_distinct_external_tensors() {
|
||||
// Design invariant: FusionStart count == number of distinct external input
|
||||
// tensors, NOT number of edges crossing the boundary. In the diamond DAG
|
||||
// `a` is consumed inside the region by two ops (outer Add + Mul), so a
|
||||
// per-edge counting scheme would give 3; the correct per-distinct-tensor
|
||||
// count is 2 ({a, b}).
|
||||
// See test_diamond_dag_fuses for why we use exp2 directly.
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor(8);
|
||||
let b = cx.tensor(8);
|
||||
let t = a + b;
|
||||
let u = t.exp2();
|
||||
let v = t.sin();
|
||||
let w = u * a;
|
||||
let _out = (w + v).output();
|
||||
|
||||
let regions = extract_all_fused_regions(&mut cx);
|
||||
let expected = sorted_names(&["FusedAdd", "FusedAdd", "FusedExp2", "FusedMul", "FusedSin"]);
|
||||
// Multiple 5-op extractions are reachable: the merge-FE-FE rule fires
|
||||
// across paths that may have minted distinct FS enodes for the shared
|
||||
// tensor `a` at separate sites. The design invariant is that *some*
|
||||
// extraction collapses those into the deduped form (one FS per distinct
|
||||
// tensor → 2 FS for {a, b}); we don't require every random sample to.
|
||||
let matching: Vec<&FusedRegion> = regions
|
||||
.iter()
|
||||
.filter(|r| r.internal_ops_sorted == expected)
|
||||
.collect();
|
||||
assert!(
|
||||
!matching.is_empty(),
|
||||
"expected at least one extraction to produce the full 5-op diamond region, \
|
||||
got: {regions:#?}"
|
||||
);
|
||||
assert!(
|
||||
matching
|
||||
.iter()
|
||||
.any(|r| r.start_count == 2 && r.end_count == 1),
|
||||
"expected at least one 5-op diamond extraction with FusionStart count == 2 \
|
||||
(one per distinct external tensor) and FusionEnd count == 1; got: {matching:#?}"
|
||||
);
|
||||
}
|
||||
|
||||
// ---- Targeted rule-family tests (one per family / orientation) ----
|
||||
//
|
||||
// The structural and diamond tests above hit several rule families at once.
|
||||
// These narrow tests pin each rule family / orientation independently so a
|
||||
// regression in one rule shows up as a single failing test rather than a
|
||||
// confusing diamond mismatch.
|
||||
|
||||
#[test]
|
||||
fn test_pair_fuse_unary_unary_marker_form() {
|
||||
// Pair-fuse U→U: `a.sin().sqrt()` should be reachable as a marker-bracketed
|
||||
// region containing FusedSin and FusedSqrt (with one FusionStart for `a`).
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor(8);
|
||||
let _b = a.sin().sqrt().output();
|
||||
|
||||
let regions = extract_all_fused_regions(&mut cx);
|
||||
let expected = sorted_names(&["FusedSin", "FusedSqrt"]);
|
||||
assert!(
|
||||
regions
|
||||
.iter()
|
||||
.any(|r| r.internal_ops_sorted == expected && r.start_count == 1 && r.end_count == 1),
|
||||
"expected marker region of {expected:?} with 1 FusionStart, got: {regions:#?}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pair_fuse_unary_to_binary_rhs() {
|
||||
// Pair-fuse U→B (RHS variant): `a + b.sin()`. The unary is on the
|
||||
// binary's B input, so the rule's RHS-orientation version is what fires.
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor(8);
|
||||
let b = cx.tensor(8);
|
||||
let _c = (a + b.sin()).output();
|
||||
|
||||
let regions = extract_all_fused_regions(&mut cx);
|
||||
let expected = sorted_names(&["FusedAdd", "FusedSin"]);
|
||||
assert!(
|
||||
regions
|
||||
.iter()
|
||||
.any(|r| r.internal_ops_sorted == expected && r.start_count == 2),
|
||||
"expected a fused region of {expected:?} with 2 FusionStarts (RHS-side unary), \
|
||||
got: {regions:#?}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pair_fuse_binary_to_binary_rhs() {
|
||||
// Pair-fuse B→B (RHS variant): `c * (a + b)`. The inner binary feeds the
|
||||
// outer binary's B input, exercising the mirror direction of the rule
|
||||
// covered by test_chain_of_binaries_fuses.
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor(8);
|
||||
let b = cx.tensor(8);
|
||||
let c = cx.tensor(8);
|
||||
let _d = (c * (a + b)).output();
|
||||
|
||||
let regions = extract_all_fused_regions(&mut cx);
|
||||
let expected = sorted_names(&["FusedAdd", "FusedMul"]);
|
||||
assert!(
|
||||
regions
|
||||
.iter()
|
||||
.any(|r| r.internal_ops_sorted == expected && r.start_count == 3),
|
||||
"expected a fused region of {expected:?} with 3 FusionStarts (RHS-side inner binary), \
|
||||
got: {regions:#?}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_grow_fe_to_binary_rhs() {
|
||||
// Grow FE→B (RHS variant): `c + (a.sin() + b)`. Once the inner
|
||||
// `a.sin() + b` is fused, the outer `+ c` consumes that FE on its B input
|
||||
// (because we wrote `c + (...)` — `c` is on LHS, FE on RHS), exercising
|
||||
// grow-FE-B-rhs to absorb the outer Add into the same region.
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor(8);
|
||||
let b = cx.tensor(8);
|
||||
let c = cx.tensor(8);
|
||||
let _d = (c + (a.sin() + b)).output();
|
||||
|
||||
let regions = extract_all_fused_regions(&mut cx);
|
||||
let expected = sorted_names(&["FusedAdd", "FusedAdd", "FusedSin"]);
|
||||
assert!(
|
||||
regions
|
||||
.iter()
|
||||
.any(|r| r.internal_ops_sorted == expected && r.start_count == 3),
|
||||
"expected a 3-op fused region of {expected:?} with 3 FusionStarts (grow into RHS), \
|
||||
got: {regions:#?}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_merge_two_regions_at_outer_binary() {
|
||||
// Merge: `(sin(a) + b) + (sqrt(c) + d)`. Each side independently pair-fuses
|
||||
// U→B on its own (the unary gives the inner Add a fusion partner that
|
||||
// doesn't pull in the outer Add), so both sides become FEs. The outer Add
|
||||
// then fires merge-FE-FE-Add to collapse them into a single region.
|
||||
// Without the unaries, `(a+b) + (c+d)` would only ever pair-fuse one
|
||||
// inner Add at a time with the outer Add — merge wouldn't have two FEs to
|
||||
// combine because the inner Adds never become singleton FEs on their own.
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor(8);
|
||||
let b = cx.tensor(8);
|
||||
let c = cx.tensor(8);
|
||||
let d = cx.tensor(8);
|
||||
let _e = ((a.sin() + b) + (c.sqrt() + d)).output();
|
||||
|
||||
let regions = extract_all_fused_regions(&mut cx);
|
||||
let expected = sorted_names(&["FusedAdd", "FusedAdd", "FusedAdd", "FusedSin", "FusedSqrt"]);
|
||||
assert!(
|
||||
regions
|
||||
.iter()
|
||||
.any(|r| r.internal_ops_sorted == expected && r.start_count == 4),
|
||||
"expected a 5-op merged region (two pair-fused sides combined at outer Add) with \
|
||||
4 FusionStarts, got: {regions:#?}"
|
||||
);
|
||||
}
|
||||
|
||||
/// Microbench: time three unfused kernels (`add_k` → `sin_k` → `sqrt_k`)
|
||||
/// vs one fused kernel (`(a + b).sin().sqrt()` in a single launch) on a
|
||||
/// fixed-size input, using CUDA events for device-side timing. Mirrors
|
||||
/// the existing sqrt→recip bench but on the binary-inclusive 3-op DAG
|
||||
/// PR2's region codegen targets.
|
||||
///
|
||||
/// Ignored by default — run with
|
||||
/// `cargo test -p luminal_cuda_lite -- --ignored bench_fused_region_vs_unfused_3op --nocapture`.
|
||||
#[test]
|
||||
#[ignore]
|
||||
fn bench_fused_region_vs_unfused_3op() {
|
||||
use crate::compile_module_image_for_current_device;
|
||||
use cudarc::driver::{CudaContext, LaunchConfig, PushKernelArg};
|
||||
|
||||
const N: usize = 1 << 20; // 1M elements
|
||||
const WARMUP: usize = 100;
|
||||
const TRIALS: usize = 2000;
|
||||
|
||||
let ctx = match CudaContext::new(0) {
|
||||
Ok(c) => c,
|
||||
Err(_) => return, // no GPU available, skip
|
||||
};
|
||||
ctx.bind_to_thread().unwrap();
|
||||
let stream = ctx.default_stream();
|
||||
|
||||
// Inputs in (0, 1] keep `sin` < 1 and `sqrt` well-defined post-add.
|
||||
let host_a: Vec<f32> = (0..N)
|
||||
.map(|i| (i as f32 + 1.0) / (N as f32) * 0.5)
|
||||
.collect();
|
||||
let host_b: Vec<f32> = (0..N)
|
||||
.map(|i| (i as f32 + 1.0) / (N as f32) * 0.5)
|
||||
.collect();
|
||||
let d_a = stream.clone_htod(&host_a).unwrap();
|
||||
let d_b = stream.clone_htod(&host_b).unwrap();
|
||||
let mut d_scratch1 = stream.alloc_zeros::<f32>(N).unwrap();
|
||||
let mut d_scratch2 = stream.alloc_zeros::<f32>(N).unwrap();
|
||||
let mut d_out = stream.alloc_zeros::<f32>(N).unwrap();
|
||||
|
||||
let compile = |src: &str, name: &str| {
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), src).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
module.load_function(name).unwrap()
|
||||
};
|
||||
|
||||
let add_k = compile(
|
||||
r#"
|
||||
extern "C" __global__ void add_k(float* out, const float* a, const float* b, long long n) {
|
||||
long long i = (long long)blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (i >= n) return;
|
||||
out[i] = a[i] + b[i];
|
||||
}
|
||||
"#,
|
||||
"add_k",
|
||||
);
|
||||
let sin_k = compile(
|
||||
r#"
|
||||
extern "C" __global__ void sin_k(float* out, const float* in, long long n) {
|
||||
long long i = (long long)blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (i >= n) return;
|
||||
out[i] = sinf(in[i]);
|
||||
}
|
||||
"#,
|
||||
"sin_k",
|
||||
);
|
||||
let sqrt_k = compile(
|
||||
r#"
|
||||
extern "C" __global__ void sqrt_k(float* out, const float* in, long long n) {
|
||||
long long i = (long long)blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (i >= n) return;
|
||||
out[i] = sqrtf(in[i]);
|
||||
}
|
||||
"#,
|
||||
"sqrt_k",
|
||||
);
|
||||
let fused_k = compile(
|
||||
r#"
|
||||
extern "C" __global__ void fused_k(float* out, const float* a, const float* b, long long n) {
|
||||
long long i = (long long)blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (i >= n) return;
|
||||
float v = a[i] + b[i];
|
||||
v = sinf(v);
|
||||
v = sqrtf(v);
|
||||
out[i] = v;
|
||||
}
|
||||
"#,
|
||||
"fused_k",
|
||||
);
|
||||
|
||||
let cfg = LaunchConfig::for_num_elems(N as u32);
|
||||
let n_arg: i64 = N as i64;
|
||||
|
||||
let launch_unfused =
|
||||
|d_out: &mut cudarc::driver::CudaSlice<f32>,
|
||||
d_scratch1: &mut cudarc::driver::CudaSlice<f32>,
|
||||
d_scratch2: &mut cudarc::driver::CudaSlice<f32>| {
|
||||
let mut b = stream.launch_builder(&add_k);
|
||||
b.arg(&mut *d_scratch1).arg(&d_a).arg(&d_b).arg(&n_arg);
|
||||
unsafe { b.launch(cfg) }.unwrap();
|
||||
let mut b = stream.launch_builder(&sin_k);
|
||||
b.arg(&mut *d_scratch2).arg(&*d_scratch1).arg(&n_arg);
|
||||
unsafe { b.launch(cfg) }.unwrap();
|
||||
let mut b = stream.launch_builder(&sqrt_k);
|
||||
b.arg(d_out).arg(&*d_scratch2).arg(&n_arg);
|
||||
unsafe { b.launch(cfg) }.unwrap();
|
||||
};
|
||||
let launch_fused = |d_out: &mut cudarc::driver::CudaSlice<f32>| {
|
||||
let mut b = stream.launch_builder(&fused_k);
|
||||
b.arg(d_out).arg(&d_a).arg(&d_b).arg(&n_arg);
|
||||
unsafe { b.launch(cfg) }.unwrap();
|
||||
};
|
||||
|
||||
// Warmup
|
||||
for _ in 0..WARMUP {
|
||||
launch_unfused(&mut d_out, &mut d_scratch1, &mut d_scratch2);
|
||||
launch_fused(&mut d_out);
|
||||
}
|
||||
stream.synchronize().unwrap();
|
||||
|
||||
// Host-side wall-clock timing: synchronize before/after each batch so the
|
||||
// measured interval covers exactly the GPU work for `TRIALS` iterations.
|
||||
// (CUDA event-based timing is the more precise option in principle, but
|
||||
// `event.elapsed_ms` on this driver/cudarc combo errors with
|
||||
// CUDA_ERROR_INVALID_HANDLE — see bench_fused_vs_unfused_sqrt_recip
|
||||
// above which fails the same way. Wall-clock is reliable here.)
|
||||
let unfused_start = std::time::Instant::now();
|
||||
for _ in 0..TRIALS {
|
||||
launch_unfused(&mut d_out, &mut d_scratch1, &mut d_scratch2);
|
||||
}
|
||||
stream.synchronize().unwrap();
|
||||
let unfused_total_ms = unfused_start.elapsed().as_secs_f64() * 1_000.0;
|
||||
|
||||
let fused_start = std::time::Instant::now();
|
||||
for _ in 0..TRIALS {
|
||||
launch_fused(&mut d_out);
|
||||
}
|
||||
stream.synchronize().unwrap();
|
||||
let fused_total_ms = fused_start.elapsed().as_secs_f64() * 1_000.0;
|
||||
|
||||
let unfused_us = unfused_total_ms * 1_000.0 / TRIALS as f64;
|
||||
let fused_us = fused_total_ms * 1_000.0 / TRIALS as f64;
|
||||
let speedup = unfused_us / fused_us;
|
||||
|
||||
println!(
|
||||
"\n[fusion microbench, (a+b).sin().sqrt(), N={N}, trials={TRIALS}]\n\
|
||||
unfused (add_k; sin_k; sqrt_k): {unfused_us:8.3} us/iter ({unfused_total_ms:.2} ms total)\n\
|
||||
fused (one kernel): {fused_us:8.3} us/iter ({fused_total_ms:.2} ms total)\n\
|
||||
speedup: {speedup:.2}x"
|
||||
);
|
||||
}
|
||||
@@ -5,10 +5,14 @@ mod bucket_tests;
|
||||
#[cfg(test)]
|
||||
mod consumed_buffer_tests;
|
||||
#[cfg(test)]
|
||||
mod fusion;
|
||||
#[cfg(test)]
|
||||
mod model_fuzz;
|
||||
#[cfg(test)]
|
||||
mod op_functional_tests;
|
||||
#[cfg(test)]
|
||||
mod performance_tests;
|
||||
#[cfg(test)]
|
||||
mod qwen3_moe_rewrite;
|
||||
#[cfg(test)]
|
||||
mod transformer;
|
||||
|
||||
314
crates/luminal_cuda_lite/src/tests/qwen3_moe_rewrite.rs
Normal file
314
crates/luminal_cuda_lite/src/tests/qwen3_moe_rewrite.rs
Normal file
@@ -0,0 +1,314 @@
|
||||
use half::bf16;
|
||||
use luminal::{dtype::DType, prelude::*, shape::Expression};
|
||||
|
||||
use super::utilities::{assert_close, get_cuda_stream, random_f32_vec};
|
||||
use crate::{
|
||||
host::{
|
||||
HostOp,
|
||||
moe::{GLUMoE, GLUMoEMode},
|
||||
},
|
||||
runtime::CudaRuntime,
|
||||
};
|
||||
|
||||
const SEQ: usize = 2;
|
||||
const HIDDEN: usize = 16;
|
||||
const NUM_EXPERTS: usize = 8;
|
||||
const TOP_K: usize = 2;
|
||||
const MOE_INTERMEDIATE: usize = 6;
|
||||
const RMS_NORM_EPS: f32 = 1e-6;
|
||||
|
||||
struct QwenMoeGraph {
|
||||
graph: Graph,
|
||||
x: GraphTensor,
|
||||
router: GraphTensor,
|
||||
gate_up_weights: GraphTensor,
|
||||
down_weights: GraphTensor,
|
||||
output: GraphTensor,
|
||||
}
|
||||
|
||||
struct GemmaMoeGraph {
|
||||
graph: Graph,
|
||||
router_input: GraphTensor,
|
||||
expert_input: GraphTensor,
|
||||
router_scale: GraphTensor,
|
||||
router_proj: GraphTensor,
|
||||
per_expert_scale: GraphTensor,
|
||||
gate_up_weights: GraphTensor,
|
||||
down_weights: GraphTensor,
|
||||
output: GraphTensor,
|
||||
}
|
||||
|
||||
fn build_qwen_moe_graph() -> QwenMoeGraph {
|
||||
let mut cx = Graph::default();
|
||||
let x = cx.tensor(('s', HIDDEN));
|
||||
let router = cx.tensor((NUM_EXPERTS, HIDDEN));
|
||||
let gate_up_weights = cx
|
||||
.tensor((NUM_EXPERTS, MOE_INTERMEDIATE * 2, HIDDEN))
|
||||
.as_dtype(DType::Bf16);
|
||||
let down_weights = cx
|
||||
.tensor((NUM_EXPERTS, HIDDEN, MOE_INTERMEDIATE))
|
||||
.as_dtype(DType::Bf16);
|
||||
|
||||
let n = x.dims().len();
|
||||
let e_dim = *router.dims().first().unwrap();
|
||||
let k_expr = Expression::from(TOP_K);
|
||||
|
||||
let routing_weights = x.matmul(router.t()).softmax(n - 1);
|
||||
let top_k_indices = routing_weights.topk_indexes(TOP_K, n - 1);
|
||||
|
||||
let row_offsets = x
|
||||
.graph()
|
||||
.iota(Expression::from('z') / k_expr * e_dim, top_k_indices.dims());
|
||||
let routing_flat_idx = row_offsets + top_k_indices;
|
||||
let top_k_values = routing_weights.gather(routing_flat_idx);
|
||||
|
||||
let gate_up_gathered = gather_experts(x, top_k_indices, gate_up_weights).cast(DType::F32);
|
||||
let x_exp = x.expand_dim(n - 1, TOP_K).unsqueeze(n);
|
||||
let gate_up_out = x_exp.matmul(gate_up_gathered.transpose(2, 3)).squeeze(n);
|
||||
let gate = gate_up_out.slice((.., .., ..MOE_INTERMEDIATE));
|
||||
let up = gate_up_out.slice((.., .., MOE_INTERMEDIATE..));
|
||||
let hidden = gate.silu() * up;
|
||||
|
||||
let down_gathered = gather_experts(x, top_k_indices, down_weights).cast(DType::F32);
|
||||
let down_out = hidden
|
||||
.unsqueeze(2)
|
||||
.matmul(down_gathered.transpose(2, 3))
|
||||
.squeeze(2);
|
||||
let output = (down_out * top_k_values.unsqueeze(top_k_values.dims().len()))
|
||||
.sum(n - 1)
|
||||
.output();
|
||||
|
||||
QwenMoeGraph {
|
||||
graph: cx,
|
||||
x,
|
||||
router,
|
||||
gate_up_weights,
|
||||
down_weights,
|
||||
output,
|
||||
}
|
||||
}
|
||||
|
||||
fn build_gemma_moe_graph() -> GemmaMoeGraph {
|
||||
let mut cx = Graph::default();
|
||||
let router_input = cx.tensor(('s', HIDDEN));
|
||||
let expert_input = cx.tensor(('s', HIDDEN));
|
||||
let router_scale = cx.tensor(HIDDEN);
|
||||
let router_proj = cx.tensor((NUM_EXPERTS, HIDDEN));
|
||||
let per_expert_scale = cx.tensor(NUM_EXPERTS);
|
||||
let gate_up_weights = cx
|
||||
.tensor((NUM_EXPERTS, MOE_INTERMEDIATE * 2, HIDDEN))
|
||||
.as_dtype(DType::Bf16);
|
||||
let down_weights = cx
|
||||
.tensor((NUM_EXPERTS, HIDDEN, MOE_INTERMEDIATE))
|
||||
.as_dtype(DType::Bf16);
|
||||
|
||||
let n = router_input.dims().len();
|
||||
let e_dim = *router_proj.dims().first().unwrap();
|
||||
let k_expr = Expression::from(TOP_K);
|
||||
|
||||
let router_hidden = router_input.std_norm(n - 1, RMS_NORM_EPS)
|
||||
* router_scale.expand_lhs(&router_input.dims()[..n - 1])
|
||||
* (HIDDEN as f32).sqrt().recip();
|
||||
let routing_weights = router_hidden.matmul(router_proj.t()).softmax(n - 1);
|
||||
|
||||
let top_k_indices = routing_weights.topk_indexes(TOP_K, n - 1);
|
||||
let row_offsets = router_input
|
||||
.graph()
|
||||
.iota(Expression::from('z') / k_expr * e_dim, top_k_indices.dims());
|
||||
let routing_flat_idx = row_offsets + top_k_indices;
|
||||
let top_k_values = routing_weights.gather(routing_flat_idx);
|
||||
let top_k_norm = top_k_values.sum(n - 1).expand_dim(n - 1, TOP_K);
|
||||
let top_k_weights = (top_k_values / top_k_norm) * per_expert_scale.gather(top_k_indices);
|
||||
|
||||
let gate_up_gathered =
|
||||
gather_experts(expert_input, top_k_indices, gate_up_weights).cast(DType::F32);
|
||||
let x_exp = expert_input.expand_dim(n - 1, TOP_K).unsqueeze(n);
|
||||
let gate_up_out = x_exp.matmul(gate_up_gathered.transpose(2, 3)).squeeze(n);
|
||||
let gate = gate_up_out.slice((.., .., ..MOE_INTERMEDIATE));
|
||||
let up = gate_up_out.slice((.., .., MOE_INTERMEDIATE..));
|
||||
let hidden = gemma_gelu(gate) * up;
|
||||
|
||||
let down_gathered = gather_experts(expert_input, top_k_indices, down_weights).cast(DType::F32);
|
||||
let down_out = hidden
|
||||
.unsqueeze(2)
|
||||
.matmul(down_gathered.transpose(2, 3))
|
||||
.squeeze(2);
|
||||
let output = (down_out * top_k_weights.unsqueeze(top_k_weights.dims().len()))
|
||||
.sum(n - 1)
|
||||
.output();
|
||||
|
||||
GemmaMoeGraph {
|
||||
graph: cx,
|
||||
router_input,
|
||||
expert_input,
|
||||
router_scale,
|
||||
router_proj,
|
||||
per_expert_scale,
|
||||
gate_up_weights,
|
||||
down_weights,
|
||||
output,
|
||||
}
|
||||
}
|
||||
|
||||
fn gather_experts(
|
||||
graph_source: GraphTensor,
|
||||
top_k_indices: GraphTensor,
|
||||
weights: GraphTensor,
|
||||
) -> GraphTensor {
|
||||
let (_, d1, d2) = weights.dims3();
|
||||
let io = d1 * d2;
|
||||
let base = top_k_indices * io;
|
||||
let within = graph_source.graph().iota(Expression::from('z'), (d1, d2));
|
||||
let n_base = base.dims().len();
|
||||
let exp_base = base.expand_dim(n_base, d1).expand_dim(n_base + 1, d2);
|
||||
let mut exp_within = within;
|
||||
for (axis, dim) in base.dims().iter().enumerate() {
|
||||
exp_within = exp_within.expand_dim(axis, *dim);
|
||||
}
|
||||
let expert_flat_idx = exp_base + exp_within;
|
||||
weights.gather(expert_flat_idx)
|
||||
}
|
||||
|
||||
#[allow(clippy::excessive_precision)]
|
||||
fn gemma_gelu(x: GraphTensor) -> GraphTensor {
|
||||
let scaled = 1.5957691216 * x * (1. + 0.044715 * x * x);
|
||||
x * scaled.sigmoid()
|
||||
}
|
||||
|
||||
fn glumoe_modes(rt: &CudaRuntime) -> Vec<GLUMoEMode> {
|
||||
rt.llir_graph()
|
||||
.node_weights()
|
||||
.filter_map(|node| {
|
||||
let op = node.to_dialect::<dyn HostOp>()?;
|
||||
op.as_any()
|
||||
.downcast_ref::<GLUMoE>()
|
||||
.map(|glumoe| glumoe.mode)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn run_qwen_moe(use_glumoe: bool) -> (Vec<f32>, Vec<GLUMoEMode>) {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return (vec![], vec![]);
|
||||
};
|
||||
|
||||
let mut model = build_qwen_moe_graph();
|
||||
model.graph.set_dim('s', SEQ);
|
||||
if use_glumoe {
|
||||
model.graph.build_search_space::<CudaRuntime>();
|
||||
} else {
|
||||
model
|
||||
.graph
|
||||
.build_search_space_exclude_ops::<CudaRuntime, GLUMoE>();
|
||||
}
|
||||
|
||||
let x_data = random_f32_vec(SEQ * HIDDEN, 11, -0.15, 0.15);
|
||||
let router_data = random_f32_vec(NUM_EXPERTS * HIDDEN, 12, -0.2, 0.2);
|
||||
let gate_up_data = random_f32_vec(NUM_EXPERTS * MOE_INTERMEDIATE * 2 * HIDDEN, 13, -0.1, 0.1)
|
||||
.into_iter()
|
||||
.map(bf16::from_f32)
|
||||
.collect::<Vec<_>>();
|
||||
let down_data = random_f32_vec(NUM_EXPERTS * HIDDEN * MOE_INTERMEDIATE, 14, -0.1, 0.1)
|
||||
.into_iter()
|
||||
.map(bf16::from_f32)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.set_data(model.x, x_data);
|
||||
rt.set_data(model.router, router_data);
|
||||
rt.set_data(model.gate_up_weights, gate_up_data);
|
||||
rt.set_data(model.down_weights, down_data);
|
||||
rt = model.graph.search(rt, 10);
|
||||
rt.execute(&model.graph.dyn_map);
|
||||
|
||||
(rt.get_f32(model.output.id), glumoe_modes(&rt))
|
||||
}
|
||||
|
||||
fn run_gemma_moe(use_glumoe: bool) -> (Vec<f32>, Vec<GLUMoEMode>) {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return (vec![], vec![]);
|
||||
};
|
||||
|
||||
let mut model = build_gemma_moe_graph();
|
||||
model.graph.set_dim('s', SEQ);
|
||||
if use_glumoe {
|
||||
model.graph.build_search_space::<CudaRuntime>();
|
||||
} else {
|
||||
model
|
||||
.graph
|
||||
.build_search_space_exclude_ops::<CudaRuntime, GLUMoE>();
|
||||
}
|
||||
|
||||
let router_input_data = random_f32_vec(SEQ * HIDDEN, 21, -0.15, 0.15);
|
||||
let expert_input_data = random_f32_vec(SEQ * HIDDEN, 22, -0.15, 0.15);
|
||||
let router_scale_data = random_f32_vec(HIDDEN, 23, 0.7, 1.3);
|
||||
let router_proj_data = random_f32_vec(NUM_EXPERTS * HIDDEN, 24, -0.2, 0.2);
|
||||
let per_expert_scale_data = random_f32_vec(NUM_EXPERTS, 25, 0.5, 1.5);
|
||||
let gate_up_data = random_f32_vec(NUM_EXPERTS * MOE_INTERMEDIATE * 2 * HIDDEN, 26, -0.1, 0.1)
|
||||
.into_iter()
|
||||
.map(bf16::from_f32)
|
||||
.collect::<Vec<_>>();
|
||||
let down_data = random_f32_vec(NUM_EXPERTS * HIDDEN * MOE_INTERMEDIATE, 27, -0.1, 0.1)
|
||||
.into_iter()
|
||||
.map(bf16::from_f32)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.set_data(model.router_input, router_input_data);
|
||||
rt.set_data(model.expert_input, expert_input_data);
|
||||
rt.set_data(model.router_scale, router_scale_data);
|
||||
rt.set_data(model.router_proj, router_proj_data);
|
||||
rt.set_data(model.per_expert_scale, per_expert_scale_data);
|
||||
rt.set_data(model.gate_up_weights, gate_up_data);
|
||||
rt.set_data(model.down_weights, down_data);
|
||||
rt = model.graph.search(rt, 10);
|
||||
rt.execute(&model.graph.dyn_map);
|
||||
|
||||
(rt.get_f32(model.output.id), glumoe_modes(&rt))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_glumoe_matches_qwen_swiglu_pattern() {
|
||||
let (_result, modes) = run_qwen_moe(true);
|
||||
if modes.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
assert_eq!(modes, vec![GLUMoEMode::SwiGLU]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_glumoe_matches_gemma_gelu_pattern() {
|
||||
let (_result, modes) = run_gemma_moe(true);
|
||||
if modes.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
assert_eq!(modes, vec![GLUMoEMode::GemmaGELU]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_glumoe_swiglu_matches_unfused_output() {
|
||||
let (expected, baseline_modes) = run_qwen_moe(false);
|
||||
if expected.is_empty() {
|
||||
return;
|
||||
}
|
||||
assert!(baseline_modes.is_empty());
|
||||
|
||||
let (actual, fused_modes) = run_qwen_moe(true);
|
||||
assert_eq!(fused_modes, vec![GLUMoEMode::SwiGLU]);
|
||||
assert_close(&actual, &expected, 3e-2, 3e-2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_glumoe_gemma_gelu_matches_unfused_output() {
|
||||
let (expected, baseline_modes) = run_gemma_moe(false);
|
||||
if expected.is_empty() {
|
||||
return;
|
||||
}
|
||||
assert!(baseline_modes.is_empty());
|
||||
|
||||
let (actual, fused_modes) = run_gemma_moe(true);
|
||||
assert_eq!(fused_modes, vec![GLUMoEMode::GemmaGELU]);
|
||||
assert_close(&actual, &expected, 3e-2, 3e-2);
|
||||
}
|
||||
@@ -300,7 +300,7 @@ fn test_mini_transformer_two_layers() {
|
||||
let input = cx.tensor((SEQ, HIDDEN));
|
||||
let layer1 = MiniTransformerLayer::init(&mut cx);
|
||||
let layer2 = MiniTransformerLayer::init(&mut cx);
|
||||
let x = layer1.forward(input).graph_break();
|
||||
let x = layer1.forward(input);
|
||||
let out = layer2.forward(x).output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
@@ -508,3 +508,32 @@ fn test_swiglu_mlp_cuda() {
|
||||
|
||||
assert_close(&result, &expected, 1e-3, 1e-3);
|
||||
}
|
||||
|
||||
/// Body=1, trips=3 chain of scalar Muls plus a residual back to the
|
||||
/// chain's initial value. Auto-rolling sees this as a state-carrying loop
|
||||
/// with state at input position 0; the rolled HLIR must round-trip through
|
||||
/// egglog (rolled body Mul + LoopStart/LoopInput/LoopEnd markers) and
|
||||
/// `unroll_loops_in_llir` must reconstruct the flat 3-mul chain plus
|
||||
/// rewire the residual edge to reference the chain's initial input
|
||||
/// (outside the body) — not a per-iter clone.
|
||||
#[test]
|
||||
fn test_rolled_chained_scalar_muls() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
let mut cx = Graph::default();
|
||||
let x = cx.tensor((1, 4, 32));
|
||||
let chained = ((x * 2.0_f32) * 3.0_f32) * 5.0_f32;
|
||||
let out = (chained + x).output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
let x_data = random_f32_vec(4 * 32, 101, -0.5, 0.5);
|
||||
rt.set_data(x, x_data.clone());
|
||||
rt = cx.search(rt, 3);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
let result = rt.get_f32(out);
|
||||
let expected: Vec<f32> = x_data.iter().map(|v| v * 2.0 * 3.0 * 5.0 + v).collect();
|
||||
assert_close(&result, &expected, 1e-5, 1e-5);
|
||||
}
|
||||
|
||||
@@ -468,7 +468,7 @@ pub fn fuzz_genomes<T: TestDType>(
|
||||
|
||||
let mut list_cache = FxHashMap::default();
|
||||
let mut expr_cache = FxHashMap::default();
|
||||
let llir_graph = egglog_to_llir(
|
||||
let mut llir_graph = egglog_to_llir(
|
||||
egraph,
|
||||
genome.clone(),
|
||||
ops,
|
||||
@@ -477,6 +477,12 @@ pub fn fuzz_genomes<T: TestDType>(
|
||||
&mut expr_cache,
|
||||
None,
|
||||
);
|
||||
// Same finalization as `Graph::search` performs on the chosen
|
||||
// best LLIR: collapse the rolled body's loop markers into a
|
||||
// fully-unrolled LLIR. The runtime cannot execute LoopStart /
|
||||
// LoopEnd / LoopInput / LoopOutput markers — they exist only as
|
||||
// a search-time scaffold the auto-roll prepass introduces.
|
||||
unroll_loops_in_llir(&mut llir_graph);
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream.clone());
|
||||
rt.load_llir(&llir_graph);
|
||||
|
||||
48
crates/luminal_metal/src/dyn_backend.rs
Normal file
48
crates/luminal_metal/src/dyn_backend.rs
Normal file
@@ -0,0 +1,48 @@
|
||||
//! [`DynBackend`] implementation for the Metal runtime.
|
||||
|
||||
use luminal::dtype::DType;
|
||||
use luminal::dyn_backend::{bytes_to_native_data, compile_backend, BackendCompileArgs, DynBackend};
|
||||
use luminal::prelude::*;
|
||||
|
||||
use crate::runtime::MetalRuntime;
|
||||
|
||||
/// [`DynBackend`] wrapper for [`MetalRuntime`].
|
||||
pub struct MetalDynBackend {
|
||||
pub runtime: MetalRuntime,
|
||||
}
|
||||
|
||||
impl DynBackend for MetalDynBackend {
|
||||
fn name(&self) -> &str {
|
||||
"metal"
|
||||
}
|
||||
|
||||
fn set_data_bytes(&mut self, node: NodeIndex, bytes: Vec<u8>, dtype: DType) {
|
||||
self.runtime
|
||||
.set_data(node, bytes_to_native_data(bytes, dtype));
|
||||
}
|
||||
fn set_data_f32(&mut self, node: NodeIndex, data: Vec<f32>) {
|
||||
self.runtime.set_data(node, data);
|
||||
}
|
||||
fn get_output_f32(&self, node: NodeIndex) -> Vec<f32> {
|
||||
self.runtime.get_f32(node)
|
||||
}
|
||||
fn execute(&mut self, dyn_map: &FxHashMap<char, usize>) {
|
||||
self.runtime.execute(dyn_map);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn metal_factory(
|
||||
graph: &mut Graph,
|
||||
args: BackendCompileArgs,
|
||||
) -> Result<Box<dyn DynBackend>, String> {
|
||||
compile_backend::<MetalRuntime>(
|
||||
graph,
|
||||
args,
|
||||
|| Ok(MetalRuntime::initialize(())),
|
||||
|rt, node, bytes, dtype| {
|
||||
rt.set_data(node, bytes_to_native_data(bytes, dtype));
|
||||
},
|
||||
None,
|
||||
|rt| Box::new(MetalDynBackend { runtime: rt }),
|
||||
)
|
||||
}
|
||||
@@ -1,3 +1,4 @@
|
||||
pub mod dyn_backend;
|
||||
pub mod kernel;
|
||||
pub mod runtime;
|
||||
|
||||
|
||||
@@ -234,6 +234,10 @@ impl Runtime for MetalRuntime {
|
||||
}
|
||||
}
|
||||
|
||||
fn aggregate_profile_metrics(metrics: &[Self::ProfileMetric]) -> Self::ProfileMetric {
|
||||
metrics.iter().copied().sum()
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
fn load_llir(&mut self, llir_graph: &LLIRGraph) {
|
||||
self.pipelines.clear();
|
||||
|
||||
@@ -749,6 +749,92 @@ candidates rejected" during search, check whether the rejection is from actual f
|
||||
or from dtype misinterpretation — the key diagnostic is whether the NaN pattern is
|
||||
identical across all attempts (dtype issue) vs varying (actual numerical issue).
|
||||
|
||||
## 2026-04-22 — Benchmark python_luminal Path: NativeRuntime Panic on CUDA Weights
|
||||
|
||||
### What the symptom was
|
||||
|
||||
Running `benchmarks/ttft/run.py` with the `python_luminal` path panicked deep in Rust:
|
||||
|
||||
```
|
||||
thread panicked at src/hlir.rs:2239:40: no entry found for key
|
||||
```
|
||||
|
||||
The panic occurred in `NativeRuntime::execute` when the `Output` node tried to read its
|
||||
predecessor's buffer from `self.buffers` — and the buffer wasn't there.
|
||||
|
||||
### What the actual root cause was
|
||||
|
||||
The luminal Python wheel was built without `--features cuda` (plain `maturin build --release`).
|
||||
This means `_cuda_lite_factory_capsule` is not compiled into the `.so` file. In `main.py`,
|
||||
`_detect_factory_capsule` catches the resulting `ImportError` and **silently** falls back to
|
||||
`_native_factory_capsule` (NativeRuntime / CPU runtime).
|
||||
|
||||
The benchmark model (`LlamaForCausalLM.from_pretrained(...).to("cuda")`) has all weights as
|
||||
CUDA device pointers. `BackendCompileArgs.device_ptrs` is populated with these GPU pointers.
|
||||
NativeRuntime has no mechanism to handle GPU-resident weight data — the `device_ptrs` map is
|
||||
simply ignored. After search completes (it can search because it uses dummy CPU data during
|
||||
profiling), the first real `execute()` call processes the graph:
|
||||
|
||||
1. `Input` nodes are skipped (their buffers should be pre-populated by `set_input_from_ptr`)
|
||||
2. Weight `Input` nodes were set via `set_input_device_ptr` — but NativeRuntime's
|
||||
`set_input_device_ptr` likely no-ops or stores garbage, leaving those buffers empty
|
||||
3. The `Output` node looks up its predecessor's buffer → key not found → panic
|
||||
|
||||
### Why it was hard to find
|
||||
|
||||
1. **Silent fallback**: `_detect_factory_capsule` catches `ImportError` without logging a
|
||||
warning. Nothing in stdout indicates you're running on CPU when the model is on GPU.
|
||||
2. **Search succeeds**: The e-graph search runs to completion (searches 1 group, 1 chunk in
|
||||
~15s) because it uses 1.0f32 dummy data that doesn't need GPU. The failure only occurs at
|
||||
first real execution.
|
||||
3. **Misleading error site**: `hlir.rs:2239` is in NativeRuntime's buffer-copy loop for Output
|
||||
nodes — it gives no indication that the root cause is a missing CUDA feature flag at build time.
|
||||
4. **Backtrace required**: Without `RUST_BACKTRACE=1`, only the panic message is visible;
|
||||
the `NativeRuntime` frame that reveals the CPU fallback is hidden.
|
||||
|
||||
### The fix
|
||||
|
||||
Rebuild the wheel with CUDA support:
|
||||
```bash
|
||||
maturin build --release --features cuda
|
||||
pip install target/wheels/luminal_python-*.whl --force-reinstall
|
||||
```
|
||||
|
||||
Or via the test runner: `./run_tests_cuda.sh` uses `maturin develop --features cuda -r`.
|
||||
|
||||
Consider adding an explicit warning or error in `_detect_factory_capsule` when CUDA inputs are
|
||||
detected but no CUDA factory is available:
|
||||
|
||||
```python
|
||||
if device.type == "cuda":
|
||||
try:
|
||||
from .luminal import _cuda_lite_factory_capsule
|
||||
return _cuda_lite_factory_capsule()
|
||||
except ImportError:
|
||||
import warnings
|
||||
warnings.warn(
|
||||
"CUDA inputs detected but luminal was built without --features cuda. "
|
||||
"Falling back to NativeRuntime (CPU) — this will likely panic at runtime.",
|
||||
RuntimeWarning,
|
||||
stacklevel=3,
|
||||
)
|
||||
```
|
||||
|
||||
### The regression test
|
||||
|
||||
`test_hf_llama3_8b_instruct_1layer` in `tests/test_llama3.py` — tests the exact architecture
|
||||
from the benchmark (Meta-Llama-3-8B-Instruct, 4096 hidden, 32 attn heads, 8 KV heads) with
|
||||
1 layer and random weights. This test passes with `--features cuda` and panics without it.
|
||||
|
||||
### General principle
|
||||
|
||||
**When a feature gate silently changes the runtime backend, assert that the selected backend
|
||||
is compatible with the input device.** A CUDA tensor flowing into a CPU-only runtime is always
|
||||
a programming error, not a graceful degradation. The failure should surface at factory
|
||||
selection time (with a clear error message), not deep in a Rust buffer-copy loop.
|
||||
|
||||
---
|
||||
|
||||
## 2026-03-25 — KernelExp/KernelSigmoid: Fused CUDA Kernels for Precision
|
||||
|
||||
1. **Symptom**: `test_hf_llama3_full` (16-layer Llama-3.2-1B) had ~1e-4 max diff vs PyTorch.
|
||||
@@ -756,3 +842,211 @@ identical across all attempts (dtype issue) vs varying (actual numerical issue).
|
||||
3. **Why hard**: Per-operation error was ~1e-7 but compounded over 16 layers × ~25 extra materializations. The egglog `Exp` rewrite depends on exact constant format matching.
|
||||
4. **Fix**: Added `KernelExp` (uses `expf()`), `KernelSigmoid` (uses `1/(1+expf(-x))`), and Kahan summation in SumReduce. Each uses both `kernel_rewrite` and a direct egglog pattern match with range checks (e.g., `(> ?val 1.44) (< ?val 1.45)`) to bypass constant format dependency.
|
||||
5. **Principle**: When decomposed CUDA kernel chains cause precision loss, add fused kernels via `kernel_rewrite`. For robustness, add BOTH the logical-op rewrite path AND a direct HLIR pattern match — the constant format in egglog can be fragile.
|
||||
|
||||
---
|
||||
|
||||
## 2026-04-23 — NativeRuntime Multi-Call Panic: Input Buffers Cleared After Each Run
|
||||
|
||||
1. **Symptom**: The compiled model panicked with `hlir.rs:XXXX: no entry found for key` on the second call. First call succeeded; subsequent calls failed.
|
||||
2. **Root cause**: `NativeRuntime::execute` in `src/hlir.rs` called `self.buffers.retain(|k, _| output_nodes.contains(k))` after each run to free intermediate buffers. This correctly pruned temporary buffers but also pruned the Input-node buffers that hold model weights — so on the second call, the weight tensors were gone.
|
||||
3. **Why hard**: The bug never manifested in the test suite because every test called the compiled model exactly once per compile. The issue only appeared when running a bench loop that called the model multiple times. The panic location (deep in buffer lookup) gave no indication that the root cause was in the buffer retention policy.
|
||||
4. **Fix**: Changed the retain predicate to keep both `Output` and `Input` nodes:
|
||||
```rust
|
||||
let keep_nodes = graph.node_indices()
|
||||
.filter(|n| is::<Output> || is::<Input>)
|
||||
.collect();
|
||||
self.buffers.retain(|k, _| keep_nodes.contains(k));
|
||||
```
|
||||
5. **Principle**: When buffer lifetime policies are changed to free memory after a run, always verify that *persistent* state (model weights stored in Input nodes) is excluded from the cleanup sweep. A test that compiles + calls once per test function will never catch a multi-call regression — add a dedicated multi-call test for any compiled runtime.
|
||||
|
||||
---
|
||||
|
||||
## 2026-04-23 — PT2 USER_INPUT_MUTATION Outputs Confuse Dynamo Caller
|
||||
|
||||
1. **Symptom**: With `StaticCache`, the compiled model returned `[1]` (cumulative_length update) instead of `[1, vocab_size]` logits. The wrong tensor was silently mapped to the output variable.
|
||||
2. **Root cause**: When `torch.export` encounters in-place mutations to input tensors (KV cache updates via `index_copy_`), it lifts them as `USER_INPUT_MUTATION` output specs, placed *before* the actual `USER_OUTPUT` logits in `ep.graph_signature.output_specs`. The compiled model returned all outputs; dynamo mapped index 0 (the mutation) to the first return value.
|
||||
3. **Why hard**: The output shape `[1]` from `cumulative_length` looked like a valid (though wrong) output. No error was raised — just wrong logits. Required inspecting `ep.graph_signature.output_specs` and understanding the ordering convention for different `OutputKind` values.
|
||||
4. **Fix**: In `pt2_backend`, parse `output_specs` to build a `mutation_mappings` list and `user_output_indices`. Wrap the compiled model to: (a) copy mutation outputs back into the corresponding input tensors, and (b) return only the `USER_OUTPUT` tensors.
|
||||
5. **Principle**: After `torch.export(...).run_decompositions()`, always inspect `ep.graph_signature.output_specs` when the model has in-place operations (KV cache, BN running stats). The output ordering is: mutations first, then actual outputs — and the caller only expects actual outputs.
|
||||
|
||||
---
|
||||
|
||||
## 2026-04-23 — CUDA Version Mismatch: torch+cuXXX Must Match System Driver
|
||||
|
||||
1. **Symptom**: `torch.cuda.is_available()` returned `False` despite `nvidia-smi` showing a GPU. Warning: "CUDA initialization: The NVIDIA driver on your system is too old (found version 12080)."
|
||||
2. **Root cause**: `torch==2.11.0+cu130` requires CUDA 13.0 which needs driver >= 575. The system has driver 570 (CUDA 12.8 max). The mismatch caused silent CPU fallback — no error, just False from `is_available()`.
|
||||
3. **Why hard**: The bench appeared to start successfully (model loaded, compilation ran) but produced no results because it was running an 8B model on CPU. Zero output with exit code 0 looked like a hang or silent crash.
|
||||
4. **Fix**: Installed `torch==2.11.0+cu128` from `https://download.pytorch.org/whl/cu128`. CUDA 12.8 matches driver 570. Also needed matching `torchvision==0.26.0+cu128` and the `nvidia-cusparselt-cu12` runtime library.
|
||||
5. **Principle**: Before running any CUDA-dependent bench or test, verify `torch.cuda.is_available()` returns `True`. Check `nvidia-smi` CUDA Version field against the `+cuXXX` suffix in `torch.__version__` — they must match (CUDA runtime ≤ driver's max supported version). Never assume CPU fallback "works" for large model benchmarks.
|
||||
|
||||
---
|
||||
|
||||
## 2026-04-26 — Loop unroll-union rules silently disabled in full egglog stage
|
||||
|
||||
1. **Symptom**: Python `test_llama_transformer_block` (CUDA backend) produced output ~1e-2 off from PyTorch (atol=1e-4) on the `loop_rolling` branch. All component tests (RMSNorm, attention, SwiGLU, RoPE) passed. The diff pattern was suspicious: row 0 of the (1,4,32) output matched exactly, rows 1–3 differed slightly. Disabling rolling fixed it.
|
||||
2. **Root cause**: The auto-roll prepass folds three sequential scalar muls in PyTorch's `pow(2)` decomposition (`exp2(log2(x) * 0.693 * 2.0 * 1.442)` — the last constant is `log2(e)`). The kernel `direct-exp-fusion` egglog rule rewrites `Mul(?x, log2_e_const) → Exp2(...)` into `KernelExp(?x)` (single `expf()` instead of separate exp2f + multiply by truncated log2(e)). Without rolling, this fusion fires and the float chain stays stable; with rolling the fusion can't see through the `LoopStart`/`LoopEnd` markers, so the chain stays as `KernelMul → KernelExp2`, and the truncated `log2(e)` constant accumulates ~1e-7 error per layer that compounds into ~1e-2 over the full block.
|
||||
|
||||
The unroll-union rules I'd added (`Mul`/`Add`/etc. binary-op rules that union a rolled body with its fully-unrolled equivalent) were registered only in `EgglogOp::early_rewrites()`, not `rewrites()`. The egglog driver feeds `early_rewrites` only into the early-stage program and `rewrites` only into the full-stage program. So the unrolled chain materialised in the early egraph, the early→full extract picked the (cheaper) rolled form, the unrolled chain was lost, and `direct-exp-fusion` (which runs in the full stage) had nothing to match against.
|
||||
3. **Why hard**: The post-unroll LLIR for the rolled vs un-rolled paths *looked* nearly identical when scanned visually — both had the Log2 → Mul × 3 → Exp2 chain. The diff was 2 extra Muls vs no-rolling, and the actual semantic gap was visible only in op-name counts: WITH-rolling had 3 `KernelExp2` and 0 `KernelExp`, WITHOUT-rolling had 1 `KernelExp2` and 2 `KernelExp`. Tracking the missing fusion to the early/full ruleset split required reading the egglog driver carefully and noticing that `OpTextParts` builds `early_rewrites` and `full_rewrites` from disjoint method calls.
|
||||
4. **Fix**: Register `binary_op_unroll_rules` in BOTH `early_rewrites()` (so fusion patterns like GLUMoE can match before the early-stage extract, which is what fixed `test_glumoe_gemma_gelu_matches_unfused_output` earlier in the session) AND `rewrites()` (so kernel-level rewrites like `direct-exp-fusion` can match in the full stage on the unrolled chain). One block per binary op (`Add`, `Mul`, `Mod`, `LessThan`).
|
||||
5. **Principle**: When egglog has multiple stages (early/full) with disjoint rule sets, any rewrite that materialises new HLIR/IR enodes (rather than just lowering to LLIR) needs to fire in BOTH stages if downstream rewrites in BOTH stages might want to see the new structure. Putting "preparatory" rewrites only in `early_rewrites` means their effect is lost across the early→full handoff. The narrow rule of thumb: if your rule's outputs are intended to enable matches by other rules, audit which stages those other rules run in and register accordingly.
|
||||
|
||||
---
|
||||
|
||||
## 2026-04-26 — `unroll_loops_in_llir` panicked on iteration-invariant body producers
|
||||
|
||||
1. **Symptom**: Modal CI/CD job for the gemma example panicked at `src/graph.rs:1867` with `no entry found for key`. The line is `clone_map[i - 1][&body_producer]` inside `unroll_loops_in_llir`'s `resolve_src` closure — `body_producer` (the LoopEnd's incoming source for that slot) wasn't a key in the per-iteration clone map. cuda_lite/python tests didn't repro: only triggered by the specific genome and graph shapes that gemma's longer search settles on.
|
||||
2. **Root cause**: `body_nodes` is computed by walking *forward* from each LoopStart/LoopInput/LoopInputStatic outgoing edge, stopping at markers and `Output` ops. Some egglog-extracted LLIRs land a `body_producer` that isn't reachable via that forward walk — i.e., its only ancestors are non-marker (a constant, an external input, or an op whose chain was congruence-merged off the marker chain by rules like `LoopInputStatic inline`). Semantically this is a degenerate "iteration-invariant body": every iter computes the same value, so the loop's state never changes. The per-iter clone path needed a fallback for that case.
|
||||
3. **Why hard**: cuda_lite and python tests don't generate genomes that produce this shape, so local runs always pass. The forward-walk-only definition of `body_nodes` is *almost* always right — only specific extraction shapes from longer searches expose the gap. Test-driven debugging has limited reach when the failure mode depends on a search trajectory the local fuzzers don't explore.
|
||||
4. **Fix**: in `unroll_loops_in_llir::resolve_src`, when the LoopStart-resolved `body_producer` isn't in `body_nodes`, return `body_producer` itself for iter > 0 instead of indexing `clone_map[i - 1]`. The body op didn't depend on the loop variable, so every iter > 0 carries the same value forward — using `body_producer` directly is semantically correct. Mirrored the same `unwrap_or(body_producer)` fallback in the post-loop substitution map (`marker_post_sub` for LoopEnd / LoopOutputSelect). Added a backward-walk-from-end-markers backfill in `collapse_loops_to_first_iter` so its body-node iteration also covers these nodes (it doesn't have a clone_map, but does need to rewire body ops' incoming edges before deleting markers).
|
||||
5. **Principle**: When a graph-walk-derived set is used as a hashmap key requirement, every code path that *could* produce a key outside that set needs a graceful fallback — not just a defensive `expect`. For loop unrolling specifically, the rule is: `body_nodes` is the set of "ops that participate in per-iter computation"; ops on the LoopEnd's path that *don't* participate (iteration-invariant) are still legitimate, and need a "no clone, share across iters" path through `resolve_src` and `marker_post_sub`. Forward-walk-only `body_nodes` is correct only when extraction never produces iteration-invariant body producers — and in an egglog-driven search, that's not a guarantee you can make.
|
||||
|
||||
---
|
||||
|
||||
## 2026-04-26 — Iteration-invariant state slots are a first-class concept, not a defensive fallback
|
||||
|
||||
1. **Symptom + fix recap**: gemma Modal CI panicked at `clone_map[i-1][&body_producer]` because some state slots' `body_producer` (LoopEnd's incoming) isn't in `body_nodes` (forward walk from input markers). The first commit pair (16de9638 / 93fb02c4) caught this with `.unwrap_or(body_producer)` — which works but reads as "defensive, unclear *why* this case exists."
|
||||
2. **What's actually happening**: extracted LLIR from gemma legitimately puts a `KernelConstant` at LoopEnd's incoming for some state slots. e.g. for one slot of gemma's body=104 trips=5 rolling: `initial = KernelConstant 1.442695` (log2 e), `body_producer = same node`. For another: `body_producer = KernelConstant 9.21034` (ln 10000, RoPE's frequency base after `Log2 * ln(2)` simplification). egglog's kernel-level rewrites legitimately union body-slot eclasses with these constants when the body chain provably reduces to them. The state really is iteration-invariant — every iter sees the same value.
|
||||
3. **Why "defensive fallback" framing is misleading**: it implies the LLIR is broken. It isn't. The forward-walk-only `body_nodes` definition just doesn't cover this case, because the case requires no per-iter cloning at all. A *node not reachable from any loop input marker has no input-marker ancestor*, so by construction its value doesn't depend on the loop's per-iter state.
|
||||
4. **Cleaner formulation**: name the concept. Compute an `iteration_invariant_slots: HashSet<LoopStart>` set at the same time `start_meta` is built, with the rule `body_producer ∉ body_nodes ⇒ iteration_invariant`. `resolve_src` and `marker_post_sub` then have explicit branches: if the slot is invariant, use `body_producer` directly; otherwise the standard per-iter clone lookup. The behavior is the same as the `unwrap_or` band-aid, but the code now documents that this is a real, sound case the unroll handles correctly — not a panic suppressor.
|
||||
5. **Principle**: when an `unwrap_or` papers over a case that turns out to be semantically valid, the right cleanup isn't to keep the `unwrap_or` and add a comment — it's to name the case. Hoist the predicate into a set or enum and branch on it explicitly. The compiler then enforces that every consumer of the per-iter cloning machinery has an opinion on iteration-invariant slots, instead of silently relying on a `Map::get` returning `None` at the right moment.
|
||||
|
||||
---
|
||||
|
||||
## 2026-04-30 — `translate_grouped_mm` casted the full expert weight to F32, OOMing search on Qwen3-MoE
|
||||
|
||||
### What the symptom was
|
||||
|
||||
`benchmarks/ttft/run.py --config qwen3-moe` crashed every search-profile attempt with:
|
||||
```
|
||||
crates/luminal_cuda_lite/src/runtime.rs:711: called `Result::unwrap()` on an `Err` value:
|
||||
DriverError(CUDA_ERROR_OUT_OF_MEMORY, "out of memory")
|
||||
```
|
||||
The DB shows this had been failing every run for ~2 weeks. The rust `examples/qwen3_moe` ran fine end-to-end. python_baseline / python_torch_compile / qwen3-4b were all fine — only python_luminal × qwen3-moe failed.
|
||||
|
||||
### What the actual root cause was
|
||||
|
||||
`translate_grouped_mm` in `crates/luminal_python/rust/src/translator/tensor.rs` was lowering HF's `_grouped_mm(input, weight, offs)` op to a *full-broadcast* batched matmul plus a group-mask:
|
||||
|
||||
```rust
|
||||
let weight_f = weight.cast(DType::F32); // [G=128, K, N] cast → 1.5 GB / layer
|
||||
let input_batched = input_f.expand_dim(0, g);
|
||||
let all_out = input_batched.matmul(weight_f); // [G, S, N]
|
||||
let mask = ... (g_arange == expert_id).cast(F32);
|
||||
let out = (all_out * mask.expand_dim(2, n)).sum(0); // mask + sum over G
|
||||
```
|
||||
|
||||
The full `[G, K, N]` F32 cast intermediate is 1.5 GB / layer for gate-up and 0.6 GB / layer for down on Qwen3-30B-A3B. With 60 GB of persistent bf16 weights already on a 97 GB GPU, the search-time profiler ran out of memory allocating those casts.
|
||||
|
||||
By contrast, `examples/qwen3_moe`'s `gather_experts` gathers only the top-K active experts per token first, then casts that small `[s, k, d1, d2]` slice (~100 MB / layer). The GLUMoE host op (`crates/luminal_cuda_lite/src/host/moe/glumoe_rewrite.egg`) is also wired to this gather pattern.
|
||||
|
||||
### Why it was hard to find
|
||||
|
||||
1. **Code path was reasonable in isolation**: at small scale (`test_grouped_mm_fallback`: g=2, K=8, N=16) the broadcast version was fine — the F32 cast was only 1 KB, and search profiling never noticed.
|
||||
2. **The error reported "out of memory" but the rest of the system looked healthy**: 60 GB weights + 37 GB headroom looks like plenty until you realise 48 layers × 2.1 GB cast intermediates per layer doesn't fit, even after loop rolling.
|
||||
3. **The DB's `code 1` failures looked the same as a Python exception** — the actual panic site (`runtime.rs:711:64` `stream.alloc_zeros(needed_bytes).unwrap()`) had to be recovered from a tmux scrollback because the orchestrator's stdout was already torn down by the time we looked.
|
||||
|
||||
### The fix
|
||||
|
||||
Rewrote `translate_grouped_mm` to gather first, matmul second:
|
||||
|
||||
```rust
|
||||
// expert_id[m] = first g s.t. m < offs[g], clamped to [0, G-1]
|
||||
let expert_id = ge_boundary.sum(0).minimum_f32(g_max_f).cast(DType::Int);
|
||||
|
||||
// flat_idx = expert_id * (K*N) + iota('z', (K, N)) — same shape as
|
||||
// rust qwen3_moe's `gather_experts`
|
||||
let flat_idx = (expert_id * (k * n))
|
||||
.expand_dim(1, k).expand_dim(2, n)
|
||||
+ self.graph.iota(Expression::from('z'), (k, n)).expand_dim(0, s);
|
||||
|
||||
let weight_gathered = weight.gather(flat_idx); // [S, K, N], bf16
|
||||
let result = input.cast(F32).unsqueeze(1)
|
||||
.matmul(weight_gathered.cast(F32)) // [S, 1, N]
|
||||
.squeeze(1);
|
||||
```
|
||||
|
||||
Two important details:
|
||||
|
||||
1. **Clamp `expert_id` to `[0, G-1]`**: at search time, dummy data fills `offs` with all-1s (`make_ones_bytes` in `compile_backend`). For S>1 that pushes `expert_id` to G (boundary count = G), which is one past the last valid expert and OOBs the gather. HF's own grouped-MM forward also clamps for the same reason (invalid expert IDs from EP).
|
||||
2. **Don't cast the full weight**: the cast moved from before the batched-matmul (over `[G, K, N]`) to after the gather (over `[S, K, N]`). 16× shrink at prefill (S=top_k=8 vs G=128).
|
||||
|
||||
### Result
|
||||
|
||||
`search-iters=1` end-to-end works on Qwen3-30B-A3B: `BENCH_RESULT … "ttft_ms": 9350.5, "tpot_ms": 1166.7`. The OOM is gone.
|
||||
|
||||
`search-iters>=5` still crashes — but with a *different*, downstream `CUDA_ERROR_ILLEGAL_ADDRESS` during execution after search completes. That looks like the same family as the 2026-03-07 / 2026-03-09 egglog-extractor non-determinism bugs (some mutation during search picks a kernel/rewrite combo that's broken at this scale). It's a separate investigation — the gather-based lowering is correct in isolation (`test_grouped_mm_fallback` passes; a synthetic `g=128, S=8, K=2048, N=1536` bf16 test passes with max-diff ~2.4e-4).
|
||||
|
||||
### General principle
|
||||
|
||||
**When lowering an op that takes a per-row index over a large parameter, gather first and cast second — never cast the full parameter to F32 just because your matmul kernel is F32-only.** A "broadcast over G + mask" pattern is mathematically equivalent to "gather per-row" but materialises a G× larger intermediate — fine for tests, ruinous on real MoE checkpoints. When in doubt, mirror the rust example's pattern: the egglog fusion rules (GLUMoE here) are written to recognise the gather form, not the broadcast-and-mask form.
|
||||
|
||||
Also: search-time dummy-1 inputs are not the same shape as runtime inputs. Anything you compute from a runtime tensor (cumsum offsets, routing indices, mask boundaries) needs to remain in-bounds for the dummy. Clamp index-producing chains as a matter of course, not just when the math says you "should" — `make_ones_bytes` is a hostile witness.
|
||||
|
||||
---
|
||||
|
||||
## 2026-05-01 — `KernelScatter` float4 vectorization wrote 2× past end of buffer for bf16/f16 KV cache
|
||||
|
||||
### What the symptom was
|
||||
|
||||
After the `translate_grouped_mm` gather rewrite (above) cleared the OOM, the qwen3-moe bench progressed past search but panicked during execution roughly 40% of the time:
|
||||
```
|
||||
crates/luminal_cuda_lite/src/runtime.rs:1204:
|
||||
CUDA execute error in "CudaGraph":
|
||||
DriverError(CUDA_ERROR_ILLEGAL_ADDRESS, "an illegal memory access was encountered")
|
||||
```
|
||||
qwen3-4b (dense) was unaffected; the bf16 KV cache in HF `StaticCache` was the only path triggering it. The rust `examples/qwen3_moe` ran fine because it uses an F32 KV cache.
|
||||
|
||||
### What the actual root cause was
|
||||
|
||||
`KernelScatter::compile` in `crates/luminal_cuda_lite/src/kernel/hlir.rs` emitted a hand-written CUDA copy phase that vectorised through `float4` (16-byte) reads/writes:
|
||||
|
||||
```cuda
|
||||
long long n_vec = n_dest / 4; // ← assumes 4-byte dtype
|
||||
float4 *out4 = (float4 *)out;
|
||||
const float4 *dest4 = (const float4 *)dest;
|
||||
for (long long i = tid; i < n_vec; i += blockDim.x) {
|
||||
out4[i] = dest4[i]; // ← writes 16 B per iteration
|
||||
}
|
||||
long long remainder_start = n_vec * 4; // ← also assumes 4 elem/vec
|
||||
```
|
||||
|
||||
For `dtype=F32` (4 bytes), `n_vec * 16 = n_dest * 4` bytes — exactly fills the buffer. For `dtype=Bf16` (2 bytes), `n_vec * 16 = (n_dest/4) * 16 = n_dest * 4` bytes, which is **2× the actual buffer size of `n_dest * 2` bytes**. The write walks half the buffer past the end of `out` (and reads past `dest`).
|
||||
|
||||
Whether that produced an `ILLEGAL_ADDRESS` depended on whether the OOB region happened to land on an unmapped page. For different search outcomes, the surrounding allocator state differed → ~60% it was silent corruption, ~40% it crashed the CUDA context. That probabilistic mix is why the bug had been hidden — no test exercised a bf16 scatter (every existing scatter test uses F32 by default), and the rust example uses F32 KV cache so it was never seen there either.
|
||||
|
||||
### Why it was hard to find
|
||||
|
||||
1. **Probabilistic, but search-determinate**: the rewrite from HLIR `Scatter` → `KernelScatter` always fires (it's the only non-NoCopy path), so the kernel is always present. The crash depends on memory layout, which depends on which other kernels the search picked. Made it look like an egglog-mutation issue rather than a kernel-correctness issue.
|
||||
2. **Existing test coverage was F32-only**: `test_scatter_execution_correctness` (in `tests/consumed_buffer_tests.rs`) explicitly tries 50 random extractions to cover both `Scatter` and `ScatterNoCopy`, but always with `cx.tensor(5)` which defaults to F32. The bug would never surface there.
|
||||
3. **The panic message hid the kernel name**: it surfaced as a generic `"CudaGraph"` host-op panic — the cuda_graph_exec batches all kernels into one atomic launch, so the failing kernel disappears into the batch. To localize it I had to add a `LUMINAL_DEBUG_SEQ` env var to `CudaGraphOp::execute_internal` that bypasses graph batching and launches each kernel via `cuLaunchKernel` with a sync afterwards, surfacing kernel name + node + grid/block/pointers when one fails.
|
||||
|
||||
### The fix
|
||||
|
||||
Parameterise `n_vec` and the remainder-loop start by the number of dtype elements that fit in 16 bytes:
|
||||
|
||||
```rust
|
||||
let elements_per_vec: usize = match self.dtype {
|
||||
DType::F64 => 2,
|
||||
DType::F32 | DType::Int => 4,
|
||||
DType::F16 | DType::Bf16 | DType::I16 | DType::U16 => 8,
|
||||
DType::Bool | DType::I8 | DType::U8
|
||||
| DType::F8UE8M0 | DType::F8E4M3 | DType::F8E5M2 => 16,
|
||||
other => panic!("Unsupported dtype for scatter vectorization: {other:?}"),
|
||||
};
|
||||
```
|
||||
and substitute `{elements_per_vec}` into the kernel template (both the `n_vec` calc and `remainder_start`). For F32 / Int the generated code is byte-for-byte identical to before, so existing F32 tests are unaffected; for any other dtype the byte coverage now exactly equals `n_dest * sizeof(dtype)` as intended.
|
||||
|
||||
### Result
|
||||
|
||||
Before fix: 3/5 success at iters=10 (probabilistic).
|
||||
After fix: 5/5 at iters=10, 3/3 at iters=50. All 206 HLIR tests still pass. TTFT/TPOT identical (~9.35s / ~1.17s).
|
||||
|
||||
### General principle
|
||||
|
||||
**Hand-rolled CUDA vectorisation with a fixed-width type (`float4`, `float2`, `int4`, …) is almost always specialised to one element size.** When the same kernel template is parameterised by `dtype`, every byte-count expression has to be too. The cheapest correct form is "elements per vector load" computed from the dtype's byte size — never hardcode `/4`.
|
||||
|
||||
Also: **F32 is not a representative test dtype for kernels with vector loads.** When a kernel is written generic-over-dtype, the test matrix needs to actually exercise the dtypes (bf16, f16, bool) where the vector-element-count differs. A `test_scatter_bf16` would have caught this years before the qwen3-moe bench did. Same trap likely exists wherever else `float4` is cast over a `{dtype} *` template.
|
||||
|
||||
Diagnostic also added: `LUMINAL_DEBUG_SEQ=1` on the python_luminal path will now bypass `CudaGraphOp` batching at execute time, launching each kernel sequentially with a sync afterwards. If a future ILLEGAL_ADDRESS hides inside a batched graph again, this surfaces the kernel name and node index immediately.
|
||||
|
||||
|
||||
@@ -186,7 +186,7 @@ class TestRunner:
|
||||
env = os.environ.copy()
|
||||
existing = env.get("PYTHONPATH")
|
||||
env["PYTHONPATH"] = f"{SRC_PATH}:{existing}" if existing else SRC_PATH
|
||||
env["LUMINAL_BACKEND"] = "cuda"
|
||||
env["LUMINAL_TEST_DEVICE"] = "cuda"
|
||||
env["UV_PROJECT_ENVIRONMENT"] = VENV_PATH
|
||||
env["MATURIN_PEP517_ARGS"] = "--features cuda --profile release"
|
||||
env["CUDARC_CUDA_VERSION"] = CUDARC_CUDA_VERSION
|
||||
|
||||
@@ -46,4 +46,5 @@ dev = [
|
||||
"transformers>=4.40.0",
|
||||
"diffusers>=0.35.0",
|
||||
"modal>=1.3.5",
|
||||
"matplotlib>=3.8",
|
||||
]
|
||||
|
||||
@@ -28,7 +28,7 @@ uv run maturin develop --manifest-path rust/Cargo.toml --features cuda -r
|
||||
|
||||
echo ""
|
||||
echo "--- 2a: CUDA ---"
|
||||
RUST_BACKTRACE=1 LUMINAL_BACKEND=cuda uv run pytest $CUDA_TESTS -m "not slow" -v
|
||||
RUST_BACKTRACE=1 LUMINAL_TEST_DEVICE=cuda uv run pytest $CUDA_TESTS -m "not slow" -v
|
||||
|
||||
echo ""
|
||||
echo "=========================================="
|
||||
|
||||
@@ -14,7 +14,7 @@ uv run maturin develop --manifest-path rust/Cargo.toml --features cuda -r
|
||||
|
||||
# Run pytest with CUDA backend
|
||||
echo "Step 3: Running pytest with CUDA backend..."
|
||||
RUST_BACKTRACE=1 LUMINAL_BACKEND=cuda uv run pytest tests/test_llama3.py tests/test_hlir_ops.py tests/test_unary.py -v
|
||||
RUST_BACKTRACE=1 LUMINAL_TEST_DEVICE=cuda uv run pytest tests/test_llama3.py tests/test_hlir_ops.py tests/test_unary.py -v
|
||||
|
||||
echo ""
|
||||
echo "=== Tests Complete ==="
|
||||
|
||||
@@ -1,17 +1,13 @@
|
||||
#[cfg(feature = "cuda")]
|
||||
use luminal::prelude::tracing::{trace, warn};
|
||||
use luminal::{
|
||||
hlir::{NativeData, Output},
|
||||
dyn_backend::{BackendCompileArgs, BackendFactory, DynBackend},
|
||||
prelude::*,
|
||||
shape::Expression,
|
||||
visualization::ToDot,
|
||||
};
|
||||
use pyo3::prelude::*;
|
||||
use std::collections::HashMap;
|
||||
#[cfg(feature = "cuda")]
|
||||
use std::collections::HashSet;
|
||||
|
||||
use crate::{runtime::RuntimeBackend, typed_data::TypedData};
|
||||
use crate::typed_data::TypedData;
|
||||
|
||||
/// Maps symbolic dimension parameter names (e.g. "seq_len") to luminal Expression variable chars.
|
||||
pub type DimParamMap = HashMap<String, char>;
|
||||
@@ -59,7 +55,7 @@ pub struct WeightData {
|
||||
#[pyclass(unsendable)]
|
||||
pub struct CompiledGraph {
|
||||
pub graph: Graph,
|
||||
pub runtime: RuntimeBackend,
|
||||
pub runtime: Box<dyn DynBackend>,
|
||||
pub tensor_ids: HashMap<String, NodeIndex>,
|
||||
/// Cached label → NodeIndex map for O(1) lookups in set_weight_* methods.
|
||||
label_map: HashMap<String, NodeIndex>,
|
||||
@@ -76,12 +72,12 @@ impl CompiledGraph {
|
||||
/// Compilation pipeline for PT2/FX graphs.
|
||||
///
|
||||
/// Takes a `GraphTranslation` (produced by `translate_pt2`) and `WeightData`,
|
||||
/// builds the backend, loads weights, and
|
||||
/// builds the backend via the global registry, loads weights, and
|
||||
/// returns a ready-to-execute `CompiledGraph`.
|
||||
pub fn parse_graph(
|
||||
translation: GraphTranslation,
|
||||
weight_data: WeightData,
|
||||
backend: &str,
|
||||
factory: BackendFactory,
|
||||
search_iters: usize,
|
||||
) -> Result<CompiledGraph, String> {
|
||||
let GraphTranslation {
|
||||
@@ -95,45 +91,29 @@ impl CompiledGraph {
|
||||
dim_param_map,
|
||||
} = translation;
|
||||
|
||||
let rt = match backend {
|
||||
#[cfg(feature = "cuda")]
|
||||
"cuda" | "gpu" => {
|
||||
CompiledGraph::build_cuda_backend(&mut graph, &weight_data, search_iters)?
|
||||
}
|
||||
"native" | "cpu" => {
|
||||
CompiledGraph::build_native_backend(&mut graph, &weight_data, search_iters)?
|
||||
}
|
||||
_ => {
|
||||
#[cfg(feature = "cuda")]
|
||||
{
|
||||
return Err(format!(
|
||||
"Invalid backend '{}'. Must be 'native' or 'cuda'",
|
||||
backend
|
||||
));
|
||||
}
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
{
|
||||
if backend == "cuda" {
|
||||
return Err(
|
||||
"CUDA backend requested, but this luminal extension was built without the `cuda` feature. Rebuild with `maturin develop --features cuda -r` or use backend='native'."
|
||||
.to_string(),
|
||||
);
|
||||
}
|
||||
return Err(format!(
|
||||
"Invalid backend '{}'. This build only supports 'native'. Rebuild with the `cuda` feature to enable 'cuda'.",
|
||||
backend
|
||||
));
|
||||
}
|
||||
}
|
||||
// Build compile args from WeightData (convert TypedData -> raw bytes + dtype)
|
||||
let compile_args = BackendCompileArgs {
|
||||
search_iters,
|
||||
weights: weight_data
|
||||
.weights
|
||||
.iter()
|
||||
.map(|(label, td)| (label.clone(), td.bytes.clone(), td.dtype))
|
||||
.collect(),
|
||||
tensor_sizes: weight_data.tensor_sizes,
|
||||
device_ptrs: weight_data.device_ptrs,
|
||||
};
|
||||
|
||||
// Create backend via the factory directly
|
||||
let rt =
|
||||
luminal::dyn_backend::compile_backend_from_factory(factory, &mut graph, compile_args)?;
|
||||
|
||||
// Resolve concrete output shapes from expressions
|
||||
let output_shapes: Vec<Vec<usize>> = output_shape_exprs
|
||||
.iter()
|
||||
.map(|exprs| exprs.iter().map(|e| e.to_usize().unwrap_or(1)).collect())
|
||||
.collect();
|
||||
|
||||
let label_map = CompiledGraph::build_label_map(&graph);
|
||||
let label_map = luminal::dyn_backend::build_label_map(&graph);
|
||||
|
||||
Ok(CompiledGraph {
|
||||
graph,
|
||||
@@ -149,160 +129,6 @@ impl CompiledGraph {
|
||||
dim_param_map,
|
||||
})
|
||||
}
|
||||
|
||||
/// Build a label → NodeIndex map for all Input nodes in the graph.
|
||||
/// Used for efficient weight loading by label matching.
|
||||
fn build_label_map(graph: &Graph) -> HashMap<String, NodeIndex> {
|
||||
graph
|
||||
.graph
|
||||
.node_indices()
|
||||
.filter_map(|node_id| {
|
||||
(*graph.graph[node_id])
|
||||
.as_any()
|
||||
.downcast_ref::<luminal::hlir::Input>()
|
||||
.map(|input| (input.label.clone(), node_id))
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
fn build_cuda_backend(
|
||||
graph: &mut Graph,
|
||||
weight_data: &WeightData,
|
||||
search_iters: usize,
|
||||
) -> Result<RuntimeBackend, String> {
|
||||
let device_ptrs = &weight_data.device_ptrs;
|
||||
use luminal_cuda_lite::cudarc::driver::CudaContext;
|
||||
use luminal_cuda_lite::runtime::CudaRuntime;
|
||||
|
||||
let cuda_ctx = CudaContext::new(0).map_err(|e| format!("CUDA context init failed: {e}"))?;
|
||||
let stream = cuda_ctx.default_stream();
|
||||
|
||||
graph.build_search_space::<CudaRuntime>();
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
|
||||
// Build label → NodeIndex map for device pointer matching.
|
||||
let label_map = CompiledGraph::build_label_map(graph);
|
||||
|
||||
// For weights with device pointers: use them directly (zero-copy).
|
||||
// This avoids allocating ~N GB of dummy data during search.
|
||||
// The pointers survive search because profiling mode skips buffer consumption,
|
||||
// and graph-level .persist() ensures they survive post-search execution too.
|
||||
let mut device_ptr_nodes: HashSet<NodeIndex> = HashSet::new();
|
||||
let mut matched_count = 0usize;
|
||||
let mut missed_labels: Vec<String> = Vec::new();
|
||||
for (label, &(ptr, n_bytes)) in device_ptrs {
|
||||
if let Some(&node_id) = label_map.get(label) {
|
||||
unsafe { rt.set_device_ptr(node_id, ptr, n_bytes) };
|
||||
device_ptr_nodes.insert(node_id);
|
||||
matched_count += 1;
|
||||
} else {
|
||||
missed_labels.push(label.clone());
|
||||
}
|
||||
}
|
||||
let total_device_bytes: usize = device_ptrs.values().map(|(_, n)| *n).sum();
|
||||
trace!(
|
||||
"[CUDA BUILD] Device pointers: {} matched, {} missed out of {} total ({:.3} GiB)",
|
||||
matched_count,
|
||||
missed_labels.len(),
|
||||
device_ptrs.len(),
|
||||
total_device_bytes as f64 / (1024.0 * 1024.0 * 1024.0),
|
||||
);
|
||||
if !missed_labels.is_empty() {
|
||||
warn!(
|
||||
"[CUDA BUILD] {} device-ptr labels did not match any Input node (first 10): {:?}",
|
||||
missed_labels.len(),
|
||||
&missed_labels[..missed_labels.len().min(10)]
|
||||
);
|
||||
let available: Vec<&String> = label_map.keys().take(10).collect();
|
||||
warn!(
|
||||
"[CUDA BUILD] Available label_map keys (first 10): {:?}",
|
||||
available
|
||||
);
|
||||
}
|
||||
|
||||
// Set dummy 1.0 data for remaining Input nodes (user inputs, constants without
|
||||
// device pointers) for safe search profiling.
|
||||
// IMPORTANT: Must use 1.0, NOT 0.0. Zero inputs cause NaN in many ops:
|
||||
// - fmod(0, 0) = NaN (Mod)
|
||||
// - recip(0) = inf → weight * inf = NaN (Div)
|
||||
// - log(0) = -inf (Pow)
|
||||
// - chain ops with zero produce NaN (Erf)
|
||||
let mut dummy_total_elements = 0usize;
|
||||
let mut dummy_count = 0usize;
|
||||
for node_id in graph.graph.node_indices() {
|
||||
if device_ptr_nodes.contains(&node_id) {
|
||||
continue;
|
||||
}
|
||||
if let Some(input) = (*graph.graph[node_id])
|
||||
.as_any()
|
||||
.downcast_ref::<luminal::hlir::Input>()
|
||||
{
|
||||
if let Some(&n) = weight_data.tensor_sizes.get(&input.label) {
|
||||
if n > 0 {
|
||||
dummy_total_elements += n;
|
||||
dummy_count += 1;
|
||||
// Use dtype-aware dummy data: TypedData::ones produces correct
|
||||
// byte patterns for every dtype (f32, f16, bf16, i32, bool, f8, etc.).
|
||||
// Must use 1, not 0 — zero inputs cause NaN in many ops.
|
||||
rt.set_data(node_id, TypedData::ones(n, input.dtype).bytes);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
trace!(
|
||||
"[CUDA BUILD] Dummy data: {} nodes, {} elements ({:.3} GiB as f32)",
|
||||
dummy_count,
|
||||
dummy_total_elements,
|
||||
(dummy_total_elements * 4) as f64 / (1024.0 * 1024.0 * 1024.0),
|
||||
);
|
||||
|
||||
// Search (device-pointer weights are used directly; dummy data for the rest)
|
||||
let mut rt = graph.search(rt, search_iters);
|
||||
|
||||
// Load real weight data for non-device-ptr weights (constants from PT2 archive, etc.)
|
||||
let mut loaded_weight_bytes = 0usize;
|
||||
let mut loaded_weight_count = 0usize;
|
||||
for (label, data) in &weight_data.weights {
|
||||
if !device_ptrs.contains_key(label) {
|
||||
if let Some(&node_id) = label_map.get(label) {
|
||||
loaded_weight_bytes += data.n_bytes();
|
||||
loaded_weight_count += 1;
|
||||
rt.set_data(node_id, data.bytes.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
trace!(
|
||||
"[CUDA BUILD] Post-search weight load: {} weights, {:.3} GiB",
|
||||
loaded_weight_count,
|
||||
loaded_weight_bytes as f64 / (1024.0 * 1024.0 * 1024.0),
|
||||
);
|
||||
|
||||
Ok(RuntimeBackend::Cuda(Box::new(rt)))
|
||||
}
|
||||
|
||||
fn build_native_backend(
|
||||
graph: &mut Graph,
|
||||
weight_data: &WeightData,
|
||||
search_iters: usize,
|
||||
) -> Result<RuntimeBackend, String> {
|
||||
graph.build_search_space::<NativeRuntime>();
|
||||
let mut rt = graph.search(NativeRuntime::default(), search_iters);
|
||||
|
||||
// Load weight data after search, preserving native dtype.
|
||||
// TypedData -> NativeData conversion (From<TypedData>) handles mapping to the
|
||||
// correct NativeData variant (F32, F16, Bf16, Int, Bool).
|
||||
let label_map = CompiledGraph::build_label_map(graph);
|
||||
for (label, data) in &weight_data.weights {
|
||||
if let Some(&node_id) = label_map.get(label) {
|
||||
let native: NativeData = data.into();
|
||||
rt.set_data(node_id, native);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(RuntimeBackend::Native(rt))
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
@@ -349,12 +175,24 @@ impl CompiledGraph {
|
||||
self.tensor_ids.keys().cloned().collect()
|
||||
}
|
||||
|
||||
/// Get the name of the active backend (native or cuda).
|
||||
/// Get the name of the active backend.
|
||||
#[getter]
|
||||
fn backend(&self) -> &'static str {
|
||||
fn backend(&self) -> &str {
|
||||
self.runtime.name()
|
||||
}
|
||||
|
||||
/// The device type this backend operates on (e.g. "cpu", "cuda").
|
||||
#[getter]
|
||||
fn device_type(&self) -> &str {
|
||||
self.runtime.device_type()
|
||||
}
|
||||
|
||||
/// Whether the active backend supports device pointer operations (zero-copy GPU I/O).
|
||||
#[getter]
|
||||
fn supports_device_ptrs(&self) -> bool {
|
||||
self.runtime.supports_device_ptrs()
|
||||
}
|
||||
|
||||
/// Whether this graph has dynamic (symbolic) dimensions.
|
||||
#[getter]
|
||||
fn has_dynamic_dims(&self) -> bool {
|
||||
@@ -445,91 +283,83 @@ impl CompiledGraph {
|
||||
})?;
|
||||
let raw_bytes = unsafe { std::slice::from_raw_parts(ptr as *const u8, n_bytes).to_vec() };
|
||||
let typed = TypedData::from_pytorch_bytes(raw_bytes, dtype_code);
|
||||
self.runtime.set_data(*node_id, typed);
|
||||
self.runtime
|
||||
.set_data_bytes(*node_id, typed.bytes, typed.dtype);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Set input from a CUDA device pointer. Zero-copy on device.
|
||||
/// The pointer must be a valid CUDA device allocation with at least n_bytes bytes.
|
||||
#[cfg(feature = "cuda")]
|
||||
/// Set input from a device pointer. Zero-copy on device.
|
||||
/// The pointer must be a valid device allocation with at least n_bytes bytes.
|
||||
/// Requires a GPU backend (e.g. CUDA).
|
||||
fn set_input_device_ptr(
|
||||
&mut self,
|
||||
name: &str,
|
||||
device_ptr: u64,
|
||||
n_bytes: usize,
|
||||
) -> PyResult<()> {
|
||||
if !self.runtime.supports_device_ptrs() {
|
||||
return Err(pyo3::exceptions::PyValueError::new_err(
|
||||
"set_input_device_ptr requires a GPU backend",
|
||||
));
|
||||
}
|
||||
let node_id = self.tensor_ids.get(name).ok_or_else(|| {
|
||||
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!("Unknown input tensor: {}", name))
|
||||
})?;
|
||||
match &mut self.runtime {
|
||||
RuntimeBackend::Cuda(rt) => unsafe { rt.set_device_ptr(*node_id, device_ptr, n_bytes) },
|
||||
_ => {
|
||||
return Err(pyo3::exceptions::PyValueError::new_err(
|
||||
"set_input_device_ptr requires CUDA backend",
|
||||
));
|
||||
}
|
||||
}
|
||||
unsafe { self.runtime.set_device_ptr(*node_id, device_ptr, n_bytes) };
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// For PT2 weights (e.g. "fc1.weight"). Persistence is handled at graph level via .persist().
|
||||
#[cfg(feature = "cuda")]
|
||||
/// Set a weight from a device pointer (e.g. "fc1.weight"). Zero-copy on device.
|
||||
/// Requires a GPU backend.
|
||||
fn set_weight_device_ptr(
|
||||
&mut self,
|
||||
label: &str,
|
||||
device_ptr: u64,
|
||||
n_bytes: usize,
|
||||
) -> PyResult<()> {
|
||||
if !self.runtime.supports_device_ptrs() {
|
||||
return Err(pyo3::exceptions::PyValueError::new_err(
|
||||
"set_weight_device_ptr requires a GPU backend",
|
||||
));
|
||||
}
|
||||
let &node_id = self.label_map.get(label).ok_or_else(|| {
|
||||
pyo3::exceptions::PyKeyError::new_err(format!("No Input node with label: {}", label))
|
||||
})?;
|
||||
match &mut self.runtime {
|
||||
RuntimeBackend::Cuda(rt) => {
|
||||
unsafe { rt.set_device_ptr(node_id, device_ptr, n_bytes) };
|
||||
}
|
||||
_ => {
|
||||
return Err(pyo3::exceptions::PyValueError::new_err(
|
||||
"set_weight_device_ptr requires CUDA backend",
|
||||
));
|
||||
}
|
||||
}
|
||||
unsafe { self.runtime.set_device_ptr(node_id, device_ptr, n_bytes) };
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Register an external device pointer for an output tensor (zero-copy output).
|
||||
/// Call before run() — the runtime will write kernel results directly into this buffer.
|
||||
/// For aliased outputs (in-place ops), falls back to DtoD copy; check output_is_zero_copy() after run().
|
||||
#[cfg(feature = "cuda")]
|
||||
/// Requires a GPU backend.
|
||||
fn set_output_device_ptr(
|
||||
&mut self,
|
||||
name: &str,
|
||||
device_ptr: u64,
|
||||
n_bytes: usize,
|
||||
) -> PyResult<()> {
|
||||
if !self.runtime.supports_device_ptrs() {
|
||||
return Err(pyo3::exceptions::PyValueError::new_err(
|
||||
"set_output_device_ptr requires a GPU backend",
|
||||
));
|
||||
}
|
||||
let node_id = self.tensor_ids.get(name).ok_or_else(|| {
|
||||
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!(
|
||||
"Unknown output tensor: {}",
|
||||
name
|
||||
))
|
||||
})?;
|
||||
|
||||
match &mut self.runtime {
|
||||
RuntimeBackend::Cuda(rt) => {
|
||||
unsafe { rt.set_output_device_ptr(*node_id, device_ptr, n_bytes) };
|
||||
}
|
||||
_ => {
|
||||
return Err(pyo3::exceptions::PyValueError::new_err(
|
||||
"set_output_device_ptr requires CUDA backend",
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
unsafe {
|
||||
self.runtime
|
||||
.set_output_device_ptr(*node_id, device_ptr, n_bytes)
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check whether an output tensor was zero-copied (written directly to the registered pointer).
|
||||
/// Returns false for aliased outputs that need a fallback DtoD copy. Must be called after run().
|
||||
#[cfg(feature = "cuda")]
|
||||
/// Returns false for aliased outputs that need a fallback DtoD copy, or if no GPU backend.
|
||||
/// Must be called after run().
|
||||
fn output_is_zero_copy(&self, name: &str) -> PyResult<bool> {
|
||||
let node_id = self.tensor_ids.get(name).ok_or_else(|| {
|
||||
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!(
|
||||
@@ -537,11 +367,7 @@ impl CompiledGraph {
|
||||
name
|
||||
))
|
||||
})?;
|
||||
|
||||
match &self.runtime {
|
||||
RuntimeBackend::Cuda(rt) => Ok(rt.output_is_zero_copy(*node_id)),
|
||||
_ => Ok(false),
|
||||
}
|
||||
Ok(self.runtime.output_is_zero_copy(*node_id))
|
||||
}
|
||||
|
||||
/// Set a weight tensor from a CPU host pointer, matching by Input node label (dtype-aware).
|
||||
@@ -559,7 +385,8 @@ impl CompiledGraph {
|
||||
})?;
|
||||
let bytes = unsafe { std::slice::from_raw_parts(ptr as *const u8, n_bytes).to_vec() };
|
||||
let typed = TypedData::from_pytorch_bytes(bytes, dtype_code);
|
||||
self.runtime.set_data(node_id, typed);
|
||||
self.runtime
|
||||
.set_data_bytes(node_id, typed.bytes, typed.dtype);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -585,9 +412,6 @@ impl CompiledGraph {
|
||||
}
|
||||
|
||||
/// Get output tensor data by name as f32 (copies to host).
|
||||
/// For native backend: handles any NativeData variant by converting to f32.
|
||||
/// The native runtime may produce NativeData::Int or NativeData::Bool for some ops
|
||||
/// (e.g., Cast chains), so we can't assume NativeData::F32.
|
||||
fn get_output(&self, name: &str) -> PyResult<Vec<f32>> {
|
||||
let node_id = self.tensor_ids.get(name).ok_or_else(|| {
|
||||
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!(
|
||||
@@ -595,57 +419,50 @@ impl CompiledGraph {
|
||||
name
|
||||
))
|
||||
})?;
|
||||
match &self.runtime {
|
||||
RuntimeBackend::Native(rt) => {
|
||||
let id = *node_id;
|
||||
let output_id = rt
|
||||
.graph
|
||||
.node_indices()
|
||||
.find(|n| {
|
||||
if let Some(out) = (**rt.graph[*n]).as_any().downcast_ref::<Output>() {
|
||||
out.node == id.index()
|
||||
} else {
|
||||
false
|
||||
}
|
||||
})
|
||||
.ok_or_else(|| {
|
||||
pyo3::exceptions::PyRuntimeError::new_err(format!(
|
||||
"No output node found for tensor: {}",
|
||||
name
|
||||
))
|
||||
})?;
|
||||
let data = rt.buffers.get(&output_id).ok_or_else(|| {
|
||||
pyo3::exceptions::PyRuntimeError::new_err(format!(
|
||||
"No buffer data for output tensor: {}",
|
||||
name
|
||||
))
|
||||
})?;
|
||||
// Convert any NativeData variant to f32
|
||||
Ok((0..data.len()).map(|i| data.f32(i)).collect())
|
||||
}
|
||||
#[cfg(feature = "cuda")]
|
||||
RuntimeBackend::Cuda(rt) => Ok(rt.get_f32(*node_id)),
|
||||
}
|
||||
Ok(self.runtime.get_output_f32(*node_id))
|
||||
}
|
||||
|
||||
/// Copy output tensor data directly to a CUDA device pointer (DtoD).
|
||||
/// Avoids the DtoH + HtoD round-trip of get_output() + .to(device).
|
||||
#[cfg(feature = "cuda")]
|
||||
fn copy_output_to_device_ptr(&self, name: &str, dest_ptr: u64, n_bytes: usize) -> PyResult<()> {
|
||||
/// Get output tensor data by name as i32 (copies to host).
|
||||
fn get_output_i32(&self, name: &str) -> PyResult<Vec<i32>> {
|
||||
let node_id = self.tensor_ids.get(name).ok_or_else(|| {
|
||||
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!(
|
||||
"Unknown output tensor: {}",
|
||||
name
|
||||
))
|
||||
})?;
|
||||
match &self.runtime {
|
||||
RuntimeBackend::Cuda(rt) => {
|
||||
unsafe { rt.copy_output_to_device_ptr(*node_id, dest_ptr, n_bytes) };
|
||||
Ok(())
|
||||
}
|
||||
_ => Err(pyo3::exceptions::PyValueError::new_err(
|
||||
"copy_output_to_device_ptr requires CUDA backend",
|
||||
)),
|
||||
Ok(self.runtime.get_output_i32(*node_id))
|
||||
}
|
||||
|
||||
/// Get output tensor data by name as bool (copies to host).
|
||||
fn get_output_bool(&self, name: &str) -> PyResult<Vec<bool>> {
|
||||
let node_id = self.tensor_ids.get(name).ok_or_else(|| {
|
||||
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!(
|
||||
"Unknown output tensor: {}",
|
||||
name
|
||||
))
|
||||
})?;
|
||||
Ok(self.runtime.get_output_bool(*node_id))
|
||||
}
|
||||
|
||||
/// Copy output tensor data directly to a device pointer (DtoD).
|
||||
/// Avoids the DtoH + HtoD round-trip of get_output() + .to(device).
|
||||
/// Requires a GPU backend.
|
||||
fn copy_output_to_device_ptr(&self, name: &str, dest_ptr: u64, n_bytes: usize) -> PyResult<()> {
|
||||
if !self.runtime.supports_device_ptrs() {
|
||||
return Err(pyo3::exceptions::PyValueError::new_err(
|
||||
"copy_output_to_device_ptr requires a GPU backend",
|
||||
));
|
||||
}
|
||||
let node_id = self.tensor_ids.get(name).ok_or_else(|| {
|
||||
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!(
|
||||
"Unknown output tensor: {}",
|
||||
name
|
||||
))
|
||||
})?;
|
||||
unsafe {
|
||||
self.runtime
|
||||
.copy_output_to_device_ptr(*node_id, dest_ptr, n_bytes)
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
mod compiled_graph;
|
||||
mod runtime;
|
||||
pub mod typed_data;
|
||||
|
||||
// PT2 modules
|
||||
@@ -12,10 +11,40 @@ mod translator;
|
||||
use compiled_graph::CompiledGraph;
|
||||
use pt2_compiled_model::process_pt2;
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::PyCapsule;
|
||||
|
||||
#[pymodule]
|
||||
fn luminal(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_function(wrap_pyfunction!(process_pt2, m)?)?;
|
||||
m.add_class::<CompiledGraph>()?;
|
||||
m.add_function(wrap_pyfunction!(_native_factory_capsule, m)?)?;
|
||||
#[cfg(feature = "cuda")]
|
||||
m.add_function(wrap_pyfunction!(_cuda_lite_factory_capsule, m)?)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Factory capsule helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Wrapper to put a function pointer into a PyCapsule.
|
||||
#[allow(dead_code)]
|
||||
struct FnPtrWrapper(pub *const std::ffi::c_void);
|
||||
unsafe impl Send for FnPtrWrapper {}
|
||||
|
||||
/// PyCapsule wrapping the native (CPU) backend factory.
|
||||
#[pyfunction]
|
||||
fn _native_factory_capsule<'py>(py: Python<'py>) -> PyResult<Bound<'py, PyCapsule>> {
|
||||
let fptr = ::luminal::dyn_backend::native_factory as *const std::ffi::c_void;
|
||||
let name = ::luminal::dyn_backend::BACKEND_FACTORY_CAPSULE_NAME.to_owned();
|
||||
PyCapsule::new(py, FnPtrWrapper(fptr), Some(name))
|
||||
}
|
||||
|
||||
/// PyCapsule wrapping the cuda_lite backend factory.
|
||||
#[cfg(feature = "cuda")]
|
||||
#[pyfunction]
|
||||
fn _cuda_lite_factory_capsule<'py>(py: Python<'py>) -> PyResult<Bound<'py, PyCapsule>> {
|
||||
let fptr = luminal_cuda_lite::dyn_backend::cuda_lite_factory as *const std::ffi::c_void;
|
||||
let name = ::luminal::dyn_backend::BACKEND_FACTORY_CAPSULE_NAME.to_owned();
|
||||
PyCapsule::new(py, FnPtrWrapper(fptr), Some(name))
|
||||
}
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
use luminal::dyn_backend::BackendFactory;
|
||||
use luminal::prelude::tracing::warn;
|
||||
use luminal::prelude::*;
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::{PyCapsule, PyCapsuleMethods};
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::compiled_graph::{CompiledGraph, DimParamMap, GraphTranslation, WeightData};
|
||||
@@ -36,20 +38,54 @@ fn resolve_dim_sizes(
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
#[pyo3(signature = (pt2_path, weights_path, backend, search_iters, weight_device_ptrs=None))]
|
||||
#[pyo3(signature = (pt2_path, weights_path, search_iters, factory_capsule, weight_device_ptrs=None))]
|
||||
pub fn process_pt2(
|
||||
pt2_path: &str,
|
||||
weights_path: &str,
|
||||
backend: &str,
|
||||
search_iters: usize,
|
||||
factory_capsule: &Bound<'_, PyCapsule>,
|
||||
weight_device_ptrs: Option<HashMap<String, (u64, usize)>>,
|
||||
) -> PyResult<CompiledGraph> {
|
||||
let factory: BackendFactory = {
|
||||
let expected = ::luminal::dyn_backend::BACKEND_FACTORY_CAPSULE_NAME;
|
||||
match factory_capsule.name()? {
|
||||
Some(name) => {
|
||||
// SAFETY: the &CStr is used immediately (for a byte-wise
|
||||
// comparison) and never stored; the capsule is borrowed for
|
||||
// the duration of this function, so the name pointer stays
|
||||
// valid for as long as we read it here.
|
||||
let actual = unsafe { name.as_cstr() };
|
||||
if actual != expected {
|
||||
return Err(pyo3::exceptions::PyValueError::new_err(format!(
|
||||
"factory_capsule has wrong name: expected {:?}, got {:?}",
|
||||
expected, actual,
|
||||
)));
|
||||
}
|
||||
}
|
||||
None => {
|
||||
return Err(pyo3::exceptions::PyValueError::new_err(
|
||||
"factory_capsule has no name; expected \"luminal.backend_factory\"",
|
||||
));
|
||||
}
|
||||
}
|
||||
let wrapper_ptr = factory_capsule
|
||||
.pointer_checked(Some(expected))
|
||||
.map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("{e}")))?
|
||||
.as_ptr() as *const *const std::ffi::c_void;
|
||||
let fn_ptr = unsafe { *wrapper_ptr };
|
||||
if fn_ptr.is_null() {
|
||||
return Err(pyo3::exceptions::PyValueError::new_err(
|
||||
"factory_capsule inner function pointer is null",
|
||||
));
|
||||
}
|
||||
unsafe { std::mem::transmute(fn_ptr) }
|
||||
};
|
||||
compile_pt2(
|
||||
pt2_path,
|
||||
weights_path,
|
||||
backend,
|
||||
search_iters,
|
||||
weight_device_ptrs.unwrap_or_default(),
|
||||
factory,
|
||||
)
|
||||
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e:#}")))
|
||||
}
|
||||
@@ -57,14 +93,14 @@ pub fn process_pt2(
|
||||
fn compile_pt2(
|
||||
pt2_path: &str,
|
||||
weights_path: &str,
|
||||
backend: &str,
|
||||
search_iters: usize,
|
||||
weight_device_ptrs: HashMap<String, (u64, usize)>,
|
||||
factory: BackendFactory,
|
||||
) -> anyhow::Result<CompiledGraph> {
|
||||
let (translation, mut weights) = translate_pt2(pt2_path, weights_path)?;
|
||||
weights.device_ptrs = weight_device_ptrs;
|
||||
|
||||
CompiledGraph::parse_graph(translation, weights, backend, search_iters)
|
||||
CompiledGraph::parse_graph(translation, weights, factory, search_iters)
|
||||
.map_err(|e| anyhow::anyhow!(e))
|
||||
}
|
||||
|
||||
|
||||
@@ -160,7 +160,31 @@ pub fn parse_pt2(path: &str) -> Result<ParsedPT2> {
|
||||
let file = File::open(path).with_context(|| format!("Failed to open PT2 file: {path}"))?;
|
||||
let mut archive = ZipArchive::new(file).context("Failed to read PT2 ZIP archive")?;
|
||||
|
||||
// Determine archive prefix from the first entry
|
||||
// Torch >= 2.6 uses a flat archive with no prefix directory; detect by presence of the
|
||||
// well-known root-level file. Older torch used a prefix (e.g. "archive/models/model.json").
|
||||
let is_new_format = archive
|
||||
.file_names()
|
||||
.any(|n| n == "serialized_exported_program.json");
|
||||
|
||||
if is_new_format {
|
||||
let program: ExportedProgram = {
|
||||
let mut entry = archive.by_name("serialized_exported_program.json")?;
|
||||
let mut buf = String::new();
|
||||
entry.read_to_string(&mut buf)?;
|
||||
serde_json::from_str(&buf)
|
||||
.context("Failed to parse serialized_exported_program.json")?
|
||||
};
|
||||
// Tensor constants live in serialized_constants.pt; Python extracts them
|
||||
// and loads them post-compile via set_weight_from_ptr.
|
||||
return Ok(ParsedPT2 {
|
||||
program,
|
||||
constants_config: None,
|
||||
archive_prefix: String::new(),
|
||||
pt2_path: path.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
// Old prefix-based format.
|
||||
let archive_prefix = {
|
||||
let first = archive
|
||||
.file_names()
|
||||
|
||||
@@ -1,99 +0,0 @@
|
||||
use luminal::hlir::NativeData;
|
||||
use luminal::prelude::*;
|
||||
#[cfg(feature = "cuda")]
|
||||
use luminal_cuda_lite::cudarc::driver::{CudaContext, CudaStream};
|
||||
#[cfg(feature = "cuda")]
|
||||
use luminal_cuda_lite::runtime::CudaRuntime;
|
||||
use rustc_hash::FxHashMap;
|
||||
#[cfg(feature = "cuda")]
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::typed_data::TypedData;
|
||||
|
||||
/// Enum wrapper for runtime backends allowing runtime selection.
|
||||
pub enum RuntimeBackend {
|
||||
Native(NativeRuntime),
|
||||
#[cfg(feature = "cuda")]
|
||||
Cuda(Box<CudaRuntime>),
|
||||
}
|
||||
|
||||
impl RuntimeBackend {
|
||||
/// Set input data for a tensor node (dtype-aware).
|
||||
pub fn set_data(&mut self, node: NodeIndex, data: TypedData) {
|
||||
match self {
|
||||
RuntimeBackend::Native(rt) => {
|
||||
let native: NativeData = data.into();
|
||||
rt.set_data(node, native);
|
||||
}
|
||||
#[cfg(feature = "cuda")]
|
||||
RuntimeBackend::Cuda(rt) => {
|
||||
// CUDA runtime stores raw bytes — just upload directly
|
||||
rt.set_data(node, data.bytes);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Set input data from a Vec<f32> (convenience for backward compatibility).
|
||||
pub fn set_data_f32(&mut self, node: NodeIndex, data: Vec<f32>) {
|
||||
match self {
|
||||
RuntimeBackend::Native(rt) => rt.set_data(node, data),
|
||||
#[cfg(feature = "cuda")]
|
||||
RuntimeBackend::Cuda(rt) => rt.set_data(node, data),
|
||||
}
|
||||
}
|
||||
|
||||
/// Execute the compiled graph.
|
||||
pub fn execute(&mut self, dyn_map: &FxHashMap<char, usize>) {
|
||||
match self {
|
||||
RuntimeBackend::Native(rt) => rt.execute(dyn_map),
|
||||
#[cfg(feature = "cuda")]
|
||||
RuntimeBackend::Cuda(rt) => rt.execute(dyn_map),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get output data as f32 from a tensor node.
|
||||
pub fn get_f32(&self, node: NodeIndex) -> Vec<f32> {
|
||||
match self {
|
||||
RuntimeBackend::Native(rt) => rt.get_f32(node).to_vec(),
|
||||
#[cfg(feature = "cuda")]
|
||||
RuntimeBackend::Cuda(rt) => rt.get_f32(node),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the name of the active backend.
|
||||
pub fn name(&self) -> &'static str {
|
||||
match self {
|
||||
RuntimeBackend::Native(_) => "native",
|
||||
#[cfg(feature = "cuda")]
|
||||
RuntimeBackend::Cuda(_) => "cuda",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Two-phase initialization for CUDA (required because profiling executes graph)
|
||||
// ============================================================================
|
||||
|
||||
/// Prepare CUDA runtime: build search space and create runtime, but don't search yet.
|
||||
/// Returns the unoptimized runtime that can have data set on it.
|
||||
///
|
||||
/// Use this with `finalize_cuda` for proper CUDA initialization:
|
||||
/// 1. Call `prepare_cuda` to get the runtime
|
||||
/// 2. Set data on the runtime using `rt.set_data(node_id, data)`
|
||||
/// 3. Call `finalize_cuda` to run profiling with data available
|
||||
#[cfg(feature = "cuda")]
|
||||
pub fn prepare_cuda(context: &mut Graph) -> Result<(CudaRuntime, Arc<CudaStream>), String> {
|
||||
let cuda_ctx =
|
||||
CudaContext::new(0).map_err(|e| format!("Failed to init CUDA context: {}", e))?;
|
||||
let stream = cuda_ctx.default_stream();
|
||||
context.build_search_space::<CudaRuntime>();
|
||||
let rt = CudaRuntime::initialize(stream.clone());
|
||||
Ok((rt, stream))
|
||||
}
|
||||
|
||||
/// Finalize CUDA runtime: run search with data already set.
|
||||
#[cfg(feature = "cuda")]
|
||||
pub fn finalize_cuda(context: &mut Graph, rt: CudaRuntime) -> RuntimeBackend {
|
||||
let optimized_rt = context.search(rt, 10);
|
||||
RuntimeBackend::Cuda(Box::new(optimized_rt))
|
||||
}
|
||||
@@ -51,6 +51,7 @@ impl<'a> Translator<'a> {
|
||||
"torch.ops.aten.sub.Scalar" => self.translate_binary_scalar_op(node, BinaryOp::Sub)?,
|
||||
"torch.ops.aten.div.Tensor" => self.translate_binary_op(node, BinaryOp::Div)?,
|
||||
"torch.ops.aten.div.Scalar" => self.translate_binary_scalar_op(node, BinaryOp::Div)?,
|
||||
"torch.ops.aten.div.Tensor_mode" => self.translate_div_tensor_mode(node)?,
|
||||
|
||||
// Unary ops
|
||||
"torch.ops.aten.neg.default" => self.translate_unary_op(node, |a| a * (-1.0))?,
|
||||
@@ -71,6 +72,8 @@ impl<'a> Translator<'a> {
|
||||
"torch.ops.aten.log.default" => self.translate_unary_op(node, |a| a.log())?,
|
||||
"torch.ops.aten.log2.default" => self.translate_unary_op(node, |a| a.log2())?,
|
||||
"torch.ops.aten.exp2.default" => self.translate_unary_op(node, |a| a.exp2())?,
|
||||
"torch.ops.aten.sign.default" => self.translate_sign(node)?,
|
||||
"torch.ops.aten.bitwise_not.default" => self.translate_bitwise_not(node)?,
|
||||
|
||||
// Cast
|
||||
"torch.ops.aten._to_copy.default" => self.translate_to_copy(node)?,
|
||||
@@ -109,6 +112,7 @@ impl<'a> Translator<'a> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
if !a.shape.is_contiguous() { a + 0.0 } else { a }
|
||||
}
|
||||
"torch.ops.aten.argsort.default" => self.translate_argsort(node)?,
|
||||
|
||||
// Matmul
|
||||
"torch.ops.aten.mm.default" | "torch.ops.aten.bmm.default" => {
|
||||
@@ -159,6 +163,8 @@ impl<'a> Translator<'a> {
|
||||
|
||||
// Where
|
||||
"torch.ops.aten.where.self" => self.translate_where(node)?,
|
||||
"torch.ops.aten.where.ScalarOther" => self.translate_where_scalar_other(node)?,
|
||||
"torch.ops.aten.masked_fill.Scalar" => self.translate_masked_fill_scalar(node)?,
|
||||
|
||||
// Pow
|
||||
"torch.ops.aten.pow.Tensor_Scalar" => {
|
||||
@@ -176,6 +182,21 @@ impl<'a> Translator<'a> {
|
||||
// Creation ops
|
||||
"torch.ops.aten.arange.start_step" => self.translate_arange(node)?,
|
||||
"torch.ops.aten.full.default" => self.translate_full(node)?,
|
||||
"torch.ops.aten.full_like.default" => self.translate_full_like(node)?,
|
||||
"torch.ops.aten.empty_permuted.default"
|
||||
| "torch.ops.aten.empty.memory_format" => self.translate_empty(node)?,
|
||||
"torch.ops.aten.histc.default" => self.translate_histc(node)?,
|
||||
|
||||
// Grouped matmul (MoE expert dispatch).
|
||||
// aten._grouped_mm is the native op; transformers::grouped_mm_fallback
|
||||
// is a Python-implemented custom_op (transformers/integrations/moe.py)
|
||||
// used by HF MoE when _grouped_mm isn't available for the activation
|
||||
// dtype. Both have identical (input, weight, offs) signature; route
|
||||
// both through the same batched-matmul + group-mask lowering.
|
||||
"torch.ops.aten._grouped_mm.default"
|
||||
| "torch.ops.transformers.grouped_mm_fallback.default" => {
|
||||
self.translate_grouped_mm(node)?
|
||||
}
|
||||
"torch.ops.aten.scalar_tensor.default" => {
|
||||
let val = self.get_float_arg(node, 0)? as f32;
|
||||
self.graph.constant_float(val)
|
||||
@@ -185,6 +206,7 @@ impl<'a> Translator<'a> {
|
||||
"torch.ops.aten.lt.Scalar" => self.translate_scalar_comparison(node, |a, s| a.lt(s))?,
|
||||
"torch.ops.aten.ge.Scalar" => self.translate_scalar_comparison(node, |a, s| a.ge(s))?,
|
||||
"torch.ops.aten.le.Scalar" => self.translate_scalar_comparison(node, |a, s| a.le(s))?,
|
||||
"torch.ops.aten.eq.Scalar" => self.translate_scalar_comparison(node, |a, s| a.eq(s))?,
|
||||
|
||||
// Tensor comparisons
|
||||
"torch.ops.aten.ne.Scalar" => {
|
||||
@@ -219,7 +241,11 @@ impl<'a> Translator<'a> {
|
||||
let b = b.cast(DType::F32);
|
||||
(a * b).cast(DType::Bool)
|
||||
}
|
||||
"torch.ops.aten.logical_or.default" => {
|
||||
"torch.ops.aten.bitwise_or.Tensor" | "torch.ops.aten.logical_or.default" => {
|
||||
// Both arms use the same bool-OR lowering. Gemma-4's sliding+full
|
||||
// attention mask fusion emits bitwise_or on boolean tensors; the
|
||||
// integer semantics of bitwise_or aren't exercised by any op in
|
||||
// the test suite, so we rely on inputs being boolean-typed.
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let b = self.get_input_tensor(node, 1)?;
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
@@ -271,24 +297,40 @@ impl<'a> Translator<'a> {
|
||||
}
|
||||
"torch.ops.aten.erf.default" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
// Abramowitz & Stegun approximation 7.1.28 (max error ~1.5e-7)
|
||||
// erf(x) = sign(x) * (1 - poly(t) * exp(-x^2))
|
||||
// where t = 1/(1 + 0.3275911*|x|), poly in Horner form
|
||||
let ax = a.abs();
|
||||
let x2 = a * a;
|
||||
let t = (ax * 0.3275911_f32 + 1.0).reciprocal();
|
||||
// Horner: t*(a1 + t*(a2 + t*(a3 + t*(a4 + t*a5))))
|
||||
let poly = t
|
||||
* (t * (t
|
||||
* (t * (t * 1.061_405_4_f32 + (-1.453_152_1_f32)) + 1.421_413_8_f32)
|
||||
+ (-0.284_496_72_f32))
|
||||
+ 0.254_829_6_f32);
|
||||
let result_abs =
|
||||
self.graph.constant_float(1.0).expand_rhs(a.shape) - poly * (x2 * (-1.0)).exp();
|
||||
// sign(x) = 2*(x >= 0) - 1
|
||||
let zero = self.graph.constant_float(0.0).expand_rhs(a.shape);
|
||||
let sign = a.ge(zero).cast(DType::F32) * 2.0 - 1.0;
|
||||
result_abs * sign
|
||||
self.erf_approx(a)
|
||||
}
|
||||
"torch.ops.aten.gelu.default" => {
|
||||
let a_in = self.get_input_tensor(node, 0)?;
|
||||
// PyTorch's gelu has a kwarg `approximate` (default "none").
|
||||
// "none" → 0.5 * x * (1 + erf(x / sqrt(2))) (exact)
|
||||
// "tanh" → 0.5 * x * (1 + tanh(c * (x + 0.044715*x^3)))
|
||||
// where c = sqrt(2/pi) ≈ 0.7978845608
|
||||
// Gemma family uses approximate="tanh" but lowering may emit
|
||||
// either form; honour whatever the FX graph carries.
|
||||
let approximate = node.inputs.iter().find_map(|input| {
|
||||
if input.name == "approximate"
|
||||
&& let Argument::Other(val) = &input.arg
|
||||
{
|
||||
return val.as_str().map(|s| s.to_string());
|
||||
}
|
||||
None
|
||||
});
|
||||
// Promote to F32 around the constants/comparisons (same reason
|
||||
// as clamp/erf — luminal binary ops assert matching dtypes).
|
||||
let orig = a_in.dtype;
|
||||
let a = if orig == DType::F32 { a_in } else { a_in.cast(DType::F32) };
|
||||
let half = self.graph.constant_float(0.5).expand_rhs(a.shape);
|
||||
let one = self.graph.constant_float(1.0).expand_rhs(a.shape);
|
||||
let result = if approximate.as_deref() == Some("tanh") {
|
||||
let x2 = a * a;
|
||||
let inner = a * (x2 * 0.044715_f32 + 1.0) * 0.797_884_56_f32;
|
||||
half * a * (one + inner.tanh())
|
||||
} else {
|
||||
let scaled = a * 0.707_106_77_f32; // 1 / sqrt(2)
|
||||
let erf_val = self.erf_approx(scaled);
|
||||
half * a * (one + erf_val)
|
||||
};
|
||||
if orig == DType::F32 { result } else { result.cast(orig) }
|
||||
}
|
||||
"torch.ops.aten.isnan.default" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
@@ -349,7 +391,17 @@ impl<'a> Translator<'a> {
|
||||
|
||||
// Scatter ops
|
||||
"torch.ops.aten.scatter.src" => self.translate_scatter_src(node)?,
|
||||
"torch.ops.aten.index_put.default" => self.translate_index_put(node)?,
|
||||
"torch.ops.aten.scatter.value" => self.translate_scatter_value(node)?,
|
||||
"torch.ops.aten.index_put_.default" | "torch.ops.aten.index_put.default" => {
|
||||
self.translate_index_put(node)?
|
||||
}
|
||||
|
||||
// Integer routing math
|
||||
"torch.ops.aten.floor_divide.default" => self.translate_floor_divide(node)?,
|
||||
|
||||
// Triangular
|
||||
"torch.ops.aten.tril.default" => self.translate_tril(node)?,
|
||||
"torch.ops.aten.triu.default" => self.translate_triu(node)?,
|
||||
|
||||
// TopK — handles its own output storage, returns early
|
||||
"torch.ops.aten.topk.default" => {
|
||||
@@ -357,6 +409,12 @@ impl<'a> Translator<'a> {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Sort — handles its own output storage, returns early
|
||||
"torch.ops.aten.sort.default" => {
|
||||
self.translate_sort(node)?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Split
|
||||
"torch.ops.aten.split_with_sizes.default" => self.translate_split_with_sizes(node)?,
|
||||
|
||||
|
||||
@@ -68,6 +68,9 @@ impl<'a> Translator<'a> {
|
||||
fn translate_graph(&mut self) -> Result<()> {
|
||||
self.create_inputs()?;
|
||||
|
||||
// Per-block partitioning is now handled automatically by the upstream
|
||||
// loop-rolling prepass; this translator no longer needs to insert
|
||||
// manual graph breaks at RMSNorm boundaries.
|
||||
let nodes = &self.parsed.program.graph_module.graph.nodes;
|
||||
for (i, node) in nodes.iter().enumerate() {
|
||||
self.translate_node(node)
|
||||
@@ -77,11 +80,12 @@ impl<'a> Translator<'a> {
|
||||
let output_names = self.parsed.output_names();
|
||||
for name in &output_names {
|
||||
let tensor = self.get_tensor(name)?;
|
||||
// Cast non-float outputs (Bool, Int) to F32 for the runtime.
|
||||
// Preserve F16/BF16/F32 as-is to avoid corrupting half-precision models.
|
||||
let tensor = match tensor.dtype {
|
||||
DType::Bool | DType::Int => tensor.cast(DType::F32) + 0.0,
|
||||
_ => tensor + 0.0,
|
||||
let tensor = if tensor.dtype == DType::Bool {
|
||||
tensor.cast(DType::Int).cast(DType::Bool)
|
||||
} else if tensor.dtype == DType::Int {
|
||||
tensor
|
||||
} else {
|
||||
tensor + 0.0
|
||||
};
|
||||
tensor.output();
|
||||
self.output_ids.push((name.clone(), tensor.id));
|
||||
@@ -155,6 +159,12 @@ impl<'a> Translator<'a> {
|
||||
|
||||
// --- Helper methods ---
|
||||
|
||||
pub(crate) fn tensor_meta(&self, name: &str) -> Option<&TensorMeta> {
|
||||
self.extra_tensor_values
|
||||
.get(name)
|
||||
.or_else(|| self.parsed.tensor_meta(name))
|
||||
}
|
||||
|
||||
pub(crate) fn get_tensor(&self, name: &str) -> Result<GraphTensor> {
|
||||
self.tensors
|
||||
.get(name)
|
||||
@@ -329,3 +339,4 @@ impl<'a> Translator<'a> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -6,6 +6,11 @@ use crate::pt2_util::*;
|
||||
|
||||
use super::Translator;
|
||||
|
||||
const SCATTER_INPUT_ARG: usize = 0;
|
||||
const SCATTER_DIM_ARG: usize = 1;
|
||||
const SCATTER_INDEX_ARG: usize = 2;
|
||||
const SCATTER_VALUE_ARG: usize = 3;
|
||||
|
||||
impl<'a> Translator<'a> {
|
||||
pub(crate) fn translate_reshape(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
@@ -359,27 +364,125 @@ impl<'a> Translator<'a> {
|
||||
Ok(a.scatter_elements(indices.cast(DType::Int), src, dim))
|
||||
}
|
||||
|
||||
pub(crate) fn translate_scatter_value(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, SCATTER_INPUT_ARG)?;
|
||||
let dim = self.get_int_arg(node, SCATTER_DIM_ARG)?;
|
||||
let dim = normalize_dim(dim, a.shape.len());
|
||||
let indices = self.get_input_tensor(node, SCATTER_INDEX_ARG)?;
|
||||
let value_arg = &node
|
||||
.inputs
|
||||
.get(SCATTER_VALUE_ARG)
|
||||
.context("scatter.value missing value input")?
|
||||
.arg;
|
||||
let value = if let Some(b) = value_arg.as_bool() {
|
||||
self.graph.constant(if b { 1 } else { 0 }).cast(a.dtype)
|
||||
} else if let Some(i) = value_arg.as_int() {
|
||||
self.graph.constant(i).cast(a.dtype)
|
||||
} else if let Some(f) = value_arg.as_float() {
|
||||
self.graph.constant_float(f as f32).cast(a.dtype)
|
||||
} else {
|
||||
bail!("scatter.value: unsupported scalar argument {:?}", value_arg);
|
||||
}
|
||||
.expand_rhs(indices.shape);
|
||||
Ok(a.scatter_elements(indices.cast(DType::Int), value, dim))
|
||||
}
|
||||
|
||||
pub(crate) fn translate_index_put(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let index_names = node.inputs[1]
|
||||
.arg
|
||||
.as_tensors()
|
||||
.context("index_put: indices not as_tensors")?;
|
||||
let values = self.get_input_tensor(node, 2)?;
|
||||
|
||||
if index_names.len() == 1 {
|
||||
let indices = self.get_tensor(&index_names[0].name)?.cast(DType::Int);
|
||||
// scatter_nd expects indices of shape [batch, K] where K = number of index dims.
|
||||
// PT2's index_put gives 1D indices [batch]; reshape to [batch, 1].
|
||||
let indices = if indices.shape.len() == 1 {
|
||||
indices.expand_dim(1, Expression::from(1usize))
|
||||
} else {
|
||||
indices
|
||||
};
|
||||
Ok(a.scatter_nd(indices, values))
|
||||
} else {
|
||||
bail!("index_put with multiple index tensors not yet supported");
|
||||
// --- all-tensor indices: bool-mask blend or scatter_nd ---
|
||||
if let Some(index_names) = node.inputs[1].arg.as_tensors() {
|
||||
if index_names.len() == 1 {
|
||||
let idx_tensor = self.get_tensor(&index_names[0].name)?;
|
||||
|
||||
// Boolean-mask index_put: when the only index is a Bool tensor whose
|
||||
// shape matches the data tensor, PyTorch semantics are
|
||||
// data[mask] = value ↔ where(mask, value, data)
|
||||
// NOT a scatter into positions. Casting the Bool mask to Int and
|
||||
// feeding it to scatter_nd would reinterpret True/False as row
|
||||
// indices 1/0 and silently corrupt the data. Reproducer:
|
||||
// x = arange(16).reshape(4, 4); mask = zeros(4, 4, dtype=bool)
|
||||
// y = x.clone(); y[mask] = 99 # eager: y == x (no-op)
|
||||
// Pre-fix the compiled graph wrote 99 to row 0; this branch
|
||||
// ensures the bool-mask path lowers to a where-blend instead.
|
||||
if idx_tensor.dtype == DType::Bool && idx_tensor.shape.dims == a.shape.dims {
|
||||
let mask_f = idx_tensor.cast(a.dtype);
|
||||
let values_b = values.cast(a.dtype).expand_rhs(a.shape);
|
||||
// Implements where(mask, value, a) as
|
||||
// a*(1 - mask) + value*mask
|
||||
// — works without a dedicated cond op for any numeric dtype.
|
||||
let one = self
|
||||
.graph
|
||||
.constant_float(1.0)
|
||||
.cast(a.dtype)
|
||||
.expand_rhs(a.shape);
|
||||
return Ok(a * (one - mask_f) + values_b * mask_f);
|
||||
}
|
||||
|
||||
// Integer-index scatter: index_put with indices=[idx_tensor] writes
|
||||
// into dim 0 of `a` at every position named in idx_tensor (flattened),
|
||||
// broadcasting values across the trailing dims of `a`. Always pad
|
||||
// a trailing size-1 dim so rank-1 and rank-N cases share a path.
|
||||
let indices = idx_tensor.cast(DType::Int);
|
||||
let new_last = indices.shape.len();
|
||||
let indices = indices.expand_dim(new_last, Expression::from(1usize));
|
||||
return Ok(a.scatter_nd(indices, values));
|
||||
}
|
||||
bail!("index_put with multiple all-tensor indices not yet supported");
|
||||
}
|
||||
|
||||
// --- optional-tensor indices: [None, arange_tensor, None, ...] ---
|
||||
// Each None means "all of that dimension"; one tensor means "index into that dim".
|
||||
// StaticCache uses this for KV updates: cache[:, :, position, :] = new_value.
|
||||
if let Some(opt_tensors) = node.inputs[1].arg.as_optional_tensors() {
|
||||
use crate::pt2_schema::OptionalTensorEntry;
|
||||
let mut first_non_none_dim = 0usize;
|
||||
let mut idx_name: Option<String> = None;
|
||||
let mut non_none_count = 0usize;
|
||||
|
||||
for (i, entry) in opt_tensors.iter().enumerate() {
|
||||
if let OptionalTensorEntry::Tensor(t) = entry {
|
||||
if idx_name.is_none() {
|
||||
first_non_none_dim = i;
|
||||
}
|
||||
idx_name = Some(t.as_tensor.name.clone());
|
||||
non_none_count += 1;
|
||||
}
|
||||
}
|
||||
|
||||
if non_none_count != 1 {
|
||||
bail!(
|
||||
"index_put with optional tensors: only single non-None index supported \
|
||||
(got {non_none_count})"
|
||||
);
|
||||
}
|
||||
|
||||
let mut indices = self.get_tensor(&idx_name.unwrap())?.cast(DType::Int);
|
||||
|
||||
// Expand 1-D indices [P] to values.shape for scatter_elements:
|
||||
// Build [1, ..., 1, P, 1, ..., 1] with P at first_non_none_dim, then broadcast.
|
||||
let rank = a.shape.len();
|
||||
// Insert singleton dims before first_non_none_dim
|
||||
for i in 0..first_non_none_dim {
|
||||
indices = indices.expand_dim(i, Expression::from(1usize));
|
||||
}
|
||||
// Insert singleton dims after first_non_none_dim
|
||||
let current_rank = indices.shape.len();
|
||||
for j in current_rank..rank {
|
||||
indices = indices.expand_dim(j, Expression::from(1usize));
|
||||
}
|
||||
// Broadcast singletons to values shape
|
||||
let values_shape: Vec<Expression> = values.shape.dims[..rank].to_vec();
|
||||
indices.shape.expand(values_shape);
|
||||
|
||||
return Ok(a.scatter_elements(indices, values, first_non_none_dim));
|
||||
}
|
||||
|
||||
bail!(
|
||||
"index_put: unsupported indices format: {:?}",
|
||||
node.inputs[1].arg
|
||||
)
|
||||
}
|
||||
|
||||
pub(crate) fn translate_split_with_sizes(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
|
||||
@@ -6,6 +6,27 @@ use crate::pt2_util::*;
|
||||
|
||||
use super::Translator;
|
||||
|
||||
const FULL_SHAPE_ARG: usize = 0;
|
||||
const FULL_VALUE_ARG: usize = 1;
|
||||
|
||||
const FULL_LIKE_INPUT_ARG: usize = 0;
|
||||
const FULL_LIKE_VALUE_ARG: usize = 1;
|
||||
|
||||
const TOPK_INPUT_ARG: usize = 0;
|
||||
const TOPK_K_ARG: usize = 1;
|
||||
const TOPK_DIM_ARG: usize = 2;
|
||||
|
||||
const SORT_INPUT_ARG: usize = 0;
|
||||
const SORT_DIM_ARG: usize = 1;
|
||||
const SORT_DESCENDING_ARG: usize = 2;
|
||||
|
||||
const WHERE_COND_ARG: usize = 0;
|
||||
const WHERE_X_ARG: usize = 1;
|
||||
const WHERE_OTHER_ARG: usize = 2;
|
||||
|
||||
const TRIANGULAR_INPUT_ARG: usize = 0;
|
||||
const TRIANGULAR_DIAGONAL_ARG: usize = 1;
|
||||
|
||||
impl<'a> Translator<'a> {
|
||||
pub(crate) fn translate_arange(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let positional_args: Vec<Expression> = node
|
||||
@@ -30,19 +51,218 @@ impl<'a> Translator<'a> {
|
||||
}
|
||||
|
||||
pub(crate) fn translate_full(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let shape = self.get_exprs_arg(node, 0)?;
|
||||
let shape = self.get_exprs_arg(node, FULL_SHAPE_ARG)?;
|
||||
// fill_value can be float, int, or bool after decomposition
|
||||
let val = if let Ok(f) = self.get_float_arg(node, 1) {
|
||||
let val = if let Ok(f) = self.get_float_arg(node, FULL_VALUE_ARG) {
|
||||
f as f32
|
||||
} else if let Ok(b) = self.get_bool_arg(node, 1) {
|
||||
} else if let Ok(b) = self.get_bool_arg(node, FULL_VALUE_ARG) {
|
||||
if b { 1.0 } else { 0.0 }
|
||||
} else {
|
||||
anyhow::bail!(
|
||||
"full: unsupported fill value type: {:?}",
|
||||
node.inputs.get(1)
|
||||
node.inputs.get(FULL_VALUE_ARG)
|
||||
);
|
||||
};
|
||||
Ok(self.graph.constant_float(val).expand_rhs(shape))
|
||||
let dtype = self.output_meta_dtype(node)?;
|
||||
let value = self.graph.constant_float(val).cast(dtype);
|
||||
Ok(if shape.is_empty() {
|
||||
value
|
||||
} else {
|
||||
value.expand_rhs(shape)
|
||||
})
|
||||
}
|
||||
|
||||
/// Translate `aten.histc.default(input, bins, min, max)` → `Tensor[bins]`.
|
||||
///
|
||||
/// Counts how many input elements fall in each of `bins` equal-width
|
||||
/// buckets over `[min, max]`. PyTorch's histc accepts only 1D input;
|
||||
/// HF MoE forwards emit it on flattened expert-assignment tensors to
|
||||
/// produce per-expert token counts (one_hot + sum, essentially).
|
||||
///
|
||||
/// Implementation: arange over bins, broadcast to [G, N], element-wise
|
||||
/// `(lower <= input < upper)` into a F32 mask, sum over the input axis.
|
||||
/// The right edge of the last bin is technically inclusive in PyTorch;
|
||||
/// we treat it as exclusive — for the typical MoE use (integer expert
|
||||
/// IDs in `[0, num_experts)`), no input ever equals `max` so this is
|
||||
/// indistinguishable.
|
||||
pub(crate) fn translate_histc(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let input = self.get_input_tensor(node, 0)?;
|
||||
let bins = self.get_int_arg(node, 1)? as usize;
|
||||
let min_val = self.get_float_arg(node, 2)? as f32;
|
||||
let max_val = self.get_float_arg(node, 3)? as f32;
|
||||
|
||||
anyhow::ensure!(
|
||||
input.shape.len() == 1,
|
||||
"histc: only 1D input supported (got {}D)",
|
||||
input.shape.len()
|
||||
);
|
||||
let n = input.shape.dims[0];
|
||||
let g = Expression::from(bins);
|
||||
|
||||
let input_f = input.cast(DType::F32);
|
||||
let step = (max_val - min_val) / bins as f32;
|
||||
|
||||
// Per-bin lower edges: arange(bins) * step + min.
|
||||
let bin_idx = self.graph.arange(g).cast(DType::F32);
|
||||
let lower_1d = bin_idx * step + min_val;
|
||||
let upper_1d = lower_1d + step;
|
||||
|
||||
// Broadcast to [G, N] and produce the boolean mask.
|
||||
let input_b = input_f.expand_dim(0, g);
|
||||
let lower = lower_1d.expand_dim(1, n);
|
||||
let upper = upper_1d.expand_dim(1, n);
|
||||
|
||||
let in_lower = input_b.ge(lower).cast(DType::F32);
|
||||
let in_upper = input_b.lt(upper).cast(DType::F32);
|
||||
let mask = in_lower * in_upper;
|
||||
|
||||
Ok(mask.sum(1))
|
||||
}
|
||||
|
||||
/// Translate `aten.empty_permuted.default(size, physical_layout, **kwargs)`
|
||||
/// → zero-filled tensor of shape `size`.
|
||||
///
|
||||
/// PyTorch's `empty_permuted` allocates uninitialized memory with a given
|
||||
/// stride permutation; downstream code typically overwrites every element
|
||||
/// before reading. Luminal's tensor abstraction doesn't expose strides, so
|
||||
/// the physical_layout hint is irrelevant — we just emit a zero tensor of
|
||||
/// the requested shape and dtype. (Same approach works for `aten.empty`
|
||||
/// variants when they show up.)
|
||||
pub(crate) fn translate_empty(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let shape = self.get_exprs_arg(node, 0)?;
|
||||
let dtype = self.output_meta_dtype(node)?;
|
||||
let value = self.graph.constant_float(0.0).cast(dtype);
|
||||
Ok(if shape.is_empty() {
|
||||
value
|
||||
} else {
|
||||
value.expand_rhs(shape)
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn translate_full_like(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let reference = self.get_input_tensor(node, FULL_LIKE_INPUT_ARG)?;
|
||||
let val = if let Ok(f) = self.get_float_arg(node, FULL_LIKE_VALUE_ARG) {
|
||||
f as f32
|
||||
} else if let Ok(b) = self.get_bool_arg(node, FULL_LIKE_VALUE_ARG) {
|
||||
if b { 1.0 } else { 0.0 }
|
||||
} else {
|
||||
anyhow::bail!(
|
||||
"full_like: unsupported fill value type: {:?}",
|
||||
node.inputs.get(FULL_LIKE_VALUE_ARG)
|
||||
);
|
||||
};
|
||||
let dtype = self.output_meta_dtype(node)?;
|
||||
let value = self.graph.constant_float(val).cast(dtype);
|
||||
Ok(value.expand_rhs(reference.shape))
|
||||
}
|
||||
|
||||
fn output_meta_dtype(&self, node: &Node) -> Result<DType> {
|
||||
let output_name = node
|
||||
.outputs
|
||||
.first()
|
||||
.and_then(|o| o.as_tensor.as_ref())
|
||||
.map(|t| t.name.clone())
|
||||
.unwrap_or_default();
|
||||
let meta = self
|
||||
.tensor_meta(&output_name)
|
||||
.context("Missing tensor meta for output dtype")?;
|
||||
Ok(torch_dtype_int_to_luminal(meta.dtype))
|
||||
}
|
||||
|
||||
/// Translate `aten._grouped_mm.default(input, weight, offs)` → `Tensor[S, N]`.
|
||||
///
|
||||
/// Grouped matmul: `input` is `[S, K]` (tokens sorted by expert), `weight` is
|
||||
/// `[G, K, N]` (per-expert weights), `offs` is `[G]` cumulative token counts.
|
||||
/// Output `[S, N]` where token m (in group g s.t. `offs[g-1] <= m < offs[g]`)
|
||||
/// is multiplied by `weight[g]`.
|
||||
///
|
||||
/// Implementation: for each token m we (a) compute its expert id from offs,
|
||||
/// (b) gather only that expert's `[K, N]` slice from weight, and (c) do a
|
||||
/// single per-token matmul. The gather pattern mirrors the rust qwen3_moe
|
||||
/// example's `gather_experts`, which the GLUMoE host-op fusion in
|
||||
/// `luminal_cuda_lite` is designed to recognise.
|
||||
///
|
||||
/// Why not the straightforward `[G, S, K] @ [G, K, N] → [G, S, N]` + mask:
|
||||
/// it forces a full F32 cast of the entire `[G, K, N]` weight tensor as
|
||||
/// search-time intermediate, which OOMs on real MoE checkpoints
|
||||
/// (Qwen3-30B-A3B: 1.5 GB / layer × 48 layers for gate-up alone). Gathering
|
||||
/// first keeps the F32 cast on `[S, K, N]` instead — for prefill (S = top_k)
|
||||
/// that is a 16× shrink (G=128, top_k=8).
|
||||
///
|
||||
/// `offs` flows through as a runtime tensor — the routing decision is computed
|
||||
/// at execution time by the gate network and the same compiled graph handles
|
||||
/// any routing pattern without recompilation.
|
||||
pub(crate) fn translate_grouped_mm(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let input = self.get_input_tensor(node, 0)?;
|
||||
let weight = self.get_input_tensor(node, 1)?;
|
||||
let offs = self.get_input_tensor(node, 2)?;
|
||||
|
||||
anyhow::ensure!(
|
||||
input.shape.len() == 2,
|
||||
"_grouped_mm: input must be 2D, got {}D",
|
||||
input.shape.len()
|
||||
);
|
||||
anyhow::ensure!(
|
||||
weight.shape.len() == 3,
|
||||
"_grouped_mm: weight must be 3D, got {}D",
|
||||
weight.shape.len()
|
||||
);
|
||||
anyhow::ensure!(
|
||||
offs.shape.len() == 1,
|
||||
"_grouped_mm: offs must be 1D, got {}D",
|
||||
offs.shape.len()
|
||||
);
|
||||
|
||||
let s = input.shape.dims[0];
|
||||
let g = weight.shape.dims[0];
|
||||
let k = weight.shape.dims[1];
|
||||
let n = weight.shape.dims[2];
|
||||
|
||||
// expert_id[m] = number of g s.t. m >= offs[g]
|
||||
// = first g s.t. m < offs[g], i.e. the expert assigned to m.
|
||||
// Clamp to [0, G-1] before using as gather index. Matches HF MoE's
|
||||
// `expert_ids.clamp(0, num_experts-1)` for invalid IDs from EP, AND
|
||||
// protects search-time profiling: dummy-1 input bytes give offs=[1,…,1],
|
||||
// which makes `m >= offs[g]` true for m≥1 and pushes expert_id to G,
|
||||
// out of bounds for the weight gather. Clamping keeps the gather safe.
|
||||
let g_max_f = (g
|
||||
.to_usize()
|
||||
.context("_grouped_mm: G (num_experts) must be concrete")?
|
||||
as f32)
|
||||
- 1.0;
|
||||
let offs_f = offs.cast(DType::F32);
|
||||
let s_arange_f = self.graph.arange(s).cast(DType::F32);
|
||||
let ge_boundary = s_arange_f
|
||||
.expand_dim(0, g)
|
||||
.ge(offs_f.expand_dim(1, s))
|
||||
.cast(DType::F32);
|
||||
let expert_id = ge_boundary
|
||||
.sum(0)
|
||||
.minimum_f32(g_max_f)
|
||||
.cast(DType::Int); // [S] Int
|
||||
|
||||
// Flat gather index into weight (treated as a length-G*K*N 1D buffer):
|
||||
// flat[m, k_, n_] = expert_id[m] * (K*N) + k_ * N + n_
|
||||
// Encoded as `Mul(expert_id, Iota(io_const)) + Iota(MIter, K*N)` so the
|
||||
// resulting Gather matches the GLUMoE / gather-experts egglog patterns.
|
||||
let io = k * n;
|
||||
let base = expert_id * io;
|
||||
let within = self.graph.iota(Expression::from('z'), (k, n));
|
||||
let exp_base = base.expand_dim(1, k).expand_dim(2, n);
|
||||
let exp_within = within.expand_dim(0, s);
|
||||
let flat_idx = exp_base + exp_within;
|
||||
|
||||
// Gather → [S, K, N]. Preserves weight's native dtype (bf16 stays bf16).
|
||||
let weight_gathered = weight.gather(flat_idx);
|
||||
|
||||
// Cast for matmul — now on the small gathered slice, not the full weight.
|
||||
let input_f = input.cast(DType::F32);
|
||||
let weight_f = weight_gathered.cast(DType::F32);
|
||||
|
||||
// Per-token matmul: [S, 1, K] @ [S, K, N] → [S, 1, N] → [S, N].
|
||||
let result = input_f.unsqueeze(1).matmul(weight_f).squeeze(1);
|
||||
|
||||
Ok(result.cast(input.dtype))
|
||||
}
|
||||
|
||||
pub(crate) fn translate_where(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
@@ -62,11 +282,64 @@ impl<'a> Translator<'a> {
|
||||
Ok(c * x_f + (one - c) * y_f)
|
||||
}
|
||||
|
||||
pub(crate) fn translate_where_scalar_other(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let cond = self.get_input_tensor(node, WHERE_COND_ARG)?;
|
||||
let x = self.get_input_tensor(node, WHERE_X_ARG)?;
|
||||
let other_val = self.get_float_arg(node, WHERE_OTHER_ARG)? as f32;
|
||||
// Broadcast cond and x to a common shape
|
||||
let (cond_b, x_b) = broadcast_binary(cond, x);
|
||||
let c = cond_b.cast(DType::F32);
|
||||
let one = self.graph.constant_float(1.0).expand_rhs(c.shape);
|
||||
let other = self.graph.constant_float(other_val).expand_rhs(c.shape);
|
||||
Ok(c * x_b + (one - c) * other)
|
||||
}
|
||||
|
||||
pub(crate) fn translate_tril(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
self.translate_triangular(node, false)
|
||||
}
|
||||
|
||||
pub(crate) fn translate_triu(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
self.translate_triangular(node, true)
|
||||
}
|
||||
|
||||
fn translate_triangular(&mut self, node: &Node, upper: bool) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, TRIANGULAR_INPUT_ARG)?;
|
||||
let diagonal = if node.inputs.len() > TRIANGULAR_DIAGONAL_ARG {
|
||||
self.get_int_arg(node, TRIANGULAR_DIAGONAL_ARG).unwrap_or(0) as i32
|
||||
} else {
|
||||
0
|
||||
};
|
||||
let dims = a.shape.dims;
|
||||
let rows = dims[dims.len() - 2];
|
||||
let cols = dims[dims.len() - 1];
|
||||
let (r_val, c_val) = match (rows.to_usize(), cols.to_usize()) {
|
||||
(Some(r), Some(c)) => (r, c),
|
||||
_ => anyhow::bail!("tril/triu requires concrete matrix dimensions"),
|
||||
};
|
||||
let size = r_val.max(c_val);
|
||||
let mask = if upper {
|
||||
self.graph.triu(size, diagonal)
|
||||
} else {
|
||||
self.graph.tril(size, diagonal)
|
||||
}
|
||||
.cast(DType::F32);
|
||||
let mask = if rows != cols {
|
||||
mask.slice_along(0..r_val, 0).slice_along(0..c_val, 1)
|
||||
} else {
|
||||
mask
|
||||
};
|
||||
let mut mask_expanded = mask;
|
||||
for i in (0..dims.len() - 2).rev() {
|
||||
mask_expanded = mask_expanded.expand_dim(0, dims[i]);
|
||||
}
|
||||
Ok(a * mask_expanded)
|
||||
}
|
||||
|
||||
pub(crate) fn translate_topk(&mut self, node: &Node) -> Result<()> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let k = self.get_int_arg(node, 1)? as usize;
|
||||
let dim = if node.inputs.len() > 2 {
|
||||
self.get_int_arg(node, 2).unwrap_or(-1)
|
||||
let a = self.get_input_tensor(node, TOPK_INPUT_ARG)?;
|
||||
let k = self.get_int_arg(node, TOPK_K_ARG)? as usize;
|
||||
let dim = if node.inputs.len() > TOPK_DIM_ARG {
|
||||
self.get_int_arg(node, TOPK_DIM_ARG).unwrap_or(-1)
|
||||
} else {
|
||||
-1
|
||||
};
|
||||
@@ -86,13 +359,10 @@ impl<'a> Translator<'a> {
|
||||
None
|
||||
};
|
||||
|
||||
// Use full argsort then slice, rather than topk_indexes/topk_values directly.
|
||||
// This avoids a CUDA gather kernel bug when data and index shapes differ
|
||||
// along the gather axis (topk_indexes returns a sliced tensor).
|
||||
let full_argsort = a.argsort(dim, true);
|
||||
// Build top-k outputs from a full stable argsort, then slice to k.
|
||||
let full_argsort = a.stable_argsort(dim, true);
|
||||
|
||||
// Only build each branch when its output is consumed.
|
||||
// Dead nodes in the graph can confuse the CUDA optimizer.
|
||||
// Only build the outputs that are consumed.
|
||||
if let Some(val_name) = values_name
|
||||
&& !val_name.is_empty()
|
||||
{
|
||||
@@ -100,8 +370,7 @@ impl<'a> Translator<'a> {
|
||||
self.tensors.insert(val_name, values);
|
||||
}
|
||||
if let Some(idx_name) = indices_name {
|
||||
// Materialize Int indices as F32 with `* 1.0` to force a contiguous copy.
|
||||
// Without this, CUDA can't correctly read the sliced Int view.
|
||||
// Materialize the sliced indices through a copy before storing them.
|
||||
let indices = full_argsort.slice_along(..k, dim) * 1.0;
|
||||
self.tensors.insert(idx_name, indices);
|
||||
}
|
||||
@@ -109,6 +378,51 @@ impl<'a> Translator<'a> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn translate_sort(&mut self, node: &Node) -> Result<()> {
|
||||
let a = self.get_input_tensor(node, SORT_INPUT_ARG)?;
|
||||
let dim = if node.inputs.len() > SORT_DIM_ARG {
|
||||
self.get_int_arg(node, SORT_DIM_ARG).unwrap_or(-1)
|
||||
} else {
|
||||
-1
|
||||
};
|
||||
let descending = if node.inputs.len() > SORT_DESCENDING_ARG {
|
||||
self.get_bool_arg(node, SORT_DESCENDING_ARG)
|
||||
.unwrap_or(false)
|
||||
} else {
|
||||
false
|
||||
};
|
||||
let dim = normalize_dim(dim, a.shape.len());
|
||||
|
||||
// Determine output names (sort returns (values, indices))
|
||||
let values_name = node
|
||||
.outputs
|
||||
.first()
|
||||
.and_then(|o| o.as_tensor.as_ref().map(|t| t.name.clone()));
|
||||
let indices_name =
|
||||
if let Some(ts) = node.outputs.first().and_then(|o| o.as_tensors.as_ref()) {
|
||||
ts.get(1).map(|t| t.name.clone())
|
||||
} else if node.outputs.len() > 1 {
|
||||
node.outputs[1].as_tensor.as_ref().map(|t| t.name.clone())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let full_argsort = a.stable_argsort(dim, descending);
|
||||
|
||||
if let Some(val_name) = values_name
|
||||
&& !val_name.is_empty()
|
||||
{
|
||||
let values = a.gather_elements(full_argsort, dim);
|
||||
self.tensors.insert(val_name, values);
|
||||
}
|
||||
if let Some(idx_name) = indices_name {
|
||||
let indices = full_argsort * 1.0;
|
||||
self.tensors.insert(idx_name, indices);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn translate_wrap_set_grad(&mut self, node: &Node) -> Result<()> {
|
||||
let subgraph = node.inputs[1]
|
||||
.arg
|
||||
|
||||
@@ -6,7 +6,38 @@ use crate::pt2_util::{broadcast_binary, torch_dtype_int_to_luminal};
|
||||
|
||||
use super::Translator;
|
||||
|
||||
const ARGSORT_INPUT_ARG: usize = 0;
|
||||
const ARGSORT_DIM_ARG: usize = 1;
|
||||
const ARGSORT_DESCENDING_ARG: usize = 2;
|
||||
|
||||
const MASKED_FILL_INPUT_ARG: usize = 0;
|
||||
const MASKED_FILL_MASK_ARG: usize = 1;
|
||||
const MASKED_FILL_VALUE_ARG: usize = 2;
|
||||
|
||||
const FLOOR_DIVIDE_INPUT_ARG: usize = 0;
|
||||
const FLOOR_DIVIDE_OTHER_ARG: usize = 1;
|
||||
|
||||
const DIV_MODE_INPUT_ARG: usize = 0;
|
||||
const DIV_MODE_OTHER_ARG: usize = 1;
|
||||
|
||||
impl<'a> Translator<'a> {
|
||||
pub(crate) fn translate_argsort(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, ARGSORT_INPUT_ARG)?;
|
||||
let dim = if node.inputs.len() > ARGSORT_DIM_ARG {
|
||||
self.get_int_arg(node, ARGSORT_DIM_ARG).unwrap_or(-1)
|
||||
} else {
|
||||
-1
|
||||
};
|
||||
let descending = if node.inputs.len() > ARGSORT_DESCENDING_ARG {
|
||||
self.get_bool_arg(node, ARGSORT_DESCENDING_ARG)
|
||||
.unwrap_or(false)
|
||||
} else {
|
||||
false
|
||||
};
|
||||
let dim = crate::pt2_util::normalize_dim(dim, a.shape.len());
|
||||
Ok(a.stable_argsort(dim, descending))
|
||||
}
|
||||
|
||||
pub(crate) fn translate_unary_op(
|
||||
&mut self,
|
||||
node: &Node,
|
||||
@@ -19,11 +50,15 @@ impl<'a> Translator<'a> {
|
||||
pub(crate) fn translate_to_copy(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
for input in &node.inputs {
|
||||
if input.name == "dtype"
|
||||
&& let Some(dtype_int) = input.arg.as_int()
|
||||
{
|
||||
let dtype = torch_dtype_int_to_luminal(dtype_int as u32);
|
||||
return Ok(a.cast(dtype));
|
||||
if input.name == "dtype" {
|
||||
if let Some(dtype_int) = input.arg.as_int() {
|
||||
let dtype = torch_dtype_int_to_luminal(dtype_int as u32);
|
||||
return Ok(a.cast(dtype));
|
||||
}
|
||||
if let Some(dtype_int) = input.arg.as_scalar_type() {
|
||||
let dtype = torch_dtype_int_to_luminal(dtype_int);
|
||||
return Ok(a.cast(dtype));
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(a)
|
||||
@@ -60,6 +95,155 @@ impl<'a> Translator<'a> {
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
pub(crate) fn translate_sign(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let zero = self
|
||||
.graph
|
||||
.constant_float(0.0)
|
||||
.cast(a.dtype)
|
||||
.expand_rhs(a.shape);
|
||||
let pos = a.gt(zero).cast(DType::Int);
|
||||
let neg = a.lt(zero).cast(DType::Int);
|
||||
let signed = pos - neg;
|
||||
Ok(if a.dtype == DType::Int {
|
||||
signed
|
||||
} else {
|
||||
signed.cast(a.dtype)
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn translate_bitwise_not(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
Ok(match a.dtype {
|
||||
DType::Bool => {
|
||||
let one = self
|
||||
.graph
|
||||
.constant_float(1.0)
|
||||
.cast(DType::Int)
|
||||
.expand_rhs(a.shape);
|
||||
(one - a.cast(DType::Int)).cast(DType::Bool)
|
||||
}
|
||||
DType::Int => (a + 1) * -1.0,
|
||||
other => {
|
||||
anyhow::bail!("bitwise_not only supports Bool/Int routing tensors, got {other:?}")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn translate_masked_fill_scalar(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let input = self.get_input_tensor(node, MASKED_FILL_INPUT_ARG)?;
|
||||
let mask = self.get_input_tensor(node, MASKED_FILL_MASK_ARG)?;
|
||||
let fill = self.get_float_arg(node, MASKED_FILL_VALUE_ARG)? as f32;
|
||||
let (input, mask) = broadcast_binary(input, mask);
|
||||
let work_dtype = if input.dtype == DType::Bool {
|
||||
DType::Int
|
||||
} else {
|
||||
input.dtype
|
||||
};
|
||||
let input_work = if input.dtype == DType::Bool {
|
||||
input.cast(DType::Int)
|
||||
} else {
|
||||
input
|
||||
};
|
||||
let mask_work = mask.cast(work_dtype);
|
||||
let fill_work = self
|
||||
.graph
|
||||
.constant_float(fill)
|
||||
.cast(work_dtype)
|
||||
.expand_rhs(input_work.shape);
|
||||
let one = self
|
||||
.graph
|
||||
.constant_float(1.0)
|
||||
.cast(work_dtype)
|
||||
.expand_rhs(input_work.shape);
|
||||
let result = mask_work * fill_work + (one - mask_work) * input_work;
|
||||
Ok(if input.dtype == DType::Bool {
|
||||
result.cast(DType::Bool)
|
||||
} else {
|
||||
result
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn translate_floor_divide(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, FLOOR_DIVIDE_INPUT_ARG)?;
|
||||
let b = if let Some(name) = node
|
||||
.inputs
|
||||
.get(FLOOR_DIVIDE_OTHER_ARG)
|
||||
.and_then(|i| i.arg.as_tensor_name())
|
||||
{
|
||||
self.get_tensor(name)?
|
||||
} else {
|
||||
let scalar = self.get_float_arg(node, FLOOR_DIVIDE_OTHER_ARG)? as f32;
|
||||
self.graph
|
||||
.constant_float(scalar)
|
||||
.cast(a.dtype)
|
||||
.expand_rhs(a.shape)
|
||||
};
|
||||
let (a, b) = crate::pt2_util::ensure_same_dtype(a, b);
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
let quotient = a.cast(DType::F32) / b.cast(DType::F32);
|
||||
let trunc = quotient.cast(DType::Int).cast(DType::F32);
|
||||
let adjust = quotient.lt(trunc).cast(DType::F32);
|
||||
let floored = trunc - adjust;
|
||||
Ok(if a.dtype == DType::Int {
|
||||
floored.cast(DType::Int)
|
||||
} else {
|
||||
floored.cast(a.dtype)
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn translate_div_tensor_mode(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, DIV_MODE_INPUT_ARG)?;
|
||||
let b = if let Some(name) = node
|
||||
.inputs
|
||||
.get(DIV_MODE_OTHER_ARG)
|
||||
.and_then(|i| i.arg.as_tensor_name())
|
||||
{
|
||||
self.get_tensor(name)?
|
||||
} else {
|
||||
let scalar = self.get_float_arg(node, DIV_MODE_OTHER_ARG)? as f32;
|
||||
self.graph
|
||||
.constant_float(scalar)
|
||||
.cast(a.dtype)
|
||||
.expand_rhs(a.shape)
|
||||
};
|
||||
let (a, b) = crate::pt2_util::ensure_same_dtype(a, b);
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
|
||||
// Check rounding_mode kwarg
|
||||
let rounding_mode = node.inputs.iter().find_map(|input| {
|
||||
if input.name == "rounding_mode"
|
||||
&& let Argument::Other(val) = &input.arg
|
||||
{
|
||||
return val.as_str().map(|s| s.to_string());
|
||||
}
|
||||
None
|
||||
});
|
||||
|
||||
let quotient = a.cast(DType::F32) / b.cast(DType::F32);
|
||||
match rounding_mode.as_deref() {
|
||||
Some("floor") => {
|
||||
let trunc = quotient.cast(DType::Int).cast(DType::F32);
|
||||
let adjust = quotient.lt(trunc).cast(DType::F32);
|
||||
let floored = trunc - adjust;
|
||||
Ok(if a.dtype == DType::Int {
|
||||
floored.cast(DType::Int)
|
||||
} else {
|
||||
floored.cast(a.dtype)
|
||||
})
|
||||
}
|
||||
Some("trunc") => Ok(if a.dtype == DType::Int {
|
||||
quotient.cast(DType::Int)
|
||||
} else {
|
||||
quotient.cast(DType::Int).cast(a.dtype)
|
||||
}),
|
||||
_ => {
|
||||
// No rounding mode — regular division
|
||||
Ok(quotient.cast(a.dtype))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn translate_clamp(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let min_val = if node.inputs.len() > 1 {
|
||||
@@ -73,13 +257,54 @@ impl<'a> Translator<'a> {
|
||||
None
|
||||
};
|
||||
|
||||
let mut result = a;
|
||||
// maximum_f32 / minimum_f32 internally use `.lt(F32 scalar)`, which
|
||||
// asserts matching tensor dtypes. Without this, clamp on an Int tensor
|
||||
// (e.g. Qwen3-MoE routes `cache_position.clamp(...)` through here)
|
||||
// panics inside luminal core. Promote to F32 around the bounds check
|
||||
// and cast back at the end.
|
||||
let original_dtype = a.dtype;
|
||||
let needs_promote = original_dtype != DType::F32;
|
||||
let mut result = if needs_promote { a.cast(DType::F32) } else { a };
|
||||
if let Some(min) = min_val {
|
||||
result = result.maximum_f32(min);
|
||||
}
|
||||
if let Some(max) = max_val {
|
||||
result = result.minimum_f32(max);
|
||||
}
|
||||
if needs_promote {
|
||||
result = result.cast(original_dtype);
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Compute `erf(a)` via the Abramowitz & Stegun 7.1.28 approximation
|
||||
/// (max error ~1.5e-7). Shared by `aten.erf.default` and the exact
|
||||
/// `aten.gelu.default` (which is `0.5 * x * (1 + erf(x / sqrt(2)))`).
|
||||
///
|
||||
/// erf(x) = sign(x) * (1 - poly(t) * exp(-x^2))
|
||||
/// where t = 1/(1 + 0.3275911*|x|), poly is degree 5 in Horner form.
|
||||
///
|
||||
/// Promotes the input to F32 internally (the approximation constants are
|
||||
/// F32 anyway, and luminal's binary ops assert matching dtypes — running
|
||||
/// this on Bf16 input directly trips the assertion at `a.ge(zero)`).
|
||||
/// Restores the original dtype on return.
|
||||
pub(crate) fn erf_approx(&mut self, a: GraphTensor) -> GraphTensor {
|
||||
let orig = a.dtype;
|
||||
let a = if orig == DType::F32 { a } else { a.cast(DType::F32) };
|
||||
let ax = a.abs();
|
||||
let x2 = a * a;
|
||||
let t = (ax * 0.3275911_f32 + 1.0).reciprocal();
|
||||
let poly = t
|
||||
* (t * (t
|
||||
* (t * (t * 1.061_405_4_f32 + (-1.453_152_1_f32)) + 1.421_413_8_f32)
|
||||
+ (-0.284_496_72_f32))
|
||||
+ 0.254_829_6_f32);
|
||||
let result_abs =
|
||||
self.graph.constant_float(1.0).expand_rhs(a.shape) - poly * (x2 * (-1.0)).exp();
|
||||
// sign(x) = 2*(x >= 0) - 1
|
||||
let zero = self.graph.constant_float(0.0).expand_rhs(a.shape);
|
||||
let sign = a.ge(zero).cast(DType::F32) * 2.0 - 1.0;
|
||||
let result = result_abs * sign;
|
||||
if orig == DType::F32 { result } else { result.cast(orig) }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,20 +2,65 @@
|
||||
|
||||
# Import Python components
|
||||
# Register DynamicCache pytree serialization once at import time
|
||||
import torch.export._unlift as _torch_export_unlift
|
||||
|
||||
from .cache_utils import _register_cache_serialization
|
||||
from .compiled_model import CompiledModel
|
||||
|
||||
# Import Rust extension components (built by maturin)
|
||||
# These are available directly in the package namespace
|
||||
from .luminal import CompiledGraph, process_pt2
|
||||
from .main import luminal_backend
|
||||
from .main import luminal_backend, register_backend
|
||||
|
||||
_register_cache_serialization()
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Suppress torch.export's `_guards_fn` insertion when luminal is on the stack.
|
||||
#
|
||||
# When `torch._dynamo.config.automatic_dynamic_shapes=True` (the default) and
|
||||
# a model is called with shapes that vary across calls, dynamo promotes the
|
||||
# changing dim to a SymInt and re-traces. During the re-trace, torch.export's
|
||||
# `_unlift_exported_program_lifted_states` (in `torch/export/_unlift.py`)
|
||||
# generates a `_guards_fn` submodule whose body closes over `L` — dynamo's
|
||||
# locals namespace. When aot_autograd later evaluates the resulting
|
||||
# GraphModule via fx.Interpreter, the closure's free `L` reference doesn't
|
||||
# resolve and we get
|
||||
# NameError: name 'L' is not defined
|
||||
# (gemma3 + StaticCache reproduces this deterministically).
|
||||
#
|
||||
# torch.export's own opt-out — `_ok_to_generate_guards_fn` — already walks
|
||||
# the call stack for filename patterns to suppress guard generation for
|
||||
# specific embedders (executorch, modai, on_device_ai, torchao). Add
|
||||
# "luminal" to the same suppression set by monkey-patching the function.
|
||||
# Net effect: torch.export never inserts `_guards_fn`, so re-tracing
|
||||
# succeeds, dynamic-shape compile-once-run-many works, and StaticCache
|
||||
# decode loops compile in ~one shot instead of per-token.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_orig_ok_to_generate_guards_fn = _torch_export_unlift._ok_to_generate_guards_fn
|
||||
|
||||
|
||||
def _luminal_aware_ok_to_generate_guards_fn() -> bool:
|
||||
"""Return False whenever luminal is anywhere in the call stack."""
|
||||
import inspect
|
||||
|
||||
frame = inspect.currentframe()
|
||||
try:
|
||||
while frame is not None:
|
||||
if "luminal" in frame.f_code.co_filename:
|
||||
return False
|
||||
frame = frame.f_back
|
||||
finally:
|
||||
del frame # avoid reference cycle
|
||||
return _orig_ok_to_generate_guards_fn()
|
||||
|
||||
|
||||
_torch_export_unlift._ok_to_generate_guards_fn = _luminal_aware_ok_to_generate_guards_fn
|
||||
|
||||
# Re-export everything for clean package interface
|
||||
__all__ = [
|
||||
"CompiledModel",
|
||||
"luminal_backend",
|
||||
"register_backend",
|
||||
"CompiledGraph",
|
||||
"process_pt2",
|
||||
]
|
||||
|
||||
@@ -31,7 +31,10 @@ class CompiledModel:
|
||||
self._has_dynamic_dims = getattr(graph_result, "has_dynamic_dims", False)
|
||||
self._weight_refs = weight_refs or []
|
||||
self._user_indices = user_indices
|
||||
self._is_cuda = graph_result.backend == "cuda"
|
||||
self._is_gpu = getattr(graph_result, "device_type", "cpu") != "cpu"
|
||||
self._supports_device_ptrs = getattr(
|
||||
graph_result, "supports_device_ptrs", False
|
||||
)
|
||||
# Expected input dtypes from graph (used to convert user inputs)
|
||||
input_dtype_codes = graph_result.input_dtypes
|
||||
self._input_dtypes = [
|
||||
@@ -89,13 +92,13 @@ class CompiledModel:
|
||||
for name, tensor, expected_dtype in zip(
|
||||
self._input_names, user_inputs, self._input_dtypes
|
||||
):
|
||||
if self._is_cuda and tensor.is_cuda:
|
||||
if self._supports_device_ptrs and tensor.is_cuda:
|
||||
t = tensor.detach().contiguous().to(expected_dtype)
|
||||
n_bytes = t.numel() * t.element_size()
|
||||
self._graph.set_input_device_ptr(name, t.data_ptr(), n_bytes)
|
||||
_input_refs.append(t)
|
||||
else:
|
||||
t = tensor.detach().cpu().contiguous()
|
||||
t = tensor.detach().cpu().contiguous().to(expected_dtype)
|
||||
n_bytes = t.numel() * t.element_size()
|
||||
dtype_code = _torch_dtype_code(t.dtype)
|
||||
self._graph.set_input_from_ptr(name, t.data_ptr(), n_bytes, dtype_code)
|
||||
@@ -110,7 +113,7 @@ class CompiledModel:
|
||||
|
||||
# CUDA zero-copy path: pre-allocate output tensors and register their device
|
||||
# pointers so the final kernel writes directly into PyTorch's buffer.
|
||||
_use_zero_copy = self._is_cuda and hasattr(self._graph, "set_output_device_ptr")
|
||||
_use_zero_copy = self._supports_device_ptrs
|
||||
output_tensors = []
|
||||
if _use_zero_copy:
|
||||
for i, (name, shape) in enumerate(zip(self._output_names, output_shapes)):
|
||||
@@ -120,9 +123,10 @@ class CompiledModel:
|
||||
else torch.float32
|
||||
)
|
||||
out = torch.empty(shape, dtype=out_dtype, device=input_device)
|
||||
self._graph.set_output_device_ptr(
|
||||
name, out.data_ptr(), out.numel() * out.element_size()
|
||||
)
|
||||
if out_dtype.is_floating_point:
|
||||
self._graph.set_output_device_ptr(
|
||||
name, out.data_ptr(), out.numel() * out.element_size()
|
||||
)
|
||||
output_tensors.append(out)
|
||||
|
||||
# Run the graph
|
||||
@@ -130,13 +134,42 @@ class CompiledModel:
|
||||
|
||||
# Collect outputs
|
||||
if _use_zero_copy:
|
||||
# For aliased outputs that couldn't be zero-copied, fall back to DtoD copy.
|
||||
for name, out in zip(self._output_names, output_tensors):
|
||||
if not self._graph.output_is_zero_copy(name):
|
||||
self._graph.copy_output_to_device_ptr(
|
||||
name, out.data_ptr(), out.numel() * out.element_size()
|
||||
outputs = []
|
||||
for i, (name, shape) in enumerate(zip(self._output_names, output_shapes)):
|
||||
out_dtype = (
|
||||
code_to_torch_dtype(output_dtype_codes[i])
|
||||
if i < len(output_dtype_codes)
|
||||
else torch.float32
|
||||
)
|
||||
out = output_tensors[i]
|
||||
if out_dtype.is_floating_point:
|
||||
if not self._graph.output_is_zero_copy(name):
|
||||
self._graph.copy_output_to_device_ptr(
|
||||
name, out.data_ptr(), out.numel() * out.element_size()
|
||||
)
|
||||
elif out_dtype == torch.int32:
|
||||
data = self._graph.get_output_i32(name)
|
||||
out = (
|
||||
torch.tensor(data, dtype=torch.int32)
|
||||
.reshape(tuple(shape))
|
||||
.to(input_device)
|
||||
)
|
||||
outputs = output_tensors
|
||||
elif out_dtype == torch.bool:
|
||||
data = self._graph.get_output_bool(name)
|
||||
out = (
|
||||
torch.tensor(data, dtype=torch.bool)
|
||||
.reshape(tuple(shape))
|
||||
.to(input_device)
|
||||
)
|
||||
else:
|
||||
data = self._graph.get_output(name)
|
||||
out = (
|
||||
torch.tensor(data, dtype=torch.float32)
|
||||
.reshape(tuple(shape))
|
||||
.to(out_dtype)
|
||||
.to(input_device)
|
||||
)
|
||||
outputs.append(out)
|
||||
else:
|
||||
# Native path: retrieve as f32, then convert to target dtype if needed.
|
||||
outputs = []
|
||||
@@ -146,13 +179,20 @@ class CompiledModel:
|
||||
if i < len(output_dtype_codes)
|
||||
else torch.float32
|
||||
)
|
||||
data = self._graph.get_output(name)
|
||||
out = (
|
||||
torch.tensor(data, dtype=torch.float32)
|
||||
.reshape(tuple(shape))
|
||||
.to(out_dtype)
|
||||
.to(input_device)
|
||||
)
|
||||
if out_dtype == torch.int32:
|
||||
data = self._graph.get_output_i32(name)
|
||||
out = torch.tensor(data, dtype=torch.int32).reshape(tuple(shape))
|
||||
elif out_dtype == torch.bool:
|
||||
data = self._graph.get_output_bool(name)
|
||||
out = torch.tensor(data, dtype=torch.bool).reshape(tuple(shape))
|
||||
else:
|
||||
data = self._graph.get_output(name)
|
||||
out = (
|
||||
torch.tensor(data, dtype=torch.float32)
|
||||
.reshape(tuple(shape))
|
||||
.to(out_dtype)
|
||||
)
|
||||
out = out.to(input_device)
|
||||
outputs.append(out)
|
||||
|
||||
return tuple(outputs)
|
||||
|
||||
@@ -9,20 +9,37 @@ from .dtype_util import torch_dtype_code as _torch_dtype_code
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _detect_backend(example_inputs):
|
||||
"""Detect backend from input device. Returns 'cuda' or 'native'."""
|
||||
device = example_inputs[0].device if example_inputs else torch.device("cpu")
|
||||
return "cuda" if device.type == "cuda" else "native"
|
||||
def _detect_factory_capsule(example_inputs):
|
||||
"""Pick the best built-in factory capsule based on input device.
|
||||
|
||||
Walks example_inputs for the first Tensor to read .device from. With
|
||||
dynamic=True, dynamo may pass SymInt/SymFloat alongside Tensors and those
|
||||
don't have a .device attribute — falling back to CPU on a SymInt-only call
|
||||
would silently route to the wrong backend, so prefer the first Tensor."""
|
||||
device = torch.device("cpu")
|
||||
for v in example_inputs or ():
|
||||
if isinstance(v, torch.Tensor):
|
||||
device = v.device
|
||||
break
|
||||
if device.type == "cuda":
|
||||
try:
|
||||
from .luminal import _cuda_lite_factory_capsule
|
||||
|
||||
return _cuda_lite_factory_capsule()
|
||||
except ImportError:
|
||||
pass
|
||||
from .luminal import _native_factory_capsule
|
||||
|
||||
return _native_factory_capsule()
|
||||
|
||||
|
||||
def _collect_weight_pointers(weights, backend):
|
||||
def _collect_weight_pointers(weights):
|
||||
"""Partition weight tensors into CUDA device pointers and CPU host pointers.
|
||||
|
||||
Preserves native dtype — no forced conversion to float32.
|
||||
|
||||
Args:
|
||||
weights: dict of name -> torch.Tensor
|
||||
backend: "cuda", "gpu", "cpu", or "native"
|
||||
|
||||
Returns:
|
||||
(keep_alive, device_ptrs, cpu_ptrs) where:
|
||||
@@ -36,7 +53,7 @@ def _collect_weight_pointers(weights, backend):
|
||||
for name, tensor in weights.items():
|
||||
t = tensor.detach().contiguous()
|
||||
n_bytes = t.numel() * t.element_size()
|
||||
if backend in ("cuda", "gpu") and t.is_cuda:
|
||||
if t.is_cuda:
|
||||
keep_alive.append(t)
|
||||
device_ptrs[name] = (t.data_ptr(), n_bytes)
|
||||
else:
|
||||
@@ -53,18 +70,41 @@ def _load_cpu_weights(compiled_graph, cpu_weights):
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# torch.compile backend entry point
|
||||
# Backend registration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def register_backend(factory_capsule):
|
||||
"""Wrap a backend factory PyCapsule into a torch.compile-compatible callable.
|
||||
|
||||
Args:
|
||||
factory_capsule: PyCapsule wrapping a BackendFactory fn pointer.
|
||||
|
||||
Returns:
|
||||
A callable(gm, example_inputs, options=None) suitable for torch.compile.
|
||||
"""
|
||||
|
||||
def backend(gm, example_inputs, options=None):
|
||||
return _compile_pt2(gm, example_inputs, factory_capsule, options=options)
|
||||
|
||||
return backend
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# torch.compile backend entry point (auto-detecting)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def luminal_backend(gm, example_inputs, options=None):
|
||||
"""Luminal torch.compile backend.
|
||||
"""Auto-detecting torch.compile backend.
|
||||
|
||||
Usage:
|
||||
torch.compile(model, backend=luminal_backend)
|
||||
Picks cuda_lite if inputs are on CUDA (and cuda feature is compiled in),
|
||||
native otherwise.
|
||||
|
||||
For external backends, use register_backend with the backend's factory capsule.
|
||||
"""
|
||||
backend = _detect_backend(example_inputs)
|
||||
return _compile_pt2(gm, example_inputs, backend)
|
||||
capsule = _detect_factory_capsule(example_inputs)
|
||||
return _compile_pt2(gm, example_inputs, capsule, options=options)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -72,8 +112,16 @@ def luminal_backend(gm, example_inputs, options=None):
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _compile_pt2(gm, example_inputs, backend):
|
||||
def _compile_pt2(gm, example_inputs, factory_capsule, options=None):
|
||||
"""PT2/torch.export path — delegates to pt2.pt2_backend."""
|
||||
from .pt2 import pt2_backend
|
||||
|
||||
return pt2_backend(gm, example_inputs, backend=backend)
|
||||
search_iterations = None
|
||||
if options is not None:
|
||||
search_iterations = options.get("search_iterations")
|
||||
return pt2_backend(
|
||||
gm,
|
||||
example_inputs,
|
||||
factory=factory_capsule,
|
||||
search_iterations=search_iterations,
|
||||
)
|
||||
|
||||
@@ -14,7 +14,85 @@ import torch
|
||||
|
||||
from .compiled_model import CompiledModel
|
||||
from .luminal import process_pt2
|
||||
from .main import _collect_weight_pointers, _detect_backend, _load_cpu_weights
|
||||
from .main import _collect_weight_pointers, _detect_factory_capsule, _load_cpu_weights
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DynamicCache <> pytree registration
|
||||
#
|
||||
# Without this, torch.export.export raises when handed an HF model that
|
||||
# returns CausalLMOutputWithPast(past_key_values=DynamicCache(...)), which
|
||||
# is every model with use_cache=True. The registration mirrors the one in
|
||||
# transformers.integrations.executorch.register_dynamic_cache_export_support
|
||||
# — same dict-based flatten (key_cache / value_cache lists), same replay via
|
||||
# cache.update(k, v, idx), and the matching torch.fx._pytree spec for FX
|
||||
# graphs. Done at module import so both entry points (pt2_backend via
|
||||
# torch.compile and the direct compile() call) get it for free.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _get_cache_dict(cache):
|
||||
"""Flatten a DynamicCache to a dict of parallel key/value lists."""
|
||||
return {
|
||||
"key_cache": [layer.keys for layer in cache.layers if layer.keys is not None],
|
||||
"value_cache": [
|
||||
layer.values for layer in cache.layers if layer.values is not None
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def _flatten_dynamic_cache(cache):
|
||||
return torch.utils._pytree._dict_flatten(_get_cache_dict(cache))
|
||||
|
||||
|
||||
def _flatten_with_keys_dynamic_cache(cache):
|
||||
return torch.utils._pytree._dict_flatten_with_keys(_get_cache_dict(cache))
|
||||
|
||||
|
||||
def _unflatten_dynamic_cache(values, context):
|
||||
from transformers.cache_utils import DynamicCache
|
||||
|
||||
dictionary = torch.utils._pytree._dict_unflatten(values, context)
|
||||
cache = DynamicCache()
|
||||
key_list = dictionary.get("key_cache", [])
|
||||
value_list = dictionary.get("value_cache", [])
|
||||
for idx in range(max(len(key_list), len(value_list))):
|
||||
k = key_list[idx] if idx < len(key_list) else None
|
||||
v = value_list[idx] if idx < len(value_list) else None
|
||||
cache.update(k, v, idx)
|
||||
return cache
|
||||
|
||||
|
||||
def _register_cache_serialization():
|
||||
"""Register DynamicCache with both torch.utils._pytree and torch.fx._pytree.
|
||||
|
||||
Idempotent: a second call is a no-op. Silently skipped if transformers is
|
||||
not installed.
|
||||
"""
|
||||
try:
|
||||
from transformers.cache_utils import DynamicCache
|
||||
except ImportError:
|
||||
return
|
||||
|
||||
if DynamicCache in torch.utils._pytree.SUPPORTED_NODES:
|
||||
return
|
||||
|
||||
torch.utils._pytree.register_pytree_node(
|
||||
DynamicCache,
|
||||
_flatten_dynamic_cache,
|
||||
_unflatten_dynamic_cache,
|
||||
serialized_type_name=f"{DynamicCache.__module__}.{DynamicCache.__name__}",
|
||||
flatten_with_keys_fn=_flatten_with_keys_dynamic_cache,
|
||||
)
|
||||
torch.fx._pytree.register_pytree_flatten_spec(
|
||||
DynamicCache,
|
||||
lambda cache, spec: torch.fx._pytree._dict_flatten_spec(
|
||||
_get_cache_dict(cache), spec
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
_register_cache_serialization()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
@@ -32,12 +110,51 @@ def _export_kwargs():
|
||||
return kwargs
|
||||
|
||||
|
||||
def _save_and_compile(ep_or_path, backend, search_iterations, original_weights=None):
|
||||
def _extract_pt2_constants(pt2_path):
|
||||
"""Extract tensor constants from the new flat PT2 format (torch >= 2.6).
|
||||
|
||||
In the new format, inline constants (e.g. ``torch.tensor([1., 2.])``) are
|
||||
stored in ``serialized_constants.pt`` rather than individual ZIP entries.
|
||||
The Rust parser skips them (returns constants_config=None); this function
|
||||
reads them back and returns a cpu_ptrs dict ready for _load_cpu_weights.
|
||||
|
||||
Returns (keep_alive, cpu_ptrs) — keep_alive must stay alive until after
|
||||
_load_cpu_weights returns (set_weight_from_ptr copies the bytes).
|
||||
"""
|
||||
import io
|
||||
import zipfile
|
||||
|
||||
from .dtype_util import torch_dtype_code as _torch_dtype_code
|
||||
|
||||
try:
|
||||
with zipfile.ZipFile(pt2_path) as z:
|
||||
if "serialized_constants.pt" not in z.namelist():
|
||||
return [], {}
|
||||
data = z.read("serialized_constants.pt")
|
||||
except Exception:
|
||||
return [], {}
|
||||
|
||||
constants = torch.load(io.BytesIO(data), weights_only=False)
|
||||
if not constants:
|
||||
return [], {}
|
||||
|
||||
keep_alive = []
|
||||
cpu_ptrs = {}
|
||||
for name, tensor in constants.items():
|
||||
t = tensor.detach().cpu().contiguous()
|
||||
keep_alive.append(t)
|
||||
n_bytes = t.numel() * t.element_size()
|
||||
cpu_ptrs[name] = (t.data_ptr(), n_bytes, _torch_dtype_code(t.dtype))
|
||||
return keep_alive, cpu_ptrs
|
||||
|
||||
|
||||
def _save_and_compile(ep_or_path, factory, search_iterations, original_weights=None):
|
||||
"""Compile a PT2 model via Rust, return CompiledModel.
|
||||
|
||||
Args:
|
||||
ep_or_path: Either an ExportedProgram (will be saved to a temp file) or
|
||||
a path to an already-saved .pt2 file.
|
||||
factory: PyCapsule wrapping the BackendFactory to use.
|
||||
original_weights: Optional dict mapping state_dict key -> original PyTorch tensor.
|
||||
When provided, device pointers are taken from these tensors instead of
|
||||
ep.state_dict (which torch.export may have cloned), enabling true zero-copy
|
||||
@@ -58,16 +175,20 @@ def _save_and_compile(ep_or_path, backend, search_iterations, original_weights=N
|
||||
|
||||
# Collect weight pointers for Rust (avoids duplicate GPU buffer allocation)
|
||||
keep_alive, weight_device_ptrs, cpu_weights = _collect_weight_pointers(
|
||||
weight_source, backend
|
||||
weight_source
|
||||
)
|
||||
|
||||
# Compile with device pointers — search uses actual weight memory (zero-copy)
|
||||
compiled = process_pt2(
|
||||
pt2_path, "", backend, search_iterations, weight_device_ptrs
|
||||
pt2_path, "", search_iterations, factory, weight_device_ptrs
|
||||
)
|
||||
|
||||
# Load CPU weights after compilation
|
||||
# Load CPU weights; also load inline tensor constants from the new flat
|
||||
# PT2 format (torch >= 2.6 stores them in serialized_constants.pt).
|
||||
const_keep_alive, const_cpu_weights = _extract_pt2_constants(pt2_path)
|
||||
cpu_weights.update(const_cpu_weights)
|
||||
_load_cpu_weights(compiled, cpu_weights)
|
||||
del const_keep_alive # bytes were copied by set_weight_from_ptr
|
||||
|
||||
return CompiledModel(compiled, weight_refs=keep_alive)
|
||||
finally:
|
||||
@@ -79,13 +200,21 @@ def _reinternalize_lifted_params(gm, example_inputs):
|
||||
"""Re-internalize lifted params as buffers so torch.export sees them as model state.
|
||||
|
||||
torch.compile lifts model parameters out of the module and passes them as
|
||||
extra elements in example_inputs. The Rust PT2 compiler may expect weights in
|
||||
the .pt2 state dict, not as runtime inputs. This function reverses the
|
||||
extra elements in example_inputs. The Rust PT2 compiler may expect weights in
|
||||
the .pt2 state dict, not as runtime inputs. This function reverses the
|
||||
lifting by registering them as buffers and replacing the placeholder nodes
|
||||
with get_attr nodes.
|
||||
|
||||
SymInt/SymFloat/SymBool values in example_inputs are rejected by
|
||||
torch.export.export as user inputs ("Unsupported input type
|
||||
<class 'torch.SymInt'>"). We don't restructure the graph for this — we
|
||||
specialize the *value* to its concrete hint (a plain int/float/bool), which
|
||||
torch.export accepts. The placeholder stays in place; the traced graph
|
||||
proceeds as if dynamo had specialized this dim. Invisible to callers of
|
||||
`torch.compile(..., backend=luminal_backend)`.
|
||||
|
||||
Returns (gm, user_inputs, original_weights) where:
|
||||
- user_inputs contains only the real inputs
|
||||
- user_inputs contains only real inputs (Tensors and concrete scalars)
|
||||
- original_weights maps buffer name -> original tensor (for zero-copy device pointers)
|
||||
"""
|
||||
buffer_indices = []
|
||||
@@ -119,14 +248,49 @@ def _reinternalize_lifted_params(gm, example_inputs):
|
||||
gm.graph.lint()
|
||||
gm.recompile()
|
||||
|
||||
user_inputs = (
|
||||
raw_user_inputs = (
|
||||
[example_inputs[i] for i in user_indices]
|
||||
if user_indices
|
||||
else list(example_inputs)
|
||||
)
|
||||
user_inputs = [
|
||||
_specialize_sym_scalar(v) if _is_sym_scalar(v) else v
|
||||
for v in raw_user_inputs
|
||||
]
|
||||
return gm, user_inputs, original_weights
|
||||
|
||||
|
||||
def _is_sym_scalar(val) -> bool:
|
||||
"""True for torch SymInt/SymFloat/SymBool — anything torch.export's fakify
|
||||
rejects as a user input. Plain int/float/bool are fine; only the symbolic
|
||||
wrappers need specialization."""
|
||||
if val is None:
|
||||
return False
|
||||
if isinstance(val, torch.Tensor):
|
||||
return False
|
||||
return type(val).__name__ in ("SymInt", "SymFloat", "SymBool") or isinstance(
|
||||
val, (torch.SymInt, torch.SymFloat, torch.SymBool)
|
||||
)
|
||||
|
||||
|
||||
def _specialize_sym_scalar(val):
|
||||
"""Resolve a SymInt/SymFloat/SymBool to its concrete hint. Falls back to
|
||||
str(val) -> primitive parse if the SymNode hint is missing."""
|
||||
try:
|
||||
if isinstance(val, torch.SymBool):
|
||||
return bool(val)
|
||||
if isinstance(val, torch.SymFloat):
|
||||
return float(val)
|
||||
return int(val)
|
||||
except Exception:
|
||||
# SymNodes without a hint — try parsing the str form as a last resort.
|
||||
s = str(val)
|
||||
try:
|
||||
return int(s)
|
||||
except ValueError:
|
||||
return float(s)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -136,7 +300,7 @@ def compile(
|
||||
model,
|
||||
example_input,
|
||||
search_iterations=25,
|
||||
backend=None,
|
||||
factory=None,
|
||||
export_kwargs=None,
|
||||
dynamic_dim=None,
|
||||
):
|
||||
@@ -146,7 +310,7 @@ def compile(
|
||||
model: A PyTorch nn.Module.
|
||||
example_input: Example input tensor(s) for tracing.
|
||||
search_iterations: Number of optimization search iterations.
|
||||
backend: "native" or "cuda". Auto-detected if None.
|
||||
factory: PyCapsule wrapping a BackendFactory. Auto-detected if None.
|
||||
export_kwargs: Extra kwargs passed to torch.export.export.
|
||||
dynamic_dim: Which input dimension to make dynamic.
|
||||
|
||||
@@ -156,10 +320,8 @@ def compile(
|
||||
if dynamic_dim is None:
|
||||
dynamic_dim = "auto"
|
||||
|
||||
if backend is None:
|
||||
backend = os.environ.get("LUMINAL_BACKEND", None)
|
||||
if backend is None:
|
||||
backend = "cuda" if torch.cuda.is_available() else "native"
|
||||
if factory is None:
|
||||
factory = _detect_factory_capsule([example_input])
|
||||
|
||||
kwargs = export_kwargs or {}
|
||||
extra = _export_kwargs()
|
||||
@@ -208,18 +370,20 @@ def compile(
|
||||
)
|
||||
ep = ep.run_decompositions()
|
||||
|
||||
return _save_and_compile(ep, backend, search_iterations)
|
||||
return _save_and_compile(ep, factory, search_iterations)
|
||||
|
||||
|
||||
def pt2_backend(gm, example_inputs, backend=None):
|
||||
def pt2_backend(gm, example_inputs, factory=None, search_iterations=None):
|
||||
"""torch.compile backend using PT2 pipeline.
|
||||
|
||||
Usage: torch.compile(model, backend=luminal.pt2.pt2_backend)
|
||||
Usage: torch.compile(model, backend=luminal.register_backend(capsule))
|
||||
"""
|
||||
import gc
|
||||
|
||||
if backend is None:
|
||||
backend = _detect_backend(example_inputs)
|
||||
if factory is None:
|
||||
factory = _detect_factory_capsule(example_inputs)
|
||||
if search_iterations is None:
|
||||
search_iterations = 10
|
||||
|
||||
gm = gm.eval()
|
||||
gm, user_inputs, original_weights = _reinternalize_lifted_params(gm, example_inputs)
|
||||
@@ -227,6 +391,28 @@ def pt2_backend(gm, example_inputs, backend=None):
|
||||
ep = torch.export.export(gm, tuple(user_inputs), **_export_kwargs())
|
||||
ep = ep.run_decompositions()
|
||||
|
||||
# Detect USER_INPUT_MUTATION outputs (e.g., in-place KV cache updates).
|
||||
# These must be written back to the original input tensors after each call.
|
||||
# Only USER_OUTPUT results are returned to the torch.compile caller.
|
||||
try:
|
||||
from torch.export.graph_signature import OutputKind
|
||||
|
||||
mutation_mappings = [] # list of (compiled_output_idx, user_input_idx)
|
||||
user_output_indices = []
|
||||
for i, spec in enumerate(ep.graph_signature.output_specs):
|
||||
if spec.kind == OutputKind.USER_INPUT_MUTATION:
|
||||
# target is 'args_N' — index into user_inputs
|
||||
try:
|
||||
arg_idx = int(spec.target.split("_")[1])
|
||||
mutation_mappings.append((i, arg_idx))
|
||||
except (ValueError, IndexError):
|
||||
user_output_indices.append(i)
|
||||
else:
|
||||
user_output_indices.append(i)
|
||||
except ImportError:
|
||||
mutation_mappings = []
|
||||
user_output_indices = None # unknown; return all outputs
|
||||
|
||||
# When using shared memory (original_weights), strip large weight buffers from
|
||||
# the EP before saving. The Rust side uses device pointers for these weights,
|
||||
# not the .pt2 file data, so serializing them is pure IO waste (~32 GB for 8B
|
||||
@@ -252,8 +438,27 @@ def pt2_backend(gm, example_inputs, backend=None):
|
||||
|
||||
try:
|
||||
result = _save_and_compile(
|
||||
pt2_path, backend, 10, original_weights=original_weights
|
||||
pt2_path, factory, search_iterations, original_weights=original_weights
|
||||
)
|
||||
return result
|
||||
finally:
|
||||
shutil.rmtree(tmpdir, ignore_errors=True)
|
||||
|
||||
# Wrap the compiled model to handle USER_INPUT_MUTATION: write updated tensors
|
||||
# back into the original input buffers and return only USER_OUTPUT tensors.
|
||||
if mutation_mappings:
|
||||
_compiled = result
|
||||
_mut = mutation_mappings
|
||||
_usr = user_output_indices
|
||||
|
||||
def _mutation_wrapper(*inputs):
|
||||
outputs = _compiled(*inputs)
|
||||
for out_idx, inp_idx in _mut:
|
||||
if inp_idx < len(inputs) and out_idx < len(outputs):
|
||||
inputs[inp_idx].copy_(outputs[out_idx])
|
||||
if _usr is not None:
|
||||
return tuple(outputs[i] for i in _usr if i < len(outputs))
|
||||
return outputs
|
||||
|
||||
return _mutation_wrapper
|
||||
|
||||
return result
|
||||
|
||||
@@ -7,8 +7,8 @@ try:
|
||||
import maturin_import_hook
|
||||
from maturin_import_hook.settings import MaturinSettings
|
||||
|
||||
backend = os.getenv("LUMINAL_BACKEND", "native").lower()
|
||||
settings = MaturinSettings(features=["cuda"]) if backend == "cuda" else None
|
||||
use_cuda = os.getenv("LUMINAL_TEST_DEVICE", "cpu").lower() == "cuda"
|
||||
settings = MaturinSettings(features=["cuda"]) if use_cuda else None
|
||||
maturin_import_hook.install(settings=settings)
|
||||
except ImportError:
|
||||
pass # Hook not available, rebuilds will be manual
|
||||
@@ -22,23 +22,17 @@ torch.set_float32_matmul_precision("highest")
|
||||
|
||||
@pytest.fixture
|
||||
def device() -> torch.device:
|
||||
backend = os.getenv("LUMINAL_BACKEND", "native").lower()
|
||||
return torch.device("cuda") if backend == "cuda" else torch.device("cpu")
|
||||
if (
|
||||
os.getenv("LUMINAL_TEST_DEVICE", "cpu").lower() == "cuda"
|
||||
and torch.cuda.is_available()
|
||||
):
|
||||
return torch.device("cuda")
|
||||
return torch.device("cpu")
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True, scope="function")
|
||||
def reset_torch_dynamo():
|
||||
# We need this for two reasons
|
||||
# 1. Some of our casts tests use the same model, but those graph have some state to them
|
||||
# and the cache will return old models
|
||||
# 2. The cache adds a large preformace hit to the test suite
|
||||
torch._dynamo.config.cache_size_limit = 1
|
||||
# Disable silent fallback to eager mode so backend errors surface as test failures
|
||||
torch._dynamo.config.suppress_errors = False
|
||||
"""Reset PyTorch Dynamo state after each test to prevent state leakage.
|
||||
|
||||
This fixture automatically runs after every test function to clear
|
||||
torch._dynamo's compilation cache, ensuring test isolation.
|
||||
"""
|
||||
yield # Test runs here
|
||||
yield
|
||||
torch._dynamo.reset()
|
||||
|
||||
34
crates/luminal_python/tests/test_capsule_validation.py
Normal file
34
crates/luminal_python/tests/test_capsule_validation.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""FFI-boundary tests for process_pt2's capsule validation.
|
||||
|
||||
Deviates from the standard `torch.compile(..., backend=luminal_backend)`
|
||||
pattern in CLAUDE.md because the thing under test is the capsule-name
|
||||
check itself, not a feature behavior. Exercising it through torch.compile
|
||||
would only cover the happy path (`_native_factory_capsule` produces a
|
||||
correctly-named capsule, so validation passes trivially).
|
||||
"""
|
||||
|
||||
import ctypes
|
||||
|
||||
import pytest
|
||||
|
||||
from luminal import process_pt2
|
||||
|
||||
|
||||
def _new_capsule(name: bytes):
|
||||
PyCapsule_New = ctypes.pythonapi.PyCapsule_New
|
||||
PyCapsule_New.restype = ctypes.py_object
|
||||
PyCapsule_New.argtypes = [ctypes.c_void_p, ctypes.c_char_p, ctypes.c_void_p]
|
||||
dummy = ctypes.c_void_p(0xDEADBEEF)
|
||||
return PyCapsule_New(ctypes.byref(dummy), name, None)
|
||||
|
||||
|
||||
def test_process_pt2_rejects_capsule_with_wrong_name():
|
||||
bogus = _new_capsule(b"not.luminal.backend_factory")
|
||||
with pytest.raises(ValueError, match="luminal.backend_factory"):
|
||||
process_pt2("/dev/null", "/dev/null", 0, bogus, None)
|
||||
|
||||
|
||||
def test_process_pt2_rejects_capsule_with_no_name():
|
||||
unnamed = _new_capsule(None)
|
||||
with pytest.raises(ValueError, match="luminal.backend_factory"):
|
||||
process_pt2("/dev/null", "/dev/null", 0, unnamed, None)
|
||||
@@ -170,7 +170,8 @@ from test_models import (
|
||||
ScatterElementsAxis0TestModel,
|
||||
# ScatterElements models
|
||||
ScatterElementsTestModel,
|
||||
# ScatterND model
|
||||
# ScatterND / IndexPut models
|
||||
IndexPutOptionalModel,
|
||||
ScatterNDTestModel,
|
||||
ShapeReshapeBatchFlattenModel,
|
||||
ShapeReshapeKeepBatchModel,
|
||||
@@ -215,6 +216,7 @@ from test_models import (
|
||||
WhereWithConstantModel,
|
||||
# Xor model
|
||||
XorTestModel,
|
||||
ArgsortStableDuplicatesModel,
|
||||
# Conv models
|
||||
Conv1dNoPadModel,
|
||||
Conv1dSamePadModel,
|
||||
@@ -231,6 +233,7 @@ from test_models import (
|
||||
GroupedConv2dModel,
|
||||
GroupedConv2dGroups3Model,
|
||||
MambaConvBlockModel,
|
||||
TinyMoERoutingModel,
|
||||
)
|
||||
|
||||
from luminal import luminal_backend
|
||||
@@ -1634,6 +1637,21 @@ def test_or(device: torch.device):
|
||||
assert torch.allclose(output, original)
|
||||
|
||||
|
||||
def test_bitwise_or(device: torch.device):
|
||||
"""Test bitwise_or on boolean tensors. PyTorch's `a | b` on Bool tensors
|
||||
emits `aten.bitwise_or.Tensor`, NOT `aten.logical_or.default` — Gemma-style
|
||||
sliding+full attention mask fusion takes this path."""
|
||||
from test_models import BitwiseOrTestModel
|
||||
|
||||
model: torch.nn.Module = BitwiseOrTestModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
a = torch.tensor([True, False, True, False, True, True], device=device)
|
||||
b = torch.tensor([False, True, True, False, False, True], device=device)
|
||||
original = model(a, b)
|
||||
output = model_compiled(a, b)
|
||||
assert torch.equal(output, original)
|
||||
|
||||
|
||||
# ========== PT2 Xor Node Tests ==========
|
||||
|
||||
|
||||
@@ -1948,6 +1966,54 @@ def test_split(device: torch.device):
|
||||
assert torch.allclose(model_compiled(x), model(x))
|
||||
|
||||
|
||||
# ========== Argsort / MoE Routing Tests ==========
|
||||
|
||||
|
||||
def test_argsort_stable_duplicates(device: torch.device):
|
||||
"""Duplicate values should follow stable lower-index-first tie-breaking."""
|
||||
model: torch.nn.Module = ArgsortStableDuplicatesModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x = torch.tensor(
|
||||
[[2.0, 1.0, 1.0, 3.0]],
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
original: torch.Tensor = model(x)
|
||||
output: torch.Tensor = model_compiled(x)
|
||||
assert output.dtype == torch.int32
|
||||
assert torch.equal(output, original.to(torch.int32))
|
||||
|
||||
|
||||
def test_tiny_moe_routing(device: torch.device):
|
||||
"""Focused proof for build MoE routing support."""
|
||||
model: torch.nn.Module = TinyMoERoutingModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
scores = torch.tensor(
|
||||
[[0.1, 0.9, 0.4, 0.7], [0.6, -0.8, 0.95, 0.2]],
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
expected = model(scores)
|
||||
output = model_compiled(scores)
|
||||
|
||||
expected_dtypes = (
|
||||
torch.int32,
|
||||
torch.float32,
|
||||
torch.int32,
|
||||
torch.bool,
|
||||
torch.int32,
|
||||
torch.float32,
|
||||
)
|
||||
for actual, eager, expected_dtype in zip(output, expected, expected_dtypes):
|
||||
assert actual.dtype == expected_dtype
|
||||
eager = eager.to(actual.dtype)
|
||||
if actual.dtype.is_floating_point:
|
||||
assert torch.allclose(actual, eager)
|
||||
else:
|
||||
assert torch.equal(actual, eager)
|
||||
|
||||
|
||||
# ========== PT2 TopK Node Tests ==========
|
||||
|
||||
|
||||
@@ -2016,6 +2082,271 @@ def test_scatter_nd(device: torch.device):
|
||||
assert torch.allclose(output, original)
|
||||
|
||||
|
||||
def test_index_put_optional(device: torch.device):
|
||||
"""Tests index_put with optional (None) indices — mirrors StaticCache KV update."""
|
||||
model: torch.nn.Module = IndexPutOptionalModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x: torch.Tensor = torch.zeros(2, 2, 8, 4, device=device)
|
||||
original: torch.Tensor = model(x)
|
||||
output: torch.Tensor = model_compiled(x)
|
||||
assert torch.allclose(output, original, atol=1e-5)
|
||||
|
||||
|
||||
# ========== Bool-mask index_put correctness tests ==========
|
||||
#
|
||||
# `x[bool_mask] = scalar` is semantically `where(mask, scalar, x)`, NOT a
|
||||
# scatter into Int(mask) positions. Pre-fix, the translator cast the Bool
|
||||
# mask to Int and routed through scatter_nd, reinterpreting True/False as
|
||||
# row indices 1/0 and silently corrupting `x`. Each variant below exercises
|
||||
# a different mask configuration; together they would catch any regression
|
||||
# in the bool-mask blend path.
|
||||
|
||||
|
||||
def _check_bool_mask(
|
||||
device: torch.device, model_cls, x: torch.Tensor, mask: torch.Tensor
|
||||
):
|
||||
"""Shared body: compile, run eager + compiled, assert exact equality."""
|
||||
from test_models import (
|
||||
BoolMaskAssign3DModel,
|
||||
BoolMaskAssignFloatModel,
|
||||
BoolMaskAssignIntModel,
|
||||
)
|
||||
|
||||
_ = (BoolMaskAssign3DModel, BoolMaskAssignFloatModel, BoolMaskAssignIntModel)
|
||||
model: torch.nn.Module = model_cls().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
original: torch.Tensor = model(x, mask)
|
||||
output: torch.Tensor = model_compiled(x, mask)
|
||||
# Bit-equal (not allclose) — the lowering should produce identical
|
||||
# results to eager for bool-mask blends.
|
||||
assert torch.equal(output, original), (
|
||||
f"bool-mask index_put mismatch:\n"
|
||||
f" mask = {mask.flatten().tolist()}\n"
|
||||
f" eager = {original.flatten().tolist()}\n"
|
||||
f" out = {output.flatten().tolist()}"
|
||||
)
|
||||
|
||||
|
||||
def test_bool_mask_index_put_all_false(device: torch.device):
|
||||
"""All-False mask must be a no-op. Pre-fix this *silently* corrupted row 0
|
||||
— the regression that drove the Gemma-4 ~30-magnitude logits drift."""
|
||||
from test_models import BoolMaskAssignIntModel
|
||||
|
||||
x = torch.arange(16, device=device, dtype=torch.long).reshape(4, 4)
|
||||
mask = torch.zeros(4, 4, dtype=torch.bool, device=device)
|
||||
_check_bool_mask(device, BoolMaskAssignIntModel, x, mask)
|
||||
|
||||
|
||||
def test_bool_mask_index_put_one_true(device: torch.device):
|
||||
"""Single True position — only that position should change."""
|
||||
from test_models import BoolMaskAssignIntModel
|
||||
|
||||
x = torch.arange(16, device=device, dtype=torch.long).reshape(4, 4)
|
||||
mask = torch.zeros(4, 4, dtype=torch.bool, device=device)
|
||||
mask[1, 2] = True
|
||||
_check_bool_mask(device, BoolMaskAssignIntModel, x, mask)
|
||||
|
||||
|
||||
def test_bool_mask_index_put_many_true(device: torch.device):
|
||||
"""Multiple scattered True positions — each should be replaced independently."""
|
||||
from test_models import BoolMaskAssignIntModel
|
||||
|
||||
x = torch.arange(16, device=device, dtype=torch.long).reshape(4, 4)
|
||||
mask = torch.tensor(
|
||||
[
|
||||
[True, False, False, True],
|
||||
[False, False, True, False],
|
||||
[True, False, False, False],
|
||||
[False, True, False, True],
|
||||
],
|
||||
dtype=torch.bool,
|
||||
device=device,
|
||||
)
|
||||
_check_bool_mask(device, BoolMaskAssignIntModel, x, mask)
|
||||
|
||||
|
||||
def test_bool_mask_index_put_all_true(device: torch.device):
|
||||
"""All-True mask — every element should become the scalar value."""
|
||||
from test_models import BoolMaskAssignIntModel
|
||||
|
||||
x = torch.arange(16, device=device, dtype=torch.long).reshape(4, 4)
|
||||
mask = torch.ones(4, 4, dtype=torch.bool, device=device)
|
||||
_check_bool_mask(device, BoolMaskAssignIntModel, x, mask)
|
||||
|
||||
|
||||
def test_bool_mask_index_put_float(device: torch.device):
|
||||
"""Float data + float scalar value. Verifies the where-blend works for
|
||||
non-integer dtypes — the blend formula `a*(1-mask) + value*mask` casts
|
||||
mask to data's dtype, so dtype-specific paths must compose correctly."""
|
||||
from test_models import BoolMaskAssignFloatModel
|
||||
|
||||
x = torch.arange(20, device=device, dtype=torch.float32).reshape(4, 5)
|
||||
mask = torch.tensor(
|
||||
[
|
||||
[True, False, False, True, False],
|
||||
[False, True, False, False, True],
|
||||
[True, True, False, False, False],
|
||||
[False, False, False, True, True],
|
||||
],
|
||||
dtype=torch.bool,
|
||||
device=device,
|
||||
)
|
||||
model = BoolMaskAssignFloatModel().to(device)
|
||||
compiled = torch.compile(model, backend=luminal_backend)
|
||||
original = model(x, mask)
|
||||
output = compiled(x, mask)
|
||||
assert torch.allclose(output, original)
|
||||
|
||||
|
||||
def test_bool_mask_index_put_3d(device: torch.device):
|
||||
"""3-D `x` with a 3-D bool mask of matching shape. Catches regressions
|
||||
where the bool-mask detection only works at one specific rank — the
|
||||
`idx_tensor.shape.dims == a.shape.dims` check has to handle arbitrary
|
||||
ranks, not just 2-D."""
|
||||
from test_models import BoolMaskAssign3DModel
|
||||
|
||||
x = torch.arange(24, device=device, dtype=torch.float32).reshape(2, 3, 4)
|
||||
mask = torch.zeros(2, 3, 4, dtype=torch.bool, device=device)
|
||||
mask[0, 1, 2] = True
|
||||
mask[1, 0, 0] = True
|
||||
mask[1, 2, 3] = True
|
||||
model = BoolMaskAssign3DModel().to(device)
|
||||
compiled = torch.compile(model, backend=luminal_backend)
|
||||
original = model(x, mask)
|
||||
output = compiled(x, mask)
|
||||
assert torch.allclose(output, original)
|
||||
|
||||
|
||||
def test_int_index_put_scalar_src(device: torch.device):
|
||||
"""`x[indices] = scalar` with int indices: the scatter path receives a
|
||||
scalar src against a 1D index tensor. Pre-fix `GraphTensor::scatter`
|
||||
panicked at `flatten_strides` (rank mismatch: index_shape=[2],
|
||||
src_strides=[]). With the zero-stride padding the scalar broadcasts
|
||||
across all indexed positions correctly."""
|
||||
from test_models import IntIndexAssignScalarModel
|
||||
|
||||
x = torch.arange(20, device=device, dtype=torch.float32).reshape(5, 4)
|
||||
indices = torch.tensor([0, 3], device=device, dtype=torch.long)
|
||||
model = IntIndexAssignScalarModel().to(device)
|
||||
compiled = torch.compile(model, backend=luminal_backend)
|
||||
original = model(x, indices)
|
||||
output = compiled(x, indices)
|
||||
assert torch.allclose(output, original)
|
||||
|
||||
|
||||
def test_grouped_mm_fallback(device: torch.device):
|
||||
"""Tests transformers::grouped_mm_fallback — the per-expert batched matmul
|
||||
used by HF MoE forward passes (DeepSeek-V2/V3, Qwen2/3-MoE, Mixtral, ...).
|
||||
|
||||
Importing transformers.integrations.moe registers the custom_op via
|
||||
`torch.library.custom_op("transformers::grouped_mm_fallback", ...)`. After
|
||||
import, `torch.ops.transformers.grouped_mm_fallback` is callable directly.
|
||||
"""
|
||||
# Side-effect import: registers the custom_op via torch.library.custom_op.
|
||||
# The name itself isn't referenced — ruff's F401 must be suppressed.
|
||||
import transformers.integrations.moe # noqa: F401
|
||||
from test_models import GroupedMMFallbackTestModel
|
||||
|
||||
model: torch.nn.Module = GroupedMMFallbackTestModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
# 2 experts, 4 tokens, K=8, N=16. Tokens [0,1] go to expert 0, [2,3] to expert 1.
|
||||
g, s, k, n = 2, 4, 8, 16
|
||||
input = torch.randn(s, k, device=device)
|
||||
weight = torch.randn(g, k, n, device=device)
|
||||
offs = torch.tensor([2, 4], device=device, dtype=torch.int32)
|
||||
original: torch.Tensor = model(input, weight, offs)
|
||||
output: torch.Tensor = model_compiled(input, weight, offs)
|
||||
assert torch.allclose(output, original, atol=1e-4)
|
||||
|
||||
|
||||
def test_grouped_mm_fallback_routing_invariance(device: torch.device):
|
||||
"""The MoE forest, not just the trees: one compile must correctly handle
|
||||
*any* routing pattern at the same shape.
|
||||
|
||||
`translate_grouped_mm` is correct only if `offs` flows through as a runtime
|
||||
tensor — the gate's top-k decision varies per token batch, and the same
|
||||
compiled graph has to dispatch tokens to the right experts for whatever
|
||||
`offs` arrives at execution. If our lowering accidentally specialized on a
|
||||
particular `offs` value (baking in expert assignments), `compiled(input_b,
|
||||
weight, offs_b)` would either silently produce wrong-expert output or
|
||||
trigger a recompile.
|
||||
|
||||
This test asserts three things at once:
|
||||
(a) Different `offs` (= different routing) doesn't trigger a recompile.
|
||||
(b) `offs` appears as an FX graph node, not a baked constant.
|
||||
(c) The same compiled graph produces correct output for both routings,
|
||||
and outputs *differ* between routings (else the test is moot).
|
||||
"""
|
||||
import transformers.integrations.moe # noqa: F401
|
||||
from test_models import GroupedMMFallbackTestModel
|
||||
|
||||
g, s, k, n = 2, 4, 8, 16
|
||||
|
||||
# Wrap luminal_backend to capture the FX graph(s) dynamo hands us.
|
||||
captured = []
|
||||
|
||||
def capturing_backend(gm, example_inputs):
|
||||
captured.append(gm)
|
||||
return luminal_backend(gm, example_inputs)
|
||||
|
||||
model = GroupedMMFallbackTestModel().to(device)
|
||||
compiled = torch.compile(model, backend=capturing_backend)
|
||||
|
||||
# Same shapes, different data → different routing patterns.
|
||||
weight = torch.randn(g, k, n, device=device)
|
||||
input_a = torch.randn(s, k, device=device)
|
||||
input_b = torch.randn(s, k, device=device)
|
||||
# offs[i] = cumulative tokens through expert i. Different routings:
|
||||
# offs_a: 1 token to expert 0, 3 to expert 1
|
||||
# offs_b: 3 tokens to expert 0, 1 to expert 1
|
||||
offs_a = torch.tensor([1, 4], device=device, dtype=torch.int32)
|
||||
offs_b = torch.tensor([3, 4], device=device, dtype=torch.int32)
|
||||
|
||||
with torch.no_grad():
|
||||
ref_a = model(input_a, weight, offs_a)
|
||||
out_a = compiled(input_a, weight, offs_a)
|
||||
n_compiles_after_first = len(captured)
|
||||
|
||||
ref_b = model(input_b, weight, offs_b)
|
||||
out_b = compiled(input_b, weight, offs_b)
|
||||
|
||||
# (a) No recompile between distinct routings.
|
||||
assert len(captured) == n_compiles_after_first, (
|
||||
f"Different routings triggered a recompile: "
|
||||
f"{n_compiles_after_first} → {len(captured)}"
|
||||
)
|
||||
|
||||
# (b) offs is an FX graph node, not a baked constant.
|
||||
grouped_nodes = [
|
||||
node for node in captured[0].graph.nodes if "grouped_mm" in str(node.target)
|
||||
]
|
||||
assert len(grouped_nodes) == 1, (
|
||||
f"Expected exactly one grouped_mm node, got {len(grouped_nodes)}"
|
||||
)
|
||||
grouped_node = grouped_nodes[0]
|
||||
# transformers::grouped_mm_fallback emits offs as a kwarg; aten._grouped_mm
|
||||
# may emit it as a positional. Accept either.
|
||||
offs_arg = grouped_node.kwargs.get("offs")
|
||||
if offs_arg is None and len(grouped_node.args) > 2:
|
||||
offs_arg = grouped_node.args[2]
|
||||
assert hasattr(offs_arg, "op"), (
|
||||
f"offs argument should be an FX graph node, got {offs_arg!r} "
|
||||
f"({type(offs_arg).__name__}) — looks baked as constant"
|
||||
)
|
||||
|
||||
# (c) Both routings produce correct output, and outputs differ.
|
||||
assert torch.allclose(out_a, ref_a, atol=1e-4), (
|
||||
f"routing A: max_diff={torch.max(torch.abs(out_a - ref_a)).item():.2e}"
|
||||
)
|
||||
assert torch.allclose(out_b, ref_b, atol=1e-4), (
|
||||
f"routing B: max_diff={torch.max(torch.abs(out_b - ref_b)).item():.2e}"
|
||||
)
|
||||
assert not torch.allclose(out_a, out_b, atol=1e-3), (
|
||||
"Outputs of routing A and B should differ — otherwise routing isn't "
|
||||
"actually being exercised."
|
||||
)
|
||||
|
||||
|
||||
# ========== Dtype Round-Trip Tests ==========
|
||||
|
||||
|
||||
|
||||
94
crates/luminal_python/tests/test_kv_cache_comparison.py
Normal file
94
crates/luminal_python/tests/test_kv_cache_comparison.py
Normal file
@@ -0,0 +1,94 @@
|
||||
"""KV Cache decode loop test.
|
||||
|
||||
Compiles a tiny 1-layer Llama model with use_cache=True, then:
|
||||
1. Prefill: model(input_ids) -> logits + K/V cache
|
||||
2. Decode: model(next_token, past_key_values=cache) -> logits + updated K/V
|
||||
|
||||
Verifies correctness of both steps and writes DOT graphs for comparison.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
from luminal import luminal_backend
|
||||
|
||||
|
||||
def _capturing_backend(captured):
|
||||
"""Wrap luminal_backend to capture CompiledModels for DOT extraction."""
|
||||
|
||||
def backend(gm, example_inputs):
|
||||
compiled = luminal_backend(gm, example_inputs)
|
||||
captured.append(compiled)
|
||||
return compiled
|
||||
|
||||
return backend
|
||||
|
||||
|
||||
def test_kv_cache_decode_loop():
|
||||
"""Full prefill -> decode loop through luminal with KV cache."""
|
||||
from transformers import LlamaConfig, LlamaForCausalLM
|
||||
|
||||
# Allow both prefill and decode compilations (conftest sets limit=1)
|
||||
torch._dynamo.config.cache_size_limit = 2
|
||||
|
||||
config = LlamaConfig(
|
||||
hidden_size=64,
|
||||
num_attention_heads=4,
|
||||
num_key_value_heads=2,
|
||||
num_hidden_layers=1,
|
||||
intermediate_size=128,
|
||||
vocab_size=256,
|
||||
max_position_embeddings=128,
|
||||
use_cache=True,
|
||||
attn_implementation="eager",
|
||||
)
|
||||
model = LlamaForCausalLM(config).eval()
|
||||
input_ids = torch.tensor([[1, 2, 3, 4]])
|
||||
|
||||
captured = []
|
||||
compiled = torch.compile(model, backend=_capturing_backend(captured))
|
||||
|
||||
# --- Prefill step ---
|
||||
with torch.no_grad():
|
||||
ref_prefill = model(input_ids)
|
||||
out_prefill = compiled(input_ids)
|
||||
|
||||
assert torch.allclose(out_prefill.logits, ref_prefill.logits, atol=1e-5)
|
||||
assert out_prefill.past_key_values is not None, "Prefill should return KV cache"
|
||||
|
||||
# --- Decode step ---
|
||||
next_token = ref_prefill.logits[0, -1, :].argmax().unsqueeze(0).unsqueeze(0)
|
||||
|
||||
with torch.no_grad():
|
||||
ref_decode = model(next_token, past_key_values=ref_prefill.past_key_values)
|
||||
out_decode = compiled(next_token, past_key_values=out_prefill.past_key_values)
|
||||
|
||||
assert torch.allclose(out_decode.logits, ref_decode.logits, atol=1e-5)
|
||||
|
||||
# --- DOT graph comparison ---
|
||||
# captured[0] = prefill graph, captured[1] = decode graph (recompiled by dynamo)
|
||||
assert len(captured) >= 2, (
|
||||
f"Expected 2 compilations (prefill+decode), got {len(captured)}"
|
||||
)
|
||||
|
||||
out_dir = "/tmp/luminal_kv_cache_comparison"
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
|
||||
prefill_dot = captured[0]._graph.to_dot()
|
||||
decode_dot = captured[1]._graph.to_dot()
|
||||
|
||||
with open(os.path.join(out_dir, "prefill.dot"), "w") as f:
|
||||
f.write(prefill_dot)
|
||||
with open(os.path.join(out_dir, "decode.dot"), "w") as f:
|
||||
f.write(decode_dot)
|
||||
|
||||
print(f"\n=== DOT files written to {out_dir} ===")
|
||||
print(f"Prefill: {len(prefill_dot)} chars, inputs: {captured[0]._input_names}")
|
||||
print(f"Decode: {len(decode_dot)} chars, inputs: {captured[1]._input_names}")
|
||||
|
||||
# Decode graph should have more inputs (past K/V cache tensors)
|
||||
assert len(captured[1]._input_names) > len(captured[0]._input_names), (
|
||||
f"Decode should have more inputs than prefill: "
|
||||
f"{len(captured[1]._input_names)} vs {len(captured[0]._input_names)}"
|
||||
)
|
||||
194
crates/luminal_python/tests/test_kv_cache_growing.py
Normal file
194
crates/luminal_python/tests/test_kv_cache_growing.py
Normal file
@@ -0,0 +1,194 @@
|
||||
"""KV Cache growing decode loop test.
|
||||
|
||||
Compiles a tiny 1-layer Llama model with use_cache=True, then runs a
|
||||
multi-step autoregressive decode loop:
|
||||
|
||||
1. Prefill: model(input_ids) -> logits + initial KV cache
|
||||
2. Decode x N: model(next_token, past_key_values=cache) -> logits + grown KV cache
|
||||
|
||||
At each step, prints the KV cache tensor shapes so you can see the
|
||||
sequence dimension grow: (1, n_kv_heads, 4, head_dim) -> (1, n_kv_heads, 5, ...) -> ...
|
||||
|
||||
Verifies luminal output matches PyTorch reference at every step.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch._dynamo
|
||||
|
||||
from luminal import luminal_backend
|
||||
|
||||
NUM_DECODE_STEPS = 5
|
||||
|
||||
|
||||
def test_kv_cache_growing():
|
||||
"""Multi-step prefill + decode loop showing KV cache growth."""
|
||||
from transformers import LlamaConfig, LlamaForCausalLM
|
||||
|
||||
# We need 1 compilation for prefill + 1 per unique decode cache size
|
||||
torch._dynamo.config.cache_size_limit = NUM_DECODE_STEPS + 2
|
||||
# Disable automatic dynamic shapes — dynamo would otherwise try to use SymInt
|
||||
# for the varying cache seq_len dimension, which torch.export doesn't support.
|
||||
# Instead, we want a fresh recompilation for each new cache size.
|
||||
torch._dynamo.config.automatic_dynamic_shapes = False
|
||||
|
||||
config = LlamaConfig(
|
||||
hidden_size=64,
|
||||
num_attention_heads=4,
|
||||
num_key_value_heads=2,
|
||||
num_hidden_layers=4,
|
||||
intermediate_size=128,
|
||||
vocab_size=256,
|
||||
max_position_embeddings=128,
|
||||
use_cache=True,
|
||||
attn_implementation="eager",
|
||||
)
|
||||
model = LlamaForCausalLM(config).eval()
|
||||
compiled = torch.compile(model, backend=luminal_backend)
|
||||
|
||||
input_ids = torch.tensor([[1, 2, 3, 4]])
|
||||
|
||||
# ---- Prefill ----
|
||||
with torch.no_grad():
|
||||
ref_out = model(input_ids)
|
||||
lum_out = compiled(input_ids)
|
||||
|
||||
assert ref_out.past_key_values is not None, "Reference should return KV cache"
|
||||
assert lum_out.past_key_values is not None, "Luminal should return KV cache"
|
||||
|
||||
assert torch.allclose(lum_out.logits, ref_out.logits, atol=1e-5), (
|
||||
f"Prefill mismatch: max_diff="
|
||||
f"{torch.max(torch.abs(lum_out.logits - ref_out.logits)).item():.2e}"
|
||||
)
|
||||
|
||||
_print_cache_shapes("Prefill", ref_out.past_key_values, lum_out.past_key_values)
|
||||
|
||||
ref_cache = ref_out.past_key_values
|
||||
lum_cache = lum_out.past_key_values
|
||||
|
||||
# ---- Decode loop ----
|
||||
for step in range(NUM_DECODE_STEPS):
|
||||
# Greedy next token from reference logits
|
||||
next_token = ref_out.logits[0, -1, :].argmax().unsqueeze(0).unsqueeze(0)
|
||||
|
||||
with torch.no_grad():
|
||||
ref_out = model(next_token, past_key_values=ref_cache)
|
||||
lum_out = compiled(next_token, past_key_values=lum_cache)
|
||||
|
||||
assert torch.allclose(lum_out.logits, ref_out.logits, atol=1e-5), (
|
||||
f"Decode step {step} mismatch: max_diff="
|
||||
f"{torch.max(torch.abs(lum_out.logits - ref_out.logits)).item():.2e}"
|
||||
)
|
||||
|
||||
ref_cache = ref_out.past_key_values
|
||||
lum_cache = lum_out.past_key_values
|
||||
|
||||
_print_cache_shapes(f"Decode step {step}", ref_cache, lum_cache)
|
||||
|
||||
# Final sanity check: cache seq_len should equal prompt + decode steps
|
||||
expected_seq = input_ids.shape[1] + NUM_DECODE_STEPS
|
||||
final_k = ref_cache.layers[0].keys
|
||||
assert final_k.shape[2] == expected_seq, (
|
||||
f"Expected cache seq_len={expected_seq}, got {final_k.shape[2]}"
|
||||
)
|
||||
print(
|
||||
f"\nAll {NUM_DECODE_STEPS} decode steps passed. "
|
||||
f"Cache grew from seq_len={input_ids.shape[1]} to {expected_seq}."
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(),
|
||||
reason="R1 full-width 1-layer is too memory-heavy for CPU native backend",
|
||||
)
|
||||
def test_kv_cache_growing_r1_mla(device: torch.device):
|
||||
"""Growing-cache decode loop on DeepSeek-R1 (MLA + decoupled RoPE), 1 layer.
|
||||
|
||||
Exercises MLA: q_lora / kv_lora low-rank projections, decoupled RoPE split
|
||||
(qk_nope_head_dim + qk_rope_head_dim), and DynamicCache crossing the compile
|
||||
boundary through the MLA update path (`cache_utils.py:102-121`).
|
||||
|
||||
Runs in fp32 — in bf16, MLA's empty-tensor-cat inside DynamicLayer.update
|
||||
has a precision drift on the compiled path (logits ~3.7 on 1 layer) that
|
||||
does not affect standard GQA (Llama in bf16 is bit-identical). Investigate
|
||||
separately.
|
||||
"""
|
||||
from transformers import AutoConfig, DeepseekV3ForCausalLM
|
||||
|
||||
torch._dynamo.config.cache_size_limit = NUM_DECODE_STEPS + 2
|
||||
torch._dynamo.config.automatic_dynamic_shapes = False
|
||||
|
||||
# Release any memory accumulated by previous tests in the same pytest
|
||||
# process — full-width R1 instantiation needs ~3 GB and the test runner's
|
||||
# GPU is shared with ~230 prior tests' allocations.
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
config = AutoConfig.from_pretrained("deepseek-ai/DeepSeek-R1")
|
||||
config.num_hidden_layers = 1
|
||||
# first_k_dense_replace=3 (default) makes the 1 layer dense, so we avoid
|
||||
# the 256-expert MoE path and the associated memory pressure.
|
||||
config._attn_implementation = "eager"
|
||||
config.torch_dtype = torch.float32
|
||||
# Aggressively shrink the embedding / LM head / FFN dimensions while
|
||||
# preserving the MLA-specific knobs that the test is actually exercising
|
||||
# (q_lora_rank, kv_lora_rank, qk_nope_head_dim, qk_rope_head_dim, v_head_dim).
|
||||
# Full R1 has vocab=129280, intermediate=18432, hidden=7168 — at fp32 the
|
||||
# embedding + LM head alone is ~3.5 GB, which OOMs the 40 GB test runner
|
||||
# after prior tests' allocations. The MLA path is unchanged at vocab=256.
|
||||
config.vocab_size = 256
|
||||
config.intermediate_size = 512
|
||||
config.max_position_embeddings = 128
|
||||
model = DeepseekV3ForCausalLM(config).eval().to(dtype=torch.float32, device=device)
|
||||
compiled = torch.compile(model, backend=luminal_backend)
|
||||
|
||||
input_ids = torch.tensor([[1, 2, 3, 4]], device=device)
|
||||
|
||||
with torch.no_grad():
|
||||
ref_out = model(input_ids)
|
||||
lum_out = compiled(input_ids)
|
||||
|
||||
# fp32 MLA matches to ~1e-5 — see diagnose_dtype.py. Keep the tolerance
|
||||
# tight here so regressions in the MLA cat/split path show up immediately.
|
||||
assert torch.allclose(lum_out.logits, ref_out.logits, atol=1e-4), (
|
||||
f"Prefill: max_diff={torch.max(torch.abs(lum_out.logits - ref_out.logits)).item():.2e}"
|
||||
)
|
||||
|
||||
ref_cache = ref_out.past_key_values
|
||||
lum_cache = lum_out.past_key_values
|
||||
|
||||
# Run a single decode step — enough to confirm the cache flows through as an
|
||||
# explicit input on the second compile (the key signal from
|
||||
# _test_kv_cache_comparison.py's "decode has more inputs than prefill"
|
||||
# assertion). Full 5-step growth is covered by the Llama test above.
|
||||
next_token = ref_out.logits[0, -1, :].argmax().view(1, 1).to(device)
|
||||
with torch.no_grad():
|
||||
ref_dec = model(next_token, past_key_values=ref_cache)
|
||||
lum_dec = compiled(next_token, past_key_values=lum_cache)
|
||||
|
||||
assert torch.allclose(lum_dec.logits, ref_dec.logits, atol=1e-4), (
|
||||
f"Decode: max_diff={torch.max(torch.abs(lum_dec.logits - ref_dec.logits)).item():.2e}"
|
||||
)
|
||||
|
||||
|
||||
def _print_cache_shapes(label, ref_cache, lum_cache):
|
||||
"""Print KV cache shapes for both reference and luminal."""
|
||||
print(f"\n--- {label} ---")
|
||||
for layer_idx, ref_layer in enumerate(ref_cache.layers):
|
||||
ref_k, ref_v = ref_layer.keys, ref_layer.values
|
||||
lum_layer = lum_cache.layers[layer_idx]
|
||||
lum_k, lum_v = lum_layer.keys, lum_layer.values
|
||||
print(
|
||||
f" Layer {layer_idx}: "
|
||||
f"K ref={list(ref_k.shape)} lum={list(lum_k.shape)} | "
|
||||
f"V ref={list(ref_v.shape)} lum={list(lum_v.shape)}"
|
||||
)
|
||||
# Verify cache tensors match
|
||||
assert torch.allclose(lum_k, ref_k, atol=1e-5), (
|
||||
f"{label} layer {layer_idx} K mismatch: "
|
||||
f"max_diff={torch.max(torch.abs(lum_k - ref_k)).item():.2e}"
|
||||
)
|
||||
assert torch.allclose(lum_v, ref_v, atol=1e-5), (
|
||||
f"{label} layer {layer_idx} V mismatch: "
|
||||
f"max_diff={torch.max(torch.abs(lum_v - ref_v)).item():.2e}"
|
||||
)
|
||||
@@ -66,12 +66,15 @@ def test_causal_self_attention(device: torch.device):
|
||||
|
||||
def test_llama_transformer_block(device: torch.device):
|
||||
"""Test full Llama transformer block: RMSNorm -> Attn -> Residual -> RMSNorm -> MLP -> Residual."""
|
||||
torch.manual_seed(0)
|
||||
model: torch.nn.Module = LlamaTransformerBlockModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x: torch.Tensor = torch.rand((1, 4, 32), device=device)
|
||||
original: torch.Tensor = model(x)
|
||||
output: torch.Tensor = model_compiled(x)
|
||||
assert torch.allclose(output, original, atol=1e-4)
|
||||
assert torch.allclose(output, original, atol=1e-3), (
|
||||
f"max_diff={torch.max(torch.abs(output - original)).item():.2e}"
|
||||
)
|
||||
|
||||
|
||||
# ========== HuggingFace LlamaForCausalLM Tests ==========
|
||||
@@ -392,11 +395,11 @@ def test_dynamic_dim_reuse_no_recompile(device: torch.device):
|
||||
return self.proj(self.embed(x))
|
||||
|
||||
model = DynamicSeqModel().eval().to(device)
|
||||
backend = "cuda" if device.type == "cuda" else "native"
|
||||
|
||||
# Compile once with dynamic seq dim (auto-detected for integer inputs)
|
||||
# Compile once with dynamic seq dim (auto-detected for integer inputs).
|
||||
# Factory capsule is auto-detected from example.device.
|
||||
example = torch.tensor([[1, 2, 3, 4]], device=device)
|
||||
compiled = luminal_compile(model, example, search_iterations=5, backend=backend)
|
||||
compiled = luminal_compile(model, example, search_iterations=5)
|
||||
|
||||
# Execute with multiple different seq lengths — each call reuses the
|
||||
# same compiled graph, updating dynamic dims in-place.
|
||||
@@ -411,6 +414,71 @@ def test_dynamic_dim_reuse_no_recompile(device: torch.device):
|
||||
)
|
||||
|
||||
|
||||
def test_hf_llama3_8b_instruct_1layer(device: torch.device):
|
||||
"""HuggingFace LlamaForCausalLM — Llama-3-8B-Instruct architecture, 1 layer, random weights.
|
||||
|
||||
Uses the exact model architecture from the TTFT benchmark
|
||||
(NousResearch/Meta-Llama-3-8B-Instruct) with num_hidden_layers=1. Full 8B width:
|
||||
4096 hidden, 32 attn heads, 8 KV heads, 14336 intermediate, 128256 vocab.
|
||||
Random weights — tests that compilation and execution complete without error.
|
||||
|
||||
Regression for: NativeRuntime panic 'no entry found for key' (hlir.rs:2239) when the
|
||||
wheel is built without --features cuda. The CUDA factory capsule silently falls back
|
||||
to NativeRuntime, which cannot process GPU-resident weight device pointers, leaving
|
||||
Output-node predecessor buffers unpopulated.
|
||||
"""
|
||||
from transformers import AutoConfig, LlamaForCausalLM
|
||||
|
||||
config = AutoConfig.from_pretrained("NousResearch/Meta-Llama-3-8B-Instruct")
|
||||
config.num_hidden_layers = 1
|
||||
config.use_cache = False
|
||||
config._attn_implementation = "eager"
|
||||
|
||||
model = LlamaForCausalLM(config).eval().to(device)
|
||||
compiled = torch.compile(model, backend=luminal_backend)
|
||||
input_ids = torch.tensor([[1, 2, 3, 4]], device=device)
|
||||
with torch.no_grad():
|
||||
ref = model(input_ids)
|
||||
out = compiled(input_ids)
|
||||
assert torch.allclose(out.logits, ref.logits, atol=1e-4), (
|
||||
f"max_diff={torch.max(torch.abs(out.logits - ref.logits)).item():.2e}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.xfail(reason="numerical precision — max_diff exceeds atol at full 8B scale")
|
||||
def test_hf_llama3_8b_instruct_full(device: torch.device):
|
||||
"""HuggingFace LlamaForCausalLM — full Llama-3-8B-Instruct with real pretrained weights.
|
||||
|
||||
Direct reproduction of the TTFT benchmark scenario. All 32 layers at full width.
|
||||
Loads actual weights from NousResearch/Meta-Llama-3-8B-Instruct (~30 GB in fp32).
|
||||
Marked slow (requires model download) and xfail (numerical precision at this scale).
|
||||
"""
|
||||
from transformers import AutoConfig, LlamaForCausalLM
|
||||
|
||||
config = AutoConfig.from_pretrained("NousResearch/Meta-Llama-3-8B-Instruct")
|
||||
config.use_cache = False
|
||||
config._attn_implementation = "eager"
|
||||
|
||||
model = (
|
||||
LlamaForCausalLM.from_pretrained(
|
||||
"NousResearch/Meta-Llama-3-8B-Instruct",
|
||||
config=config,
|
||||
torch_dtype=torch.float32,
|
||||
)
|
||||
.eval()
|
||||
.to(device)
|
||||
)
|
||||
compiled = torch.compile(model, backend=luminal_backend)
|
||||
input_ids = torch.tensor([[1, 2, 3, 4]], device=device)
|
||||
with torch.no_grad():
|
||||
ref = model(input_ids)
|
||||
out = compiled(input_ids)
|
||||
assert torch.allclose(out.logits, ref.logits, atol=1e-5), (
|
||||
f"max_diff={torch.max(torch.abs(out.logits - ref.logits)).item():.2e}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="numerical precision — max_diff exceeds atol")
|
||||
def test_hf_llama38b_full(device: torch.device):
|
||||
"""HuggingFace LlamaForCausalLM — full Llama-3.1-8B-Instruct with real pretrained weights.
|
||||
|
||||
@@ -1619,6 +1619,73 @@ class SplitTestModel(torch.nn.Module):
|
||||
return a + b
|
||||
|
||||
|
||||
# ========== Argsort / MoE Routing Test Models ==========
|
||||
|
||||
|
||||
class ArgsortStableDuplicatesModel(torch.nn.Module):
|
||||
"""Tests deterministic duplicate ordering for exported argsort."""
|
||||
|
||||
SORT_DIM = 1
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.argsort(x, dim=self.SORT_DIM)
|
||||
|
||||
|
||||
class TinyMoERoutingModel(torch.nn.Module):
|
||||
"""Minimal deterministic MoE-style routing proof for PT2/native and CUDA."""
|
||||
|
||||
TOP_K = 2
|
||||
ROUTING_DIM = -1
|
||||
ZERO_FILL = 0.0
|
||||
DISPATCH_ON = 1
|
||||
GROUP_SIZE = 2
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.register_buffer(
|
||||
"expert_scale",
|
||||
torch.tensor([1.5, -0.5, 2.0, 0.25], dtype=torch.float32),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, scores: torch.Tensor
|
||||
) -> tuple[
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
]:
|
||||
topk_values, topk_indices = torch.topk(scores, self.TOP_K, dim=self.ROUTING_DIM)
|
||||
regroup_order = torch.argsort(topk_indices, dim=self.ROUTING_DIM)
|
||||
routed_indices = torch.gather(topk_indices, self.ROUTING_DIM, regroup_order)
|
||||
routed_values = torch.gather(topk_values, self.ROUTING_DIM, regroup_order)
|
||||
|
||||
expert_scale = self.expert_scale.unsqueeze(0).expand(scores.shape[0], -1)
|
||||
gathered_scale = torch.gather(expert_scale, self.ROUTING_DIM, routed_indices)
|
||||
weighted = routed_values * gathered_scale
|
||||
|
||||
inactive_mask = torch.bitwise_not(weighted > 0)
|
||||
masked_values = weighted.masked_fill(inactive_mask, self.ZERO_FILL)
|
||||
|
||||
slots = torch.zeros_like(routed_indices).scatter(
|
||||
self.ROUTING_DIM, regroup_order, self.DISPATCH_ON
|
||||
)
|
||||
active_slots = torch.bitwise_not(inactive_mask).to(slots.dtype)
|
||||
dispatch = slots * active_slots
|
||||
group_ids = torch.floor_divide(routed_indices, self.GROUP_SIZE)
|
||||
routing_sign = torch.sign(masked_values)
|
||||
return (
|
||||
routed_indices,
|
||||
masked_values,
|
||||
dispatch,
|
||||
inactive_mask,
|
||||
group_ids,
|
||||
routing_sign,
|
||||
)
|
||||
|
||||
|
||||
# ========== TopK Node Test Models ==========
|
||||
|
||||
|
||||
@@ -1685,6 +1752,22 @@ class ScatterNDTestModel(torch.nn.Module):
|
||||
return result
|
||||
|
||||
|
||||
class IndexPutOptionalModel(torch.nn.Module):
|
||||
"""Tests index_put with optional (None) indices — mirrors StaticCache KV update.
|
||||
|
||||
result[:, :, pos, :] = ones → index_put([None, None, pos_tensor, (implied None)], ones)
|
||||
Input: (2, 2, 8, 4) Output: same shape with dim-2 position 0 set to 1.
|
||||
Batch size > 1 is required so PT2 preserves the full rank of the values tensor.
|
||||
"""
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
pos = torch.zeros(1, dtype=torch.long, device=x.device)
|
||||
v = torch.ones(2, 2, 1, 4, device=x.device)
|
||||
result = x.clone()
|
||||
result[:, :, pos, :] = v
|
||||
return result
|
||||
|
||||
|
||||
# ========== Llama3 Component Test Models ==========
|
||||
|
||||
|
||||
@@ -1840,9 +1923,14 @@ class LlamaTransformerBlockModel(torch.nn.Module):
|
||||
class Conv1dNoPadModel(torch.nn.Module):
|
||||
"""Conv1d with no padding: output length shrinks by (kernel-1)."""
|
||||
|
||||
KERNEL_SIZE = 3
|
||||
PADDING = 0
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv1d(8, 16, kernel_size=3, padding=0, bias=False)
|
||||
self.conv = torch.nn.Conv1d(
|
||||
8, 16, kernel_size=self.KERNEL_SIZE, padding=self.PADDING, bias=False
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.conv(x)
|
||||
@@ -1851,9 +1939,14 @@ class Conv1dNoPadModel(torch.nn.Module):
|
||||
class Conv1dSamePadModel(torch.nn.Module):
|
||||
"""Conv1d with same-size padding (output length == input length)."""
|
||||
|
||||
KERNEL_SIZE = 3
|
||||
PADDING = 1
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv1d(8, 16, kernel_size=3, padding=1, bias=False)
|
||||
self.conv = torch.nn.Conv1d(
|
||||
8, 16, kernel_size=self.KERNEL_SIZE, padding=self.PADDING, bias=False
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.conv(x)
|
||||
@@ -1862,9 +1955,14 @@ class Conv1dSamePadModel(torch.nn.Module):
|
||||
class Conv1dBiasModel(torch.nn.Module):
|
||||
"""Conv1d with bias."""
|
||||
|
||||
KERNEL_SIZE = 3
|
||||
PADDING = 1
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv1d(8, 16, kernel_size=3, padding=1, bias=True)
|
||||
self.conv = torch.nn.Conv1d(
|
||||
8, 16, kernel_size=self.KERNEL_SIZE, padding=self.PADDING, bias=True
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.conv(x)
|
||||
@@ -1873,9 +1971,14 @@ class Conv1dBiasModel(torch.nn.Module):
|
||||
class Conv2dNoPadModel(torch.nn.Module):
|
||||
"""Conv2d with no padding: output spatial dims shrink by (kernel-1)."""
|
||||
|
||||
KERNEL_SIZE = 3
|
||||
PADDING = 0
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(3, 16, kernel_size=3, padding=0, bias=False)
|
||||
self.conv = torch.nn.Conv2d(
|
||||
3, 16, kernel_size=self.KERNEL_SIZE, padding=self.PADDING, bias=False
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.conv(x)
|
||||
@@ -1884,9 +1987,14 @@ class Conv2dNoPadModel(torch.nn.Module):
|
||||
class Conv2dSamePadModel(torch.nn.Module):
|
||||
"""Conv2d with same-size padding."""
|
||||
|
||||
KERNEL_SIZE = 3
|
||||
PADDING = 1
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(3, 16, kernel_size=3, padding=1, bias=False)
|
||||
self.conv = torch.nn.Conv2d(
|
||||
3, 16, kernel_size=self.KERNEL_SIZE, padding=self.PADDING, bias=False
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.conv(x)
|
||||
@@ -1895,9 +2003,14 @@ class Conv2dSamePadModel(torch.nn.Module):
|
||||
class Conv2dBiasModel(torch.nn.Module):
|
||||
"""Conv2d with bias."""
|
||||
|
||||
KERNEL_SIZE = 3
|
||||
PADDING = 1
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(3, 16, kernel_size=3, padding=1, bias=True)
|
||||
self.conv = torch.nn.Conv2d(
|
||||
3, 16, kernel_size=self.KERNEL_SIZE, padding=self.PADDING, bias=True
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.conv(x)
|
||||
@@ -1906,10 +2019,19 @@ class Conv2dBiasModel(torch.nn.Module):
|
||||
class Conv2dStrideModel(torch.nn.Module):
|
||||
"""Conv2d with stride=2 (output dims halved)."""
|
||||
|
||||
KERNEL_SIZE = 3
|
||||
STRIDE = 2
|
||||
PADDING = 1
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(
|
||||
3, 16, kernel_size=3, stride=2, padding=1, bias=False
|
||||
3,
|
||||
16,
|
||||
kernel_size=self.KERNEL_SIZE,
|
||||
stride=self.STRIDE,
|
||||
padding=self.PADDING,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@@ -1919,10 +2041,19 @@ class Conv2dStrideModel(torch.nn.Module):
|
||||
class Conv2dDilationModel(torch.nn.Module):
|
||||
"""Conv2d with dilation=2 and padding chosen to preserve spatial size."""
|
||||
|
||||
KERNEL_SIZE = 3
|
||||
DILATION = 2
|
||||
PADDING = 2
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(
|
||||
8, 16, kernel_size=3, dilation=2, padding=2, bias=False
|
||||
8,
|
||||
16,
|
||||
kernel_size=self.KERNEL_SIZE,
|
||||
dilation=self.DILATION,
|
||||
padding=self.PADDING,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@@ -1932,9 +2063,14 @@ class Conv2dDilationModel(torch.nn.Module):
|
||||
class Conv3dSamePadModel(torch.nn.Module):
|
||||
"""Conv3d with padding=1 to preserve spatial dimensions."""
|
||||
|
||||
KERNEL_SIZE = 3
|
||||
PADDING = 1
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv3d(4, 8, kernel_size=3, padding=1, bias=False)
|
||||
self.conv = torch.nn.Conv3d(
|
||||
4, 8, kernel_size=self.KERNEL_SIZE, padding=self.PADDING, bias=False
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.conv(x)
|
||||
@@ -1943,10 +2079,19 @@ class Conv3dSamePadModel(torch.nn.Module):
|
||||
class DepthwiseConv1dModel(torch.nn.Module):
|
||||
"""Depthwise Conv1d as used in Mamba (groups == in_channels)."""
|
||||
|
||||
KERNEL_SIZE = 4
|
||||
GROUPS = 16
|
||||
PADDING = 3
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv1d(
|
||||
16, 16, kernel_size=4, groups=16, padding=3, bias=True
|
||||
16,
|
||||
16,
|
||||
kernel_size=self.KERNEL_SIZE,
|
||||
groups=self.GROUPS,
|
||||
padding=self.PADDING,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@@ -1957,10 +2102,19 @@ class DepthwiseConv1dModel(torch.nn.Module):
|
||||
class DepthwiseConv2dModel(torch.nn.Module):
|
||||
"""Depthwise Conv2d (groups == in_channels)."""
|
||||
|
||||
KERNEL_SIZE = 3
|
||||
GROUPS = 8
|
||||
PADDING = 1
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(
|
||||
8, 8, kernel_size=3, groups=8, padding=1, bias=False
|
||||
8,
|
||||
8,
|
||||
kernel_size=self.KERNEL_SIZE,
|
||||
groups=self.GROUPS,
|
||||
padding=self.PADDING,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@@ -1970,10 +2124,19 @@ class DepthwiseConv2dModel(torch.nn.Module):
|
||||
class DepthwiseMultiplierConv2dModel(torch.nn.Module):
|
||||
"""Depthwise Conv2d with channel multiplier 2 (out_channels = 2 * in_channels)."""
|
||||
|
||||
KERNEL_SIZE = 3
|
||||
GROUPS = 8
|
||||
PADDING = 1
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(
|
||||
8, 16, kernel_size=3, groups=8, padding=1, bias=False
|
||||
8,
|
||||
16,
|
||||
kernel_size=self.KERNEL_SIZE,
|
||||
groups=self.GROUPS,
|
||||
padding=self.PADDING,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@@ -1983,10 +2146,19 @@ class DepthwiseMultiplierConv2dModel(torch.nn.Module):
|
||||
class GroupedConv2dModel(torch.nn.Module):
|
||||
"""Conv2d with groups=4 (not depthwise, but grouped)."""
|
||||
|
||||
KERNEL_SIZE = 3
|
||||
GROUPS = 4
|
||||
PADDING = 1
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(
|
||||
16, 32, kernel_size=3, groups=4, padding=1, bias=False
|
||||
16,
|
||||
32,
|
||||
kernel_size=self.KERNEL_SIZE,
|
||||
groups=self.GROUPS,
|
||||
padding=self.PADDING,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@@ -1996,10 +2168,19 @@ class GroupedConv2dModel(torch.nn.Module):
|
||||
class GroupedConv2dGroups3Model(torch.nn.Module):
|
||||
"""Conv2d with groups=3 and ch_per_group=4."""
|
||||
|
||||
KERNEL_SIZE = 3
|
||||
GROUPS = 3
|
||||
PADDING = 1
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(
|
||||
12, 12, kernel_size=3, groups=3, padding=1, bias=False
|
||||
12,
|
||||
12,
|
||||
kernel_size=self.KERNEL_SIZE,
|
||||
groups=self.GROUPS,
|
||||
padding=self.PADDING,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@@ -2015,9 +2196,16 @@ class MambaConvBlockModel(torch.nn.Module):
|
||||
def __init__(self, d_model: int = 16, d_conv: int = 4, expand: int = 2) -> None:
|
||||
super().__init__()
|
||||
d_inner = d_model * expand
|
||||
groups = d_inner
|
||||
padding = d_conv - 1
|
||||
self.in_proj = torch.nn.Linear(d_model, d_inner * 2, bias=False)
|
||||
self.conv1d = torch.nn.Conv1d(
|
||||
d_inner, d_inner, d_conv, groups=d_inner, padding=d_conv - 1, bias=True
|
||||
d_inner,
|
||||
d_inner,
|
||||
d_conv,
|
||||
groups=groups,
|
||||
padding=padding,
|
||||
bias=True,
|
||||
)
|
||||
self.out_proj = torch.nn.Linear(d_inner, d_model, bias=False)
|
||||
|
||||
@@ -2029,3 +2217,82 @@ class MambaConvBlockModel(torch.nn.Module):
|
||||
return self.out_proj(
|
||||
torch.nn.functional.silu(x_part) * torch.nn.functional.silu(z)
|
||||
)
|
||||
|
||||
|
||||
class BitwiseOrTestModel(torch.nn.Module):
|
||||
"""Tests bitwise_or on boolean tensors — the pattern Gemma-style models
|
||||
emit when fusing sliding-window and full-attention masks
|
||||
(`mask = sliding_mask | full_mask`)."""
|
||||
|
||||
def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
|
||||
return a | b
|
||||
|
||||
|
||||
class GroupedMMFallbackTestModel(torch.nn.Module):
|
||||
"""Tests transformers::grouped_mm_fallback — the per-expert batched
|
||||
matmul HF MoE models emit (DeepSeek-V2, Qwen-MoE, Mixtral, etc.).
|
||||
|
||||
Calls the registered custom_op directly with shapes that match a
|
||||
realistic MoE expert dispatch: input is `(S, K)` of tokens already
|
||||
sorted by expert, weight is `(G, K, N)` per-expert weights, offs is
|
||||
`(G,)` cumulative token counts.
|
||||
"""
|
||||
|
||||
def forward(
|
||||
self, input: torch.Tensor, weight: torch.Tensor, offs: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
return torch.ops.transformers.grouped_mm_fallback(input, weight, offs)
|
||||
|
||||
|
||||
class BoolMaskAssignIntModel(torch.nn.Module):
|
||||
"""`x[mask] = scalar` on integer data with a Bool-dtype mask whose shape
|
||||
matches `x`.
|
||||
|
||||
PyTorch decomposes this to `aten.index_put_(x, [mask], scalar)`. The
|
||||
correct lowering is `where(mask, scalar, x)` — NOT a scatter into Int(mask)
|
||||
positions. Pre-fix, the compiled output silently corrupted row 0 of `x`
|
||||
even when the mask was all-False (the silent-data-corruption case driven
|
||||
by Gemma-4's multimodal_mask path).
|
||||
"""
|
||||
|
||||
def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
||||
out = x.clone()
|
||||
out[mask] = 99
|
||||
return out
|
||||
|
||||
|
||||
class BoolMaskAssignFloatModel(torch.nn.Module):
|
||||
"""Same as BoolMaskAssignIntModel but with float data + a float scalar.
|
||||
|
||||
Verifies the `where` blend works for non-integer dtypes too.
|
||||
"""
|
||||
|
||||
def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
||||
out = x.clone()
|
||||
out[mask] = 7.5
|
||||
return out
|
||||
|
||||
|
||||
class BoolMaskAssign3DModel(torch.nn.Module):
|
||||
"""Multi-dimensional `x[mask] = scalar` — Bool mask shape must match `x`'s
|
||||
full shape, not just be 1D. Catches regressions where the bool-mask
|
||||
detection only works at one specific rank.
|
||||
"""
|
||||
|
||||
def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
||||
out = x.clone()
|
||||
out[mask] = -1.0
|
||||
return out
|
||||
|
||||
|
||||
class IntIndexAssignScalarModel(torch.nn.Module):
|
||||
"""`x[indices] = scalar_tensor` with a rank-1 index tensor and a 0-D
|
||||
scalar value. After PT2 decomposition this hits the scatter path with a
|
||||
scalar src; the lowering must broadcast the scalar across all indexed
|
||||
positions (zero-stride padding in `GraphTensor::scatter`).
|
||||
"""
|
||||
|
||||
def forward(self, x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
|
||||
out = x.clone()
|
||||
out[indices] = 42.0
|
||||
return out
|
||||
|
||||
BIN
docs/logo/inference_at_the_speed_of_light.png
Normal file
BIN
docs/logo/inference_at_the_speed_of_light.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 380 KiB |
@@ -13,11 +13,21 @@ use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
||||
|
||||
const REPO_ID: &str = "unsloth/gemma-3-4b-it";
|
||||
|
||||
// Default configuration — override at runtime via env vars.
|
||||
const DEFAULT_MAX_SEQ_LEN: usize = 4096;
|
||||
const DEFAULT_SEARCH_GRAPHS: usize = 50;
|
||||
const DEFAULT_GEN_TOKENS: usize = 500;
|
||||
|
||||
fn env_usize(name: &str, default: usize) -> usize {
|
||||
std::env::var(name).ok().and_then(|s| s.parse().ok()).unwrap_or(default)
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let max_seq_len = 4096;
|
||||
let gen_tokens = 500;
|
||||
let search_graphs = 500;
|
||||
let prompt = "Explain what a neural network is in simple terms:";
|
||||
let max_seq_len = env_usize("MAX_SEQ_LEN", DEFAULT_MAX_SEQ_LEN);
|
||||
let gen_tokens = env_usize("GEN_TOKENS", DEFAULT_GEN_TOKENS);
|
||||
let search_graphs = env_usize("SEARCH_GRAPHS", DEFAULT_SEARCH_GRAPHS);
|
||||
let prompt = std::env::var("PROMPT")
|
||||
.unwrap_or_else(|_| "Explain what a neural network is in simple terms:".to_string());
|
||||
|
||||
tracing_subscriber::registry()
|
||||
.with(tracing_subscriber::fmt::layer())
|
||||
@@ -46,6 +56,7 @@ fn main() {
|
||||
}
|
||||
|
||||
println!("Building E-Graph...");
|
||||
let compile_start = std::time::Instant::now();
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
|
||||
println!("Loading weights...");
|
||||
@@ -65,36 +76,90 @@ fn main() {
|
||||
runtime.set_data(input, vec![1]);
|
||||
runtime.set_data(token_ids, vec![1]);
|
||||
runtime = cx.search(runtime, search_graphs);
|
||||
println!(" COMPILE: {:.2} ms", compile_start.elapsed().as_secs_f64() * 1e3);
|
||||
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
|
||||
}
|
||||
|
||||
// Full-prompt warmup: run the complete prompt to bring the GPU to steady state before timing
|
||||
for (w_step, &w_token) in prompt_tokens.iter().enumerate() {
|
||||
let p = w_step + 1;
|
||||
cx.set_dim('s', 1);
|
||||
cx.set_dim('p', p);
|
||||
runtime.set_data(input, vec![w_token as i32]);
|
||||
runtime.set_data(token_ids, vec![p as i32]);
|
||||
runtime.execute(&cx.dyn_map);
|
||||
for (layer_idx, (k_out, v_out)) in cache_outputs.iter().enumerate() {
|
||||
let k_buf = runtime.remove_buffer(*k_out);
|
||||
let v_buf = runtime.remove_buffer(*v_out);
|
||||
runtime.set_buffer(kv_cache.k_caches[layer_idx], k_buf);
|
||||
runtime.set_buffer(kv_cache.v_caches[layer_idx], v_buf);
|
||||
}
|
||||
}
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
|
||||
}
|
||||
|
||||
let iters = env_usize("ITERS", 3);
|
||||
let prompt_len = prompt_tokens.len();
|
||||
|
||||
println!("Prompt: {} tokens, generating up to {} tokens", prompt_len, gen_tokens);
|
||||
|
||||
// ── TTFT: prefill-only timing over N iterations ───────────────────────
|
||||
let mut ttft_samples_ms: Vec<f64> = vec![];
|
||||
for _ in 0..iters {
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
|
||||
}
|
||||
let mut prev_seq = 1usize;
|
||||
let mut step_times = vec![];
|
||||
for step in 0..prompt_len {
|
||||
cx.set_dim('s', 1);
|
||||
cx.set_dim('p', prev_seq);
|
||||
runtime.set_data(input, vec![prompt_tokens[step] as i32]);
|
||||
runtime.set_data(token_ids, vec![prev_seq as i32]);
|
||||
let t = std::time::Instant::now();
|
||||
runtime.execute(&cx.dyn_map);
|
||||
let _ = runtime.get_f32(logits);
|
||||
step_times.push(t.elapsed());
|
||||
for (layer_idx, (k_out, v_out)) in cache_outputs.iter().enumerate() {
|
||||
let k_buf = runtime.remove_buffer(*k_out);
|
||||
let v_buf = runtime.remove_buffer(*v_out);
|
||||
runtime.set_buffer(kv_cache.k_caches[layer_idx], k_buf);
|
||||
runtime.set_buffer(kv_cache.v_caches[layer_idx], v_buf);
|
||||
}
|
||||
prev_seq += 1;
|
||||
}
|
||||
ttft_samples_ms.push(step_times.iter().sum::<Duration>().as_secs_f64() * 1e3);
|
||||
}
|
||||
ttft_samples_ms.sort_by(|a, b| a.partial_cmp(b).unwrap());
|
||||
let ttft_ms = ttft_samples_ms[ttft_samples_ms.len() / 2];
|
||||
|
||||
// ── Text generation: one pass for TPOT + visible output ───────────────
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
|
||||
}
|
||||
let mut prev_seq = 1usize;
|
||||
let mut sentence = vec![prompt_tokens[0]];
|
||||
let total_steps = prompt_tokens.len() - 1 + gen_tokens;
|
||||
let prompt_len = prompt_tokens.len();
|
||||
let mut fwd_durations = vec![];
|
||||
let total_steps = prompt_len - 1 + gen_tokens;
|
||||
let mut decode_step_times: Vec<Duration> = vec![];
|
||||
let mut seen_tokens = FxHashSet::default();
|
||||
let repetition_penalty: f32 = 1.05;
|
||||
|
||||
const EOS_TOKEN: u32 = 1;
|
||||
const STOP_TOKEN: u32 = 107;
|
||||
|
||||
println!(
|
||||
"Prompt: {} tokens, generating up to {} tokens",
|
||||
prompt_len, gen_tokens
|
||||
);
|
||||
|
||||
for i in 0..total_steps {
|
||||
let start = std::time::Instant::now();
|
||||
let is_prefill = i < prompt_len - 1;
|
||||
let seq_len = sentence.len();
|
||||
|
||||
cx.set_dim('s', seq_len);
|
||||
cx.set_dim('p', prev_seq);
|
||||
|
||||
runtime.set_data(
|
||||
input,
|
||||
sentence.iter().map(|t| *t as i32).collect::<Vec<_>>(),
|
||||
@@ -104,26 +169,26 @@ fn main() {
|
||||
(prev_seq as i32..(seq_len + prev_seq) as i32).collect::<Vec<_>>(),
|
||||
);
|
||||
|
||||
let step_start = std::time::Instant::now();
|
||||
runtime.execute(&cx.dyn_map);
|
||||
let logits_data = runtime.get_f32(logits);
|
||||
let step_elapsed = step_start.elapsed();
|
||||
|
||||
// Round-trip KV cache
|
||||
for (layer_idx, (k_out, v_out)) in cache_outputs.iter().enumerate() {
|
||||
let k_buf = runtime.remove_buffer(*k_out);
|
||||
let v_buf = runtime.remove_buffer(*v_out);
|
||||
runtime.set_buffer(kv_cache.k_caches[layer_idx], k_buf);
|
||||
runtime.set_buffer(kv_cache.v_caches[layer_idx], v_buf);
|
||||
}
|
||||
|
||||
prev_seq += seq_len;
|
||||
fwd_durations.push(start.elapsed());
|
||||
|
||||
if is_prefill {
|
||||
sentence = vec![prompt_tokens[i + 1]];
|
||||
continue;
|
||||
}
|
||||
|
||||
// Greedy decode with repetition penalty
|
||||
decode_step_times.push(step_elapsed);
|
||||
|
||||
let mut last_row = logits_data[logits_data.len() - VOCAB_SIZE..].to_vec();
|
||||
for &tok in &seen_tokens {
|
||||
let logit = &mut last_row[tok as usize];
|
||||
@@ -152,21 +217,13 @@ fn main() {
|
||||
}
|
||||
println!();
|
||||
|
||||
// Benchmarks
|
||||
let decode_durations: Vec<_> = fwd_durations.iter().skip(prompt_len).collect();
|
||||
if decode_durations.len() > 2 {
|
||||
println!(
|
||||
" TTFT: {:.2} ms",
|
||||
fwd_durations[..prompt_len]
|
||||
.iter()
|
||||
.sum::<Duration>()
|
||||
.as_secs_f64()
|
||||
* 1e3
|
||||
);
|
||||
// ── Report ────────────────────────────────────────────────────────────
|
||||
println!(" TTFT: {:.2} ms", ttft_ms);
|
||||
if decode_step_times.len() > 1 {
|
||||
println!(
|
||||
" TPOT: {:.2} ms",
|
||||
(decode_durations.iter().skip(1).copied().sum::<Duration>()
|
||||
/ (decode_durations.len() - 1) as u32)
|
||||
(decode_step_times.iter().skip(1).sum::<Duration>()
|
||||
/ (decode_step_times.len() - 1) as u32)
|
||||
.as_secs_f64()
|
||||
* 1_000.
|
||||
);
|
||||
|
||||
@@ -199,7 +199,7 @@ impl Gemma {
|
||||
kv_cache.v_caches[i],
|
||||
kv_cache.max_seq,
|
||||
);
|
||||
x = x_new.graph_break();
|
||||
x = x_new;
|
||||
cache_outputs.push((k_out, v_out));
|
||||
}
|
||||
let logits = self.lm_norm.forward(x).matmul(self.lm_head.t());
|
||||
|
||||
22
examples/gemma4_moe/Cargo.toml
Normal file
22
examples/gemma4_moe/Cargo.toml
Normal file
@@ -0,0 +1,22 @@
|
||||
[package]
|
||||
name = "gemma4_moe"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[features]
|
||||
|
||||
[dependencies]
|
||||
luminal = { path = "../.." }
|
||||
luminal_nn = { path = "../../crates/luminal_nn" }
|
||||
luminal_cuda_lite = { path = "../../crates/luminal_cuda_lite" }
|
||||
tokenizers = "0.22.2"
|
||||
rustc-hash = "2"
|
||||
|
||||
# HuggingFace model download
|
||||
hf-hub = { version = "0.4", default-features = false, features = ["rustls-tls", "ureq"] }
|
||||
safetensors = "0.7.0"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
half = { version = "2.7.1", features = ["bytemuck"] }
|
||||
bytemuck = "1.24.0"
|
||||
memmap2 = "0.9.9"
|
||||
227
examples/gemma4_moe/src/hf.rs
Normal file
227
examples/gemma4_moe/src/hf.rs
Normal file
@@ -0,0 +1,227 @@
|
||||
use half::{bf16, f16};
|
||||
use hf_hub::api::sync::Api;
|
||||
use memmap2::MmapOptions;
|
||||
use safetensors::{tensor::TensorView, Dtype, SafeTensors};
|
||||
use serde::Deserialize;
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
fs::File,
|
||||
io::Write,
|
||||
path::{Path, PathBuf},
|
||||
};
|
||||
|
||||
use crate::model::HIDDEN;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct SafetensorsIndex {
|
||||
weight_map: HashMap<String, String>,
|
||||
}
|
||||
|
||||
enum TensorData {
|
||||
F32(Vec<f32>),
|
||||
BF16(Vec<u8>),
|
||||
}
|
||||
|
||||
struct StoredTensor {
|
||||
shape: Vec<usize>,
|
||||
data: TensorData,
|
||||
}
|
||||
|
||||
pub fn download_hf_model(repo_id: &str) -> Result<PathBuf, Box<dyn std::error::Error>> {
|
||||
let api = Api::new()?;
|
||||
let repo = api.model(repo_id.to_string());
|
||||
|
||||
let tokenizer_path = repo.get("tokenizer.json")?;
|
||||
let model_dir = tokenizer_path.parent().unwrap().to_path_buf();
|
||||
|
||||
if repo.get("model.safetensors").is_ok() {
|
||||
return Ok(model_dir);
|
||||
}
|
||||
|
||||
let index_path = repo.get("model.safetensors.index.json")?;
|
||||
let index_content = std::fs::read_to_string(&index_path)?;
|
||||
let index: SafetensorsIndex = serde_json::from_str(&index_content)?;
|
||||
|
||||
let mut shard_files: Vec<String> = index.weight_map.values().cloned().collect();
|
||||
shard_files.sort();
|
||||
shard_files.dedup();
|
||||
|
||||
for shard_file in &shard_files {
|
||||
repo.get(shard_file)?;
|
||||
}
|
||||
|
||||
Ok(model_dir)
|
||||
}
|
||||
|
||||
fn tensor_to_f32(tensor: &safetensors::tensor::TensorView) -> Vec<f32> {
|
||||
match tensor.dtype() {
|
||||
Dtype::F32 => bytemuck::cast_slice::<u8, f32>(tensor.data()).to_vec(),
|
||||
Dtype::F16 => {
|
||||
let f16_slice: &[f16] = bytemuck::cast_slice(tensor.data());
|
||||
f16_slice.iter().map(|x| x.to_f32()).collect()
|
||||
}
|
||||
Dtype::BF16 => {
|
||||
let bf16_slice: &[bf16] = bytemuck::cast_slice(tensor.data());
|
||||
bf16_slice.iter().map(|x| x.to_f32()).collect()
|
||||
}
|
||||
other => panic!("Unsupported dtype for conversion: {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
fn tensor_to_bf16_bytes(tensor: &safetensors::tensor::TensorView) -> Vec<u8> {
|
||||
match tensor.dtype() {
|
||||
Dtype::BF16 => tensor.data().to_vec(),
|
||||
Dtype::F16 => {
|
||||
let f16_slice: &[f16] = bytemuck::cast_slice(tensor.data());
|
||||
let bf16_data: Vec<bf16> = f16_slice
|
||||
.iter()
|
||||
.map(|x| bf16::from_f32(x.to_f32()))
|
||||
.collect();
|
||||
bytemuck::cast_slice(&bf16_data).to_vec()
|
||||
}
|
||||
Dtype::F32 => {
|
||||
let f32_slice: &[f32] = bytemuck::cast_slice(tensor.data());
|
||||
let bf16_data: Vec<bf16> = f32_slice.iter().map(|x| bf16::from_f32(*x)).collect();
|
||||
bytemuck::cast_slice(&bf16_data).to_vec()
|
||||
}
|
||||
other => panic!("Unsupported dtype for conversion: {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
fn is_text_weight(name: &str) -> bool {
|
||||
name.starts_with("model.language_model.")
|
||||
}
|
||||
|
||||
fn is_expert_weight(name: &str) -> bool {
|
||||
name.contains(".experts.")
|
||||
}
|
||||
|
||||
pub fn combine_safetensors(model_dir: &Path) -> Result<PathBuf, Box<dyn std::error::Error>> {
|
||||
let output_path = model_dir.join("model_combined.safetensors");
|
||||
if output_path.exists() {
|
||||
return Ok(output_path);
|
||||
}
|
||||
|
||||
let index_path = model_dir.join("model.safetensors.index.json");
|
||||
let single_shard_path = model_dir.join("model.safetensors");
|
||||
|
||||
let shard_files: Vec<PathBuf> = if single_shard_path.exists() && !index_path.exists() {
|
||||
println!("Single shard model detected...");
|
||||
vec![single_shard_path]
|
||||
} else if index_path.exists() {
|
||||
let index_content = std::fs::read_to_string(&index_path)?;
|
||||
let index: SafetensorsIndex = serde_json::from_str(&index_content)?;
|
||||
|
||||
let mut files: Vec<String> = index.weight_map.values().cloned().collect();
|
||||
files.sort();
|
||||
files.dedup();
|
||||
|
||||
println!("Loading {} shard files...", files.len());
|
||||
files.into_iter().map(|f| model_dir.join(f)).collect()
|
||||
} else {
|
||||
return Err("No model.safetensors or model.safetensors.index.json found".into());
|
||||
};
|
||||
|
||||
let mut all_tensors: HashMap<String, StoredTensor> = HashMap::new();
|
||||
|
||||
for shard_path in &shard_files {
|
||||
println!(
|
||||
" Loading {}...",
|
||||
shard_path.file_name().unwrap().to_string_lossy()
|
||||
);
|
||||
let file = File::open(shard_path)?;
|
||||
let mmap = unsafe { MmapOptions::new().map(&file)? };
|
||||
let st = SafeTensors::deserialize(&mmap)?;
|
||||
|
||||
for name in st.names() {
|
||||
if !is_text_weight(name) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let new_name = name.replacen("model.language_model.", "model.", 1);
|
||||
let tensor = st.tensor(name)?;
|
||||
|
||||
if new_name.ends_with(".layer_scalar") {
|
||||
let scalar = tensor_to_f32(&tensor);
|
||||
let scalar = *scalar.first().expect("layer_scalar tensor is empty");
|
||||
all_tensors.insert(
|
||||
new_name,
|
||||
StoredTensor {
|
||||
shape: vec![HIDDEN],
|
||||
data: TensorData::F32(vec![scalar; HIDDEN]),
|
||||
},
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
let shape = tensor.shape().to_vec();
|
||||
let data = if is_expert_weight(&new_name) {
|
||||
TensorData::BF16(tensor_to_bf16_bytes(&tensor))
|
||||
} else {
|
||||
TensorData::F32(tensor_to_f32(&tensor))
|
||||
};
|
||||
|
||||
all_tensors.insert(new_name, StoredTensor { shape, data });
|
||||
}
|
||||
}
|
||||
|
||||
println!("Extracted {} text tensors", all_tensors.len());
|
||||
|
||||
let embed_key = "model.embed_tokens.weight";
|
||||
if let Some(embed_tensor) = all_tensors.get(embed_key) {
|
||||
let (shape, embed_data) = match &embed_tensor.data {
|
||||
TensorData::F32(data) => (embed_tensor.shape.clone(), data.clone()),
|
||||
TensorData::BF16(_) => unreachable!("Embedding weights should stay in F32"),
|
||||
};
|
||||
|
||||
all_tensors.insert(
|
||||
"lm_head.weight".to_string(),
|
||||
StoredTensor {
|
||||
shape,
|
||||
data: TensorData::F32(embed_data.clone()),
|
||||
},
|
||||
);
|
||||
|
||||
let embed_scale = (HIDDEN as f32).sqrt();
|
||||
if let Some(stored) = all_tensors.get_mut(embed_key) {
|
||||
match &mut stored.data {
|
||||
TensorData::F32(data) => {
|
||||
for value in data {
|
||||
*value *= embed_scale;
|
||||
}
|
||||
}
|
||||
TensorData::BF16(_) => unreachable!("Embedding weights should stay in F32"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
println!("Saving combined model (BF16 experts + F32 rest)...");
|
||||
let tensor_views: HashMap<String, TensorView<'_>> = all_tensors
|
||||
.iter()
|
||||
.map(|(name, stored)| {
|
||||
let view = match &stored.data {
|
||||
TensorData::F32(data) => {
|
||||
let bytes: &[u8] = bytemuck::cast_slice(data);
|
||||
TensorView::new(Dtype::F32, stored.shape.clone(), bytes).unwrap()
|
||||
}
|
||||
TensorData::BF16(bytes) => {
|
||||
TensorView::new(Dtype::BF16, stored.shape.clone(), bytes).unwrap()
|
||||
}
|
||||
};
|
||||
(name.clone(), view)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let serialized = safetensors::serialize(&tensor_views, None)?;
|
||||
let mut file = File::create(&output_path)?;
|
||||
file.write_all(&serialized)?;
|
||||
|
||||
println!("Combined model saved successfully!");
|
||||
Ok(output_path)
|
||||
}
|
||||
|
||||
pub fn prepare_hf_model(repo_id: &str) -> Result<PathBuf, Box<dyn std::error::Error>> {
|
||||
let model_dir = download_hf_model(repo_id)?;
|
||||
combine_safetensors(&model_dir)?;
|
||||
Ok(model_dir)
|
||||
}
|
||||
245
examples/gemma4_moe/src/main.rs
Normal file
245
examples/gemma4_moe/src/main.rs
Normal file
@@ -0,0 +1,245 @@
|
||||
mod hf;
|
||||
mod model;
|
||||
|
||||
use hf::prepare_hf_model;
|
||||
use luminal::prelude::*;
|
||||
use luminal_cuda_lite::{cudarc::driver::CudaContext, runtime::CudaRuntime};
|
||||
use model::*;
|
||||
use rustc_hash::FxHashSet;
|
||||
use std::{io::Write, time::Duration};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
const REPO_ID: &str = "google/gemma-4-26B-A4B";
|
||||
|
||||
fn env_usize(name: &str, default: usize) -> usize {
|
||||
std::env::var(name)
|
||||
.ok()
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(default)
|
||||
}
|
||||
|
||||
fn env_bool(name: &str) -> bool {
|
||||
std::env::var(name)
|
||||
.ok()
|
||||
.is_some_and(|s| matches!(s.as_str(), "1" | "true" | "TRUE" | "yes" | "YES"))
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let max_seq_len = env_usize("MAX_SEQ_LEN", 4096);
|
||||
let gen_tokens = env_usize("GEN_TOKENS", 30);
|
||||
let search_graphs = env_usize("SEARCH_GRAPHS", 50);
|
||||
let prompt = std::env::var("PROMPT").unwrap_or_else(|_| "The capital of France is".to_string());
|
||||
let print_token_ids = env_bool("PRINT_TOKEN_IDS");
|
||||
|
||||
let ctx = CudaContext::new(0).unwrap();
|
||||
let stream = ctx.default_stream();
|
||||
|
||||
let model_dir = prepare_hf_model(REPO_ID).expect("Failed to prepare model");
|
||||
println!("Using model directory: {}", model_dir.display());
|
||||
|
||||
let tokenizer = Tokenizer::from_file(model_dir.join("tokenizer.json")).unwrap();
|
||||
let prompt_tokens = tokenizer
|
||||
.encode(prompt.as_str(), true)
|
||||
.unwrap()
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
|
||||
let mut cx = Graph::default();
|
||||
let input = cx.named_tensor("input", 's').as_dtype(DType::Int);
|
||||
let pos_ids = cx.named_tensor("pos_ids", 's').as_dtype(DType::Int);
|
||||
let kv_cache = KVCache::new(&mut cx, max_seq_len);
|
||||
let (logits, cache_outputs) = Gemma4MoE::init(&mut cx).forward(input, pos_ids, &kv_cache);
|
||||
let logits = logits.output();
|
||||
for (k_out, v_out) in &cache_outputs {
|
||||
k_out.output();
|
||||
v_out.output();
|
||||
}
|
||||
|
||||
println!("Building E-Graph...");
|
||||
let compile_start = std::time::Instant::now();
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
|
||||
println!("Loading weights...");
|
||||
let mut runtime = CudaRuntime::initialize(stream);
|
||||
let weights_path = model_dir.join("model_combined.safetensors");
|
||||
runtime.load_safetensors(&cx, weights_path.to_str().unwrap());
|
||||
|
||||
for layer in 0..LAYERS {
|
||||
let cache_bytes = cache_bytes_for_layer(layer, max_seq_len);
|
||||
runtime.set_zeros(kv_cache.k_caches[layer], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[layer], cache_bytes);
|
||||
}
|
||||
|
||||
println!("Compiling...");
|
||||
cx.set_dim('s', 1);
|
||||
cx.set_dim('p', 1);
|
||||
runtime.set_data(input, vec![1]);
|
||||
runtime.set_data(pos_ids, vec![1]);
|
||||
runtime = cx.search(runtime, search_graphs);
|
||||
println!(" COMPILE: {:.2} ms", compile_start.elapsed().as_secs_f64() * 1e3);
|
||||
|
||||
for layer in 0..LAYERS {
|
||||
let cache_bytes = cache_bytes_for_layer(layer, max_seq_len);
|
||||
runtime.set_zeros(kv_cache.k_caches[layer], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[layer], cache_bytes);
|
||||
}
|
||||
|
||||
// Full-prompt warmup: run the complete prompt to bring the GPU to steady state before timing
|
||||
for (w_pos, &w_token) in prompt_tokens.iter().enumerate() {
|
||||
cx.set_dim('s', 1);
|
||||
cx.set_dim('p', w_pos);
|
||||
runtime.set_data(input, vec![w_token as i32]);
|
||||
runtime.set_data(pos_ids, vec![w_pos as i32]);
|
||||
runtime.execute(&cx.dyn_map);
|
||||
for (layer_idx, (k_out, v_out)) in cache_outputs.iter().enumerate() {
|
||||
let k_buf = runtime.remove_buffer(*k_out);
|
||||
let v_buf = runtime.remove_buffer(*v_out);
|
||||
runtime.set_buffer(kv_cache.k_caches[layer_idx], k_buf);
|
||||
runtime.set_buffer(kv_cache.v_caches[layer_idx], v_buf);
|
||||
}
|
||||
}
|
||||
for layer in 0..LAYERS {
|
||||
let cache_bytes = cache_bytes_for_layer(layer, max_seq_len);
|
||||
runtime.set_zeros(kv_cache.k_caches[layer], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[layer], cache_bytes);
|
||||
}
|
||||
|
||||
let iters = env_usize("ITERS", 3);
|
||||
let prompt_len = prompt_tokens.len();
|
||||
|
||||
// ── TTFT: prefill-only timing over N iterations ───────────────────────
|
||||
let mut ttft_samples_ms: Vec<f64> = vec![];
|
||||
for _ in 0..iters {
|
||||
for layer in 0..LAYERS {
|
||||
let cache_bytes = cache_bytes_for_layer(layer, max_seq_len);
|
||||
runtime.set_zeros(kv_cache.k_caches[layer], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[layer], cache_bytes);
|
||||
}
|
||||
let prefill_start = std::time::Instant::now();
|
||||
let mut prev_seq = 0usize;
|
||||
for &token in &prompt_tokens {
|
||||
cx.set_dim('s', 1);
|
||||
cx.set_dim('p', prev_seq);
|
||||
runtime.set_data(input, vec![token as i32]);
|
||||
runtime.set_data(pos_ids, vec![prev_seq as i32]);
|
||||
runtime.execute(&cx.dyn_map);
|
||||
for (layer_idx, (k_out, v_out)) in cache_outputs.iter().enumerate() {
|
||||
let k_buf = runtime.remove_buffer(*k_out);
|
||||
let v_buf = runtime.remove_buffer(*v_out);
|
||||
runtime.set_buffer(kv_cache.k_caches[layer_idx], k_buf);
|
||||
runtime.set_buffer(kv_cache.v_caches[layer_idx], v_buf);
|
||||
}
|
||||
prev_seq += 1;
|
||||
}
|
||||
ttft_samples_ms.push(prefill_start.elapsed().as_secs_f64() * 1e3);
|
||||
}
|
||||
ttft_samples_ms.sort_by(|a, b| a.partial_cmp(b).unwrap());
|
||||
let ttft_ms = ttft_samples_ms[ttft_samples_ms.len() / 2];
|
||||
|
||||
// ── Text generation: one pass for TPOT + visible output ───────────────
|
||||
for layer in 0..LAYERS {
|
||||
let cache_bytes = cache_bytes_for_layer(layer, max_seq_len);
|
||||
runtime.set_zeros(kv_cache.k_caches[layer], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[layer], cache_bytes);
|
||||
}
|
||||
print!("{prompt}");
|
||||
std::io::stdout().flush().unwrap();
|
||||
|
||||
let mut prev_seq = 0usize;
|
||||
let mut fwd_durations = vec![];
|
||||
let mut seen_tokens = FxHashSet::default();
|
||||
let mut generated_token_ids = vec![];
|
||||
let repetition_penalty: f32 = 1.05;
|
||||
|
||||
const EOS_TOKEN: u32 = 1;
|
||||
|
||||
for &token in &prompt_tokens {
|
||||
cx.set_dim('s', 1);
|
||||
cx.set_dim('p', prev_seq);
|
||||
runtime.set_data(input, vec![token as i32]);
|
||||
runtime.set_data(pos_ids, vec![prev_seq as i32]);
|
||||
runtime.execute(&cx.dyn_map);
|
||||
for (layer_idx, (k_out, v_out)) in cache_outputs.iter().enumerate() {
|
||||
let k_buf = runtime.remove_buffer(*k_out);
|
||||
let v_buf = runtime.remove_buffer(*v_out);
|
||||
runtime.set_buffer(kv_cache.k_caches[layer_idx], k_buf);
|
||||
runtime.set_buffer(kv_cache.v_caches[layer_idx], v_buf);
|
||||
}
|
||||
prev_seq += 1;
|
||||
}
|
||||
|
||||
let logits_data = runtime.get_f32(logits);
|
||||
let last_row = &logits_data[..VOCAB_SIZE];
|
||||
let mut next_token = last_row
|
||||
.iter()
|
||||
.enumerate()
|
||||
.max_by(|(_, a), (_, b)| a.total_cmp(b))
|
||||
.unwrap()
|
||||
.0 as u32;
|
||||
generated_token_ids.push(next_token);
|
||||
print!("{}", tokenizer.decode(&[next_token], true).unwrap());
|
||||
std::io::stdout().flush().unwrap();
|
||||
seen_tokens.insert(next_token);
|
||||
|
||||
for _ in 1..gen_tokens {
|
||||
let start = std::time::Instant::now();
|
||||
cx.set_dim('s', 1);
|
||||
cx.set_dim('p', prev_seq);
|
||||
runtime.set_data(input, vec![next_token as i32]);
|
||||
runtime.set_data(pos_ids, vec![prev_seq as i32]);
|
||||
runtime.execute(&cx.dyn_map);
|
||||
|
||||
for (layer_idx, (k_out, v_out)) in cache_outputs.iter().enumerate() {
|
||||
let k_buf = runtime.remove_buffer(*k_out);
|
||||
let v_buf = runtime.remove_buffer(*v_out);
|
||||
runtime.set_buffer(kv_cache.k_caches[layer_idx], k_buf);
|
||||
runtime.set_buffer(kv_cache.v_caches[layer_idx], v_buf);
|
||||
}
|
||||
prev_seq += 1;
|
||||
|
||||
let logits_data = runtime.get_f32(logits);
|
||||
let mut last_row = logits_data[..VOCAB_SIZE].to_vec();
|
||||
for &tok in &seen_tokens {
|
||||
let logit = &mut last_row[tok as usize];
|
||||
if *logit > 0.0 {
|
||||
*logit /= repetition_penalty;
|
||||
} else {
|
||||
*logit *= repetition_penalty;
|
||||
}
|
||||
}
|
||||
next_token = last_row
|
||||
.iter()
|
||||
.enumerate()
|
||||
.max_by(|(_, a), (_, b)| a.total_cmp(b))
|
||||
.unwrap()
|
||||
.0 as u32;
|
||||
generated_token_ids.push(next_token);
|
||||
seen_tokens.insert(next_token);
|
||||
|
||||
if next_token == EOS_TOKEN {
|
||||
break;
|
||||
}
|
||||
|
||||
print!("{}", tokenizer.decode(&[next_token], true).unwrap());
|
||||
std::io::stdout().flush().unwrap();
|
||||
fwd_durations.push(start.elapsed());
|
||||
}
|
||||
println!();
|
||||
if print_token_ids {
|
||||
println!("Generated token ids: {generated_token_ids:?}");
|
||||
}
|
||||
|
||||
// ── Report ────────────────────────────────────────────────────────────
|
||||
println!(
|
||||
" TTFT: {:.2} ms ({} prompt tokens)",
|
||||
ttft_ms, prompt_len
|
||||
);
|
||||
if fwd_durations.len() > 1 {
|
||||
println!(
|
||||
" TPOT: {:.2} ms",
|
||||
(fwd_durations.iter().skip(1).sum::<Duration>() / (fwd_durations.len() - 1) as u32)
|
||||
.as_secs_f64()
|
||||
* 1_000.
|
||||
);
|
||||
}
|
||||
}
|
||||
621
examples/gemma4_moe/src/model.rs
Normal file
621
examples/gemma4_moe/src/model.rs
Normal file
@@ -0,0 +1,621 @@
|
||||
use luminal::{
|
||||
dtype::DType,
|
||||
graph::Graph,
|
||||
prelude::{F32Pow, GraphTensor},
|
||||
shape::Expression,
|
||||
};
|
||||
use luminal_nn::LayerNorm;
|
||||
|
||||
pub const LAYERS: usize = 30;
|
||||
pub const HIDDEN: usize = 2816;
|
||||
pub const INTERMEDIATE: usize = 2112;
|
||||
pub const MOE_INTERMEDIATE: usize = 704;
|
||||
pub const NUM_EXPERTS: usize = 128;
|
||||
pub const TOP_K: usize = 8;
|
||||
pub const N_HEADS: usize = 16;
|
||||
pub const SLIDING_HEAD_DIM: usize = 256;
|
||||
pub const FULL_HEAD_DIM: usize = 512;
|
||||
pub const SLIDING_KV_HEADS: usize = 8;
|
||||
pub const FULL_KV_HEADS: usize = 2;
|
||||
pub const VOCAB_SIZE: usize = 262144;
|
||||
pub const RMS_NORM_EPS: f32 = 1e-6;
|
||||
pub const SLIDING_WINDOW_SIZE: usize = 1024;
|
||||
pub const SLIDING_ROPE_THETA: f32 = 10_000.0;
|
||||
pub const FULL_ROPE_THETA: f32 = 1_000_000.0;
|
||||
pub const FULL_PARTIAL_ROTARY_FACTOR: f32 = 0.25;
|
||||
pub const FINAL_LOGIT_SOFTCAP: f32 = 30.0;
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
struct LayerSpec {
|
||||
is_sliding: bool,
|
||||
head_dim: usize,
|
||||
q_dim: usize,
|
||||
num_kv_heads: usize,
|
||||
kv_dim: usize,
|
||||
kv_groups: usize,
|
||||
rope_theta: f32,
|
||||
partial_rotary_factor: f32,
|
||||
has_v_proj: bool,
|
||||
}
|
||||
|
||||
fn layer_spec(layer: usize) -> LayerSpec {
|
||||
if !(layer + 1).is_multiple_of(6) {
|
||||
LayerSpec {
|
||||
is_sliding: true,
|
||||
head_dim: SLIDING_HEAD_DIM,
|
||||
q_dim: N_HEADS * SLIDING_HEAD_DIM,
|
||||
num_kv_heads: SLIDING_KV_HEADS,
|
||||
kv_dim: SLIDING_KV_HEADS * SLIDING_HEAD_DIM,
|
||||
kv_groups: N_HEADS / SLIDING_KV_HEADS,
|
||||
rope_theta: SLIDING_ROPE_THETA,
|
||||
partial_rotary_factor: 1.0,
|
||||
has_v_proj: true,
|
||||
}
|
||||
} else {
|
||||
LayerSpec {
|
||||
is_sliding: false,
|
||||
head_dim: FULL_HEAD_DIM,
|
||||
q_dim: N_HEADS * FULL_HEAD_DIM,
|
||||
num_kv_heads: FULL_KV_HEADS,
|
||||
kv_dim: FULL_KV_HEADS * FULL_HEAD_DIM,
|
||||
kv_groups: N_HEADS / FULL_KV_HEADS,
|
||||
rope_theta: FULL_ROPE_THETA,
|
||||
partial_rotary_factor: FULL_PARTIAL_ROTARY_FACTOR,
|
||||
has_v_proj: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn cache_bytes_for_layer(layer: usize, max_seq: usize) -> usize {
|
||||
let spec = layer_spec(layer);
|
||||
spec.num_kv_heads * max_seq * spec.head_dim * std::mem::size_of::<f32>()
|
||||
}
|
||||
|
||||
pub struct KVCache {
|
||||
pub k_caches: Vec<GraphTensor>,
|
||||
pub v_caches: Vec<GraphTensor>,
|
||||
pub max_seq: usize,
|
||||
}
|
||||
|
||||
impl KVCache {
|
||||
pub fn new(cx: &mut Graph, max_seq: usize) -> Self {
|
||||
let mut k_caches = Vec::with_capacity(LAYERS);
|
||||
let mut v_caches = Vec::with_capacity(LAYERS);
|
||||
for layer in 0..LAYERS {
|
||||
let spec = layer_spec(layer);
|
||||
let k = cx
|
||||
.named_tensor(
|
||||
format!("kv_cache.{layer}.k"),
|
||||
(spec.num_kv_heads, max_seq, spec.head_dim),
|
||||
)
|
||||
.persist();
|
||||
let v = cx
|
||||
.named_tensor(
|
||||
format!("kv_cache.{layer}.v"),
|
||||
(spec.num_kv_heads, max_seq, spec.head_dim),
|
||||
)
|
||||
.persist();
|
||||
k_caches.push(k);
|
||||
v_caches.push(v);
|
||||
}
|
||||
Self {
|
||||
k_caches,
|
||||
v_caches,
|
||||
max_seq,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Gemma4MoE {
|
||||
embedding: GraphTensor,
|
||||
lm_head: GraphTensor,
|
||||
layers: Vec<Gemma4Layer>,
|
||||
lm_norm: LayerNorm,
|
||||
}
|
||||
|
||||
impl Gemma4MoE {
|
||||
pub fn init(cx: &mut Graph) -> Self {
|
||||
let mut layers = Vec::with_capacity(LAYERS);
|
||||
for layer in 0..LAYERS {
|
||||
let spec = layer_spec(layer);
|
||||
let gate = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.mlp.gate_proj.weight"),
|
||||
(INTERMEDIATE, HIDDEN),
|
||||
)
|
||||
.persist();
|
||||
let up = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.mlp.up_proj.weight"),
|
||||
(INTERMEDIATE, HIDDEN),
|
||||
)
|
||||
.persist();
|
||||
let down = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.mlp.down_proj.weight"),
|
||||
(HIDDEN, INTERMEDIATE),
|
||||
)
|
||||
.persist();
|
||||
|
||||
let q_proj = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.self_attn.q_proj.weight"),
|
||||
(spec.q_dim, HIDDEN),
|
||||
)
|
||||
.persist();
|
||||
let k_proj = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.self_attn.k_proj.weight"),
|
||||
(spec.kv_dim, HIDDEN),
|
||||
)
|
||||
.persist();
|
||||
let v_proj = spec.has_v_proj.then(|| {
|
||||
cx.named_tensor(
|
||||
format!("model.layers.{layer}.self_attn.v_proj.weight"),
|
||||
(spec.kv_dim, HIDDEN),
|
||||
)
|
||||
.persist()
|
||||
});
|
||||
let o_proj = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.self_attn.o_proj.weight"),
|
||||
(HIDDEN, spec.q_dim),
|
||||
)
|
||||
.persist();
|
||||
let q_norm = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.self_attn.q_norm.weight"),
|
||||
spec.head_dim,
|
||||
)
|
||||
.persist();
|
||||
let k_norm = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.self_attn.k_norm.weight"),
|
||||
spec.head_dim,
|
||||
)
|
||||
.persist();
|
||||
let layer_scalar = cx
|
||||
.named_tensor(format!("model.layers.{layer}.layer_scalar"), HIDDEN)
|
||||
.persist();
|
||||
|
||||
let router_scale = cx
|
||||
.named_tensor(format!("model.layers.{layer}.router.scale"), HIDDEN)
|
||||
.persist();
|
||||
let router_proj = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.router.proj.weight"),
|
||||
(NUM_EXPERTS, HIDDEN),
|
||||
)
|
||||
.persist();
|
||||
let per_expert_scale = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.router.per_expert_scale"),
|
||||
NUM_EXPERTS,
|
||||
)
|
||||
.persist();
|
||||
let gate_up_weights = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.experts.gate_up_proj"),
|
||||
(NUM_EXPERTS, MOE_INTERMEDIATE * 2, HIDDEN),
|
||||
)
|
||||
.persist()
|
||||
.as_dtype(DType::Bf16);
|
||||
let down_weights = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.experts.down_proj"),
|
||||
(NUM_EXPERTS, HIDDEN, MOE_INTERMEDIATE),
|
||||
)
|
||||
.persist()
|
||||
.as_dtype(DType::Bf16);
|
||||
|
||||
layers.push(Gemma4Layer {
|
||||
spec,
|
||||
gate,
|
||||
up,
|
||||
down,
|
||||
q_proj,
|
||||
k_proj,
|
||||
v_proj,
|
||||
o_proj,
|
||||
q_norm,
|
||||
k_norm,
|
||||
layer_scalar,
|
||||
input_layernorm: gemma4_norm(
|
||||
HIDDEN,
|
||||
&format!("model.layers.{layer}.input_layernorm.weight"),
|
||||
cx,
|
||||
),
|
||||
post_attention_layernorm: gemma4_norm(
|
||||
HIDDEN,
|
||||
&format!("model.layers.{layer}.post_attention_layernorm.weight"),
|
||||
cx,
|
||||
),
|
||||
pre_feedforward_layernorm: gemma4_norm(
|
||||
HIDDEN,
|
||||
&format!("model.layers.{layer}.pre_feedforward_layernorm.weight"),
|
||||
cx,
|
||||
),
|
||||
post_feedforward_layernorm: gemma4_norm(
|
||||
HIDDEN,
|
||||
&format!("model.layers.{layer}.post_feedforward_layernorm.weight"),
|
||||
cx,
|
||||
),
|
||||
post_feedforward_layernorm_1: gemma4_norm(
|
||||
HIDDEN,
|
||||
&format!("model.layers.{layer}.post_feedforward_layernorm_1.weight"),
|
||||
cx,
|
||||
),
|
||||
post_feedforward_layernorm_2: gemma4_norm(
|
||||
HIDDEN,
|
||||
&format!("model.layers.{layer}.post_feedforward_layernorm_2.weight"),
|
||||
cx,
|
||||
),
|
||||
pre_feedforward_layernorm_2: gemma4_norm(
|
||||
HIDDEN,
|
||||
&format!("model.layers.{layer}.pre_feedforward_layernorm_2.weight"),
|
||||
cx,
|
||||
),
|
||||
moe: Gemma4SparseMoE {
|
||||
router_scale,
|
||||
router_proj,
|
||||
per_expert_scale,
|
||||
gate_up_weights,
|
||||
down_weights,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
let embedding = cx
|
||||
.named_tensor("model.embed_tokens.weight", (VOCAB_SIZE, HIDDEN))
|
||||
.persist();
|
||||
let lm_head = cx
|
||||
.named_tensor("lm_head.weight", (VOCAB_SIZE, HIDDEN))
|
||||
.persist();
|
||||
let lm_norm = gemma4_norm(HIDDEN, "model.norm.weight", cx);
|
||||
|
||||
Self {
|
||||
embedding,
|
||||
lm_head,
|
||||
layers,
|
||||
lm_norm,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward(
|
||||
&self,
|
||||
token_ids: GraphTensor,
|
||||
pos_ids: GraphTensor,
|
||||
kv_cache: &KVCache,
|
||||
) -> (GraphTensor, Vec<(GraphTensor, GraphTensor)>) {
|
||||
let seq = token_ids.dims1();
|
||||
let mut x = self.embedding.gather(
|
||||
(token_ids * HIDDEN).expand_dim(1, HIDDEN)
|
||||
+ token_ids.graph().arange(HIDDEN).expand_dim(0, seq),
|
||||
);
|
||||
|
||||
let mut cache_outputs = Vec::with_capacity(LAYERS);
|
||||
for (layer_idx, layer) in self.layers.iter().enumerate() {
|
||||
let (x_new, k_out, v_out) = layer.forward(
|
||||
x,
|
||||
pos_ids,
|
||||
kv_cache.k_caches[layer_idx],
|
||||
kv_cache.v_caches[layer_idx],
|
||||
kv_cache.max_seq,
|
||||
);
|
||||
x = x_new;
|
||||
cache_outputs.push((k_out, v_out));
|
||||
}
|
||||
|
||||
let logits = self.lm_norm.forward(x).matmul(self.lm_head.t());
|
||||
let logits = (logits / FINAL_LOGIT_SOFTCAP).tanh() * FINAL_LOGIT_SOFTCAP;
|
||||
(logits, cache_outputs)
|
||||
}
|
||||
}
|
||||
|
||||
struct Gemma4Layer {
|
||||
spec: LayerSpec,
|
||||
gate: GraphTensor,
|
||||
up: GraphTensor,
|
||||
down: GraphTensor,
|
||||
q_proj: GraphTensor,
|
||||
k_proj: GraphTensor,
|
||||
v_proj: Option<GraphTensor>,
|
||||
o_proj: GraphTensor,
|
||||
q_norm: GraphTensor,
|
||||
k_norm: GraphTensor,
|
||||
layer_scalar: GraphTensor,
|
||||
input_layernorm: LayerNorm,
|
||||
post_attention_layernorm: LayerNorm,
|
||||
pre_feedforward_layernorm: LayerNorm,
|
||||
post_feedforward_layernorm: LayerNorm,
|
||||
post_feedforward_layernorm_1: LayerNorm,
|
||||
post_feedforward_layernorm_2: LayerNorm,
|
||||
pre_feedforward_layernorm_2: LayerNorm,
|
||||
moe: Gemma4SparseMoE,
|
||||
}
|
||||
|
||||
struct Gemma4SparseMoE {
|
||||
router_scale: GraphTensor,
|
||||
router_proj: GraphTensor,
|
||||
per_expert_scale: GraphTensor,
|
||||
gate_up_weights: GraphTensor,
|
||||
down_weights: GraphTensor,
|
||||
}
|
||||
|
||||
fn gemma4_norm(dim: usize, weight_name: &str, cx: &mut Graph) -> LayerNorm {
|
||||
LayerNorm::new(dim, Some(weight_name), None, false, RMS_NORM_EPS, cx)
|
||||
}
|
||||
|
||||
#[allow(clippy::excessive_precision)]
|
||||
fn gemma_gelu(x: GraphTensor) -> GraphTensor {
|
||||
let scaled = 1.5957691216 * x * (1. + 0.044715 * x * x);
|
||||
x * scaled.sigmoid()
|
||||
}
|
||||
|
||||
fn qk_norm(x: GraphTensor, weight: GraphTensor, n_heads: usize, head_dim: usize) -> GraphTensor {
|
||||
let seq = x.dims()[0];
|
||||
let reshaped = x.split_dims(1, head_dim);
|
||||
let normed = reshaped.std_norm(2, RMS_NORM_EPS);
|
||||
let w = weight.expand_dim(0, n_heads).expand_dim(0, seq);
|
||||
(normed * w).merge_dims(1, 2)
|
||||
}
|
||||
|
||||
fn value_norm(x: GraphTensor, head_dim: usize) -> GraphTensor {
|
||||
x.split_dims(1, head_dim)
|
||||
.std_norm(2, RMS_NORM_EPS)
|
||||
.merge_dims(1, 2)
|
||||
}
|
||||
|
||||
fn gemma4_rotary_embeddings(
|
||||
input: GraphTensor,
|
||||
pos_ids: GraphTensor,
|
||||
n_heads: usize,
|
||||
head_dim: usize,
|
||||
rope_theta: f32,
|
||||
partial_rotary_factor: f32,
|
||||
) -> GraphTensor {
|
||||
let input = input.split_dims(1, head_dim).transpose(0, 1);
|
||||
let half_dim = head_dim / 2;
|
||||
let rope_angles = ((partial_rotary_factor * head_dim as f32) / 2.0).floor() as usize;
|
||||
|
||||
let rotated = input
|
||||
.graph()
|
||||
.arange_options(0, rope_angles * 2, 2)
|
||||
.cast(DType::F32)
|
||||
/ head_dim as f32;
|
||||
let rotated = rope_theta.pow(rotated).reciprocal();
|
||||
let inv_freqs = if rope_angles < half_dim {
|
||||
let zeros = input
|
||||
.graph()
|
||||
.arange(half_dim - rope_angles)
|
||||
.cast(DType::F32)
|
||||
* 0.0;
|
||||
rotated.concat_along(zeros, 0)
|
||||
} else {
|
||||
rotated
|
||||
};
|
||||
|
||||
let emb = pos_ids
|
||||
.cast(DType::F32)
|
||||
.expand_dim(1, 1)
|
||||
.matmul(inv_freqs.expand_dim(0, 1));
|
||||
|
||||
let x0 = input.slice((.., .., ..half_dim));
|
||||
let x1 = input.slice((.., .., half_dim..));
|
||||
|
||||
let cos = emb.cos().expand_dim(0, n_heads);
|
||||
let sin = emb.sin().expand_dim(0, n_heads);
|
||||
let x0_out = x0 * cos - x1 * sin;
|
||||
let x1_out = x1 * cos + x0 * sin;
|
||||
|
||||
x0_out
|
||||
.concat_along(x1_out, 2)
|
||||
.transpose(0, 1)
|
||||
.merge_dims(1, 2)
|
||||
}
|
||||
|
||||
fn gather_experts(
|
||||
graph_source: GraphTensor,
|
||||
top_k_indices: GraphTensor,
|
||||
weights: GraphTensor,
|
||||
) -> GraphTensor {
|
||||
let (_, d1, d2) = weights.dims3();
|
||||
let io = d1 * d2;
|
||||
let base = top_k_indices * io;
|
||||
let within = graph_source.graph().iota(Expression::from('z'), (d1, d2));
|
||||
let n_base = base.dims().len();
|
||||
let exp_base = base.expand_dim(n_base, d1).expand_dim(n_base + 1, d2);
|
||||
let mut exp_within = within;
|
||||
for (axis, dim) in base.dims().iter().enumerate() {
|
||||
exp_within = exp_within.expand_dim(axis, *dim);
|
||||
}
|
||||
let expert_flat_idx = exp_base + exp_within;
|
||||
weights.gather(expert_flat_idx)
|
||||
}
|
||||
|
||||
fn hlir_attention(
|
||||
q_rope: GraphTensor,
|
||||
k_rope: GraphTensor,
|
||||
v: GraphTensor,
|
||||
k_cache_in: GraphTensor,
|
||||
v_cache_in: GraphTensor,
|
||||
max_seq: usize,
|
||||
spec: LayerSpec,
|
||||
) -> (GraphTensor, GraphTensor, GraphTensor) {
|
||||
let cx = q_rope.graph();
|
||||
let seq = q_rope.dims()[0];
|
||||
let prev = Expression::from('p');
|
||||
let total_seq = prev + seq;
|
||||
|
||||
let k_new = k_rope.split_dims(1, spec.head_dim).transpose(0, 1);
|
||||
let v_new = v.split_dims(1, spec.head_dim).transpose(0, 1);
|
||||
|
||||
let h_offset = cx.arange(spec.num_kv_heads) * (max_seq * spec.head_dim);
|
||||
let p_offset = (cx.arange(seq) + prev) * spec.head_dim;
|
||||
let d_offset = cx.arange(spec.head_dim);
|
||||
let scatter_idx = h_offset.expand_dim(1, seq).expand_dim(2, spec.head_dim)
|
||||
+ p_offset
|
||||
.expand_dim(0, spec.num_kv_heads)
|
||||
.expand_dim(2, spec.head_dim)
|
||||
+ d_offset.expand_dim(0, spec.num_kv_heads).expand_dim(1, seq);
|
||||
|
||||
let k_cache_out = k_new.scatter(scatter_idx, k_cache_in);
|
||||
let v_cache_out = v_new.scatter(scatter_idx, v_cache_in);
|
||||
|
||||
let k_full = k_cache_out.slice((.., ..total_seq, ..));
|
||||
let v_full = v_cache_out.slice((.., ..total_seq, ..));
|
||||
|
||||
let k_3d = k_full.expand_dim(1, spec.kv_groups).merge_dims(0, 1);
|
||||
let v_3d = v_full.expand_dim(1, spec.kv_groups).merge_dims(0, 1);
|
||||
let q = q_rope.split_dims(1, spec.head_dim).transpose(0, 1);
|
||||
|
||||
// Gemma 4's text attention uses Q/K normalization and then leaves the
|
||||
// attention scaling at 1.0 in the reference implementation.
|
||||
let scores = q.matmul(k_3d.transpose(1, 2));
|
||||
|
||||
let q_abs = cx.arange(seq).cast(DType::F32) + prev;
|
||||
let k_pos = cx.arange(total_seq).cast(DType::F32);
|
||||
let future_mask = k_pos
|
||||
.expand_dim(0, seq)
|
||||
.gt(q_abs.expand_dim(1, total_seq))
|
||||
.cast(DType::F32);
|
||||
|
||||
let mask_2d = if spec.is_sliding {
|
||||
let window_start = q_abs - (SLIDING_WINDOW_SIZE - 1) as f32;
|
||||
let past_mask = window_start
|
||||
.expand_dim(1, total_seq)
|
||||
.gt(k_pos.expand_dim(0, seq))
|
||||
.cast(DType::F32);
|
||||
future_mask + past_mask
|
||||
} else {
|
||||
future_mask
|
||||
};
|
||||
let mask_3d = mask_2d.expand_dim(0, N_HEADS);
|
||||
let masked_scores = scores + mask_3d * (-1e10f32);
|
||||
|
||||
let attn_weights = masked_scores.softmax(2);
|
||||
let attn_out = attn_weights.matmul(v_3d);
|
||||
let out = attn_out.transpose(0, 1).merge_dims(1, 2);
|
||||
|
||||
(out, k_cache_out, v_cache_out)
|
||||
}
|
||||
|
||||
impl Gemma4Layer {
|
||||
pub fn forward(
|
||||
&self,
|
||||
x: GraphTensor,
|
||||
pos_ids: GraphTensor,
|
||||
k_cache_in: GraphTensor,
|
||||
v_cache_in: GraphTensor,
|
||||
max_seq: usize,
|
||||
) -> (GraphTensor, GraphTensor, GraphTensor) {
|
||||
let residual = x;
|
||||
let x_attn = self.input_layernorm.forward(x);
|
||||
let q = x_attn.matmul(self.q_proj.t());
|
||||
let k_base = x_attn.matmul(self.k_proj.t());
|
||||
let v_base = if let Some(v_proj) = self.v_proj {
|
||||
x_attn.matmul(v_proj.t())
|
||||
} else {
|
||||
k_base
|
||||
};
|
||||
|
||||
let q_normed = qk_norm(q, self.q_norm, N_HEADS, self.spec.head_dim);
|
||||
let k_normed = qk_norm(
|
||||
k_base,
|
||||
self.k_norm,
|
||||
self.spec.num_kv_heads,
|
||||
self.spec.head_dim,
|
||||
);
|
||||
let v_normed = value_norm(v_base, self.spec.head_dim);
|
||||
|
||||
let q_rope = gemma4_rotary_embeddings(
|
||||
q_normed,
|
||||
pos_ids,
|
||||
N_HEADS,
|
||||
self.spec.head_dim,
|
||||
self.spec.rope_theta,
|
||||
self.spec.partial_rotary_factor,
|
||||
);
|
||||
let k_rope = gemma4_rotary_embeddings(
|
||||
k_normed,
|
||||
pos_ids,
|
||||
self.spec.num_kv_heads,
|
||||
self.spec.head_dim,
|
||||
self.spec.rope_theta,
|
||||
self.spec.partial_rotary_factor,
|
||||
);
|
||||
|
||||
let (attn_out, k_cache_out, v_cache_out) = hlir_attention(
|
||||
q_rope, k_rope, v_normed, k_cache_in, v_cache_in, max_seq, self.spec,
|
||||
);
|
||||
|
||||
let attn_proj = attn_out.matmul(self.o_proj.t());
|
||||
let x = residual + self.post_attention_layernorm.forward(attn_proj);
|
||||
|
||||
let dense_ff = dense_ffn(
|
||||
self.pre_feedforward_layernorm.forward(x),
|
||||
self.gate,
|
||||
self.up,
|
||||
self.down,
|
||||
);
|
||||
let dense_ff = self.post_feedforward_layernorm_1.forward(dense_ff);
|
||||
|
||||
let moe_out = self
|
||||
.moe
|
||||
.forward(x, self.pre_feedforward_layernorm_2.forward(x));
|
||||
let moe_out = self.post_feedforward_layernorm_2.forward(moe_out);
|
||||
|
||||
let ff_out = self.post_feedforward_layernorm.forward(dense_ff + moe_out);
|
||||
let x = x + ff_out;
|
||||
let x = x * self
|
||||
.layer_scalar
|
||||
.expand_lhs(&x.dims()[..x.dims().len() - 1]);
|
||||
|
||||
(x, k_cache_out, v_cache_out)
|
||||
}
|
||||
}
|
||||
|
||||
fn dense_ffn(x: GraphTensor, gate: GraphTensor, up: GraphTensor, down: GraphTensor) -> GraphTensor {
|
||||
(gemma_gelu(x.matmul(gate.t())) * x.matmul(up.t())).matmul(down.t())
|
||||
}
|
||||
|
||||
impl Gemma4SparseMoE {
|
||||
fn forward(&self, router_input: GraphTensor, expert_input: GraphTensor) -> GraphTensor {
|
||||
let n = router_input.dims().len();
|
||||
let e_dim = *self.router_proj.dims().first().unwrap();
|
||||
let k_expr = Expression::from(TOP_K);
|
||||
|
||||
let router_hidden = router_input.std_norm(router_input.dims().len() - 1, RMS_NORM_EPS)
|
||||
* self
|
||||
.router_scale
|
||||
.expand_lhs(&router_input.dims()[..router_input.dims().len() - 1])
|
||||
* (HIDDEN as f32).sqrt().recip();
|
||||
let routing_weights = router_hidden.matmul(self.router_proj.t()).softmax(n - 1);
|
||||
|
||||
let top_k_indices = routing_weights.topk_indexes(TOP_K, n - 1);
|
||||
let row_offsets = router_input
|
||||
.graph()
|
||||
.iota(Expression::from('z') / k_expr * e_dim, top_k_indices.dims());
|
||||
let routing_flat_idx = row_offsets + top_k_indices;
|
||||
let top_k_values = routing_weights.gather(routing_flat_idx);
|
||||
let top_k_norm = top_k_values.sum(n - 1).expand_dim(n - 1, TOP_K);
|
||||
let top_k_weights =
|
||||
(top_k_values / top_k_norm) * self.per_expert_scale.gather(top_k_indices);
|
||||
|
||||
let gate_up_gathered =
|
||||
gather_experts(expert_input, top_k_indices, self.gate_up_weights).cast(DType::F32);
|
||||
let x_exp = expert_input.expand_dim(n - 1, TOP_K).unsqueeze(n);
|
||||
let gate_up_out = x_exp.matmul(gate_up_gathered.transpose(2, 3)).squeeze(n);
|
||||
|
||||
let gate = gate_up_out.slice((.., .., ..MOE_INTERMEDIATE));
|
||||
let up = gate_up_out.slice((.., .., MOE_INTERMEDIATE..));
|
||||
let hidden = gemma_gelu(gate) * up;
|
||||
|
||||
let down_gathered =
|
||||
gather_experts(expert_input, top_k_indices, self.down_weights).cast(DType::F32);
|
||||
let hidden_exp = hidden.unsqueeze(2);
|
||||
let down_out = hidden_exp.matmul(down_gathered.transpose(2, 3)).squeeze(2);
|
||||
|
||||
(down_out * top_k_weights.unsqueeze(top_k_weights.dims().len())).sum(n - 1)
|
||||
}
|
||||
}
|
||||
@@ -13,11 +13,21 @@ use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
||||
|
||||
const REPO_ID: &str = "NousResearch/Meta-Llama-3-8B-Instruct";
|
||||
|
||||
// Default configuration — override at runtime via env vars.
|
||||
const DEFAULT_MAX_SEQ_LEN: usize = 4096;
|
||||
const DEFAULT_SEARCH_GRAPHS: usize = 500;
|
||||
const DEFAULT_GEN_TOKENS: usize = 500;
|
||||
|
||||
fn env_usize(name: &str, default: usize) -> usize {
|
||||
std::env::var(name).ok().and_then(|s| s.parse().ok()).unwrap_or(default)
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let max_seq_len = 4096;
|
||||
let gen_tokens = 500;
|
||||
let search_graphs = 500;
|
||||
let prompt = "Explain what a neural network is in a paragraph.";
|
||||
let max_seq_len = env_usize("MAX_SEQ_LEN", DEFAULT_MAX_SEQ_LEN);
|
||||
let gen_tokens = env_usize("GEN_TOKENS", DEFAULT_GEN_TOKENS);
|
||||
let search_graphs = env_usize("SEARCH_GRAPHS", DEFAULT_SEARCH_GRAPHS);
|
||||
let prompt = std::env::var("PROMPT")
|
||||
.unwrap_or_else(|_| "Explain what a neural network is in a paragraph.".to_string());
|
||||
|
||||
tracing_subscriber::registry()
|
||||
.with(tracing_subscriber::fmt::layer())
|
||||
@@ -53,6 +63,7 @@ fn main() {
|
||||
}
|
||||
|
||||
println!("Building E-Graph...");
|
||||
let compile_start = std::time::Instant::now();
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
|
||||
println!("Loading weights...");
|
||||
@@ -72,36 +83,90 @@ fn main() {
|
||||
runtime.set_data(input, vec![1]);
|
||||
runtime.set_data(token_ids, vec![1]);
|
||||
runtime = cx.search(runtime, search_graphs);
|
||||
println!(" COMPILE: {:.2} ms", compile_start.elapsed().as_secs_f64() * 1e3);
|
||||
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
|
||||
}
|
||||
|
||||
// Full-prompt warmup: run the complete prompt to bring the GPU to steady state before timing
|
||||
for (w_step, &w_token) in prompt_tokens.iter().enumerate() {
|
||||
let p = w_step + 1;
|
||||
cx.set_dim('s', 1);
|
||||
cx.set_dim('p', p);
|
||||
runtime.set_data(input, vec![w_token as i32]);
|
||||
runtime.set_data(token_ids, vec![p as i32]);
|
||||
runtime.execute(&cx.dyn_map);
|
||||
for (layer_idx, (k_out, v_out)) in cache_outputs.iter().enumerate() {
|
||||
let k_buf = runtime.remove_buffer(*k_out);
|
||||
let v_buf = runtime.remove_buffer(*v_out);
|
||||
runtime.set_buffer(kv_cache.k_caches[layer_idx], k_buf);
|
||||
runtime.set_buffer(kv_cache.v_caches[layer_idx], v_buf);
|
||||
}
|
||||
}
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
|
||||
}
|
||||
|
||||
let iters = env_usize("ITERS", 3);
|
||||
let prompt_len = prompt_tokens.len();
|
||||
|
||||
println!("Prompt: {} tokens, generating up to {} tokens", prompt_len, gen_tokens);
|
||||
|
||||
// ── TTFT: prefill-only timing over N iterations ───────────────────────
|
||||
let mut ttft_samples_ms: Vec<f64> = vec![];
|
||||
for _ in 0..iters {
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
|
||||
}
|
||||
let mut prev_seq = 1usize;
|
||||
let mut step_times = vec![];
|
||||
for step in 0..prompt_len {
|
||||
cx.set_dim('s', 1);
|
||||
cx.set_dim('p', prev_seq);
|
||||
runtime.set_data(input, vec![prompt_tokens[step] as i32]);
|
||||
runtime.set_data(token_ids, vec![prev_seq as i32]);
|
||||
let t = std::time::Instant::now();
|
||||
runtime.execute(&cx.dyn_map);
|
||||
let _ = runtime.get_f32(logits);
|
||||
step_times.push(t.elapsed());
|
||||
for (layer_idx, (k_out, v_out)) in cache_outputs.iter().enumerate() {
|
||||
let k_buf = runtime.remove_buffer(*k_out);
|
||||
let v_buf = runtime.remove_buffer(*v_out);
|
||||
runtime.set_buffer(kv_cache.k_caches[layer_idx], k_buf);
|
||||
runtime.set_buffer(kv_cache.v_caches[layer_idx], v_buf);
|
||||
}
|
||||
prev_seq += 1;
|
||||
}
|
||||
ttft_samples_ms.push(step_times.iter().sum::<Duration>().as_secs_f64() * 1e3);
|
||||
}
|
||||
ttft_samples_ms.sort_by(|a, b| a.partial_cmp(b).unwrap());
|
||||
let ttft_ms = ttft_samples_ms[ttft_samples_ms.len() / 2];
|
||||
|
||||
// ── Text generation: one pass for TPOT + visible output ───────────────
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
|
||||
}
|
||||
let mut prev_seq = 1usize;
|
||||
let mut sentence = vec![prompt_tokens[0]];
|
||||
let total_steps = prompt_tokens.len() - 1 + gen_tokens;
|
||||
let prompt_len = prompt_tokens.len();
|
||||
let mut fwd_durations = vec![];
|
||||
let total_steps = prompt_len - 1 + gen_tokens;
|
||||
let mut decode_step_times: Vec<Duration> = vec![];
|
||||
let mut seen_tokens = FxHashSet::default();
|
||||
let repetition_penalty: f32 = 1.05;
|
||||
|
||||
const EOS_TOKEN: u32 = 128009;
|
||||
const STOP_TOKEN: u32 = 128001;
|
||||
|
||||
println!(
|
||||
"Prompt: {} tokens, generating up to {} tokens",
|
||||
prompt_len, gen_tokens
|
||||
);
|
||||
|
||||
for i in 0..total_steps {
|
||||
let start = std::time::Instant::now();
|
||||
let is_prefill = i < prompt_len - 1;
|
||||
let seq_len = sentence.len();
|
||||
|
||||
cx.set_dim('s', seq_len);
|
||||
cx.set_dim('p', prev_seq);
|
||||
|
||||
runtime.set_data(
|
||||
input,
|
||||
sentence.iter().map(|t| *t as i32).collect::<Vec<_>>(),
|
||||
@@ -111,26 +176,26 @@ fn main() {
|
||||
(prev_seq as i32..(seq_len + prev_seq) as i32).collect::<Vec<_>>(),
|
||||
);
|
||||
|
||||
let step_start = std::time::Instant::now();
|
||||
runtime.execute(&cx.dyn_map);
|
||||
let logits_data = runtime.get_f32(logits);
|
||||
let step_elapsed = step_start.elapsed();
|
||||
|
||||
// Round-trip KV cache
|
||||
for (layer_idx, (k_out, v_out)) in cache_outputs.iter().enumerate() {
|
||||
let k_buf = runtime.remove_buffer(*k_out);
|
||||
let v_buf = runtime.remove_buffer(*v_out);
|
||||
runtime.set_buffer(kv_cache.k_caches[layer_idx], k_buf);
|
||||
runtime.set_buffer(kv_cache.v_caches[layer_idx], v_buf);
|
||||
}
|
||||
|
||||
prev_seq += seq_len;
|
||||
fwd_durations.push(start.elapsed());
|
||||
|
||||
if is_prefill {
|
||||
sentence = vec![prompt_tokens[i + 1]];
|
||||
continue;
|
||||
}
|
||||
|
||||
// Greedy decode with repetition penalty
|
||||
decode_step_times.push(step_elapsed);
|
||||
|
||||
let mut last_row = logits_data[logits_data.len() - VOCAB_SIZE..].to_vec();
|
||||
for &tok in &seen_tokens {
|
||||
let logit = &mut last_row[tok as usize];
|
||||
@@ -159,21 +224,13 @@ fn main() {
|
||||
}
|
||||
println!();
|
||||
|
||||
// Benchmarks
|
||||
let decode_durations: Vec<_> = fwd_durations.iter().skip(prompt_len).collect();
|
||||
if decode_durations.len() > 2 {
|
||||
println!(
|
||||
" TTFT: {:.2} ms",
|
||||
fwd_durations[..prompt_len]
|
||||
.iter()
|
||||
.sum::<Duration>()
|
||||
.as_secs_f64()
|
||||
* 1e3
|
||||
);
|
||||
// ── Report ────────────────────────────────────────────────────────────
|
||||
println!(" TTFT: {:.2} ms", ttft_ms);
|
||||
if decode_step_times.len() > 1 {
|
||||
println!(
|
||||
" TPOT: {:.2} ms",
|
||||
(decode_durations.iter().skip(1).copied().sum::<Duration>()
|
||||
/ (decode_durations.len() - 1) as u32)
|
||||
(decode_step_times.iter().skip(1).sum::<Duration>()
|
||||
/ (decode_step_times.len() - 1) as u32)
|
||||
.as_secs_f64()
|
||||
* 1_000.
|
||||
);
|
||||
|
||||
@@ -159,7 +159,8 @@ impl Llama {
|
||||
kv_cache.v_caches[i],
|
||||
kv_cache.max_seq,
|
||||
);
|
||||
x = x_new.graph_break();
|
||||
x = x_new;
|
||||
//x = x_new.graph_break();
|
||||
cache_outputs.push((k_out, v_out));
|
||||
}
|
||||
let logits = self.lm_norm.forward(x).matmul(self.lm_head.t());
|
||||
|
||||
@@ -157,7 +157,7 @@ impl Llama {
|
||||
kv_cache.k_caches[i],
|
||||
kv_cache.v_caches[i],
|
||||
);
|
||||
x = x_new.graph_break();
|
||||
x = x_new;
|
||||
cache_outputs.push((k_out, v_out));
|
||||
}
|
||||
let logits = self.lm_norm.forward(x).matmul(self.lm_head.t());
|
||||
|
||||
@@ -13,11 +13,21 @@ use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
||||
|
||||
const REPO_ID: &str = "Qwen/Qwen3-4B";
|
||||
|
||||
// Default configuration — override at runtime via env vars.
|
||||
const DEFAULT_MAX_SEQ_LEN: usize = 4096;
|
||||
const DEFAULT_SEARCH_GRAPHS: usize = 50;
|
||||
const DEFAULT_GEN_TOKENS: usize = 500;
|
||||
|
||||
fn env_usize(name: &str, default: usize) -> usize {
|
||||
std::env::var(name).ok().and_then(|s| s.parse().ok()).unwrap_or(default)
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let max_seq_len = 4096;
|
||||
let gen_tokens = 500;
|
||||
let search_graphs = 500;
|
||||
let prompt = "Explain what a neural network is in a paragraph.";
|
||||
let max_seq_len = env_usize("MAX_SEQ_LEN", DEFAULT_MAX_SEQ_LEN);
|
||||
let gen_tokens = env_usize("GEN_TOKENS", DEFAULT_GEN_TOKENS);
|
||||
let search_graphs = env_usize("SEARCH_GRAPHS", DEFAULT_SEARCH_GRAPHS);
|
||||
let prompt = std::env::var("PROMPT")
|
||||
.unwrap_or_else(|_| "Explain what a neural network is in a paragraph.".to_string());
|
||||
|
||||
tracing_subscriber::registry()
|
||||
.with(tracing_subscriber::fmt::layer())
|
||||
@@ -46,6 +56,7 @@ fn main() {
|
||||
}
|
||||
|
||||
println!("Building E-Graph...");
|
||||
let compile_start = std::time::Instant::now();
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
|
||||
println!("Loading weights...");
|
||||
@@ -65,36 +76,90 @@ fn main() {
|
||||
runtime.set_data(input, vec![1]);
|
||||
runtime.set_data(token_ids, vec![1]);
|
||||
runtime = cx.search(runtime, search_graphs);
|
||||
println!(" COMPILE: {:.2} ms", compile_start.elapsed().as_secs_f64() * 1e3);
|
||||
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
|
||||
}
|
||||
|
||||
// Full-prompt warmup: run the complete prompt to bring the GPU to steady state before timing
|
||||
for (w_step, &w_token) in prompt_tokens.iter().enumerate() {
|
||||
let p = w_step + 1;
|
||||
cx.set_dim('s', 1);
|
||||
cx.set_dim('p', p);
|
||||
runtime.set_data(input, vec![w_token as i32]);
|
||||
runtime.set_data(token_ids, vec![p as i32]);
|
||||
runtime.execute(&cx.dyn_map);
|
||||
for (layer_idx, (k_out, v_out)) in cache_outputs.iter().enumerate() {
|
||||
let k_buf = runtime.remove_buffer(*k_out);
|
||||
let v_buf = runtime.remove_buffer(*v_out);
|
||||
runtime.set_buffer(kv_cache.k_caches[layer_idx], k_buf);
|
||||
runtime.set_buffer(kv_cache.v_caches[layer_idx], v_buf);
|
||||
}
|
||||
}
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
|
||||
}
|
||||
|
||||
let iters = env_usize("ITERS", 3);
|
||||
let prompt_len = prompt_tokens.len();
|
||||
|
||||
println!("Prompt: {} tokens, generating up to {} tokens", prompt_len, gen_tokens);
|
||||
|
||||
// ── TTFT: prefill-only timing over N iterations ───────────────────────
|
||||
let mut ttft_samples_ms: Vec<f64> = vec![];
|
||||
for _ in 0..iters {
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
|
||||
}
|
||||
let mut prev_seq = 1usize;
|
||||
let mut step_times = vec![];
|
||||
for step in 0..prompt_len {
|
||||
cx.set_dim('s', 1);
|
||||
cx.set_dim('p', prev_seq);
|
||||
runtime.set_data(input, vec![prompt_tokens[step] as i32]);
|
||||
runtime.set_data(token_ids, vec![prev_seq as i32]);
|
||||
let t = std::time::Instant::now();
|
||||
runtime.execute(&cx.dyn_map);
|
||||
let _ = runtime.get_f32(logits);
|
||||
step_times.push(t.elapsed());
|
||||
for (layer_idx, (k_out, v_out)) in cache_outputs.iter().enumerate() {
|
||||
let k_buf = runtime.remove_buffer(*k_out);
|
||||
let v_buf = runtime.remove_buffer(*v_out);
|
||||
runtime.set_buffer(kv_cache.k_caches[layer_idx], k_buf);
|
||||
runtime.set_buffer(kv_cache.v_caches[layer_idx], v_buf);
|
||||
}
|
||||
prev_seq += 1;
|
||||
}
|
||||
ttft_samples_ms.push(step_times.iter().sum::<Duration>().as_secs_f64() * 1e3);
|
||||
}
|
||||
ttft_samples_ms.sort_by(|a, b| a.partial_cmp(b).unwrap());
|
||||
let ttft_ms = ttft_samples_ms[ttft_samples_ms.len() / 2];
|
||||
|
||||
// ── Text generation: one pass for TPOT + visible output ───────────────
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
|
||||
}
|
||||
let mut prev_seq = 1usize;
|
||||
let mut sentence = vec![prompt_tokens[0]];
|
||||
let total_steps = prompt_tokens.len() - 1 + gen_tokens;
|
||||
let prompt_len = prompt_tokens.len();
|
||||
let mut fwd_durations = vec![];
|
||||
let total_steps = prompt_len - 1 + gen_tokens;
|
||||
let mut decode_step_times: Vec<Duration> = vec![];
|
||||
let mut seen_tokens = FxHashSet::default();
|
||||
let repetition_penalty: f32 = 1.05;
|
||||
|
||||
const EOS_TOKEN: u32 = 151645; // <|endoftext|>
|
||||
const STOP_TOKEN: u32 = 151643; // <|end|>
|
||||
|
||||
println!(
|
||||
"Prompt: {} tokens, generating up to {} tokens",
|
||||
prompt_len, gen_tokens
|
||||
);
|
||||
|
||||
for i in 0..total_steps {
|
||||
let start = std::time::Instant::now();
|
||||
let is_prefill = i < prompt_len - 1;
|
||||
let seq_len = sentence.len();
|
||||
|
||||
cx.set_dim('s', seq_len);
|
||||
cx.set_dim('p', prev_seq);
|
||||
|
||||
runtime.set_data(
|
||||
input,
|
||||
sentence.iter().map(|t| *t as i32).collect::<Vec<_>>(),
|
||||
@@ -104,26 +169,26 @@ fn main() {
|
||||
(prev_seq as i32..(seq_len + prev_seq) as i32).collect::<Vec<_>>(),
|
||||
);
|
||||
|
||||
let step_start = std::time::Instant::now();
|
||||
runtime.execute(&cx.dyn_map);
|
||||
let logits_data = runtime.get_f32(logits);
|
||||
let step_elapsed = step_start.elapsed();
|
||||
|
||||
// Round-trip KV cache
|
||||
for (layer_idx, (k_out, v_out)) in cache_outputs.iter().enumerate() {
|
||||
let k_buf = runtime.remove_buffer(*k_out);
|
||||
let v_buf = runtime.remove_buffer(*v_out);
|
||||
runtime.set_buffer(kv_cache.k_caches[layer_idx], k_buf);
|
||||
runtime.set_buffer(kv_cache.v_caches[layer_idx], v_buf);
|
||||
}
|
||||
|
||||
prev_seq += seq_len;
|
||||
fwd_durations.push(start.elapsed());
|
||||
|
||||
if is_prefill {
|
||||
sentence = vec![prompt_tokens[i + 1]];
|
||||
continue;
|
||||
}
|
||||
|
||||
// Greedy decode with repetition penalty
|
||||
decode_step_times.push(step_elapsed);
|
||||
|
||||
let mut last_row = logits_data[logits_data.len() - VOCAB_SIZE..].to_vec();
|
||||
for &tok in &seen_tokens {
|
||||
let logit = &mut last_row[tok as usize];
|
||||
@@ -152,21 +217,13 @@ fn main() {
|
||||
}
|
||||
println!();
|
||||
|
||||
// Benchmarks
|
||||
let decode_durations: Vec<_> = fwd_durations.iter().skip(prompt_len).collect();
|
||||
if decode_durations.len() > 2 {
|
||||
println!(
|
||||
" TTFT: {:.2} ms",
|
||||
fwd_durations[..prompt_len]
|
||||
.iter()
|
||||
.sum::<Duration>()
|
||||
.as_secs_f64()
|
||||
* 1e3
|
||||
);
|
||||
// ── Report ────────────────────────────────────────────────────────────
|
||||
println!(" TTFT: {:.2} ms", ttft_ms);
|
||||
if decode_step_times.len() > 1 {
|
||||
println!(
|
||||
" TPOT: {:.2} ms",
|
||||
(decode_durations.iter().skip(1).copied().sum::<Duration>()
|
||||
/ (decode_durations.len() - 1) as u32)
|
||||
(decode_step_times.iter().skip(1).sum::<Duration>()
|
||||
/ (decode_step_times.len() - 1) as u32)
|
||||
.as_secs_f64()
|
||||
* 1_000.
|
||||
);
|
||||
|
||||
@@ -178,7 +178,7 @@ impl Qwen {
|
||||
kv_cache.v_caches[i],
|
||||
kv_cache.max_seq,
|
||||
);
|
||||
x = x_new.graph_break();
|
||||
x = x_new;
|
||||
cache_outputs.push((k_out, v_out));
|
||||
}
|
||||
// Tied embeddings: lm_head = embedding.t()
|
||||
|
||||
@@ -11,11 +11,21 @@ use tokenizers::Tokenizer;
|
||||
|
||||
const REPO_ID: &str = "Qwen/Qwen3-30B-A3B";
|
||||
|
||||
// Default configuration — override at runtime via env vars.
|
||||
const DEFAULT_MAX_SEQ_LEN: usize = 4096;
|
||||
const DEFAULT_SEARCH_GRAPHS: usize = 50;
|
||||
const DEFAULT_GEN_TOKENS: usize = 30;
|
||||
|
||||
fn env_usize(name: &str, default: usize) -> usize {
|
||||
std::env::var(name).ok().and_then(|s| s.parse().ok()).unwrap_or(default)
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let max_seq_len = 4096;
|
||||
let gen_tokens = 30;
|
||||
let search_graphs = 50;
|
||||
let prompt = "The capital of France is";
|
||||
let max_seq_len = env_usize("MAX_SEQ_LEN", DEFAULT_MAX_SEQ_LEN);
|
||||
let gen_tokens = env_usize("GEN_TOKENS", DEFAULT_GEN_TOKENS);
|
||||
let search_graphs = env_usize("SEARCH_GRAPHS", DEFAULT_SEARCH_GRAPHS);
|
||||
let prompt = std::env::var("PROMPT")
|
||||
.unwrap_or_else(|_| "The capital of France is".to_string());
|
||||
|
||||
let ctx = CudaContext::new(0).unwrap();
|
||||
let stream = ctx.default_stream();
|
||||
@@ -24,7 +34,7 @@ fn main() {
|
||||
println!("Using model directory: {}", model_dir.display());
|
||||
|
||||
let tokenizer = Tokenizer::from_file(model_dir.join("tokenizer.json")).unwrap();
|
||||
let prompt_tokens = tokenizer.encode(prompt, true).unwrap().get_ids().to_vec();
|
||||
let prompt_tokens = tokenizer.encode(prompt.as_str(), true).unwrap().get_ids().to_vec();
|
||||
|
||||
// Build graph
|
||||
let mut cx = Graph::default();
|
||||
@@ -39,6 +49,7 @@ fn main() {
|
||||
}
|
||||
|
||||
println!("Building E-Graph...");
|
||||
let compile_start = std::time::Instant::now();
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
|
||||
println!("Loading weights...");
|
||||
@@ -58,12 +69,68 @@ fn main() {
|
||||
runtime.set_data(input, vec![1]);
|
||||
runtime.set_data(pos_ids, vec![1]);
|
||||
runtime = cx.search(runtime, search_graphs);
|
||||
println!(" COMPILE: {:.2} ms", compile_start.elapsed().as_secs_f64() * 1e3);
|
||||
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
|
||||
}
|
||||
|
||||
// Full-prompt warmup: run the complete prompt to bring the GPU to steady state before timing
|
||||
for (w_pos, &w_token) in prompt_tokens.iter().enumerate() {
|
||||
cx.set_dim('s', 1);
|
||||
cx.set_dim('p', w_pos);
|
||||
runtime.set_data(input, vec![w_token as i32]);
|
||||
runtime.set_data(pos_ids, vec![w_pos as i32]);
|
||||
runtime.execute(&cx.dyn_map);
|
||||
for (layer_idx, (k_out, v_out)) in cache_outputs.iter().enumerate() {
|
||||
let k_buf = runtime.remove_buffer(*k_out);
|
||||
let v_buf = runtime.remove_buffer(*v_out);
|
||||
runtime.set_buffer(kv_cache.k_caches[layer_idx], k_buf);
|
||||
runtime.set_buffer(kv_cache.v_caches[layer_idx], v_buf);
|
||||
}
|
||||
}
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
|
||||
}
|
||||
|
||||
let iters = env_usize("ITERS", 3);
|
||||
let prompt_len = prompt_tokens.len();
|
||||
|
||||
// ── TTFT: prefill-only timing over N iterations ───────────────────────
|
||||
let mut ttft_samples_ms: Vec<f64> = vec![];
|
||||
for _ in 0..iters {
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
|
||||
}
|
||||
let prefill_start = std::time::Instant::now();
|
||||
let mut prev_seq = 0usize;
|
||||
for &token in &prompt_tokens {
|
||||
cx.set_dim('s', 1);
|
||||
cx.set_dim('p', prev_seq);
|
||||
runtime.set_data(input, vec![token as i32]);
|
||||
runtime.set_data(pos_ids, vec![prev_seq as i32]);
|
||||
runtime.execute(&cx.dyn_map);
|
||||
for (layer_idx, (k_out, v_out)) in cache_outputs.iter().enumerate() {
|
||||
let k_buf = runtime.remove_buffer(*k_out);
|
||||
let v_buf = runtime.remove_buffer(*v_out);
|
||||
runtime.set_buffer(kv_cache.k_caches[layer_idx], k_buf);
|
||||
runtime.set_buffer(kv_cache.v_caches[layer_idx], v_buf);
|
||||
}
|
||||
prev_seq += 1;
|
||||
}
|
||||
ttft_samples_ms.push(prefill_start.elapsed().as_secs_f64() * 1e3);
|
||||
}
|
||||
ttft_samples_ms.sort_by(|a, b| a.partial_cmp(b).unwrap());
|
||||
let ttft_ms = ttft_samples_ms[ttft_samples_ms.len() / 2];
|
||||
|
||||
// ── Text generation: one pass for TPOT + visible output ───────────────
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
|
||||
}
|
||||
print!("{prompt}");
|
||||
std::io::stdout().flush().unwrap();
|
||||
|
||||
@@ -75,28 +142,21 @@ fn main() {
|
||||
const EOS_TOKEN: u32 = 151645;
|
||||
const STOP_TOKEN: u32 = 151643;
|
||||
|
||||
// Prefill: process prompt tokens one at a time
|
||||
let prefill_start = std::time::Instant::now();
|
||||
for &token in &prompt_tokens {
|
||||
cx.set_dim('s', 1);
|
||||
cx.set_dim('p', prev_seq);
|
||||
runtime.set_data(input, vec![token as i32]);
|
||||
runtime.set_data(pos_ids, vec![prev_seq as i32]);
|
||||
runtime.execute(&cx.dyn_map);
|
||||
|
||||
// Round-trip KV cache
|
||||
for (layer_idx, (k_out, v_out)) in cache_outputs.iter().enumerate() {
|
||||
let k_buf = runtime.remove_buffer(*k_out);
|
||||
let v_buf = runtime.remove_buffer(*v_out);
|
||||
runtime.set_buffer(kv_cache.k_caches[layer_idx], k_buf);
|
||||
runtime.set_buffer(kv_cache.v_caches[layer_idx], v_buf);
|
||||
}
|
||||
|
||||
prev_seq += 1;
|
||||
}
|
||||
let prefill_duration = prefill_start.elapsed();
|
||||
|
||||
// Get logits from last prefill step and sample first new token
|
||||
let logits_data = runtime.get_f32(logits);
|
||||
let last_row = &logits_data[..VOCAB_SIZE];
|
||||
let mut next_token = last_row
|
||||
@@ -109,7 +169,6 @@ fn main() {
|
||||
std::io::stdout().flush().unwrap();
|
||||
seen_tokens.insert(next_token);
|
||||
|
||||
// Decode loop
|
||||
for _ in 1..gen_tokens {
|
||||
let start = std::time::Instant::now();
|
||||
cx.set_dim('s', 1);
|
||||
@@ -117,15 +176,12 @@ fn main() {
|
||||
runtime.set_data(input, vec![next_token as i32]);
|
||||
runtime.set_data(pos_ids, vec![prev_seq as i32]);
|
||||
runtime.execute(&cx.dyn_map);
|
||||
|
||||
// Round-trip KV cache
|
||||
for (layer_idx, (k_out, v_out)) in cache_outputs.iter().enumerate() {
|
||||
let k_buf = runtime.remove_buffer(*k_out);
|
||||
let v_buf = runtime.remove_buffer(*v_out);
|
||||
runtime.set_buffer(kv_cache.k_caches[layer_idx], k_buf);
|
||||
runtime.set_buffer(kv_cache.v_caches[layer_idx], v_buf);
|
||||
}
|
||||
|
||||
prev_seq += 1;
|
||||
|
||||
let logits_data = runtime.get_f32(logits);
|
||||
@@ -156,11 +212,10 @@ fn main() {
|
||||
}
|
||||
println!();
|
||||
|
||||
// Report benchmarks
|
||||
// ── Report ────────────────────────────────────────────────────────────
|
||||
println!(
|
||||
" TTFT: {:.2} ms ({} prompt tokens)",
|
||||
prefill_duration.as_secs_f64() * 1e3,
|
||||
prompt_tokens.len()
|
||||
ttft_ms, prompt_len
|
||||
);
|
||||
if fwd_durations.len() > 1 {
|
||||
println!(
|
||||
|
||||
@@ -186,7 +186,7 @@ impl Qwen3MoE {
|
||||
kv_cache.v_caches[i],
|
||||
kv_cache.max_seq,
|
||||
);
|
||||
x = x_new.graph_break();
|
||||
x = x_new;
|
||||
cache_outputs.push((k_out, v_out));
|
||||
}
|
||||
let logits = self.lm_norm.forward(x).matmul(self.lm_head.t());
|
||||
@@ -239,7 +239,6 @@ impl Qwen3MoELayer {
|
||||
let (attn_out, k_cache_out, v_cache_out) =
|
||||
attention(q_rope, k_rope, v, k_cache_in, v_cache_in, max_seq);
|
||||
x += attn_out.matmul(self.o_proj.t());
|
||||
x = x.graph_break();
|
||||
|
||||
// MoE FFN
|
||||
let x_mlp = self.mlp_rms.forward(x);
|
||||
@@ -264,8 +263,7 @@ impl QwenMoE {
|
||||
let row_offsets = x
|
||||
.graph()
|
||||
.iota(Expression::from('z') / k_expr * e_dim, top_k_indices.dims());
|
||||
let routing_flat_idx =
|
||||
(row_offsets.cast(DType::F32) + top_k_indices.cast(DType::F32)).cast(DType::Int);
|
||||
let routing_flat_idx = row_offsets + top_k_indices;
|
||||
let top_k_values = routing_weights.gather(routing_flat_idx);
|
||||
|
||||
// 4. Gather gate_up expert weights → [s, k, intermediate*2, H]
|
||||
@@ -303,18 +301,18 @@ fn gather_experts(
|
||||
) -> GraphTensor {
|
||||
let (_, d1, d2) = weights.dims3();
|
||||
let io = d1 * d2;
|
||||
let base = (top_k_indices * io).cast(DType::F32);
|
||||
let within = graph_source
|
||||
.graph()
|
||||
.iota(Expression::from('z'), (d1, d2))
|
||||
.cast(DType::F32);
|
||||
// Keep expert gather indices in Int all the way through. Routing them through
|
||||
// F32 loses exactness once the flat offsets exceed 2^24, which Qwen's expert
|
||||
// tensors do at realistic hidden sizes.
|
||||
let base = top_k_indices * io;
|
||||
let within = graph_source.graph().iota(Expression::from('z'), (d1, d2));
|
||||
let n_base = base.dims().len();
|
||||
let exp_base = base.expand_dim(n_base, d1).expand_dim(n_base + 1, d2);
|
||||
let mut exp_within = within;
|
||||
for (i, dim) in base.dims().iter().enumerate() {
|
||||
exp_within = exp_within.expand_dim(i, *dim);
|
||||
}
|
||||
let expert_flat_idx = (exp_base + exp_within).cast(DType::Int);
|
||||
let expert_flat_idx = exp_base + exp_within;
|
||||
weights.gather(expert_flat_idx)
|
||||
}
|
||||
|
||||
|
||||
339
src/dyn_backend.rs
Normal file
339
src/dyn_backend.rs
Normal file
@@ -0,0 +1,339 @@
|
||||
//! Dynamic backend trait and factory-based compilation.
|
||||
//!
|
||||
//! This module provides:
|
||||
//! - [`DynBackend`]: an object-safe trait for dynamic backend dispatch
|
||||
//! - [`compile_backend`]: generic helper that handles the full compilation pipeline
|
||||
//! - [`BackendFactory`]: function pointer type for backend factories
|
||||
//! - [`NativeDynBackend`]: the reference implementation for CPU
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use half::{bf16, f16};
|
||||
use petgraph::stable_graph::NodeIndex;
|
||||
use rustc_hash::FxHashMap;
|
||||
|
||||
use crate::dtype::DType;
|
||||
use crate::graph::Graph;
|
||||
use crate::hlir::{NativeData, NativeRuntime, Output};
|
||||
use crate::op::Runtime;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// DynBackend trait
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Object-safe backend trait for dynamic dispatch.
|
||||
///
|
||||
/// Wraps a concrete [`Runtime`] implementor, providing a uniform interface
|
||||
/// for `luminal_python` (and other dynamic consumers) without requiring
|
||||
/// generic type parameters.
|
||||
pub trait DynBackend {
|
||||
fn name(&self) -> &str;
|
||||
|
||||
/// The device type this backend operates on (e.g. "cpu", "cuda").
|
||||
/// Used by the Python frontend to decide input tensor placement.
|
||||
fn device_type(&self) -> &str {
|
||||
"cpu"
|
||||
}
|
||||
|
||||
fn set_data_bytes(&mut self, node: NodeIndex, bytes: Vec<u8>, dtype: DType);
|
||||
fn set_data_f32(&mut self, node: NodeIndex, data: Vec<f32>);
|
||||
fn get_output_f32(&self, node: NodeIndex) -> Vec<f32>;
|
||||
fn get_output_i32(&self, _node: NodeIndex) -> Vec<i32> {
|
||||
panic!("get_output_i32 not supported by '{}'", self.name());
|
||||
}
|
||||
fn get_output_bool(&self, _node: NodeIndex) -> Vec<bool> {
|
||||
panic!("get_output_bool not supported by '{}'", self.name());
|
||||
}
|
||||
fn execute(&mut self, dyn_map: &FxHashMap<char, usize>);
|
||||
|
||||
// --- Optional device pointer support (GPU backends) --------------------
|
||||
|
||||
fn supports_device_ptrs(&self) -> bool {
|
||||
false
|
||||
}
|
||||
/// # Safety
|
||||
/// Device pointer must be valid and point to at least `n_bytes` bytes.
|
||||
unsafe fn set_device_ptr(&mut self, _node: NodeIndex, _ptr: u64, _n_bytes: usize) {
|
||||
panic!("set_device_ptr not supported by '{}'", self.name());
|
||||
}
|
||||
/// # Safety
|
||||
/// Device pointer must remain valid through the next `execute()` call.
|
||||
unsafe fn set_output_device_ptr(&mut self, _node: NodeIndex, _ptr: u64, _n_bytes: usize) {
|
||||
panic!("set_output_device_ptr not supported by '{}'", self.name());
|
||||
}
|
||||
fn output_is_zero_copy(&self, _node: NodeIndex) -> bool {
|
||||
false
|
||||
}
|
||||
/// # Safety
|
||||
/// `dest_ptr` must be a valid device allocation with at least `n_bytes`.
|
||||
unsafe fn copy_output_to_device_ptr(&self, _node: NodeIndex, _dest_ptr: u64, _n_bytes: usize) {
|
||||
panic!(
|
||||
"copy_output_to_device_ptr not supported by '{}'",
|
||||
self.name()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// BackendCompileArgs + BackendFactory + Registry
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Arguments passed to a backend factory during compilation.
|
||||
pub struct BackendCompileArgs {
|
||||
pub search_iters: usize,
|
||||
pub weights: Vec<(String, Vec<u8>, DType)>,
|
||||
pub tensor_sizes: HashMap<String, usize>,
|
||||
pub device_ptrs: HashMap<String, (u64, usize)>,
|
||||
}
|
||||
|
||||
/// Canonical PyCapsule name for [`BackendFactory`] function-pointer capsules.
|
||||
///
|
||||
/// Value MUST remain `"luminal.backend_factory"` for compatibility with
|
||||
/// external plugin producers built against older versions of this crate.
|
||||
pub const BACKEND_FACTORY_CAPSULE_NAME: &std::ffi::CStr = c"luminal.backend_factory";
|
||||
|
||||
/// A factory function that compiles a [`Graph`] into a ready-to-execute [`DynBackend`].
|
||||
pub type BackendFactory = fn(&mut Graph, BackendCompileArgs) -> Result<Box<dyn DynBackend>, String>;
|
||||
|
||||
/// Compile a graph using a factory function directly.
|
||||
pub fn compile_backend_from_factory(
|
||||
factory: BackendFactory,
|
||||
graph: &mut Graph,
|
||||
args: BackendCompileArgs,
|
||||
) -> Result<Box<dyn DynBackend>, String> {
|
||||
factory(graph, args)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// compile_backend — generic compilation helper
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Optional callback for uploading a device pointer + byte count to a node.
|
||||
pub type SetDevicePtrFn<'a, Rt> = &'a dyn Fn(&mut Rt, NodeIndex, u64, usize);
|
||||
|
||||
/// Generic compilation pipeline shared by all backends.
|
||||
///
|
||||
/// Handles: build search space → init runtime → set device ptrs → set dummy
|
||||
/// data → search → load weights → wrap as `Box<dyn DynBackend>`.
|
||||
///
|
||||
/// Backend-specific behavior is injected via callbacks:
|
||||
/// - `init`: create the concrete runtime
|
||||
/// - `set_raw`: upload raw bytes + dtype to a node
|
||||
/// - `set_device_ptr`: optional zero-copy device pointer setter
|
||||
/// - `wrap`: wrap the final runtime in a `Box<dyn DynBackend>`
|
||||
pub fn compile_backend<Rt: Runtime + 'static>(
|
||||
graph: &mut Graph,
|
||||
args: BackendCompileArgs,
|
||||
init: impl FnOnce() -> Result<Rt, String>,
|
||||
set_raw: impl Fn(&mut Rt, NodeIndex, Vec<u8>, DType),
|
||||
set_device_ptr: Option<SetDevicePtrFn<'_, Rt>>,
|
||||
wrap: impl FnOnce(Rt) -> Box<dyn DynBackend>,
|
||||
) -> Result<Box<dyn DynBackend>, String> {
|
||||
// Build label map from input_meta (plain data — no downcast needed,
|
||||
// survives cross-binary type identity mismatches with external plugins).
|
||||
let label_map = build_label_map(graph);
|
||||
|
||||
graph.build_search_space::<Rt>();
|
||||
|
||||
let mut rt = init()?;
|
||||
|
||||
// Set device pointers for zero-copy weights (GPU backends)
|
||||
let mut device_ptr_nodes = rustc_hash::FxHashSet::default();
|
||||
if let Some(set_ptr) = set_device_ptr {
|
||||
for (label, &(ptr, n_bytes)) in &args.device_ptrs {
|
||||
if let Some(&node_id) = label_map.get(label) {
|
||||
set_ptr(&mut rt, node_id, ptr, n_bytes);
|
||||
device_ptr_nodes.insert(node_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Set dummy ones for Input nodes (required for search profiling).
|
||||
// Must use 1, NOT 0 — zero inputs cause NaN in many ops.
|
||||
for (&node_id, (label, dtype)) in &graph.input_meta {
|
||||
if device_ptr_nodes.contains(&node_id) {
|
||||
continue;
|
||||
}
|
||||
if let Some(&n) = args.tensor_sizes.get(label) {
|
||||
if n > 0 {
|
||||
set_raw(&mut rt, node_id, make_ones_bytes(n, *dtype), *dtype);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Search
|
||||
let mut rt = graph.search(rt, args.search_iters);
|
||||
|
||||
// Rebuild label map after search (graph may have changed)
|
||||
let label_map = build_label_map(graph);
|
||||
|
||||
// Load real weights post-search (skip device-ptr weights)
|
||||
for (label, bytes, dtype) in &args.weights {
|
||||
if !args.device_ptrs.contains_key(label) {
|
||||
if let Some(&node_id) = label_map.get(label) {
|
||||
set_raw(&mut rt, node_id, bytes.clone(), *dtype);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(wrap(rt))
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Shared utilities
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Build a `label → NodeIndex` map for all Input nodes in the graph.
|
||||
///
|
||||
/// Uses `graph.input_meta` (plain data) rather than downcasting, so it works
|
||||
/// correctly when the graph was built by a different compilation unit (e.g.
|
||||
/// an external backend plugin compiled as a separate wheel).
|
||||
pub fn build_label_map(graph: &Graph) -> HashMap<String, NodeIndex> {
|
||||
graph
|
||||
.input_meta
|
||||
.iter()
|
||||
.map(|(&node_id, (label, _))| (label.clone(), node_id))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Create a byte buffer of `n_elements` ones for the given dtype.
|
||||
///
|
||||
/// IMPORTANT: Must use 1, NOT 0 — zero inputs cause NaN in many ops
|
||||
/// (fmod, recip, log, etc.) during search profiling.
|
||||
pub fn make_ones_bytes(n_elements: usize, dtype: DType) -> Vec<u8> {
|
||||
// Safety: all source types have defined bit representations; we just
|
||||
// reinterpret the backing Vec<u8> without changing the allocation.
|
||||
unsafe fn as_bytes<T>(v: Vec<T>) -> Vec<u8> {
|
||||
let mut v = std::mem::ManuallyDrop::new(v);
|
||||
let ptr = v.as_mut_ptr() as *mut u8;
|
||||
let len = v.len() * std::mem::size_of::<T>();
|
||||
unsafe { Vec::from_raw_parts(ptr, len, len) }
|
||||
}
|
||||
match dtype {
|
||||
DType::F32 | DType::TF32 => unsafe { as_bytes(vec![1.0f32; n_elements]) },
|
||||
DType::F64 => unsafe { as_bytes(vec![1.0f64; n_elements]) },
|
||||
DType::F16 => unsafe { as_bytes(vec![f16::from_f32(1.0); n_elements]) },
|
||||
DType::Bf16 => unsafe { as_bytes(vec![bf16::from_f32(1.0); n_elements]) },
|
||||
DType::Int => unsafe { as_bytes(vec![1i32; n_elements]) },
|
||||
DType::I16 => unsafe { as_bytes(vec![1i16; n_elements]) },
|
||||
DType::U16 => unsafe { as_bytes(vec![1u16; n_elements]) },
|
||||
_ => vec![1u8; n_elements], // I8, U8, Bool, sub-byte types
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert raw bytes + [`DType`] to [`NativeData`].
|
||||
pub fn bytes_to_native_data(bytes: Vec<u8>, dtype: DType) -> NativeData {
|
||||
// Safety: source bytes are from a valid typed buffer; we reinterpret.
|
||||
unsafe fn from_bytes<T: Copy>(bytes: Vec<u8>) -> Vec<T> {
|
||||
let n = bytes.len() / std::mem::size_of::<T>();
|
||||
let mut bytes = std::mem::ManuallyDrop::new(bytes);
|
||||
unsafe { Vec::from_raw_parts(bytes.as_mut_ptr() as *mut T, n, n) }
|
||||
}
|
||||
match dtype {
|
||||
DType::F32 | DType::TF32 => NativeData::F32(unsafe { from_bytes(bytes) }),
|
||||
DType::F64 => {
|
||||
let f64s: Vec<f64> = unsafe { from_bytes(bytes) };
|
||||
NativeData::F32(f64s.into_iter().map(|v| v as f32).collect())
|
||||
}
|
||||
DType::F16 => NativeData::F16(unsafe { from_bytes(bytes) }),
|
||||
DType::Bf16 => NativeData::Bf16(unsafe { from_bytes(bytes) }),
|
||||
DType::Int => NativeData::Int(unsafe { from_bytes(bytes) }),
|
||||
DType::Bool => NativeData::Bool(bytes.into_iter().map(|b| b != 0).collect()),
|
||||
DType::I8 => NativeData::Int(bytes.iter().map(|&b| b as i8 as i32).collect()),
|
||||
DType::U8 => NativeData::Int(bytes.iter().map(|&b| b as i32).collect()),
|
||||
DType::I16 => {
|
||||
let i16s: Vec<i16> = unsafe { from_bytes(bytes) };
|
||||
NativeData::Int(i16s.into_iter().map(|v| v as i32).collect())
|
||||
}
|
||||
DType::U16 => {
|
||||
let u16s: Vec<u16> = unsafe { from_bytes(bytes) };
|
||||
NativeData::Int(u16s.into_iter().map(|v| v as i32).collect())
|
||||
}
|
||||
_ => NativeData::F32(vec![]),
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// NativeDynBackend
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// [`DynBackend`] wrapper for the native (CPU) runtime.
|
||||
pub struct NativeDynBackend {
|
||||
pub runtime: NativeRuntime,
|
||||
}
|
||||
|
||||
impl DynBackend for NativeDynBackend {
|
||||
fn name(&self) -> &str {
|
||||
"native"
|
||||
}
|
||||
|
||||
fn set_data_bytes(&mut self, node: NodeIndex, bytes: Vec<u8>, dtype: DType) {
|
||||
self.runtime
|
||||
.set_data(node, bytes_to_native_data(bytes, dtype));
|
||||
}
|
||||
|
||||
fn set_data_f32(&mut self, node: NodeIndex, data: Vec<f32>) {
|
||||
self.runtime.set_data(node, data);
|
||||
}
|
||||
|
||||
fn get_output_f32(&self, node: NodeIndex) -> Vec<f32> {
|
||||
let data = self.output_buffer(node);
|
||||
(0..data.len()).map(|i| data.f32(i)).collect()
|
||||
}
|
||||
|
||||
fn get_output_i32(&self, node: NodeIndex) -> Vec<i32> {
|
||||
let data = self.output_buffer(node);
|
||||
(0..data.len()).map(|i| data.i32(i)).collect()
|
||||
}
|
||||
|
||||
fn get_output_bool(&self, node: NodeIndex) -> Vec<bool> {
|
||||
let data = self.output_buffer(node);
|
||||
(0..data.len()).map(|i| data.bool(i)).collect()
|
||||
}
|
||||
|
||||
fn execute(&mut self, dyn_map: &FxHashMap<char, usize>) {
|
||||
self.runtime.execute(dyn_map);
|
||||
}
|
||||
}
|
||||
|
||||
impl NativeDynBackend {
|
||||
fn output_buffer(&self, node: NodeIndex) -> &NativeData {
|
||||
let output_id = self
|
||||
.runtime
|
||||
.graph
|
||||
.node_indices()
|
||||
.find(|n| {
|
||||
(**self.runtime.graph[*n])
|
||||
.as_any()
|
||||
.downcast_ref::<Output>()
|
||||
.is_some_and(|out| out.node == node.index())
|
||||
})
|
||||
.unwrap_or_else(|| panic!("No output node found for {:?}", node));
|
||||
self.runtime
|
||||
.buffers
|
||||
.get(&output_id)
|
||||
.unwrap_or_else(|| panic!("No buffer data for output {:?}", node))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn native_factory(
|
||||
graph: &mut Graph,
|
||||
args: BackendCompileArgs,
|
||||
) -> Result<Box<dyn DynBackend>, String> {
|
||||
compile_backend::<NativeRuntime>(
|
||||
graph,
|
||||
args,
|
||||
|| Ok(NativeRuntime::default()),
|
||||
// NativeRuntime::set_data requires the LLIR graph to be loaded (it searches
|
||||
// for Input nodes in the LLIR). Before search, the LLIR is empty. We guard
|
||||
// against that: if rt.graph is empty, skip (dummy data isn't needed for
|
||||
// native since its profile is a no-op).
|
||||
|rt, node, bytes, dtype| {
|
||||
if rt.graph.node_count() > 0 {
|
||||
rt.set_data(node, bytes_to_native_data(bytes, dtype));
|
||||
}
|
||||
},
|
||||
None,
|
||||
|rt| Box::new(NativeDynBackend { runtime: rt }),
|
||||
)
|
||||
}
|
||||
@@ -6,7 +6,7 @@ use rand::Rng;
|
||||
use rustc_hash::FxHashSet;
|
||||
use std::collections::hash_map::DefaultHasher;
|
||||
use std::hash::{Hash, Hasher};
|
||||
use std::{str, sync::Arc};
|
||||
use std::{str, sync::Arc, time::Duration};
|
||||
use tracing::trace;
|
||||
|
||||
pub mod api;
|
||||
@@ -112,24 +112,64 @@ pub fn early_egglog(
|
||||
ops: &[Arc<Box<dyn EgglogOp>>],
|
||||
cleanup: bool,
|
||||
) -> String {
|
||||
let parts = OpTextParts::new(ops, cleanup);
|
||||
early_egglog_with(program, root, &parts)
|
||||
}
|
||||
|
||||
pub fn full_egglog(program: &str, ops: &[Arc<Box<dyn EgglogOp>>], cleanup: bool) -> String {
|
||||
let parts = OpTextParts::new(ops, cleanup);
|
||||
full_egglog_with(program, &parts)
|
||||
}
|
||||
|
||||
/// Pre-computed per-op text fragments. `run_egglog` calls early + full back
|
||||
/// to back with identical `ops`; materialising all op-derived strings once
|
||||
/// up front means callers that want to drive multiple egglog runs in parallel
|
||||
/// only need to share `&str` references and never touch the non-Send trait
|
||||
/// objects in `ops`.
|
||||
pub struct OpTextParts {
|
||||
op_defs: String,
|
||||
cleanups: String,
|
||||
early_rewrites: String,
|
||||
full_rewrites: String,
|
||||
}
|
||||
|
||||
impl OpTextParts {
|
||||
pub fn new(ops: &[Arc<Box<dyn EgglogOp>>], cleanup: bool) -> Self {
|
||||
Self {
|
||||
op_defs: op_defs_string(ops),
|
||||
cleanups: if cleanup {
|
||||
op_cleanups_string(ops)
|
||||
} else {
|
||||
String::new()
|
||||
},
|
||||
early_rewrites: ops
|
||||
.iter()
|
||||
.flat_map(|o| o.early_rewrites())
|
||||
.map(|r| r.to_egglog_string())
|
||||
.join("\n"),
|
||||
full_rewrites: ops
|
||||
.iter()
|
||||
.flat_map(|o| o.rewrites())
|
||||
.map(|r| r.to_egglog_string())
|
||||
.join("\n"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn early_egglog_with(program: &str, root: &str, parts: &OpTextParts) -> String {
|
||||
[
|
||||
base::base_expression_egglog(),
|
||||
op_defs_string(ops),
|
||||
ops.iter()
|
||||
.flat_map(|o| o.early_rewrites())
|
||||
.map(|r| r.to_egglog_string())
|
||||
.join("\n"),
|
||||
if cleanup {
|
||||
op_cleanups_string(ops)
|
||||
} else {
|
||||
"".to_string()
|
||||
},
|
||||
parts.op_defs.clone(),
|
||||
parts.early_rewrites.clone(),
|
||||
parts.cleanups.clone(),
|
||||
base::base_cleanup_egglog(),
|
||||
program.to_string(),
|
||||
format!(
|
||||
"(run-schedule
|
||||
(saturate expr)
|
||||
(run)
|
||||
(repeat 6
|
||||
(saturate expr)
|
||||
(run)
|
||||
)
|
||||
(saturate base_cleanup)
|
||||
)
|
||||
(extract {root})"
|
||||
@@ -138,20 +178,13 @@ pub fn early_egglog(
|
||||
.join("\n")
|
||||
}
|
||||
|
||||
pub fn full_egglog(program: &str, ops: &[Arc<Box<dyn EgglogOp>>], cleanup: bool) -> String {
|
||||
fn full_egglog_with(program: &str, parts: &OpTextParts) -> String {
|
||||
[
|
||||
base::base_expression_egglog(),
|
||||
op_defs_string(ops),
|
||||
if cleanup {
|
||||
op_cleanups_string(ops)
|
||||
} else {
|
||||
"".to_string()
|
||||
},
|
||||
parts.op_defs.clone(),
|
||||
parts.cleanups.clone(),
|
||||
base::base_cleanup_egglog(),
|
||||
ops.iter()
|
||||
.flat_map(|o| o.rewrites())
|
||||
.map(|r| r.to_egglog_string())
|
||||
.join("\n"),
|
||||
parts.full_rewrites.clone(),
|
||||
program.to_string(),
|
||||
RUN_SCHEDULE.to_string(),
|
||||
]
|
||||
@@ -160,8 +193,7 @@ pub fn full_egglog(program: &str, ops: &[Arc<Box<dyn EgglogOp>>], cleanup: bool)
|
||||
|
||||
use crate::{
|
||||
dtype::DType,
|
||||
graph::{Graph, LLIRGraph, SubgraphDescriptor},
|
||||
hlir::{Input, Output},
|
||||
graph::{Graph, LLIRGraph},
|
||||
op::{CustomOp, EgglogOp},
|
||||
prelude::FxHashMap,
|
||||
shape::Expression,
|
||||
@@ -179,6 +211,20 @@ pub struct SerializedEGraph {
|
||||
pub roots: Vec<ClassId>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct EgglogStageReport {
|
||||
pub num_matches_per_rule: FxHashMap<String, usize>,
|
||||
pub search_and_apply_time_per_rule: FxHashMap<String, Duration>,
|
||||
pub total_time: Duration,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct EgglogRunReport {
|
||||
pub early: EgglogStageReport,
|
||||
pub full: EgglogStageReport,
|
||||
pub total_time: Duration,
|
||||
}
|
||||
|
||||
impl SerializedEGraph {
|
||||
/// This is an opinionated function which does more than strictly take the state of the egglog object.
|
||||
/// It also filters out "[...]" nodes and then changes the structure from the e-termDAG that egraph-serialize
|
||||
@@ -320,11 +366,17 @@ pub fn hash_egglog_normalized(text: &str) -> u64 {
|
||||
for line in text.lines() {
|
||||
if line.contains("(Input ") {
|
||||
// Format: (let tN (Input NODE "LABEL" (DTYPE)))
|
||||
// Strip the node index and label, keep only the dtype.
|
||||
// Strip the node index and label identity, but preserve whether this
|
||||
// is a synthetic boundary input or a real graph input.
|
||||
// The dtype is the last parenthesized token, e.g. "(F32)".
|
||||
if let Some(dtype_start) = line.rfind(" (") {
|
||||
let dtype = &line[dtype_start + 1..];
|
||||
("INPUT", dtype).hash(&mut hasher);
|
||||
let kind = if line.contains("\"boundary\"") {
|
||||
"BOUNDARY_INPUT"
|
||||
} else {
|
||||
"REAL_INPUT"
|
||||
};
|
||||
(kind, dtype).hash(&mut hasher);
|
||||
} else {
|
||||
line.hash(&mut hasher);
|
||||
}
|
||||
@@ -391,8 +443,10 @@ pub fn hlir_to_egglog(graph: &Graph) -> (String, String) {
|
||||
|
||||
// 2. Map <node-id> → <egglog var name>
|
||||
let mut names: HashMap<NodeIndex, String> = HashMap::new();
|
||||
let mut out = String::new();
|
||||
// Pre-size output to avoid growth reallocations; ops emit ~100-200 chars each.
|
||||
let mut out = String::with_capacity(topo_order.len() * 160);
|
||||
|
||||
use std::fmt::Write;
|
||||
let mut curr_id = 0;
|
||||
for n in topo_order {
|
||||
let sources: Vec<(NodeIndex, String)> = graph
|
||||
@@ -401,7 +455,9 @@ pub fn hlir_to_egglog(graph: &Graph) -> (String, String) {
|
||||
.map(|src| (src, names[&src].clone()))
|
||||
.collect_vec();
|
||||
let code = graph[n].to_egglog(&sources);
|
||||
out.push_str(&format!("(let t{curr_id} {code})\n"));
|
||||
// write!() into the existing buffer skips the intermediate String
|
||||
// that format! would otherwise allocate for each node.
|
||||
let _ = writeln!(out, "(let t{curr_id} {code})");
|
||||
names.insert(n, format!("t{curr_id}"));
|
||||
curr_id += 1;
|
||||
}
|
||||
@@ -414,145 +470,12 @@ pub fn hlir_to_egglog(graph: &Graph) -> (String, String) {
|
||||
let mut root = names[0].clone();
|
||||
for node in names.into_iter().skip(1) {
|
||||
curr_id += 1;
|
||||
out.push_str(&format!("(let t{curr_id} (OutputJoin {root} {node}))\n"));
|
||||
let _ = writeln!(out, "(let t{curr_id} (OutputJoin {root} {node}))");
|
||||
root = format!("t{curr_id}");
|
||||
}
|
||||
(out.replace("(MVar \"z\")", "(MIter)"), root)
|
||||
}
|
||||
|
||||
/// Convert a subgraph of the HLIR to egglog, injecting synthetic Input/Output
|
||||
/// nodes at graph break boundaries.
|
||||
pub fn hlir_subgraph_to_egglog(graph: &Graph, subgraph: &SubgraphDescriptor) -> (String, String) {
|
||||
use std::cmp::Reverse;
|
||||
use std::collections::{BinaryHeap, HashMap};
|
||||
|
||||
let mut names: HashMap<NodeIndex, String> = HashMap::new();
|
||||
let mut out = String::new();
|
||||
let mut curr_id = 0;
|
||||
|
||||
// Emit synthetic Input nodes for boundary inputs
|
||||
for boundary in &subgraph.boundary_inputs {
|
||||
let var_name = format!("t{curr_id}");
|
||||
let code = format!(
|
||||
"(Input {} \"boundary\" ({:?}))",
|
||||
boundary.break_node.index(),
|
||||
boundary.dtype
|
||||
);
|
||||
out.push_str(&format!("(let {var_name} {code})\n"));
|
||||
// Map the GraphBreak node to this synthetic Input variable.
|
||||
// When downstream nodes reference the GraphBreak as a source, they'll use this.
|
||||
names.insert(boundary.break_node, var_name);
|
||||
curr_id += 1;
|
||||
}
|
||||
|
||||
// Topo-order only the nodes in this subgraph
|
||||
// Build sub-indeg map restricted to subgraph nodes
|
||||
let mut indeg: HashMap<NodeIndex, usize> = HashMap::new();
|
||||
for &n in &subgraph.nodes {
|
||||
let count = graph
|
||||
.graph
|
||||
.neighbors_directed(n, Direction::Incoming)
|
||||
.filter(|pred| subgraph.nodes.contains(pred))
|
||||
.count();
|
||||
indeg.insert(n, count);
|
||||
}
|
||||
|
||||
let mut ready: BinaryHeap<(Reverse<usize>, NodeIndex)> = BinaryHeap::new();
|
||||
for (&n, &d) in &indeg {
|
||||
if d == 0 {
|
||||
ready.push((Reverse(n.index()), n));
|
||||
}
|
||||
}
|
||||
|
||||
let mut topo_order: Vec<NodeIndex> = Vec::with_capacity(indeg.len());
|
||||
while let Some((_, n)) = ready.pop() {
|
||||
topo_order.push(n);
|
||||
for succ in graph.graph.neighbors_directed(n, Direction::Outgoing) {
|
||||
if let Some(e) = indeg.get_mut(&succ) {
|
||||
*e -= 1;
|
||||
if *e == 0 {
|
||||
ready.push((Reverse(succ.index()), succ));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Convert each node in topological order to egglog
|
||||
for n in topo_order {
|
||||
let sources: Vec<(NodeIndex, String)> = graph
|
||||
.get_sources(n)
|
||||
.into_iter()
|
||||
.map(|src| {
|
||||
let name = names
|
||||
.get(&src)
|
||||
.cloned()
|
||||
.unwrap_or_else(|| panic!("Missing egglog name for node {:?}", src));
|
||||
(src, name)
|
||||
})
|
||||
.collect_vec();
|
||||
let code = graph.graph[n].to_egglog(&sources);
|
||||
out.push_str(&format!("(let t{curr_id} {code})\n"));
|
||||
names.insert(n, format!("t{curr_id}"));
|
||||
curr_id += 1;
|
||||
}
|
||||
|
||||
// Emit synthetic Output nodes for boundary outputs
|
||||
for &brk in &subgraph.boundary_outputs {
|
||||
// The predecessor of the GraphBreak is the actual producer
|
||||
let pred = graph
|
||||
.graph
|
||||
.neighbors_directed(brk, Direction::Incoming)
|
||||
.next()
|
||||
.expect("GraphBreak must have exactly one input");
|
||||
let pred_name = names.get(&pred).cloned().unwrap_or_else(|| {
|
||||
panic!(
|
||||
"Missing egglog name for boundary output predecessor {:?}",
|
||||
pred
|
||||
)
|
||||
});
|
||||
let code = format!("(Output {} {})", pred_name, brk.index());
|
||||
out.push_str(&format!("(let t{curr_id} {code})\n"));
|
||||
names.insert(brk, format!("t{curr_id}"));
|
||||
curr_id += 1;
|
||||
}
|
||||
|
||||
// Join outputs: real outputs (nodes with no outgoing edges within the subgraph)
|
||||
// plus boundary outputs
|
||||
let mut output_names: Vec<String> = vec![];
|
||||
|
||||
// Boundary outputs
|
||||
for &brk in &subgraph.boundary_outputs {
|
||||
if let Some(name) = names.get(&brk) {
|
||||
output_names.push(name.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Real outputs: only actual Output HLIR ops that exist in this subgraph
|
||||
// (not arbitrary nodes that happen to have no subgraph successors)
|
||||
for &n in &subgraph.nodes {
|
||||
if graph.try_get_op::<Output>(n).is_some() {
|
||||
if let Some(name) = names.get(&n) {
|
||||
output_names.push(name.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if output_names.is_empty() {
|
||||
// Fallback: use the last node added
|
||||
output_names.push(format!("t{}", curr_id - 1));
|
||||
}
|
||||
|
||||
// Join with OutputJoin
|
||||
let mut root = output_names[0].clone();
|
||||
for node in output_names.into_iter().skip(1) {
|
||||
curr_id += 1;
|
||||
out.push_str(&format!("(let t{curr_id} (OutputJoin {root} {node}))\n"));
|
||||
root = format!("t{curr_id}");
|
||||
}
|
||||
|
||||
(out.replace("(MVar \"z\")", "(MIter)"), root)
|
||||
}
|
||||
|
||||
pub fn elist_to_egglog(shape: &[Expression]) -> String {
|
||||
list_to_egglog(
|
||||
&shape.iter().map(|e| e.to_egglog()).collect_vec(),
|
||||
@@ -589,41 +512,34 @@ fn termdag_to_egglog(td: &egglog::TermDag, root: egglog::TermId) -> (String, Str
|
||||
(out.replace("(MVar \"z\")", "(MIter)"), format!("t{root}"))
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub fn run_egglog(
|
||||
program: &str,
|
||||
root: &str,
|
||||
ops: &[Arc<Box<dyn EgglogOp>>],
|
||||
cleanup: bool,
|
||||
) -> Result<SerializedEGraph, egglog::Error> {
|
||||
let start = std::time::Instant::now();
|
||||
let code = early_egglog(program, root, ops, cleanup);
|
||||
let mut egraph = egglog::EGraph::default();
|
||||
let commands = egraph.parser.get_program_from_string(None, &code)?;
|
||||
let outputs = egraph.run_program(commands)?;
|
||||
let CommandOutput::ExtractBest(termdag, _cost, term) = outputs.last().unwrap() else {
|
||||
panic!();
|
||||
};
|
||||
let (program, root) = termdag_to_egglog(termdag, termdag.lookup(term));
|
||||
let code = full_egglog(&program, ops, cleanup);
|
||||
let mut egraph = egglog::EGraph::default();
|
||||
let commands = egraph.parser.get_program_from_string(None, &code)?;
|
||||
trace!("{}", "Egglog running...".green());
|
||||
let _outputs = egraph.run_program(commands)?;
|
||||
trace!("{}", "---- Egglog Rule Matches ----".green());
|
||||
fn stage_report(egraph: &egglog::EGraph, total_time: Duration) -> EgglogStageReport {
|
||||
let run_report = egraph.get_overall_run_report();
|
||||
EgglogStageReport {
|
||||
num_matches_per_rule: run_report
|
||||
.num_matches_per_rule
|
||||
.iter()
|
||||
.map(|(name, matches)| (name.to_string(), *matches))
|
||||
.collect(),
|
||||
search_and_apply_time_per_rule: run_report
|
||||
.search_and_apply_time_per_rule
|
||||
.iter()
|
||||
.map(|(name, elapsed)| (name.to_string(), *elapsed))
|
||||
.collect(),
|
||||
total_time,
|
||||
}
|
||||
}
|
||||
|
||||
fn trace_stage_report(header: &str, report: &EgglogStageReport) {
|
||||
trace!("{}", header.green());
|
||||
trace!(
|
||||
"{}",
|
||||
run_report
|
||||
report
|
||||
.num_matches_per_rule
|
||||
.iter()
|
||||
.filter(|(k, _)| !k.contains("("))
|
||||
.map(|(k, v)| format!(
|
||||
"{k}: {v} ({})",
|
||||
pretty_duration::pretty_duration(
|
||||
&run_report.search_and_apply_time_per_rule[k],
|
||||
None
|
||||
)
|
||||
pretty_duration::pretty_duration(&report.search_and_apply_time_per_rule[k], None)
|
||||
))
|
||||
.join("\n")
|
||||
.green()
|
||||
@@ -631,11 +547,73 @@ pub fn run_egglog(
|
||||
trace!(
|
||||
"{}",
|
||||
format!(
|
||||
"---- Egglog Took {} ----",
|
||||
pretty_duration::pretty_duration(&start.elapsed(), None).bold()
|
||||
"---- {} Took {} ----",
|
||||
header,
|
||||
pretty_duration::pretty_duration(&report.total_time, None).bold()
|
||||
)
|
||||
.green()
|
||||
);
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub fn run_egglog_with_report(
|
||||
program: &str,
|
||||
root: &str,
|
||||
ops: &[Arc<Box<dyn EgglogOp>>],
|
||||
cleanup: bool,
|
||||
) -> Result<(SerializedEGraph, EgglogRunReport), egglog::Error> {
|
||||
let op_parts = OpTextParts::new(ops, cleanup);
|
||||
run_egglog_with_report_parts(program, root, &op_parts)
|
||||
}
|
||||
|
||||
/// Same as [`run_egglog_with_report`], but takes pre-computed [`OpTextParts`].
|
||||
/// Useful when a caller runs many egglog invocations with the same op set
|
||||
/// and wants to factor the op-derived text work out of a parallel loop.
|
||||
/// Takes only `&str` / `&OpTextParts` inputs so the whole function is `Send`.
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub fn run_egglog_with_report_parts(
|
||||
program: &str,
|
||||
root: &str,
|
||||
op_parts: &OpTextParts,
|
||||
) -> Result<(SerializedEGraph, EgglogRunReport), egglog::Error> {
|
||||
let total_start = std::time::Instant::now();
|
||||
|
||||
let early_start = std::time::Instant::now();
|
||||
let code = early_egglog_with(program, root, op_parts);
|
||||
let mut egraph = egglog::EGraph::default();
|
||||
let commands = egraph.parser.get_program_from_string(None, &code)?;
|
||||
let outputs = egraph.run_program(commands)?;
|
||||
let early_report = stage_report(&egraph, early_start.elapsed());
|
||||
|
||||
let CommandOutput::ExtractBest(termdag, _cost, term) = outputs.last().unwrap() else {
|
||||
panic!();
|
||||
};
|
||||
let (program, root) = termdag_to_egglog(termdag, termdag.lookup(term));
|
||||
|
||||
let full_start = std::time::Instant::now();
|
||||
let code = full_egglog_with(&program, op_parts);
|
||||
let mut egraph = egglog::EGraph::default();
|
||||
let commands = egraph.parser.get_program_from_string(None, &code)?;
|
||||
trace!("{}", "Egglog running...".green());
|
||||
let _outputs = egraph.run_program(commands)?;
|
||||
let full_report = stage_report(&egraph, full_start.elapsed());
|
||||
trace_stage_report("---- Egglog Early Rule Matches ----", &early_report);
|
||||
trace_stage_report("---- Egglog Full Rule Matches ----", &full_report);
|
||||
|
||||
let run_report = EgglogRunReport {
|
||||
early: early_report,
|
||||
full: full_report,
|
||||
total_time: total_start.elapsed(),
|
||||
};
|
||||
trace!(
|
||||
"{}",
|
||||
format!(
|
||||
"---- Egglog Total Took {} ----",
|
||||
pretty_duration::pretty_duration(&run_report.total_time, None).bold()
|
||||
)
|
||||
.green()
|
||||
);
|
||||
|
||||
let (sort, value) = egraph.eval_expr(&var!(root))?;
|
||||
let s = egraph.serialize(egglog::SerializeConfig {
|
||||
root_eclasses: vec![(sort, value)],
|
||||
@@ -720,7 +698,28 @@ pub fn run_egglog(
|
||||
"No valid graphs present in the e-graph!"
|
||||
);
|
||||
|
||||
Ok(egraph)
|
||||
Ok((egraph, run_report))
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub fn run_egglog(
|
||||
program: &str,
|
||||
root: &str,
|
||||
ops: &[Arc<Box<dyn EgglogOp>>],
|
||||
cleanup: bool,
|
||||
) -> Result<SerializedEGraph, egglog::Error> {
|
||||
run_egglog_with_report(program, root, ops, cleanup).map(|(egraph, _)| egraph)
|
||||
}
|
||||
|
||||
/// Same as [`run_egglog`] but takes pre-computed [`OpTextParts`], so the
|
||||
/// whole function is `Send`. Used by the parallel grouped-egraphs build.
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub fn run_egglog_with(
|
||||
program: &str,
|
||||
root: &str,
|
||||
op_parts: &OpTextParts,
|
||||
) -> Result<SerializedEGraph, egglog::Error> {
|
||||
run_egglog_with_report_parts(program, root, op_parts).map(|(egraph, _)| egraph)
|
||||
}
|
||||
|
||||
pub fn extract_expr_list<'a>(
|
||||
@@ -996,19 +995,33 @@ pub fn validate_choice_set<'a>(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Hash a choice set for uniqueness checking
|
||||
pub fn hash_choice_set(choices: &EGraphChoiceSet) -> u64 {
|
||||
/// Hash a single (class_id, node_id) entry. Used both for the full
|
||||
/// choice-set hash and for the incremental updates in
|
||||
/// `extract_generation`.
|
||||
fn hash_choice_entry(class_id: &ClassId, node_id: &NodeId) -> u64 {
|
||||
let mut hasher = DefaultHasher::new();
|
||||
// Sort by ClassId for deterministic hashing
|
||||
let mut sorted: Vec<_> = choices.iter().collect();
|
||||
sorted.sort_by(|(k1, _), (k2, _)| k1.as_ref().cmp(k2.as_ref()));
|
||||
for (class_id, node_id) in sorted {
|
||||
class_id.hash(&mut hasher);
|
||||
node_id.hash(&mut hasher);
|
||||
}
|
||||
class_id.hash(&mut hasher);
|
||||
node_id.hash(&mut hasher);
|
||||
hasher.finish()
|
||||
}
|
||||
|
||||
/// Hash a choice set for uniqueness checking. Order-independent XOR
|
||||
/// of per-entry hashes. The XOR design lets `extract_generation`
|
||||
/// update the hash incrementally on each `insert(k, new)` by XORing
|
||||
/// out `hash_choice_entry(k, old)` and XORing in
|
||||
/// `hash_choice_entry(k, new)`, dropping the per-attempt cost from
|
||||
/// O(N log N) over the full choice set to O(M) where M = mutations
|
||||
/// applied. On large e-graphs (e.g. Gemma's ~3.5M-entry choice set)
|
||||
/// that's the difference between ~135 seconds and a few milliseconds
|
||||
/// per generation.
|
||||
pub fn hash_choice_set(choices: &EGraphChoiceSet) -> u64 {
|
||||
let mut h = 0u64;
|
||||
for (k, v) in choices {
|
||||
h ^= hash_choice_entry(k, v);
|
||||
}
|
||||
h
|
||||
}
|
||||
|
||||
/// Extract a generation of mutated offspring from a base genome.
|
||||
///
|
||||
/// Takes a base `EGraphChoiceSet` and produces up to `generation_size` mutated offspring,
|
||||
@@ -1048,25 +1061,38 @@ pub fn extract_generation<'a>(
|
||||
// Limit attempts to avoid infinite loops when search space is exhausted
|
||||
let max_attempts = generation_size * 100;
|
||||
let mut attempts = 0;
|
||||
// Compute the base's full hash exactly once. Each attempt starts from
|
||||
// this and applies XOR diffs for its mutations — no per-attempt
|
||||
// O(N log N) sort+hash over the full choice set.
|
||||
let base_hash = hash_choice_set(base);
|
||||
|
||||
while offspring.len() < generation_size && attempts < max_attempts {
|
||||
attempts += 1;
|
||||
|
||||
// Create a mutated offspring from base
|
||||
let mut child = base.clone();
|
||||
let mut child_hash = base_hash;
|
||||
|
||||
for _ in 0..rng.random_range(1..=mutations_per_generation) {
|
||||
// Pick a random mutable eclass
|
||||
let class_id = mutable_classes[rng.random_range(0..mutable_classes.len())];
|
||||
let (_, enodes) = &egraph.eclasses[class_id];
|
||||
// Pick a random enode for this class
|
||||
child.insert(class_id, &enodes[rng.random_range(0..enodes.len())]);
|
||||
let new_node = &enodes[rng.random_range(0..enodes.len())];
|
||||
// Insert returns the previous binding (if any); fold the diff
|
||||
// into the running hash. If the new pick equals the old one,
|
||||
// the two XORs cancel and `child_hash` is unchanged — exactly
|
||||
// the right behaviour.
|
||||
let old_node = child.insert(class_id, new_node);
|
||||
if let Some(old_node) = old_node {
|
||||
child_hash ^= hash_choice_entry(class_id, old_node);
|
||||
}
|
||||
child_hash ^= hash_choice_entry(class_id, new_node);
|
||||
}
|
||||
|
||||
// Hash and check if seen before
|
||||
let h = hash_choice_set(&child);
|
||||
if !prev_selected.contains(&h) {
|
||||
prev_selected.insert(h);
|
||||
if !prev_selected.contains(&child_hash) {
|
||||
prev_selected.insert(child_hash);
|
||||
offspring.push(child);
|
||||
}
|
||||
}
|
||||
@@ -1104,11 +1130,34 @@ pub fn egglog_to_llir<'a>(
|
||||
list_cache: &mut FxHashMap<&'a NodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a NodeId, Expression>,
|
||||
custom_op_id_remap: Option<&FxHashMap<usize, usize>>,
|
||||
) -> LLIRGraph {
|
||||
egglog_to_llir_from_root(
|
||||
egraph,
|
||||
choices,
|
||||
ops,
|
||||
custom_ops,
|
||||
list_cache,
|
||||
expr_cache,
|
||||
custom_op_id_remap,
|
||||
&egraph.roots[0],
|
||||
)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn egglog_to_llir_from_root<'a>(
|
||||
egraph: &'a SerializedEGraph,
|
||||
choices: EGraphChoiceSet<'a>,
|
||||
ops: &'a Vec<Arc<Box<dyn EgglogOp>>>,
|
||||
custom_ops: &[Box<dyn CustomOp>],
|
||||
list_cache: &mut FxHashMap<&'a NodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a NodeId, Expression>,
|
||||
custom_op_id_remap: Option<&FxHashMap<usize, usize>>,
|
||||
root_class: &ClassId,
|
||||
) -> LLIRGraph {
|
||||
// Make reachability set from root
|
||||
let mut reachable = FxHashSet::default();
|
||||
reachable.insert(choices[&egraph.roots[0]]);
|
||||
let mut reachability_stack = vec![choices[&egraph.roots[0]]];
|
||||
reachable.insert(choices[root_class]);
|
||||
let mut reachability_stack = vec![choices[root_class]];
|
||||
while let Some(r) = reachability_stack.pop() {
|
||||
for ch in &egraph.enodes[r].1 {
|
||||
if egraph.eclasses[ch].0.contains("IR") || egraph.eclasses[ch].0.contains("IList") {
|
||||
@@ -1123,12 +1172,17 @@ pub fn egglog_to_llir<'a>(
|
||||
let mut graph = LLIRGraph::default();
|
||||
let mut edges_to_place = vec![];
|
||||
let mut enode_to_node = FxHashMap::default();
|
||||
for &node in choices.values() {
|
||||
if !reachable.contains(node) {
|
||||
continue;
|
||||
}
|
||||
// Iterate the small reachable set rather than the full choice set.
|
||||
// On large e-graphs (e.g., Gemma's ~3.48M-entry choice set produced
|
||||
// by the binary-fusion grow rules cascading through super-block
|
||||
// chains), `reachable` is ~3K nodes and the choice set is ~1000×
|
||||
// larger. Filtering the choice set against `reachable` was
|
||||
// dominating per-candidate `egglog_to_llir` time.
|
||||
for &node in &reachable {
|
||||
if egraph.eclasses[&egraph.node_to_class[node]].0 != "IR" {
|
||||
// Skip IList / OpKind
|
||||
// Skip IList enodes — `reachable` includes them because the
|
||||
// reachability walk follows IList children, but only IR
|
||||
// enodes become LLIR nodes.
|
||||
continue;
|
||||
}
|
||||
let enode_label = egraph.enodes[node].0.as_str();
|
||||
@@ -1229,135 +1283,10 @@ pub fn egglog_to_llir<'a>(
|
||||
// )
|
||||
// .unwrap();
|
||||
// }
|
||||
// Loop markers (LoopStart/End/Input/InputStatic/Output) are intentionally
|
||||
// preserved here — `crate::graph::collapse_loops_to_first_iter` produces
|
||||
// a single-iteration LLIR for fast per-candidate profiling, and the full
|
||||
// `crate::graph::unroll_loops_in_llir` runs once on the chosen best LLIR
|
||||
// before it is loaded into the runtime.
|
||||
graph
|
||||
}
|
||||
|
||||
/// Merge multiple per-chunk LLIR graphs into a single LLIR graph,
|
||||
/// resolving boundary Input/Output nodes at graph break boundaries.
|
||||
pub fn stitch_llir_graphs(
|
||||
chunk_llirs: &[LLIRGraph],
|
||||
descriptors: &[SubgraphDescriptor],
|
||||
) -> LLIRGraph {
|
||||
use petgraph::stable_graph::NodeIndex;
|
||||
|
||||
let mut merged = LLIRGraph::default();
|
||||
|
||||
// Collect the set of boundary break_node indices for matching
|
||||
let mut boundary_output_set: FxHashSet<usize> = FxHashSet::default();
|
||||
let mut boundary_input_set: FxHashSet<usize> = FxHashSet::default();
|
||||
for desc in descriptors {
|
||||
for brk in &desc.boundary_outputs {
|
||||
boundary_output_set.insert(brk.index());
|
||||
}
|
||||
for bi in &desc.boundary_inputs {
|
||||
boundary_input_set.insert(bi.break_node.index());
|
||||
}
|
||||
}
|
||||
|
||||
// Per-chunk node mapping: old NodeIndex -> new NodeIndex in merged graph
|
||||
let mut node_maps: Vec<FxHashMap<NodeIndex, NodeIndex>> = Vec::with_capacity(chunk_llirs.len());
|
||||
|
||||
// Track boundary producers: break_node_index -> new NodeIndex of the actual producer
|
||||
let mut boundary_producers: FxHashMap<usize, NodeIndex> = FxHashMap::default();
|
||||
|
||||
// Track real Input node deduplication: Input.node -> new NodeIndex
|
||||
let mut real_inputs: FxHashMap<usize, NodeIndex> = FxHashMap::default();
|
||||
|
||||
for (_chunk_idx, chunk_graph) in chunk_llirs.iter().enumerate() {
|
||||
let mut this_map: FxHashMap<NodeIndex, NodeIndex> = FxHashMap::default();
|
||||
|
||||
// Pass 1: Add all non-boundary nodes
|
||||
for old_node in chunk_graph.node_indices() {
|
||||
let op = &chunk_graph[old_node];
|
||||
|
||||
// Check if this is a boundary Output
|
||||
if let Some(output_op) = op.to_op::<Output>() {
|
||||
if boundary_output_set.contains(&output_op.node) {
|
||||
// Skip — will resolve in pass 2
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// Check if this is a boundary Input
|
||||
if let Some(input_op) = op.to_op::<Input>() {
|
||||
if boundary_input_set.contains(&input_op.node) {
|
||||
// Skip — will resolve in pass 2
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check if this is a real Input that was already added (dedup)
|
||||
if let Some(&existing) = real_inputs.get(&input_op.node) {
|
||||
this_map.insert(old_node, existing);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
let new_node = merged.add_node(op.clone());
|
||||
this_map.insert(old_node, new_node);
|
||||
|
||||
// Track real inputs for deduplication
|
||||
if let Some(input_op) = op.to_op::<Input>() {
|
||||
real_inputs.insert(input_op.node, new_node);
|
||||
}
|
||||
}
|
||||
|
||||
// Pass 2: Resolve boundary Output nodes (record the producer)
|
||||
for old_node in chunk_graph.node_indices() {
|
||||
let op = &chunk_graph[old_node];
|
||||
if let Some(output_op) = op.to_op::<Output>() {
|
||||
if boundary_output_set.contains(&output_op.node) {
|
||||
// Find the predecessor (the actual producer)
|
||||
let pred = chunk_graph
|
||||
.neighbors_directed(old_node, petgraph::Direction::Incoming)
|
||||
.next()
|
||||
.expect("Boundary Output must have exactly one input");
|
||||
if let Some(&producer_new) = this_map.get(&pred) {
|
||||
boundary_producers.insert(output_op.node, producer_new);
|
||||
} else {
|
||||
eprintln!(
|
||||
"[stitch] WARNING: chunk {}: boundary Output node={} predecessor {:?} not in this_map!",
|
||||
_chunk_idx,
|
||||
output_op.node,
|
||||
pred.index()
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Pass 2b: Resolve boundary Input nodes (map to producer from prior chunk)
|
||||
for old_node in chunk_graph.node_indices() {
|
||||
let op = &chunk_graph[old_node];
|
||||
if let Some(input_op) = op.to_op::<Input>() {
|
||||
if boundary_input_set.contains(&input_op.node) {
|
||||
if let Some(&producer) = boundary_producers.get(&input_op.node) {
|
||||
this_map.insert(old_node, producer);
|
||||
} else {
|
||||
eprintln!(
|
||||
"[stitch] WARNING: chunk {}: boundary Input node={} has no producer in boundary_producers!",
|
||||
_chunk_idx, input_op.node
|
||||
);
|
||||
eprintln!(
|
||||
"[stitch] available producers: {:?}",
|
||||
boundary_producers.keys().collect::<Vec<_>>()
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Pass 3: Add edges (preserving duplicate edges for ops like x*x)
|
||||
for edge in chunk_graph.edge_indices() {
|
||||
let (src, dst) = chunk_graph.edge_endpoints(edge).unwrap();
|
||||
if let (Some(&new_src), Some(&new_dst)) = (this_map.get(&src), this_map.get(&dst)) {
|
||||
if new_src != new_dst {
|
||||
merged.add_edge(new_src, new_dst, ());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
node_maps.push(this_map);
|
||||
}
|
||||
|
||||
merged
|
||||
}
|
||||
|
||||
@@ -403,13 +403,24 @@ impl GraphTensor {
|
||||
DType::Int,
|
||||
"Scatter indexes must have an integer dtype!"
|
||||
);
|
||||
// Pad src_strides with leading zero-strides when src has lower rank
|
||||
// than indexes. A zero stride reads the same src element at every
|
||||
// index position — matches PyTorch's broadcast semantics for
|
||||
// `x[idx] = scalar`. Without this, KernelScatter::compile calls
|
||||
// flatten_strides(index_shape, src_strides) with mismatched lengths
|
||||
// and panics with `assertion `left == right` failed, left: 1 right: 0`.
|
||||
let mut src_strides = self.shape.strides.to_vec();
|
||||
let target_rank = indexes.shape.dims.len();
|
||||
while src_strides.len() < target_rank {
|
||||
src_strides.insert(0, Expression::from(0));
|
||||
}
|
||||
let id = self.graph().add_op(
|
||||
Scatter {
|
||||
dest_shape: dest.shape.dims.to_vec(),
|
||||
dest_strides: dest.shape.strides.to_vec(),
|
||||
index_shape: indexes.shape.dims.to_vec(),
|
||||
index_strides: indexes.shape.strides.to_vec(),
|
||||
src_strides: self.shape.strides.to_vec(),
|
||||
src_strides,
|
||||
},
|
||||
&[dest.id, indexes.id, self.id],
|
||||
);
|
||||
|
||||
@@ -105,6 +105,9 @@ impl GraphTensor {
|
||||
if let Some(gmem) = self.graph().try_get_op_mut::<Input>(self.id) {
|
||||
gmem.dtype = dtype;
|
||||
}
|
||||
if let Some((_, d)) = self.graph().input_meta.get_mut(&self.id) {
|
||||
*d = dtype;
|
||||
}
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
@@ -152,16 +152,6 @@ impl GraphTensor {
|
||||
GraphTensor::from_id(new_id, self.shape.contiguous(), self.graph_ref, self.dtype)
|
||||
}
|
||||
|
||||
pub fn graph_break(self) -> GraphTensor {
|
||||
let new_id = self.graph().add_op(
|
||||
crate::hlir::GraphBreak {
|
||||
input_shape: self.shape,
|
||||
},
|
||||
&[self.id],
|
||||
);
|
||||
GraphTensor::from_id(new_id, self.shape.contiguous(), self.graph_ref, self.dtype)
|
||||
}
|
||||
|
||||
/// Scale so std is 1.0
|
||||
pub fn std_norm<T>(self, axes: impl ToAxes, epsilon: T) -> GraphTensor
|
||||
where
|
||||
@@ -663,7 +653,7 @@ pub(super) mod tests {
|
||||
let mut out: Vec<(NotNan<f32>, usize)> =
|
||||
heap.into_iter().map(|std::cmp::Reverse(t)| t).collect();
|
||||
|
||||
out.sort_unstable_by(|a, b| b.0.cmp(&a.0));
|
||||
out.sort_unstable_by_key(|b| std::cmp::Reverse(b.0));
|
||||
out.into_iter().map(|(_, i)| i).collect()
|
||||
}
|
||||
test_unary(
|
||||
|
||||
2838
src/graph.rs
2838
src/graph.rs
File diff suppressed because it is too large
Load Diff
761
src/hlir.rs
761
src/hlir.rs
@@ -119,6 +119,90 @@ pub fn binary_sort(name: &str) -> SortDef {
|
||||
)
|
||||
}
|
||||
|
||||
/// Generate egglog rewrite rules that union a small rolled `body=1, trips=N`
|
||||
/// single-binary-op loop with its fully-unrolled equivalent in the same
|
||||
/// eclass. Both representations coexist; the cost-based extractor picks
|
||||
/// whichever one downstream patterns prefer — the unrolled form when fusions
|
||||
/// (e.g. GLUMoE GemmaGELU, KernelExp's `direct-exp-fusion`) match through
|
||||
/// the flat chain, the rolled form otherwise. Without these unions, rolling
|
||||
/// a tiny chain blocks the fusion entirely and the extracted graph is
|
||||
/// strictly worse than not rolling.
|
||||
///
|
||||
/// **Register in both `EgglogOp::early_rewrites()` AND `rewrites()`.** The
|
||||
/// driver feeds `early_rewrites` into the early-stage program only and
|
||||
/// `rewrites` into the full-stage program only; we need the unrolled chain
|
||||
/// visible in both stages so early-stage fusion patterns (GLUMoE) AND
|
||||
/// full-stage kernel rewrites (`direct-exp-fusion`) can both match it.
|
||||
///
|
||||
/// Generates 2 rules per iter count (state at body input position 0 vs 1)
|
||||
/// for every `n_iters` in `2..=max_trips`. Larger trips stay rolled-only —
|
||||
/// real transformer-block rolls are body ≫ 1 anyway, and carrying both
|
||||
/// forms beyond a small N adds search-time cost without an upside.
|
||||
///
|
||||
/// Each rule matches the rolled shape `LoopEnd(body)` where `body` is the
|
||||
/// binary op consuming `LoopStart(initial)` and `LoopInput(s0..s_{N-1})`,
|
||||
/// and unions `LoopEnd` with the chain
|
||||
/// `u0 = <kind>(initial, s0); u1 = <kind>(u0, s1); … u_{N-1}`.
|
||||
/// (or symmetric for state at position 1.)
|
||||
pub fn binary_op_unroll_rules(op_kind: &str, max_trips: usize) -> Vec<Rule> {
|
||||
let mut rules = Vec::with_capacity((max_trips.saturating_sub(1)) * 2);
|
||||
for n_iters in 2..=max_trips {
|
||||
for state_pos in 0..2 {
|
||||
rules.push(binary_op_unroll_rule(op_kind, n_iters, state_pos));
|
||||
}
|
||||
}
|
||||
rules
|
||||
}
|
||||
|
||||
fn binary_op_unroll_rule(op_kind: &str, n_iters: usize, state_pos: usize) -> Rule {
|
||||
// Swap (state, per_iter) → (input0, input1) by `state_pos`. Both the
|
||||
// body match pattern and the unrolled chain bodies follow this mapping
|
||||
// so a/b stride positions stay aligned.
|
||||
debug_assert!(state_pos < 2);
|
||||
let order = |state: &str, per_iter: &str| -> String {
|
||||
if state_pos == 0 {
|
||||
format!("(ICons {state} (ICons {per_iter} (INil)))")
|
||||
} else {
|
||||
format!("(ICons {per_iter} (ICons {state} (INil)))")
|
||||
}
|
||||
};
|
||||
let li_sources = (0..n_iters).rev().fold(String::from("(INil)"), |acc, i| {
|
||||
format!("(ICons ?s{i} {acc})")
|
||||
});
|
||||
let chain = (0..n_iters)
|
||||
.map(|i| {
|
||||
let prev = if i == 0 {
|
||||
"?initial".to_string()
|
||||
} else {
|
||||
format!("?u{}", i - 1)
|
||||
};
|
||||
format!(
|
||||
" (let ?u{i} (Op ({op_kind} ?sh ?as ?bs ?os) {}))",
|
||||
order(&prev, &format!("?s{i}"))
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
Rule::raw(format!(
|
||||
"(rule
|
||||
(
|
||||
(= ?ls (LoopStart ?initial ?loop_id ?slot_idx (MNum {n_iters}) ?dt))
|
||||
(= ?li (Op (LoopInput ?loop_id ?stream ?dt) {li_sources}))
|
||||
(= ?body (Op ({op_kind} ?sh ?as ?bs ?os) {body_pat}))
|
||||
(= ?le (LoopEnd ?body ?loop_id ?slot_idx ?dt))
|
||||
)
|
||||
(
|
||||
{chain}
|
||||
(union ?le ?u{last})
|
||||
)
|
||||
:ruleset expr
|
||||
:name \"unroll {op_kind} body trips={n_iters} state={state_pos}\"
|
||||
)",
|
||||
body_pat = order("?ls", "?li"),
|
||||
last = n_iters - 1,
|
||||
))
|
||||
}
|
||||
|
||||
/// Reduce op kind: (shape: EList, iters: Expression, strides: EList, iter_stride: Expression, out_strides: EList), IList: [inp]
|
||||
pub fn reduce_sort(name: &str) -> SortDef {
|
||||
sort(
|
||||
@@ -138,6 +222,12 @@ pub type HLIROps = (
|
||||
Input,
|
||||
Output,
|
||||
CustomOpKind,
|
||||
LoopStart,
|
||||
LoopEnd,
|
||||
LoopInput,
|
||||
LoopInputStatic,
|
||||
LoopOutput,
|
||||
LoopOutputSelect,
|
||||
Constant,
|
||||
Cast,
|
||||
Iota,
|
||||
@@ -336,6 +426,607 @@ impl NativeOp for CustomOpKind {
|
||||
}
|
||||
}
|
||||
|
||||
// --- Loop ops ---------------------------------------------------------------
|
||||
//
|
||||
// Automatic loop-rolling replaces N unrolled copies of a repeating body with
|
||||
// a single body plus structural marker ops. All four ops in one loop share a
|
||||
// `loop_id`. `iters` lives on `LoopStart` only; every other op references the
|
||||
// same loop via `loop_id`.
|
||||
//
|
||||
// LoopStart — one per loop-carried slot; takes the initial value, yields
|
||||
// the current iteration's value into the body.
|
||||
// LoopEnd — mirror of LoopStart; takes the body's final value for the
|
||||
// slot, yields the post-loop value.
|
||||
// LoopInput — OpKind (variable-arity). Takes N input tensors (one per
|
||||
// iteration) and yields the current iteration's tensor.
|
||||
// LoopOutput — OpKind (variable-arity, sink). Takes the body's value + N
|
||||
// target tensors; writes body[i] -> target[i] each iteration.
|
||||
//
|
||||
// Execution semantics and iteration driving live in the runtime compilation
|
||||
// step; these ops just carry the structure through HLIR/egglog/LLIR.
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct LoopStart {
|
||||
pub loop_id: usize,
|
||||
pub slot_idx: usize,
|
||||
pub iters: Expression,
|
||||
pub dtype: DType,
|
||||
}
|
||||
|
||||
impl Display for LoopStart {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"LoopStart(id={}, slot={}, iters={:?}, {})",
|
||||
self.loop_id, self.slot_idx, self.iters, self.dtype
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl EgglogOp for LoopStart {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
IR,
|
||||
"LoopStart",
|
||||
&[
|
||||
("inp", IR),
|
||||
("loop_id", I64),
|
||||
("slot_idx", I64),
|
||||
("iters", EXPRESSION),
|
||||
("dtype", DTYPE),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![dtype_from_field_rule(&self.sort(), "dtype")]
|
||||
}
|
||||
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
_input_enodes: Vec<&'a ENodeId>,
|
||||
_: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
let loop_id = egraph.enodes[kind_children[1]]
|
||||
.0
|
||||
.replace("\"", "")
|
||||
.parse::<usize>()
|
||||
.unwrap();
|
||||
let slot_idx = egraph.enodes[kind_children[2]]
|
||||
.0
|
||||
.replace("\"", "")
|
||||
.parse::<usize>()
|
||||
.unwrap();
|
||||
let iters = extract_expr(egraph, kind_children[3], expr_cache).unwrap();
|
||||
let dtype = extract_dtype(egraph, kind_children[4]);
|
||||
(
|
||||
LLIROp::new::<LoopStart>(Box::new(Self {
|
||||
loop_id,
|
||||
slot_idx,
|
||||
iters,
|
||||
dtype,
|
||||
})),
|
||||
vec![kind_children[0]],
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl HLIROp for LoopStart {
|
||||
fn to_egglog(&self, inp: &[(NodeIndex, String)]) -> String {
|
||||
format!(
|
||||
"(LoopStart {} {} {} {} ({:?}))",
|
||||
inp[0].1,
|
||||
self.loop_id,
|
||||
self.slot_idx,
|
||||
self.iters.to_egglog(),
|
||||
self.dtype,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl NativeOp for LoopStart {
|
||||
fn execute(&self, _: Vec<&NativeData>, _: &FxHashMap<char, usize>) -> NativeData {
|
||||
unimplemented!("LoopStart is driven by the runtime loop compiler")
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct LoopEnd {
|
||||
pub loop_id: usize,
|
||||
pub slot_idx: usize,
|
||||
pub dtype: DType,
|
||||
}
|
||||
|
||||
impl Display for LoopEnd {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"LoopEnd(id={}, slot={}, {})",
|
||||
self.loop_id, self.slot_idx, self.dtype
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl EgglogOp for LoopEnd {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
IR,
|
||||
"LoopEnd",
|
||||
&[
|
||||
("inp", IR),
|
||||
("loop_id", I64),
|
||||
("slot_idx", I64),
|
||||
("dtype", DTYPE),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![dtype_from_field_rule(&self.sort(), "dtype")]
|
||||
}
|
||||
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
_input_enodes: Vec<&'a ENodeId>,
|
||||
_: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
_: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
let loop_id = egraph.enodes[kind_children[1]]
|
||||
.0
|
||||
.replace("\"", "")
|
||||
.parse::<usize>()
|
||||
.unwrap();
|
||||
let slot_idx = egraph.enodes[kind_children[2]]
|
||||
.0
|
||||
.replace("\"", "")
|
||||
.parse::<usize>()
|
||||
.unwrap();
|
||||
let dtype = extract_dtype(egraph, kind_children[3]);
|
||||
(
|
||||
LLIROp::new::<LoopEnd>(Box::new(Self {
|
||||
loop_id,
|
||||
slot_idx,
|
||||
dtype,
|
||||
})),
|
||||
vec![kind_children[0]],
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl HLIROp for LoopEnd {
|
||||
fn to_egglog(&self, inp: &[(NodeIndex, String)]) -> String {
|
||||
format!(
|
||||
"(LoopEnd {} {} {} ({:?}))",
|
||||
inp[0].1, self.loop_id, self.slot_idx, self.dtype,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl NativeOp for LoopEnd {
|
||||
fn execute(&self, _: Vec<&NativeData>, _: &FxHashMap<char, usize>) -> NativeData {
|
||||
unimplemented!("LoopEnd is driven by the runtime loop compiler")
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct LoopInput {
|
||||
pub loop_id: usize,
|
||||
pub stream_id: usize,
|
||||
pub dtype: DType,
|
||||
}
|
||||
|
||||
impl Display for LoopInput {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"LoopInput(id={}, stream={}, {})",
|
||||
self.loop_id, self.stream_id, self.dtype
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl EgglogOp for LoopInput {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
"LoopInput",
|
||||
&[("loop_id", I64), ("stream_id", I64), ("dtype", DTYPE)],
|
||||
)
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![dtype_from_kind_field(&self.sort(), "dtype")]
|
||||
}
|
||||
|
||||
fn early_rewrites(&self) -> Vec<Rule> {
|
||||
// Declare the `identical_inputs` relation and the three-way unification
|
||||
// chain between `LoopInput`, `LoopInputStatic`, and an inlined source.
|
||||
// Running in Stage 1 alongside fusion rules (e.g. GLUMoE) so that
|
||||
// fusion patterns that expect raw op kinds at boundary positions can
|
||||
// match via the unioned eclass.
|
||||
vec![Rule::raw(
|
||||
r#"
|
||||
(relation identical_inputs (IList))
|
||||
|
||||
; All four rules live in the `expr` ruleset, which the early/full
|
||||
; schedules saturate each iteration. Default-ruleset scheduling
|
||||
; only runs each rule once per outer step, which is not enough to
|
||||
; propagate `identical_inputs` through an N-element IList.
|
||||
|
||||
; Base: single-element list is trivially identical.
|
||||
(rule ((= ?l (ICons ?x (INil))))
|
||||
((identical_inputs ?l))
|
||||
:ruleset expr
|
||||
:name "identical_inputs base")
|
||||
|
||||
; Inductive: head equals next-head, and the tail starting at next-head is identical.
|
||||
(rule ((= ?l (ICons ?x (ICons ?x ?tail)))
|
||||
(identical_inputs (ICons ?x ?tail)))
|
||||
((identical_inputs ?l))
|
||||
:ruleset expr
|
||||
:name "identical_inputs ind")
|
||||
|
||||
; LoopInput with an identical IList is equivalent to LoopInputStatic over a single copy.
|
||||
(rule ((= ?e (Op (LoopInput ?id ?stream ?dt) (ICons ?x ?cont)))
|
||||
(identical_inputs (ICons ?x ?cont)))
|
||||
((let ?static (Op (LoopInputStatic ?id ?stream ?dt) (ICons ?x (INil))))
|
||||
(union ?e ?static))
|
||||
:ruleset expr
|
||||
:name "LoopInput to LoopInputStatic")
|
||||
|
||||
; LoopInputStatic is equivalent to its single inner value — collapses the boundary
|
||||
; wrapper for pattern-matching and extraction purposes.
|
||||
(rule ((= ?e (Op (LoopInputStatic ?id ?stream ?dt) (ICons ?x (INil)))))
|
||||
((union ?e ?x))
|
||||
:ruleset expr
|
||||
:name "LoopInputStatic inline")
|
||||
"#,
|
||||
)]
|
||||
}
|
||||
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
_: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
_: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
let loop_id = egraph.enodes[kind_children[0]]
|
||||
.0
|
||||
.replace("\"", "")
|
||||
.parse::<usize>()
|
||||
.unwrap();
|
||||
let stream_id = egraph.enodes[kind_children[1]]
|
||||
.0
|
||||
.replace("\"", "")
|
||||
.parse::<usize>()
|
||||
.unwrap();
|
||||
let dtype = extract_dtype(egraph, kind_children[2]);
|
||||
(
|
||||
LLIROp::new::<LoopInput>(Box::new(Self {
|
||||
loop_id,
|
||||
stream_id,
|
||||
dtype,
|
||||
})),
|
||||
input_enodes,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl HLIROp for LoopInput {
|
||||
fn to_egglog(&self, inp: &[(NodeIndex, String)]) -> String {
|
||||
format!(
|
||||
"(Op (LoopInput {} {} ({:?})) {})",
|
||||
self.loop_id,
|
||||
self.stream_id,
|
||||
self.dtype,
|
||||
list_to_egglog(&inp.iter().map(|i| &i.1).collect_vec(), "ICons", "INil"),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl NativeOp for LoopInput {
|
||||
fn execute(&self, _: Vec<&NativeData>, _: &FxHashMap<char, usize>) -> NativeData {
|
||||
unimplemented!("LoopInput is driven by the runtime loop compiler")
|
||||
}
|
||||
}
|
||||
|
||||
/// Iteration-independent boundary input: the same value flows into every
|
||||
/// iteration of a loop. Structurally a `LoopInput` whose per-iteration
|
||||
/// sources have all been proven equal (via the `identical_inputs` egglog
|
||||
/// relation) collapses into `LoopInputStatic` with a single-element IList,
|
||||
/// and that in turn collapses via a further rewrite into just its inner
|
||||
/// value — so egglog search can explore any of the three representations.
|
||||
/// At unroll time `LoopInputStatic` lowers to a plain edge: every cloned
|
||||
/// body node in every iteration references the single shared source.
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct LoopInputStatic {
|
||||
pub loop_id: usize,
|
||||
pub stream_id: usize,
|
||||
pub dtype: DType,
|
||||
}
|
||||
|
||||
impl Display for LoopInputStatic {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"LoopInputStatic(id={}, stream={}, {})",
|
||||
self.loop_id, self.stream_id, self.dtype
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl EgglogOp for LoopInputStatic {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
"LoopInputStatic",
|
||||
&[("loop_id", I64), ("stream_id", I64), ("dtype", DTYPE)],
|
||||
)
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![dtype_from_kind_field(&self.sort(), "dtype")]
|
||||
}
|
||||
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
_: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
_: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
let loop_id = egraph.enodes[kind_children[0]]
|
||||
.0
|
||||
.replace("\"", "")
|
||||
.parse::<usize>()
|
||||
.unwrap();
|
||||
let stream_id = egraph.enodes[kind_children[1]]
|
||||
.0
|
||||
.replace("\"", "")
|
||||
.parse::<usize>()
|
||||
.unwrap();
|
||||
let dtype = extract_dtype(egraph, kind_children[2]);
|
||||
(
|
||||
LLIROp::new::<LoopInputStatic>(Box::new(Self {
|
||||
loop_id,
|
||||
stream_id,
|
||||
dtype,
|
||||
})),
|
||||
input_enodes,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl HLIROp for LoopInputStatic {
|
||||
fn to_egglog(&self, inp: &[(NodeIndex, String)]) -> String {
|
||||
format!(
|
||||
"(Op (LoopInputStatic {} {} ({:?})) {})",
|
||||
self.loop_id,
|
||||
self.stream_id,
|
||||
self.dtype,
|
||||
list_to_egglog(&inp.iter().map(|i| &i.1).collect_vec(), "ICons", "INil"),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl NativeOp for LoopInputStatic {
|
||||
fn execute(&self, _: Vec<&NativeData>, _: &FxHashMap<char, usize>) -> NativeData {
|
||||
unimplemented!("LoopInputStatic is driven by the runtime loop compiler")
|
||||
}
|
||||
}
|
||||
|
||||
/// Marker for the per-iter output stream of a rolled loop. Mirrors `LoopInput`
|
||||
/// in reverse: a single body producer (one incoming edge) feeds the marker, and
|
||||
/// `LoopOutputSelect(i)` nodes hang off it to pluck iteration `i`'s value for
|
||||
/// downstream consumers (any post-region op — `Output` HLIR, downstream
|
||||
/// computation, etc.).
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct LoopOutput {
|
||||
pub loop_id: usize,
|
||||
pub stream_id: usize,
|
||||
pub dtype: DType,
|
||||
}
|
||||
|
||||
impl Display for LoopOutput {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"LoopOutput(id={}, stream={}, {})",
|
||||
self.loop_id, self.stream_id, self.dtype
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl EgglogOp for LoopOutput {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
"LoopOutput",
|
||||
&[("loop_id", I64), ("stream_id", I64), ("dtype", DTYPE)],
|
||||
)
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![dtype_from_kind_field(&self.sort(), "dtype")]
|
||||
}
|
||||
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
_: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
_: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
let loop_id = egraph.enodes[kind_children[0]]
|
||||
.0
|
||||
.replace("\"", "")
|
||||
.parse::<usize>()
|
||||
.unwrap();
|
||||
let stream_id = egraph.enodes[kind_children[1]]
|
||||
.0
|
||||
.replace("\"", "")
|
||||
.parse::<usize>()
|
||||
.unwrap();
|
||||
let dtype = extract_dtype(egraph, kind_children[2]);
|
||||
(
|
||||
LLIROp::new::<LoopOutput>(Box::new(Self {
|
||||
loop_id,
|
||||
stream_id,
|
||||
dtype,
|
||||
})),
|
||||
input_enodes,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl HLIROp for LoopOutput {
|
||||
fn to_egglog(&self, inp: &[(NodeIndex, String)]) -> String {
|
||||
format!(
|
||||
"(Op (LoopOutput {} {} ({:?})) {})",
|
||||
self.loop_id,
|
||||
self.stream_id,
|
||||
self.dtype,
|
||||
list_to_egglog(&inp.iter().map(|i| &i.1).collect_vec(), "ICons", "INil"),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl NativeOp for LoopOutput {
|
||||
fn execute(&self, _: Vec<&NativeData>, _: &FxHashMap<char, usize>) -> NativeData {
|
||||
unimplemented!("LoopOutput is driven by the runtime loop compiler")
|
||||
}
|
||||
}
|
||||
|
||||
/// Per-iteration extractor for a `LoopOutput` stream. Mirrors a per-iter
|
||||
/// `LoopInput` source slot in reverse: every cross-region edge that originally
|
||||
/// went from iteration `i`'s body producer to a post-region consumer is
|
||||
/// rewired through `LoopOutputSelect { iter: i, ... }`. At unroll time
|
||||
/// `Select(i)` lowers to the iter-`i` body clone's producer; at collapse time
|
||||
/// every Select lowers to iter-0's producer.
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct LoopOutputSelect {
|
||||
pub loop_id: usize,
|
||||
pub stream_id: usize,
|
||||
pub iter: usize,
|
||||
pub dtype: DType,
|
||||
}
|
||||
|
||||
impl Display for LoopOutputSelect {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"LoopOutputSelect(id={}, stream={}, iter={}, {})",
|
||||
self.loop_id, self.stream_id, self.iter, self.dtype
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl EgglogOp for LoopOutputSelect {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
"LoopOutputSelect",
|
||||
&[
|
||||
("loop_id", I64),
|
||||
("stream_id", I64),
|
||||
("iter", I64),
|
||||
("dtype", DTYPE),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![dtype_from_kind_field(&self.sort(), "dtype")]
|
||||
}
|
||||
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
_: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
_: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
let loop_id = egraph.enodes[kind_children[0]]
|
||||
.0
|
||||
.replace("\"", "")
|
||||
.parse::<usize>()
|
||||
.unwrap();
|
||||
let stream_id = egraph.enodes[kind_children[1]]
|
||||
.0
|
||||
.replace("\"", "")
|
||||
.parse::<usize>()
|
||||
.unwrap();
|
||||
let iter = egraph.enodes[kind_children[2]]
|
||||
.0
|
||||
.replace("\"", "")
|
||||
.parse::<usize>()
|
||||
.unwrap();
|
||||
let dtype = extract_dtype(egraph, kind_children[3]);
|
||||
(
|
||||
LLIROp::new::<LoopOutputSelect>(Box::new(Self {
|
||||
loop_id,
|
||||
stream_id,
|
||||
iter,
|
||||
dtype,
|
||||
})),
|
||||
input_enodes,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl HLIROp for LoopOutputSelect {
|
||||
fn to_egglog(&self, inp: &[(NodeIndex, String)]) -> String {
|
||||
format!(
|
||||
"(Op (LoopOutputSelect {} {} {} ({:?})) {})",
|
||||
self.loop_id,
|
||||
self.stream_id,
|
||||
self.iter,
|
||||
self.dtype,
|
||||
list_to_egglog(&inp.iter().map(|i| &i.1).collect_vec(), "ICons", "INil"),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl NativeOp for LoopOutputSelect {
|
||||
fn execute(&self, _: Vec<&NativeData>, _: &FxHashMap<char, usize>) -> NativeData {
|
||||
unimplemented!("LoopOutputSelect is driven by the runtime loop compiler")
|
||||
}
|
||||
}
|
||||
|
||||
/// Produces a single number constant from an expression or a float
|
||||
#[derive(Clone, PartialEq, Default)]
|
||||
pub struct Constant(pub f32);
|
||||
@@ -555,28 +1246,6 @@ impl NativeOp for Cast {
|
||||
}
|
||||
}
|
||||
|
||||
/// Graph break for chunking search graphs
|
||||
#[derive(Clone, PartialEq, Default)]
|
||||
pub struct GraphBreak {
|
||||
pub input_shape: ShapeTracker,
|
||||
}
|
||||
impl Debug for GraphBreak {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "GraphBreak")
|
||||
}
|
||||
}
|
||||
impl Display for GraphBreak {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "GraphBreak")
|
||||
}
|
||||
}
|
||||
|
||||
impl HLIROp for GraphBreak {
|
||||
fn to_egglog(&self, _: &[(NodeIndex, String)]) -> String {
|
||||
panic!("Cannot turn GraphBreak into egglog op!");
|
||||
}
|
||||
}
|
||||
|
||||
// Unary Op (A -> A)
|
||||
|
||||
fn unary_impl(
|
||||
@@ -1009,7 +1678,12 @@ impl EgglogOp for Add {
|
||||
2
|
||||
}
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![dtype_propagation_op(&self.sort())]
|
||||
let mut r = vec![dtype_propagation_op(&self.sort())];
|
||||
r.extend(self.early_rewrites());
|
||||
r
|
||||
}
|
||||
fn early_rewrites(&self) -> Vec<Rule> {
|
||||
binary_op_unroll_rules("Add", 4)
|
||||
}
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
@@ -1094,7 +1768,12 @@ impl EgglogOp for Mul {
|
||||
2
|
||||
}
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![dtype_propagation_op(&self.sort())]
|
||||
let mut r = vec![dtype_propagation_op(&self.sort())];
|
||||
r.extend(self.early_rewrites());
|
||||
r
|
||||
}
|
||||
fn early_rewrites(&self) -> Vec<Rule> {
|
||||
binary_op_unroll_rules("Mul", 4)
|
||||
}
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
@@ -1179,7 +1858,12 @@ impl EgglogOp for Mod {
|
||||
2
|
||||
}
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![dtype_propagation_op(&self.sort())]
|
||||
let mut r = vec![dtype_propagation_op(&self.sort())];
|
||||
r.extend(self.early_rewrites());
|
||||
r
|
||||
}
|
||||
fn early_rewrites(&self) -> Vec<Rule> {
|
||||
binary_op_unroll_rules("Mod", 4)
|
||||
}
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
@@ -1264,8 +1948,13 @@ impl EgglogOp for LessThan {
|
||||
2
|
||||
}
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
// Comparison operations always output Bool
|
||||
vec![dtype_fixed_op(&self.sort(), &SORTS.bool_dt)]
|
||||
// Comparisons output Bool, not the input dtype.
|
||||
let mut r = vec![dtype_fixed_op(&self.sort(), &SORTS.bool_dt)];
|
||||
r.extend(self.early_rewrites());
|
||||
r
|
||||
}
|
||||
fn early_rewrites(&self) -> Vec<Rule> {
|
||||
binary_op_unroll_rules("LessThan", 4)
|
||||
}
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
@@ -2200,6 +2889,10 @@ impl Runtime for NativeRuntime {
|
||||
(0, "0 ms".to_string())
|
||||
}
|
||||
|
||||
fn aggregate_profile_metrics(metrics: &[Self::ProfileMetric]) -> Self::ProfileMetric {
|
||||
metrics.iter().copied().sum()
|
||||
}
|
||||
|
||||
fn load_llir(&mut self, llir_graph: &LLIRGraph) {
|
||||
// Extract nativeop graph
|
||||
let mut graph = StableGraph::new();
|
||||
@@ -2253,13 +2946,19 @@ impl Runtime for NativeRuntime {
|
||||
self.buffers.insert(node, output);
|
||||
}
|
||||
|
||||
// Consume all non-Output buffers (inputs + intermediates)
|
||||
let output_nodes: FxHashSet<NodeIndex> = self
|
||||
// Free intermediate computation buffers; keep Input (weights/user data) and Output nodes.
|
||||
// Keeping Input buffers allows the graph to be called multiple times without re-loading
|
||||
// weights. User inputs are re-set before each call via set_data, so stale values are
|
||||
// overwritten. Weight inputs are set once and must survive across calls.
|
||||
let keep_nodes: FxHashSet<NodeIndex> = self
|
||||
.graph
|
||||
.node_indices()
|
||||
.filter(|n| (**self.graph[*n]).as_any().is::<Output>())
|
||||
.filter(|n| {
|
||||
(**self.graph[*n]).as_any().is::<Output>()
|
||||
|| (**self.graph[*n]).as_any().is::<Input>()
|
||||
})
|
||||
.collect();
|
||||
self.buffers.retain(|k, _| output_nodes.contains(k));
|
||||
self.buffers.retain(|k, _| keep_nodes.contains(k));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
pub mod dtype;
|
||||
pub mod dyn_backend;
|
||||
pub mod egglog_utils;
|
||||
pub mod frontend;
|
||||
pub mod graph;
|
||||
|
||||
16
src/op.rs
16
src/op.rs
@@ -21,6 +21,16 @@ pub trait Runtime {
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
trials: usize,
|
||||
) -> (Self::ProfileMetric, String);
|
||||
/// Aggregate multiple profile metrics into one comparable metric.
|
||||
/// Used for regionalized profiling where one candidate maps to multiple LLIR regions.
|
||||
fn aggregate_profile_metrics(metrics: &[Self::ProfileMetric]) -> Self::ProfileMetric {
|
||||
metrics
|
||||
.first()
|
||||
.unwrap_or_else(|| panic!("aggregate_profile_metrics called with empty metrics"))
|
||||
.clone()
|
||||
}
|
||||
/// Optional per-candidate profiling timeout used by search.
|
||||
fn set_profile_timeout(&mut self, _timeout: Option<std::time::Duration>) {}
|
||||
/// Allocate a dummy input buffer for a boundary node during per-chunk profiling.
|
||||
/// `node_index` is the HLIR node index used in the Input op's `node` field.
|
||||
/// `num_bytes` is the number of bytes to allocate.
|
||||
@@ -226,7 +236,11 @@ impl LLIROp {
|
||||
assert!(
|
||||
op.type_name().contains("dyn")
|
||||
|| op.type_name().contains("Input")
|
||||
|| op.type_name().contains("Output"),
|
||||
|| op.type_name().contains("Output")
|
||||
|| op.type_name().contains("LoopStart")
|
||||
|| op.type_name().contains("LoopEnd")
|
||||
|| op.type_name().contains("LoopInput")
|
||||
|| op.type_name().contains("LoopOutput"),
|
||||
"op types must be erased into dialect traits for dialect casting to work!"
|
||||
);
|
||||
Self(Arc::new(Box::new(DialectOp::new(op))))
|
||||
|
||||
@@ -7,7 +7,17 @@ pub use tracker::*;
|
||||
use std::ops::{Bound, Range, RangeBounds, RangeFrom, RangeFull, RangeTo, RangeToInclusive};
|
||||
|
||||
pub fn flatten_strides(range: &[Expression], strides: &[Expression]) -> Expression {
|
||||
assert_eq!(range.len(), strides.len());
|
||||
assert_eq!(
|
||||
range.len(),
|
||||
strides.len(),
|
||||
"flatten_strides: shape and strides must have matching dimensionality \
|
||||
(got shape len {}, strides len {}). This typically means an HLIR op \
|
||||
was constructed or extracted with mismatched fields — common culprit \
|
||||
is a Scatter / Gather kernel whose index_strides or src_strides list \
|
||||
wasn't populated alongside index_shape.",
|
||||
range.len(),
|
||||
strides.len(),
|
||||
);
|
||||
let mut current_elem_size = Expression::from(1);
|
||||
let mut flat_stride = Expression::from(0);
|
||||
for (dim, (range, stride)) in range.iter().zip(strides).enumerate().rev() {
|
||||
|
||||
@@ -485,3 +485,56 @@ fn test_only_outputs_remain() {
|
||||
.count();
|
||||
assert_eq!(rt.buffers.len(), output_count);
|
||||
}
|
||||
|
||||
fn build_repeated_block_graph(
|
||||
layers: usize,
|
||||
width: usize,
|
||||
) -> (Graph, NodeIndex, Vec<NodeIndex>, NodeIndex) {
|
||||
let mut cx = Graph::new();
|
||||
let x = cx.tensor(width);
|
||||
let mut state = x;
|
||||
let mut weight_nodes = Vec::with_capacity(layers * 2);
|
||||
for i in 0..layers {
|
||||
let w = cx.named_tensor(format!("w_{i}"), width);
|
||||
let b = cx.named_tensor(format!("b_{i}"), width);
|
||||
weight_nodes.push(w.id);
|
||||
weight_nodes.push(b.id);
|
||||
state = ((state * w) + b).sin();
|
||||
}
|
||||
let y = state.output();
|
||||
(cx, x.id, weight_nodes, y.id)
|
||||
}
|
||||
|
||||
fn repeated_block_reference(layers: usize, input: &[f32], weights: &[Vec<f32>]) -> Vec<f32> {
|
||||
let mut state = input.to_vec();
|
||||
for i in 0..layers {
|
||||
let w = &weights[i * 2];
|
||||
let b = &weights[i * 2 + 1];
|
||||
for ((s, wi), bi) in state.iter_mut().zip(w.iter()).zip(b.iter()) {
|
||||
*s = (*s * *wi + *bi).sin();
|
||||
}
|
||||
}
|
||||
state
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn integration_auto_loop_rolling_matches_reference_native_runtime() {
|
||||
let layers = 12;
|
||||
let width = 16;
|
||||
let input = random_vec(width);
|
||||
let weights: Vec<Vec<f32>> = (0..layers * 2).map(|_| random_vec(width)).collect();
|
||||
|
||||
let reference = repeated_block_reference(layers, &input, &weights);
|
||||
|
||||
let (mut graph, input_id, weight_ids, output_id) = build_repeated_block_graph(layers, width);
|
||||
graph.build_search_space::<NativeRuntime>();
|
||||
let mut rt = graph.search(NativeRuntime::default(), 1);
|
||||
rt.set_data(input_id, input);
|
||||
for (node, data) in weight_ids.iter().zip(weights.iter()) {
|
||||
rt.set_data(*node, data.clone());
|
||||
}
|
||||
rt.execute(&graph.dyn_map);
|
||||
let out = rt.get_f32(output_id);
|
||||
|
||||
assert_close(&reference, out);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user