mirror of
https://github.com/jafioti/luminal.git
synced 2026-06-01 21:49:47 +09:00
Compare commits
6 Commits
codex/rust
...
codex-lumi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
79d00a4827 | ||
|
|
acad3a625a | ||
|
|
07ad11d101 | ||
|
|
98f4f2102b | ||
|
|
896c4b7c7e | ||
|
|
0134aa425a |
4
.github/workflows/modal-examples.yml
vendored
4
.github/workflows/modal-examples.yml
vendored
@@ -18,11 +18,11 @@ jobs:
|
||||
name: "${{ matrix.example }} (Modal ${{ matrix.gpu.type }})"
|
||||
runs-on: ubuntu-latest
|
||||
environment: Modal
|
||||
timeout-minutes: 120
|
||||
timeout-minutes: 70
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
example: [llama, gemma, qwen, qwen3_moe, gemma4_moe, whisper]
|
||||
example: [llama, gemma, qwen, qwen3_moe]
|
||||
gpu:
|
||||
- { type: "A100-80GB" }
|
||||
# To add more GPUs, just append another entry:
|
||||
|
||||
2
.github/workflows/test-core.yml
vendored
2
.github/workflows/test-core.yml
vendored
@@ -21,4 +21,4 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- name: Run tests
|
||||
run: cargo test --release --workspace --exclude luminal_cuda_lite --exclude luminal_metal --exclude luminal_bench --verbose
|
||||
run: cargo test --workspace --exclude luminal_cuda_lite --exclude luminal_metal --exclude luminal_bench --verbose
|
||||
|
||||
2
.github/workflows/test-cuda.yml
vendored
2
.github/workflows/test-cuda.yml
vendored
@@ -18,7 +18,7 @@ jobs:
|
||||
name: Cuda Unit Tests
|
||||
runs-on: ubuntu-latest
|
||||
environment: Modal
|
||||
timeout-minutes: 120
|
||||
timeout-minutes: 30
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
2
.github/workflows/test-metal.yml
vendored
2
.github/workflows/test-metal.yml
vendored
@@ -16,4 +16,4 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- name: Run Metal crate tests
|
||||
run: rustup update; cargo test --release -p luminal_metal --verbose -- --test-threads=1
|
||||
run: rustup update; cargo test -p luminal_metal --verbose -- --test-threads=1
|
||||
|
||||
4
.github/workflows/test-python-cuda.yml
vendored
4
.github/workflows/test-python-cuda.yml
vendored
@@ -18,7 +18,7 @@ jobs:
|
||||
name: Python CUDA Tests
|
||||
runs-on: ubuntu-latest
|
||||
environment: Modal
|
||||
timeout-minutes: 120
|
||||
timeout-minutes: 60
|
||||
defaults:
|
||||
run:
|
||||
working-directory: crates/luminal_python
|
||||
@@ -38,7 +38,7 @@ jobs:
|
||||
MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }}
|
||||
MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }}
|
||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||
run: modal run modal_pytest_runner.py --gpu A100 --timeout 7200 --profile --profile-output-dir luminal_artifacts/pytest-profiling/github-${{ github.run_id }}-${{ github.run_attempt }} tests/ -v -s -m "not slow"
|
||||
run: modal run modal_pytest_runner.py --gpu A100 --timeout 3300 --profile --profile-output-dir luminal_artifacts/pytest-profiling/github-${{ github.run_id }}-${{ github.run_attempt }} tests/ -v -s -m "not slow"
|
||||
- name: Upload Modal pytest profiling artifacts
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
|
||||
2
.github/workflows/test-python-native.yml
vendored
2
.github/workflows/test-python-native.yml
vendored
@@ -23,6 +23,6 @@ jobs:
|
||||
- name: Update Rust toolchain
|
||||
run: rustup update
|
||||
- name: Build maturin extension
|
||||
run: uv run maturin develop --manifest-path rust/Cargo.toml --profile release
|
||||
run: uv run maturin develop --manifest-path rust/Cargo.toml
|
||||
- name: Run pytest
|
||||
run: uv run pytest tests/test_hlir_ops.py tests/test_unary.py -v -m "not slow"
|
||||
|
||||
@@ -25,7 +25,6 @@ generational-box = "0.5.6"
|
||||
serde_json = "1.0.140"
|
||||
egglog = {git="https://github.com/egraphs-good/egglog", rev="0a8cc35a6c68d0460c20449d5fa19ca3caba2923"}
|
||||
egglog-ast = {git="https://github.com/egraphs-good/egglog", rev="0a8cc35a6c68d0460c20449d5fa19ca3caba2923"}
|
||||
egglog-reports = {git="https://github.com/egraphs-good/egglog", rev="0a8cc35a6c68d0460c20449d5fa19ca3caba2923"}
|
||||
egraph-serialize = { version = "0.3.0", default-features = false, features = ["graphviz", "serde"]}
|
||||
tracing = "0.1.43"
|
||||
paste = "1.0.15"
|
||||
|
||||
@@ -28,7 +28,7 @@ cuda_image = (
|
||||
@app.function(
|
||||
image=cuda_image,
|
||||
gpu=gpu_type,
|
||||
timeout=7200, # 2 hours
|
||||
timeout=1800, # 30 minutes
|
||||
)
|
||||
def run_cargo_test():
|
||||
"""Run cargo test for luminal_cuda_lite on a Modal GPU."""
|
||||
@@ -47,7 +47,6 @@ def run_cargo_test():
|
||||
[
|
||||
"cargo",
|
||||
"test",
|
||||
"--release",
|
||||
"-p",
|
||||
"luminal_cuda_lite",
|
||||
"--verbose",
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
import modal
|
||||
import subprocess
|
||||
import os
|
||||
|
||||
example = os.environ.get("EXAMPLE", "llama")
|
||||
gpu_type = os.environ.get("GPU_TYPE", "A100-80GB")
|
||||
@@ -21,79 +18,6 @@ hf_cache = modal.Volume.from_name(
|
||||
|
||||
WORKDIR = "/workspace/luminal"
|
||||
|
||||
ANSI_ESCAPE = re.compile(r"\x1b\[[0-?]*[ -/]*[@-~]")
|
||||
|
||||
EXPECTED_OUTPUT = {
|
||||
"llama": [
|
||||
"complex system modeled after the structure and function of the human brain",
|
||||
],
|
||||
"gemma": [
|
||||
"recognize pictures of cats",
|
||||
"little detectives looking for specific features",
|
||||
],
|
||||
"qwen": [
|
||||
"computational model inspired by the structure and function of the human brain",
|
||||
],
|
||||
"qwen3_moe": [
|
||||
"The capital of France is Paris",
|
||||
],
|
||||
"gemma4_moe": [
|
||||
"city of romance, art and culture",
|
||||
],
|
||||
"whisper": [
|
||||
"ask not what your country can do for you",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def run_and_capture(command: list[str], *, cwd: str, env: dict[str, str]) -> str:
|
||||
process = subprocess.Popen(
|
||||
command,
|
||||
cwd=cwd,
|
||||
env=env,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
)
|
||||
assert process.stdout is not None
|
||||
|
||||
chunks = []
|
||||
while True:
|
||||
chunk = process.stdout.read1(4096)
|
||||
if not chunk:
|
||||
break
|
||||
sys.stdout.buffer.write(chunk)
|
||||
sys.stdout.buffer.flush()
|
||||
chunks.append(chunk)
|
||||
|
||||
return_code = process.wait()
|
||||
output = b"".join(chunks).decode("utf-8", errors="replace")
|
||||
if return_code:
|
||||
raise subprocess.CalledProcessError(return_code, command, output=output)
|
||||
return output
|
||||
|
||||
|
||||
def normalize_output(output: str) -> str:
|
||||
output = ANSI_ESCAPE.sub("", output)
|
||||
output = output.replace("\r", "\n")
|
||||
return re.sub(r"\s+", " ", output).casefold()
|
||||
|
||||
|
||||
def validate_output(example: str, output: str):
|
||||
expected_phrases = EXPECTED_OUTPUT.get(example)
|
||||
if expected_phrases is None:
|
||||
raise ValueError(f"No expected output phrases configured for example {example!r}")
|
||||
|
||||
normalized_output = normalize_output(output)
|
||||
for phrase in expected_phrases:
|
||||
if normalize_output(phrase) in normalized_output:
|
||||
print(f"\nOutput check passed for {example!r}: found {phrase!r}")
|
||||
return
|
||||
|
||||
expected = "\n - ".join(expected_phrases)
|
||||
raise AssertionError(
|
||||
f"Output check failed for {example!r}. Expected one of:\n - {expected}"
|
||||
)
|
||||
|
||||
cuda_image = (
|
||||
modal.Image.from_registry(
|
||||
"nvcr.io/nvidia/pytorch:25.03-py3"
|
||||
@@ -115,7 +39,7 @@ cuda_image = (
|
||||
@app.function(
|
||||
image=cuda_image,
|
||||
gpu=gpu_type,
|
||||
timeout=7200, # 2 hours
|
||||
timeout=3600, # 60 minutes
|
||||
volumes={
|
||||
HF_CACHE_PATH: hf_cache,
|
||||
},
|
||||
@@ -124,17 +48,16 @@ def run_example(example: str):
|
||||
"""Build and run a luminal example on a Modal GPU."""
|
||||
subprocess.run(["nvidia-smi"], check=True)
|
||||
|
||||
run_env = {
|
||||
**os.environ,
|
||||
"CUDARC_CUDA_VERSION": CUDARC_CUDA_VERSION,
|
||||
"HF_HOME": HF_CACHE_PATH,
|
||||
}
|
||||
output = run_and_capture(
|
||||
subprocess.run(
|
||||
["cargo", "run", "--release"],
|
||||
cwd=f"{WORKDIR}/examples/{example}",
|
||||
env=run_env,
|
||||
env={
|
||||
**os.environ,
|
||||
"CUDARC_CUDA_VERSION": CUDARC_CUDA_VERSION,
|
||||
"HF_HOME": HF_CACHE_PATH,
|
||||
},
|
||||
check=True,
|
||||
)
|
||||
validate_output(example, output)
|
||||
|
||||
hf_cache.commit()
|
||||
|
||||
|
||||
@@ -10,8 +10,7 @@ license = "MIT OR Apache-2.0"
|
||||
[dependencies]
|
||||
luminal = { path = "../.." }
|
||||
luminal_tracing = { path = "../luminal_tracing" }
|
||||
cudarc = {version="0.19.4", features=["cuda-version-from-build-system", "fallback-latest"]}
|
||||
anyhow = "1.0"
|
||||
cudarc = {version="0.18.2", features=["cuda-version-from-build-system", "fallback-latest"]}
|
||||
as-any = "0.3.2"
|
||||
itertools = "0.12.1"
|
||||
fixedbitset = "0.5.7"
|
||||
@@ -24,7 +23,6 @@ memmap2 = "0.9.9"
|
||||
uuid = {version="1.19.0", features=["v4"]}
|
||||
lru = "0.16.2"
|
||||
libc = "0.2"
|
||||
libloading = "0.8"
|
||||
colorize = "*"
|
||||
|
||||
[dev-dependencies]
|
||||
|
||||
@@ -1,611 +0,0 @@
|
||||
use std::{collections::BTreeMap, sync::Arc, time::Instant};
|
||||
|
||||
use itertools::Itertools;
|
||||
use luminal::prelude::egglog::{ast::Span, prelude::RustSpan};
|
||||
use luminal::{
|
||||
dtype::DType,
|
||||
egglog_utils::{
|
||||
base::{base_cleanup_egglog, base_expression_egglog},
|
||||
hlir_to_egglog,
|
||||
},
|
||||
hlir::HLIROps,
|
||||
op::{EgglogOp, IntoEgglogOp, Runtime},
|
||||
prelude::*,
|
||||
shape::Expression,
|
||||
};
|
||||
use luminal_cuda_lite::runtime::CudaRuntime;
|
||||
|
||||
const DEFAULT_PASSES: usize = 256;
|
||||
const EGGLOG_RULESETS: &[&str] = &[
|
||||
"matmul_flatten",
|
||||
"kernel_lower",
|
||||
"direct_kernel",
|
||||
"kernel_specialize",
|
||||
"buffer_reuse",
|
||||
"matmul_backend",
|
||||
"glumoe",
|
||||
"fusion_pair",
|
||||
"fusion_grow",
|
||||
"fusion_merge",
|
||||
];
|
||||
const MOE_SEQ: usize = 2;
|
||||
const MOE_HIDDEN: usize = 16;
|
||||
const MOE_NUM_EXPERTS: usize = 8;
|
||||
const MOE_TOP_K: usize = 2;
|
||||
const MOE_INTERMEDIATE: usize = 6;
|
||||
const GEMMA_RMS_NORM_EPS: f32 = 1e-6;
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
enum Backend {
|
||||
Native,
|
||||
Cuda,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
enum Mode {
|
||||
Current,
|
||||
Steps,
|
||||
FullDefault,
|
||||
FullCycle,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
enum Case {
|
||||
Mul,
|
||||
UnaryChain(usize),
|
||||
Gelu,
|
||||
Softmax,
|
||||
LayerNorm,
|
||||
Matmul,
|
||||
Attention,
|
||||
QwenMoe,
|
||||
GemmaMoe,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Args {
|
||||
backend: Backend,
|
||||
mode: Mode,
|
||||
case: Case,
|
||||
passes: usize,
|
||||
cleanup: bool,
|
||||
skip_roll: bool,
|
||||
}
|
||||
|
||||
fn parse_args() -> Args {
|
||||
let mut args = Args {
|
||||
backend: Backend::Cuda,
|
||||
mode: Mode::Current,
|
||||
case: Case::Gelu,
|
||||
passes: DEFAULT_PASSES,
|
||||
cleanup: true,
|
||||
skip_roll: false,
|
||||
};
|
||||
|
||||
let mut iter = std::env::args().skip(1);
|
||||
while let Some(arg) = iter.next() {
|
||||
match arg.as_str() {
|
||||
"--backend" => {
|
||||
args.backend = match iter.next().as_deref() {
|
||||
Some("native") => Backend::Native,
|
||||
Some("cuda") => Backend::Cuda,
|
||||
other => panic!("invalid --backend {other:?}; use native|cuda"),
|
||||
};
|
||||
}
|
||||
"--mode" => {
|
||||
args.mode = match iter.next().as_deref() {
|
||||
Some("current") => Mode::Current,
|
||||
Some("steps") => Mode::Steps,
|
||||
Some("full-default") => Mode::FullDefault,
|
||||
Some("full-cycle") => Mode::FullCycle,
|
||||
other => panic!(
|
||||
"invalid --mode {other:?}; use current|steps|full-default|full-cycle"
|
||||
),
|
||||
};
|
||||
}
|
||||
"--case" => {
|
||||
args.case = parse_case(&iter.next().expect("missing --case value"));
|
||||
}
|
||||
"--passes" => {
|
||||
args.passes = iter
|
||||
.next()
|
||||
.expect("missing --passes value")
|
||||
.parse()
|
||||
.expect("invalid --passes value");
|
||||
}
|
||||
"--no-cleanup" => args.cleanup = false,
|
||||
"--skip-roll" => args.skip_roll = true,
|
||||
"--help" | "-h" => {
|
||||
println!(
|
||||
"Usage: egglog_saturation [OPTIONS]\n\
|
||||
\n\
|
||||
Options:\n\
|
||||
--backend native|cuda default: cuda\n\
|
||||
--mode current|steps|full-default|full-cycle\n\
|
||||
--case mul|unary-chain:N|gelu|softmax|layer-norm|matmul|attention|qwen-moe|gemma-moe\n\
|
||||
--passes N default: 256\n\
|
||||
--no-cleanup omit backend/HLIR cleanup rules\n\
|
||||
--skip-roll skip auto loop rolling prepass"
|
||||
);
|
||||
std::process::exit(0);
|
||||
}
|
||||
other => panic!("unknown argument {other}; use --help"),
|
||||
}
|
||||
}
|
||||
|
||||
args
|
||||
}
|
||||
|
||||
fn parse_case(s: &str) -> Case {
|
||||
if let Some(n) = s.strip_prefix("unary-chain:") {
|
||||
return Case::UnaryChain(n.parse().expect("invalid unary-chain length"));
|
||||
}
|
||||
match s {
|
||||
"mul" => Case::Mul,
|
||||
"gelu" => Case::Gelu,
|
||||
"softmax" => Case::Softmax,
|
||||
"layer-norm" | "layer_norm" => Case::LayerNorm,
|
||||
"matmul" => Case::Matmul,
|
||||
"attention" => Case::Attention,
|
||||
"qwen-moe" | "qwen_moe" => Case::QwenMoe,
|
||||
"gemma-moe" | "gemma_moe" => Case::GemmaMoe,
|
||||
other => panic!("unknown case {other}"),
|
||||
}
|
||||
}
|
||||
|
||||
fn build_case(case: Case) -> Graph {
|
||||
let mut cx = Graph::new();
|
||||
let out = match case {
|
||||
Case::Mul => {
|
||||
let x = cx.tensor((64, 64));
|
||||
x * x
|
||||
}
|
||||
Case::UnaryChain(n) => {
|
||||
let mut x = cx.tensor((64, 64));
|
||||
for i in 0..n {
|
||||
x = match i % 6 {
|
||||
0 => x.sin(),
|
||||
1 => x.sqrt(),
|
||||
2 => x.reciprocal(),
|
||||
3 => x.exp2(),
|
||||
4 => x.log2(),
|
||||
_ => x * 1.125,
|
||||
};
|
||||
}
|
||||
x
|
||||
}
|
||||
Case::Gelu => cx.tensor((64, 64)).gelu(),
|
||||
Case::Softmax => cx.tensor((128, 128)).softmax(1),
|
||||
Case::LayerNorm => cx.tensor((128, 128)).layer_norm(1, 1e-5),
|
||||
Case::Matmul => {
|
||||
let a = cx.tensor((32, 64));
|
||||
let b = cx.tensor((64, 32));
|
||||
a.matmul(b)
|
||||
}
|
||||
Case::Attention => {
|
||||
let q = cx.tensor((64, 32));
|
||||
let k = cx.tensor((64, 32));
|
||||
let v = cx.tensor((64, 32));
|
||||
let scores = q.matmul(k.permute((1, 0))) * (1.0 / 32.0_f32.sqrt());
|
||||
scores.softmax(1).matmul(v)
|
||||
}
|
||||
Case::QwenMoe => build_qwen_moe(&mut cx),
|
||||
Case::GemmaMoe => build_gemma_moe(&mut cx),
|
||||
};
|
||||
let _ = out.output();
|
||||
cx
|
||||
}
|
||||
|
||||
fn build_qwen_moe(cx: &mut Graph) -> GraphTensor {
|
||||
cx.set_dim('s', MOE_SEQ);
|
||||
let x = cx.tensor(('s', MOE_HIDDEN));
|
||||
let router = cx.tensor((MOE_NUM_EXPERTS, MOE_HIDDEN));
|
||||
let gate_up_weights = cx
|
||||
.tensor((MOE_NUM_EXPERTS, MOE_INTERMEDIATE * 2, MOE_HIDDEN))
|
||||
.as_dtype(DType::Bf16);
|
||||
let down_weights = cx
|
||||
.tensor((MOE_NUM_EXPERTS, MOE_HIDDEN, MOE_INTERMEDIATE))
|
||||
.as_dtype(DType::Bf16);
|
||||
|
||||
let n = x.dims().len();
|
||||
let e_dim = *router.dims().first().unwrap();
|
||||
let k_expr = Expression::from(MOE_TOP_K);
|
||||
|
||||
let routing_weights = x.matmul(router.t()).softmax(n - 1);
|
||||
let top_k_indices = routing_weights.topk_indexes(MOE_TOP_K, n - 1);
|
||||
let row_offsets = x
|
||||
.graph()
|
||||
.iota(Expression::from('z') / k_expr * e_dim, top_k_indices.dims());
|
||||
let routing_flat_idx = row_offsets + top_k_indices;
|
||||
let top_k_values = routing_weights.gather(routing_flat_idx);
|
||||
|
||||
let gate_up_gathered = gather_experts(x, top_k_indices, gate_up_weights).cast(DType::F32);
|
||||
let x_exp = x.expand_dim(n - 1, MOE_TOP_K).unsqueeze(n);
|
||||
let gate_up_out = x_exp.matmul(gate_up_gathered.transpose(2, 3)).squeeze(n);
|
||||
let gate = gate_up_out.slice((.., .., ..MOE_INTERMEDIATE));
|
||||
let up = gate_up_out.slice((.., .., MOE_INTERMEDIATE..));
|
||||
let hidden = gate.silu() * up;
|
||||
|
||||
let down_gathered = gather_experts(x, top_k_indices, down_weights).cast(DType::F32);
|
||||
let down_out = hidden
|
||||
.unsqueeze(2)
|
||||
.matmul(down_gathered.transpose(2, 3))
|
||||
.squeeze(2);
|
||||
let mut weights_exp = top_k_values.unsqueeze(top_k_values.dims().len());
|
||||
weights_exp.shape.expand(down_out.dims());
|
||||
(down_out * weights_exp).sum(n - 1)
|
||||
}
|
||||
|
||||
fn build_gemma_moe(cx: &mut Graph) -> GraphTensor {
|
||||
cx.set_dim('s', MOE_SEQ);
|
||||
let router_input = cx.tensor(('s', MOE_HIDDEN));
|
||||
let expert_input = cx.tensor(('s', MOE_HIDDEN));
|
||||
let router_scale = cx.tensor(MOE_HIDDEN);
|
||||
let router_proj = cx.tensor((MOE_NUM_EXPERTS, MOE_HIDDEN));
|
||||
let per_expert_scale = cx.tensor(MOE_NUM_EXPERTS);
|
||||
let gate_up_weights = cx
|
||||
.tensor((MOE_NUM_EXPERTS, MOE_INTERMEDIATE * 2, MOE_HIDDEN))
|
||||
.as_dtype(DType::Bf16);
|
||||
let down_weights = cx
|
||||
.tensor((MOE_NUM_EXPERTS, MOE_HIDDEN, MOE_INTERMEDIATE))
|
||||
.as_dtype(DType::Bf16);
|
||||
|
||||
let n = router_input.dims().len();
|
||||
let e_dim = *router_proj.dims().first().unwrap();
|
||||
let k_expr = Expression::from(MOE_TOP_K);
|
||||
|
||||
let router_hidden = router_input.std_norm(n - 1, GEMMA_RMS_NORM_EPS)
|
||||
* router_scale.expand_lhs(&router_input.dims()[..n - 1])
|
||||
* (MOE_HIDDEN as f32).sqrt().recip();
|
||||
let routing_weights = router_hidden.matmul(router_proj.t()).softmax(n - 1);
|
||||
let top_k_indices = routing_weights.topk_indexes(MOE_TOP_K, n - 1);
|
||||
let row_offsets = router_input
|
||||
.graph()
|
||||
.iota(Expression::from('z') / k_expr * e_dim, top_k_indices.dims());
|
||||
let routing_flat_idx = row_offsets + top_k_indices;
|
||||
let top_k_values = routing_weights.gather(routing_flat_idx);
|
||||
let top_k_norm = top_k_values.sum(n - 1).expand_dim(n - 1, MOE_TOP_K);
|
||||
let top_k_weights = (top_k_values / top_k_norm) * per_expert_scale.gather(top_k_indices);
|
||||
|
||||
let gate_up_gathered =
|
||||
gather_experts(expert_input, top_k_indices, gate_up_weights).cast(DType::F32);
|
||||
let x_exp = expert_input.expand_dim(n - 1, MOE_TOP_K).unsqueeze(n);
|
||||
let gate_up_out = x_exp.matmul(gate_up_gathered.transpose(2, 3)).squeeze(n);
|
||||
let gate = gate_up_out.slice((.., .., ..MOE_INTERMEDIATE));
|
||||
let up = gate_up_out.slice((.., .., MOE_INTERMEDIATE..));
|
||||
let hidden = gemma_gelu(gate) * up;
|
||||
|
||||
let down_gathered = gather_experts(expert_input, top_k_indices, down_weights).cast(DType::F32);
|
||||
let down_out = hidden
|
||||
.unsqueeze(2)
|
||||
.matmul(down_gathered.transpose(2, 3))
|
||||
.squeeze(2);
|
||||
let mut weights_exp = top_k_weights.unsqueeze(top_k_weights.dims().len());
|
||||
weights_exp.shape.expand(down_out.dims());
|
||||
(down_out * weights_exp).sum(n - 1)
|
||||
}
|
||||
|
||||
fn gather_experts(
|
||||
graph_source: GraphTensor,
|
||||
top_k_indices: GraphTensor,
|
||||
weights: GraphTensor,
|
||||
) -> GraphTensor {
|
||||
let (_, d1, d2) = weights.dims3();
|
||||
let io = d1 * d2;
|
||||
let base = top_k_indices * io;
|
||||
let within = graph_source.graph().iota(Expression::from('z'), (d1, d2));
|
||||
let n_base = base.dims().len();
|
||||
let exp_base = base.expand_dim(n_base, d1).expand_dim(n_base + 1, d2);
|
||||
let mut exp_within = within;
|
||||
for (axis, dim) in base.dims().iter().enumerate() {
|
||||
exp_within = exp_within.expand_dim(axis, *dim);
|
||||
}
|
||||
weights.gather(exp_base + exp_within)
|
||||
}
|
||||
|
||||
#[allow(clippy::excessive_precision)]
|
||||
fn gemma_gelu(x: GraphTensor) -> GraphTensor {
|
||||
let scaled = 1.5957691216 * x * (1. + 0.044715 * x * x);
|
||||
x * scaled.sigmoid()
|
||||
}
|
||||
|
||||
fn op_defs_string(ops: &[Arc<Box<dyn EgglogOp>>]) -> String {
|
||||
let mut ir_variants = Vec::new();
|
||||
let mut opkind_variants = Vec::new();
|
||||
for op in ops {
|
||||
let sort = op.sort();
|
||||
let variant = format!(
|
||||
"({} {})",
|
||||
sort.name,
|
||||
sort.fields.iter().map(|field| &field.sort).join(" ")
|
||||
);
|
||||
match sort.class.as_str() {
|
||||
"IR" => ir_variants.push(variant),
|
||||
"OpKind" => opkind_variants.push(variant),
|
||||
other => panic!("unknown sort class {other} for {}", sort.name),
|
||||
}
|
||||
}
|
||||
let extra_ir = ops.iter().flat_map(|op| op.ir_defs()).unique().join("\n");
|
||||
format!(
|
||||
"
|
||||
(datatype*
|
||||
(IR
|
||||
(OutputJoin IR IR)
|
||||
(Op OpKind IList)
|
||||
{extra_ir}
|
||||
{}
|
||||
)
|
||||
(OpKind
|
||||
{}
|
||||
)
|
||||
(IList
|
||||
(ICons IR IList)
|
||||
(INil)
|
||||
)
|
||||
)
|
||||
(function dtype (IR) DType :merge new)
|
||||
",
|
||||
ir_variants.join("\n"),
|
||||
opkind_variants.join("\n")
|
||||
)
|
||||
}
|
||||
|
||||
fn op_cleanups_string(ops: &[Arc<Box<dyn EgglogOp>>]) -> String {
|
||||
ops.iter()
|
||||
.filter(|op| op.cleanup())
|
||||
.map(|op| {
|
||||
let sort = op.sort();
|
||||
let fields = (0..sort.fields.len())
|
||||
.map(|i| (b'a' + i as u8) as char)
|
||||
.join(" ");
|
||||
if sort.class == "OpKind" {
|
||||
format!(
|
||||
"(rule
|
||||
((= ?m (Op ({} {fields}) ?__cleanup_inputs)))
|
||||
((delete (Op ({} {fields}) ?__cleanup_inputs)))
|
||||
:ruleset cleanup)",
|
||||
sort.name, sort.name
|
||||
)
|
||||
} else {
|
||||
format!(
|
||||
"(rule
|
||||
((= ?m ({} {fields})))
|
||||
((delete ({} {fields})))
|
||||
:ruleset cleanup)",
|
||||
sort.name, sort.name
|
||||
)
|
||||
}
|
||||
})
|
||||
.join("\n")
|
||||
}
|
||||
|
||||
fn setup_program(program: &str, ops: &[Arc<Box<dyn EgglogOp>>], cleanup: bool) -> String {
|
||||
let rewrites = ops
|
||||
.iter()
|
||||
.flat_map(|op| op.rewrites())
|
||||
.map(|rule| rule.to_egglog_string())
|
||||
.join("\n");
|
||||
[
|
||||
EGGLOG_RULESETS
|
||||
.iter()
|
||||
.map(|ruleset| format!("(ruleset {ruleset})"))
|
||||
.join("\n"),
|
||||
base_expression_egglog(),
|
||||
op_defs_string(ops),
|
||||
if cleanup {
|
||||
op_cleanups_string(ops)
|
||||
} else {
|
||||
String::new()
|
||||
},
|
||||
base_cleanup_egglog(),
|
||||
rewrites,
|
||||
program.to_string(),
|
||||
]
|
||||
.join("\n")
|
||||
}
|
||||
|
||||
fn producer_schedule() -> String {
|
||||
"(seq
|
||||
(saturate expr)
|
||||
(saturate dtype_prop)
|
||||
(run matmul_flatten)
|
||||
(run kernel_lower)
|
||||
(run direct_kernel)
|
||||
(run kernel_specialize)
|
||||
(run buffer_reuse)
|
||||
(run matmul_backend)
|
||||
(run glumoe)
|
||||
(run fusion_pair)
|
||||
)"
|
||||
.to_string()
|
||||
}
|
||||
|
||||
fn fusion_schedule() -> String {
|
||||
"(seq
|
||||
(saturate expr)
|
||||
(saturate dtype_prop)
|
||||
(run fusion_grow)
|
||||
(run fusion_merge)
|
||||
)"
|
||||
.to_string()
|
||||
}
|
||||
|
||||
fn split_cycle() -> Vec<(&'static str, String)> {
|
||||
vec![
|
||||
("producers", format!("(saturate {})", producer_schedule())),
|
||||
("fusion", format!("(saturate {})", fusion_schedule())),
|
||||
]
|
||||
}
|
||||
|
||||
fn split_cycle_schedule() -> String {
|
||||
format!(
|
||||
"(seq
|
||||
(saturate {})
|
||||
(saturate {})
|
||||
)",
|
||||
producer_schedule(),
|
||||
fusion_schedule()
|
||||
)
|
||||
}
|
||||
|
||||
fn phase(egraph: &mut egglog::EGraph, name: &str, schedule: &str) -> bool {
|
||||
let before = egraph.num_tuples();
|
||||
let start = Instant::now();
|
||||
let command = format!("(run-schedule {schedule})");
|
||||
let outputs = egraph
|
||||
.parse_and_run_program(None, &command)
|
||||
.unwrap_or_else(|err| panic!("failed phase {name} schedule {schedule}: {err}"));
|
||||
let elapsed = start.elapsed();
|
||||
let after = egraph.num_tuples();
|
||||
let report = outputs
|
||||
.into_iter()
|
||||
.find_map(|output| match output {
|
||||
egglog::CommandOutput::RunSchedule(report) => Some(report),
|
||||
_ => None,
|
||||
})
|
||||
.expect("run-schedule did not return a report");
|
||||
let mut rules = report
|
||||
.search_and_apply_time_per_rule
|
||||
.iter()
|
||||
.map(|(rule, time)| {
|
||||
(
|
||||
rule.to_string(),
|
||||
*time,
|
||||
report
|
||||
.num_matches_per_rule
|
||||
.get(rule)
|
||||
.copied()
|
||||
.unwrap_or_default(),
|
||||
)
|
||||
})
|
||||
.collect_vec();
|
||||
rules.sort_by_key(|(_, time, matches)| (std::cmp::Reverse(*time), std::cmp::Reverse(*matches)));
|
||||
let matches = report.num_matches_per_rule.values().sum::<usize>();
|
||||
println!(
|
||||
"phase {name:<18} {elapsed_ms:>8.2} ms | tuples {before} -> {after} ({delta:+}) | updated={updated} | iters={iters} | matches={matches}",
|
||||
elapsed_ms = elapsed.as_secs_f64() * 1000.0,
|
||||
delta = after as isize - before as isize,
|
||||
updated = report.updated,
|
||||
iters = report.iterations.len(),
|
||||
);
|
||||
for (rule, time, matches) in rules
|
||||
.into_iter()
|
||||
.filter(|(_, time, matches)| !time.is_zero() || *matches > 0)
|
||||
.take(8)
|
||||
{
|
||||
println!(
|
||||
" rule {rule:<82} {ms:>8.2} ms | matches {matches}",
|
||||
ms = time.as_secs_f64() * 1000.0,
|
||||
);
|
||||
}
|
||||
report.updated
|
||||
}
|
||||
|
||||
fn serialize_summary(egraph: &mut egglog::EGraph, root: &str) {
|
||||
let (sort, value) = egraph.eval_expr(&egglog::var!(root.to_string())).unwrap();
|
||||
let output = egraph.serialize(egglog::SerializeConfig {
|
||||
root_eclasses: vec![(sort, value)],
|
||||
max_functions: None,
|
||||
include_temporary_functions: false,
|
||||
max_calls_per_function: None,
|
||||
});
|
||||
let mut classes = std::collections::BTreeSet::new();
|
||||
let mut top_ops = BTreeMap::<String, usize>::new();
|
||||
let mut nodes = 0usize;
|
||||
for node in output.egraph.nodes.values().filter(|node| !node.subsumed) {
|
||||
nodes += 1;
|
||||
classes.insert(node.eclass.clone());
|
||||
*top_ops.entry(node.op.clone()).or_default() += 1;
|
||||
}
|
||||
let top_ops = top_ops
|
||||
.into_iter()
|
||||
.sorted_by_key(|(_, count)| std::cmp::Reverse(*count))
|
||||
.take(12)
|
||||
.map(|(op, count)| format!("{op}={count}"))
|
||||
.join(", ");
|
||||
println!(
|
||||
"serialize nodes={nodes} classes={} roots={} top_ops={top_ops}",
|
||||
classes.len(),
|
||||
output.egraph.root_eclasses.len()
|
||||
);
|
||||
}
|
||||
|
||||
fn run(args: Args) {
|
||||
let mut graph = build_case(args.case);
|
||||
let rolled = if args.skip_roll {
|
||||
0
|
||||
} else {
|
||||
graph.auto_roll_loops_prepass()
|
||||
};
|
||||
let (program, root) = hlir_to_egglog(&graph);
|
||||
|
||||
let mut ops = match args.backend {
|
||||
Backend::Native => <NativeRuntime as Runtime>::Ops::into_vec(),
|
||||
Backend::Cuda => <CudaRuntime as Runtime>::Ops::into_vec(),
|
||||
};
|
||||
ops.extend(<HLIROps as IntoEgglogOp>::into_vec());
|
||||
let cleanup = args.cleanup && matches!(args.backend, Backend::Cuda);
|
||||
let setup = setup_program(&program, &ops, cleanup);
|
||||
|
||||
println!(
|
||||
"case={:?} backend={:?} mode={:?} passes={} cleanup={} rolled={} hlir_nodes={} setup_lines={} setup_bytes={} root={root}",
|
||||
args.case,
|
||||
args.backend,
|
||||
args.mode,
|
||||
args.passes,
|
||||
cleanup,
|
||||
rolled,
|
||||
graph.graph.node_count(),
|
||||
setup.lines().count(),
|
||||
setup.len(),
|
||||
);
|
||||
|
||||
let mut egraph = egglog::EGraph::default();
|
||||
let before = egraph.num_tuples();
|
||||
let start = Instant::now();
|
||||
let commands = egraph.parser.get_program_from_string(None, &setup).unwrap();
|
||||
egraph.run_program(commands).unwrap();
|
||||
println!(
|
||||
"setup {:>8.2} ms | tuples {before} -> {} ({:+})",
|
||||
start.elapsed().as_secs_f64() * 1000.0,
|
||||
egraph.num_tuples(),
|
||||
egraph.num_tuples() as isize - before as isize,
|
||||
);
|
||||
|
||||
match args.mode {
|
||||
Mode::Current | Mode::Steps => {
|
||||
for pass in 1..=args.passes {
|
||||
let mut updated = false;
|
||||
for (name, schedule) in split_cycle() {
|
||||
updated |= phase(&mut egraph, &format!("{pass:03} {name}"), &schedule);
|
||||
}
|
||||
if matches!(args.mode, Mode::Current) && !updated {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Mode::FullDefault => {
|
||||
phase(&mut egraph, "expr", "(saturate expr)");
|
||||
phase(&mut egraph, "dtype", "(saturate dtype_prop)");
|
||||
phase(&mut egraph, "default-full", "(saturate (run))");
|
||||
}
|
||||
Mode::FullCycle => {
|
||||
phase(
|
||||
&mut egraph,
|
||||
"cycle-full",
|
||||
&format!("(saturate {})", split_cycle_schedule()),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
phase(&mut egraph, "final expr", "(saturate expr)");
|
||||
if cleanup {
|
||||
phase(&mut egraph, "cleanup", "(saturate cleanup)");
|
||||
}
|
||||
phase(&mut egraph, "base cleanup", "(saturate base_cleanup)");
|
||||
serialize_summary(&mut egraph, &root);
|
||||
}
|
||||
|
||||
fn main() {
|
||||
run(parse_args());
|
||||
}
|
||||
@@ -1,198 +0,0 @@
|
||||
//! ComputeAttnMask — fused op that computes the paged attention mask from indptrs.
|
||||
//!
|
||||
//! This op exists so the indptr tensors (qo_indptr, kv_indptr) are visible in the
|
||||
//! same e-graph chunk as the attention pattern, letting the FlashInfer egglog rule
|
||||
//! capture them directly.
|
||||
//!
|
||||
//! Inputs (3): q_pos (s,) Int, qo_indptr (r,) Int, kv_indptr (r,) Int.
|
||||
//! Output: mask (s, c) F32 where mask[i, j] = 0.0 (attend) or -1e10 (block).
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use luminal::{
|
||||
egglog_utils::{
|
||||
api::{Rule, SortDef, sort},
|
||||
base::{EXPRESSION, OP_KIND},
|
||||
extract_expr,
|
||||
},
|
||||
op::{EgglogOp, HLIROp, LLIROp},
|
||||
prelude::*,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
cudarc::driver::{CudaStream, result},
|
||||
host::{DeviceBuffer, HostOp},
|
||||
};
|
||||
|
||||
/// Computes the paged attention mask from indptr arrays.
|
||||
///
|
||||
/// The mask encodes both request-membership and causality:
|
||||
/// `mask[i, j] = 0.0` if query `i` and context `j` belong to the same request AND
|
||||
/// context `j`'s local position is `<= q_pos[i]`; `-1e10` otherwise.
|
||||
#[derive(Debug, Default)]
|
||||
pub struct ComputeAttnMask {
|
||||
pub s_dim: Expression,
|
||||
pub c_dim: Expression,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ComputeAttnMask {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "ComputeAttnMask(s={}, c={})", self.s_dim, self.c_dim)
|
||||
}
|
||||
}
|
||||
|
||||
impl HLIROp for ComputeAttnMask {
|
||||
fn to_egglog(&self, inputs: &[(NodeIndex, String)]) -> String {
|
||||
format!(
|
||||
"(Op (ComputeAttnMask {} {}) (ICons {} (ICons {} (ICons {} (INil)))))",
|
||||
self.s_dim.to_egglog(),
|
||||
self.c_dim.to_egglog(),
|
||||
inputs[0].1, // q_pos
|
||||
inputs[1].1, // qo_indptr
|
||||
inputs[2].1, // kv_indptr
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl EgglogOp for ComputeAttnMask {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
"ComputeAttnMask",
|
||||
&[("s_dim", EXPRESSION), ("c_dim", EXPRESSION)],
|
||||
)
|
||||
}
|
||||
|
||||
fn n_inputs(&self) -> usize {
|
||||
3
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
// No rewrites — inserted directly by model code.
|
||||
vec![]
|
||||
}
|
||||
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a luminal::egglog_utils::SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
_list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
let s_dim = extract_expr(egraph, kind_children[0], expr_cache).unwrap();
|
||||
let c_dim = extract_expr(egraph, kind_children[1], expr_cache).unwrap();
|
||||
let op = Self { s_dim, c_dim };
|
||||
let llir_op = LLIROp::new::<dyn HostOp>(Box::new(op) as Box<dyn HostOp>);
|
||||
(llir_op, input_enodes)
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
impl HostOp for ComputeAttnMask {
|
||||
fn execute(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
self_node: NodeIndex,
|
||||
inputs: &[NodeIndex],
|
||||
buffers: &FxHashMap<NodeIndex, DeviceBuffer>,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> anyhow::Result<()> {
|
||||
if inputs.len() < 3 {
|
||||
anyhow::bail!(
|
||||
"ComputeAttnMask expects 3 inputs (q_pos, qo_indptr, kv_indptr), got {}",
|
||||
inputs.len()
|
||||
);
|
||||
}
|
||||
|
||||
let s = self
|
||||
.s_dim
|
||||
.exec(dyn_map)
|
||||
.ok_or_else(|| anyhow::anyhow!("ComputeAttnMask s_dim unresolved"))?;
|
||||
let c = self
|
||||
.c_dim
|
||||
.exec(dyn_map)
|
||||
.ok_or_else(|| anyhow::anyhow!("ComputeAttnMask c_dim unresolved"))?;
|
||||
let r = *dyn_map
|
||||
.get(&'r')
|
||||
.ok_or_else(|| anyhow::anyhow!("ComputeAttnMask requires dynamic dim 'r'"))?;
|
||||
|
||||
let get_buf = |name: &str, node: NodeIndex| -> anyhow::Result<DeviceBuffer> {
|
||||
buffers.get(&node).copied().ok_or_else(|| {
|
||||
anyhow::anyhow!("ComputeAttnMask missing {name} buffer for {node:?}")
|
||||
})
|
||||
};
|
||||
|
||||
let q_pos_buf = get_buf("q_pos", inputs[0])?;
|
||||
let qo_indptr_buf = get_buf("qo_indptr", inputs[1])?;
|
||||
let kv_indptr_buf = get_buf("kv_indptr", inputs[2])?;
|
||||
let out_buf = get_buf("output", self_node)?;
|
||||
|
||||
let q_pos = dtoh_i32(stream, q_pos_buf.ptr(), s)?;
|
||||
let qo_indptr = dtoh_i32(stream, qo_indptr_buf.ptr(), r)?;
|
||||
let kv_indptr = dtoh_i32(stream, kv_indptr_buf.ptr(), r)?;
|
||||
|
||||
let mut mask = vec![-1e10f32; s * c];
|
||||
for i in 0..s {
|
||||
let q_req = indptr_to_request(&qo_indptr, i as i32);
|
||||
for j in 0..c {
|
||||
let c_req = indptr_to_request(&kv_indptr, j as i32);
|
||||
if q_req == c_req && q_req >= 0 {
|
||||
let c_local = j as i32 - kv_indptr[c_req as usize];
|
||||
if c_local <= q_pos[i] {
|
||||
mask[i * c + j] = 0.0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mask_bytes =
|
||||
unsafe { std::slice::from_raw_parts(mask.as_ptr() as *const u8, mask.len() * 4) };
|
||||
unsafe {
|
||||
let res = cudarc::driver::sys::cuMemcpyHtoD_v2(
|
||||
out_buf.ptr(),
|
||||
mask_bytes.as_ptr() as *const std::ffi::c_void,
|
||||
mask_bytes.len(),
|
||||
);
|
||||
if res != cudarc::driver::sys::CUresult::CUDA_SUCCESS {
|
||||
anyhow::bail!("ComputeAttnMask cuMemcpyHtoD failed: {res:?}");
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
self.s_dim * self.c_dim
|
||||
}
|
||||
|
||||
fn output_bytes(&self) -> Expression {
|
||||
self.output_size() * 4
|
||||
}
|
||||
|
||||
fn stats_name(&self) -> Option<&'static str> {
|
||||
Some("ComputeAttnMask")
|
||||
}
|
||||
}
|
||||
|
||||
fn dtoh_i32(stream: &Arc<CudaStream>, dev_ptr: u64, len: usize) -> anyhow::Result<Vec<i32>> {
|
||||
let mut host = vec![0u8; len * std::mem::size_of::<i32>()];
|
||||
unsafe {
|
||||
result::memcpy_dtoh_async(&mut host, dev_ptr, stream.cu_stream())?;
|
||||
}
|
||||
stream.synchronize()?;
|
||||
let v = unsafe {
|
||||
let mut bytes = std::mem::ManuallyDrop::new(host);
|
||||
Vec::from_raw_parts(bytes.as_mut_ptr() as *mut i32, len, len)
|
||||
};
|
||||
Ok(v)
|
||||
}
|
||||
|
||||
/// Given an indptr array `[0, a, b, ...]`, find which segment `idx` belongs to.
|
||||
/// Returns `count(indptr[i] <= idx) - 1`.
|
||||
fn indptr_to_request(indptr: &[i32], idx: i32) -> i32 {
|
||||
indptr.iter().filter(|&&v| v <= idx).count() as i32 - 1
|
||||
}
|
||||
@@ -19,9 +19,9 @@ use crate::{
|
||||
CudaBlas,
|
||||
sys::{cublasOperation_t, cublasSetStream_v2, cublasSgemm_v2, cublasStatus_t},
|
||||
},
|
||||
driver::CudaStream,
|
||||
driver::{CudaSlice, CudaStream, DevicePtr},
|
||||
},
|
||||
host::{DeviceBuffer, HostOp},
|
||||
host::HostOp,
|
||||
};
|
||||
|
||||
/// Global shared cuBLAS handle to avoid per-operation workspace allocation
|
||||
@@ -156,7 +156,7 @@ impl HostOp for CuBlasSgemmV2 {
|
||||
stream: &Arc<CudaStream>,
|
||||
self_node: NodeIndex,
|
||||
inputs: &[NodeIndex],
|
||||
buffers: &FxHashMap<NodeIndex, DeviceBuffer>,
|
||||
buffers: &FxHashMap<NodeIndex, &CudaSlice<u8>>,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> anyhow::Result<()> {
|
||||
// GEMM parameters
|
||||
@@ -178,9 +178,9 @@ impl HostOp for CuBlasSgemmV2 {
|
||||
let b_buf = buffers[&inputs[1]];
|
||||
|
||||
// Get device pointers
|
||||
let a_ptr = a_buf.ptr();
|
||||
let b_ptr = b_buf.ptr();
|
||||
let c_ptr = c_buf.ptr();
|
||||
let (a_ptr, _a_guard) = a_buf.device_ptr(stream);
|
||||
let (b_ptr, _b_guard) = b_buf.device_ptr(stream);
|
||||
let (c_ptr, _c_guard) = c_buf.device_ptr(stream);
|
||||
|
||||
// Debug: Check buffer sizes
|
||||
trace!(
|
||||
|
||||
@@ -68,6 +68,5 @@
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublas sgemm column-major × column-major"
|
||||
)
|
||||
|
||||
@@ -68,6 +68,5 @@
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublas sgemm column-major × row-major"
|
||||
)
|
||||
|
||||
@@ -68,6 +68,5 @@
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublas sgemm row-major × column-major"
|
||||
)
|
||||
)
|
||||
@@ -68,6 +68,5 @@
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublas sgemm row-major"
|
||||
)
|
||||
)
|
||||
@@ -42,7 +42,6 @@
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
(cublaslt_base_dtype ?dt)
|
||||
)
|
||||
(
|
||||
; For column-major A × column-major B with cuBLAS:
|
||||
@@ -53,22 +52,18 @@
|
||||
?k ; k unchanged
|
||||
"T" ; transa = Transpose (B is column-major [k,n], need B^T[n,k])
|
||||
"T" ; transb = Transpose (A is column-major [m,k], need A^T[k,m])
|
||||
"COL" "COL" "COL" "COL" ; A/B/C/D matrix orders
|
||||
?b_n_stride ; lda = B's column stride (resolves to k after z→1)
|
||||
?a_k_stride ; ldb = A's column stride (resolves to m after z→1)
|
||||
?n ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
|
||||
?n ; ldd = ldc for current row-major output rewrites
|
||||
(MNum 1) ; batch_count = 1
|
||||
(MNum 0) ; stride_a = 0
|
||||
(MNum 0) ; stride_b = 0
|
||||
(MNum 0) ; stride_c = 0
|
||||
(MNum 0) ; stride_d = 0
|
||||
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT") ; type tuple, alpha, beta
|
||||
?dt) ; dtype
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt column-major × column-major"
|
||||
)
|
||||
|
||||
@@ -116,28 +111,23 @@
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
(cublaslt_base_dtype ?dt)
|
||||
)
|
||||
(
|
||||
; cuBLAS: cublas(OP_T, OP_T, n, m, k, B, lda=b_n_stride, A, ldb=a_k_stride, C, ldc=n)
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?n ?m ?k
|
||||
"T" "T"
|
||||
"COL" "COL" "COL" "COL"
|
||||
?b_n_stride ; lda (cuBLAS A = our B, column stride)
|
||||
?a_k_stride ; ldb (cuBLAS B = our A, column stride)
|
||||
?n ; ldc
|
||||
?n ; ldd
|
||||
?batch
|
||||
?b_batch_stride ; stride_a (cuBLAS A = our B)
|
||||
?a_batch_stride ; stride_b (cuBLAS B = our A)
|
||||
(MMul ?m ?n) ; stride_c
|
||||
(MMul ?m ?n) ; stride_d
|
||||
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT")
|
||||
?dt)
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt batched column-major × column-major"
|
||||
)
|
||||
|
||||
@@ -42,7 +42,6 @@
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
(cublaslt_base_dtype ?dt)
|
||||
)
|
||||
(
|
||||
; For column-major A × row-major B with cuBLAS:
|
||||
@@ -53,22 +52,18 @@
|
||||
?k ; k unchanged
|
||||
"N" ; transa = No transpose (B is row-major, viewed as col-major [n,k])
|
||||
"T" ; transb = Transpose (A is column-major [m,k], need A^T[k,m])
|
||||
"COL" "COL" "COL" "COL" ; A/B/C/D matrix orders
|
||||
?b_k_stride ; lda = B's row stride (resolves to n after z→1)
|
||||
?a_k_stride ; ldb = A's column stride (resolves to m after z→1)
|
||||
?n ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
|
||||
?n ; ldd = ldc for current row-major output rewrites
|
||||
(MNum 1) ; batch_count = 1
|
||||
(MNum 0) ; stride_a = 0
|
||||
(MNum 0) ; stride_b = 0
|
||||
(MNum 0) ; stride_c = 0
|
||||
(MNum 0) ; stride_d = 0
|
||||
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT") ; type tuple, alpha, beta
|
||||
?dt) ; dtype
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt column-major × row-major"
|
||||
)
|
||||
|
||||
@@ -116,28 +111,23 @@
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
(cublaslt_base_dtype ?dt)
|
||||
)
|
||||
(
|
||||
; cuBLAS: cublas(OP_N, OP_T, n, m, k, B, lda=b_k_stride, A, ldb=a_k_stride, C, ldc=n)
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?n ?m ?k
|
||||
"N" "T"
|
||||
"COL" "COL" "COL" "COL"
|
||||
?b_k_stride ; lda (cuBLAS A = our B, row stride)
|
||||
?a_k_stride ; ldb (cuBLAS B = our A, column stride)
|
||||
?n ; ldc
|
||||
?n ; ldd
|
||||
?batch
|
||||
?b_batch_stride ; stride_a (cuBLAS A = our B)
|
||||
?a_batch_stride ; stride_b (cuBLAS B = our A)
|
||||
(MMul ?m ?n) ; stride_c
|
||||
(MMul ?m ?n) ; stride_d
|
||||
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT")
|
||||
?dt)
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt batched column-major × row-major"
|
||||
)
|
||||
|
||||
@@ -42,7 +42,6 @@
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
(cublaslt_base_dtype ?dt)
|
||||
)
|
||||
(
|
||||
; For row-major A × column-major B with cuBLAS:
|
||||
@@ -53,22 +52,18 @@
|
||||
?k ; k unchanged
|
||||
"T" ; transa = Transpose (B is column-major, need B^T)
|
||||
"N" ; transb = No transpose
|
||||
"COL" "COL" "COL" "COL" ; A/B/C/D matrix orders
|
||||
?b_n_stride ; lda = B's column stride (resolves to k after z→1)
|
||||
?a_m_stride ; ldb = A's row stride (resolves to k after z→1)
|
||||
?n ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
|
||||
?n ; ldd = ldc for current row-major output rewrites
|
||||
(MNum 1) ; batch_count = 1
|
||||
(MNum 0) ; stride_a = 0
|
||||
(MNum 0) ; stride_b = 0
|
||||
(MNum 0) ; stride_c = 0
|
||||
(MNum 0) ; stride_d = 0
|
||||
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT") ; type tuple, alpha, beta
|
||||
?dt) ; dtype
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt row-major × column-major"
|
||||
)
|
||||
|
||||
@@ -116,28 +111,23 @@
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
(cublaslt_base_dtype ?dt)
|
||||
)
|
||||
(
|
||||
; cuBLAS: cublas(OP_T, OP_N, n, m, k, B, lda=b_n_stride, A, ldb=a_m_stride, C, ldc=n)
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?n ?m ?k
|
||||
"T" "N"
|
||||
"COL" "COL" "COL" "COL"
|
||||
?b_n_stride ; lda (cuBLAS A = our B, column stride)
|
||||
?a_m_stride ; ldb (cuBLAS B = our A, row stride)
|
||||
?n ; ldc
|
||||
?n ; ldd
|
||||
?batch
|
||||
?b_batch_stride ; stride_a (cuBLAS A = our B)
|
||||
?a_batch_stride ; stride_b (cuBLAS B = our A)
|
||||
(MMul ?m ?n) ; stride_c
|
||||
(MMul ?m ?n) ; stride_d
|
||||
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT")
|
||||
?dt)
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt batched row-major × column-major"
|
||||
)
|
||||
|
||||
@@ -42,7 +42,6 @@
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
(cublaslt_base_dtype ?dt)
|
||||
)
|
||||
(
|
||||
; For row-major C = A × B with cuBLAS (column-major):
|
||||
@@ -53,22 +52,18 @@
|
||||
?k ; k unchanged
|
||||
"N" ; transa = No transpose
|
||||
"N" ; transb = No transpose
|
||||
"COL" "COL" "COL" "COL" ; A/B/C/D matrix orders
|
||||
?b_k_stride ; lda = B's row stride (resolves to n after z→1)
|
||||
?a_m_stride ; ldb = A's row stride (resolves to k after z→1)
|
||||
?n ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
|
||||
?n ; ldd = ldc for current row-major output rewrites
|
||||
(MNum 1) ; batch_count = 1
|
||||
(MNum 0) ; stride_a = 0
|
||||
(MNum 0) ; stride_b = 0
|
||||
(MNum 0) ; stride_c = 0
|
||||
(MNum 0) ; stride_d = 0
|
||||
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT") ; type tuple, alpha, beta
|
||||
?dt) ; dtype
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt row-major x row-major"
|
||||
)
|
||||
|
||||
@@ -121,7 +116,6 @@
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
(cublaslt_base_dtype ?dt)
|
||||
)
|
||||
(
|
||||
; cuBLAS swap: C^T[n,m] = B^T[n,k] × A^T[k,m] per batch
|
||||
@@ -129,21 +123,17 @@
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?n ?m ?k
|
||||
"N" "N"
|
||||
"COL" "COL" "COL" "COL"
|
||||
?b_k_stride ; lda (cuBLAS A = our B, row stride)
|
||||
?a_m_stride ; ldb (cuBLAS B = our A, row stride)
|
||||
?n ; ldc (contiguous output per batch)
|
||||
?n ; ldd
|
||||
?batch ; batch_count
|
||||
?b_batch_stride ; stride_a (cuBLAS A = our B)
|
||||
?a_batch_stride ; stride_b (cuBLAS B = our A)
|
||||
(MMul ?m ?n) ; stride_c
|
||||
(MMul ?m ?n) ; stride_d
|
||||
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT")
|
||||
?dt)
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt batched row-major × row-major"
|
||||
)
|
||||
|
||||
@@ -1,428 +0,0 @@
|
||||
; Fuse a row-major Add on top of an existing cuBLASLt matmul into
|
||||
; D = alpha * A * B + beta * C.
|
||||
;
|
||||
; The existing matmul rewrites view Luminal's row-major output [m,n] as a
|
||||
; column-major cuBLASLt matrix [n,m]. A row-major C input with logical strides
|
||||
; [row_stride, 1] therefore maps to ldc=row_stride. This lets a C slice from a
|
||||
; wider parent tensor use a larger ldc while D keeps the matmul output layout.
|
||||
; cuBLASLt requires out-of-place C and D to have the same matrix order, so these
|
||||
; beta rules only fuse C layouts that map to the current COL-ordered D layout.
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?matmul_c_order "COL"
|
||||
?lda ?ldb ?matmul_ldc ?ldd
|
||||
(MNum 1)
|
||||
?stride_a ?stride_b ?matmul_stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
(!= ?epilogue "RELU")
|
||||
(!= ?epilogue "RELU_BIAS")
|
||||
(!= ?epilogue "GELU")
|
||||
(!= ?epilogue "GELU_BIAS")
|
||||
|
||||
(= ?add (Op (Add
|
||||
(ECons ?n (ECons ?m (ENil)))
|
||||
?matmul_add_strides
|
||||
?c_add_strides
|
||||
?add_out_strides)
|
||||
(ICons ?matmul (ICons ?c (INil)))))
|
||||
|
||||
(= ?matmul_add_strides (ECons ?d_row_stride (ECons ?d_col_stride (ENil))))
|
||||
(= ?c_add_strides (ECons ?c_row_stride (ECons ?c_col_stride (ENil))))
|
||||
(= ?add_out_strides (ECons ?d_row_stride (ECons ?d_col_stride (ENil))))
|
||||
(= ?c_col_stride (MIter))
|
||||
(!= ?c_row_stride (MNum 0))
|
||||
(= ?matmul_add_strides ?add_out_strides)
|
||||
(= ?c_dtype (dtype ?c))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order "COL" "COL"
|
||||
?lda ?ldb ?c_row_stride ?ldd
|
||||
(MNum 1)
|
||||
?stride_a ?stride_b (MNum 0) ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 1.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b (ICons ?c ?matmul_tail)))))
|
||||
(union ?add ?fused)
|
||||
(set (dtype ?fused) ?d_dtype)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt 2d matmul plus c beta"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?matmul_c_order "COL"
|
||||
?lda ?ldb ?matmul_ldc ?ldd
|
||||
(MNum 1)
|
||||
?stride_a ?stride_b ?matmul_stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
(!= ?epilogue "RELU")
|
||||
(!= ?epilogue "RELU_BIAS")
|
||||
(!= ?epilogue "GELU")
|
||||
(!= ?epilogue "GELU_BIAS")
|
||||
|
||||
(= ?add (Op (Add
|
||||
(ECons ?n (ECons ?m (ENil)))
|
||||
?c_add_strides
|
||||
?matmul_add_strides
|
||||
?add_out_strides)
|
||||
(ICons ?c (ICons ?matmul (INil)))))
|
||||
|
||||
(= ?matmul_add_strides (ECons ?d_row_stride (ECons ?d_col_stride (ENil))))
|
||||
(= ?c_add_strides (ECons ?c_row_stride (ECons ?c_col_stride (ENil))))
|
||||
(= ?add_out_strides (ECons ?d_row_stride (ECons ?d_col_stride (ENil))))
|
||||
(= ?c_col_stride (MIter))
|
||||
(!= ?c_row_stride (MNum 0))
|
||||
(= ?matmul_add_strides ?add_out_strides)
|
||||
(= ?c_dtype (dtype ?c))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order "COL" "COL"
|
||||
?lda ?ldb ?c_row_stride ?ldd
|
||||
(MNum 1)
|
||||
?stride_a ?stride_b (MNum 0) ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 1.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b (ICons ?c ?matmul_tail)))))
|
||||
(union ?add ?fused)
|
||||
(set (dtype ?fused) ?d_dtype)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt 2d c plus matmul beta"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?matmul_c_order "COL"
|
||||
?lda ?ldb ?matmul_ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?matmul_stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
(!= ?epilogue "RELU")
|
||||
(!= ?epilogue "RELU_BIAS")
|
||||
(!= ?epilogue "GELU")
|
||||
(!= ?epilogue "GELU_BIAS")
|
||||
|
||||
(= ?add (Op (Add
|
||||
(ECons ?batch (ECons ?n (ECons ?m (ENil))))
|
||||
?matmul_add_strides
|
||||
?c_add_strides
|
||||
?add_out_strides)
|
||||
(ICons ?matmul (ICons ?c (INil)))))
|
||||
|
||||
(= ?matmul_add_strides (ECons ?d_batch_stride (ECons ?d_row_stride (ECons ?d_col_stride (ENil)))))
|
||||
(= ?c_add_strides (ECons ?c_batch_stride (ECons ?c_row_stride (ECons ?c_col_stride (ENil)))))
|
||||
(= ?add_out_strides (ECons ?d_batch_stride (ECons ?d_row_stride (ECons ?d_col_stride (ENil)))))
|
||||
(= ?c_col_stride (MIter))
|
||||
(!= ?c_row_stride (MNum 0))
|
||||
(= ?matmul_add_strides ?add_out_strides)
|
||||
(= ?c_dtype (dtype ?c))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order "COL" "COL"
|
||||
?lda ?ldb ?c_row_stride ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?c_batch_stride ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 1.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b (ICons ?c ?matmul_tail)))))
|
||||
(union ?add ?fused)
|
||||
(set (dtype ?fused) ?d_dtype)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt batched matmul plus c beta"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?matmul_c_order "COL"
|
||||
?lda ?ldb ?matmul_ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?matmul_stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
(!= ?epilogue "RELU")
|
||||
(!= ?epilogue "RELU_BIAS")
|
||||
(!= ?epilogue "GELU")
|
||||
(!= ?epilogue "GELU_BIAS")
|
||||
|
||||
(= ?add (Op (Add
|
||||
(ECons ?batch (ECons ?n (ECons ?m (ENil))))
|
||||
?c_add_strides
|
||||
?matmul_add_strides
|
||||
?add_out_strides)
|
||||
(ICons ?c (ICons ?matmul (INil)))))
|
||||
|
||||
(= ?matmul_add_strides (ECons ?d_batch_stride (ECons ?d_row_stride (ECons ?d_col_stride (ENil)))))
|
||||
(= ?c_add_strides (ECons ?c_batch_stride (ECons ?c_row_stride (ECons ?c_col_stride (ENil)))))
|
||||
(= ?add_out_strides (ECons ?d_batch_stride (ECons ?d_row_stride (ECons ?d_col_stride (ENil)))))
|
||||
(= ?c_col_stride (MIter))
|
||||
(!= ?c_row_stride (MNum 0))
|
||||
(= ?matmul_add_strides ?add_out_strides)
|
||||
(= ?c_dtype (dtype ?c))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order "COL" "COL"
|
||||
?lda ?ldb ?c_row_stride ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?c_batch_stride ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 1.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b (ICons ?c ?matmul_tail)))))
|
||||
(union ?add ?fused)
|
||||
(set (dtype ?fused) ?d_dtype)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt batched c plus matmul beta"
|
||||
)
|
||||
|
||||
; ROW-ordered D beta fusions. These pair with cublaslt_row_order_rewrite.egg,
|
||||
; where the cuBLASLt problem dimensions match Luminal's logical output [m,n].
|
||||
; A row-major C input with logical strides [row_stride, 1] maps directly to a
|
||||
; ROW-ordered cuBLASLt C[m,n] descriptor with ldc=row_stride.
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?matmul_c_order "ROW"
|
||||
?lda ?ldb ?matmul_ldc ?ldd
|
||||
(MNum 1)
|
||||
?stride_a ?stride_b ?matmul_stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
(!= ?epilogue "RELU")
|
||||
(!= ?epilogue "RELU_BIAS")
|
||||
(!= ?epilogue "GELU")
|
||||
(!= ?epilogue "GELU_BIAS")
|
||||
|
||||
(= ?add (Op (Add
|
||||
(ECons ?m (ECons ?n (ENil)))
|
||||
?matmul_add_strides
|
||||
?c_add_strides
|
||||
?add_out_strides)
|
||||
(ICons ?matmul (ICons ?c (INil)))))
|
||||
|
||||
(= ?matmul_add_strides (ECons ?d_row_stride (ECons ?d_col_stride (ENil))))
|
||||
(= ?c_add_strides (ECons ?c_row_stride (ECons ?c_col_stride (ENil))))
|
||||
(= ?add_out_strides (ECons ?d_row_stride (ECons ?d_col_stride (ENil))))
|
||||
(= ?c_col_stride (MIter))
|
||||
(!= ?c_row_stride (MNum 0))
|
||||
(= ?matmul_add_strides ?add_out_strides)
|
||||
(= ?c_dtype (dtype ?c))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order "ROW" "ROW"
|
||||
?lda ?ldb ?c_row_stride ?ldd
|
||||
(MNum 1)
|
||||
?stride_a ?stride_b (MNum 0) ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 1.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b (ICons ?c ?matmul_tail)))))
|
||||
(union ?add ?fused)
|
||||
(set (dtype ?fused) ?d_dtype)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt row-order 2d matmul plus c beta"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?matmul_c_order "ROW"
|
||||
?lda ?ldb ?matmul_ldc ?ldd
|
||||
(MNum 1)
|
||||
?stride_a ?stride_b ?matmul_stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
(!= ?epilogue "RELU")
|
||||
(!= ?epilogue "RELU_BIAS")
|
||||
(!= ?epilogue "GELU")
|
||||
(!= ?epilogue "GELU_BIAS")
|
||||
|
||||
(= ?add (Op (Add
|
||||
(ECons ?m (ECons ?n (ENil)))
|
||||
?c_add_strides
|
||||
?matmul_add_strides
|
||||
?add_out_strides)
|
||||
(ICons ?c (ICons ?matmul (INil)))))
|
||||
|
||||
(= ?matmul_add_strides (ECons ?d_row_stride (ECons ?d_col_stride (ENil))))
|
||||
(= ?c_add_strides (ECons ?c_row_stride (ECons ?c_col_stride (ENil))))
|
||||
(= ?add_out_strides (ECons ?d_row_stride (ECons ?d_col_stride (ENil))))
|
||||
(= ?c_col_stride (MIter))
|
||||
(!= ?c_row_stride (MNum 0))
|
||||
(= ?matmul_add_strides ?add_out_strides)
|
||||
(= ?c_dtype (dtype ?c))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order "ROW" "ROW"
|
||||
?lda ?ldb ?c_row_stride ?ldd
|
||||
(MNum 1)
|
||||
?stride_a ?stride_b (MNum 0) ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 1.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b (ICons ?c ?matmul_tail)))))
|
||||
(union ?add ?fused)
|
||||
(set (dtype ?fused) ?d_dtype)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt row-order 2d c plus matmul beta"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?matmul_c_order "ROW"
|
||||
?lda ?ldb ?matmul_ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?matmul_stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
(!= ?epilogue "RELU")
|
||||
(!= ?epilogue "RELU_BIAS")
|
||||
(!= ?epilogue "GELU")
|
||||
(!= ?epilogue "GELU_BIAS")
|
||||
|
||||
(= ?add (Op (Add
|
||||
(ECons ?batch (ECons ?m (ECons ?n (ENil))))
|
||||
?matmul_add_strides
|
||||
?c_add_strides
|
||||
?add_out_strides)
|
||||
(ICons ?matmul (ICons ?c (INil)))))
|
||||
|
||||
(= ?matmul_add_strides (ECons ?d_batch_stride (ECons ?d_row_stride (ECons ?d_col_stride (ENil)))))
|
||||
(= ?c_add_strides (ECons ?c_batch_stride (ECons ?c_row_stride (ECons ?c_col_stride (ENil)))))
|
||||
(= ?add_out_strides (ECons ?d_batch_stride (ECons ?d_row_stride (ECons ?d_col_stride (ENil)))))
|
||||
(= ?c_col_stride (MIter))
|
||||
(!= ?c_row_stride (MNum 0))
|
||||
(= ?matmul_add_strides ?add_out_strides)
|
||||
(= ?c_dtype (dtype ?c))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order "ROW" "ROW"
|
||||
?lda ?ldb ?c_row_stride ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?c_batch_stride ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 1.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b (ICons ?c ?matmul_tail)))))
|
||||
(union ?add ?fused)
|
||||
(set (dtype ?fused) ?d_dtype)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt row-order batched matmul plus c beta"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?matmul_c_order "ROW"
|
||||
?lda ?ldb ?matmul_ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?matmul_stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
(!= ?epilogue "RELU")
|
||||
(!= ?epilogue "RELU_BIAS")
|
||||
(!= ?epilogue "GELU")
|
||||
(!= ?epilogue "GELU_BIAS")
|
||||
|
||||
(= ?add (Op (Add
|
||||
(ECons ?batch (ECons ?m (ECons ?n (ENil))))
|
||||
?c_add_strides
|
||||
?matmul_add_strides
|
||||
?add_out_strides)
|
||||
(ICons ?c (ICons ?matmul (INil)))))
|
||||
|
||||
(= ?matmul_add_strides (ECons ?d_batch_stride (ECons ?d_row_stride (ECons ?d_col_stride (ENil)))))
|
||||
(= ?c_add_strides (ECons ?c_batch_stride (ECons ?c_row_stride (ECons ?c_col_stride (ENil)))))
|
||||
(= ?add_out_strides (ECons ?d_batch_stride (ECons ?d_row_stride (ECons ?d_col_stride (ENil)))))
|
||||
(= ?c_col_stride (MIter))
|
||||
(!= ?c_row_stride (MNum 0))
|
||||
(= ?matmul_add_strides ?add_out_strides)
|
||||
(= ?c_dtype (dtype ?c))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order "ROW" "ROW"
|
||||
?lda ?ldb ?c_row_stride ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?c_batch_stride ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 1.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b (ICons ?c ?matmul_tail)))))
|
||||
(union ?add ?fused)
|
||||
(set (dtype ?fused) ?d_dtype)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt row-order batched c plus matmul beta"
|
||||
)
|
||||
@@ -1,614 +0,0 @@
|
||||
; cuBLASLt epilogue rewrites.
|
||||
;
|
||||
; ReLU in the frontend lowers through maximum_f32(0.0):
|
||||
;
|
||||
; (matmul < 0) * 0 + cast(cast((-cast(matmul < 0) + 1) as bool) as f32) * matmul
|
||||
;
|
||||
; These rules fuse that expression back into CUBLASLT_EPILOGUE_RELU.
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype (F32)
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "DEFAULT")
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
|
||||
(= ?zero (Op (Constant 0.0) (INil)))
|
||||
(= ?neg_one (Op (Constant -1.0) (INil)))
|
||||
(= ?one (Op (Constant 1.0) (INil)))
|
||||
|
||||
(= ?lt (Op (LessThan
|
||||
?shape
|
||||
?matmul_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ENil)))
|
||||
?mask_strides)
|
||||
(ICons ?matmul (ICons ?zero (INil)))))
|
||||
(= ?lt_f32 (Op (Cast ?size (F32)) (ICons ?lt (INil))))
|
||||
|
||||
(= ?zeroed (Op (Mul
|
||||
?shape
|
||||
?mask_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ENil)))
|
||||
?zeroed_strides)
|
||||
(ICons ?lt_f32 (ICons ?zero (INil)))))
|
||||
|
||||
(= ?neg_mask (Op (Mul
|
||||
?shape
|
||||
?mask_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ENil)))
|
||||
?neg_mask_strides)
|
||||
(ICons ?lt_f32 (ICons ?neg_one (INil)))))
|
||||
(= ?not_mask_f32 (Op (Add
|
||||
?shape
|
||||
?neg_mask_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ENil)))
|
||||
?not_mask_f32_strides)
|
||||
(ICons ?neg_mask (ICons ?one (INil)))))
|
||||
(= ?not_mask_bool (Op (Cast ?size (Bool)) (ICons ?not_mask_f32 (INil))))
|
||||
(= ?not_mask (Op (Cast ?size (F32)) (ICons ?not_mask_bool (INil))))
|
||||
|
||||
(= ?positive (Op (Mul
|
||||
?shape
|
||||
?not_mask_f32_strides
|
||||
?matmul_strides
|
||||
?positive_strides)
|
||||
(ICons ?not_mask (ICons ?matmul (INil)))))
|
||||
(= ?relu (Op (Add
|
||||
?shape
|
||||
?zeroed_strides
|
||||
?positive_strides
|
||||
?relu_strides)
|
||||
(ICons ?zeroed (ICons ?positive (INil)))))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype (F32)
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "RELU")
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
(union ?relu ?fused)
|
||||
(set (dtype ?fused) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt 2d relu epilogue"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype (F32)
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "DEFAULT")
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
|
||||
(= ?zero (Op (Constant 0.0) (INil)))
|
||||
(= ?neg_one (Op (Constant -1.0) (INil)))
|
||||
(= ?one (Op (Constant 1.0) (INil)))
|
||||
|
||||
(= ?lt (Op (LessThan
|
||||
?shape
|
||||
?matmul_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ECons (MNum 0) (ENil))))
|
||||
?mask_strides)
|
||||
(ICons ?matmul (ICons ?zero (INil)))))
|
||||
(= ?lt_f32 (Op (Cast ?size (F32)) (ICons ?lt (INil))))
|
||||
|
||||
(= ?zeroed (Op (Mul
|
||||
?shape
|
||||
?mask_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ECons (MNum 0) (ENil))))
|
||||
?zeroed_strides)
|
||||
(ICons ?lt_f32 (ICons ?zero (INil)))))
|
||||
|
||||
(= ?neg_mask (Op (Mul
|
||||
?shape
|
||||
?mask_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ECons (MNum 0) (ENil))))
|
||||
?neg_mask_strides)
|
||||
(ICons ?lt_f32 (ICons ?neg_one (INil)))))
|
||||
(= ?not_mask_f32 (Op (Add
|
||||
?shape
|
||||
?neg_mask_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ECons (MNum 0) (ENil))))
|
||||
?not_mask_f32_strides)
|
||||
(ICons ?neg_mask (ICons ?one (INil)))))
|
||||
(= ?not_mask_bool (Op (Cast ?size (Bool)) (ICons ?not_mask_f32 (INil))))
|
||||
(= ?not_mask (Op (Cast ?size (F32)) (ICons ?not_mask_bool (INil))))
|
||||
|
||||
(= ?positive (Op (Mul
|
||||
?shape
|
||||
?not_mask_f32_strides
|
||||
?matmul_strides
|
||||
?positive_strides)
|
||||
(ICons ?not_mask (ICons ?matmul (INil)))))
|
||||
(= ?relu (Op (Add
|
||||
?shape
|
||||
?zeroed_strides
|
||||
?positive_strides
|
||||
?relu_strides)
|
||||
(ICons ?zeroed (ICons ?positive (INil)))))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype (F32)
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "RELU")
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
(union ?relu ?fused)
|
||||
(set (dtype ?fused) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt batched relu epilogue"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype (F32)
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "BIAS")
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
|
||||
(= ?zero (Op (Constant 0.0) (INil)))
|
||||
(= ?neg_one (Op (Constant -1.0) (INil)))
|
||||
(= ?one (Op (Constant 1.0) (INil)))
|
||||
|
||||
(= ?lt (Op (LessThan
|
||||
?shape
|
||||
?matmul_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ENil)))
|
||||
?mask_strides)
|
||||
(ICons ?matmul (ICons ?zero (INil)))))
|
||||
(= ?lt_f32 (Op (Cast ?size (F32)) (ICons ?lt (INil))))
|
||||
|
||||
(= ?zeroed (Op (Mul
|
||||
?shape
|
||||
?mask_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ENil)))
|
||||
?zeroed_strides)
|
||||
(ICons ?lt_f32 (ICons ?zero (INil)))))
|
||||
|
||||
(= ?neg_mask (Op (Mul
|
||||
?shape
|
||||
?mask_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ENil)))
|
||||
?neg_mask_strides)
|
||||
(ICons ?lt_f32 (ICons ?neg_one (INil)))))
|
||||
(= ?not_mask_f32 (Op (Add
|
||||
?shape
|
||||
?neg_mask_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ENil)))
|
||||
?not_mask_f32_strides)
|
||||
(ICons ?neg_mask (ICons ?one (INil)))))
|
||||
(= ?not_mask_bool (Op (Cast ?size (Bool)) (ICons ?not_mask_f32 (INil))))
|
||||
(= ?not_mask (Op (Cast ?size (F32)) (ICons ?not_mask_bool (INil))))
|
||||
|
||||
(= ?positive (Op (Mul
|
||||
?shape
|
||||
?not_mask_f32_strides
|
||||
?matmul_strides
|
||||
?positive_strides)
|
||||
(ICons ?not_mask (ICons ?matmul (INil)))))
|
||||
(= ?relu (Op (Add
|
||||
?shape
|
||||
?zeroed_strides
|
||||
?positive_strides
|
||||
?relu_strides)
|
||||
(ICons ?zeroed (ICons ?positive (INil)))))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype (F32)
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "RELU_BIAS")
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
(union ?relu ?fused)
|
||||
(set (dtype ?fused) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt 2d relu bias epilogue"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype (F32)
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "BIAS")
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
|
||||
(= ?zero (Op (Constant 0.0) (INil)))
|
||||
(= ?neg_one (Op (Constant -1.0) (INil)))
|
||||
(= ?one (Op (Constant 1.0) (INil)))
|
||||
|
||||
(= ?lt (Op (LessThan
|
||||
?shape
|
||||
?matmul_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ECons (MNum 0) (ENil))))
|
||||
?mask_strides)
|
||||
(ICons ?matmul (ICons ?zero (INil)))))
|
||||
(= ?lt_f32 (Op (Cast ?size (F32)) (ICons ?lt (INil))))
|
||||
|
||||
(= ?zeroed (Op (Mul
|
||||
?shape
|
||||
?mask_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ECons (MNum 0) (ENil))))
|
||||
?zeroed_strides)
|
||||
(ICons ?lt_f32 (ICons ?zero (INil)))))
|
||||
|
||||
(= ?neg_mask (Op (Mul
|
||||
?shape
|
||||
?mask_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ECons (MNum 0) (ENil))))
|
||||
?neg_mask_strides)
|
||||
(ICons ?lt_f32 (ICons ?neg_one (INil)))))
|
||||
(= ?not_mask_f32 (Op (Add
|
||||
?shape
|
||||
?neg_mask_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ECons (MNum 0) (ENil))))
|
||||
?not_mask_f32_strides)
|
||||
(ICons ?neg_mask (ICons ?one (INil)))))
|
||||
(= ?not_mask_bool (Op (Cast ?size (Bool)) (ICons ?not_mask_f32 (INil))))
|
||||
(= ?not_mask (Op (Cast ?size (F32)) (ICons ?not_mask_bool (INil))))
|
||||
|
||||
(= ?positive (Op (Mul
|
||||
?shape
|
||||
?not_mask_f32_strides
|
||||
?matmul_strides
|
||||
?positive_strides)
|
||||
(ICons ?not_mask (ICons ?matmul (INil)))))
|
||||
(= ?relu (Op (Add
|
||||
?shape
|
||||
?zeroed_strides
|
||||
?positive_strides
|
||||
?relu_strides)
|
||||
(ICons ?zeroed (ICons ?positive (INil)))))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype (F32)
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "RELU_BIAS")
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
(union ?relu ?fused)
|
||||
(set (dtype ?fused) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt batched relu bias epilogue"
|
||||
)
|
||||
|
||||
; Canonical tanh-approx GELU can also appear directly as:
|
||||
;
|
||||
; x * sigmoid(1.5957691216 * x * (1 + 0.044715 * x * x))
|
||||
;
|
||||
; Match that sigmoid form and fuse it into the cuBLASLt GELU epilogues.
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype (F32)
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "DEFAULT")
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
|
||||
(= ?gelu_coeff_inner (Op (Constant 0.044715) (INil)))
|
||||
(= ?gelu_inner_scaled (Op (Mul ?gelu_inner_scaled_shape ?gelu_inner_scaled_a_stride ?gelu_inner_scaled_b_stride ?gelu_inner_scaled_out_stride) (ICons ?matmul (ICons ?gelu_coeff_inner (INil)))))
|
||||
(= ?gelu_inner_quad (Op (Mul ?gelu_inner_quad_shape ?gelu_inner_quad_a_stride ?gelu_inner_quad_b_stride ?gelu_inner_quad_out_stride) (ICons ?gelu_inner_scaled (ICons ?matmul (INil)))))
|
||||
(= ?gelu_one (Op (Constant 1.000000) (INil)))
|
||||
(= ?gelu_poly (Op (Add ?gelu_poly_shape ?gelu_poly_a_stride ?gelu_poly_b_stride ?gelu_poly_out_stride) (ICons ?gelu_inner_quad (ICons ?gelu_one (INil)))))
|
||||
(= ?gelu_coeff_outer (Op (Constant 1.595769) (INil)))
|
||||
(= ?gelu_outer_scaled (Op (Mul ?gelu_outer_scaled_shape ?gelu_outer_scaled_a_stride ?gelu_outer_scaled_b_stride ?gelu_outer_scaled_out_stride) (ICons ?matmul (ICons ?gelu_coeff_outer (INil)))))
|
||||
(= ?gelu_scaled (Op (Mul ?gelu_scaled_shape ?gelu_scaled_a_stride ?gelu_scaled_b_stride ?gelu_scaled_out_stride) (ICons ?gelu_outer_scaled (ICons ?gelu_poly (INil)))))
|
||||
(= ?neg1 (Op (Constant -1.000000) (INil)))
|
||||
(= ?gelu_neg (Op (Mul ?gelu_neg_shape ?gelu_neg_a_stride ?gelu_neg_b_stride ?gelu_neg_out_stride) (ICons ?gelu_scaled (ICons ?neg1 (INil)))))
|
||||
(= ?log2e (Op (Constant 1.442695) (INil)))
|
||||
(= ?gelu_exp_scaled (Op (Mul ?gelu_exp_scaled_shape ?gelu_exp_scaled_a_stride ?gelu_exp_scaled_b_stride ?gelu_exp_scaled_out_stride) (ICons ?gelu_neg (ICons ?log2e (INil)))))
|
||||
(= ?gelu_exp2_val (Op (Exp2 ?gelu_exp_shape ?gelu_exp_in_stride ?gelu_exp_out_stride) (ICons ?gelu_exp_scaled (INil))))
|
||||
(= ?gelu_plus1 (Op (Add ?gelu_plus1_shape ?gelu_plus1_a_stride ?gelu_plus1_b_stride ?gelu_plus1_out_stride) (ICons ?gelu_exp2_val (ICons ?gelu_one (INil)))))
|
||||
(= ?gelu_sigmoid (Op (Recip ?gelu_sigmoid_shape ?gelu_sigmoid_in_stride ?gelu_sigmoid_out_stride) (ICons ?gelu_plus1 (INil))))
|
||||
(= ?gelu_out (Op (Mul ?gelu_out_shape ?gelu_out_a_stride ?gelu_out_b_stride ?gelu_out_out_stride) (ICons ?matmul (ICons ?gelu_sigmoid (INil)))))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype (F32)
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "GELU")
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
(union ?gelu_out ?fused)
|
||||
(set (dtype ?fused) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt gelu epilogue"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype (F32)
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "BIAS")
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
|
||||
(= ?gelu_coeff_inner (Op (Constant 0.044715) (INil)))
|
||||
(= ?gelu_inner_scaled (Op (Mul ?gelu_inner_scaled_shape ?gelu_inner_scaled_a_stride ?gelu_inner_scaled_b_stride ?gelu_inner_scaled_out_stride) (ICons ?matmul (ICons ?gelu_coeff_inner (INil)))))
|
||||
(= ?gelu_inner_quad (Op (Mul ?gelu_inner_quad_shape ?gelu_inner_quad_a_stride ?gelu_inner_quad_b_stride ?gelu_inner_quad_out_stride) (ICons ?gelu_inner_scaled (ICons ?matmul (INil)))))
|
||||
(= ?gelu_one (Op (Constant 1.000000) (INil)))
|
||||
(= ?gelu_poly (Op (Add ?gelu_poly_shape ?gelu_poly_a_stride ?gelu_poly_b_stride ?gelu_poly_out_stride) (ICons ?gelu_inner_quad (ICons ?gelu_one (INil)))))
|
||||
(= ?gelu_coeff_outer (Op (Constant 1.595769) (INil)))
|
||||
(= ?gelu_outer_scaled (Op (Mul ?gelu_outer_scaled_shape ?gelu_outer_scaled_a_stride ?gelu_outer_scaled_b_stride ?gelu_outer_scaled_out_stride) (ICons ?matmul (ICons ?gelu_coeff_outer (INil)))))
|
||||
(= ?gelu_scaled (Op (Mul ?gelu_scaled_shape ?gelu_scaled_a_stride ?gelu_scaled_b_stride ?gelu_scaled_out_stride) (ICons ?gelu_outer_scaled (ICons ?gelu_poly (INil)))))
|
||||
(= ?neg1 (Op (Constant -1.000000) (INil)))
|
||||
(= ?gelu_neg (Op (Mul ?gelu_neg_shape ?gelu_neg_a_stride ?gelu_neg_b_stride ?gelu_neg_out_stride) (ICons ?gelu_scaled (ICons ?neg1 (INil)))))
|
||||
(= ?log2e (Op (Constant 1.442695) (INil)))
|
||||
(= ?gelu_exp_scaled (Op (Mul ?gelu_exp_scaled_shape ?gelu_exp_scaled_a_stride ?gelu_exp_scaled_b_stride ?gelu_exp_scaled_out_stride) (ICons ?gelu_neg (ICons ?log2e (INil)))))
|
||||
(= ?gelu_exp2_val (Op (Exp2 ?gelu_exp_shape ?gelu_exp_in_stride ?gelu_exp_out_stride) (ICons ?gelu_exp_scaled (INil))))
|
||||
(= ?gelu_plus1 (Op (Add ?gelu_plus1_shape ?gelu_plus1_a_stride ?gelu_plus1_b_stride ?gelu_plus1_out_stride) (ICons ?gelu_exp2_val (ICons ?gelu_one (INil)))))
|
||||
(= ?gelu_sigmoid (Op (Recip ?gelu_sigmoid_shape ?gelu_sigmoid_in_stride ?gelu_sigmoid_out_stride) (ICons ?gelu_plus1 (INil))))
|
||||
(= ?gelu_out (Op (Mul ?gelu_out_shape ?gelu_out_a_stride ?gelu_out_b_stride ?gelu_out_out_stride) (ICons ?matmul (ICons ?gelu_sigmoid (INil)))))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype (F32)
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "GELU_BIAS")
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
(union ?gelu_out ?fused)
|
||||
(set (dtype ?fused) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt gelu bias epilogue"
|
||||
)
|
||||
|
||||
; This first slice fuses column-bias adds into CUBLASLT_EPILOGUE_BIAS for the
|
||||
; older COL-ordered output view. In that view Luminal's logical [m,n] output is
|
||||
; represented as a cuBLASLt [n,m] matrix, so cuBLASLt's row-broadcast bias maps
|
||||
; to the common logical column bias of length n.
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order "COL"
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
(MNum 1)
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "DEFAULT")
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
(= ?add (Op (Add
|
||||
(ECons ?n (ECons ?m (ENil)))
|
||||
?matmul_add_strides
|
||||
?bias_add_strides
|
||||
?add_out_strides)
|
||||
(ICons ?matmul (ICons ?bias (INil)))))
|
||||
|
||||
(= ?bias_add_strides (ECons (MNum 0) (ECons (MIter) (ENil))))
|
||||
(= ?matmul_add_strides ?add_out_strides)
|
||||
(= ?d_dtype (dtype ?bias))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order "COL"
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
(MNum 1)
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "BIAS")
|
||||
(ICons ?a (ICons ?b (ICons ?bias (INil))))))
|
||||
(union ?add ?fused)
|
||||
(set (dtype ?fused) ?d_dtype)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt 2d matmul plus column bias epilogue"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order "COL"
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
(MNum 1)
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "DEFAULT")
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
(= ?add (Op (Add
|
||||
(ECons ?n (ECons ?m (ENil)))
|
||||
?bias_add_strides
|
||||
?matmul_add_strides
|
||||
?add_out_strides)
|
||||
(ICons ?bias (ICons ?matmul (INil)))))
|
||||
|
||||
(= ?bias_add_strides (ECons (MNum 0) (ECons (MIter) (ENil))))
|
||||
(= ?matmul_add_strides ?add_out_strides)
|
||||
(= ?d_dtype (dtype ?bias))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order "COL"
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
(MNum 1)
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "BIAS")
|
||||
(ICons ?a (ICons ?b (ICons ?bias (INil))))))
|
||||
(union ?add ?fused)
|
||||
(set (dtype ?fused) ?d_dtype)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt 2d column bias plus matmul epilogue"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order "COL"
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "DEFAULT")
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
(= ?add (Op (Add
|
||||
(ECons ?batch (ECons ?n (ECons ?m (ENil))))
|
||||
?matmul_add_strides
|
||||
?bias_add_strides
|
||||
?add_out_strides)
|
||||
(ICons ?matmul (ICons ?bias (INil)))))
|
||||
|
||||
(= ?bias_add_strides (ECons (MNum 0) (ECons (MNum 0) (ECons (MIter) (ENil)))))
|
||||
(= ?matmul_add_strides ?add_out_strides)
|
||||
(= ?d_dtype (dtype ?bias))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order "COL"
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "BIAS")
|
||||
(ICons ?a (ICons ?b (ICons ?bias (INil))))))
|
||||
(union ?add ?fused)
|
||||
(set (dtype ?fused) ?d_dtype)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt batched matmul plus column bias epilogue"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order "COL"
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "DEFAULT")
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
(= ?add (Op (Add
|
||||
(ECons ?batch (ECons ?n (ECons ?m (ENil))))
|
||||
?bias_add_strides
|
||||
?matmul_add_strides
|
||||
?add_out_strides)
|
||||
(ICons ?bias (ICons ?matmul (INil)))))
|
||||
|
||||
(= ?bias_add_strides (ECons (MNum 0) (ECons (MNum 0) (ECons (MIter) (ENil)))))
|
||||
(= ?matmul_add_strides ?add_out_strides)
|
||||
(= ?d_dtype (dtype ?bias))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order "COL"
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "BIAS")
|
||||
(ICons ?a (ICons ?b (ICons ?bias (INil))))))
|
||||
(union ?add ?fused)
|
||||
(set (dtype ?fused) ?d_dtype)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt batched column bias plus matmul epilogue"
|
||||
)
|
||||
@@ -1,345 +0,0 @@
|
||||
; FP8 support is narrower than "any FP8 x any FP8". cuBLASLt's regular FP8
|
||||
; matmul table supports these A/B descriptor pairs for F32 outputs:
|
||||
; E4M3 x E4M3
|
||||
; E4M3 x E5M2
|
||||
; E5M2 x E4M3
|
||||
; and requires TN format on Ada/Hopper-class GPUs. These rules therefore match
|
||||
; row-major x column-major Luminal matmuls, which the existing COL-order lowering
|
||||
; describes as descriptor A = logical B, descriptor B = logical A, transa=T,
|
||||
; transb=N.
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?cast (Op (Cast ?size (F32)) (ICons ?sum (INil))))
|
||||
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
|
||||
(= ?a_stride (ECons ?a_m_stride (ECons ?a_n_stride (ECons ?a_k_stride (ENil)))))
|
||||
(= ?b_stride (ECons ?b_m_stride (ECons ?b_n_stride (ECons ?b_k_stride (ENil)))))
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
(= ?a_m_stride (MMul (MIter) ?k))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MIter))
|
||||
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MMul (MIter) ?k))
|
||||
(= ?b_k_stride (MIter))
|
||||
|
||||
(= (F8E4M3) (dtype ?a))
|
||||
(= (F8E4M3) (dtype ?b))
|
||||
)
|
||||
(
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?n ?m ?k
|
||||
"T" "N"
|
||||
"COL" "COL" "COL" "COL"
|
||||
?b_n_stride
|
||||
?a_m_stride
|
||||
?n
|
||||
?n
|
||||
(MNum 1)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(F8E4M3) (F8E4M3) (F32) (F32) "32F" "F32" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?cast ?sgemm)
|
||||
(set (dtype ?sgemm) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt fp8 e4m3/e4m3 row-major x column-major f32 output"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?cast (Op (Cast ?size (F32)) (ICons ?sum (INil))))
|
||||
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
|
||||
(= ?a_stride (ECons ?a_m_stride (ECons ?a_n_stride (ECons ?a_k_stride (ENil)))))
|
||||
(= ?b_stride (ECons ?b_m_stride (ECons ?b_n_stride (ECons ?b_k_stride (ENil)))))
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
(= ?a_m_stride (MMul (MIter) ?k))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MIter))
|
||||
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MMul (MIter) ?k))
|
||||
(= ?b_k_stride (MIter))
|
||||
|
||||
(= (F8E4M3) (dtype ?a))
|
||||
(= (F8E5M2) (dtype ?b))
|
||||
)
|
||||
(
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?n ?m ?k
|
||||
"T" "N"
|
||||
"COL" "COL" "COL" "COL"
|
||||
?b_n_stride
|
||||
?a_m_stride
|
||||
?n
|
||||
?n
|
||||
(MNum 1)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(F8E5M2) (F8E4M3) (F32) (F32) "32F" "F32" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?cast ?sgemm)
|
||||
(set (dtype ?sgemm) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt fp8 e5m2/e4m3 row-major x column-major f32 output"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?cast (Op (Cast ?size (F32)) (ICons ?sum (INil))))
|
||||
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
|
||||
(= ?a_stride (ECons ?a_m_stride (ECons ?a_n_stride (ECons ?a_k_stride (ENil)))))
|
||||
(= ?b_stride (ECons ?b_m_stride (ECons ?b_n_stride (ECons ?b_k_stride (ENil)))))
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
(= ?a_m_stride (MMul (MIter) ?k))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MIter))
|
||||
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MMul (MIter) ?k))
|
||||
(= ?b_k_stride (MIter))
|
||||
|
||||
(= (F8E5M2) (dtype ?a))
|
||||
(= (F8E4M3) (dtype ?b))
|
||||
)
|
||||
(
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?n ?m ?k
|
||||
"T" "N"
|
||||
"COL" "COL" "COL" "COL"
|
||||
?b_n_stride
|
||||
?a_m_stride
|
||||
?n
|
||||
?n
|
||||
(MNum 1)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(F8E4M3) (F8E5M2) (F32) (F32) "32F" "F32" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?cast ?sgemm)
|
||||
(set (dtype ?sgemm) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt fp8 e4m3/e5m2 row-major x column-major f32 output"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?cast (Op (Cast ?size (F32)) (ICons ?sum (INil))))
|
||||
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
(!= ?batch (MNum 0))
|
||||
|
||||
(= ?a_batch_stride (nth_from_end ?a_stride 3))
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
(= ?b_batch_stride (nth_from_end ?b_stride 3))
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
(= ?a_k_stride (MIter))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_m_stride (MMul (MIter) ?k))
|
||||
|
||||
(= ?b_k_stride (MIter))
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MMul (MIter) ?k))
|
||||
|
||||
(= ?a_batch_stride (MMul ?m ?a_m_stride))
|
||||
(= ?b_batch_stride (MMul ?n ?b_n_stride))
|
||||
|
||||
(= (F8E4M3) (dtype ?a))
|
||||
(= (F8E4M3) (dtype ?b))
|
||||
)
|
||||
(
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?n ?m ?k
|
||||
"T" "N"
|
||||
"COL" "COL" "COL" "COL"
|
||||
?b_n_stride
|
||||
?a_m_stride
|
||||
?n
|
||||
?n
|
||||
?batch
|
||||
?b_batch_stride
|
||||
?a_batch_stride
|
||||
(MMul ?m ?n)
|
||||
(MMul ?m ?n)
|
||||
(F8E4M3) (F8E4M3) (F32) (F32) "32F" "F32" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?cast ?sgemm)
|
||||
(set (dtype ?sgemm) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt fp8 e4m3/e4m3 batched row-major x column-major f32 output"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?cast (Op (Cast ?size (F32)) (ICons ?sum (INil))))
|
||||
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
(!= ?batch (MNum 0))
|
||||
|
||||
(= ?a_batch_stride (nth_from_end ?a_stride 3))
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
(= ?b_batch_stride (nth_from_end ?b_stride 3))
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
(= ?a_k_stride (MIter))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_m_stride (MMul (MIter) ?k))
|
||||
|
||||
(= ?b_k_stride (MIter))
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MMul (MIter) ?k))
|
||||
|
||||
(= ?a_batch_stride (MMul ?m ?a_m_stride))
|
||||
(= ?b_batch_stride (MMul ?n ?b_n_stride))
|
||||
|
||||
(= (F8E4M3) (dtype ?a))
|
||||
(= (F8E5M2) (dtype ?b))
|
||||
)
|
||||
(
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?n ?m ?k
|
||||
"T" "N"
|
||||
"COL" "COL" "COL" "COL"
|
||||
?b_n_stride
|
||||
?a_m_stride
|
||||
?n
|
||||
?n
|
||||
?batch
|
||||
?b_batch_stride
|
||||
?a_batch_stride
|
||||
(MMul ?m ?n)
|
||||
(MMul ?m ?n)
|
||||
(F8E5M2) (F8E4M3) (F32) (F32) "32F" "F32" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?cast ?sgemm)
|
||||
(set (dtype ?sgemm) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt fp8 e5m2/e4m3 batched row-major x column-major f32 output"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?cast (Op (Cast ?size (F32)) (ICons ?sum (INil))))
|
||||
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
(!= ?batch (MNum 0))
|
||||
|
||||
(= ?a_batch_stride (nth_from_end ?a_stride 3))
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
(= ?b_batch_stride (nth_from_end ?b_stride 3))
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
(= ?a_k_stride (MIter))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_m_stride (MMul (MIter) ?k))
|
||||
|
||||
(= ?b_k_stride (MIter))
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MMul (MIter) ?k))
|
||||
|
||||
(= ?a_batch_stride (MMul ?m ?a_m_stride))
|
||||
(= ?b_batch_stride (MMul ?n ?b_n_stride))
|
||||
|
||||
(= (F8E5M2) (dtype ?a))
|
||||
(= (F8E4M3) (dtype ?b))
|
||||
)
|
||||
(
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?n ?m ?k
|
||||
"T" "N"
|
||||
"COL" "COL" "COL" "COL"
|
||||
?b_n_stride
|
||||
?a_m_stride
|
||||
?n
|
||||
?n
|
||||
?batch
|
||||
?b_batch_stride
|
||||
?a_batch_stride
|
||||
(MMul ?m ?n)
|
||||
(MMul ?m ?n)
|
||||
(F8E4M3) (F8E5M2) (F32) (F32) "32F" "F32" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?cast ?sgemm)
|
||||
(set (dtype ?sgemm) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt fp8 e4m3/e5m2 batched row-major x column-major f32 output"
|
||||
)
|
||||
@@ -1,75 +0,0 @@
|
||||
; Mixed output dtype rewrites for cuBLASLt.
|
||||
;
|
||||
; The first mixed mode we need for low-precision matmuls is:
|
||||
;
|
||||
; D[f32] = A[fp16/bf16] * B[fp16/bf16]
|
||||
;
|
||||
; Luminal graphs express this today as a Cast(F32) around a low-precision
|
||||
; matmul. cuBLASLt can write the f32 output directly, so expose that candidate
|
||||
; before beta fusion tries to consume an f32 C input.
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
(F16) (F16) (F16) (F16)
|
||||
?compute_type ?scale_dtype
|
||||
?alpha ?beta ?epilogue)
|
||||
?inputs))
|
||||
(= ?cast (Op (Cast ?size (F32)) (ICons ?matmul (INil))))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout ?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
(F16) (F16) (F32) (F32)
|
||||
?compute_type ?scale_dtype
|
||||
?alpha ?beta ?epilogue)
|
||||
?inputs))
|
||||
(union ?cast ?fused)
|
||||
(set (dtype ?fused) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt f16 matmul cast f32 output"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
(Bf16) (Bf16) (Bf16) (Bf16)
|
||||
?compute_type ?scale_dtype
|
||||
?alpha ?beta ?epilogue)
|
||||
?inputs))
|
||||
(= ?cast (Op (Cast ?size (F32)) (ICons ?matmul (INil))))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout ?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
(Bf16) (Bf16) (F32) (F32)
|
||||
?compute_type ?scale_dtype
|
||||
?alpha ?beta ?epilogue)
|
||||
?inputs))
|
||||
(union ?cast ?fused)
|
||||
(set (dtype ?fused) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt bf16 matmul cast f32 output"
|
||||
)
|
||||
@@ -1,452 +0,0 @@
|
||||
; Natural cuBLASLt row-order output rewrites. These keep Luminal's logical
|
||||
; output C[m,n] as a cuBLASLt ROW-ordered D[m,n] instead of using the older
|
||||
; swapped COL-ordered D[n,m] view. A and B orders mirror their matched logical
|
||||
; layouts, so this family is the legal base for future ROW-ordered beta fusions.
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
|
||||
(= ?a_stride (ECons ?a_m_stride (ECons ?a_n_stride (ECons ?a_k_stride (ENil)))))
|
||||
(= ?b_stride (ECons ?b_m_stride (ECons ?b_n_stride (ECons ?b_k_stride (ENil)))))
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
(= ?a_m_stride (MMul (MIter) ?k))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MIter))
|
||||
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MIter))
|
||||
(= ?b_k_stride (MMul (MIter) ?n))
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
(cublaslt_base_dtype ?dt)
|
||||
)
|
||||
(
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?m ?n ?k
|
||||
"N" "N"
|
||||
"ROW" "ROW" "ROW" "ROW"
|
||||
?a_m_stride
|
||||
?b_k_stride
|
||||
?n
|
||||
?n
|
||||
(MNum 1)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt row-order row-major x row-major"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
|
||||
(= ?a_stride (ECons ?a_m_stride (ECons ?a_n_stride (ECons ?a_k_stride (ENil)))))
|
||||
(= ?b_stride (ECons ?b_m_stride (ECons ?b_n_stride (ECons ?b_k_stride (ENil)))))
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
(= ?a_m_stride (MMul (MIter) ?k))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MIter))
|
||||
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MMul (MIter) ?k))
|
||||
(= ?b_k_stride (MIter))
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
(cublaslt_base_dtype ?dt)
|
||||
)
|
||||
(
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?m ?n ?k
|
||||
"N" "N"
|
||||
"ROW" "COL" "ROW" "ROW"
|
||||
?a_m_stride
|
||||
?b_n_stride
|
||||
?n
|
||||
?n
|
||||
(MNum 1)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt row-order row-major x column-major"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
|
||||
(= ?a_stride (ECons ?a_m_stride (ECons ?a_n_stride (ECons ?a_k_stride (ENil)))))
|
||||
(= ?b_stride (ECons ?b_m_stride (ECons ?b_n_stride (ECons ?b_k_stride (ENil)))))
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
(= ?a_m_stride (MIter))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MMul (MIter) ?m))
|
||||
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MIter))
|
||||
(= ?b_k_stride (MMul (MIter) ?n))
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
(cublaslt_base_dtype ?dt)
|
||||
)
|
||||
(
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?m ?n ?k
|
||||
"N" "N"
|
||||
"COL" "ROW" "ROW" "ROW"
|
||||
?a_k_stride
|
||||
?b_k_stride
|
||||
?n
|
||||
?n
|
||||
(MNum 1)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt row-order column-major x row-major"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
|
||||
(= ?a_stride (ECons ?a_m_stride (ECons ?a_n_stride (ECons ?a_k_stride (ENil)))))
|
||||
(= ?b_stride (ECons ?b_m_stride (ECons ?b_n_stride (ECons ?b_k_stride (ENil)))))
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
(= ?a_m_stride (MIter))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MMul (MIter) ?m))
|
||||
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MMul (MIter) ?k))
|
||||
(= ?b_k_stride (MIter))
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
(cublaslt_base_dtype ?dt)
|
||||
)
|
||||
(
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?m ?n ?k
|
||||
"N" "N"
|
||||
"COL" "COL" "ROW" "ROW"
|
||||
?a_k_stride
|
||||
?b_n_stride
|
||||
?n
|
||||
?n
|
||||
(MNum 1)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt row-order column-major x column-major"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
(!= ?batch (MNum 0))
|
||||
|
||||
(= ?a_batch_stride (nth_from_end ?a_stride 3))
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
(= ?b_batch_stride (nth_from_end ?b_stride 3))
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
(= ?a_k_stride (MIter))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_m_stride (MMul (MIter) ?k))
|
||||
|
||||
(= ?b_n_stride (MIter))
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_k_stride (MMul (MIter) ?n))
|
||||
|
||||
(= ?a_batch_stride (MMul ?m ?a_m_stride))
|
||||
(= ?b_batch_stride (MMul ?k ?b_k_stride))
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
(cublaslt_base_dtype ?dt)
|
||||
)
|
||||
(
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?m ?n ?k
|
||||
"N" "N"
|
||||
"ROW" "ROW" "ROW" "ROW"
|
||||
?a_m_stride
|
||||
?b_k_stride
|
||||
?n
|
||||
?n
|
||||
?batch
|
||||
?a_batch_stride
|
||||
?b_batch_stride
|
||||
(MMul ?m ?n)
|
||||
(MMul ?m ?n)
|
||||
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt row-order batched row-major x row-major"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
(!= ?batch (MNum 0))
|
||||
|
||||
(= ?a_batch_stride (nth_from_end ?a_stride 3))
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
(= ?b_batch_stride (nth_from_end ?b_stride 3))
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
(= ?a_k_stride (MIter))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_m_stride (MMul (MIter) ?k))
|
||||
|
||||
(= ?b_k_stride (MIter))
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MMul (MIter) ?k))
|
||||
|
||||
(= ?a_batch_stride (MMul ?m ?a_m_stride))
|
||||
(= ?b_batch_stride (MMul ?n ?b_n_stride))
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
(cublaslt_base_dtype ?dt)
|
||||
)
|
||||
(
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?m ?n ?k
|
||||
"N" "N"
|
||||
"ROW" "COL" "ROW" "ROW"
|
||||
?a_m_stride
|
||||
?b_n_stride
|
||||
?n
|
||||
?n
|
||||
?batch
|
||||
?a_batch_stride
|
||||
?b_batch_stride
|
||||
(MMul ?m ?n)
|
||||
(MMul ?m ?n)
|
||||
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt row-order batched row-major x column-major"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
(!= ?batch (MNum 0))
|
||||
|
||||
(= ?a_batch_stride (nth_from_end ?a_stride 3))
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
(= ?b_batch_stride (nth_from_end ?b_stride 3))
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
(= ?a_m_stride (MIter))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MMul (MIter) ?m))
|
||||
|
||||
(= ?b_n_stride (MIter))
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_k_stride (MMul (MIter) ?n))
|
||||
|
||||
(= ?a_batch_stride (MMul ?k ?a_k_stride))
|
||||
(= ?b_batch_stride (MMul ?k ?b_k_stride))
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
(cublaslt_base_dtype ?dt)
|
||||
)
|
||||
(
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?m ?n ?k
|
||||
"N" "N"
|
||||
"COL" "ROW" "ROW" "ROW"
|
||||
?a_k_stride
|
||||
?b_k_stride
|
||||
?n
|
||||
?n
|
||||
?batch
|
||||
?a_batch_stride
|
||||
?b_batch_stride
|
||||
(MMul ?m ?n)
|
||||
(MMul ?m ?n)
|
||||
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt row-order batched column-major x row-major"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
(!= ?batch (MNum 0))
|
||||
|
||||
(= ?a_batch_stride (nth_from_end ?a_stride 3))
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
(= ?b_batch_stride (nth_from_end ?b_stride 3))
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
(= ?a_m_stride (MIter))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MMul (MIter) ?m))
|
||||
|
||||
(= ?b_k_stride (MIter))
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MMul (MIter) ?k))
|
||||
|
||||
(= ?a_batch_stride (MMul ?k ?a_k_stride))
|
||||
(= ?b_batch_stride (MMul ?n ?b_n_stride))
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
(cublaslt_base_dtype ?dt)
|
||||
)
|
||||
(
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?m ?n ?k
|
||||
"N" "N"
|
||||
"COL" "COL" "ROW" "ROW"
|
||||
?a_k_stride
|
||||
?b_n_stride
|
||||
?n
|
||||
?n
|
||||
?batch
|
||||
?a_batch_stride
|
||||
?b_batch_stride
|
||||
(MMul ?m ?n)
|
||||
(MMul ?m ?n)
|
||||
?dt ?dt ?dt ?dt "default" "default" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt row-order batched column-major x column-major"
|
||||
)
|
||||
@@ -1,316 +0,0 @@
|
||||
; Scalar alpha/beta rewrites for cuBLASLt. These rules target scalar constants
|
||||
; expanded across the matmul/add shape, i.e. zero strides on every logical axis.
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
1.0 0.0 "DEFAULT")
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
|
||||
(= ?scale (Op (Constant ?alpha) (INil)))
|
||||
; alpha=1.0 hash-conses ?fused == ?matmul; the union merges Mul into ?matmul's eclass and saturate diverges.
|
||||
(!= ?alpha 1.0)
|
||||
(= ?scaled (Op (Mul ?shape
|
||||
?matmul_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ENil)))
|
||||
?scaled_out_strides)
|
||||
(ICons ?matmul (ICons ?scale (INil)))))
|
||||
(= ?matmul_strides ?scaled_out_strides)
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "DEFAULT")
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
(union ?scaled ?fused)
|
||||
(set (dtype ?fused) ?d_dtype)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt 2d alpha scale"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
1.0 0.0 "DEFAULT")
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
|
||||
(= ?scale (Op (Constant ?alpha) (INil)))
|
||||
; See 2d alpha scale: alpha=1.0 makes (saturate ...) diverge.
|
||||
(!= ?alpha 1.0)
|
||||
(= ?scaled (Op (Mul ?shape
|
||||
?matmul_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ECons (MNum 0) (ENil))))
|
||||
?scaled_out_strides)
|
||||
(ICons ?matmul (ICons ?scale (INil)))))
|
||||
(= ?matmul_strides ?scaled_out_strides)
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?c_order ?d_order
|
||||
?lda ?ldb ?ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 "DEFAULT")
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
(union ?scaled ?fused)
|
||||
(set (dtype ?fused) ?d_dtype)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt batched alpha scale"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?matmul_c_order "ROW"
|
||||
?lda ?ldb ?matmul_ldc ?ldd
|
||||
(MNum 1)
|
||||
?stride_a ?stride_b ?matmul_stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
|
||||
(= ?beta_node (Op (Constant ?beta) (INil)))
|
||||
(= ?scaled_c (Op (Mul
|
||||
(ECons ?m (ECons ?n (ENil)))
|
||||
?c_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ENil)))
|
||||
?scaled_c_out_strides)
|
||||
(ICons ?c (ICons ?beta_node (INil)))))
|
||||
|
||||
(= ?add (Op (Add
|
||||
(ECons ?m (ECons ?n (ENil)))
|
||||
?matmul_add_strides
|
||||
?scaled_c_add_strides
|
||||
?add_out_strides)
|
||||
(ICons ?matmul (ICons ?scaled_c (INil)))))
|
||||
|
||||
(= ?matmul_add_strides (ECons ?d_row_stride (ECons ?d_col_stride (ENil))))
|
||||
(= ?c_strides (ECons ?c_row_stride (ECons ?c_col_stride (ENil))))
|
||||
(= ?add_out_strides (ECons ?d_row_stride (ECons ?d_col_stride (ENil))))
|
||||
(= ?scaled_c_add_strides ?scaled_c_out_strides)
|
||||
(= ?c_col_stride (MIter))
|
||||
(!= ?c_row_stride (MNum 0))
|
||||
(= ?matmul_add_strides ?add_out_strides)
|
||||
(= ?c_dtype (dtype ?c))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order "ROW" "ROW"
|
||||
?lda ?ldb ?c_row_stride ?ldd
|
||||
(MNum 1)
|
||||
?stride_a ?stride_b (MNum 0) ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha ?beta ?epilogue)
|
||||
(ICons ?a (ICons ?b (ICons ?c ?matmul_tail)))))
|
||||
(union ?add ?fused)
|
||||
(set (dtype ?fused) ?d_dtype)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt row-order 2d scaled c beta"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?matmul_c_order "ROW"
|
||||
?lda ?ldb ?matmul_ldc ?ldd
|
||||
(MNum 1)
|
||||
?stride_a ?stride_b ?matmul_stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
|
||||
(= ?beta_node (Op (Constant ?beta) (INil)))
|
||||
(= ?scaled_c (Op (Mul
|
||||
(ECons ?m (ECons ?n (ENil)))
|
||||
?c_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ENil)))
|
||||
?scaled_c_out_strides)
|
||||
(ICons ?c (ICons ?beta_node (INil)))))
|
||||
|
||||
(= ?add (Op (Add
|
||||
(ECons ?m (ECons ?n (ENil)))
|
||||
?scaled_c_add_strides
|
||||
?matmul_add_strides
|
||||
?add_out_strides)
|
||||
(ICons ?scaled_c (ICons ?matmul (INil)))))
|
||||
|
||||
(= ?matmul_add_strides (ECons ?d_row_stride (ECons ?d_col_stride (ENil))))
|
||||
(= ?c_strides (ECons ?c_row_stride (ECons ?c_col_stride (ENil))))
|
||||
(= ?add_out_strides (ECons ?d_row_stride (ECons ?d_col_stride (ENil))))
|
||||
(= ?scaled_c_add_strides ?scaled_c_out_strides)
|
||||
(= ?c_col_stride (MIter))
|
||||
(!= ?c_row_stride (MNum 0))
|
||||
(= ?matmul_add_strides ?add_out_strides)
|
||||
(= ?c_dtype (dtype ?c))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order "ROW" "ROW"
|
||||
?lda ?ldb ?c_row_stride ?ldd
|
||||
(MNum 1)
|
||||
?stride_a ?stride_b (MNum 0) ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha ?beta ?epilogue)
|
||||
(ICons ?a (ICons ?b (ICons ?c ?matmul_tail)))))
|
||||
(union ?add ?fused)
|
||||
(set (dtype ?fused) ?d_dtype)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt row-order 2d scaled c plus matmul beta"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?matmul_c_order "ROW"
|
||||
?lda ?ldb ?matmul_ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?matmul_stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
|
||||
(= ?beta_node (Op (Constant ?beta) (INil)))
|
||||
(= ?scaled_c (Op (Mul
|
||||
(ECons ?batch (ECons ?m (ECons ?n (ENil))))
|
||||
?c_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ECons (MNum 0) (ENil))))
|
||||
?scaled_c_out_strides)
|
||||
(ICons ?c (ICons ?beta_node (INil)))))
|
||||
|
||||
(= ?add (Op (Add
|
||||
(ECons ?batch (ECons ?m (ECons ?n (ENil))))
|
||||
?matmul_add_strides
|
||||
?scaled_c_add_strides
|
||||
?add_out_strides)
|
||||
(ICons ?matmul (ICons ?scaled_c (INil)))))
|
||||
|
||||
(= ?matmul_add_strides (ECons ?d_batch_stride (ECons ?d_row_stride (ECons ?d_col_stride (ENil)))))
|
||||
(= ?c_strides (ECons ?c_batch_stride (ECons ?c_row_stride (ECons ?c_col_stride (ENil)))))
|
||||
(= ?add_out_strides (ECons ?d_batch_stride (ECons ?d_row_stride (ECons ?d_col_stride (ENil)))))
|
||||
(= ?scaled_c_add_strides ?scaled_c_out_strides)
|
||||
(= ?c_col_stride (MIter))
|
||||
(!= ?c_row_stride (MNum 0))
|
||||
(= ?matmul_add_strides ?add_out_strides)
|
||||
(= ?c_dtype (dtype ?c))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order "ROW" "ROW"
|
||||
?lda ?ldb ?c_row_stride ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?c_batch_stride ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha ?beta ?epilogue)
|
||||
(ICons ?a (ICons ?b (ICons ?c ?matmul_tail)))))
|
||||
(union ?add ?fused)
|
||||
(set (dtype ?fused) ?d_dtype)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt row-order batched scaled c beta"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?matmul (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order ?matmul_c_order "ROW"
|
||||
?lda ?ldb ?matmul_ldc ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?matmul_stride_c ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha 0.0 ?epilogue)
|
||||
(ICons ?a (ICons ?b ?matmul_tail))))
|
||||
|
||||
(= ?beta_node (Op (Constant ?beta) (INil)))
|
||||
(= ?scaled_c (Op (Mul
|
||||
(ECons ?batch (ECons ?m (ECons ?n (ENil))))
|
||||
?c_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ECons (MNum 0) (ENil))))
|
||||
?scaled_c_out_strides)
|
||||
(ICons ?c (ICons ?beta_node (INil)))))
|
||||
|
||||
(= ?add (Op (Add
|
||||
(ECons ?batch (ECons ?m (ECons ?n (ENil))))
|
||||
?scaled_c_add_strides
|
||||
?matmul_add_strides
|
||||
?add_out_strides)
|
||||
(ICons ?scaled_c (ICons ?matmul (INil)))))
|
||||
|
||||
(= ?matmul_add_strides (ECons ?d_batch_stride (ECons ?d_row_stride (ECons ?d_col_stride (ENil)))))
|
||||
(= ?c_strides (ECons ?c_batch_stride (ECons ?c_row_stride (ECons ?c_col_stride (ENil)))))
|
||||
(= ?add_out_strides (ECons ?d_batch_stride (ECons ?d_row_stride (ECons ?d_col_stride (ENil)))))
|
||||
(= ?scaled_c_add_strides ?scaled_c_out_strides)
|
||||
(= ?c_col_stride (MIter))
|
||||
(!= ?c_row_stride (MNum 0))
|
||||
(= ?matmul_add_strides ?add_out_strides)
|
||||
(= ?c_dtype (dtype ?c))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (cublaslt
|
||||
?m ?n ?k
|
||||
?a_layout ?b_layout
|
||||
?a_order ?b_order "ROW" "ROW"
|
||||
?lda ?ldb ?c_row_stride ?ldd
|
||||
?batch
|
||||
?stride_a ?stride_b ?c_batch_stride ?stride_d
|
||||
?a_dtype ?b_dtype ?c_dtype ?d_dtype
|
||||
?compute_type ?scale_dtype
|
||||
?alpha ?beta ?epilogue)
|
||||
(ICons ?a (ICons ?b (ICons ?c ?matmul_tail)))))
|
||||
(union ?add ?fused)
|
||||
(set (dtype ?fused) ?d_dtype)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt row-order batched scaled c plus matmul beta"
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,124 +0,0 @@
|
||||
# FlashInfer Integration
|
||||
|
||||
FlashInfer replaces the multi-op attention pattern (Q×K^T → scale → mask → softmax → ×V) with a single fused GPU kernel via [FlashInfer](https://github.com/flashinfer-ai/flashinfer)'s batch decode and batch prefill APIs.
|
||||
|
||||
## Current State
|
||||
|
||||
**Working:**
|
||||
- Egglog rewrite rule matches any GQA paged attention pattern (model-agnostic shapes)
|
||||
- GA search selects FlashInfer when it wins profiling — verified on Llama 3 8B (32 layers) and Qwen 3 4B (36 layers)
|
||||
- **BatchDecode** (s=1): fp32 natively — FlashInfer's decode kernel uses scalar vectorized dot products, no tensor cores
|
||||
- **BatchPrefill**: template-instantiated for fp16 but **not callable from fp32** — FlashInfer's prefill kernel requires tensor core MMA (`mma.sync.aligned.m16n8k16`) and `ldmatrix` which physically only operate on 16-bit types; the C API stubs return -1 for fp32; will be enabled when native fp16/bf16 pipeline is added
|
||||
- Decode handles all cases in the current fp32 pipeline (prefill uses cuBLAS attention via dim bucketing)
|
||||
- Indptr-based mask: `qo_indptr` and `kv_indptr` are computed in-graph so the egglog rule can see them in the same chunk as the attention ops
|
||||
|
||||
**Not yet implemented:**
|
||||
- Native fp16 / bf16 pipeline (would eliminate the cast overhead in prefill)
|
||||
- Page sizes > 1
|
||||
|
||||
---
|
||||
|
||||
## File Organization
|
||||
|
||||
```
|
||||
src/host/flashinfer/
|
||||
flashinfer_attention.egg — egglog rewrite rule (pattern match → FlashInferAttention)
|
||||
mod.rs — FlashInferAttention op (EgglogOp + HostOp impl)
|
||||
jit.rs — JIT compilation: nvcc wrapper.cu → .so, dlopen, fn pointers
|
||||
find_indptrs.rs — walks the mask e-graph node to locate qo_indptr / kv_indptr inputs
|
||||
wrapper.cu — CUDA: FlashInfer template instantiation + helper kernels
|
||||
wrapper.h — C API header for wrapper.cu
|
||||
README.md — this file
|
||||
```
|
||||
|
||||
## How It Works
|
||||
|
||||
### 1. Egglog Pattern Matching
|
||||
|
||||
The rule in `flashinfer_attention.egg` matches the structural pattern of paged GQA attention:
|
||||
|
||||
```
|
||||
Gather(K_cache, idx) → GQA broadcast (Mul×1.0) → Q×K^T → Sum → scale → mask Add → softmax → attn×V → Sum → output
|
||||
Gather(V_cache, idx) → GQA broadcast (Mul×1.0) ──────────────────────────────────────────→ attn×V → Sum → output
|
||||
```
|
||||
|
||||
Key anchors that prevent false matches on MLP or other ops:
|
||||
- Two Gather ops from 2D cache pools (MLP never uses Gather)
|
||||
- GQA broadcast via `Mul(gathered, Constant(1.0))` with all-zero strides
|
||||
- Mask Add with zero-stride broadcast in the first (nheads) dimension
|
||||
- Two sequential matmul+Sum pairs connected through softmax
|
||||
|
||||
Shape dimensions are egglog variables, not pinned constants — the rule works for any model with GQA (Llama, Qwen, Mistral, etc.). The structural invariants (dimension count, zero-stride positions, Gather from 2D) are enough to avoid combinatorial explosion during saturation.
|
||||
|
||||
When the rule fires, it unions `FlashInferAttention` with the original attention output, making it an equivalent alternative in the e-graph. The GA search then profiles both paths and picks the faster one.
|
||||
|
||||
### 2. Extraction: Finding Indptrs
|
||||
|
||||
During `extract()` (called when egglog selects the FlashInferAttention e-node), `find_indptrs.rs` walks backward from the mask node in the e-graph to locate the `qo_indptr` and `kv_indptr` Input nodes. It validates the mask structure by checking for the `Mul(allowed, Constant(1e10))` pattern that `compute_attn_mask()` produces.
|
||||
|
||||
The indptrs are appended as inputs 5 and 6 to the FlashInferAttention op, so the runtime can build the CSR page table directly without recomputing anything.
|
||||
|
||||
### 3. JIT Compilation
|
||||
|
||||
FlashInfer requires `HEAD_DIM` as a compile-time template parameter. Rather than baking it at `cargo build` time, `jit.rs` JIT-compiles `wrapper.cu` with the model's actual HEAD_DIM:
|
||||
|
||||
1. First call to `ensure_compiled(head_dim)` runs `nvcc` with `-DLUMINAL_HEAD_DIM=<N>`
|
||||
2. The compiled `.so` is cached at `~/.cache/luminal/flashinfer/libflashinfer_hd<N>_<arch>.so`
|
||||
3. Subsequent calls load the cached library via `dlopen`
|
||||
4. Function pointers (plan, run, transpose, etc.) are resolved and stored in a `static OnceLock`
|
||||
|
||||
Supported HEAD_DIM values: 64, 128, 256.
|
||||
|
||||
### 4. Runtime Execution
|
||||
|
||||
`FlashInferAttention::execute()` dispatches to decode or prefill based on `total_q_tokens vs batch_size`:
|
||||
|
||||
**Common steps:**
|
||||
1. **Extract kv_indices** — a helper kernel converts the flat gather index `(c, KV_DIM)` to slot indices `(c,)`
|
||||
2. **Read indptrs to host** — copied to CPU for the plan phase
|
||||
3. **Plan** — queries GPU occupancy and decides split-KV decomposition
|
||||
4. **Run** — the fused kernel writes `(total_q_tokens, num_qo_heads, head_dim)`
|
||||
5. **Transpose** — transposes to `(num_qo_heads, total_q_tokens, head_dim)` to match the Sum reduction layout
|
||||
|
||||
**Decode path** (current, fp32): Always used. Runs FlashInfer's BatchDecode directly on fp32 buffers.
|
||||
|
||||
**Prefill path** (future, fp16/bf16 only): The prefill kernel templates are compiled into the JIT .so for fp16 (CTA_TILE_Q=16/64/128, causal mask). The C API stubs currently return -1 since the pipeline is fp32. When native fp16/bf16 dtype support is added, `execute()` will dispatch to prefill when `total_q_tokens > batch_size`.
|
||||
|
||||
Global workspaces (`static OnceLock`) are shared across all FlashInferAttention instances to avoid ~4ms allocation overhead per GA profiling candidate. Without this, the GA never selects FlashInfer because the first-run allocation cost dwarfs the kernel time.
|
||||
|
||||
## How the Attention Mask Enables FlashInfer
|
||||
|
||||
For the egglog rule to fire, the `qo_indptr` and `kv_indptr` tensors must be visible in the same e-graph chunk as the attention ops. This is why the mask is computed *inside* each layer (via `compute_attn_mask()` in the model) rather than passed as a pre-computed input.
|
||||
|
||||
The mask computation uses a specific structure:
|
||||
```rust
|
||||
let allowed = same_request * causal;
|
||||
allowed * 1e10 - 1e10 // → 0.0 for allowed, -1e10 for blocked
|
||||
```
|
||||
|
||||
The `Mul(allowed, Constant(1e10))` pattern is the anchor that `find_indptrs.rs` uses to walk backward and locate the indptr inputs.
|
||||
|
||||
## Roadmap
|
||||
|
||||
Items listed in priority order. Checked items are done.
|
||||
|
||||
- [x] Model-agnostic egglog rule (shape variables instead of Llama-specific constants)
|
||||
- [x] bs>1 supersequence decode
|
||||
- [x] Indptr-based attention mask (replaces CPU-computed mask)
|
||||
- [x] Multi-model support (verified on Llama 3 8B and Qwen 3 4B)
|
||||
- [x] BatchPrefill kernel compiled for fp16 (causal mask, CTA_TILE_Q=16/64/128)
|
||||
- [ ] Native fp16 / bf16 pipeline (enables prefill, reduces memory, eliminates cuBLAS prefill fallback)
|
||||
- [ ] HEAD_DIM dispatch for 64, 96 (JIT supports 64/128/256; wrapper.cu needs 96 for Phi)
|
||||
- [ ] Page sizes > 1 (currently page_size=1; larger pages reduce CSR overhead)
|
||||
- [ ] Sliding window, ALiBi, logits soft cap (FlashInfer `AttentionVariant` templates)
|
||||
- [ ] MHA / MQA / arbitrary GQA ratios beyond {1, 2, 4, 8}
|
||||
|
||||
## Key Design Decisions
|
||||
|
||||
- **page_size=1**: Each KV cache slot is one "page". This simplifies the CSR page table (`kv_indices` = physical slot indices directly) and matches the flat `(num_slots, KV_DIM)` cache layout.
|
||||
|
||||
- **Pinned structural anchors**: The egglog rule pins the *structure* (number of dimensions, which dims are zero-stride, presence of Gather from 2D cache) but uses variables for the *values* (head counts, head_dim). This prevents saturation blowup while remaining model-agnostic.
|
||||
|
||||
- **Prefill requires fp16/bf16**: FlashInfer's prefill kernel uses tensor core MMA instructions (`mma.sync.aligned.m16n8k16`) and `ldmatrix` which physically require 16-bit inputs — there is no fp32 tensor core matmul instruction. The prefill kernel templates are compiled into the .so for fp16 but the C API returns -1 for fp32 callers. When native fp16/bf16 is added, prefill will be enabled automatically.
|
||||
|
||||
- **Global workspaces**: Float workspace (128 MiB), int workspace (8 MiB), and a page-locked host buffer are allocated once via `static OnceLock` and shared across all instances.
|
||||
@@ -1,248 +0,0 @@
|
||||
//! Walk the e-graph from the mask node to find qo_indptr and kv_indptr Input nodes.
|
||||
//!
|
||||
//! The mask is produced by `compute_attn_mask(q_pos, qo_indptr, kv_indptr)` using
|
||||
//! primitive HLIR ops. This module validates the mask's structure and extracts the
|
||||
//! indptr Input node IDs so FlashInfer can use them directly.
|
||||
|
||||
use luminal::egglog_utils::{ClassId, NodeId, SerializedEGraph};
|
||||
use luminal::prelude::FxHashSet;
|
||||
|
||||
/// Result of walking the mask computation chain.
|
||||
#[derive(Debug)]
|
||||
pub struct IndptrNodes<'a> {
|
||||
pub qo_indptr: &'a NodeId,
|
||||
pub kv_indptr: &'a NodeId,
|
||||
}
|
||||
|
||||
/// Find the qo_indptr and kv_indptr Input nodes by walking backwards from the mask.
|
||||
///
|
||||
/// Validates the mask structure: `allowed * 1e10 + (-1e10)`. Then does a BFS from
|
||||
/// the `allowed` subtree to find all reachable Input nodes with names containing
|
||||
/// "qo_indptr" and "kv_indptr".
|
||||
///
|
||||
/// Panics with a diagnostic message if the structure doesn't match or the
|
||||
/// indptr inputs can't be found.
|
||||
pub fn find_indptr_inputs<'a>(
|
||||
egraph: &'a SerializedEGraph,
|
||||
mask_node: &'a NodeId,
|
||||
) -> IndptrNodes<'a> {
|
||||
// Step 1: Validate mask = Add(scaled_allowed, neg_constant)
|
||||
let (mask_label, mask_children) = &egraph.enodes[mask_node];
|
||||
assert!(
|
||||
mask_label == "Op",
|
||||
"find_indptr_inputs: mask node is not an Op (label={mask_label})"
|
||||
);
|
||||
let mask_kind = resolve_first_node(egraph, &mask_children[0]);
|
||||
let mask_kind_label = &egraph.enodes[mask_kind].0;
|
||||
assert!(
|
||||
mask_kind_label.contains("Add"),
|
||||
"find_indptr_inputs: mask is not an Add (kind={mask_kind_label})"
|
||||
);
|
||||
|
||||
let mask_inputs = walk_ilist_simple(egraph, &mask_children[1]);
|
||||
assert_eq!(
|
||||
mask_inputs.len(),
|
||||
2,
|
||||
"find_indptr_inputs: mask Add should have 2 inputs, got {}",
|
||||
mask_inputs.len()
|
||||
);
|
||||
|
||||
// Step 2: One of the inputs should be Mul(allowed, Constant(1e10))
|
||||
let (scaled_allowed, allowed_node) = find_1e10_mul(egraph, &mask_inputs);
|
||||
|
||||
// Step 3: BFS from `allowed` to find all reachable Input nodes
|
||||
let reachable_inputs = find_reachable_inputs(egraph, allowed_node);
|
||||
|
||||
// Step 4: Match by name
|
||||
let mut qo_indptr: Option<&NodeId> = None;
|
||||
let mut kv_indptr: Option<&NodeId> = None;
|
||||
|
||||
for (node_id, name) in &reachable_inputs {
|
||||
if name.contains("qo_indptr") {
|
||||
qo_indptr = Some(node_id);
|
||||
} else if name.contains("kv_indptr") {
|
||||
kv_indptr = Some(node_id);
|
||||
}
|
||||
}
|
||||
|
||||
let qo = qo_indptr.unwrap_or_else(|| {
|
||||
let found_names: Vec<&str> = reachable_inputs.iter().map(|(_, n)| n.as_str()).collect();
|
||||
panic!(
|
||||
"find_indptr_inputs: could not find 'qo_indptr' Input reachable from mask.\n\
|
||||
Found inputs: {:?}\n\
|
||||
Mask node: {:?}\n\
|
||||
Scaled allowed node: {:?}",
|
||||
found_names, mask_node, scaled_allowed
|
||||
);
|
||||
});
|
||||
|
||||
let kv = kv_indptr.unwrap_or_else(|| {
|
||||
let found_names: Vec<&str> = reachable_inputs.iter().map(|(_, n)| n.as_str()).collect();
|
||||
panic!(
|
||||
"find_indptr_inputs: could not find 'kv_indptr' Input reachable from mask.\n\
|
||||
Found inputs: {:?}\n\
|
||||
Mask node: {:?}\n\
|
||||
Scaled allowed node: {:?}",
|
||||
found_names, mask_node, scaled_allowed
|
||||
);
|
||||
});
|
||||
|
||||
IndptrNodes {
|
||||
qo_indptr: qo,
|
||||
kv_indptr: kv,
|
||||
}
|
||||
}
|
||||
|
||||
fn find_1e10_mul<'a>(
|
||||
egraph: &'a SerializedEGraph,
|
||||
mask_add_inputs: &[&'a NodeId],
|
||||
) -> (&'a NodeId, &'a NodeId) {
|
||||
for &input_node in mask_add_inputs {
|
||||
let (label, children) = &egraph.enodes[input_node];
|
||||
if label != "Op" {
|
||||
continue;
|
||||
}
|
||||
let kind = resolve_first_node(egraph, &children[0]);
|
||||
if !egraph.enodes[kind].0.contains("Mul") {
|
||||
continue;
|
||||
}
|
||||
let mul_inputs = walk_ilist_simple(egraph, &children[1]);
|
||||
if mul_inputs.len() != 2 {
|
||||
continue;
|
||||
}
|
||||
for (i, &inp) in mul_inputs.iter().enumerate() {
|
||||
if is_constant(egraph, inp, 1e10) {
|
||||
let other = mul_inputs[1 - i];
|
||||
return (input_node, other);
|
||||
}
|
||||
}
|
||||
}
|
||||
let mut debug_info = String::new();
|
||||
for (i, &input_node) in mask_add_inputs.iter().enumerate() {
|
||||
let (label, children) = &egraph.enodes[input_node];
|
||||
debug_info.push_str(&format!("\n input[{i}]: label={label}"));
|
||||
if label == "Op" && !children.is_empty() {
|
||||
let kind = resolve_first_node(egraph, &children[0]);
|
||||
let kind_label = &egraph.enodes[kind].0;
|
||||
debug_info.push_str(&format!(" kind={kind_label}"));
|
||||
for (j, kc) in egraph.enodes[kind].1.iter().enumerate() {
|
||||
let kc_node = resolve_first_node(egraph, kc);
|
||||
debug_info.push_str(&format!(" child[{j}]={}", egraph.enodes[kc_node].0));
|
||||
}
|
||||
if kind_label.contains("Mul") && children.len() >= 2 {
|
||||
let mul_inputs = walk_ilist_simple(egraph, &children[1]);
|
||||
for (j, &mi) in mul_inputs.iter().enumerate() {
|
||||
let (ml, mc) = &egraph.enodes[mi];
|
||||
debug_info.push_str(&format!("\n mul_input[{j}]: label={ml}"));
|
||||
if ml == "Op" && !mc.is_empty() {
|
||||
let mk = resolve_first_node(egraph, &mc[0]);
|
||||
debug_info.push_str(&format!(" kind={}", egraph.enodes[mk].0));
|
||||
for (k, mkc) in egraph.enodes[mk].1.iter().enumerate() {
|
||||
let mkc_node = resolve_first_node(egraph, mkc);
|
||||
debug_info.push_str(&format!(" ch[{k}]={}", egraph.enodes[mkc_node].0));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
panic!(
|
||||
"find_indptr_inputs: could not find Mul(allowed, Constant(1e10)) in mask Add inputs.{debug_info}"
|
||||
);
|
||||
}
|
||||
|
||||
fn is_constant(egraph: &SerializedEGraph, node: &NodeId, expected: f32) -> bool {
|
||||
let (label, children) = &egraph.enodes[node];
|
||||
if label != "Op" {
|
||||
return false;
|
||||
}
|
||||
let kind = resolve_first_node(egraph, &children[0]);
|
||||
let kind_label = &egraph.enodes[kind].0;
|
||||
if !kind_label.contains("Constant") {
|
||||
return false;
|
||||
}
|
||||
let val_children = &egraph.enodes[kind].1;
|
||||
if val_children.is_empty() {
|
||||
return false;
|
||||
}
|
||||
let val_node = resolve_first_node(egraph, &val_children[0]);
|
||||
let val_str = &egraph.enodes[val_node].0;
|
||||
if let Ok(val) = val_str.parse::<f64>() {
|
||||
(val as f32 - expected).abs() < 1.0
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
fn find_reachable_inputs<'a>(
|
||||
egraph: &'a SerializedEGraph,
|
||||
start: &'a NodeId,
|
||||
) -> Vec<(&'a NodeId, String)> {
|
||||
let mut found = Vec::new();
|
||||
let mut visited = FxHashSet::default();
|
||||
let mut stack = vec![start];
|
||||
|
||||
while let Some(node) = stack.pop() {
|
||||
if !visited.insert(node) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let (label, children) = &egraph.enodes[node];
|
||||
|
||||
if label == "Input" {
|
||||
if children.len() >= 2 {
|
||||
let name_node = resolve_first_node(egraph, &children[1]);
|
||||
let name = egraph.enodes[name_node].0.trim_matches('"').to_string();
|
||||
found.push((node, name));
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if label == "Op" && children.len() >= 2 {
|
||||
let ir_inputs = walk_ilist_simple(egraph, &children[1]);
|
||||
for inp in ir_inputs {
|
||||
stack.push(inp);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
found
|
||||
}
|
||||
|
||||
fn walk_ilist_simple<'a>(
|
||||
egraph: &'a SerializedEGraph,
|
||||
ilist_eclass: &'a ClassId,
|
||||
) -> Vec<&'a NodeId> {
|
||||
let mut inputs = Vec::new();
|
||||
let mut current = resolve_first_node(egraph, ilist_eclass);
|
||||
|
||||
loop {
|
||||
let (label, children) = &egraph.enodes[current];
|
||||
if label == "INil" {
|
||||
break;
|
||||
}
|
||||
if label != "ICons" {
|
||||
break;
|
||||
}
|
||||
let ir_node = resolve_first_ir_node(egraph, &children[0]);
|
||||
inputs.push(ir_node);
|
||||
current = resolve_first_node(egraph, &children[1]);
|
||||
}
|
||||
|
||||
inputs
|
||||
}
|
||||
|
||||
fn resolve_first_node<'a>(egraph: &'a SerializedEGraph, eclass: &ClassId) -> &'a NodeId {
|
||||
&egraph.eclasses[eclass].1[0]
|
||||
}
|
||||
|
||||
fn resolve_first_ir_node<'a>(egraph: &'a SerializedEGraph, eclass: &ClassId) -> &'a NodeId {
|
||||
let nodes = &egraph.eclasses[eclass].1;
|
||||
for node in nodes {
|
||||
let label = &egraph.enodes[node].0;
|
||||
if label == "Op" || label == "Input" {
|
||||
return node;
|
||||
}
|
||||
}
|
||||
&nodes[0]
|
||||
}
|
||||
@@ -1,125 +0,0 @@
|
||||
; FlashInfer batch decode attention rewrite rule.
|
||||
;
|
||||
; Matches the paged attention pattern for ANY model with GQA:
|
||||
; Gather(K_cache) → GQA broadcast → Q*K^T matmul → scale → add mask → softmax → attn*V matmul
|
||||
; Gather(V_cache) → GQA broadcast ──────────────────────────────────────────→ attn*V matmul
|
||||
;
|
||||
; Structural anchors (prevent false matches on MLP/other ops):
|
||||
; - Gather ops from 2D cache pools (MLP never uses Gather)
|
||||
; - GQA broadcast via Mul(gathered, Constant(1.0)) with all-zero strides
|
||||
; - Scale Mul(QK, constant) connecting QK scores to mask Add
|
||||
; - Mask Add with zero-stride broadcast in first dim (nheads broadcast)
|
||||
; - Data flow: two sequential matmul+reduce pairs connected through softmax
|
||||
;
|
||||
; The egglog rule captures the mask as 5th input. During extract(), a Rust
|
||||
; function walks the mask's computation chain in the e-graph to locate the
|
||||
; qo_indptr and kv_indptr Input nodes (validated via the Constant(1e10) anchor
|
||||
; and structural checks). These are appended as inputs 5 and 6 so FlashInfer
|
||||
; can build the CSR page table directly — no runtime derivation needed.
|
||||
;
|
||||
; Shape dimensions are egglog variables, not pinned constants.
|
||||
; Dynamic dims "s" (batch/seq) and "c" (context) stay pinned as MVar.
|
||||
|
||||
(rule
|
||||
(
|
||||
; ── Second matmul: Mul(softmax_out, V_gqa) ──
|
||||
; Shape: (nheads, s, hdim, c) — 4D
|
||||
(= ?mul2 (Op (Mul
|
||||
(ECons ?nheads (ECons (MVar "s") (ECons ?hdim (ECons (MVar "c") (ENil)))))
|
||||
?mul2_a_strides
|
||||
?mul2_b_strides
|
||||
?mul2_out_strides)
|
||||
(ICons ?soft (ICons ?v_gqa (INil)))))
|
||||
|
||||
; ── Second matmul: Sum (reduction over c) → output ──
|
||||
; Shape: (nheads, s, hdim) — reduces c
|
||||
(= ?output (Op (Sum
|
||||
(ECons ?nheads2 (ECons (MVar "s") (ECons ?hdim2 (ENil))))
|
||||
(MVar "c")
|
||||
?out_in_strides
|
||||
(MIter)
|
||||
?out_out_strides)
|
||||
(ICons ?mul2 (INil))))
|
||||
|
||||
; ── V GQA broadcast: Mul(V_gathered, 1.0) with zero-stride constant ──
|
||||
; Shape: (nheads, c, hdim) — 3D
|
||||
(= ?v_gqa_const (Op (Constant 1.000000) (INil)))
|
||||
(= ?v_gqa (Op (Mul
|
||||
(ECons ?nheads3 (ECons (MVar "c") (ECons ?hdim3 (ENil))))
|
||||
?v_gqa_a_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ECons (MNum 0) (ENil))))
|
||||
?v_gqa_out_strides)
|
||||
(ICons ?v_gathered (ICons ?v_gqa_const (INil)))))
|
||||
|
||||
; ── V Gather: rows from V_cache (2D) ──
|
||||
; Shape: (c, kvdim), Source: (num_slots, kvdim)
|
||||
(= ?v_gathered (Op (Gather
|
||||
(ECons (MVar "c") (ECons ?kvdim (ENil)))
|
||||
?v_gather_strides
|
||||
(ECons ?num_slots_v (ECons ?kvdim2 (ENil)))
|
||||
?v_src_strides)
|
||||
(ICons ?v_idx (ICons ?v_cache (INil)))))
|
||||
|
||||
; ── First matmul: Mul(Q, K_gqa) ──
|
||||
; Shape: (nheads, s, c, hdim) — 4D
|
||||
(= ?mul1 (Op (Mul
|
||||
(ECons ?nheads4 (ECons (MVar "s") (ECons (MVar "c") (ECons ?hdim4 (ENil)))))
|
||||
?mul1_a_strides
|
||||
?mul1_b_strides
|
||||
?mul1_out_strides)
|
||||
(ICons ?q (ICons ?k_gqa (INil)))))
|
||||
|
||||
; ── First matmul: Sum (reduction over hdim) → QK scores ──
|
||||
; Shape: (nheads, s, c) — reduces hdim
|
||||
(= ?qk (Op (Sum
|
||||
(ECons ?nheads5 (ECons (MVar "s") (ECons (MVar "c") (ENil))))
|
||||
?hdim5
|
||||
?qk_in_strides
|
||||
(MIter)
|
||||
?qk_out_strides)
|
||||
(ICons ?mul1 (INil))))
|
||||
|
||||
; ── Mask Add: Add(scaled_QK, mask) ──
|
||||
; Shape: (nheads, s, c) — 3D
|
||||
; Mask is broadcast from (s, c) via zero-stride in first dim (nheads).
|
||||
(= ?masked (Op (Add
|
||||
(ECons ?nheads8 (ECons (MVar "s") (ECons (MVar "c") (ENil))))
|
||||
?mask_add_a_strides
|
||||
(ECons (MNum 0) ?mask_rest_strides)
|
||||
?mask_add_out_strides)
|
||||
(ICons ?scaled_qk (ICons ?mask (INil)))))
|
||||
|
||||
; ── K GQA broadcast: Mul(K_gathered, 1.0) with zero-stride constant ──
|
||||
; Shape: (nheads, hdim, c) — 3D
|
||||
(= ?k_gqa_const (Op (Constant 1.000000) (INil)))
|
||||
(= ?k_gqa (Op (Mul
|
||||
(ECons ?nheads6 (ECons ?hdim6 (ECons (MVar "c") (ENil))))
|
||||
?k_gqa_a_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ECons (MNum 0) (ENil))))
|
||||
?k_gqa_out_strides)
|
||||
(ICons ?k_gathered (ICons ?k_gqa_const (INil)))))
|
||||
|
||||
; ── K Gather: rows from K_cache (2D) ──
|
||||
; Shape: (c, kvdim), Source: (num_slots, kvdim)
|
||||
(= ?k_gathered (Op (Gather
|
||||
(ECons (MVar "c") (ECons ?kvdim3 (ENil)))
|
||||
?k_gather_strides
|
||||
(ECons ?num_slots_k (ECons ?kvdim4 (ENil)))
|
||||
?k_src_strides)
|
||||
(ICons ?k_idx (ICons ?k_cache (INil)))))
|
||||
|
||||
; ── Dtype consistency ──
|
||||
(= ?dt (dtype ?q))
|
||||
(= ?dt (dtype ?k_cache))
|
||||
(= ?dt (dtype ?v_cache))
|
||||
)
|
||||
(
|
||||
(let ?fi (Op (FlashInferAttention
|
||||
?nheads (MDiv ?kvdim ?hdim) ?hdim (MNum 1) (MVar "s"))
|
||||
(ICons ?q (ICons ?k_cache (ICons ?v_cache (ICons ?k_idx (ICons ?mask (INil))))))))
|
||||
(union ?output ?fi)
|
||||
(set (dtype ?fi) ?dt)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "FlashInfer batch decode attention"
|
||||
)
|
||||
@@ -1,504 +0,0 @@
|
||||
//! JIT compilation and dynamic loading of FlashInfer kernels.
|
||||
//!
|
||||
//! Everything runs at compile / profiling time — there is no `build.rs`.
|
||||
//! `wrapper.cu` and `wrapper.h` are embedded via `include_str!()` and
|
||||
//! extracted to the cache directory on first use. The FlashInfer + CUTLASS
|
||||
//! header trees are located by probing `LUMINAL_FLASHINFER_DIR`, a small set
|
||||
//! of default paths, and (as a last resort) by `git clone`-ing FlashInfer at
|
||||
//! a pinned commit into the cache. `nvcc` is then invoked with the model's
|
||||
//! actual `HEAD_DIM` and the resulting `.so` is `dlopen`'d.
|
||||
//!
|
||||
//! `ensure_compiled` is called from `FlashInferAttention::extract()`, i.e.
|
||||
//! during luminal's compile / GA-profiling phase, not from `execute()`. After
|
||||
//! the first call the `OnceLock` makes subsequent lookups free.
|
||||
|
||||
use std::{
|
||||
ffi::c_void,
|
||||
hash::{Hash, Hasher},
|
||||
path::{Path, PathBuf},
|
||||
process::Command,
|
||||
sync::OnceLock,
|
||||
};
|
||||
|
||||
// ── Function pointer types matching wrapper.h ──
|
||||
|
||||
pub type PlanFn = unsafe extern "C" fn(
|
||||
float_workspace: *mut c_void,
|
||||
float_ws_size: usize,
|
||||
int_workspace: *mut c_void,
|
||||
int_ws_size: usize,
|
||||
page_locked_int_workspace: *mut c_void,
|
||||
indptr_h: *mut i32,
|
||||
batch_size: i32,
|
||||
num_qo_heads: i32,
|
||||
num_kv_heads: i32,
|
||||
page_size: i32,
|
||||
head_dim: i32,
|
||||
stream: *mut c_void,
|
||||
plan_info_out: *mut i64,
|
||||
plan_info_len_out: *mut i32,
|
||||
) -> i32;
|
||||
|
||||
pub type RunFn = unsafe extern "C" fn(
|
||||
float_workspace: *mut c_void,
|
||||
float_ws_size: usize,
|
||||
int_workspace: *mut c_void,
|
||||
plan_info_vec: *mut i64,
|
||||
plan_info_len: i32,
|
||||
q: *mut f32,
|
||||
k_cache: *mut f32,
|
||||
v_cache: *mut f32,
|
||||
kv_indptr: *mut i32,
|
||||
kv_indices: *mut i32,
|
||||
kv_last_page_len: *mut i32,
|
||||
output: *mut f32,
|
||||
batch_size: i32,
|
||||
num_qo_heads: i32,
|
||||
num_kv_heads: i32,
|
||||
page_size: i32,
|
||||
head_dim: i32,
|
||||
stream: *mut c_void,
|
||||
) -> i32;
|
||||
|
||||
pub type ExtractFn = unsafe extern "C" fn(
|
||||
flat_idx: *const i32,
|
||||
out: *mut i32,
|
||||
c: i32,
|
||||
kv_dim: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
|
||||
pub type DeriveIndptrFn =
|
||||
unsafe extern "C" fn(mask: *const f32, indptr: *mut i32, s: i32, c: i32, stream: *mut c_void);
|
||||
|
||||
pub type TransposeOutputFn = unsafe extern "C" fn(
|
||||
src: *const f32,
|
||||
dst: *mut f32,
|
||||
batch: i32,
|
||||
heads: i32,
|
||||
dim: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
|
||||
pub type PrefillPlanFn = unsafe extern "C" fn(
|
||||
float_workspace: *mut c_void,
|
||||
float_ws_size: usize,
|
||||
int_workspace: *mut c_void,
|
||||
int_ws_size: usize,
|
||||
page_locked_int_workspace: *mut c_void,
|
||||
qo_indptr_h: *mut i32,
|
||||
kv_indptr_h: *mut i32,
|
||||
total_num_rows: i32,
|
||||
batch_size: i32,
|
||||
num_qo_heads: i32,
|
||||
num_kv_heads: i32,
|
||||
page_size: i32,
|
||||
head_dim: i32,
|
||||
stream: *mut c_void,
|
||||
plan_info_out: *mut i64,
|
||||
plan_info_len_out: *mut i32,
|
||||
) -> i32;
|
||||
|
||||
pub type PrefillRunFn = unsafe extern "C" fn(
|
||||
float_workspace: *mut c_void,
|
||||
float_ws_size: usize,
|
||||
int_workspace: *mut c_void,
|
||||
plan_info_vec: *mut i64,
|
||||
plan_info_len: i32,
|
||||
q: *mut f32,
|
||||
k_cache: *mut f32,
|
||||
v_cache: *mut f32,
|
||||
qo_indptr: *mut i32,
|
||||
kv_indptr: *mut i32,
|
||||
kv_indices: *mut i32,
|
||||
kv_last_page_len: *mut i32,
|
||||
output: *mut f32,
|
||||
total_num_rows: i32,
|
||||
batch_size: i32,
|
||||
num_qo_heads: i32,
|
||||
num_kv_heads: i32,
|
||||
page_size: i32,
|
||||
head_dim: i32,
|
||||
stream: *mut c_void,
|
||||
) -> i32;
|
||||
|
||||
// ── Embedded CUDA sources ──
|
||||
|
||||
const WRAPPER_CU: &str = include_str!("wrapper.cu");
|
||||
const WRAPPER_H: &str = include_str!("wrapper.h");
|
||||
|
||||
// ── Loaded library handle ──
|
||||
|
||||
pub struct FlashInferLib {
|
||||
// Keep the handle alive so the dlopen'd .so remains mapped.
|
||||
_lib: libloading::Library,
|
||||
pub plan: PlanFn,
|
||||
pub run: RunFn,
|
||||
pub extract_slot_indices: ExtractFn,
|
||||
pub derive_indptr_from_mask: DeriveIndptrFn,
|
||||
pub transpose_output: TransposeOutputFn,
|
||||
pub prefill_plan: PrefillPlanFn,
|
||||
pub prefill_run: PrefillRunFn,
|
||||
}
|
||||
|
||||
// SAFETY: The library handle and function pointers are valid for the lifetime
|
||||
// of the process. All functions are called with proper CUDA stream serialization.
|
||||
unsafe impl Send for FlashInferLib {}
|
||||
unsafe impl Sync for FlashInferLib {}
|
||||
|
||||
static FLASHINFER_LIB: OnceLock<FlashInferLib> = OnceLock::new();
|
||||
|
||||
/// Ensure the FlashInfer library is compiled and loaded for the given HEAD_DIM.
|
||||
/// Returns a reference to the loaded library. Thread-safe via OnceLock.
|
||||
pub fn ensure_compiled(head_dim: usize) -> &'static FlashInferLib {
|
||||
FLASHINFER_LIB.get_or_init(|| {
|
||||
assert!(
|
||||
matches!(head_dim, 64 | 128 | 256),
|
||||
"FlashInfer: unsupported HEAD_DIM={} (must be 64, 128, or 256 for f32)",
|
||||
head_dim
|
||||
);
|
||||
let so_path = compile_or_cache(head_dim);
|
||||
unsafe {
|
||||
FlashInferLib::load(&so_path)
|
||||
.unwrap_or_else(|e| panic!("Failed to load FlashInfer library: {e}"))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
impl FlashInferLib {
|
||||
/// Load a compiled FlashInfer .so and resolve function pointers.
|
||||
///
|
||||
/// # Safety
|
||||
/// The .so must be a valid FlashInfer wrapper compiled from wrapper.cu.
|
||||
unsafe fn load(path: &Path) -> Result<Self, libloading::Error> {
|
||||
let lib = unsafe { libloading::Library::new(path)? };
|
||||
let plan: PlanFn = unsafe { *lib.get::<PlanFn>(b"flashinfer_batch_decode_plan\0")? };
|
||||
let run: RunFn = unsafe { *lib.get::<RunFn>(b"flashinfer_batch_decode_run\0")? };
|
||||
let extract_slot_indices: ExtractFn =
|
||||
unsafe { *lib.get::<ExtractFn>(b"flashinfer_extract_slot_indices\0")? };
|
||||
let derive_indptr_from_mask: DeriveIndptrFn =
|
||||
unsafe { *lib.get::<DeriveIndptrFn>(b"flashinfer_derive_indptr_from_mask\0")? };
|
||||
let transpose_output: TransposeOutputFn =
|
||||
unsafe { *lib.get::<TransposeOutputFn>(b"flashinfer_transpose_output\0")? };
|
||||
let prefill_plan: PrefillPlanFn =
|
||||
unsafe { *lib.get::<PrefillPlanFn>(b"flashinfer_batch_prefill_plan\0")? };
|
||||
let prefill_run: PrefillRunFn =
|
||||
unsafe { *lib.get::<PrefillRunFn>(b"flashinfer_batch_prefill_run\0")? };
|
||||
Ok(Self {
|
||||
_lib: lib,
|
||||
plan,
|
||||
run,
|
||||
extract_slot_indices,
|
||||
derive_indptr_from_mask,
|
||||
transpose_output,
|
||||
prefill_plan,
|
||||
prefill_run,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Compile wrapper.cu for the given HEAD_DIM, or return cached .so path.
|
||||
fn compile_or_cache(head_dim: usize) -> PathBuf {
|
||||
let cache_dir = cache_directory();
|
||||
std::fs::create_dir_all(&cache_dir).expect("Failed to create FlashInfer cache directory");
|
||||
|
||||
// Extract bundled wrapper sources to the cache so nvcc can compile them.
|
||||
let (wrapper_cu_path, wrapper_h_dir) = extract_wrapper_sources(&cache_dir);
|
||||
|
||||
let arch = detect_cuda_arch();
|
||||
// Bake a hash of the embedded wrapper into the .so name so old caches are
|
||||
// discarded automatically when wrapper.cu or wrapper.h change.
|
||||
let wrapper_hash = wrapper_source_hash();
|
||||
let so_name = format!(
|
||||
"libflashinfer_hd{}_{}_w{:016x}.so",
|
||||
head_dim, arch, wrapper_hash
|
||||
);
|
||||
let so_path = cache_dir.join(&so_name);
|
||||
|
||||
if so_path.exists() {
|
||||
eprintln!(
|
||||
"FlashInfer: using cached library for HEAD_DIM={} ({})",
|
||||
head_dim,
|
||||
so_path.display()
|
||||
);
|
||||
return so_path;
|
||||
}
|
||||
|
||||
let Some((flashinfer_include, cutlass_include)) = locate_flashinfer_includes() else {
|
||||
panic!(
|
||||
"FlashInfer: could not locate header tree. Set LUMINAL_FLASHINFER_DIR to the \
|
||||
FlashInfer source root (the directory containing `include/` and \
|
||||
`3rdparty/cutlass/include/`)."
|
||||
);
|
||||
};
|
||||
|
||||
eprintln!(
|
||||
"FlashInfer: JIT compiling for HEAD_DIM={}, arch={} ...",
|
||||
head_dim, arch
|
||||
);
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
let output = Command::new("nvcc")
|
||||
.args([
|
||||
"-shared",
|
||||
"-o",
|
||||
so_path.to_str().unwrap(),
|
||||
&format!("-DLUMINAL_HEAD_DIM={}", head_dim),
|
||||
wrapper_cu_path.to_str().unwrap(),
|
||||
"-I",
|
||||
flashinfer_include.to_str().unwrap(),
|
||||
"-I",
|
||||
cutlass_include.to_str().unwrap(),
|
||||
"-I",
|
||||
wrapper_h_dir.to_str().unwrap(),
|
||||
"-std=c++17",
|
||||
&format!("-arch={}", arch),
|
||||
"-O3",
|
||||
"--expt-relaxed-constexpr",
|
||||
"-w",
|
||||
"-rdc=true",
|
||||
"--compiler-options",
|
||||
"-fPIC",
|
||||
])
|
||||
.output()
|
||||
.expect("Failed to run nvcc. Is the CUDA toolkit installed?");
|
||||
|
||||
if !output.status.success() {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
let stdout = String::from_utf8_lossy(&output.stdout);
|
||||
let _ = std::fs::remove_file(&so_path);
|
||||
panic!(
|
||||
"FlashInfer JIT compilation failed (HEAD_DIM={}, arch={}):\nstdout: {}\nstderr: {}",
|
||||
head_dim, arch, stdout, stderr
|
||||
);
|
||||
}
|
||||
|
||||
let elapsed = start.elapsed();
|
||||
eprintln!(
|
||||
"FlashInfer: compiled in {:.1}s → {}",
|
||||
elapsed.as_secs_f64(),
|
||||
so_path.display()
|
||||
);
|
||||
|
||||
so_path
|
||||
}
|
||||
|
||||
/// Returns ~/.cache/luminal/flashinfer/
|
||||
fn cache_directory() -> PathBuf {
|
||||
let home = std::env::var("HOME").unwrap_or_else(|_| "/tmp".to_string());
|
||||
PathBuf::from(home)
|
||||
.join(".cache")
|
||||
.join("luminal")
|
||||
.join("flashinfer")
|
||||
}
|
||||
|
||||
/// Drop the embedded wrapper.cu/wrapper.h into the cache dir so nvcc has files
|
||||
/// on disk to compile. Returns (wrapper.cu path, directory containing wrapper.h).
|
||||
fn extract_wrapper_sources(cache_dir: &Path) -> (PathBuf, PathBuf) {
|
||||
let cu = cache_dir.join("wrapper.cu");
|
||||
let h = cache_dir.join("wrapper.h");
|
||||
write_if_changed(&cu, WRAPPER_CU.as_bytes());
|
||||
write_if_changed(&h, WRAPPER_H.as_bytes());
|
||||
(cu, cache_dir.to_path_buf())
|
||||
}
|
||||
|
||||
fn write_if_changed(path: &Path, contents: &[u8]) {
|
||||
if let Ok(existing) = std::fs::read(path)
|
||||
&& existing == contents
|
||||
{
|
||||
return;
|
||||
}
|
||||
std::fs::write(path, contents).unwrap_or_else(|e| {
|
||||
panic!(
|
||||
"FlashInfer: failed to write wrapper source to {}: {e}",
|
||||
path.display()
|
||||
)
|
||||
});
|
||||
}
|
||||
|
||||
fn wrapper_source_hash() -> u64 {
|
||||
let mut hasher = std::collections::hash_map::DefaultHasher::new();
|
||||
WRAPPER_CU.hash(&mut hasher);
|
||||
WRAPPER_H.hash(&mut hasher);
|
||||
hasher.finish()
|
||||
}
|
||||
|
||||
// ── Pinned FlashInfer source ──
|
||||
//
|
||||
// Bumping this constant invalidates the cached source tree AND the cached .so
|
||||
// (the .so cache key incorporates the wrapper hash, which is rebuilt against
|
||||
// these headers, so different headers compile to a different .so file even at
|
||||
// the same head_dim). If you change `FLASHINFER_GIT_REV`, also re-check
|
||||
// `wrapper.cu` against the new FlashInfer API.
|
||||
|
||||
const FLASHINFER_GIT_URL: &str = "https://github.com/flashinfer-ai/flashinfer.git";
|
||||
const CUTLASS_GIT_URL: &str = "https://github.com/NVIDIA/cutlass.git";
|
||||
const FLASHINFER_GIT_REV: &str = "f1e6fdcb8f65104047697f022b5d055ef022d763";
|
||||
const CUTLASS_GIT_REV: &str = "f3fde58372d33e9a5650ba7b80fc48b3b49d40c8";
|
||||
|
||||
fn locate_flashinfer_includes() -> Option<(PathBuf, PathBuf)> {
|
||||
if let Ok(path) = std::env::var("LUMINAL_FLASHINFER_DIR")
|
||||
&& !path.is_empty()
|
||||
{
|
||||
let root = PathBuf::from(path);
|
||||
let inc = root.join("include");
|
||||
let cutlass = root.join("3rdparty/cutlass/include");
|
||||
if inc.exists() && cutlass.exists() {
|
||||
return Some((inc, cutlass));
|
||||
}
|
||||
eprintln!(
|
||||
"FlashInfer: LUMINAL_FLASHINFER_DIR={} did not contain include/ and \
|
||||
3rdparty/cutlass/include/ — falling back to default locations",
|
||||
root.display()
|
||||
);
|
||||
}
|
||||
|
||||
let home = std::env::var("HOME").unwrap_or_default();
|
||||
let candidates = [
|
||||
PathBuf::from(&home).join("luminal_cuda/crates/luminal_cuda/flashinfer"),
|
||||
PathBuf::from(&home).join("luminal_cuda/flashinfer"),
|
||||
PathBuf::from("/opt/luminal_cuda/crates/luminal_cuda/flashinfer"),
|
||||
];
|
||||
for root in candidates {
|
||||
let inc = root.join("include");
|
||||
let cutlass = root.join("3rdparty/cutlass/include");
|
||||
if inc.exists() && cutlass.exists() {
|
||||
return Some((inc, cutlass));
|
||||
}
|
||||
}
|
||||
|
||||
// Last resort: fetch the pinned commit into the cache directory.
|
||||
fetch_flashinfer_source().ok().map(|root| {
|
||||
let inc = root.join("include");
|
||||
let cutlass = root.join("3rdparty/cutlass/include");
|
||||
(inc, cutlass)
|
||||
})
|
||||
}
|
||||
|
||||
/// Clone FlashInfer at `FLASHINFER_GIT_REV` + CUTLASS at `CUTLASS_GIT_REV`
|
||||
/// into `~/.cache/luminal/flashinfer-src/<short_rev>/` if absent, then return
|
||||
/// the FlashInfer root directory. ~50 MB one-time download; subsequent calls
|
||||
/// short-circuit on the directory check.
|
||||
fn fetch_flashinfer_source() -> Result<PathBuf, String> {
|
||||
let short = &FLASHINFER_GIT_REV[..12];
|
||||
let cache_root = cache_directory().join("flashinfer-src").join(short);
|
||||
let inc = cache_root.join("include");
|
||||
let cutlass_inc = cache_root.join("3rdparty/cutlass/include");
|
||||
|
||||
if inc.exists() && cutlass_inc.exists() {
|
||||
return Ok(cache_root);
|
||||
}
|
||||
|
||||
let parent = cache_root.parent().unwrap();
|
||||
std::fs::create_dir_all(parent)
|
||||
.map_err(|e| format!("failed to create {}: {e}", parent.display()))?;
|
||||
|
||||
// Clone into a staging dir, then atomic rename. Protects against multiple
|
||||
// processes racing to fetch the same source.
|
||||
let staging = parent.join(format!(".staging-{}-{}", short, std::process::id()));
|
||||
let _ = std::fs::remove_dir_all(&staging);
|
||||
|
||||
eprintln!(
|
||||
"FlashInfer: cloning {FLASHINFER_GIT_URL} @ {short} into {} (one-time fetch, ~50 MB) …",
|
||||
cache_root.display()
|
||||
);
|
||||
|
||||
run_git(&[
|
||||
"clone",
|
||||
"--filter=blob:none",
|
||||
"--no-checkout",
|
||||
FLASHINFER_GIT_URL,
|
||||
staging.to_str().unwrap(),
|
||||
])?;
|
||||
run_git_in(&staging, &["checkout", FLASHINFER_GIT_REV])?;
|
||||
|
||||
// Init only the CUTLASS submodule (skip spdlog — we don't need it for kernels).
|
||||
let cutlass_path = staging.join("3rdparty/cutlass");
|
||||
let _ = std::fs::remove_dir_all(&cutlass_path);
|
||||
run_git(&[
|
||||
"clone",
|
||||
"--filter=blob:none",
|
||||
"--no-checkout",
|
||||
CUTLASS_GIT_URL,
|
||||
cutlass_path.to_str().unwrap(),
|
||||
])?;
|
||||
run_git_in(&cutlass_path, &["checkout", CUTLASS_GIT_REV])?;
|
||||
|
||||
if !staging.join("include").exists() {
|
||||
return Err(format!(
|
||||
"FlashInfer clone succeeded but include/ missing at {}",
|
||||
staging.display()
|
||||
));
|
||||
}
|
||||
if !staging.join("3rdparty/cutlass/include").exists() {
|
||||
return Err(format!(
|
||||
"CUTLASS clone succeeded but include/ missing at {}",
|
||||
staging.join("3rdparty/cutlass").display()
|
||||
));
|
||||
}
|
||||
|
||||
// Atomic-ish rename. If another process beat us to it, just keep theirs.
|
||||
match std::fs::rename(&staging, &cache_root) {
|
||||
Ok(()) => {}
|
||||
Err(_) if cache_root.exists() => {
|
||||
let _ = std::fs::remove_dir_all(&staging);
|
||||
}
|
||||
Err(e) => return Err(format!("rename to {} failed: {e}", cache_root.display())),
|
||||
}
|
||||
|
||||
Ok(cache_root)
|
||||
}
|
||||
|
||||
fn run_git(args: &[&str]) -> Result<(), String> {
|
||||
let out = Command::new("git")
|
||||
.args(args)
|
||||
.output()
|
||||
.map_err(|e| format!("failed to spawn `git`: {e}. Is git installed?"))?;
|
||||
if !out.status.success() {
|
||||
return Err(format!(
|
||||
"`git {}` failed: {}",
|
||||
args.join(" "),
|
||||
String::from_utf8_lossy(&out.stderr)
|
||||
));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_git_in(cwd: &Path, args: &[&str]) -> Result<(), String> {
|
||||
let out = Command::new("git")
|
||||
.args(args)
|
||||
.current_dir(cwd)
|
||||
.output()
|
||||
.map_err(|e| format!("failed to spawn `git`: {e}"))?;
|
||||
if !out.status.success() {
|
||||
return Err(format!(
|
||||
"`git {}` in {} failed: {}",
|
||||
args.join(" "),
|
||||
cwd.display(),
|
||||
String::from_utf8_lossy(&out.stderr)
|
||||
));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Detect CUDA arch via env override → nvidia-smi → default sm_80.
|
||||
fn detect_cuda_arch() -> String {
|
||||
if let Ok(arch) = std::env::var("FLASHINFER_CUDA_ARCH") {
|
||||
return arch;
|
||||
}
|
||||
|
||||
if let Ok(output) = Command::new("nvidia-smi")
|
||||
.args(["--query-gpu=compute_cap", "--format=csv,noheader"])
|
||||
.output()
|
||||
&& output.status.success()
|
||||
{
|
||||
let cap = String::from_utf8_lossy(&output.stdout);
|
||||
let cap = cap.trim().lines().next().unwrap_or("8.0");
|
||||
let sm = cap.replace('.', "");
|
||||
if !sm.is_empty() {
|
||||
return format!("sm_{}", sm);
|
||||
}
|
||||
}
|
||||
|
||||
"sm_80".to_string()
|
||||
}
|
||||
@@ -1,424 +0,0 @@
|
||||
pub mod find_indptrs;
|
||||
pub mod jit;
|
||||
|
||||
use std::sync::{Arc, Mutex, OnceLock};
|
||||
|
||||
use luminal::{
|
||||
egglog_utils::{
|
||||
api::{Rule, SortDef, sort},
|
||||
base::{EXPRESSION, OP_KIND},
|
||||
extract_expr,
|
||||
},
|
||||
op::{EgglogOp, LLIROp},
|
||||
prelude::{
|
||||
tracing::{Level, span},
|
||||
*,
|
||||
},
|
||||
};
|
||||
|
||||
use crate::{
|
||||
cudarc::driver::{CudaSlice, CudaStream, DevicePtr, result},
|
||||
host::{DeviceBuffer, HostOp},
|
||||
};
|
||||
|
||||
/// FlashInfer attention op (batch decode, fp32).
|
||||
///
|
||||
/// Replaces the full paged-GQA attention pattern (gather → broadcast → Q*K^T →
|
||||
/// scale → mask → softmax → *V) with a single FlashInfer fused kernel.
|
||||
///
|
||||
/// Graph inputs (7): Q, K_pool, V_pool, flat_gather_idx, mask, qo_indptr, kv_indptr.
|
||||
/// The egglog rule captures the first 5; `extract()` appends qo/kv indptrs after
|
||||
/// walking the e-graph from the mask. `batch_size` is derived at runtime from the
|
||||
/// indptr length (= num_sequences + 1).
|
||||
#[derive(Debug)]
|
||||
pub struct FlashInferAttention {
|
||||
pub num_qo_heads: usize,
|
||||
pub num_kv_heads: usize,
|
||||
pub head_dim: usize,
|
||||
pub page_size: usize,
|
||||
pub batch_dim: Expression,
|
||||
|
||||
pub plan_info: Mutex<Vec<i64>>,
|
||||
}
|
||||
|
||||
// SAFETY: PAGE_LOCKED_WORKSPACE holds a raw pointer to page-locked CUDA memory
|
||||
// allocated once and serialized via the CUDA stream that owns it.
|
||||
unsafe impl Send for FlashInferAttention {}
|
||||
unsafe impl Sync for FlashInferAttention {}
|
||||
|
||||
const FLOAT_WORKSPACE_SIZE: usize = 128 * 1024 * 1024; // 128 MiB
|
||||
const INT_WORKSPACE_SIZE: usize = 8 * 1024 * 1024; // 8 MiB
|
||||
|
||||
static PAGE_LOCKED_WORKSPACE: OnceLock<PageLockedPtr> = OnceLock::new();
|
||||
|
||||
struct PageLockedPtr(*mut u8);
|
||||
|
||||
// SAFETY: The pointer is page-locked CUDA memory allocated once via
|
||||
// posix_memalign + cudaHostRegister and only mutated during OnceLock
|
||||
// initialization.
|
||||
unsafe impl Send for PageLockedPtr {}
|
||||
unsafe impl Sync for PageLockedPtr {}
|
||||
|
||||
impl std::fmt::Debug for PageLockedPtr {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "PageLockedPtr({:p})", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for FlashInferAttention {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
num_qo_heads: 0,
|
||||
num_kv_heads: 0,
|
||||
head_dim: 0,
|
||||
page_size: 0,
|
||||
batch_dim: Expression::default(),
|
||||
plan_info: Mutex::new(Vec::new()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl EgglogOp for FlashInferAttention {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
"FlashInferAttention",
|
||||
&[
|
||||
("num_qo_heads", EXPRESSION),
|
||||
("num_kv_heads", EXPRESSION),
|
||||
("head_dim", EXPRESSION),
|
||||
("page_size", EXPRESSION),
|
||||
("batch_dim", EXPRESSION),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
fn n_inputs(&self) -> usize {
|
||||
// Q, K_pool, V_pool, flat_gather_idx, mask (egglog IList).
|
||||
// extract() appends qo_indptr + kv_indptr → 7 actual inputs at runtime.
|
||||
5
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![Rule::raw(include_str!["flashinfer_attention.egg"])]
|
||||
}
|
||||
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a luminal::egglog_utils::SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
_list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
let num_qo_heads = extract_expr(egraph, kind_children[0], expr_cache)
|
||||
.unwrap()
|
||||
.exec(&FxHashMap::default())
|
||||
.unwrap();
|
||||
let num_kv_heads = extract_expr(egraph, kind_children[1], expr_cache)
|
||||
.unwrap()
|
||||
.exec(&FxHashMap::default())
|
||||
.unwrap();
|
||||
let head_dim = extract_expr(egraph, kind_children[2], expr_cache)
|
||||
.unwrap()
|
||||
.exec(&FxHashMap::default())
|
||||
.unwrap();
|
||||
let page_size = extract_expr(egraph, kind_children[3], expr_cache)
|
||||
.unwrap()
|
||||
.exec(&FxHashMap::default())
|
||||
.unwrap();
|
||||
let batch_dim = extract_expr(egraph, kind_children[4], expr_cache).unwrap();
|
||||
|
||||
let extracted = Self {
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
page_size,
|
||||
batch_dim,
|
||||
plan_info: Mutex::new(Vec::new()),
|
||||
};
|
||||
|
||||
// Trigger JIT compilation (or .so cache hit) at extract time, not at
|
||||
// first execute. Pays the ~30s cold-cache nvcc cost during compile
|
||||
// rather than during the GA profiling loop, where it would dominate
|
||||
// the candidate's measured runtime and make the GA reject FlashInfer.
|
||||
let _ = jit::ensure_compiled(head_dim);
|
||||
|
||||
// Walk the mask e-graph chain to recover qo_indptr / kv_indptr Input nodes.
|
||||
// input_enodes: [Q, K_cache, V_cache, gather_idx, mask]
|
||||
let mask_node = input_enodes[4];
|
||||
let indptrs = find_indptrs::find_indptr_inputs(egraph, mask_node);
|
||||
|
||||
// Build final inputs: [Q, K_cache, V_cache, gather_idx, mask, qo_indptr, kv_indptr]
|
||||
let mut final_inputs = input_enodes;
|
||||
final_inputs.push(indptrs.qo_indptr);
|
||||
final_inputs.push(indptrs.kv_indptr);
|
||||
|
||||
let op = LLIROp::new::<dyn HostOp>(Box::new(extracted) as Box<dyn HostOp>);
|
||||
(op, final_inputs)
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
impl HostOp for FlashInferAttention {
|
||||
fn execute(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
self_node: NodeIndex,
|
||||
inputs: &[NodeIndex],
|
||||
buffers: &FxHashMap<NodeIndex, DeviceBuffer>,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> anyhow::Result<()> {
|
||||
let lib = jit::ensure_compiled(self.head_dim);
|
||||
|
||||
let total_q_tokens = self
|
||||
.batch_dim
|
||||
.exec(dyn_map)
|
||||
.ok_or_else(|| anyhow::anyhow!("FlashInferAttention batch_dim is unresolved"))?;
|
||||
let c = *dyn_map
|
||||
.get(&'c')
|
||||
.ok_or_else(|| anyhow::anyhow!("FlashInferAttention requires dynamic dim 'c'"))?;
|
||||
let r = *dyn_map
|
||||
.get(&'r')
|
||||
.ok_or_else(|| anyhow::anyhow!("FlashInferAttention requires dynamic dim 'r'"))?;
|
||||
|
||||
if inputs.len() < 7 {
|
||||
anyhow::bail!(
|
||||
"FlashInferAttention expects 7 inputs (Q, K, V, flat_idx, mask, qo_indptr, kv_indptr), got {}",
|
||||
inputs.len()
|
||||
);
|
||||
}
|
||||
|
||||
let get_buf = |name: &str, node: NodeIndex| -> anyhow::Result<DeviceBuffer> {
|
||||
buffers.get(&node).copied().ok_or_else(|| {
|
||||
anyhow::anyhow!("FlashInferAttention missing {name} buffer for {node:?}")
|
||||
})
|
||||
};
|
||||
|
||||
let q_buf = get_buf("Q", inputs[0])?;
|
||||
let k_buf = get_buf("K_cache", inputs[1])?;
|
||||
let v_buf = get_buf("V_cache", inputs[2])?;
|
||||
let flat_idx_buf = get_buf("flat_gather_idx", inputs[3])?;
|
||||
// inputs[4] = mask (unused by FlashInfer — indptrs replace it)
|
||||
let kv_indptr_buf = get_buf("kv_indptr", inputs[6])?;
|
||||
let out_buf = get_buf("output", self_node)?;
|
||||
|
||||
// Derive batch_size (num sequences) from r = indptr length.
|
||||
let batch_size = r.saturating_sub(1);
|
||||
|
||||
let _span = span!(
|
||||
Level::TRACE,
|
||||
"FlashInferAttention",
|
||||
total_q_tokens,
|
||||
batch_size,
|
||||
self.num_qo_heads,
|
||||
self.num_kv_heads,
|
||||
self.head_dim,
|
||||
)
|
||||
.entered();
|
||||
|
||||
let kv_dim = self.num_kv_heads * self.head_dim;
|
||||
let cu_stream = stream.cu_stream() as *mut std::ffi::c_void;
|
||||
|
||||
// Extract slot indices (one per context page) from the flat gather index.
|
||||
let indices_buf = unsafe { stream.alloc::<u8>(c.max(1) * std::mem::size_of::<i32>())? };
|
||||
let (indices_ptr, _idx_guard) = indices_buf.device_ptr(stream);
|
||||
|
||||
if c > 0 {
|
||||
unsafe {
|
||||
(lib.extract_slot_indices)(
|
||||
flat_idx_buf.ptr() as *const i32,
|
||||
indices_ptr as *mut i32,
|
||||
c as i32,
|
||||
kv_dim as i32,
|
||||
cu_stream,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Read kv_indptr to host for the plan phase.
|
||||
let kv_indptr_bytes = r * 4;
|
||||
let mut kv_indptr_host_bytes = vec![0u8; kv_indptr_bytes];
|
||||
unsafe {
|
||||
result::memcpy_dtoh_async(
|
||||
&mut kv_indptr_host_bytes,
|
||||
kv_indptr_buf.ptr(),
|
||||
stream.cu_stream(),
|
||||
)?;
|
||||
}
|
||||
stream.synchronize()?;
|
||||
let kv_indptr_host: Vec<i32> = unsafe {
|
||||
let mut v = std::mem::ManuallyDrop::new(kv_indptr_host_bytes);
|
||||
Vec::from_raw_parts(v.as_mut_ptr() as *mut i32, r, r)
|
||||
};
|
||||
|
||||
// kv_last_page_len = [1; batch_size] when page_size=1.
|
||||
let last_page_host: Vec<i32> = vec![1; batch_size];
|
||||
let last_page_dev: CudaSlice<u8> = if batch_size > 0 {
|
||||
stream.clone_htod(unsafe {
|
||||
std::slice::from_raw_parts(
|
||||
last_page_host.as_ptr() as *const u8,
|
||||
last_page_host.len() * std::mem::size_of::<i32>(),
|
||||
)
|
||||
})?
|
||||
} else {
|
||||
unsafe { stream.alloc::<u8>(1)? }
|
||||
};
|
||||
let (last_page_ptr, _lp_guard) = last_page_dev.device_ptr(stream);
|
||||
|
||||
// Global shared workspaces (allocated once across all op instances to
|
||||
// amortize the ~4ms first-allocation cost during GA profiling).
|
||||
static FLOAT_WORKSPACE: OnceLock<CudaSlice<u8>> = OnceLock::new();
|
||||
static INT_WORKSPACE: OnceLock<CudaSlice<u8>> = OnceLock::new();
|
||||
let float_ws = FLOAT_WORKSPACE
|
||||
.get_or_init(|| unsafe { stream.alloc::<u8>(FLOAT_WORKSPACE_SIZE).unwrap() });
|
||||
let int_ws = INT_WORKSPACE
|
||||
.get_or_init(|| unsafe { stream.alloc::<u8>(INT_WORKSPACE_SIZE).unwrap() });
|
||||
let page_locked_ws = PAGE_LOCKED_WORKSPACE.get_or_init(|| unsafe {
|
||||
let mut ptr: *mut std::ffi::c_void = std::ptr::null_mut();
|
||||
let status = libc::posix_memalign(&mut ptr, 4096, INT_WORKSPACE_SIZE);
|
||||
assert_eq!(status, 0, "Failed to allocate page-locked workspace");
|
||||
let cuda_status = cuda_pin_memory(ptr, INT_WORKSPACE_SIZE);
|
||||
assert_eq!(cuda_status, 0, "Failed to pin memory");
|
||||
PageLockedPtr(ptr as *mut u8)
|
||||
});
|
||||
|
||||
let (float_ws_ptr, _fws_guard) = float_ws.device_ptr(stream);
|
||||
let (int_ws_ptr, _iws_guard) = int_ws.device_ptr(stream);
|
||||
|
||||
// FlashInfer decode writes (total_q_tokens, heads, dim);
|
||||
// luminal expects (heads, total_q_tokens, dim) — transpose at the end.
|
||||
let output_elems = total_q_tokens * self.num_qo_heads * self.head_dim;
|
||||
let temp_out_buf =
|
||||
unsafe { stream.alloc::<u8>(output_elems * std::mem::size_of::<f32>())? };
|
||||
let (temp_out_ptr, _tmp_guard) = temp_out_buf.device_ptr(stream);
|
||||
|
||||
// PrefillPlanInfo has 15 entries, DecodePlanInfo fewer — 16 is enough.
|
||||
let mut plan_info_buf = [0i64; 16];
|
||||
let mut plan_info_len: i32 = 0;
|
||||
|
||||
// ── BatchDecode path ──
|
||||
// Prefill kernels require fp16/bf16 tensor-core MMA; the C API returns -1
|
||||
// when called from the fp32 pipeline. We only use decode here.
|
||||
let plan_ret = unsafe {
|
||||
(lib.plan)(
|
||||
float_ws_ptr as *mut std::ffi::c_void,
|
||||
FLOAT_WORKSPACE_SIZE,
|
||||
int_ws_ptr as *mut std::ffi::c_void,
|
||||
INT_WORKSPACE_SIZE,
|
||||
page_locked_ws.0 as *mut std::ffi::c_void,
|
||||
kv_indptr_host.as_ptr() as *mut i32,
|
||||
batch_size as i32,
|
||||
self.num_qo_heads as i32,
|
||||
self.num_kv_heads as i32,
|
||||
self.page_size as i32,
|
||||
self.head_dim as i32,
|
||||
cu_stream,
|
||||
plan_info_buf.as_mut_ptr(),
|
||||
&mut plan_info_len,
|
||||
)
|
||||
};
|
||||
if plan_ret != 0 {
|
||||
return Err(anyhow::anyhow!(
|
||||
"FlashInfer decode plan failed with error code {plan_ret}"
|
||||
));
|
||||
}
|
||||
|
||||
let mut plan_info = self.plan_info.lock().unwrap();
|
||||
plan_info.clear();
|
||||
plan_info.extend_from_slice(&plan_info_buf[..plan_info_len as usize]);
|
||||
|
||||
let run_ret = unsafe {
|
||||
(lib.run)(
|
||||
float_ws_ptr as *mut std::ffi::c_void,
|
||||
FLOAT_WORKSPACE_SIZE,
|
||||
int_ws_ptr as *mut std::ffi::c_void,
|
||||
plan_info.as_mut_ptr(),
|
||||
plan_info.len() as i32,
|
||||
q_buf.ptr() as *mut f32,
|
||||
k_buf.ptr() as *mut f32,
|
||||
v_buf.ptr() as *mut f32,
|
||||
kv_indptr_buf.ptr() as *mut i32,
|
||||
indices_ptr as *mut i32,
|
||||
last_page_ptr as *mut i32,
|
||||
temp_out_ptr as *mut f32,
|
||||
batch_size as i32,
|
||||
self.num_qo_heads as i32,
|
||||
self.num_kv_heads as i32,
|
||||
self.page_size as i32,
|
||||
self.head_dim as i32,
|
||||
cu_stream,
|
||||
)
|
||||
};
|
||||
drop(plan_info);
|
||||
|
||||
if run_ret != 0 {
|
||||
return Err(anyhow::anyhow!(
|
||||
"FlashInfer decode run failed with error code {run_ret}"
|
||||
));
|
||||
}
|
||||
|
||||
// Transpose (total_q_tokens, heads, dim) → (heads, total_q_tokens, dim)
|
||||
unsafe {
|
||||
(lib.transpose_output)(
|
||||
temp_out_ptr as *const f32,
|
||||
out_buf.ptr() as *mut f32,
|
||||
total_q_tokens as i32,
|
||||
self.num_qo_heads as i32,
|
||||
self.head_dim as i32,
|
||||
cu_stream,
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
self.batch_dim * self.num_qo_heads * self.head_dim
|
||||
}
|
||||
|
||||
fn output_bytes(&self) -> Expression {
|
||||
self.output_size() * 4
|
||||
}
|
||||
|
||||
fn stats_name(&self) -> Option<&'static str> {
|
||||
Some("FlashInferAttention")
|
||||
}
|
||||
}
|
||||
|
||||
/// Pin host memory for CUDA async memcpy.
|
||||
///
|
||||
/// `cudaHostRegister` lives in libcudart, which cudarc doesn't link to our
|
||||
/// binary. Resolve it via `dlopen`/`dlsym` so we don't need a build script or
|
||||
/// a `#[link]` directive — keeping the crate buildable without any nvcc-side
|
||||
/// dependencies.
|
||||
unsafe fn cuda_pin_memory(ptr: *mut std::ffi::c_void, size: usize) -> i32 {
|
||||
type HostRegisterFn = unsafe extern "C" fn(*mut std::ffi::c_void, usize, u32) -> i32;
|
||||
static FN: OnceLock<usize> = OnceLock::new();
|
||||
|
||||
let raw = *FN.get_or_init(|| unsafe {
|
||||
let lib = [
|
||||
"libcudart.so",
|
||||
"libcudart.so.13",
|
||||
"libcudart.so.12",
|
||||
"libcudart.so.11",
|
||||
]
|
||||
.iter()
|
||||
.find_map(|n| libloading::Library::new(*n).ok())
|
||||
.expect("FlashInfer: could not dlopen libcudart for cudaHostRegister");
|
||||
let sym: libloading::Symbol<HostRegisterFn> = lib
|
||||
.get(b"cudaHostRegister\0")
|
||||
.expect("FlashInfer: libcudart missing cudaHostRegister symbol");
|
||||
let ptr = *sym as *const () as usize;
|
||||
// Keep libcudart resident for the process lifetime so the function
|
||||
// pointer remains valid.
|
||||
std::mem::forget(lib);
|
||||
ptr
|
||||
});
|
||||
let f: HostRegisterFn = unsafe { std::mem::transmute(raw) };
|
||||
// cudaHostRegisterDefault = 0
|
||||
unsafe { f(ptr, size, 0) }
|
||||
}
|
||||
@@ -1,357 +0,0 @@
|
||||
// FlashInfer batch decode + prefill wrapper for luminal_cuda.
|
||||
// JIT-compiled at runtime with -DLUMINAL_HEAD_DIM=N.
|
||||
//
|
||||
// Decode: instantiated for f32 (scalar vectorized dot products, no tensor cores).
|
||||
// Prefill: instantiated for f16 (requires tensor core MMA + ldmatrix).
|
||||
// The C API accepts fp32 buffers; cast kernels convert fp32↔fp16 at the boundary.
|
||||
//
|
||||
// NHD layout. GQA group_size and page_size are runtime parameters.
|
||||
|
||||
#ifndef LUMINAL_HEAD_DIM
|
||||
#error "LUMINAL_HEAD_DIM must be defined (e.g. -DLUMINAL_HEAD_DIM=128)"
|
||||
#endif
|
||||
|
||||
// Include utils.cuh first to get the original DISPATCH_HEAD_DIM, then override it
|
||||
// to only instantiate our specific HEAD_DIM. This avoids a compile error in
|
||||
// cascade.cuh where HEAD_DIM=512 + f32 triggers vec_size=16, vec_bits=512
|
||||
// which exceeds cp_async's 256-bit limit.
|
||||
#include <flashinfer/utils.cuh>
|
||||
#undef DISPATCH_HEAD_DIM
|
||||
#define DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, ...) \
|
||||
{ \
|
||||
constexpr size_t HEAD_DIM = LUMINAL_HEAD_DIM; \
|
||||
__VA_ARGS__ \
|
||||
}
|
||||
|
||||
#include <flashinfer/attention/scheduler.cuh>
|
||||
#include <flashinfer/attention/decode.cuh>
|
||||
#include <flashinfer/attention/default_decode_params.cuh>
|
||||
#include <flashinfer/attention/prefill.cuh>
|
||||
#include <flashinfer/attention/default_prefill_params.cuh>
|
||||
#include <flashinfer/attention/mask.cuh>
|
||||
#include <flashinfer/attention/variants.cuh>
|
||||
#include <flashinfer/page.cuh>
|
||||
#include <flashinfer/pos_enc.cuh>
|
||||
|
||||
#include "wrapper.h"
|
||||
|
||||
#include <cstring>
|
||||
#include <vector>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
using namespace flashinfer;
|
||||
|
||||
// ── Decode types (f32) ──
|
||||
using DTypeQ = float;
|
||||
using DTypeKV = float;
|
||||
using DTypeO = float;
|
||||
using IdType = int32_t;
|
||||
|
||||
// ── Prefill types (f16 compute, fp32 external interface) ──
|
||||
using PrefillDTypeQ = half;
|
||||
using PrefillDTypeKV = half;
|
||||
using PrefillDTypeO = half;
|
||||
|
||||
constexpr uint32_t HEAD_DIM = LUMINAL_HEAD_DIM;
|
||||
constexpr PosEncodingMode POS_ENCODING_MODE = PosEncodingMode::kNone;
|
||||
|
||||
// Attention variants
|
||||
using Variant = DefaultAttention</*use_custom_mask=*/false,
|
||||
/*use_sliding_window=*/false,
|
||||
/*use_logits_soft_cap=*/false,
|
||||
/*use_alibi=*/false>;
|
||||
|
||||
using CausalVariant = DefaultAttention</*use_custom_mask=*/false,
|
||||
/*use_sliding_window=*/false,
|
||||
/*use_logits_soft_cap=*/false,
|
||||
/*use_alibi=*/false>;
|
||||
|
||||
// Decode params (f32)
|
||||
using DecodeParams = BatchDecodeParams<DTypeQ, DTypeKV, DTypeO, IdType>;
|
||||
|
||||
// Prefill params (f16)
|
||||
using PrefillParams = BatchPrefillPagedParams<PrefillDTypeQ, PrefillDTypeKV, PrefillDTypeO, IdType>;
|
||||
|
||||
// Forward declarations
|
||||
namespace flashinfer {
|
||||
template <uint32_t HEAD_DIM, PosEncodingMode POS_ENCODING_MODE, typename AttentionVariant,
|
||||
typename Params>
|
||||
cudaError_t BatchDecodeWithPagedKVCacheDispatched(Params params, typename Params::DTypeO* tmp_v,
|
||||
float* tmp_s, bool enable_pdl,
|
||||
cudaStream_t stream);
|
||||
|
||||
template <uint32_t CTA_TILE_Q, uint32_t HEAD_DIM_QK, uint32_t HEAD_DIM_VO,
|
||||
PosEncodingMode POS_ENCODING_MODE, bool USE_FP16_QK_REDUCTION,
|
||||
MaskMode MASK_MODE, typename AttentionVariant, typename Params>
|
||||
cudaError_t BatchPrefillWithPagedKVCacheDispatched(Params params, typename Params::DTypeO* tmp_v,
|
||||
float* tmp_s, bool enable_pdl,
|
||||
cudaStream_t stream);
|
||||
}
|
||||
|
||||
// Explicit instantiation: decode kernel (f32)
|
||||
template cudaError_t flashinfer::BatchDecodeWithPagedKVCacheDispatched<
|
||||
HEAD_DIM, POS_ENCODING_MODE, Variant, DecodeParams>(
|
||||
DecodeParams params, DTypeO* tmp_v, float* tmp_s, bool enable_pdl, cudaStream_t stream);
|
||||
|
||||
// Explicit instantiation: prefill kernels (f16, causal mask, CTA_TILE_Q=16/64/128)
|
||||
template cudaError_t flashinfer::BatchPrefillWithPagedKVCacheDispatched<
|
||||
16, HEAD_DIM, HEAD_DIM, POS_ENCODING_MODE, false, MaskMode::kCausal, CausalVariant, PrefillParams>(
|
||||
PrefillParams params, PrefillDTypeO* tmp_v, float* tmp_s, bool enable_pdl, cudaStream_t stream);
|
||||
|
||||
template cudaError_t flashinfer::BatchPrefillWithPagedKVCacheDispatched<
|
||||
64, HEAD_DIM, HEAD_DIM, POS_ENCODING_MODE, false, MaskMode::kCausal, CausalVariant, PrefillParams>(
|
||||
PrefillParams params, PrefillDTypeO* tmp_v, float* tmp_s, bool enable_pdl, cudaStream_t stream);
|
||||
|
||||
template cudaError_t flashinfer::BatchPrefillWithPagedKVCacheDispatched<
|
||||
128, HEAD_DIM, HEAD_DIM, POS_ENCODING_MODE, false, MaskMode::kCausal, CausalVariant, PrefillParams>(
|
||||
PrefillParams params, PrefillDTypeO* tmp_v, float* tmp_s, bool enable_pdl, cudaStream_t stream);
|
||||
|
||||
// ── fp32 ↔ fp16 cast kernels ──
|
||||
|
||||
__global__ void cast_f32_to_f16_kernel(const float* src, half* dst, size_t n) {
|
||||
size_t i = (size_t)blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (i < n) dst[i] = __float2half(src[i]);
|
||||
}
|
||||
|
||||
__global__ void cast_f16_to_f32_kernel(const half* src, float* dst, size_t n) {
|
||||
size_t i = (size_t)blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (i < n) dst[i] = __half2float(src[i]);
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
|
||||
int flashinfer_batch_decode_plan(
|
||||
void* float_workspace, size_t float_ws_size,
|
||||
void* int_workspace, size_t int_ws_size,
|
||||
void* page_locked_int_workspace,
|
||||
int32_t* indptr_h, int batch_size,
|
||||
int num_qo_heads, int num_kv_heads, int page_size, int head_dim,
|
||||
cudaStream_t stream,
|
||||
int64_t* plan_info_out, int* plan_info_len_out)
|
||||
{
|
||||
(void)head_dim; // fixed at compile time
|
||||
|
||||
DecodePlanInfo plan_info;
|
||||
uint32_t group_size = num_qo_heads / num_kv_heads;
|
||||
|
||||
// We need to dispatch on GROUP_SIZE to get the right work estimation function
|
||||
cudaError_t status = cudaSuccess;
|
||||
|
||||
// Use a lambda to dispatch on group size
|
||||
auto do_plan = [&]<uint32_t GROUP_SIZE>() -> cudaError_t {
|
||||
auto work_estimation_func =
|
||||
BatchDecodeWithPagedKVCacheWorkEstimationDispatched<
|
||||
GROUP_SIZE, HEAD_DIM, POS_ENCODING_MODE, Variant, DecodeParams>;
|
||||
return DecodePlan<HEAD_DIM, POS_ENCODING_MODE, Variant, DecodeParams>(
|
||||
float_workspace, float_ws_size,
|
||||
int_workspace, page_locked_int_workspace,
|
||||
int_ws_size, plan_info, indptr_h,
|
||||
(uint32_t)batch_size, (uint32_t)num_qo_heads,
|
||||
(uint32_t)page_size, /*enable_cuda_graph=*/false,
|
||||
stream, work_estimation_func);
|
||||
};
|
||||
|
||||
switch (group_size) {
|
||||
case 1: status = do_plan.operator()<1>(); break;
|
||||
case 2: status = do_plan.operator()<2>(); break;
|
||||
case 4: status = do_plan.operator()<4>(); break;
|
||||
case 8: status = do_plan.operator()<8>(); break;
|
||||
default: return -1; // unsupported group size
|
||||
}
|
||||
|
||||
if (status != cudaSuccess) return (int)status;
|
||||
|
||||
auto vec = plan_info.ToVector();
|
||||
*plan_info_len_out = (int)vec.size();
|
||||
std::memcpy(plan_info_out, vec.data(), vec.size() * sizeof(int64_t));
|
||||
return 0;
|
||||
}
|
||||
|
||||
int flashinfer_batch_decode_run(
|
||||
void* float_workspace, size_t float_ws_size,
|
||||
void* int_workspace,
|
||||
int64_t* plan_info_vec, int plan_info_len,
|
||||
float* q,
|
||||
float* k_cache,
|
||||
float* v_cache,
|
||||
int32_t* kv_indptr,
|
||||
int32_t* kv_indices,
|
||||
int32_t* kv_last_page_len,
|
||||
float* output,
|
||||
int batch_size,
|
||||
int num_qo_heads, int num_kv_heads, int page_size, int head_dim,
|
||||
cudaStream_t stream)
|
||||
{
|
||||
(void)head_dim; // fixed at compile time
|
||||
|
||||
DecodePlanInfo plan_info;
|
||||
plan_info.FromVector(std::vector<int64_t>(plan_info_vec, plan_info_vec + plan_info_len));
|
||||
|
||||
// Construct paged_kv_t with NHD layout
|
||||
paged_kv_t<DTypeKV, IdType> paged_kv(
|
||||
(uint32_t)num_kv_heads,
|
||||
(uint32_t)page_size,
|
||||
HEAD_DIM,
|
||||
(uint32_t)batch_size,
|
||||
QKVLayout::kNHD,
|
||||
k_cache,
|
||||
v_cache,
|
||||
kv_indices,
|
||||
kv_indptr,
|
||||
kv_last_page_len);
|
||||
|
||||
DecodeParams params;
|
||||
params.q = q;
|
||||
params.q_rope_offset = nullptr;
|
||||
params.paged_kv = paged_kv;
|
||||
params.o = output;
|
||||
params.lse = nullptr;
|
||||
params.maybe_alibi_slopes = nullptr;
|
||||
params.padded_batch_size = plan_info.padded_batch_size;
|
||||
params.num_qo_heads = (uint32_t)num_qo_heads;
|
||||
// Q buffer is (batch, num_qo_heads * head_dim) flat — the graph's split_dims + transpose
|
||||
// are stride tricks, no data movement. So the actual memory layout is (batch, heads, dim).
|
||||
params.q_stride_n = num_qo_heads * HEAD_DIM;
|
||||
params.q_stride_h = HEAD_DIM;
|
||||
params.window_left = -1; // no sliding window
|
||||
params.logits_soft_cap = 0.0f;
|
||||
params.sm_scale = 1.0f / sqrtf((float)HEAD_DIM);
|
||||
params.rope_rcp_scale = 1.0f;
|
||||
params.rope_rcp_theta = 1.0f;
|
||||
|
||||
// Set plan info pointers
|
||||
params.request_indices =
|
||||
GetPtrFromBaseOffset<IdType>(int_workspace, plan_info.request_indices_offset);
|
||||
params.kv_tile_indices =
|
||||
GetPtrFromBaseOffset<IdType>(int_workspace, plan_info.kv_tile_indices_offset);
|
||||
params.o_indptr =
|
||||
GetPtrFromBaseOffset<IdType>(int_workspace, plan_info.o_indptr_offset);
|
||||
params.kv_chunk_size_ptr =
|
||||
GetPtrFromBaseOffset<IdType>(int_workspace, plan_info.kv_chunk_size_ptr_offset);
|
||||
params.block_valid_mask = nullptr;
|
||||
params.partition_kv = false;
|
||||
|
||||
DTypeO* tmp_v = nullptr;
|
||||
float* tmp_s = nullptr;
|
||||
|
||||
if (plan_info.split_kv) {
|
||||
tmp_v = GetPtrFromBaseOffset<DTypeO>(float_workspace, plan_info.v_offset);
|
||||
tmp_s = GetPtrFromBaseOffset<float>(float_workspace, plan_info.s_offset);
|
||||
if (plan_info.enable_cuda_graph) {
|
||||
params.block_valid_mask =
|
||||
GetPtrFromBaseOffset<bool>(int_workspace, plan_info.block_valid_mask_offset);
|
||||
}
|
||||
}
|
||||
|
||||
cudaError_t status =
|
||||
flashinfer::BatchDecodeWithPagedKVCacheDispatched<HEAD_DIM, POS_ENCODING_MODE, Variant>(
|
||||
params, tmp_v, tmp_s, /*enable_pdl=*/false, stream);
|
||||
|
||||
return (int)status;
|
||||
}
|
||||
|
||||
// ═══════════════════════════════════════════════════════════
|
||||
// BatchPrefill (fp16/bf16 only — tensor core MMA requires 16-bit inputs)
|
||||
// ═══════════════════════════════════════════════════════════
|
||||
//
|
||||
// The prefill kernel templates are instantiated above for fp16. These C API
|
||||
// functions accept fp32 pointers (matching the current luminal pipeline) but
|
||||
// return -1 to indicate that fp32 prefill is not supported. When native fp16
|
||||
// support is added, these will accept fp16 pointers and call through to the
|
||||
// instantiated templates.
|
||||
|
||||
int flashinfer_batch_prefill_plan(
|
||||
void*, size_t, void*, size_t, void*,
|
||||
int32_t*, int32_t*, int, int,
|
||||
int, int, int, int, cudaStream_t,
|
||||
int64_t*, int*)
|
||||
{
|
||||
return -1; // fp32 not supported — requires fp16/bf16
|
||||
}
|
||||
|
||||
int flashinfer_batch_prefill_run(
|
||||
void*, size_t, void*,
|
||||
int64_t*, int,
|
||||
float*, float*, float*,
|
||||
int32_t*, int32_t*, int32_t*, int32_t*,
|
||||
float*, int, int, int, int, int, int, cudaStream_t)
|
||||
{
|
||||
return -1; // fp32 not supported — requires fp16/bf16
|
||||
}
|
||||
|
||||
} // extern "C"
|
||||
|
||||
// ── Slot index extraction kernel (outside extern "C" for __global__) ──
|
||||
|
||||
__global__ void extract_slot_indices_kernel(
|
||||
const int32_t* flat_idx, int32_t* out, int c, int kv_dim) {
|
||||
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (i < c) out[i] = flat_idx[i * kv_dim] / kv_dim;
|
||||
}
|
||||
|
||||
extern "C" void flashinfer_extract_slot_indices(
|
||||
const int32_t* flat_idx, int32_t* out, int c, int kv_dim,
|
||||
cudaStream_t stream) {
|
||||
if (c == 0) return;
|
||||
int threads = 256;
|
||||
int blocks = (c + threads - 1) / threads;
|
||||
extract_slot_indices_kernel<<<blocks, threads, 0, stream>>>(
|
||||
flat_idx, out, c, kv_dim);
|
||||
}
|
||||
|
||||
// ── Derive CSR indptr from attention mask ──
|
||||
// Mask is (s, c) f32. Entries > -1e9 are "valid" (0.0), rest are -inf.
|
||||
// Per-row count of valid entries = context length for that sequence.
|
||||
// Output: indptr[0..=s] with indptr[0]=0 and indptr[i+1] = indptr[i] + ctx_len[i].
|
||||
// Single thread is fine since s is tiny (batch_size during decode, typically 1-8).
|
||||
|
||||
__global__ void derive_indptr_kernel(
|
||||
const float* mask, int32_t* indptr, int s, int c) {
|
||||
if (threadIdx.x != 0 || blockIdx.x != 0) return;
|
||||
indptr[0] = 0;
|
||||
for (int i = 0; i < s; i++) {
|
||||
int count = 0;
|
||||
for (int j = 0; j < c; j++) {
|
||||
if (mask[i * c + j] > -1e9f) count++;
|
||||
}
|
||||
indptr[i + 1] = indptr[i] + count;
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" void flashinfer_derive_indptr_from_mask(
|
||||
const float* mask, int32_t* indptr, int s, int c,
|
||||
cudaStream_t stream) {
|
||||
if (s == 0) return;
|
||||
derive_indptr_kernel<<<1, 1, 0, stream>>>(mask, indptr, s, c);
|
||||
}
|
||||
|
||||
// ── Output transpose: (batch, heads, dim) → (heads, batch, dim) ──
|
||||
// FlashInfer writes output as (batch, heads, dim) but Luminal expects (heads, batch, dim).
|
||||
// For batch=1 these are identical; for batch>1 we need an explicit transpose.
|
||||
|
||||
__global__ void transpose_bhd_to_hbd_kernel(
|
||||
const float* src, float* dst, int batch, int heads, int dim) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int total = batch * heads * dim;
|
||||
if (idx >= total) return;
|
||||
|
||||
// Decompose linear index into (b, h, d) for src layout
|
||||
int d = idx % dim;
|
||||
int h = (idx / dim) % heads;
|
||||
int b = idx / (heads * dim);
|
||||
|
||||
// Write to (h, b, d) layout in dst
|
||||
dst[h * batch * dim + b * dim + d] = src[idx];
|
||||
}
|
||||
|
||||
extern "C" void flashinfer_transpose_output(
|
||||
const float* src, float* dst,
|
||||
int batch, int heads, int dim,
|
||||
cudaStream_t stream) {
|
||||
int total = batch * heads * dim;
|
||||
if (total == 0) return;
|
||||
int threads = 256;
|
||||
int blocks = (total + threads - 1) / threads;
|
||||
transpose_bhd_to_hbd_kernel<<<blocks, threads, 0, stream>>>(
|
||||
src, dst, batch, heads, dim);
|
||||
}
|
||||
@@ -1,93 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <stdint.h>
|
||||
#include <stddef.h>
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
// Plan phase: CPU-side scheduling. Must call before each new batch config.
|
||||
// Returns 0 on success, non-zero on failure.
|
||||
int flashinfer_batch_decode_plan(
|
||||
void* float_workspace, size_t float_ws_size,
|
||||
void* int_workspace, size_t int_ws_size,
|
||||
void* page_locked_int_workspace,
|
||||
int32_t* indptr_h, int batch_size,
|
||||
int num_qo_heads, int num_kv_heads, int page_size, int head_dim,
|
||||
cudaStream_t stream,
|
||||
int64_t* plan_info_out, int* plan_info_len_out);
|
||||
|
||||
// Run phase: GPU kernel launch.
|
||||
// Returns 0 on success, non-zero on failure.
|
||||
int flashinfer_batch_decode_run(
|
||||
void* float_workspace, size_t float_ws_size,
|
||||
void* int_workspace,
|
||||
int64_t* plan_info_vec, int plan_info_len,
|
||||
float* q, // [batch_size, num_qo_heads, head_dim]
|
||||
float* k_cache, // [num_pages, page_size, num_kv_heads, head_dim] (NHD)
|
||||
float* v_cache, // same layout
|
||||
int32_t* kv_indptr, // [batch_size + 1]
|
||||
int32_t* kv_indices, // [total_pages]
|
||||
int32_t* kv_last_page_len, // [batch_size]
|
||||
float* output, // [batch_size, num_qo_heads, head_dim]
|
||||
int batch_size,
|
||||
int num_qo_heads, int num_kv_heads, int page_size, int head_dim,
|
||||
cudaStream_t stream);
|
||||
|
||||
// Extract slot indices from a flat gather index tensor.
|
||||
// flat_idx shape: (c, kv_dim) i32, out shape: (c,) i32.
|
||||
// out[i] = flat_idx[i * kv_dim] / kv_dim
|
||||
void flashinfer_extract_slot_indices(
|
||||
const int32_t* flat_idx, int32_t* out, int c, int kv_dim,
|
||||
cudaStream_t stream);
|
||||
|
||||
// Derive CSR indptr from attention mask.
|
||||
// mask shape: (s, c) f32. Entries > -1e9 are valid.
|
||||
// indptr shape: (s + 1,) i32. indptr[0] = 0, indptr[i+1] = cumsum of valid counts.
|
||||
void flashinfer_derive_indptr_from_mask(
|
||||
const float* mask, int32_t* indptr, int s, int c,
|
||||
cudaStream_t stream);
|
||||
|
||||
// Transpose output from (batch, heads, dim) to (heads, batch, dim).
|
||||
void flashinfer_transpose_output(
|
||||
const float* src, float* dst,
|
||||
int batch, int heads, int dim,
|
||||
cudaStream_t stream);
|
||||
|
||||
// ── BatchPrefill with Paged KV Cache ──
|
||||
|
||||
// Plan phase for batch prefill.
|
||||
// Returns 0 on success, non-zero on failure.
|
||||
int flashinfer_batch_prefill_plan(
|
||||
void* float_workspace, size_t float_ws_size,
|
||||
void* int_workspace, size_t int_ws_size,
|
||||
void* page_locked_int_workspace,
|
||||
int32_t* qo_indptr_h, int32_t* kv_indptr_h,
|
||||
int total_num_rows, int batch_size,
|
||||
int num_qo_heads, int num_kv_heads, int page_size, int head_dim,
|
||||
cudaStream_t stream,
|
||||
int64_t* plan_info_out, int* plan_info_len_out);
|
||||
|
||||
// Run phase for batch prefill.
|
||||
// Returns 0 on success, non-zero on failure.
|
||||
int flashinfer_batch_prefill_run(
|
||||
void* float_workspace, size_t float_ws_size,
|
||||
void* int_workspace,
|
||||
int64_t* plan_info_vec, int plan_info_len,
|
||||
float* q, // [total_num_rows, num_qo_heads, head_dim]
|
||||
float* k_cache, // [num_pages, page_size, num_kv_heads, head_dim] (NHD)
|
||||
float* v_cache, // same layout
|
||||
int32_t* qo_indptr, // [batch_size + 1] on GPU
|
||||
int32_t* kv_indptr, // [batch_size + 1] on GPU
|
||||
int32_t* kv_indices, // [total_pages]
|
||||
int32_t* kv_last_page_len, // [batch_size]
|
||||
float* output, // [total_num_rows, num_qo_heads, head_dim]
|
||||
int total_num_rows, int batch_size,
|
||||
int num_qo_heads, int num_kv_heads, int page_size, int head_dim,
|
||||
cudaStream_t stream);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
@@ -1,122 +1,17 @@
|
||||
use std::{fmt::Debug, sync::Arc};
|
||||
|
||||
use crate::cudarc::driver::{CudaStream, DriverError, result};
|
||||
use crate::cudarc::driver::{CudaSlice, CudaStream};
|
||||
use luminal::{op::EgglogOp, prelude::*};
|
||||
pub mod compute_attn_mask;
|
||||
mod cublas;
|
||||
mod cublaslt;
|
||||
pub mod flashinfer;
|
||||
pub mod moe;
|
||||
|
||||
pub use compute_attn_mask::ComputeAttnMask;
|
||||
|
||||
pub type Ops = (
|
||||
// cublas::CuBlasSgemmV2,
|
||||
cublaslt::CuBlasLt,
|
||||
moe::GLUMoE,
|
||||
compute_attn_mask::ComputeAttnMask,
|
||||
flashinfer::FlashInferAttention,
|
||||
);
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) type CublasLtTypeTuple = (
|
||||
luminal::dtype::DType,
|
||||
luminal::dtype::DType,
|
||||
luminal::dtype::DType,
|
||||
luminal::dtype::DType,
|
||||
&'static str,
|
||||
luminal::dtype::DType,
|
||||
);
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn cublaslt_type_tuple(op: &dyn HostOp) -> Option<CublasLtTypeTuple> {
|
||||
op.as_any()
|
||||
.downcast_ref::<cublaslt::CuBlasLt>()
|
||||
.map(cublaslt::CuBlasLt::type_tuple)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) type CublasLtScaleValues = (f64, f64);
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn cublaslt_scale_values(op: &dyn HostOp) -> Option<CublasLtScaleValues> {
|
||||
op.as_any()
|
||||
.downcast_ref::<cublaslt::CuBlasLt>()
|
||||
.map(cublaslt::CuBlasLt::scale_values)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn cublaslt_epilogue(op: &dyn HostOp) -> Option<&'static str> {
|
||||
op.as_any()
|
||||
.downcast_ref::<cublaslt::CuBlasLt>()
|
||||
.map(cublaslt::CuBlasLt::epilogue)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) type CublasLtMatrixOrders = (&'static str, &'static str, &'static str, &'static str);
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn cublaslt_matrix_orders(op: &dyn HostOp) -> Option<CublasLtMatrixOrders> {
|
||||
op.as_any()
|
||||
.downcast_ref::<cublaslt::CuBlasLt>()
|
||||
.map(cublaslt::CuBlasLt::matrix_orders)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) type CublasLtTransposeOps = (&'static str, &'static str);
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn cublaslt_transpose_ops(op: &dyn HostOp) -> Option<CublasLtTransposeOps> {
|
||||
op.as_any()
|
||||
.downcast_ref::<cublaslt::CuBlasLt>()
|
||||
.map(cublaslt::CuBlasLt::transpose_ops)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn cublaslt_c_d_layouts_match(op: &dyn HostOp) -> Option<bool> {
|
||||
op.as_any()
|
||||
.downcast_ref::<cublaslt::CuBlasLt>()
|
||||
.map(cublaslt::CuBlasLt::c_d_layouts_match)
|
||||
}
|
||||
|
||||
/// Non-owning device buffer handle used by host operations.
|
||||
///
|
||||
/// Runtime-owned intermediates may be a whole `CudaSlice`, a subregion inside
|
||||
/// the reusable arena, or an external pointer. Host ops only need the pointer
|
||||
/// and the logical byte length.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub struct DeviceBuffer {
|
||||
ptr: u64,
|
||||
len: usize,
|
||||
}
|
||||
|
||||
impl DeviceBuffer {
|
||||
pub fn new(ptr: u64, len: usize) -> Self {
|
||||
Self { ptr, len }
|
||||
}
|
||||
|
||||
pub fn ptr(self) -> u64 {
|
||||
self.ptr
|
||||
}
|
||||
|
||||
pub fn len(self) -> usize {
|
||||
self.len
|
||||
}
|
||||
|
||||
pub fn is_empty(self) -> bool {
|
||||
self.len == 0
|
||||
}
|
||||
|
||||
pub fn clone_dtoh(self, stream: &Arc<CudaStream>) -> Result<Vec<u8>, DriverError> {
|
||||
let mut host = vec![0u8; self.len];
|
||||
unsafe {
|
||||
result::memcpy_dtoh_async(&mut host, self.ptr, stream.cu_stream())?;
|
||||
}
|
||||
stream.synchronize()?;
|
||||
Ok(host)
|
||||
}
|
||||
}
|
||||
|
||||
/// Host operations that execute on the CPU but orchestrate GPU work.
|
||||
///
|
||||
/// This includes operations like cuBLAS calls and CUDA graph executions.
|
||||
@@ -134,7 +29,7 @@ pub trait HostOp: Debug + as_any::AsAny + EgglogOp {
|
||||
stream: &Arc<CudaStream>,
|
||||
self_node: NodeIndex,
|
||||
inputs: &[NodeIndex],
|
||||
buffers: &FxHashMap<NodeIndex, DeviceBuffer>,
|
||||
buffers: &FxHashMap<NodeIndex, &CudaSlice<u8>>,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> anyhow::Result<()>;
|
||||
|
||||
@@ -153,15 +48,6 @@ pub trait HostOp: Debug + as_any::AsAny + EgglogOp {
|
||||
vec![]
|
||||
}
|
||||
|
||||
/// Returns relative lifetimes for extra buffer nodes within this host op.
|
||||
///
|
||||
/// The tuple is `(node, first_step, last_step)`, where steps are local to
|
||||
/// this host op's execution. Returning `None` tells the runtime to treat
|
||||
/// every extra buffer as live for the whole host op.
|
||||
fn extra_buffer_lifetimes(&self) -> Option<Vec<(NodeIndex, usize, usize)>> {
|
||||
None
|
||||
}
|
||||
|
||||
/// Returns buffer size requirements for extra nodes (node -> size in elements).
|
||||
///
|
||||
/// Called during buffer allocation to ensure all required buffers exist.
|
||||
|
||||
@@ -5,19 +5,12 @@
|
||||
; mode=1: Gemma-style GELU (gate * sigmoid(1.595769 * gate * (1 + 0.044715 * gate^2)))
|
||||
;
|
||||
; To keep matching fast, we stage through marker states:
|
||||
; 1) Shared expert index/gather markers
|
||||
; 2) Shared gate-up matmul marker
|
||||
; 3) Activation marker (separate swiglu / gemma_gelu paths)
|
||||
; 4) Down matmul marker (separate swiglu / gemma_gelu paths)
|
||||
; 5) Final GLUMoE fusion (separate swiglu / gemma_gelu rules)
|
||||
; 1) Shared gate-up matmul marker
|
||||
; 2) Activation marker (separate swiglu / gemma_gelu paths)
|
||||
; 3) Down matmul marker (separate swiglu / gemma_gelu paths)
|
||||
; 4) Final GLUMoE fusion (separate swiglu / gemma_gelu rules)
|
||||
|
||||
(datatype*
|
||||
(GLUMoEExpertIndexState
|
||||
(MkGLUMoEExpertIndexState Expression Expression IR)
|
||||
)
|
||||
(GLUMoEExpertGatherState
|
||||
(MkGLUMoEExpertGatherState Expression Expression IR IR)
|
||||
)
|
||||
(GLUMoEGateUpState
|
||||
(MkGLUMoEGateUpState Expression Expression Expression IR IR IR)
|
||||
)
|
||||
@@ -35,8 +28,6 @@
|
||||
)
|
||||
)
|
||||
|
||||
(function glumoe_expert_index (IR) GLUMoEExpertIndexState :merge new)
|
||||
(function glumoe_expert_gather (IR) GLUMoEExpertGatherState :merge new)
|
||||
(function glumoe_gate_up (IR) GLUMoEGateUpState :merge new)
|
||||
(function glumoe_swiglu (IR) GLUMoESwiGLUState :merge new)
|
||||
(function glumoe_gemma_gelu (IR) GLUMoEGemmaGELUState :merge new)
|
||||
@@ -45,38 +36,17 @@
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?iota_base (Op (Iota ?io ?iota_base_range) (INil)))
|
||||
(= ?mul_base (Op (Mul ?mul_base_shape ?mul_base_a_stride ?mul_base_b_stride ?mul_base_out_stride) (ICons ?topk_idx (ICons ?iota_base (INil)))))
|
||||
(= ?iota_within (Op (Iota (MIter) ?iota_within_range) (INil)))
|
||||
(= ?add_idx (Op (Add ?add_shape ?add_a_stride ?add_b_stride ?add_out_stride) (ICons ?mul_base (ICons ?iota_within (INil)))))
|
||||
)
|
||||
(
|
||||
(set (glumoe_expert_index ?add_idx)
|
||||
(MkGLUMoEExpertIndexState ?io ?iota_within_range ?topk_idx))
|
||||
)
|
||||
:ruleset glumoe
|
||||
:name "GLUMoE expert index marker"
|
||||
)
|
||||
; ===== Gate-up expert gather =====
|
||||
(= ?gu_iota_base (Op (Iota ?gu_io ?gu_iota_base_range) (INil)))
|
||||
(= ?gu_mul_base (Op (Mul ?gu_mul_base_shape ?gu_mul_base_a_stride ?gu_mul_base_b_stride ?gu_mul_base_out_stride) (ICons ?topk_idx (ICons ?gu_iota_base (INil)))))
|
||||
(= ?gu_iota_within (Op (Iota (MIter) ?gu_iota_within_range) (INil)))
|
||||
(= ?gu_add_idx (Op (Add ?gu_add_shape ?gu_add_a_stride ?gu_add_b_stride ?gu_add_out_stride) (ICons ?gu_mul_base (ICons ?gu_iota_within (INil)))))
|
||||
(= ?gu_gathered (Op (Gather ?gu_gather_idx_shape ?gu_gather_idx_stride ?gu_gather_data_shape ?gu_gather_data_stride) (ICons ?gu_add_idx (ICons ?gate_up_w (INil)))))
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?index_state (glumoe_expert_index ?idx))
|
||||
(= ?index_state (MkGLUMoEExpertIndexState ?io ?within_range ?topk_idx))
|
||||
(= ?gathered (Op (Gather ?gather_idx_shape ?gather_idx_stride ?gather_data_shape ?gather_data_stride) (ICons ?idx (ICons ?weights (INil)))))
|
||||
(= ?f32 (Op (Cast ?f32_size (F32)) (ICons ?gathered (INil))))
|
||||
)
|
||||
(
|
||||
(set (glumoe_expert_gather ?f32)
|
||||
(MkGLUMoEExpertGatherState ?io ?within_range ?topk_idx ?weights))
|
||||
)
|
||||
:ruleset glumoe
|
||||
:name "GLUMoE expert gather marker"
|
||||
)
|
||||
; ===== Cast BF16→F32 =====
|
||||
(= ?gu_f32 (Op (Cast ?gu_f32_size (F32)) (ICons ?gu_gathered (INil))))
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?gather_state (glumoe_expert_gather ?gu_f32))
|
||||
(= ?gather_state (MkGLUMoEExpertGatherState ?gu_io ?gu_iota_within_range ?topk_idx ?gate_up_w))
|
||||
; ===== Gate-up batched matmul =====
|
||||
(= ?gu_matmul_mul (Op (Mul ?gu_matmul_mul_shape ?gu_matmul_a_stride ?gu_matmul_b_stride ?gu_matmul_mul_out_stride) (ICons ?x (ICons ?gu_f32 (INil)))))
|
||||
(= ?gu_matmul (Op (Sum ?gu_matmul_out_shape ?gu_matmul_k ?gu_matmul_in_stride ?gu_matmul_k_stride ?gu_matmul_out_stride) (ICons ?gu_matmul_mul (INil))))
|
||||
)
|
||||
@@ -84,7 +54,6 @@
|
||||
(set (glumoe_gate_up ?gu_matmul)
|
||||
(MkGLUMoEGateUpState ?gu_io ?gu_matmul_k ?gu_iota_within_range ?x ?topk_idx ?gate_up_w))
|
||||
)
|
||||
:ruleset glumoe
|
||||
:name "GLUMoE gate-up matmul marker"
|
||||
)
|
||||
|
||||
@@ -111,7 +80,6 @@
|
||||
(
|
||||
(set (glumoe_swiglu ?swiglu_out) (MkGLUMoESwiGLUState ?gate_up_state))
|
||||
)
|
||||
:ruleset glumoe
|
||||
:name "GLUMoE swiglu marker"
|
||||
)
|
||||
|
||||
@@ -145,7 +113,6 @@
|
||||
(
|
||||
(set (glumoe_gemma_gelu ?gemma_out) (MkGLUMoEGemmaGELUState ?gate_up_state))
|
||||
)
|
||||
:ruleset glumoe
|
||||
:name "GLUMoE gemma gelu marker"
|
||||
)
|
||||
|
||||
@@ -155,8 +122,12 @@
|
||||
(= ?swiglu_state (glumoe_swiglu ?swiglu_out))
|
||||
(= ?swiglu_state (MkGLUMoESwiGLUState ?gate_up_state))
|
||||
|
||||
(= ?gather_state (glumoe_expert_gather ?dn_f32))
|
||||
(= ?gather_state (MkGLUMoEExpertGatherState ?dn_io ?dn_iota_within_range ?topk_idx ?down_w))
|
||||
(= ?dn_iota_base (Op (Iota ?dn_io ?dn_iota_base_range) (INil)))
|
||||
(= ?dn_mul_base (Op (Mul ?dn_mul_base_shape ?dn_mul_base_a_stride ?dn_mul_base_b_stride ?dn_mul_base_out_stride) (ICons ?topk_idx (ICons ?dn_iota_base (INil)))))
|
||||
(= ?dn_iota_within (Op (Iota (MIter) ?dn_iota_within_range) (INil)))
|
||||
(= ?dn_add_idx (Op (Add ?dn_add_shape ?dn_add_a_stride ?dn_add_b_stride ?dn_add_out_stride) (ICons ?dn_mul_base (ICons ?dn_iota_within (INil)))))
|
||||
(= ?dn_gathered (Op (Gather ?dn_gather_idx_shape ?dn_gather_idx_stride ?dn_gather_data_shape ?dn_gather_data_stride) (ICons ?dn_add_idx (ICons ?down_w (INil)))))
|
||||
(= ?dn_f32 (Op (Cast ?dn_f32_size (F32)) (ICons ?dn_gathered (INil))))
|
||||
(= ?dn_matmul_mul (Op (Mul ?dn_matmul_mul_shape ?dn_matmul_a_stride ?dn_matmul_b_stride ?dn_matmul_mul_out_stride) (ICons ?swiglu_out (ICons ?dn_f32 (INil)))))
|
||||
(= ?dn_matmul (Op (Sum ?dn_matmul_out_shape ?dn_matmul_k ?dn_matmul_in_stride ?dn_matmul_k_stride ?dn_matmul_out_stride) (ICons ?dn_matmul_mul (INil))))
|
||||
)
|
||||
@@ -164,7 +135,6 @@
|
||||
(set (glumoe_swiglu_down ?dn_matmul)
|
||||
(MkGLUMoESwiGLUDownState ?dn_io ?dn_matmul_k ?dn_iota_within_range ?swiglu_state ?topk_idx ?down_w))
|
||||
)
|
||||
:ruleset glumoe
|
||||
:name "GLUMoE swiglu down marker"
|
||||
)
|
||||
|
||||
@@ -174,8 +144,12 @@
|
||||
(= ?gemma_state (glumoe_gemma_gelu ?gemma_out))
|
||||
(= ?gemma_state (MkGLUMoEGemmaGELUState ?gate_up_state))
|
||||
|
||||
(= ?gather_state (glumoe_expert_gather ?dn_f32))
|
||||
(= ?gather_state (MkGLUMoEExpertGatherState ?dn_io ?dn_iota_within_range ?topk_idx ?down_w))
|
||||
(= ?dn_iota_base (Op (Iota ?dn_io ?dn_iota_base_range) (INil)))
|
||||
(= ?dn_mul_base (Op (Mul ?dn_mul_base_shape ?dn_mul_base_a_stride ?dn_mul_base_b_stride ?dn_mul_base_out_stride) (ICons ?topk_idx (ICons ?dn_iota_base (INil)))))
|
||||
(= ?dn_iota_within (Op (Iota (MIter) ?dn_iota_within_range) (INil)))
|
||||
(= ?dn_add_idx (Op (Add ?dn_add_shape ?dn_add_a_stride ?dn_add_b_stride ?dn_add_out_stride) (ICons ?dn_mul_base (ICons ?dn_iota_within (INil)))))
|
||||
(= ?dn_gathered (Op (Gather ?dn_gather_idx_shape ?dn_gather_idx_stride ?dn_gather_data_shape ?dn_gather_data_stride) (ICons ?dn_add_idx (ICons ?down_w (INil)))))
|
||||
(= ?dn_f32 (Op (Cast ?dn_f32_size (F32)) (ICons ?dn_gathered (INil))))
|
||||
(= ?dn_matmul_mul (Op (Mul ?dn_matmul_mul_shape ?dn_matmul_a_stride ?dn_matmul_b_stride ?dn_matmul_mul_out_stride) (ICons ?gemma_out (ICons ?dn_f32 (INil)))))
|
||||
(= ?dn_matmul (Op (Sum ?dn_matmul_out_shape ?dn_matmul_k ?dn_matmul_in_stride ?dn_matmul_k_stride ?dn_matmul_out_stride) (ICons ?dn_matmul_mul (INil))))
|
||||
)
|
||||
@@ -183,7 +157,6 @@
|
||||
(set (glumoe_gemma_down ?dn_matmul)
|
||||
(MkGLUMoEGemmaDownState ?dn_io ?dn_matmul_k ?dn_iota_within_range ?gemma_state ?topk_idx ?down_w))
|
||||
)
|
||||
:ruleset glumoe
|
||||
:name "GLUMoE gemma down marker"
|
||||
)
|
||||
|
||||
@@ -204,10 +177,7 @@
|
||||
?gu_within_range ?dn_within_range (MNum 0))
|
||||
(ICons ?x (ICons ?topk_idx (ICons ?topk_vals (ICons ?gate_up_w (ICons ?down_w (ICons ?topk_vals (INil)))))))))
|
||||
(union ?output ?glumoe)
|
||||
(subsume (Op (Sum ?output_shape ?output_k ?output_in_stride ?output_k_stride ?output_out_stride) (ICons ?weighted (INil))))
|
||||
(subsume (Op (KernelSum ?output_shape ?output_k ?output_in_stride ?output_k_stride ?output_out_stride (F32)) (ICons ?weighted (INil))))
|
||||
)
|
||||
:ruleset glumoe
|
||||
:name "GLUMoE fused expert computation (swiglu)"
|
||||
)
|
||||
|
||||
@@ -238,9 +208,6 @@
|
||||
?gu_within_range ?dn_within_range (MNum 1))
|
||||
(ICons ?x (ICons ?topk_idx (ICons ?topk_vals (ICons ?gate_up_w (ICons ?down_w (ICons ?per_expert_scale (INil)))))))))
|
||||
(union ?output ?glumoe)
|
||||
(subsume (Op (Sum ?output_shape ?output_k ?output_in_stride ?output_k_stride ?output_out_stride) (ICons ?weighted (INil))))
|
||||
(subsume (Op (KernelSum ?output_shape ?output_k ?output_in_stride ?output_k_stride ?output_out_stride (F32)) (ICons ?weighted (INil))))
|
||||
)
|
||||
:ruleset glumoe
|
||||
:name "GLUMoE fused expert computation (gemma_gelu)"
|
||||
)
|
||||
|
||||
@@ -32,7 +32,7 @@ use crate::{
|
||||
CudaFunction, CudaModule, CudaSlice, CudaStream, DevicePtr, LaunchConfig, PushKernelArg,
|
||||
},
|
||||
},
|
||||
host::{DeviceBuffer, HostOp},
|
||||
host::HostOp,
|
||||
try_create_cublaslt,
|
||||
};
|
||||
|
||||
@@ -224,9 +224,8 @@ impl EgglogOp for GLUMoE {
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![
|
||||
Rule::raw(
|
||||
"(rule
|
||||
vec![Rule::raw(
|
||||
"(rule
|
||||
(
|
||||
(= ?e (Op (GLUMoE ?gu_io ?dn_io ?gu_matmul_k ?dn_matmul_k ?output_k ?gu_within_range ?dn_within_range ?mode) ?inputs))
|
||||
)
|
||||
@@ -235,15 +234,17 @@ impl EgglogOp for GLUMoE {
|
||||
)
|
||||
:ruleset dtype_prop
|
||||
)",
|
||||
),
|
||||
Rule::raw(include_str!["glumoe_rewrite.egg"]),
|
||||
]
|
||||
)]
|
||||
}
|
||||
|
||||
fn n_inputs(&self) -> usize {
|
||||
6
|
||||
}
|
||||
|
||||
fn early_rewrites(&self) -> Vec<Rule> {
|
||||
vec![Rule::raw(include_str!["glumoe_rewrite.egg"])]
|
||||
}
|
||||
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a luminal::egglog_utils::SerializedEGraph,
|
||||
@@ -294,140 +295,27 @@ impl HostOp for GLUMoE {
|
||||
stream: &Arc<CudaStream>,
|
||||
self_node: NodeIndex,
|
||||
inputs: &[NodeIndex],
|
||||
buffers: &FxHashMap<NodeIndex, DeviceBuffer>,
|
||||
buffers: &FxHashMap<NodeIndex, &CudaSlice<u8>>,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> anyhow::Result<()> {
|
||||
if inputs.len() < 6 {
|
||||
anyhow::bail!("GLUMoE expected at least 6 inputs, got {}", inputs.len());
|
||||
}
|
||||
|
||||
// Resolve dimensions
|
||||
let hidden = self
|
||||
.gu_matmul_k
|
||||
.exec(dyn_map)
|
||||
.ok_or_else(|| anyhow::anyhow!("GLUMoE hidden dimension is unresolved"))?;
|
||||
let intermediate = self
|
||||
.dn_matmul_k
|
||||
.exec(dyn_map)
|
||||
.ok_or_else(|| anyhow::anyhow!("GLUMoE intermediate dimension is unresolved"))?;
|
||||
let top_k = self
|
||||
.output_k
|
||||
.exec(dyn_map)
|
||||
.ok_or_else(|| anyhow::anyhow!("GLUMoE top-k dimension is unresolved"))?;
|
||||
let gu_io = self
|
||||
.gu_io
|
||||
.exec(dyn_map)
|
||||
.ok_or_else(|| anyhow::anyhow!("GLUMoE gate/up stride is unresolved"))?;
|
||||
let dn_io = self
|
||||
.dn_io
|
||||
.exec(dyn_map)
|
||||
.ok_or_else(|| anyhow::anyhow!("GLUMoE down stride is unresolved"))?;
|
||||
let hidden = self.gu_matmul_k.exec(dyn_map).unwrap();
|
||||
let intermediate = self.dn_matmul_k.exec(dyn_map).unwrap();
|
||||
let top_k_expected = self.output_k.exec(dyn_map).unwrap();
|
||||
let gate_up_dim = self.gu_io.exec(dyn_map).unwrap() / hidden; // gate_up_dim = gu_io / hidden
|
||||
let num_experts = self.gu_within_range.exec(dyn_map).unwrap() / (gate_up_dim * hidden);
|
||||
|
||||
if hidden == 0 || intermediate == 0 {
|
||||
anyhow::bail!(
|
||||
"GLUMoE got zero-sized matmul dimensions: hidden={hidden}, intermediate={intermediate}"
|
||||
);
|
||||
}
|
||||
if top_k == 0 {
|
||||
return Ok(());
|
||||
}
|
||||
if gu_io % hidden != 0 {
|
||||
anyhow::bail!("GLUMoE gate/up stride {gu_io} is not divisible by hidden {hidden}");
|
||||
}
|
||||
if dn_io % intermediate != 0 {
|
||||
anyhow::bail!(
|
||||
"GLUMoE down stride {dn_io} is not divisible by intermediate {intermediate}"
|
||||
);
|
||||
}
|
||||
|
||||
let gate_up_dim = gu_io / hidden; // gate_up_dim = 2 * intermediate for GLU
|
||||
let down_hidden = dn_io / intermediate;
|
||||
if gate_up_dim != intermediate * 2 {
|
||||
anyhow::bail!(
|
||||
"GLUMoE expected gate/up dim {} to equal 2 * intermediate {}",
|
||||
gate_up_dim,
|
||||
intermediate * 2
|
||||
);
|
||||
}
|
||||
if down_hidden != hidden {
|
||||
anyhow::bail!("GLUMoE down hidden {down_hidden} does not match hidden {hidden}");
|
||||
}
|
||||
|
||||
let output_bytes = self
|
||||
.output_bytes()
|
||||
.exec(dyn_map)
|
||||
.ok_or_else(|| anyhow::anyhow!("GLUMoE output byte size is unresolved"))?;
|
||||
if output_bytes % (hidden * 4) != 0 {
|
||||
anyhow::bail!(
|
||||
"GLUMoE output bytes {output_bytes} are not divisible by hidden bytes {}",
|
||||
hidden * 4
|
||||
);
|
||||
}
|
||||
let seq = output_bytes / (hidden * 4);
|
||||
if seq == 0 {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let get_buffer = |name: &str, node: NodeIndex| -> anyhow::Result<DeviceBuffer> {
|
||||
buffers.get(&node).copied().ok_or_else(|| {
|
||||
anyhow::anyhow!("GLUMoE missing {name} buffer for LLIR node {node:?}")
|
||||
})
|
||||
};
|
||||
// Derive seq from x buffer size: x is [seq, hidden] F32 → seq = len / (hidden * 4)
|
||||
let x_buf = buffers[&inputs[0]];
|
||||
let seq = x_buf.len() / (hidden * 4);
|
||||
|
||||
// Get input/output buffers
|
||||
let x_buf = get_buffer("x", inputs[0])?; // [seq, hidden] F32
|
||||
let topk_idx_buf = get_buffer("topk indices", inputs[1])?; // [seq, k] Int
|
||||
let topk_vals_buf = get_buffer("topk values", inputs[2])?; // [seq, k] F32
|
||||
let gate_up_buf = get_buffer("gate/up weights", inputs[3])?; // [E, gate_up_dim, hidden] BF16
|
||||
let down_buf = get_buffer("down weights", inputs[4])?; // [E, hidden, intermediate] BF16
|
||||
let mode_aux_buf = get_buffer("mode aux", inputs[5])?;
|
||||
let output_buf = get_buffer("output", self_node)?; // [seq, hidden] F32
|
||||
|
||||
let topk_bytes = seq * top_k * 4;
|
||||
if x_buf.len() < output_bytes {
|
||||
anyhow::bail!(
|
||||
"GLUMoE x buffer too small: have {} bytes, need {output_bytes}",
|
||||
x_buf.len()
|
||||
);
|
||||
}
|
||||
if topk_idx_buf.len() < topk_bytes {
|
||||
anyhow::bail!(
|
||||
"GLUMoE topk index buffer too small: have {} bytes, need {topk_bytes}",
|
||||
topk_idx_buf.len()
|
||||
);
|
||||
}
|
||||
if topk_vals_buf.len() < topk_bytes {
|
||||
anyhow::bail!(
|
||||
"GLUMoE topk value buffer too small: have {} bytes, need {topk_bytes}",
|
||||
topk_vals_buf.len()
|
||||
);
|
||||
}
|
||||
if output_buf.len() < output_bytes {
|
||||
anyhow::bail!(
|
||||
"GLUMoE output buffer too small: have {} bytes, need {output_bytes}",
|
||||
output_buf.len()
|
||||
);
|
||||
}
|
||||
|
||||
let gu_stride_bytes = gate_up_dim * hidden * 2;
|
||||
let down_stride_bytes = hidden * intermediate * 2;
|
||||
if gu_stride_bytes == 0 || gate_up_buf.len() % gu_stride_bytes != 0 {
|
||||
anyhow::bail!(
|
||||
"GLUMoE gate/up weight buffer has {} bytes, not a multiple of per-expert stride {gu_stride_bytes}",
|
||||
gate_up_buf.len()
|
||||
);
|
||||
}
|
||||
let num_experts = gate_up_buf.len() / gu_stride_bytes;
|
||||
if num_experts == 0 {
|
||||
anyhow::bail!("GLUMoE has no expert weights");
|
||||
}
|
||||
if down_buf.len() < num_experts * down_stride_bytes {
|
||||
anyhow::bail!(
|
||||
"GLUMoE down weight buffer too small: have {} bytes, need {}",
|
||||
down_buf.len(),
|
||||
num_experts * down_stride_bytes
|
||||
);
|
||||
}
|
||||
let topk_idx_buf = buffers[&inputs[1]]; // [seq, k] Int
|
||||
let topk_vals_buf = buffers[&inputs[2]]; // [seq, k] F32
|
||||
let gate_up_buf = buffers[&inputs[3]]; // [E, gate_up_dim, hidden] BF16
|
||||
let down_buf = buffers[&inputs[4]]; // [E, hidden, intermediate] BF16
|
||||
let mode_aux_buf = buffers[&inputs[5]];
|
||||
let output_buf = buffers[&self_node]; // [seq, hidden] F32
|
||||
|
||||
// Get raw device pointer addresses
|
||||
let x_ptr = buf_ptr(x_buf, stream);
|
||||
@@ -439,17 +327,21 @@ impl HostOp for GLUMoE {
|
||||
let (_, f32_to_bf16_fn, activation_fn) = self.get_kernels(stream);
|
||||
|
||||
// Read top-k routing values from GPU
|
||||
let topk_idx_host: Vec<u8> = topk_idx_buf.clone_dtoh(stream)?;
|
||||
let topk_idx_i32: &[i32] = bytemuck::cast_slice(&topk_idx_host[..topk_bytes]);
|
||||
let topk_vals_host: Vec<u8> = topk_vals_buf.clone_dtoh(stream)?;
|
||||
let topk_vals_f32: &[f32] = bytemuck::cast_slice(&topk_vals_host[..topk_bytes]);
|
||||
|
||||
for (pos, &expert_idx) in topk_idx_i32.iter().enumerate() {
|
||||
if expert_idx < 0 || expert_idx as usize >= num_experts {
|
||||
anyhow::bail!(
|
||||
"GLUMoE expert index {expert_idx} at routing position {pos} out of bounds for {num_experts} experts"
|
||||
);
|
||||
}
|
||||
let topk_idx_host: Vec<u8> = stream.clone_dtoh(topk_idx_buf)?;
|
||||
let topk_idx_i32: &[i32] = bytemuck::cast_slice(&topk_idx_host);
|
||||
let topk_vals_host: Vec<u8> = stream.clone_dtoh(topk_vals_buf)?;
|
||||
let topk_vals_f32: &[f32] = bytemuck::cast_slice(&topk_vals_host);
|
||||
let idx_k = topk_idx_i32
|
||||
.len()
|
||||
.checked_div(seq)
|
||||
.unwrap_or(top_k_expected);
|
||||
let val_k = topk_vals_f32
|
||||
.len()
|
||||
.checked_div(seq)
|
||||
.unwrap_or(top_k_expected);
|
||||
let top_k = idx_k.min(val_k);
|
||||
if seq > 0 && top_k == 0 {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Mode-dependent expert weights used for the final reduction:
|
||||
@@ -459,16 +351,9 @@ impl HostOp for GLUMoE {
|
||||
let expert_weights_f32: &[f32] = match self.mode {
|
||||
GLUMoEMode::SwiGLU => topk_vals_f32,
|
||||
GLUMoEMode::GemmaGELU => {
|
||||
let per_expert_scale_host: Vec<u8> = mode_aux_buf.clone_dtoh(stream)?;
|
||||
let per_expert_scale_bytes = num_experts * 4;
|
||||
if per_expert_scale_host.len() < per_expert_scale_bytes {
|
||||
anyhow::bail!(
|
||||
"GLUMoE per-expert scale buffer too small: have {} bytes, need {per_expert_scale_bytes}",
|
||||
per_expert_scale_host.len()
|
||||
);
|
||||
}
|
||||
let per_expert_scale_f32: &[f32] =
|
||||
bytemuck::cast_slice(&per_expert_scale_host[..per_expert_scale_bytes]);
|
||||
let per_expert_scale_host: Vec<u8> = stream.clone_dtoh(mode_aux_buf)?;
|
||||
let per_expert_scale_f32: &[f32] = bytemuck::cast_slice(&per_expert_scale_host);
|
||||
debug_assert!(per_expert_scale_f32.len() >= num_experts);
|
||||
expert_weights_storage.resize(seq * top_k, 0.0);
|
||||
for t in 0..seq {
|
||||
let base = t * top_k;
|
||||
@@ -498,10 +383,10 @@ impl HostOp for GLUMoE {
|
||||
let hidden_tmp = unsafe { stream.alloc::<u8>(intermediate * 2)? }; // BF16
|
||||
let workspace = unsafe { stream.alloc::<u8>(WORKSPACE_SIZE)? };
|
||||
|
||||
let xbf16_ptr = slice_ptr(&x_bf16_buf, stream);
|
||||
let gu_out_ptr = slice_ptr(&gate_up_out_buf, stream);
|
||||
let hid_ptr = slice_ptr(&hidden_tmp, stream);
|
||||
let ws_ptr = slice_ptr(&workspace, stream);
|
||||
let xbf16_ptr = buf_ptr(&x_bf16_buf, stream);
|
||||
let gu_out_ptr = buf_ptr(&gate_up_out_buf, stream);
|
||||
let hid_ptr = buf_ptr(&hidden_tmp, stream);
|
||||
let ws_ptr = buf_ptr(&workspace, stream);
|
||||
|
||||
// Cast x F32 → BF16
|
||||
let n_cast = (seq * hidden) as i32;
|
||||
@@ -520,8 +405,8 @@ impl HostOp for GLUMoE {
|
||||
}
|
||||
|
||||
// Per-token expert computation
|
||||
let gu_stride = gu_stride_bytes as u64; // bytes per expert gate_up (BF16)
|
||||
let down_stride = down_stride_bytes as u64; // bytes per expert down (BF16)
|
||||
let gu_stride = (gate_up_dim * hidden * 2) as u64; // bytes per expert gate_up (BF16)
|
||||
let down_stride = (hidden * intermediate * 2) as u64; // bytes per expert down (BF16)
|
||||
|
||||
for t in 0..seq {
|
||||
let x_t_ptr = xbf16_ptr + (t * hidden * 2) as u64; // BF16
|
||||
@@ -623,11 +508,7 @@ impl HostOp for GLUMoE {
|
||||
// Helpers
|
||||
// ============================================================
|
||||
|
||||
fn buf_ptr(buf: DeviceBuffer, _stream: &Arc<CudaStream>) -> u64 {
|
||||
buf.ptr()
|
||||
}
|
||||
|
||||
fn slice_ptr(buf: &CudaSlice<u8>, stream: &Arc<CudaStream>) -> u64 {
|
||||
fn buf_ptr(buf: &CudaSlice<u8>, stream: &Arc<CudaStream>) -> u64 {
|
||||
let (ptr, _guard) = buf.device_ptr(stream);
|
||||
ptr
|
||||
}
|
||||
|
||||
@@ -1,301 +0,0 @@
|
||||
// =========================================================================
|
||||
// Fused elementwise op variants used inside FusionStart/FusionEnd regions.
|
||||
//
|
||||
// Each `FusedX` struct mirrors its un-fused `KernelX` sibling field-for-field
|
||||
// and serves a single purpose: give the egglog rules a distinct sort to
|
||||
// rewrite into so a pair-fuse rule's RHS can never re-match its own LHS
|
||||
// pattern. Cascade prevention by typing.
|
||||
//
|
||||
// Each FusedX must be absorbed into a FusionEnd-rooted region and compiled by
|
||||
// `region_codegen`; standalone compilation is intentionally unsupported.
|
||||
// =========================================================================
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use cudarc::driver::{CudaFunction, CudaModule, CudaSlice, CudaStream};
|
||||
use luminal::{
|
||||
egglog_utils::{
|
||||
api::{Rule, SortDef, sort},
|
||||
base::{DTYPE, ELIST, OP_KIND},
|
||||
extract_dtype, extract_expr_list,
|
||||
},
|
||||
op::*,
|
||||
prelude::*,
|
||||
};
|
||||
|
||||
use crate::kernel::KernelOp;
|
||||
|
||||
pub type Ops = (
|
||||
FusedSin,
|
||||
FusedSqrt,
|
||||
FusedExp,
|
||||
FusedExp2,
|
||||
FusedLog2,
|
||||
FusedRecip,
|
||||
FusedAdd,
|
||||
FusedMul,
|
||||
);
|
||||
|
||||
// Standard `compile()` return tuple (matches the trait signature).
|
||||
type CompileOut = (
|
||||
CudaFunction,
|
||||
Arc<CudaModule>,
|
||||
String,
|
||||
(Expression, Expression, Expression),
|
||||
(Expression, Expression, Expression),
|
||||
Expression,
|
||||
FxHashMap<char, CudaSlice<u8>>,
|
||||
);
|
||||
|
||||
/// Generate `pub struct $Name { … unary fields … }` plus its `EgglogOp` and
|
||||
/// `KernelOp` impls. `$kernel_name` names the CUDA function (and the cache
|
||||
/// key); `$body` is the per-op CUDA expression, e.g. `"sinf(in[{in_idx}])"`.
|
||||
macro_rules! impl_fused_unary {
|
||||
($Name:ident, $sort:literal, $kernel_name:literal, $body:literal) => {
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct $Name {
|
||||
pub(crate) shape: Vec<Expression>,
|
||||
pub(crate) in_strides: Vec<Expression>,
|
||||
pub(crate) out_strides: Vec<Expression>,
|
||||
pub(crate) dtype: DType,
|
||||
}
|
||||
|
||||
impl EgglogOp for $Name {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
$sort,
|
||||
&[
|
||||
("shape", ELIST),
|
||||
("strides", ELIST),
|
||||
("out_strides", ELIST),
|
||||
("dtype", DTYPE),
|
||||
],
|
||||
)
|
||||
}
|
||||
fn n_inputs(&self) -> usize {
|
||||
1
|
||||
}
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
Vec::new()
|
||||
}
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
(
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(Self {
|
||||
shape: extract_expr_list(egraph, kind_children[0], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
in_strides: extract_expr_list(
|
||||
egraph,
|
||||
kind_children[1],
|
||||
list_cache,
|
||||
expr_cache,
|
||||
)
|
||||
.unwrap(),
|
||||
out_strides: extract_expr_list(
|
||||
egraph,
|
||||
kind_children[2],
|
||||
list_cache,
|
||||
expr_cache,
|
||||
)
|
||||
.unwrap(),
|
||||
dtype: extract_dtype(egraph, kind_children[3]),
|
||||
})),
|
||||
input_enodes,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl KernelOp for $Name {
|
||||
fn compile(
|
||||
&self,
|
||||
_stream: &Arc<CudaStream>,
|
||||
_compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> CompileOut {
|
||||
unreachable!(concat!(
|
||||
$sort,
|
||||
" must be compiled through fusion region codegen"
|
||||
))
|
||||
}
|
||||
fn output_size(&self) -> Expression {
|
||||
self.shape.iter().copied().product()
|
||||
}
|
||||
fn output_bytes(&self) -> Expression {
|
||||
(self.output_size() * self.dtype.bits()).ceil_div(8)
|
||||
}
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
self.output_bytes()
|
||||
}
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
self.output_bytes()
|
||||
}
|
||||
fn flops(&self) -> Expression {
|
||||
self.shape.iter().copied().product()
|
||||
}
|
||||
fn output_dtype(&self) -> DType {
|
||||
self.dtype
|
||||
}
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
$sort
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/// As `impl_fused_unary!` but for binary ops: 5-field sort signature
|
||||
/// (shape + per-input strides + out_stride + dtype), n_inputs = 2.
|
||||
/// `$op_str` is the CUDA infix operator, e.g. `"+"`, `"*"`.
|
||||
macro_rules! impl_fused_binary {
|
||||
($Name:ident, $sort:literal, $kernel_name:literal, $op_str:literal) => {
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct $Name {
|
||||
pub(crate) out_shape: Vec<Expression>,
|
||||
pub(crate) a_stride: Vec<Expression>,
|
||||
pub(crate) b_stride: Vec<Expression>,
|
||||
pub(crate) out_stride: Vec<Expression>,
|
||||
pub(crate) dtype: DType,
|
||||
}
|
||||
|
||||
impl EgglogOp for $Name {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
$sort,
|
||||
&[
|
||||
("shape", ELIST),
|
||||
("a_strides", ELIST),
|
||||
("b_strides", ELIST),
|
||||
("out_strides", ELIST),
|
||||
("dtype", DTYPE),
|
||||
],
|
||||
)
|
||||
}
|
||||
fn n_inputs(&self) -> usize {
|
||||
2
|
||||
}
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
Vec::new()
|
||||
}
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
(
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(Self {
|
||||
out_shape: extract_expr_list(
|
||||
egraph,
|
||||
kind_children[0],
|
||||
list_cache,
|
||||
expr_cache,
|
||||
)
|
||||
.unwrap(),
|
||||
a_stride: extract_expr_list(
|
||||
egraph,
|
||||
kind_children[1],
|
||||
list_cache,
|
||||
expr_cache,
|
||||
)
|
||||
.unwrap(),
|
||||
b_stride: extract_expr_list(
|
||||
egraph,
|
||||
kind_children[2],
|
||||
list_cache,
|
||||
expr_cache,
|
||||
)
|
||||
.unwrap(),
|
||||
out_stride: extract_expr_list(
|
||||
egraph,
|
||||
kind_children[3],
|
||||
list_cache,
|
||||
expr_cache,
|
||||
)
|
||||
.unwrap(),
|
||||
dtype: extract_dtype(egraph, kind_children[4]),
|
||||
})),
|
||||
input_enodes,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl KernelOp for $Name {
|
||||
fn compile(
|
||||
&self,
|
||||
_stream: &Arc<CudaStream>,
|
||||
_compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> CompileOut {
|
||||
unreachable!(concat!(
|
||||
$sort,
|
||||
" must be compiled through fusion region codegen"
|
||||
))
|
||||
}
|
||||
fn output_size(&self) -> Expression {
|
||||
self.out_shape.iter().copied().product()
|
||||
}
|
||||
fn output_bytes(&self) -> Expression {
|
||||
(self.output_size() * self.dtype.bits()).ceil_div(8)
|
||||
}
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
let bytes = (self.output_size() * self.dtype.bits()).ceil_div(8);
|
||||
bytes + bytes
|
||||
}
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
self.output_bytes()
|
||||
}
|
||||
fn flops(&self) -> Expression {
|
||||
self.out_shape.iter().copied().product()
|
||||
}
|
||||
fn output_dtype(&self) -> DType {
|
||||
self.dtype
|
||||
}
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
$sort
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl_fused_unary!(FusedSin, "FusedSin", "fused_sin_k", "sinf(in[{in_idx}])");
|
||||
impl_fused_unary!(
|
||||
FusedSqrt,
|
||||
"FusedSqrt",
|
||||
"fused_sqrt_k",
|
||||
"sqrtf(in[{in_idx}])"
|
||||
);
|
||||
impl_fused_unary!(FusedExp, "FusedExp", "fused_exp_k", "expf(in[{in_idx}])");
|
||||
impl_fused_unary!(
|
||||
FusedExp2,
|
||||
"FusedExp2",
|
||||
"fused_exp2_k",
|
||||
"exp2f(in[{in_idx}])"
|
||||
);
|
||||
impl_fused_unary!(
|
||||
FusedLog2,
|
||||
"FusedLog2",
|
||||
"fused_log2_k",
|
||||
"log2f(in[{in_idx}])"
|
||||
);
|
||||
impl_fused_unary!(
|
||||
FusedRecip,
|
||||
"FusedRecip",
|
||||
"fused_recip_k",
|
||||
"1.0f / in[{in_idx}]"
|
||||
);
|
||||
|
||||
impl_fused_binary!(FusedAdd, "FusedAdd", "fused_add_k", "+");
|
||||
impl_fused_binary!(FusedMul, "FusedMul", "fused_mul_k", "*");
|
||||
@@ -1,413 +0,0 @@
|
||||
// =========================================================================
|
||||
// Fusion boundary markers — FusionStart and FusionEnd.
|
||||
//
|
||||
// Tag-like LLIR ops that bracket a region of elementwise ops destined to
|
||||
// be emitted as a single CUDA kernel:
|
||||
// - N FusionStart nodes per region (one per FS leaf — distinct external
|
||||
// reads),
|
||||
// - exactly 1 FusionEnd per region.
|
||||
//
|
||||
// `FusionEnd::rewrites()` carries the seven rule families that build and
|
||||
// extend regions (pair-fuse / grow / merge); the actual single-kernel
|
||||
// codegen lives in `region_codegen`. Like FusedX, both markers'
|
||||
// `compile()` is `unreachable!()` — region codegen folds them away
|
||||
// before kernel_to_host's compile loop reaches an interior node.
|
||||
// =========================================================================
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use cudarc::driver::{CudaFunction, CudaModule, CudaSlice, CudaStream};
|
||||
use luminal::{
|
||||
egglog_utils::{
|
||||
api::{Rule, SortDef, sort},
|
||||
base::{DTYPE, ELIST, OP_KIND},
|
||||
extract_dtype, extract_expr_list,
|
||||
},
|
||||
op::*,
|
||||
prelude::*,
|
||||
};
|
||||
|
||||
use crate::kernel::KernelOp;
|
||||
|
||||
pub type Ops = (FusionStart, FusionEnd);
|
||||
|
||||
type CompileOut = (
|
||||
CudaFunction,
|
||||
Arc<CudaModule>,
|
||||
String,
|
||||
(Expression, Expression, Expression),
|
||||
(Expression, Expression, Expression),
|
||||
Expression,
|
||||
FxHashMap<char, CudaSlice<u8>>,
|
||||
);
|
||||
|
||||
// =========================================================================
|
||||
// FusionStart
|
||||
// =========================================================================
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct FusionStart {
|
||||
pub(crate) shape: Vec<Expression>,
|
||||
pub(crate) strides: Vec<Expression>,
|
||||
pub(crate) dtype: DType,
|
||||
}
|
||||
|
||||
impl EgglogOp for FusionStart {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
"FusionStart",
|
||||
&[("shape", ELIST), ("strides", ELIST), ("dtype", DTYPE)],
|
||||
)
|
||||
}
|
||||
fn n_inputs(&self) -> usize {
|
||||
1
|
||||
}
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
// No idempotence rule. `FusionStart(FusionStart(x)) ≡ FusionStart(x)`
|
||||
// would unify nested markers and create eclass cycles via the
|
||||
// pair-fuse rules; without it, occasional re-firings produce extra
|
||||
// semantically-correct identity layers, bounded by the run schedule.
|
||||
Vec::new()
|
||||
}
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
(
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(Self {
|
||||
shape: extract_expr_list(egraph, kind_children[0], list_cache, expr_cache).unwrap(),
|
||||
strides: extract_expr_list(egraph, kind_children[1], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
dtype: extract_dtype(egraph, kind_children[2]),
|
||||
})),
|
||||
input_enodes,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl KernelOp for FusionStart {
|
||||
fn compile(
|
||||
&self,
|
||||
_stream: &Arc<CudaStream>,
|
||||
_compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> CompileOut {
|
||||
unreachable!("FusionStart must be compiled through fusion region codegen")
|
||||
}
|
||||
fn output_size(&self) -> Expression {
|
||||
self.shape.iter().copied().product()
|
||||
}
|
||||
fn output_bytes(&self) -> Expression {
|
||||
(self.output_size() * self.dtype.bits()).ceil_div(8)
|
||||
}
|
||||
fn output_dtype(&self) -> DType {
|
||||
self.dtype
|
||||
}
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
"FusionStart"
|
||||
}
|
||||
fn output_aliases_input(&self) -> Option<usize> {
|
||||
Some(0)
|
||||
}
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// FusionEnd
|
||||
// =========================================================================
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct FusionEnd {
|
||||
pub(crate) shape: Vec<Expression>,
|
||||
pub(crate) strides: Vec<Expression>,
|
||||
pub(crate) dtype: DType,
|
||||
}
|
||||
|
||||
impl EgglogOp for FusionEnd {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
"FusionEnd",
|
||||
&[("shape", ELIST), ("strides", ELIST), ("dtype", DTYPE)],
|
||||
)
|
||||
}
|
||||
fn n_inputs(&self) -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
// Seven rule families build and extend FE-bracketed regions. Each
|
||||
// pair-fuse rule's LHS pattern matches *un-fused* `KernelX` ops; the
|
||||
// RHS produces `FusedX` variants in a different egglog sort, so the
|
||||
// rule's own output cannot re-match its LHS — cascade is prevented
|
||||
// by typing rather than by a discriminator field.
|
||||
//
|
||||
// Stride compatibility is expressed by reusing variable names: a
|
||||
// unary inside a region matches `(KernelU ?shape ?s ?s ?dt)` (in =
|
||||
// out, no transpose); a binary feeding a downstream op binds the
|
||||
// binary's out-stride to the downstream op's in-stride along the
|
||||
// connecting side.
|
||||
let mut rules = Vec::new();
|
||||
|
||||
// (KernelX kind, FusedX kind)
|
||||
let unaries: &[(&str, &str)] = &[
|
||||
("KernelSin", "FusedSin"),
|
||||
("KernelSqrt", "FusedSqrt"),
|
||||
("KernelExp", "FusedExp"),
|
||||
("KernelExp2", "FusedExp2"),
|
||||
("KernelLog2", "FusedLog2"),
|
||||
("KernelRecip", "FusedRecip"),
|
||||
];
|
||||
// (KernelX kind, FusedX kind, rule-name label)
|
||||
let binaries: &[(&str, &str, &str)] = &[
|
||||
("KernelAdd", "FusedAdd", "Add"),
|
||||
("KernelMul", "FusedMul", "Mul"),
|
||||
];
|
||||
|
||||
// 1. Pair-fuse U → U: U2(U1(x)) → FE(FU2(FU1(FS(x)))).
|
||||
for (ki1, fi1) in unaries {
|
||||
for (ko2, fo2) in unaries {
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule (
|
||||
(= ?u1 (Op ({ki1} ?shape ?s ?s ?dt) (ICons ?x (INil))))
|
||||
(= ?u2 (Op ({ko2} ?shape ?s ?s ?dt) (ICons ?u1 (INil))))
|
||||
) (
|
||||
(let ?fs (Op (FusionStart ?shape ?s ?dt) (ICons ?x (INil))))
|
||||
(let ?fu1 (Op ({fi1} ?shape ?s ?s ?dt) (ICons ?fs (INil))))
|
||||
(let ?fu2 (Op ({fo2} ?shape ?s ?s ?dt) (ICons ?fu1 (INil))))
|
||||
(let ?fe (Op (FusionEnd ?shape ?s ?dt) (ICons ?fu2 (INil))))
|
||||
(union ?u2 ?fe)
|
||||
) :ruleset fusion_pair :name \"pair-fuse-U-U-{ki1}-{ko2}\")"
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Pair-fuse B → U: U(B(a, b)) → FE(FU(FB(FS(a), FS(b)))).
|
||||
for (kb, fb, lb) in binaries {
|
||||
for (ku, fu) in unaries {
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule (
|
||||
(= ?bin (Op ({kb} ?shape ?a_s ?b_s ?o_s ?dt)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(= ?u (Op ({ku} ?shape ?o_s ?o_s ?dt) (ICons ?bin (INil))))
|
||||
) (
|
||||
(let ?fs_a (Op (FusionStart ?shape ?a_s ?dt) (ICons ?a (INil))))
|
||||
(let ?fs_b (Op (FusionStart ?shape ?b_s ?dt) (ICons ?b (INil))))
|
||||
(let ?fbin (Op ({fb} ?shape ?a_s ?b_s ?o_s ?dt)
|
||||
(ICons ?fs_a (ICons ?fs_b (INil)))))
|
||||
(let ?fu (Op ({fu} ?shape ?o_s ?o_s ?dt) (ICons ?fbin (INil))))
|
||||
(let ?fe (Op (FusionEnd ?shape ?o_s ?dt) (ICons ?fu (INil))))
|
||||
(union ?u ?fe)
|
||||
) :ruleset fusion_pair :name \"pair-fuse-B-U-{lb}-{ku}\")"
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Pair-fuse U → B (lhs / rhs): unary feeds binary's A or B input.
|
||||
// LHS: B(U(a), b) → FE(FB(FU(FS(a)), FS(b))).
|
||||
// RHS: B(a, U(b)) → FE(FB(FS(a), FU(FS(b)))).
|
||||
for (ku, fu) in unaries {
|
||||
for (kb, fb, lb) in binaries {
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule (
|
||||
(= ?u (Op ({ku} ?shape ?u_s ?u_s ?dt) (ICons ?a (INil))))
|
||||
(= ?bin (Op ({kb} ?shape ?u_s ?b_s ?o_s ?dt)
|
||||
(ICons ?u (ICons ?b (INil)))))
|
||||
) (
|
||||
(let ?fs_a (Op (FusionStart ?shape ?u_s ?dt) (ICons ?a (INil))))
|
||||
(let ?fs_b (Op (FusionStart ?shape ?b_s ?dt) (ICons ?b (INil))))
|
||||
(let ?fu (Op ({fu} ?shape ?u_s ?u_s ?dt) (ICons ?fs_a (INil))))
|
||||
(let ?fbin (Op ({fb} ?shape ?u_s ?b_s ?o_s ?dt)
|
||||
(ICons ?fu (ICons ?fs_b (INil)))))
|
||||
(let ?fe (Op (FusionEnd ?shape ?o_s ?dt) (ICons ?fbin (INil))))
|
||||
(union ?bin ?fe)
|
||||
) :ruleset fusion_pair :name \"pair-fuse-U-B-lhs-{ku}-{lb}\")"
|
||||
)));
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule (
|
||||
(= ?u (Op ({ku} ?shape ?u_s ?u_s ?dt) (ICons ?b (INil))))
|
||||
(= ?bin (Op ({kb} ?shape ?a_s ?u_s ?o_s ?dt)
|
||||
(ICons ?a (ICons ?u (INil)))))
|
||||
) (
|
||||
(let ?fs_a (Op (FusionStart ?shape ?a_s ?dt) (ICons ?a (INil))))
|
||||
(let ?fs_b (Op (FusionStart ?shape ?u_s ?dt) (ICons ?b (INil))))
|
||||
(let ?fu (Op ({fu} ?shape ?u_s ?u_s ?dt) (ICons ?fs_b (INil))))
|
||||
(let ?fbin (Op ({fb} ?shape ?a_s ?u_s ?o_s ?dt)
|
||||
(ICons ?fs_a (ICons ?fu (INil)))))
|
||||
(let ?fe (Op (FusionEnd ?shape ?o_s ?dt) (ICons ?fbin (INil))))
|
||||
(union ?bin ?fe)
|
||||
) :ruleset fusion_pair :name \"pair-fuse-U-B-rhs-{ku}-{lb}\")"
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
// 4. Pair-fuse B → B (lhs / rhs): inner binary feeds outer's A or B.
|
||||
for (kbi, fbi, lbi) in binaries {
|
||||
for (kbo, fbo, lbo) in binaries {
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule (
|
||||
(= ?bi (Op ({kbi} ?shape ?ai_s ?bi_s ?oi_s ?dt)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(= ?bo (Op ({kbo} ?shape ?oi_s ?co_s ?oo_s ?dt)
|
||||
(ICons ?bi (ICons ?c (INil)))))
|
||||
) (
|
||||
(let ?fs_a (Op (FusionStart ?shape ?ai_s ?dt) (ICons ?a (INil))))
|
||||
(let ?fs_b (Op (FusionStart ?shape ?bi_s ?dt) (ICons ?b (INil))))
|
||||
(let ?fs_c (Op (FusionStart ?shape ?co_s ?dt) (ICons ?c (INil))))
|
||||
(let ?fbi (Op ({fbi} ?shape ?ai_s ?bi_s ?oi_s ?dt)
|
||||
(ICons ?fs_a (ICons ?fs_b (INil)))))
|
||||
(let ?fbo (Op ({fbo} ?shape ?oi_s ?co_s ?oo_s ?dt)
|
||||
(ICons ?fbi (ICons ?fs_c (INil)))))
|
||||
(let ?fe (Op (FusionEnd ?shape ?oo_s ?dt) (ICons ?fbo (INil))))
|
||||
(union ?bo ?fe)
|
||||
) :ruleset fusion_pair :name \"pair-fuse-B-B-lhs-{lbi}-{lbo}\")"
|
||||
)));
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule (
|
||||
(= ?bi (Op ({kbi} ?shape ?ai_s ?bi_s ?oi_s ?dt)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(= ?bo (Op ({kbo} ?shape ?co_s ?oi_s ?oo_s ?dt)
|
||||
(ICons ?c (ICons ?bi (INil)))))
|
||||
) (
|
||||
(let ?fs_a (Op (FusionStart ?shape ?ai_s ?dt) (ICons ?a (INil))))
|
||||
(let ?fs_b (Op (FusionStart ?shape ?bi_s ?dt) (ICons ?b (INil))))
|
||||
(let ?fs_c (Op (FusionStart ?shape ?co_s ?dt) (ICons ?c (INil))))
|
||||
(let ?fbi (Op ({fbi} ?shape ?ai_s ?bi_s ?oi_s ?dt)
|
||||
(ICons ?fs_a (ICons ?fs_b (INil)))))
|
||||
(let ?fbo (Op ({fbo} ?shape ?co_s ?oi_s ?oo_s ?dt)
|
||||
(ICons ?fs_c (ICons ?fbi (INil)))))
|
||||
(let ?fe (Op (FusionEnd ?shape ?oo_s ?dt) (ICons ?fbo (INil))))
|
||||
(union ?bo ?fe)
|
||||
) :ruleset fusion_pair :name \"pair-fuse-B-B-rhs-{lbi}-{lbo}\")"
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
// 5. Grow FE → U: U(FE(inner)) → FE(FU(inner)). No new FS.
|
||||
for (ku, fu) in unaries {
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule (
|
||||
(= ?fe (Op (FusionEnd ?shape ?s ?dt) (ICons ?inner (INil))))
|
||||
(= ?u (Op ({ku} ?shape ?s ?s ?dt) (ICons ?fe (INil))))
|
||||
) (
|
||||
(let ?fu (Op ({fu} ?shape ?s ?s ?dt) (ICons ?inner (INil))))
|
||||
(let ?new_fe (Op (FusionEnd ?shape ?s ?dt) (ICons ?fu (INil))))
|
||||
(union ?u ?new_fe)
|
||||
) :ruleset fusion_grow :name \"grow-FE-U-{ku}\")"
|
||||
)));
|
||||
}
|
||||
|
||||
// 6. Grow FE → B (lhs / rhs): one input is the FE, the other external.
|
||||
for (kb, fb, lb) in binaries {
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule (
|
||||
(= ?fe (Op (FusionEnd ?shape ?a_s ?dt) (ICons ?inner_a (INil))))
|
||||
(= ?bin (Op ({kb} ?shape ?a_s ?b_s ?o_s ?dt)
|
||||
(ICons ?fe (ICons ?b (INil)))))
|
||||
) (
|
||||
(let ?fs_b (Op (FusionStart ?shape ?b_s ?dt) (ICons ?b (INil))))
|
||||
(let ?fbin (Op ({fb} ?shape ?a_s ?b_s ?o_s ?dt)
|
||||
(ICons ?inner_a (ICons ?fs_b (INil)))))
|
||||
(let ?new_fe (Op (FusionEnd ?shape ?o_s ?dt) (ICons ?fbin (INil))))
|
||||
(union ?bin ?new_fe)
|
||||
) :ruleset fusion_grow :name \"grow-FE-B-lhs-{lb}\")"
|
||||
)));
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule (
|
||||
(= ?fe (Op (FusionEnd ?shape ?b_s ?dt) (ICons ?inner_b (INil))))
|
||||
(= ?bin (Op ({kb} ?shape ?a_s ?b_s ?o_s ?dt)
|
||||
(ICons ?a (ICons ?fe (INil)))))
|
||||
) (
|
||||
(let ?fs_a (Op (FusionStart ?shape ?a_s ?dt) (ICons ?a (INil))))
|
||||
(let ?fbin (Op ({fb} ?shape ?a_s ?b_s ?o_s ?dt)
|
||||
(ICons ?fs_a (ICons ?inner_b (INil)))))
|
||||
(let ?new_fe (Op (FusionEnd ?shape ?o_s ?dt) (ICons ?fbin (INil))))
|
||||
(union ?bin ?new_fe)
|
||||
) :ruleset fusion_grow :name \"grow-FE-B-rhs-{lb}\")"
|
||||
)));
|
||||
}
|
||||
|
||||
// 7. Merge two FEs at a binary: B(FE(ia), FE(ib)) → FE(FB(ia, ib)).
|
||||
//
|
||||
// This is destructive: after creating the larger region, subsume the
|
||||
// two smaller FusionEnd rows. Without that, independently-grown left
|
||||
// and right regions form a Cartesian product, then those alternatives
|
||||
// can merge again higher in the graph.
|
||||
for (kb, fb, lb) in binaries {
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule (
|
||||
(= ?fe_a (Op (FusionEnd ?shape ?a_s ?dt) (ICons ?inner_a (INil))))
|
||||
(= ?fe_b (Op (FusionEnd ?shape ?b_s ?dt) (ICons ?inner_b (INil))))
|
||||
(= ?bin (Op ({kb} ?shape ?a_s ?b_s ?o_s ?dt)
|
||||
(ICons ?fe_a (ICons ?fe_b (INil)))))
|
||||
) (
|
||||
(let ?fbin (Op ({fb} ?shape ?a_s ?b_s ?o_s ?dt)
|
||||
(ICons ?inner_a (ICons ?inner_b (INil)))))
|
||||
(let ?new_fe (Op (FusionEnd ?shape ?o_s ?dt) (ICons ?fbin (INil))))
|
||||
(union ?bin ?new_fe)
|
||||
(subsume (Op (FusionEnd ?shape ?a_s ?dt) (ICons ?inner_a (INil))))
|
||||
(subsume (Op (FusionEnd ?shape ?b_s ?dt) (ICons ?inner_b (INil))))
|
||||
) :ruleset fusion_merge :name \"merge-FE-FE-{lb}\")"
|
||||
)));
|
||||
}
|
||||
|
||||
// No dissolve rule (`FS(FE(x)) → x`): unioning FS's eclass with FE's
|
||||
// inner eclass creates self-referential eclasses after grow rules
|
||||
// extend the downstream region, and extraction then panics with
|
||||
// `Cycle(NodeIndex(_))`. Grow rules already compose adjacent regions
|
||||
// correctly without dissolve.
|
||||
|
||||
rules
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
(
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(Self {
|
||||
shape: extract_expr_list(egraph, kind_children[0], list_cache, expr_cache).unwrap(),
|
||||
strides: extract_expr_list(egraph, kind_children[1], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
dtype: extract_dtype(egraph, kind_children[2]),
|
||||
})),
|
||||
input_enodes,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl KernelOp for FusionEnd {
|
||||
fn compile(
|
||||
&self,
|
||||
_stream: &Arc<CudaStream>,
|
||||
_compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> CompileOut {
|
||||
unreachable!("FusionEnd must be compiled through fusion region codegen")
|
||||
}
|
||||
fn output_size(&self) -> Expression {
|
||||
self.shape.iter().copied().product()
|
||||
}
|
||||
fn output_bytes(&self) -> Expression {
|
||||
(self.output_size() * self.dtype.bits()).ceil_div(8)
|
||||
}
|
||||
fn output_dtype(&self) -> DType {
|
||||
self.dtype
|
||||
}
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
"FusionEnd"
|
||||
}
|
||||
}
|
||||
@@ -1,26 +0,0 @@
|
||||
//! Binary-inclusive elementwise kernel fusion.
|
||||
//!
|
||||
//! - `markers` — `FusionStart` / `FusionEnd` ops + the seven egglog rule
|
||||
//! families that build and extend FE-bracketed regions.
|
||||
//! - `fused_ops` — eight `FusedX` op variants (interior to a region) so
|
||||
//! pair-fuse rules' RHS sit in a different egglog sort than their LHS,
|
||||
//! blocking cascade by typing.
|
||||
//! - `region_codegen` — `kernel_to_host` calls into here to collapse each
|
||||
//! FE-rooted region into a single CUDA kernel at compile time.
|
||||
//!
|
||||
//! The LLIR keeps `FusionStart` / `FusedX` / `FusionEnd` nodes after
|
||||
//! extraction; `region_codegen` is the only place that walks them.
|
||||
|
||||
pub mod fused_ops;
|
||||
pub mod markers;
|
||||
pub mod region_codegen;
|
||||
|
||||
pub use fused_ops::{
|
||||
FusedAdd, FusedExp, FusedExp2, FusedLog2, FusedMul, FusedRecip, FusedSin, FusedSqrt,
|
||||
};
|
||||
pub use markers::{FusionEnd, FusionStart};
|
||||
|
||||
/// All fusion-related op types that the egglog runtime needs to know about
|
||||
/// (markers + interior FusedX variants). Combined into a flat tuple for the
|
||||
/// `Ops` registry in `kernel::mod`.
|
||||
pub type Ops = (markers::Ops, fused_ops::Ops);
|
||||
@@ -1,476 +0,0 @@
|
||||
// =========================================================================
|
||||
// Region codegen for FusionStart / FusionEnd-bracketed fused regions.
|
||||
//
|
||||
// PR1 left FusedX / FusionStart / FusionEnd nodes in the post-extraction
|
||||
// LLIR, each compiling to its own standalone CUDA kernel. PR2 collapses
|
||||
// every FusionEnd-rooted region into ONE fused CUDA kernel at codegen
|
||||
// time — without rewriting the LLIR.
|
||||
//
|
||||
// Pipeline:
|
||||
// `kernel_to_host` builds a Vec<CompileUnit> from the topo order:
|
||||
// - CompileUnit::Single(node) — un-fused KernelX, compiled as before.
|
||||
// - CompileUnit::Region(rgn) — one FE + its interior FusedX DAG +
|
||||
// its FS leaves. Compiled here as a
|
||||
// single CUDA kernel that reads from
|
||||
// the region's external inputs once,
|
||||
// chains all FusedX bodies through
|
||||
// register-resident locals, and writes
|
||||
// the FE's output.
|
||||
//
|
||||
// The CompiledKernel for a Region is keyed on the FE node and stores
|
||||
// `inputs = external producer NodeIndices` (one per interior FusionStart),
|
||||
// so the existing buffer-pointer wiring in to_host.rs picks up the right
|
||||
// device pointers at execute time. Interior FusedX / FusionStart nodes
|
||||
// never enter the kernels Vec — they have no buffers, no launches.
|
||||
// =========================================================================
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use cudarc::driver::{CudaFunction, CudaModule, CudaSlice, CudaStream};
|
||||
use luminal::{
|
||||
graph::LLIRGraph,
|
||||
prelude::{
|
||||
petgraph::{Direction, algo::toposort, visit::EdgeRef},
|
||||
*,
|
||||
},
|
||||
};
|
||||
|
||||
use as_any::Downcast;
|
||||
|
||||
use crate::{
|
||||
compile_module_image_for_current_device, cuda_dtype,
|
||||
kernel::KernelOp,
|
||||
kernel::fusion::markers::{FusionEnd, FusionStart},
|
||||
kernel::hlir::{dtype_includes, generate_dyn_dims_defines},
|
||||
};
|
||||
|
||||
// =========================================================================
|
||||
// Compile units — what `kernel_to_host` iterates over instead of nodes.
|
||||
// =========================================================================
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct RegionUnit {
|
||||
/// The FusionEnd node that anchors this region.
|
||||
pub fe_node: NodeIndex,
|
||||
/// Interior FusedX nodes, in topological order (predecessors before
|
||||
/// consumers). Used to emit register-binding statements in dependency
|
||||
/// order in the fused CUDA kernel body.
|
||||
pub fusedx_topo: Vec<NodeIndex>,
|
||||
/// FusionStart nodes that bound the region's leaves. One per external
|
||||
/// read site — duplicates (different FS LLIR nodes wrapping the same
|
||||
/// upstream tensor) are kept separate so each read uses its own
|
||||
/// strides; the host launch passes the same device pointer twice.
|
||||
pub fs_nodes: Vec<NodeIndex>,
|
||||
/// External producer NodeIndices, one per `fs_nodes` entry in the same
|
||||
/// order. Becomes the `inputs` field of the FE's `CompiledKernel`, and
|
||||
/// the kernel function's `in0`, `in1`, ... parameters in that order.
|
||||
pub external_inputs: Vec<NodeIndex>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) enum CompileUnit {
|
||||
Single(NodeIndex),
|
||||
Region(RegionUnit),
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// Region detection.
|
||||
// =========================================================================
|
||||
|
||||
/// Group a sub-DAG's topo order into compile units. Each FusionEnd node
|
||||
/// becomes the root of a `CompileUnit::Region`; the region's interior
|
||||
/// FusedX and FusionStart nodes are absorbed into that region and removed
|
||||
/// from the per-node iteration. Anything else is wrapped in
|
||||
/// `CompileUnit::Single`.
|
||||
/// Globally-absorbed FS / FE markers — the set of marker nodes that any
|
||||
/// `FusionEnd` in the LLIR walks back to during region detection. A
|
||||
/// marker is "absorbed" iff some FE in the LLIR can reach it by walking
|
||||
/// incoming edges through `FusionEnd` / `FusedX` nodes, stopping at
|
||||
/// `FusionStart` leaves.
|
||||
///
|
||||
/// This is computed once over the full LLIR rather than per-convex-
|
||||
/// subgraph, because `partition_marked_convex` may put a shared FS leaf
|
||||
/// (one whose e-graph congruence-deduplicated it across multiple
|
||||
/// regions) into a different subgraph than the FE that absorbs it.
|
||||
/// Without this global view, `build_compile_units` running on the FS's
|
||||
/// subgraph would not see any FE walking back to the FS and would emit the
|
||||
/// FS as `CompileUnit::Single`; marker standalone compilation is not supported.
|
||||
pub(crate) fn globally_absorbed_markers(llir_graph: &LLIRGraph) -> FxHashSet<NodeIndex> {
|
||||
let name_of = |idx: NodeIndex| -> Option<&'static str> {
|
||||
llir_graph
|
||||
.node_weight(idx)
|
||||
.and_then(|op| op.to_dialect::<dyn KernelOp>().map(|k| k.kernel_name()))
|
||||
};
|
||||
|
||||
let mut absorbed: FxHashSet<NodeIndex> = FxHashSet::default();
|
||||
for fe in llir_graph.node_indices() {
|
||||
if name_of(fe) != Some("FusionEnd") {
|
||||
continue;
|
||||
}
|
||||
let mut visited: FxHashSet<NodeIndex> = FxHashSet::default();
|
||||
let mut stack: Vec<NodeIndex> = vec![fe];
|
||||
visited.insert(fe);
|
||||
while let Some(cur) = stack.pop() {
|
||||
for pred in llir_graph.neighbors_directed(cur, Direction::Incoming) {
|
||||
if !visited.insert(pred) {
|
||||
continue;
|
||||
}
|
||||
match name_of(pred) {
|
||||
Some("FusionStart") => {
|
||||
absorbed.insert(pred);
|
||||
}
|
||||
Some("FusionEnd") => {
|
||||
absorbed.insert(pred);
|
||||
stack.push(pred);
|
||||
}
|
||||
Some(other) if other.starts_with("Fused") => {
|
||||
absorbed.insert(pred);
|
||||
stack.push(pred);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
absorbed
|
||||
}
|
||||
|
||||
pub(crate) fn build_compile_units(
|
||||
topo_order: &[NodeIndex],
|
||||
llir_graph: &LLIRGraph,
|
||||
globally_absorbed: &FxHashSet<NodeIndex>,
|
||||
) -> Vec<CompileUnit> {
|
||||
let name_of = |idx: NodeIndex| -> Option<&'static str> {
|
||||
llir_graph
|
||||
.node_weight(idx)
|
||||
.and_then(|op| op.to_dialect::<dyn KernelOp>().map(|k| k.kernel_name()))
|
||||
};
|
||||
|
||||
// First pass: every FusionEnd in the subgraph anchors a region; gather
|
||||
// the region's interior + FS leaves by walking incoming edges
|
||||
// backward, stopping at FusionStart (a leaf — its predecessor is the
|
||||
// external producer, outside the region).
|
||||
let mut absorbed: FxHashSet<NodeIndex> = FxHashSet::default();
|
||||
let mut regions: FxHashMap<NodeIndex, RegionUnit> = FxHashMap::default();
|
||||
|
||||
for &node in topo_order {
|
||||
if name_of(node) != Some("FusionEnd") {
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut interior: Vec<NodeIndex> = Vec::new();
|
||||
let mut fs_nodes: Vec<NodeIndex> = Vec::new();
|
||||
let mut visited: FxHashSet<NodeIndex> = FxHashSet::default();
|
||||
let mut stack: Vec<NodeIndex> = Vec::new();
|
||||
stack.push(node);
|
||||
visited.insert(node);
|
||||
|
||||
while let Some(cur) = stack.pop() {
|
||||
for pred in llir_graph.neighbors_directed(cur, Direction::Incoming) {
|
||||
if !visited.insert(pred) {
|
||||
continue;
|
||||
}
|
||||
match name_of(pred) {
|
||||
Some("FusionStart") => {
|
||||
fs_nodes.push(pred);
|
||||
// Don't recurse past FS — its predecessor is
|
||||
// external (outside the region).
|
||||
}
|
||||
Some("FusionEnd") => {
|
||||
// A nested FE inside a region. Under the current
|
||||
// rule design these are cascade artifacts — treat
|
||||
// them as transparent (walk through) rather than
|
||||
// as a separate region. The outer region absorbs
|
||||
// them. They do not become CompileUnit::Region
|
||||
// anchors because their eclass is already the
|
||||
// outer region's.
|
||||
absorbed.insert(pred);
|
||||
stack.push(pred);
|
||||
}
|
||||
Some(other) if other.starts_with("Fused") => {
|
||||
interior.push(pred);
|
||||
stack.push(pred);
|
||||
}
|
||||
_ => {
|
||||
// Non-marker, non-FusedX predecessor inside what
|
||||
// we thought was a region. Shouldn't happen with
|
||||
// the current rules; treat conservatively: do
|
||||
// not absorb it. This means the region is
|
||||
// malformed and we likely should not have a
|
||||
// region at all; caller will see incomplete
|
||||
// interior.
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Topological order on the interior + FS nodes (so the kernel
|
||||
// emits `let v = ...;` lines after their inputs are bound). We
|
||||
// use the parent graph's toposort filtered to in-region nodes.
|
||||
let mut region_set: FxHashSet<NodeIndex> = FxHashSet::default();
|
||||
region_set.extend(interior.iter().copied());
|
||||
region_set.extend(fs_nodes.iter().copied());
|
||||
let topo = toposort(llir_graph, None).expect("LLIR cycle in region detection");
|
||||
let interior_topo: Vec<NodeIndex> = topo
|
||||
.iter()
|
||||
.copied()
|
||||
.filter(|n| region_set.contains(n) && interior.contains(n))
|
||||
.collect();
|
||||
let fs_topo: Vec<NodeIndex> = topo
|
||||
.iter()
|
||||
.copied()
|
||||
.filter(|n| region_set.contains(n) && fs_nodes.contains(n))
|
||||
.collect();
|
||||
|
||||
// External producer for each FS leaf, in the same order.
|
||||
let external_inputs: Vec<NodeIndex> = fs_topo
|
||||
.iter()
|
||||
.map(|&fs| {
|
||||
llir_graph
|
||||
.neighbors_directed(fs, Direction::Incoming)
|
||||
.next()
|
||||
.expect("FusionStart with no predecessor")
|
||||
})
|
||||
.collect();
|
||||
|
||||
absorbed.extend(interior_topo.iter().copied());
|
||||
absorbed.extend(fs_topo.iter().copied());
|
||||
|
||||
regions.insert(
|
||||
node,
|
||||
RegionUnit {
|
||||
fe_node: node,
|
||||
fusedx_topo: interior_topo,
|
||||
fs_nodes: fs_topo,
|
||||
external_inputs,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
// Second pass: emit compile units in original topo order, replacing
|
||||
// FE nodes with their RegionUnit and skipping anything absorbed —
|
||||
// either by a region in *this* subgraph (`absorbed`) or by any
|
||||
// region anywhere in the LLIR (`globally_absorbed`). Skipping the
|
||||
// latter prevents shared FS markers whose consumers live in other
|
||||
// convex subgraphs from being emitted as standalone compile units:
|
||||
// those FSes are absorbed by some other region, and the consuming
|
||||
// region reads from FS's external producer.
|
||||
let mut units: Vec<CompileUnit> = Vec::new();
|
||||
for &node in topo_order {
|
||||
if let Some(region) = regions.remove(&node) {
|
||||
units.push(CompileUnit::Region(region));
|
||||
} else if absorbed.contains(&node) || globally_absorbed.contains(&node) {
|
||||
continue;
|
||||
} else {
|
||||
units.push(CompileUnit::Single(node));
|
||||
}
|
||||
}
|
||||
units
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// Per-FusedX body templates.
|
||||
//
|
||||
// Each entry takes the names of the local variables holding the op's
|
||||
// inputs and returns a CUDA expression evaluating to the op's output
|
||||
// (a register-resident value, no buffer involved).
|
||||
// =========================================================================
|
||||
|
||||
fn fused_body(name: &str, locals: &[&str]) -> String {
|
||||
match name {
|
||||
"FusedSin" => format!("sinf({})", locals[0]),
|
||||
"FusedSqrt" => format!("sqrtf({})", locals[0]),
|
||||
"FusedExp" => format!("expf({})", locals[0]),
|
||||
"FusedExp2" => format!("exp2f({})", locals[0]),
|
||||
"FusedLog2" => format!("log2f({})", locals[0]),
|
||||
"FusedRecip" => format!("1.0f / {}", locals[0]),
|
||||
"FusedAdd" => format!("{} + {}", locals[0], locals[1]),
|
||||
"FusedMul" => format!("{} * {}", locals[0], locals[1]),
|
||||
other => panic!("region_codegen: unknown FusedX op {other}"),
|
||||
}
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// Region compilation — emit one CUDA kernel for the whole region.
|
||||
// =========================================================================
|
||||
|
||||
#[allow(clippy::type_complexity)]
|
||||
pub(crate) struct CompiledRegion {
|
||||
pub function: CudaFunction,
|
||||
pub module: Arc<CudaModule>,
|
||||
pub kernel_str: String,
|
||||
pub grid: (Expression, Expression, Expression),
|
||||
pub block: (Expression, Expression, Expression),
|
||||
pub shared_mem: Expression,
|
||||
pub constants: FxHashMap<char, CudaSlice<u8>>,
|
||||
}
|
||||
|
||||
#[allow(clippy::type_complexity)]
|
||||
pub(crate) fn compile_region(
|
||||
region: &RegionUnit,
|
||||
llir_graph: &LLIRGraph,
|
||||
stream: &Arc<CudaStream>,
|
||||
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> CompiledRegion {
|
||||
// Resolve FE: shape, strides (for the write), dtype.
|
||||
let fe_op = llir_graph[region.fe_node]
|
||||
.to_dialect::<dyn KernelOp>()
|
||||
.expect("FE node must be a KernelOp");
|
||||
let fe_struct: &FusionEnd = (***fe_op)
|
||||
.downcast_ref::<FusionEnd>()
|
||||
.expect("region root must be FusionEnd");
|
||||
let out_shape: &[Expression] = &fe_struct.shape;
|
||||
let out_strides: &[Expression] = &fe_struct.strides;
|
||||
let dtype: DType = fe_struct.dtype;
|
||||
|
||||
// Aggregate all dynamic vars used anywhere in the region (FS strides,
|
||||
// FE strides, FusedX shape — all FusedX share `out_shape`, but their
|
||||
// own strides are likewise relevant for any future stride-affine ops).
|
||||
let mut all_vars: FxHashSet<char> = FxHashSet::default();
|
||||
all_vars.extend(out_shape.iter().flat_map(|e| e.dyn_vars()));
|
||||
all_vars.extend(out_strides.iter().flat_map(|e| e.dyn_vars()));
|
||||
for &fs_idx in ®ion.fs_nodes {
|
||||
let fs_op = llir_graph[fs_idx].to_dialect::<dyn KernelOp>().unwrap();
|
||||
let fs_struct: &FusionStart = (***fs_op).downcast_ref::<FusionStart>().unwrap();
|
||||
all_vars.extend(fs_struct.strides.iter().flat_map(|e| e.dyn_vars()));
|
||||
}
|
||||
|
||||
let cuda_ty = cuda_dtype(dtype);
|
||||
let includes = dtype_includes(&[dtype]);
|
||||
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&all_vars);
|
||||
let dyn_dims_param = if all_vars.is_empty() {
|
||||
""
|
||||
} else {
|
||||
", const int* dyn_dims"
|
||||
};
|
||||
|
||||
let n_elements = out_shape
|
||||
.iter()
|
||||
.copied()
|
||||
.product::<Expression>()
|
||||
.to_kernel();
|
||||
|
||||
// Build kernel signature: out, then one input per FS leaf in
|
||||
// `region.fs_nodes` order. The `external_inputs` list (parallel to
|
||||
// `fs_nodes`) is what the host wires into the launch params.
|
||||
let mut signature_params: Vec<String> = vec![format!("{cuda_ty} *out")];
|
||||
for i in 0..region.fs_nodes.len() {
|
||||
signature_params.push(format!("const {cuda_ty} *in{i}"));
|
||||
}
|
||||
let signature = signature_params.join(", ");
|
||||
|
||||
// Body: read FS leaves, then walk FusedX in topo order emitting a
|
||||
// local per op, then write FE output. Every node gets a local keyed
|
||||
// by a position-in-region index so the kernel string is invariant
|
||||
// under NodeIndex churn (each `egglog_to_llir` reissues NodeIndexes,
|
||||
// so naming locals by `n.index()` would invalidate the kernel
|
||||
// string cache on every search candidate). Indices: FS leaves get
|
||||
// 0..fs_nodes.len(), FusedX get fs_nodes.len()..(+ fusedx_topo.len()).
|
||||
let mut local_idx_map: FxHashMap<NodeIndex, usize> = FxHashMap::default();
|
||||
for (i, &fs_idx) in region.fs_nodes.iter().enumerate() {
|
||||
local_idx_map.insert(fs_idx, i);
|
||||
}
|
||||
let fs_count = region.fs_nodes.len();
|
||||
for (i, &op_idx) in region.fusedx_topo.iter().enumerate() {
|
||||
local_idx_map.insert(op_idx, fs_count + i);
|
||||
}
|
||||
let local_name = |n: NodeIndex| format!("v_{}", local_idx_map[&n]);
|
||||
|
||||
let mut body = String::new();
|
||||
body.push_str(&format!(
|
||||
" long long const_z = (long long)blockIdx.x * blockDim.x + threadIdx.x;\n\
|
||||
\x20 if (const_z >= {n_elements}) return;\n"
|
||||
));
|
||||
|
||||
// FS leaves: each reads from its corresponding `in_i` parameter using
|
||||
// its own strides.
|
||||
for (i, &fs_idx) in region.fs_nodes.iter().enumerate() {
|
||||
let fs_op = llir_graph[fs_idx].to_dialect::<dyn KernelOp>().unwrap();
|
||||
let fs_struct: &FusionStart = (***fs_op).downcast_ref::<FusionStart>().unwrap();
|
||||
let read_idx = flatten_strides(out_shape, &fs_struct.strides).to_kernel();
|
||||
body.push_str(&format!(
|
||||
" {cuda_ty} {name} = in{i}[{read_idx}];\n",
|
||||
name = local_name(fs_idx),
|
||||
));
|
||||
}
|
||||
|
||||
// FusedX ops in topo order. Each looks up its predecessor locals
|
||||
// (in incoming-edge id order to match the original op's input
|
||||
// arity / position).
|
||||
for &op_idx in ®ion.fusedx_topo {
|
||||
let op_ref = llir_graph[op_idx].to_dialect::<dyn KernelOp>().unwrap();
|
||||
let op_name = op_ref.kernel_name();
|
||||
|
||||
let mut input_locals: Vec<String> = llir_graph
|
||||
.edges_directed(op_idx, Direction::Incoming)
|
||||
.map(|e| (e.id(), e.source()))
|
||||
.collect::<Vec<_>>()
|
||||
.into_iter()
|
||||
.map(|(_, src)| local_name(src))
|
||||
.collect();
|
||||
// Sort by edge id like the rest of the codegen does for stable
|
||||
// input ordering.
|
||||
let mut edges: Vec<(_, NodeIndex)> = llir_graph
|
||||
.edges_directed(op_idx, Direction::Incoming)
|
||||
.map(|e| (e.id(), e.source()))
|
||||
.collect();
|
||||
edges.sort_by_key(|(eid, _)| *eid);
|
||||
input_locals = edges.into_iter().map(|(_, src)| local_name(src)).collect();
|
||||
let inputs_ref: Vec<&str> = input_locals.iter().map(|s| s.as_str()).collect();
|
||||
|
||||
let expr = fused_body(op_name, &inputs_ref);
|
||||
body.push_str(&format!(
|
||||
" {cuda_ty} {name} = {expr};\n",
|
||||
name = local_name(op_idx),
|
||||
));
|
||||
}
|
||||
|
||||
// FE write: pick the FusedX feeding FE (its single incoming edge in
|
||||
// the region — a FusedX or, in degenerate single-FS regions which
|
||||
// shouldn't arise, an FS).
|
||||
let fe_input: NodeIndex = llir_graph
|
||||
.neighbors_directed(region.fe_node, Direction::Incoming)
|
||||
.next()
|
||||
.expect("FusionEnd with no predecessor");
|
||||
let fe_input_local = local_name(fe_input);
|
||||
let write_idx = flatten_strides(out_shape, out_strides).to_kernel();
|
||||
body.push_str(&format!(" out[{write_idx}] = {fe_input_local};\n"));
|
||||
|
||||
let kernel = format!(
|
||||
"{includes}\n\
|
||||
{dyn_defines}\n\
|
||||
extern \"C\" {{\n\
|
||||
\x20 __global__ void fused_region_k({signature}{dyn_dims_param}) {{\n\
|
||||
{body}\
|
||||
\x20 }}\n\
|
||||
}}"
|
||||
);
|
||||
|
||||
let (module, function) = if let Some((m, f)) = compile_cache.get(&kernel) {
|
||||
(m.clone(), f.clone())
|
||||
} else {
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), &kernel)
|
||||
.expect("region kernel PTX compile failed");
|
||||
let module = stream
|
||||
.context()
|
||||
.load_module(ptx)
|
||||
.expect("module load failed");
|
||||
let function = module
|
||||
.load_function("fused_region_k")
|
||||
.expect("region kernel function not found");
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), function.clone()));
|
||||
(module, function)
|
||||
};
|
||||
|
||||
let out_size = out_shape.iter().copied().product::<Expression>();
|
||||
|
||||
CompiledRegion {
|
||||
function,
|
||||
module,
|
||||
kernel_str: kernel,
|
||||
grid: (out_size.ceil_div(256), 1.into(), 1.into()),
|
||||
block: (out_size.min(256), 1.into(), 1.into()),
|
||||
shared_mem: 0.into(),
|
||||
constants: FxHashMap::default(),
|
||||
}
|
||||
}
|
||||
@@ -8,7 +8,7 @@ use cudarc::driver::{CudaFunction, CudaModule, CudaSlice, CudaStream};
|
||||
use itertools::Itertools;
|
||||
use luminal::{
|
||||
egglog_utils::{
|
||||
api::{Rule, SortDef, Term, app, eq, rule, set, sort, union, v},
|
||||
api::{Rule, SortDef, app, eq, rule, set, sort, union, v},
|
||||
base::{DTYPE, ELIST, EXPRESSION, F64, OP_KIND, SORTS, dtype, ilist, op_term},
|
||||
extract_dtype, extract_expr, extract_expr_list,
|
||||
},
|
||||
@@ -79,48 +79,7 @@ pub fn kernel_rewrite<H: Default + EgglogOp, L: Default + EgglogOp>() -> Rule {
|
||||
args.add("dtype", dt.clone());
|
||||
let llir_kind_term = llir.call(&args);
|
||||
let llir_op = op_term(llir_kind_term, inputs);
|
||||
rule(union(hlir_op.clone(), llir_op))
|
||||
.fact(eq(dt, dtype(hlir_op)))
|
||||
.ruleset("kernel_lower")
|
||||
}
|
||||
|
||||
/// Build a kernel rewrite for ops whose kernel dtype must match the first input.
|
||||
///
|
||||
/// This avoids extracting stale/conflicting dtype facts from the output e-class
|
||||
/// after backend alternatives have been unioned into it.
|
||||
fn kernel_rewrite_from_first_input<H: Default + EgglogOp, L: Default + EgglogOp>() -> Rule {
|
||||
let hlir = H::default().sort();
|
||||
let llir = L::default().sort();
|
||||
let (mut args, hlir_kind_term) = hlir.new_call();
|
||||
let first_inp = v("?__first_inp");
|
||||
let tail = v("?__tail");
|
||||
let inputs = Term::App {
|
||||
variant: "ICons".to_string(),
|
||||
args: vec![first_inp.clone(), tail],
|
||||
};
|
||||
let hlir_op = op_term(hlir_kind_term, inputs.clone());
|
||||
let dt = v("?__dt");
|
||||
args.add("dtype", dt.clone());
|
||||
let llir_kind_term = llir.call(&args);
|
||||
let llir_op = op_term(llir_kind_term, inputs);
|
||||
rule(union(hlir_op, llir_op))
|
||||
.fact(eq(dt, dtype(first_inp)))
|
||||
.ruleset("kernel_lower")
|
||||
}
|
||||
|
||||
fn dtype_for_ir_enode(egraph: &SerializedEGraph, ir_node: &ENodeId) -> Option<DType> {
|
||||
let ir_class = egraph.node_to_class.get(ir_node)?;
|
||||
let dtype_node = egraph.enodes.iter().find_map(|(node, (label, children))| {
|
||||
(label == "dtype" && children.first() == Some(ir_class)).then_some(node)
|
||||
})?;
|
||||
let dtype_class = egraph.node_to_class.get(dtype_node)?;
|
||||
egraph.eclasses.get(dtype_class)?.1.iter().find_map(|node| {
|
||||
match egraph.enodes.get(node)?.0.as_str() {
|
||||
"F32" | "F16" | "Bf16" | "Int" | "Bool" | "F4E2M1" | "F8E4M3" | "F8UE8M0" | "I4"
|
||||
| "TF32" => Some(extract_dtype(egraph, node)),
|
||||
_ => None,
|
||||
}
|
||||
})
|
||||
rule(union(hlir_op.clone(), llir_op)).fact(eq(dt, dtype(hlir_op)))
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
@@ -741,7 +700,7 @@ impl EgglogOp for KernelMul {
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![kernel_rewrite_from_first_input::<Mul, Self>()]
|
||||
vec![kernel_rewrite::<Mul, Self>()]
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
@@ -756,45 +715,17 @@ impl EgglogOp for KernelMul {
|
||||
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
let mut out_shape =
|
||||
extract_expr_list(egraph, kind_children[0], list_cache, expr_cache).unwrap();
|
||||
let mut a_stride =
|
||||
extract_expr_list(egraph, kind_children[1], list_cache, expr_cache).unwrap();
|
||||
let mut b_stride =
|
||||
extract_expr_list(egraph, kind_children[2], list_cache, expr_cache).unwrap();
|
||||
let mut out_stride =
|
||||
extract_expr_list(egraph, kind_children[3], list_cache, expr_cache).unwrap();
|
||||
// Some e-graph paths (length-changing rewrites such as `merge_dims`
|
||||
// or `RemoveNthFromEnd`) leave a Mul kind enode whose shape and
|
||||
// strides children are extracted to different lengths under the
|
||||
// first-enode walk. The `enforce_consistent_first_kind_enodes`
|
||||
// pass in `src/egglog_utils/mod.rs` repairs this where it can,
|
||||
// but a handful of eclasses have *no* consistent variant in any
|
||||
// of their stride sub-eclasses. For those we truncate to the
|
||||
// SHORTEST length here so `flatten_strides` is structurally
|
||||
// satisfied — the resulting kernel is numerically wrong for that
|
||||
// candidate but harmless for the search, which profiles many
|
||||
// candidates and steers toward the consistent ones.
|
||||
let n = out_shape
|
||||
.len()
|
||||
.min(a_stride.len())
|
||||
.min(b_stride.len())
|
||||
.min(out_stride.len());
|
||||
out_shape.truncate(n);
|
||||
a_stride.truncate(n);
|
||||
b_stride.truncate(n);
|
||||
out_stride.truncate(n);
|
||||
let dtype = input_enodes
|
||||
.first()
|
||||
.and_then(|node| dtype_for_ir_enode(egraph, node))
|
||||
.unwrap_or_else(|| extract_dtype(egraph, kind_children[4]));
|
||||
(
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(Self {
|
||||
out_shape,
|
||||
a_stride,
|
||||
b_stride,
|
||||
out_stride,
|
||||
dtype,
|
||||
out_shape: extract_expr_list(egraph, kind_children[0], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
a_stride: extract_expr_list(egraph, kind_children[1], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
b_stride: extract_expr_list(egraph, kind_children[2], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
out_stride: extract_expr_list(egraph, kind_children[3], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
dtype: extract_dtype(egraph, kind_children[4]),
|
||||
})),
|
||||
input_enodes,
|
||||
)
|
||||
@@ -934,29 +865,13 @@ impl EgglogOp for KernelGather {
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
// Match HLIR Gather (now in Op format) and rewrite to KernelGather.
|
||||
// Mirror the IList pattern used by `Gather`'s own dtype propagation
|
||||
// rule (`src/hlir.rs`): use a `?__tail` variable instead of a
|
||||
// strict `(INil)` so we don't accidentally fail to match against a
|
||||
// Gather Op whose IList tail eclass has been merged with another
|
||||
// chain by some unrelated egglog union. Without this the kernel
|
||||
// rewrite is silently skipped for some Gathers in deep models
|
||||
// (e.g. YOLO's stacked make_contiguous chains).
|
||||
// Match HLIR Gather (now in Op format) and rewrite to KernelGather
|
||||
let hlir_gather = luminal::hlir::Gather::default().sort();
|
||||
let (gather_args, gather_kind_term) = hlir_gather.new_call();
|
||||
// HLIR Gather inputs: [indexes, data] (n_inputs=2)
|
||||
let indexes = v("?__indexes");
|
||||
let data = v("?__data");
|
||||
let tail = v("?__tail");
|
||||
let gather_inputs = Term::App {
|
||||
variant: "ICons".to_string(),
|
||||
args: vec![
|
||||
indexes.clone(),
|
||||
Term::App {
|
||||
variant: "ICons".to_string(),
|
||||
args: vec![data.clone(), tail],
|
||||
},
|
||||
],
|
||||
};
|
||||
let gather_inputs = ilist(vec![indexes.clone(), data.clone()]);
|
||||
let gather_op = op_term(gather_kind_term, gather_inputs);
|
||||
|
||||
let out_strides = SORTS
|
||||
@@ -979,11 +894,7 @@ impl EgglogOp for KernelGather {
|
||||
];
|
||||
let kernel_kind_term = self.sort().call(kernel_kind_args);
|
||||
let kernel_op = op_term(kernel_kind_term, ilist(vec![indexes, data.clone()]));
|
||||
vec![
|
||||
rule(union(gather_op, kernel_op))
|
||||
.fact(eq(dt, dtype(data)))
|
||||
.ruleset("kernel_lower"),
|
||||
]
|
||||
vec![rule(union(gather_op, kernel_op)).fact(eq(dt, dtype(data)))]
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
@@ -1218,11 +1129,7 @@ impl EgglogOp for KernelScatter {
|
||||
];
|
||||
let kernel_kind_term = self.sort().call(kernel_kind_args);
|
||||
let kernel_op = op_term(kernel_kind_term, ilist(vec![dest, indexes, src.clone()]));
|
||||
vec![
|
||||
rule(union(scatter_op, kernel_op))
|
||||
.fact(eq(dt, dtype(src)))
|
||||
.ruleset("kernel_lower"),
|
||||
]
|
||||
vec![rule(union(scatter_op, kernel_op)).fact(eq(dt, dtype(src)))]
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
@@ -1293,25 +1200,7 @@ impl KernelOp for KernelScatter {
|
||||
|
||||
// Single-kernel scatter: copy dest→output then scatter src→output[indexes]
|
||||
// Launched as 1 block of 1024 threads with __syncthreads() barrier.
|
||||
// Uses float4 vectorized copy (16 bytes per op) for the copy phase.
|
||||
//
|
||||
// The number of dtype elements that fit in a float4 (16 bytes) depends
|
||||
// on the element size. Computing `n_vec = n_dest / 4` would only be
|
||||
// correct for 4-byte dtypes — for bf16 it walks 2× past the end of
|
||||
// `out`, producing CUDA_ERROR_ILLEGAL_ADDRESS once the OOB region
|
||||
// happens to land on an unmapped page.
|
||||
let elements_per_vec: usize = match self.dtype {
|
||||
DType::F64 => 2,
|
||||
DType::F32 | DType::Int => 4,
|
||||
DType::F16 | DType::Bf16 | DType::I16 | DType::U16 => 8,
|
||||
DType::Bool
|
||||
| DType::I8
|
||||
| DType::U8
|
||||
| DType::F8UE8M0
|
||||
| DType::F8E4M3
|
||||
| DType::F8E5M2 => 16,
|
||||
other => panic!("Unsupported dtype for scatter vectorization: {other:?}"),
|
||||
};
|
||||
// Uses float4 vectorized copy (4x throughput) for the copy phase.
|
||||
let n_src_elements = self
|
||||
.index_shape
|
||||
.iter()
|
||||
@@ -1336,17 +1225,15 @@ extern \"C\" {{
|
||||
int tid = threadIdx.x;
|
||||
long long n_dest = {n_dest_elements};
|
||||
long long n_src = {n_src_elements};
|
||||
// Phase 1: vectorized copy dest → output (float4 = 16 bytes / iter,
|
||||
// i.e. {elements_per_vec} {dtype} elements). n_vec is sized so the
|
||||
// total bytes covered (`n_vec * 16`) never exceed `n_dest * sizeof({dtype})`.
|
||||
long long n_vec = n_dest / {elements_per_vec};
|
||||
// Phase 1: vectorized copy dest → output (float4 = 4 elements per op)
|
||||
long long n_vec = n_dest / 4;
|
||||
float4 *out4 = (float4 *)out;
|
||||
const float4 *dest4 = (const float4 *)dest;
|
||||
for (long long i = tid; i < n_vec; i += blockDim.x) {{
|
||||
out4[i] = dest4[i];
|
||||
}}
|
||||
// Handle remaining elements (the dtype-tail past the last full float4).
|
||||
long long remainder_start = n_vec * {elements_per_vec};
|
||||
// Handle remaining elements
|
||||
long long remainder_start = n_vec * 4;
|
||||
for (long long i = remainder_start + tid; i < n_dest; i += blockDim.x) {{
|
||||
out[i] = dest[i];
|
||||
}}
|
||||
@@ -1499,8 +1386,7 @@ impl EgglogOp for KernelIota {
|
||||
let kernel_op = op_term(kernel_kind, hlir_inputs);
|
||||
vec![
|
||||
rule(union(hlir_op, kernel_op.clone()))
|
||||
.set(dtype(kernel_op), app(&SORTS.int_dt, vec![]))
|
||||
.ruleset("kernel_lower"),
|
||||
.set(dtype(kernel_op), app(&SORTS.int_dt, vec![])),
|
||||
]
|
||||
}
|
||||
|
||||
@@ -1540,22 +1426,19 @@ impl KernelOp for KernelIota {
|
||||
Expression,
|
||||
FxHashMap<char, CudaSlice<u8>>,
|
||||
) {
|
||||
let mut vars = self.expr.dyn_vars().into_iter().collect::<FxHashSet<_>>();
|
||||
vars.extend(self.range.dyn_vars());
|
||||
let vars = self.expr.dyn_vars().into_iter().collect::<FxHashSet<_>>();
|
||||
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
|
||||
let dyn_dims_param = if vars.is_empty() {
|
||||
""
|
||||
} else {
|
||||
", const int* dyn_dims"
|
||||
};
|
||||
let range = self.range.to_kernel();
|
||||
let kernel = format!(
|
||||
"
|
||||
{dyn_defines}
|
||||
extern \"C\" {{
|
||||
__global__ void iota_k(int *C{dyn_dims_param}) {{
|
||||
long long const_z = (long long)blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (const_z >= {range}) return;
|
||||
C[const_z] = {};
|
||||
}}
|
||||
}}",
|
||||
@@ -1574,8 +1457,8 @@ extern \"C\" {{
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(self.range.ceil_div(256), 1.into(), 1.into()),
|
||||
(256.into(), 1.into(), 1.into()),
|
||||
(self.range, 1.into(), 1.into()),
|
||||
(1.into(), 1.into(), 1.into()),
|
||||
0.into(),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
@@ -2588,11 +2471,7 @@ impl EgglogOp for KernelLessThan {
|
||||
args.add("dtype", dt.clone());
|
||||
let kernel_kind_term = self.sort().call(&args);
|
||||
let kernel_op = op_term(kernel_kind_term, hlir_inputs);
|
||||
vec![
|
||||
rule(union(hlir_op, kernel_op))
|
||||
.fact(eq(dt, dtype(inp_a)))
|
||||
.ruleset("kernel_lower"),
|
||||
]
|
||||
vec![rule(union(hlir_op, kernel_op)).fact(eq(dt, dtype(inp_a)))]
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
@@ -2749,8 +2628,7 @@ impl EgglogOp for KernelConstant {
|
||||
let kernel_op = op_term(kernel_kind, hlir_inputs);
|
||||
vec![
|
||||
rule(union(hlir_op, kernel_op.clone()))
|
||||
.set(dtype(kernel_op), app(&SORTS.f32_dt, vec![]))
|
||||
.ruleset("kernel_lower"),
|
||||
.set(dtype(kernel_op), app(&SORTS.f32_dt, vec![])),
|
||||
]
|
||||
}
|
||||
|
||||
@@ -2892,11 +2770,7 @@ impl EgglogOp for KernelCast {
|
||||
cast_args.add("src_dtype", out_dty);
|
||||
let kernel_kind_term = self.sort().call(&cast_args);
|
||||
let kernel_op = op_term(kernel_kind_term, cast_inputs);
|
||||
vec![
|
||||
rule(union(cast_op, kernel_op))
|
||||
.fact(eq(in_dty, dtype(inp)))
|
||||
.ruleset("kernel_lower"),
|
||||
]
|
||||
vec![rule(union(cast_op, kernel_op)).fact(eq(in_dty, dtype(inp)))]
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
@@ -2938,14 +2812,6 @@ impl KernelOp for KernelCast {
|
||||
) {
|
||||
let out_dtype = cuda_dtype(self.out_dtype);
|
||||
let includes = dtype_includes(&[self.in_dtype, self.out_dtype]);
|
||||
let vars = self.size.dyn_vars().into_iter().collect::<FxHashSet<_>>();
|
||||
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
|
||||
let dyn_dims_param = if vars.is_empty() {
|
||||
""
|
||||
} else {
|
||||
", const int* dyn_dims"
|
||||
};
|
||||
let size = self.size.to_kernel();
|
||||
|
||||
let kernel = if self.in_dtype.bits() < 8 {
|
||||
// Sub-byte packed types: multiple values packed per byte.
|
||||
@@ -2955,11 +2821,9 @@ impl KernelOp for KernelCast {
|
||||
let mask = (1u32 << bits) - 1;
|
||||
format!(
|
||||
"{includes}
|
||||
{dyn_defines}
|
||||
extern \"C\" {{
|
||||
__global__ void cast_k({out_dtype} *out, const unsigned char *in_raw{dyn_dims_param}) {{
|
||||
__global__ void cast_k({out_dtype} *out, const unsigned char *in_raw) {{
|
||||
long long idx = (long long)blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx >= {size}) return;
|
||||
long long bit_offset = idx * {bits};
|
||||
long long byte_idx = bit_offset >> 3;
|
||||
int bit_pos = (int)(bit_offset & 7);
|
||||
@@ -2975,11 +2839,9 @@ extern \"C\" {{
|
||||
let in_dtype = cuda_dtype(self.in_dtype);
|
||||
format!(
|
||||
"{includes}
|
||||
{dyn_defines}
|
||||
extern \"C\" {{
|
||||
__global__ void cast_k({out_dtype} *out, const {in_dtype} *in{dyn_dims_param}) {{
|
||||
__global__ void cast_k({out_dtype} *out, const {in_dtype} *in) {{
|
||||
long long const_z = (long long)blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (const_z >= {size}) return;
|
||||
out[const_z] = ({out_dtype})in[const_z];
|
||||
}}
|
||||
}}"
|
||||
@@ -2998,8 +2860,8 @@ extern \"C\" {{
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(self.size.ceil_div(256), 1.into(), 1.into()),
|
||||
(256.into(), 1.into(), 1.into()),
|
||||
(self.size, 1.into(), 1.into()),
|
||||
(1.into(), 1.into(), 1.into()),
|
||||
0.into(),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
@@ -3162,7 +3024,6 @@ impl EgglogOp for KernelEmbed {
|
||||
(union ?gather ?ke)
|
||||
(set (dtype ?ke) (F32))
|
||||
)
|
||||
:ruleset kernel_specialize
|
||||
:name \"kernel embed with cast mul\"
|
||||
)"),
|
||||
// Match Gather with Add(Iota, Mul(Cast(token_ids), const)) indices (reversed order)
|
||||
@@ -3182,7 +3043,6 @@ impl EgglogOp for KernelEmbed {
|
||||
(union ?gather ?ke)
|
||||
(set (dtype ?ke) (F32))
|
||||
)
|
||||
:ruleset kernel_specialize
|
||||
:name \"kernel embed with cast mul reversed\"
|
||||
)"),
|
||||
// Match Gather with Add(Mul(token_ids, const), Iota) indices (no Cast)
|
||||
@@ -3201,7 +3061,6 @@ impl EgglogOp for KernelEmbed {
|
||||
(union ?gather ?ke)
|
||||
(set (dtype ?ke) (F32))
|
||||
)
|
||||
:ruleset kernel_specialize
|
||||
:name \"kernel embed with mul\"
|
||||
)"),
|
||||
// Match Gather with Add(Iota, Mul(token_ids, const)) indices (reversed order, no Cast)
|
||||
@@ -3220,7 +3079,6 @@ impl EgglogOp for KernelEmbed {
|
||||
(union ?gather ?ke)
|
||||
(set (dtype ?ke) (F32))
|
||||
)
|
||||
:ruleset kernel_specialize
|
||||
:name \"kernel embed with mul reversed\"
|
||||
)"),
|
||||
]
|
||||
@@ -3281,24 +3139,15 @@ impl KernelOp for KernelEmbed {
|
||||
.chain(self.out_stride.iter().flat_map(|e| e.dyn_vars()))
|
||||
.chain(self.embed_dim.dyn_vars())
|
||||
.collect::<FxHashSet<_>>();
|
||||
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
|
||||
let dyn_dims_param = if vars.is_empty() {
|
||||
""
|
||||
} else {
|
||||
", const int* dyn_dims"
|
||||
};
|
||||
let token_offset_expr = flatten_strides(&self.batch_shape, &self.token_stride).to_kernel();
|
||||
let out_offset_expr = flatten_strides(&self.batch_shape, &self.out_stride).to_kernel();
|
||||
let embed_dim_expr = self.embed_dim.to_kernel();
|
||||
let total_threads = batch_size * self.embed_dim;
|
||||
let n_elements = total_threads.to_kernel();
|
||||
let kernel = format!(
|
||||
"
|
||||
{dyn_defines}
|
||||
{}
|
||||
extern \"C\" {{
|
||||
__global__ void embed(float *out, const int *token_ids, const float *embed_table{dyn_dims_param}) {{
|
||||
__global__ void embed(float *out, const int *token_ids, const float *embed_table) {{
|
||||
long long idx = (long long)blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx >= {n_elements}) return;
|
||||
long long embed_dim = {embed_dim_expr};
|
||||
long long batch_idx = idx / embed_dim;
|
||||
long long embed_idx = idx % embed_dim;
|
||||
@@ -3308,7 +3157,10 @@ extern \"C\" {{
|
||||
int token_id = token_ids[token_offset];
|
||||
out[out_offset + embed_idx] = embed_table[(long long)token_id * embed_dim + embed_idx];
|
||||
}}
|
||||
}}"
|
||||
}}",
|
||||
vars.iter()
|
||||
.map(|i| format!("__constant__ int const_{i}[1];"))
|
||||
.join("\n"),
|
||||
);
|
||||
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
|
||||
(module.clone(), func.clone())
|
||||
@@ -3319,14 +3171,17 @@ extern \"C\" {{
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
(module, func)
|
||||
};
|
||||
// Return empty constants map - we now use shared dyn_dims buffer
|
||||
let constants = FxHashMap::default();
|
||||
let constants = vars
|
||||
.into_iter()
|
||||
.map(|d| (d, module.get_global(&format!("const_{d}"), stream).unwrap()))
|
||||
.collect();
|
||||
let total_threads = batch_size * self.embed_dim;
|
||||
(
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(total_threads.ceil_div(256), 1.into(), 1.into()),
|
||||
(256.into(), 1.into(), 1.into()),
|
||||
(total_threads, 1.into(), 1.into()),
|
||||
(1.into(), 1.into(), 1.into()),
|
||||
0.into(),
|
||||
constants,
|
||||
)
|
||||
|
||||
@@ -10,13 +10,12 @@ use luminal_tracing::schema::{
|
||||
use uuid::Uuid;
|
||||
|
||||
pub mod cuda_graph;
|
||||
pub mod fusion;
|
||||
pub mod hlir;
|
||||
pub mod other_ops;
|
||||
|
||||
pub use cuda_graph::*;
|
||||
|
||||
pub type Ops = (hlir::Ops, other_ops::Ops, fusion::Ops);
|
||||
pub type Ops = (hlir::Ops, other_ops::Ops);
|
||||
|
||||
/// Build a mapping from interned string IDs to their string values for a given sequence.
|
||||
fn build_interned_strings(trace: &schema::Trace) -> std::collections::HashMap<(u32, u64), String> {
|
||||
|
||||
@@ -25,6 +25,7 @@ pub type Ops = (
|
||||
KernelSoftmax,
|
||||
KernelExp,
|
||||
KernelSigmoid,
|
||||
KernelFusedElementwise,
|
||||
);
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
@@ -128,8 +129,7 @@ impl KernelOp for KernelMeanReduce {
|
||||
let dtype = cuda_dtype(self.dtype);
|
||||
let includes = dtype_includes(&[self.dtype]);
|
||||
let n_outputs: Expression = self.out_shape.iter().copied().product();
|
||||
let threads_per_block: usize = 256; // 8 warps per block
|
||||
let n_warps = threads_per_block / 32;
|
||||
let threads_per_block = 256; // 8 warps per block
|
||||
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
|
||||
let dyn_dims_param = if vars.is_empty() {
|
||||
""
|
||||
@@ -150,24 +150,12 @@ extern \"C\" {{
|
||||
long long iters = {iters};
|
||||
long long iter_stride = {iter_stride};
|
||||
|
||||
float thread_sum = 0.0f;
|
||||
for (long long i = threadIdx.x; i < iters; i += {threads_per_block})
|
||||
thread_sum += (float)in[in_start + i * iter_stride];
|
||||
|
||||
for (int offset = 16; offset > 0; offset >>= 1)
|
||||
thread_sum += __shfl_down_sync(0xffffffff, thread_sum, offset);
|
||||
|
||||
__shared__ float warp_sums[{n_warps}];
|
||||
int lane = threadIdx.x & 31;
|
||||
int warp = threadIdx.x >> 5;
|
||||
if (lane == 0) warp_sums[warp] = thread_sum;
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x == 0) {{
|
||||
float sum = 0.0f;
|
||||
for (int w = 0; w < {n_warps}; w++) sum += warp_sums[w];
|
||||
out[{out_index}] = ({dtype})(sum / (float)iters);
|
||||
{dtype} sum = 0;
|
||||
for (long long i = 0; i < iters; i++) {{
|
||||
sum += in[in_start + i * iter_stride];
|
||||
}}
|
||||
|
||||
out[{out_index}] = ({dtype})(sum / ({dtype})iters);
|
||||
}}
|
||||
}}",
|
||||
dtype = dtype,
|
||||
@@ -180,8 +168,6 @@ extern \"C\" {{
|
||||
.substitute('z', Expression::from(1))
|
||||
.simplify()
|
||||
.to_kernel(),
|
||||
threads_per_block = threads_per_block,
|
||||
n_warps = n_warps,
|
||||
);
|
||||
|
||||
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
|
||||
@@ -198,9 +184,9 @@ extern \"C\" {{
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(n_outputs, 1.into(), 1.into()), // grid
|
||||
(threads_per_block.into(), 1.into(), 1.into()), // block
|
||||
0.into(), // shmem size
|
||||
(n_outputs, 1.into(), 1.into()), // grid
|
||||
(1.into(), 1.into(), 1.into()), // blocks (single-threaded)
|
||||
0.into(), // shmem size
|
||||
FxHashMap::default(),
|
||||
)
|
||||
}
|
||||
@@ -294,9 +280,6 @@ impl EgglogOp for KernelScatterNoCopy {
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
// Match KernelScatter and rewrite to KernelScatterNoCopy with ConsumedBuffer on dest.
|
||||
// ConsumedBuffer wraps dest to signal in-place modification.
|
||||
// This is only valid when the destination buffer can also represent
|
||||
// the scatter output layout. If dest is a strided/broadcast view,
|
||||
// regular Scatter must first materialize a contiguous output copy.
|
||||
//
|
||||
// Two-phase resolution:
|
||||
// 1. During (run): cleanup rules delete ConsumedBuffer if dest is shared (another op uses it)
|
||||
@@ -307,31 +290,12 @@ impl EgglogOp for KernelScatterNoCopy {
|
||||
// If ConsumedBuffer was deleted (shared case), cascade cleanup removes the dependent
|
||||
// ICons and KernelScatterNoCopy Op, leaving only KernelScatter.
|
||||
let mut rules = vec![
|
||||
Rule::raw("(relation consumed_buffer_ilist_contains (IList IR))"),
|
||||
Rule::raw(
|
||||
"(rule
|
||||
((= ?list (ICons ?head ?tail)))
|
||||
((consumed_buffer_ilist_contains ?list ?head))
|
||||
:ruleset cleanup
|
||||
:name \"consumed-buffer-ilist-contains-head\"
|
||||
)",
|
||||
),
|
||||
Rule::raw(
|
||||
"(rule
|
||||
((= ?list (ICons ?head ?tail))
|
||||
(consumed_buffer_ilist_contains ?tail ?item))
|
||||
((consumed_buffer_ilist_contains ?list ?item))
|
||||
:ruleset cleanup
|
||||
:name \"consumed-buffer-ilist-contains-tail\"
|
||||
)",
|
||||
),
|
||||
// Rewrite: KernelScatter -> KernelScatterNoCopy with ConsumedBuffer
|
||||
Rule::raw(
|
||||
"(rule
|
||||
(
|
||||
(= ?scatter (Op (KernelScatter ?ds ?dst ?is ?istr ?ss ?os ?dt)
|
||||
(ICons ?dest (ICons ?indexes (ICons ?src (INil))))))
|
||||
(= ?dst ?os)
|
||||
(= ?dty (dtype ?src))
|
||||
)
|
||||
(
|
||||
@@ -341,7 +305,6 @@ impl EgglogOp for KernelScatterNoCopy {
|
||||
(union ?scatter ?nocopy)
|
||||
(set (dtype ?nocopy) ?dty)
|
||||
)
|
||||
:ruleset buffer_reuse
|
||||
:name \"scatter to scatter-no-copy\"
|
||||
)",
|
||||
),
|
||||
@@ -351,7 +314,6 @@ impl EgglogOp for KernelScatterNoCopy {
|
||||
((= ?cb (ConsumedBuffer ?a))
|
||||
(= ?dt (dtype ?a)))
|
||||
((set (dtype ?cb) ?dt))
|
||||
:ruleset dtype_prop
|
||||
:name \"consumed-buffer-dtype\"
|
||||
)",
|
||||
),
|
||||
@@ -361,28 +323,13 @@ impl EgglogOp for KernelScatterNoCopy {
|
||||
"(rule
|
||||
((= ?cb (ConsumedBuffer ?a))
|
||||
(= ?op1 (Op ?k1 ?ilist1))
|
||||
(consumed_buffer_ilist_contains ?ilist1 ?cb)
|
||||
(= ?ilist1 (ICons ?cb ?rest1))
|
||||
(= ?op2 (Op ?k2 ?ilist2))
|
||||
(!= ?op1 ?op2)
|
||||
(consumed_buffer_ilist_contains ?ilist2 ?a))
|
||||
(= ?ilist2 (ICons ?a ?t2)))
|
||||
((delete (ConsumedBuffer ?a)))
|
||||
:ruleset cleanup
|
||||
:name \"consumed-buffer-cleanup-shared-op-use\"
|
||||
)",
|
||||
));
|
||||
// If a valid no-copy scatter survives cleanup, it dominates the copying scatter.
|
||||
// This must run before base_cleanup resolves ConsumedBuffer back to the destination.
|
||||
rules.push(Rule::raw(
|
||||
"(rule
|
||||
((= ?cb (ConsumedBuffer ?dest))
|
||||
(= ?scatter (Op (KernelScatter ?ds ?dst ?is ?istr ?ss ?os ?dt)
|
||||
(ICons ?dest (ICons ?indexes (ICons ?src (INil))))))
|
||||
(= ?nocopy (Op (KernelScatterNoCopy ?ds ?dst ?is ?istr ?ss ?os ?dt)
|
||||
(ICons ?cb (ICons ?indexes (ICons ?src (INil)))))))
|
||||
((delete (Op (KernelScatter ?ds ?dst ?is ?istr ?ss ?os ?dt)
|
||||
(ICons ?dest (ICons ?indexes (ICons ?src (INil)))))))
|
||||
:ruleset post_cleanup
|
||||
:name \"scatter-no-copy-dominates-valid-consumed-buffer\"
|
||||
:name \"consumed-buffer-cleanup-pos\"
|
||||
)",
|
||||
));
|
||||
// Surviving ConsumedBuffers are valid — union with source and delete.
|
||||
@@ -509,8 +456,8 @@ extern \"C\" {{
|
||||
func,
|
||||
module,
|
||||
scatter_kernel,
|
||||
(n_src.ceil_div(256), 1.into(), 1.into()),
|
||||
(256.into(), 1.into(), 1.into()),
|
||||
(n_src, 1.into(), 1.into()),
|
||||
(1.into(), 1.into(), 1.into()),
|
||||
0.into(),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
@@ -713,7 +660,6 @@ impl EgglogOp for KernelBatchMatVec {
|
||||
(union ?sum ?bmv)
|
||||
(set (dtype ?bmv) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name \"batch mat-vec\"
|
||||
)"
|
||||
)]
|
||||
@@ -994,7 +940,6 @@ impl EgglogOp for KernelBatchMatMul {
|
||||
(union ?sum ?bmm)
|
||||
(set (dtype ?bmm) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name \"batch matmul\"
|
||||
)"
|
||||
)]
|
||||
@@ -1234,7 +1179,6 @@ impl EgglogOp for KernelSoftmax {
|
||||
(union ?sm ?ksm)
|
||||
(set (dtype ?ksm) (F32))
|
||||
)
|
||||
:ruleset kernel_lower
|
||||
:name \"softmax-to-kernel-f32\"
|
||||
)",
|
||||
),
|
||||
@@ -1507,7 +1451,6 @@ impl EgglogOp for KernelExp {
|
||||
(union ?exp2 ?kexp)
|
||||
(set (dtype ?kexp) ?dt)
|
||||
)
|
||||
:ruleset direct_kernel
|
||||
:name \"direct-exp-fusion\"
|
||||
)",
|
||||
),
|
||||
@@ -1669,17 +1612,9 @@ impl EgglogOp for KernelSigmoid {
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![
|
||||
// Stage the HLIR sigmoid pattern through a small marker so repeated
|
||||
// default passes do not re-run one large join over every Mul/Add/Recip.
|
||||
// Match the HLIR pattern directly: Recip(Add(Exp2(Mul(Mul(x, -1), log2e)), 1))
|
||||
Rule::raw(
|
||||
"(datatype*
|
||||
(KernelSigmoidScaledState
|
||||
(MkKernelSigmoidScaledState IR EList EList DType)
|
||||
)
|
||||
)
|
||||
(function kernel_sigmoid_scaled (IR) KernelSigmoidScaledState :merge new)
|
||||
|
||||
(rule
|
||||
"(rule
|
||||
(
|
||||
(= ?neg1 (Op (Constant ?nv) (INil)))
|
||||
(< ?nv -0.99)
|
||||
@@ -1689,33 +1624,19 @@ impl EgglogOp for KernelSigmoid {
|
||||
(> ?lv 1.44)
|
||||
(< ?lv 1.45)
|
||||
(= ?scaled (Op (Mul ?shape ?neg_out_stride ?log2e_stride ?scaled_stride) (ICons ?neg_x (ICons ?log2e (INil)))))
|
||||
(= ?dt (dtype ?x))
|
||||
)
|
||||
(
|
||||
(set (kernel_sigmoid_scaled ?scaled)
|
||||
(MkKernelSigmoidScaledState ?x ?shape ?x_stride ?dt))
|
||||
)
|
||||
:ruleset direct_kernel
|
||||
:name \"direct-sigmoid-scaled-marker\"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?scaled_state (kernel_sigmoid_scaled ?scaled))
|
||||
(= ?scaled_state (MkKernelSigmoidScaledState ?x ?shape ?x_stride ?dt))
|
||||
(= ?exp2 (Op (Exp2 ?shape ?scaled_stride ?exp_stride) (ICons ?scaled (INil))))
|
||||
(= ?one (Op (Constant ?ov) (INil)))
|
||||
(> ?ov 0.99)
|
||||
(< ?ov 1.01)
|
||||
(= ?plus_one (Op (Add ?shape ?exp_stride ?one_stride ?add_stride) (ICons ?exp2 (ICons ?one (INil)))))
|
||||
(= ?sig_out (Op (Recip ?shape ?add_stride ?out_stride) (ICons ?plus_one (INil))))
|
||||
(= ?dt (dtype ?x))
|
||||
)
|
||||
(
|
||||
(let ?ksig (Op (KernelSigmoid ?shape ?x_stride ?out_stride ?dt) (ICons ?x (INil))))
|
||||
(union ?sig_out ?ksig)
|
||||
(set (dtype ?ksig) ?dt)
|
||||
)
|
||||
:ruleset direct_kernel
|
||||
:name \"direct-sigmoid-fusion\"
|
||||
)",
|
||||
),
|
||||
@@ -1846,3 +1767,283 @@ extern \"C\" {{
|
||||
"Sigmoid"
|
||||
}
|
||||
}
|
||||
|
||||
/// A unary math function that can appear inside a fused elementwise kernel.
|
||||
/// Each variant has a stable string name (used both as the egglog token in
|
||||
/// the rule-generated ops string and as the `kernel_name()` of the source
|
||||
/// unary kernel op).
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum UnaryFn {
|
||||
Sin,
|
||||
Sqrt,
|
||||
Exp2,
|
||||
Log2,
|
||||
Recip,
|
||||
}
|
||||
|
||||
impl UnaryFn {
|
||||
pub fn name(self) -> &'static str {
|
||||
match self {
|
||||
UnaryFn::Sin => "Sin",
|
||||
UnaryFn::Sqrt => "Sqrt",
|
||||
UnaryFn::Exp2 => "Exp2",
|
||||
UnaryFn::Log2 => "Log2",
|
||||
UnaryFn::Recip => "Recip",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_name(name: &str) -> Self {
|
||||
match name {
|
||||
"Sin" => UnaryFn::Sin,
|
||||
"Sqrt" => UnaryFn::Sqrt,
|
||||
"Exp2" => UnaryFn::Exp2,
|
||||
"Log2" => UnaryFn::Log2,
|
||||
"Recip" => UnaryFn::Recip,
|
||||
_ => panic!("invalid UnaryFn name: {name}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// An LLIR-only op created by fusing a chain of unary elementwise kernels.
|
||||
/// Only fires when every op in the chain shares the same stride pattern,
|
||||
/// so reads and writes use a single `strides` field.
|
||||
///
|
||||
/// The `ops` sequence is carried as a comma-separated egglog `String`
|
||||
/// (e.g. `"Sin,Sqrt,Exp2"`) — it's pure codegen metadata that egglog never
|
||||
/// reasons about, and `String` is a primitive sort, so this avoids
|
||||
/// introducing a new datatype/sort just to carry the list.
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct KernelFusedElementwise {
|
||||
shape: Vec<Expression>,
|
||||
strides: Vec<Expression>,
|
||||
ops: Vec<UnaryFn>,
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
impl KernelFusedElementwise {
|
||||
pub fn ops(&self) -> &[UnaryFn] {
|
||||
&self.ops
|
||||
}
|
||||
}
|
||||
|
||||
impl EgglogOp for KernelFusedElementwise {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
"KernelFusedElementwise",
|
||||
&[
|
||||
("shape", ELIST),
|
||||
("strides", ELIST),
|
||||
("ops", STRING),
|
||||
("dtype", DTYPE),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
fn n_inputs(&self) -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
let unaries = [
|
||||
("KernelSin", UnaryFn::Sin),
|
||||
("KernelSqrt", UnaryFn::Sqrt),
|
||||
("KernelExp2", UnaryFn::Exp2),
|
||||
("KernelLog2", UnaryFn::Log2),
|
||||
("KernelRecip", UnaryFn::Recip),
|
||||
];
|
||||
let mut rules = Vec::with_capacity(unaries.len() * unaries.len() + unaries.len());
|
||||
|
||||
// Pair fusion: two adjacent pure-elementwise unaries -> Fused[a, b].
|
||||
for (a_name, a_fn) in unaries {
|
||||
for (b_name, b_fn) in unaries {
|
||||
let (a_str, b_str) = (a_fn.name(), b_fn.name());
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule
|
||||
(
|
||||
(= ?a (Op ({a_name} ?shape ?strides ?strides ?dt) (ICons ?inp (INil))))
|
||||
(= ?b (Op ({b_name} ?shape ?strides ?strides ?dt) (ICons ?a (INil))))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (KernelFusedElementwise ?shape ?strides
|
||||
\"{a_str},{b_str}\" ?dt)
|
||||
(ICons ?inp (INil))))
|
||||
(union ?b ?fused)
|
||||
)
|
||||
:name \"fuse-{a_name}-{b_name}\"
|
||||
)"
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
// Chain extend: Fused[ops] -> unary -> Fused[ops + \",<new>\"]. One
|
||||
// rule per outer unary. `+` is the builtin variadic string concat,
|
||||
// so this is O(1) per firing and handles chains of any length
|
||||
// without recursion.
|
||||
for (b_name, b_fn) in unaries {
|
||||
let b_str = b_fn.name();
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule
|
||||
(
|
||||
(= ?fused (Op (KernelFusedElementwise ?shape ?strides ?ops ?dt)
|
||||
(ICons ?inp (INil))))
|
||||
(= ?next (Op ({b_name} ?shape ?strides ?strides ?dt)
|
||||
(ICons ?fused (INil))))
|
||||
)
|
||||
(
|
||||
(let ?new_ops (+ ?ops \",{b_str}\"))
|
||||
(let ?new_fused (Op (KernelFusedElementwise ?shape ?strides ?new_ops ?dt)
|
||||
(ICons ?inp (INil))))
|
||||
(union ?next ?new_fused)
|
||||
)
|
||||
:name \"extend-Fused-{b_name}\"
|
||||
)"
|
||||
)));
|
||||
}
|
||||
|
||||
rules
|
||||
}
|
||||
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
// The `ops` field is a String enode; its label is the quoted
|
||||
// literal (e.g. `"Sin,Sqrt"`), so strip the quotes and split.
|
||||
let ops_str = egraph.enodes[kind_children[2]].0.replace('"', "");
|
||||
let ops = if ops_str.is_empty() {
|
||||
Vec::new()
|
||||
} else {
|
||||
ops_str.split(',').map(UnaryFn::from_name).collect()
|
||||
};
|
||||
(
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(Self {
|
||||
shape: extract_expr_list(egraph, kind_children[0], list_cache, expr_cache).unwrap(),
|
||||
strides: extract_expr_list(egraph, kind_children[1], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
ops,
|
||||
dtype: extract_dtype(egraph, kind_children[3]),
|
||||
})),
|
||||
input_enodes,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl KernelOp for KernelFusedElementwise {
|
||||
fn compile(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> (
|
||||
CudaFunction,
|
||||
Arc<CudaModule>,
|
||||
String,
|
||||
(Expression, Expression, Expression),
|
||||
(Expression, Expression, Expression),
|
||||
Expression,
|
||||
FxHashMap<char, CudaSlice<u8>>,
|
||||
) {
|
||||
let vars = self
|
||||
.shape
|
||||
.iter()
|
||||
.flat_map(|e| e.dyn_vars())
|
||||
.chain(self.strides.iter().flat_map(|e| e.dyn_vars()))
|
||||
.collect::<FxHashSet<_>>();
|
||||
let dtype = cuda_dtype(self.dtype);
|
||||
let includes = dtype_includes(&[self.dtype]);
|
||||
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
|
||||
let dyn_dims_param = if vars.is_empty() {
|
||||
""
|
||||
} else {
|
||||
", const int* dyn_dims"
|
||||
};
|
||||
let n_elements = self
|
||||
.shape
|
||||
.iter()
|
||||
.copied()
|
||||
.product::<Expression>()
|
||||
.to_kernel();
|
||||
let idx = flatten_strides(&self.shape, &self.strides).to_kernel();
|
||||
let ops_body = self
|
||||
.ops
|
||||
.iter()
|
||||
.map(|op| match op {
|
||||
UnaryFn::Sin => "val = sinf(val);",
|
||||
UnaryFn::Sqrt => "val = sqrtf(val);",
|
||||
UnaryFn::Exp2 => "val = exp2f(val);",
|
||||
UnaryFn::Log2 => "val = log2f(val);",
|
||||
UnaryFn::Recip => "val = 1.0f / val;",
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n ");
|
||||
let kernel = format!(
|
||||
"{includes}
|
||||
{dyn_defines}
|
||||
extern \"C\" {{
|
||||
__global__ void fused_elementwise_k({dtype} *out, const {dtype} *in{dyn_dims_param}) {{
|
||||
long long const_z = (long long)blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (const_z >= {n_elements}) return;
|
||||
long long idx = {idx};
|
||||
{dtype} val = in[idx];
|
||||
{ops_body}
|
||||
out[idx] = val;
|
||||
}}
|
||||
}}"
|
||||
);
|
||||
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
|
||||
(module.clone(), func.clone())
|
||||
} else {
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let func = module.load_function("fused_elementwise_k").unwrap();
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
(module, func)
|
||||
};
|
||||
let out_size = self.shape.iter().copied().product::<Expression>();
|
||||
(
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(out_size.ceil_div(256), 1.into(), 1.into()),
|
||||
(out_size.min(256), 1.into(), 1.into()),
|
||||
0.into(),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
self.shape.iter().copied().product()
|
||||
}
|
||||
|
||||
fn output_bytes(&self) -> Expression {
|
||||
(self.output_size() * self.dtype.bits()).ceil_div(8)
|
||||
}
|
||||
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
self.output_bytes()
|
||||
}
|
||||
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
self.output_bytes()
|
||||
}
|
||||
|
||||
fn flops(&self) -> Expression {
|
||||
self.output_size() * (self.ops.len() as i32)
|
||||
}
|
||||
|
||||
fn output_dtype(&self) -> DType {
|
||||
self.dtype
|
||||
}
|
||||
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
"FusedElementwise"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,7 +13,6 @@ use itertools::Itertools;
|
||||
use luminal::{
|
||||
egglog_utils::{api::Rule, base::OP_KIND},
|
||||
graph::LLIRGraph,
|
||||
hlir::{LoopEnd, LoopInput, LoopInputStatic, LoopOutput, LoopOutputSelect, LoopStart},
|
||||
op::{EgglogOp, LLIROp},
|
||||
prelude::{
|
||||
petgraph::{Direction, algo::toposort, visit::EdgeRef},
|
||||
@@ -23,11 +22,10 @@ use luminal::{
|
||||
use tracing::{Level, enabled, span};
|
||||
|
||||
use crate::{
|
||||
host::{DeviceBuffer, HostOp},
|
||||
host::HostOp,
|
||||
kernel::{
|
||||
CudaFunctionExt, CudaGraphExecHandle, CudaGraphHandle, KernelOp, create_cuda_event,
|
||||
destroy_cuda_event,
|
||||
fusion::region_codegen::{self, CompileUnit},
|
||||
hlir::{clear_global_dyn_dims, get_global_dyn_dims, set_global_dyn_dims},
|
||||
},
|
||||
runtime::partition_marked_convex,
|
||||
@@ -48,12 +46,8 @@ struct CompiledKernel {
|
||||
shared_mem: Expression,
|
||||
/// Input node indices (for buffer lookup)
|
||||
inputs: Vec<NodeIndex>,
|
||||
/// Human-readable labels for input nodes, for launch diagnostics.
|
||||
input_labels: Vec<String>,
|
||||
/// Reference to the KernelOp for trait methods
|
||||
kernel_op: Arc<Box<dyn KernelOp>>,
|
||||
/// Whether this compiled CUDA function has a trailing dyn_dims parameter.
|
||||
has_dyn_dims_param: bool,
|
||||
/// Internal buffers allocated for this kernel
|
||||
internal_bufs: Vec<CudaSlice<u8>>,
|
||||
/// Device constants from compile()
|
||||
@@ -73,9 +67,7 @@ impl CompiledKernel {
|
||||
block: (Expression, Expression, Expression),
|
||||
shared_mem: Expression,
|
||||
inputs: Vec<NodeIndex>,
|
||||
input_labels: Vec<String>,
|
||||
kernel_op: Arc<Box<dyn KernelOp>>,
|
||||
has_dyn_dims_param: bool,
|
||||
constants: FxHashMap<char, CudaSlice<u8>>,
|
||||
kernel_name: &'static str,
|
||||
) -> Self {
|
||||
@@ -86,9 +78,7 @@ impl CompiledKernel {
|
||||
block,
|
||||
shared_mem,
|
||||
inputs,
|
||||
input_labels,
|
||||
kernel_op,
|
||||
has_dyn_dims_param,
|
||||
internal_bufs: Vec::new(),
|
||||
constants,
|
||||
graph_node: None,
|
||||
@@ -235,7 +225,7 @@ impl HostOp for CudaGraphOp {
|
||||
stream: &Arc<CudaStream>,
|
||||
_self_node: NodeIndex,
|
||||
_inputs: &[NodeIndex],
|
||||
buffers: &FxHashMap<NodeIndex, DeviceBuffer>,
|
||||
buffers: &FxHashMap<NodeIndex, &CudaSlice<u8>>,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> anyhow::Result<()> {
|
||||
self.execute_internal(stream, buffers, dyn_map)
|
||||
@@ -267,40 +257,6 @@ impl HostOp for CudaGraphOp {
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn extra_buffer_lifetimes(&self) -> Option<Vec<(NodeIndex, usize, usize)>> {
|
||||
let state = self.state.borrow();
|
||||
let mut lifetimes: FxHashMap<NodeIndex, (usize, usize)> = FxHashMap::default();
|
||||
let max_step = state.kernels.len().saturating_sub(1);
|
||||
|
||||
let mut touch = |node: NodeIndex, step: usize| {
|
||||
lifetimes
|
||||
.entry(node)
|
||||
.and_modify(|(first, last)| {
|
||||
*first = (*first).min(step);
|
||||
*last = (*last).max(step);
|
||||
})
|
||||
.or_insert((step, step));
|
||||
};
|
||||
|
||||
for (step, kernel) in state.kernels.iter().enumerate() {
|
||||
for &input in &kernel.inputs {
|
||||
touch(input, step);
|
||||
}
|
||||
touch(kernel.node, step);
|
||||
}
|
||||
|
||||
for node in self.extra_buffer_nodes() {
|
||||
lifetimes.entry(node).or_insert((0, max_step));
|
||||
}
|
||||
|
||||
Some(
|
||||
lifetimes
|
||||
.into_iter()
|
||||
.map(|(node, (start, end))| (node, start, end))
|
||||
.collect(),
|
||||
)
|
||||
}
|
||||
|
||||
fn extra_buffer_sizes(&self) -> FxHashMap<NodeIndex, Expression> {
|
||||
self.buffer_sizes.clone()
|
||||
}
|
||||
@@ -311,64 +267,11 @@ impl HostOp for CudaGraphOp {
|
||||
}
|
||||
|
||||
impl CudaGraphOp {
|
||||
fn expected_kernel_inputs(kernel_name: &str) -> Option<usize> {
|
||||
match kernel_name {
|
||||
"Constant" | "Iota" => Some(0),
|
||||
"MaxReduce" | "MeanReduce" | "SumReduce" | "Cast" | "Exp" | "Exp2" | "Log2" | "Sin"
|
||||
| "Recip" | "Sigmoid" | "Softmax" | "Sqrt" => Some(1),
|
||||
"Add" | "BatchMatMul" | "BatchMatVec" | "Embed" | "Gather" | "LessThan" | "Mod"
|
||||
| "Mul" => Some(2),
|
||||
"Scatter" | "ScatterNoCopy" => Some(3),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn kernel_requires_output_buffer(
|
||||
kernel: &CompiledKernel,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> bool {
|
||||
kernel.kernel_op.output_size().exec(dyn_map).unwrap_or(1) != 0
|
||||
&& kernel.kernel_op.output_aliases_input().is_none()
|
||||
}
|
||||
|
||||
fn validate_kernel_pointers(
|
||||
kernel: &CompiledKernel,
|
||||
output_ptr: u64,
|
||||
input_ptrs: &[u64],
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> anyhow::Result<()> {
|
||||
if Self::kernel_requires_output_buffer(kernel, dyn_map) && output_ptr == 0 {
|
||||
anyhow::bail!(
|
||||
"missing output buffer for CUDA kernel {} at LLIR node {:?}",
|
||||
kernel.kernel_name,
|
||||
kernel.node,
|
||||
);
|
||||
}
|
||||
|
||||
for (idx, (input_node, input_ptr)) in kernel.inputs.iter().zip(input_ptrs).enumerate() {
|
||||
if *input_ptr == 0 {
|
||||
let input_label = kernel
|
||||
.input_labels
|
||||
.get(idx)
|
||||
.map(String::as_str)
|
||||
.unwrap_or("unknown");
|
||||
anyhow::bail!(
|
||||
"missing input buffer {idx} for CUDA kernel {} at LLIR node {:?}; input LLIR node {:?} ({input_label})",
|
||||
kernel.kernel_name,
|
||||
kernel.node,
|
||||
input_node,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Execute the CUDA graph with the given buffers and dynamic dimensions.
|
||||
fn execute_internal(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
buffers: &FxHashMap<NodeIndex, DeviceBuffer>,
|
||||
buffers: &FxHashMap<NodeIndex, &CudaSlice<u8>>,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> anyhow::Result<()> {
|
||||
let mut state = self.state.borrow_mut();
|
||||
@@ -439,7 +342,7 @@ impl CudaGraphOp {
|
||||
let mut current_buffer_ptrs: FxHashMap<NodeIndex, u64> = FxHashMap::default();
|
||||
for &node in &self.buffer_nodes {
|
||||
if let Some(buf) = buffers.get(&node) {
|
||||
current_buffer_ptrs.insert(node, buf.ptr());
|
||||
current_buffer_ptrs.insert(node, buf.device_ptr(stream).0);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -487,26 +390,13 @@ impl CudaGraphOp {
|
||||
.iter()
|
||||
.map(|inp| current_buffer_ptrs.get(inp).copied().unwrap_or(0))
|
||||
.collect();
|
||||
Self::validate_kernel_pointers(kernel, output_ptr, &input_ptrs, dyn_map)?;
|
||||
let kernel_dyn_dims_ptr = if kernel.has_dyn_dims_param {
|
||||
dyn_dims_ptr
|
||||
} else {
|
||||
0
|
||||
};
|
||||
if kernel.has_dyn_dims_param && kernel_dyn_dims_ptr == 0 {
|
||||
anyhow::bail!(
|
||||
"missing dyn_dims buffer for CUDA kernel {} at LLIR node {:?}",
|
||||
kernel.kernel_name,
|
||||
kernel.node,
|
||||
);
|
||||
}
|
||||
|
||||
let param_values = kernel.kernel_op.build_params(
|
||||
stream,
|
||||
output_ptr,
|
||||
&input_ptrs,
|
||||
&kernel.internal_bufs,
|
||||
kernel_dyn_dims_ptr,
|
||||
dyn_dims_ptr,
|
||||
);
|
||||
state.kernel_params[idx] = UnifiedKernelParams::new(param_values);
|
||||
}
|
||||
@@ -533,19 +423,6 @@ impl CudaGraphOp {
|
||||
kernel.block.1.exec(dyn_map).unwrap() as u32,
|
||||
kernel.block.2.exec(dyn_map).unwrap() as u32,
|
||||
);
|
||||
if grid_dim.0 == 0
|
||||
|| grid_dim.1 == 0
|
||||
|| grid_dim.2 == 0
|
||||
|| block_dim.0 == 0
|
||||
|| block_dim.1 == 0
|
||||
|| block_dim.2 == 0
|
||||
{
|
||||
anyhow::bail!(
|
||||
"invalid CUDA launch dimensions for kernel {} at LLIR node {:?}: grid={grid_dim:?} block={block_dim:?}",
|
||||
kernel.kernel_name,
|
||||
kernel.node,
|
||||
);
|
||||
}
|
||||
let shared_mem = kernel.shared_mem.exec(dyn_map).unwrap() as u32;
|
||||
let cu_func = unsafe { kernel.function.raw_function() };
|
||||
|
||||
@@ -574,7 +451,7 @@ impl CudaGraphOp {
|
||||
&self,
|
||||
state: &mut std::cell::RefMut<'_, CudaGraphOpState>,
|
||||
stream: &Arc<CudaStream>,
|
||||
buffers: &FxHashMap<NodeIndex, DeviceBuffer>,
|
||||
buffers: &FxHashMap<NodeIndex, &CudaSlice<u8>>,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> anyhow::Result<()> {
|
||||
let ctx = stream.context().clone();
|
||||
@@ -596,7 +473,7 @@ impl CudaGraphOp {
|
||||
let mut buffer_ptrs: FxHashMap<NodeIndex, u64> = FxHashMap::default();
|
||||
for &node in &self.buffer_nodes {
|
||||
if let Some(buf) = buffers.get(&node) {
|
||||
buffer_ptrs.insert(node, buf.ptr());
|
||||
buffer_ptrs.insert(node, buf.device_ptr(stream).0);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -643,19 +520,6 @@ impl CudaGraphOp {
|
||||
kernel.block.1.exec(dyn_map).unwrap() as u32,
|
||||
kernel.block.2.exec(dyn_map).unwrap() as u32,
|
||||
);
|
||||
if grid_dim.0 == 0
|
||||
|| grid_dim.1 == 0
|
||||
|| grid_dim.2 == 0
|
||||
|| block_dim.0 == 0
|
||||
|| block_dim.1 == 0
|
||||
|| block_dim.2 == 0
|
||||
{
|
||||
anyhow::bail!(
|
||||
"invalid CUDA launch dimensions for kernel {} at LLIR node {:?}: grid={grid_dim:?} block={block_dim:?}",
|
||||
kernel.kernel_name,
|
||||
kernel.node,
|
||||
);
|
||||
}
|
||||
let shared_mem = kernel.shared_mem.exec(dyn_map).unwrap() as u32;
|
||||
|
||||
let output_ptr = buffer_ptrs.get(&kernel.node).copied().unwrap_or(0);
|
||||
@@ -664,41 +528,18 @@ impl CudaGraphOp {
|
||||
.iter()
|
||||
.map(|inp| buffer_ptrs.get(inp).copied().unwrap_or(0))
|
||||
.collect();
|
||||
Self::validate_kernel_pointers(kernel, output_ptr, &input_ptrs, dyn_map)?;
|
||||
let kernel_dyn_dims_ptr = if kernel.has_dyn_dims_param {
|
||||
dyn_dims_ptr
|
||||
} else {
|
||||
0
|
||||
};
|
||||
if kernel.has_dyn_dims_param && kernel_dyn_dims_ptr == 0 {
|
||||
anyhow::bail!(
|
||||
"missing dyn_dims buffer for CUDA kernel {} at LLIR node {:?}",
|
||||
kernel.kernel_name,
|
||||
kernel.node,
|
||||
);
|
||||
}
|
||||
|
||||
let param_values = kernel.kernel_op.build_params(
|
||||
stream,
|
||||
output_ptr,
|
||||
&input_ptrs,
|
||||
&kernel.internal_bufs,
|
||||
kernel_dyn_dims_ptr,
|
||||
dyn_dims_ptr,
|
||||
);
|
||||
let mut params = UnifiedKernelParams::new(param_values);
|
||||
|
||||
let cu_func = unsafe { kernel.function.raw_function() };
|
||||
let kernel_node = kernel.node;
|
||||
if std::env::var_os("LUMINAL_CUDA_DEBUG_GRAPH").is_some() {
|
||||
eprintln!(
|
||||
"cuGraphAddKernelNode kernel={} node={:?} grid={grid_dim:?} block={block_dim:?} shared_mem={shared_mem} inputs={} has_dyn={} params={}",
|
||||
kernel.kernel_name,
|
||||
kernel.node,
|
||||
kernel.inputs.len(),
|
||||
kernel.has_dyn_dims_param,
|
||||
params.values.len(),
|
||||
);
|
||||
}
|
||||
|
||||
// Get timing event for this index (separate access from kernels)
|
||||
let timing_event = if tracing_enabled {
|
||||
@@ -814,41 +655,6 @@ pub fn kernel_to_host(
|
||||
}
|
||||
|
||||
let kernel_subgraphs = partition_marked_convex(llir_graph, &kernel_ops_in_graph).unwrap();
|
||||
// Compute the set of FS / FE / FusedX nodes globally absorbed by some
|
||||
// FusionEnd in the LLIR. Used by `build_compile_units` to suppress
|
||||
// standalone marker compile units for shared FS leaves whose consumers
|
||||
// live in a different convex subgraph than the FS itself.
|
||||
let globally_absorbed = region_codegen::globally_absorbed_markers(llir_graph);
|
||||
|
||||
let name_of = |graph: &LLIRGraph, idx: NodeIndex| -> Option<&'static str> {
|
||||
graph
|
||||
.node_weight(idx)
|
||||
.and_then(|op| op.to_dialect::<dyn KernelOp>().map(|k| k.kernel_name()))
|
||||
};
|
||||
let is_transparent_input = |graph: &LLIRGraph, node: NodeIndex| -> bool {
|
||||
name_of(graph, node) == Some("FusionStart")
|
||||
|| graph[node].to_op::<LoopStart>().is_some()
|
||||
|| graph[node].to_op::<LoopEnd>().is_some()
|
||||
|| graph[node].to_op::<LoopInput>().is_some()
|
||||
|| graph[node].to_op::<LoopInputStatic>().is_some()
|
||||
|| graph[node].to_op::<LoopOutput>().is_some()
|
||||
|| graph[node].to_op::<LoopOutputSelect>().is_some()
|
||||
};
|
||||
let resolve_transparent_input = |graph: &LLIRGraph, mut node: NodeIndex| -> NodeIndex {
|
||||
let mut visited = FxHashSet::default();
|
||||
while visited.insert(node) && is_transparent_input(graph, node) {
|
||||
let Some(pred) = graph
|
||||
.edges_directed(node, Direction::Incoming)
|
||||
.sorted_by_key(|e| e.id())
|
||||
.map(|e| e.source())
|
||||
.next()
|
||||
else {
|
||||
break;
|
||||
};
|
||||
node = pred;
|
||||
}
|
||||
node
|
||||
};
|
||||
|
||||
// Track which kernel node belongs to which CudaGraphOp (for later edge creation)
|
||||
let mut kernel_to_cuda_graph: FxHashMap<NodeIndex, NodeIndex> = FxHashMap::default();
|
||||
@@ -866,7 +672,6 @@ pub fn kernel_to_host(
|
||||
let mut all_dyn_dims = FxHashSet::default();
|
||||
let mut all_buffer_nodes = FxHashSet::default();
|
||||
let mut all_buffer_sizes: FxHashMap<NodeIndex, Expression> = FxHashMap::default();
|
||||
let mut external_inputs = FxHashSet::default();
|
||||
|
||||
// Pre-scan: collect all dynamic vars from all kernel ops without compiling.
|
||||
// This uses KernelOp::all_dyn_vars() which inspects struct expression fields.
|
||||
@@ -880,151 +685,49 @@ pub fn kernel_to_host(
|
||||
// Set global dyn dims ordering so compiles use consistent indices
|
||||
let mut global_dyn_dims: Vec<char> = all_dyn_dims.iter().copied().collect();
|
||||
global_dyn_dims.sort();
|
||||
set_global_dyn_dims(global_dyn_dims.clone());
|
||||
if !global_dyn_dims.is_empty() {
|
||||
set_global_dyn_dims(global_dyn_dims.clone());
|
||||
}
|
||||
|
||||
// Group the topo order into compile units: each FusionEnd-rooted
|
||||
// region collapses to a single CompileUnit::Region (one fused
|
||||
// CUDA kernel for the whole DAG); everything else stays as
|
||||
// CompileUnit::Single (the existing per-op compile path).
|
||||
let compile_units =
|
||||
region_codegen::build_compile_units(&topo_order, llir_graph, &globally_absorbed);
|
||||
// Compile all kernels with global ordering for correct dyn_dims indices
|
||||
let mut kernels = Vec::with_capacity(topo_order.len());
|
||||
for kernel_node_idx in &topo_order {
|
||||
let kernel_op_ref = llir_graph[*kernel_node_idx]
|
||||
.to_dialect::<dyn KernelOp>()
|
||||
.unwrap();
|
||||
|
||||
// Compile all units with global ordering for correct dyn_dims indices
|
||||
let mut kernels = Vec::with_capacity(compile_units.len());
|
||||
for unit in &compile_units {
|
||||
match unit {
|
||||
CompileUnit::Single(kernel_node_idx) => {
|
||||
let kernel_op_ref = llir_graph[*kernel_node_idx]
|
||||
.to_dialect::<dyn KernelOp>()
|
||||
.unwrap();
|
||||
let (kernel_function, _, _kernel_str, grid, block, shared_mem, constants) =
|
||||
kernel_op_ref.compile(cuda_stream, kernel_cache);
|
||||
|
||||
let (kernel_function, _, kernel_str, grid, block, shared_mem, constants) =
|
||||
kernel_op_ref.compile(cuda_stream, kernel_cache);
|
||||
let has_dyn_dims_param = kernel_str.contains("dyn_dims");
|
||||
// Collect inputs from graph edges
|
||||
let mut inputs: Vec<NodeIndex> = llir_graph
|
||||
.edges_directed(*kernel_node_idx, Direction::Incoming)
|
||||
.sorted_by_key(|e| e.id())
|
||||
.map(|e| e.source())
|
||||
.collect_vec();
|
||||
|
||||
// Collect inputs from graph edges
|
||||
let inputs: Vec<NodeIndex> = llir_graph
|
||||
.edges_directed(*kernel_node_idx, Direction::Incoming)
|
||||
.sorted_by_key(|e| e.id())
|
||||
.map(|e| e.source())
|
||||
.map(|input| resolve_transparent_input(llir_graph, input))
|
||||
.collect_vec();
|
||||
if let Some(expected_inputs) =
|
||||
CudaGraphOp::expected_kernel_inputs(kernel_op_ref.kernel_name())
|
||||
{
|
||||
assert_eq!(
|
||||
inputs.len(),
|
||||
expected_inputs,
|
||||
"invalid input arity for CUDA kernel {} at LLIR node {:?}",
|
||||
kernel_op_ref.kernel_name(),
|
||||
kernel_node_idx,
|
||||
);
|
||||
}
|
||||
let input_labels = inputs
|
||||
.iter()
|
||||
.map(|&input| {
|
||||
name_of(llir_graph, input)
|
||||
.map(str::to_string)
|
||||
.unwrap_or_else(|| format!("{:?}", llir_graph[input]))
|
||||
})
|
||||
.collect_vec();
|
||||
|
||||
// Collect buffer nodes and sizes
|
||||
// Only add kernel nodes with non-zero output size (MegakernelOps have size 0)
|
||||
let output_size = kernel_op_ref.output_size();
|
||||
if output_size.exec(&FxHashMap::default()).unwrap_or(1) != 0 {
|
||||
all_buffer_nodes.insert(*kernel_node_idx);
|
||||
all_buffer_sizes.insert(*kernel_node_idx, output_size);
|
||||
}
|
||||
all_buffer_nodes.extend(inputs.iter().copied());
|
||||
external_inputs.extend(
|
||||
inputs
|
||||
.iter()
|
||||
.copied()
|
||||
.filter(|input| !subgraph.contains(input)),
|
||||
);
|
||||
|
||||
let kernel_op: Arc<Box<dyn KernelOp>> = Arc::clone(kernel_op_ref);
|
||||
|
||||
kernels.push(CompiledKernel::new(
|
||||
*kernel_node_idx,
|
||||
kernel_function,
|
||||
grid,
|
||||
block,
|
||||
shared_mem,
|
||||
inputs,
|
||||
input_labels,
|
||||
kernel_op.clone(),
|
||||
has_dyn_dims_param,
|
||||
constants,
|
||||
kernel_op.kernel_name(),
|
||||
));
|
||||
}
|
||||
CompileUnit::Region(region) => {
|
||||
// Generate one fused CUDA kernel for the whole region.
|
||||
let compiled = region_codegen::compile_region(
|
||||
region,
|
||||
llir_graph,
|
||||
cuda_stream,
|
||||
kernel_cache,
|
||||
);
|
||||
let has_dyn_dims_param = compiled.kernel_str.contains("dyn_dims");
|
||||
|
||||
// The region's CompiledKernel is keyed on the FE node
|
||||
// (so FE provides trait methods like output_size /
|
||||
// build_params) but its `inputs` are the external
|
||||
// producers, not FE's literal LLIR predecessors —
|
||||
// those are interior FusedX nodes that don't exist
|
||||
// as buffer-bearing nodes from the host's view.
|
||||
let fe_op_ref = llir_graph[region.fe_node]
|
||||
.to_dialect::<dyn KernelOp>()
|
||||
.unwrap();
|
||||
|
||||
let inputs: Vec<NodeIndex> = region
|
||||
.external_inputs
|
||||
.iter()
|
||||
.copied()
|
||||
.map(|input| resolve_transparent_input(llir_graph, input))
|
||||
.collect();
|
||||
let input_labels = inputs
|
||||
.iter()
|
||||
.map(|&input| {
|
||||
name_of(llir_graph, input)
|
||||
.map(str::to_string)
|
||||
.unwrap_or_else(|| format!("{:?}", llir_graph[input]))
|
||||
})
|
||||
.collect_vec();
|
||||
|
||||
let output_size = fe_op_ref.output_size();
|
||||
if output_size.exec(&FxHashMap::default()).unwrap_or(1) != 0 {
|
||||
all_buffer_nodes.insert(region.fe_node);
|
||||
all_buffer_sizes.insert(region.fe_node, output_size);
|
||||
}
|
||||
all_buffer_nodes.extend(inputs.iter().copied());
|
||||
external_inputs.extend(
|
||||
inputs
|
||||
.iter()
|
||||
.copied()
|
||||
.filter(|input| !subgraph.contains(input)),
|
||||
);
|
||||
|
||||
let kernel_op: Arc<Box<dyn KernelOp>> = Arc::clone(fe_op_ref);
|
||||
|
||||
kernels.push(CompiledKernel::new(
|
||||
region.fe_node,
|
||||
compiled.function,
|
||||
compiled.grid,
|
||||
compiled.block,
|
||||
compiled.shared_mem,
|
||||
inputs,
|
||||
input_labels,
|
||||
kernel_op,
|
||||
has_dyn_dims_param,
|
||||
compiled.constants,
|
||||
"FusedRegion",
|
||||
));
|
||||
}
|
||||
// Collect buffer nodes and sizes
|
||||
// Only add kernel nodes with non-zero output size (MegakernelOps have size 0)
|
||||
let output_size = kernel_op_ref.output_size();
|
||||
if output_size.exec(&FxHashMap::default()).unwrap_or(1) != 0 {
|
||||
all_buffer_nodes.insert(*kernel_node_idx);
|
||||
all_buffer_sizes.insert(*kernel_node_idx, output_size);
|
||||
}
|
||||
all_buffer_nodes.extend(inputs.iter().copied());
|
||||
|
||||
let kernel_op: Arc<Box<dyn KernelOp>> = Arc::clone(kernel_op_ref);
|
||||
|
||||
kernels.push(CompiledKernel::new(
|
||||
*kernel_node_idx,
|
||||
kernel_function,
|
||||
grid,
|
||||
block,
|
||||
shared_mem,
|
||||
inputs,
|
||||
kernel_op.clone(),
|
||||
constants,
|
||||
kernel_op.kernel_name(),
|
||||
));
|
||||
}
|
||||
|
||||
// Get the possibly-extended global ordering (kernels may have discovered new dims)
|
||||
@@ -1064,17 +767,16 @@ pub fn kernel_to_host(
|
||||
}
|
||||
cuda_graph_subgraphs.push((cuda_graph_node, subgraph.clone()));
|
||||
|
||||
// Find external inputs: nodes outside subgraph that have edges into
|
||||
// subgraph. Also include normalized FusionStart predecessors, because
|
||||
// the compiled kernels read from the concrete producer buffer rather
|
||||
// than the marker node.
|
||||
external_inputs.extend(subgraph.iter().flat_map(|&node| {
|
||||
llir_graph
|
||||
.edges_directed(node, Direction::Incoming)
|
||||
.map(|e| e.source())
|
||||
.map(|input| resolve_transparent_input(llir_graph, input))
|
||||
.filter(|src| !subgraph.contains(src))
|
||||
}));
|
||||
// Find external inputs: nodes outside subgraph that have edges into subgraph
|
||||
let external_inputs: FxHashSet<NodeIndex> = subgraph
|
||||
.iter()
|
||||
.flat_map(|&node| {
|
||||
llir_graph
|
||||
.edges_directed(node, Direction::Incoming)
|
||||
.map(|e| e.source())
|
||||
.filter(|src| !subgraph.contains(src))
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Add edges from external inputs to CudaGraphOp
|
||||
for input in &external_inputs {
|
||||
@@ -1118,41 +820,22 @@ pub fn kernel_to_host(
|
||||
}
|
||||
}
|
||||
|
||||
// Add each cross-CudaGraphOp dep edge iff it would carry new ordering
|
||||
// information without closing a cycle. The previous topo-position gate
|
||||
// ("skip when src_pos >= dst_pos") was too coarse: it dropped edges
|
||||
// whose src happened to land later in the toposort than their dst even
|
||||
// when no path dst→src actually existed, leaving consumers free to run
|
||||
// before the producer wrote their input buffer (wrong outputs); and it
|
||||
// also added edges that were already implied by an existing src→dst
|
||||
// path (extra serialization, no new info).
|
||||
// Add collected edges (deduplicate), skipping back-edges to preserve DAG property
|
||||
let edges_to_add: FxHashSet<(NodeIndex, NodeIndex)> = edges_to_add.into_iter().collect();
|
||||
use petgraph::algo::has_path_connecting;
|
||||
for (src, dst) in edges_to_add {
|
||||
if has_path_connecting(&*llir_graph, src, dst, None) {
|
||||
continue; // already ordered src→dst by some path; edge redundant
|
||||
}
|
||||
if has_path_connecting(&*llir_graph, dst, src, None) {
|
||||
continue; // adding src→dst would close a cycle
|
||||
}
|
||||
llir_graph.add_edge(src, dst, ());
|
||||
let topo = toposort(&*llir_graph, None).unwrap();
|
||||
let mut topo_pos: FxHashMap<NodeIndex, usize> = FxHashMap::default();
|
||||
for (i, n) in topo.iter().enumerate() {
|
||||
topo_pos.insert(*n, i);
|
||||
}
|
||||
|
||||
// Strip fully-absorbed marker nodes (FusionStart, nested FusionEnd,
|
||||
// FusedX) from the LLIR. Region codegen has already folded them into
|
||||
// a single fused CUDA function anchored at each region's root
|
||||
// FusionEnd; the absorbed nodes have no consumers outside the region
|
||||
// and never need their own buffers. Removing them keeps later
|
||||
// per-execute walks (e.g., `allocate_intermediate_buffers`) from
|
||||
// chewing through dead nodes every decode token.
|
||||
//
|
||||
// Root FusionEnd nodes are NOT in `globally_absorbed` (they were the
|
||||
// walks' starting points), so we keep them — they're the kernel
|
||||
// anchor for the region's compiled kernel.
|
||||
for node in globally_absorbed {
|
||||
// Defensive: only remove if the node still exists.
|
||||
if llir_graph.node_weight(node).is_some() {
|
||||
llir_graph.remove_node(node);
|
||||
for (src, dst) in edges_to_add {
|
||||
// Only add forward edges (src before dst in topo order) to avoid creating cycles
|
||||
let src_pos = topo_pos.get(&src).copied().unwrap_or(usize::MAX);
|
||||
let dst_pos = topo_pos.get(&dst).copied().unwrap_or(usize::MAX);
|
||||
if src_pos >= dst_pos {
|
||||
continue; // Skip back-edges
|
||||
}
|
||||
if !llir_graph.edges_connecting(src, dst).any(|_| true) {
|
||||
llir_graph.add_edge(src, dst, ());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
pub mod dyn_backend;
|
||||
pub mod host;
|
||||
pub mod kernel;
|
||||
mod memory_analysis;
|
||||
pub mod runtime;
|
||||
use std::{
|
||||
ffi::{CStr, CString},
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -41,8 +41,9 @@ fn extract_all_kernel_names(cx: &mut Graph) -> Vec<String> {
|
||||
all_names
|
||||
}
|
||||
|
||||
/// When dest is NOT shared with any other compute op, KernelScatterNoCopy should
|
||||
/// be the only scatter variant left after post-cleanup.
|
||||
/// When dest is NOT shared with any other op, KernelScatterNoCopy should be available.
|
||||
/// The ConsumedBuffer cleanup rule should NOT fire because dest only appears inside
|
||||
/// the ConsumedBuffer (not in any other ICons).
|
||||
#[test]
|
||||
fn test_scatter_nocopy_selected_when_dest_unshared() {
|
||||
let ctx = CudaContext::new(0).unwrap();
|
||||
@@ -61,17 +62,12 @@ fn test_scatter_nocopy_selected_when_dest_unshared() {
|
||||
let names = extract_all_kernel_names(&mut cx);
|
||||
println!("All possible kernels: {:?}", names);
|
||||
|
||||
// KernelScatterNoCopy should be the only scatter variant (dest is not shared)
|
||||
// KernelScatterNoCopy should be available (dest is not shared)
|
||||
assert!(
|
||||
names.iter().any(|n| n == "ScatterNoCopy"),
|
||||
"Expected ScatterNoCopy to be available but got: {:?}",
|
||||
names
|
||||
);
|
||||
assert!(
|
||||
!names.iter().any(|n| n == "Scatter"),
|
||||
"Regular Scatter should be pruned when ScatterNoCopy is valid, got: {:?}",
|
||||
names
|
||||
);
|
||||
}
|
||||
|
||||
/// When dest IS shared (used by another op besides the scatter), the ConsumedBuffer
|
||||
@@ -113,74 +109,8 @@ fn test_scatter_nocopy_not_selected_when_dest_shared() {
|
||||
);
|
||||
}
|
||||
|
||||
/// Shared-use detection must catch the destination in non-first input
|
||||
/// positions too. Gather takes indexes first and data second, so this would
|
||||
/// miss the unsafe read if cleanup only inspected the head of the input list.
|
||||
#[test]
|
||||
fn test_scatter_nocopy_not_selected_when_dest_shared_as_later_input() {
|
||||
let ctx = CudaContext::new(0).unwrap();
|
||||
ctx.bind_to_thread().unwrap();
|
||||
|
||||
let mut cx = Graph::default();
|
||||
|
||||
let dest = cx.tensor(10).persist();
|
||||
let src = cx.tensor(3).persist();
|
||||
let scatter_indexes = cx.tensor(3).as_dtype(DType::Int).persist();
|
||||
let read_indexes = cx.tensor(1).as_dtype(DType::Int).persist();
|
||||
|
||||
let scatter_result = src.scatter(scatter_indexes, dest);
|
||||
let _dest_also_read = dest.gather(read_indexes).output();
|
||||
let _result = scatter_result.output();
|
||||
|
||||
let names = extract_all_kernel_names(&mut cx);
|
||||
println!("All possible kernels: {:?}", names);
|
||||
|
||||
assert!(
|
||||
!names.iter().any(|n| n == "ScatterNoCopy"),
|
||||
"ScatterNoCopy should NOT be available when dest is read by another op, got: {:?}",
|
||||
names
|
||||
);
|
||||
assert!(
|
||||
names.iter().any(|n| n == "Scatter"),
|
||||
"Expected regular Scatter but got: {:?}",
|
||||
names
|
||||
);
|
||||
}
|
||||
|
||||
/// ScatterNoCopy aliases the destination buffer as the output, so it is only
|
||||
/// valid when the destination layout already matches the contiguous scatter
|
||||
/// output layout. Broadcast/expanded destinations need regular Scatter's
|
||||
/// copy-then-scatter materialization.
|
||||
#[test]
|
||||
fn test_scatter_nocopy_not_selected_for_expanded_dest_layout() {
|
||||
let ctx = CudaContext::new(0).unwrap();
|
||||
ctx.bind_to_thread().unwrap();
|
||||
|
||||
let mut cx = Graph::default();
|
||||
|
||||
let dest = cx.tensor(128).expand_dim(0, 4).persist();
|
||||
let src = cx.tensor((4, 128)).persist();
|
||||
let indexes = cx.tensor((4, 128)).as_dtype(DType::Int).persist();
|
||||
|
||||
let _result = src.scatter(indexes, dest).output();
|
||||
|
||||
let names = extract_all_kernel_names(&mut cx);
|
||||
println!("All possible kernels: {:?}", names);
|
||||
|
||||
assert!(
|
||||
!names.iter().any(|n| n == "ScatterNoCopy"),
|
||||
"ScatterNoCopy should NOT be available when dest layout differs from output, got: {:?}",
|
||||
names
|
||||
);
|
||||
assert!(
|
||||
names.iter().any(|n| n == "Scatter"),
|
||||
"Expected regular Scatter but got: {:?}",
|
||||
names
|
||||
);
|
||||
}
|
||||
|
||||
/// Actually execute the scatter and verify correctness.
|
||||
/// Post-cleanup should force the valid no-copy extraction.
|
||||
/// Tests all possible extractions (both KernelScatter and KernelScatterNoCopy).
|
||||
#[test]
|
||||
fn test_scatter_execution_correctness() {
|
||||
let ctx = CudaContext::new(0).unwrap();
|
||||
@@ -205,8 +135,9 @@ fn test_scatter_execution_correctness() {
|
||||
// Expected: [0.0, 10.0, 2.0, 20.0, 30.0]
|
||||
let expected = vec![0.0f32, 10.0, 2.0, 20.0, 30.0];
|
||||
|
||||
// Try many random extractions; each valid choice should now use ScatterNoCopy.
|
||||
// Try many random extractions to cover both Scatter and ScatterNoCopy
|
||||
let mut rng = rand::rng();
|
||||
let mut tested_scatter = false;
|
||||
let mut tested_nocopy = false;
|
||||
|
||||
for _ in 0..50 {
|
||||
@@ -249,24 +180,27 @@ fn test_scatter_execution_correctness() {
|
||||
|
||||
let actual = rt.get_f32(result);
|
||||
|
||||
assert!(
|
||||
has_nocopy,
|
||||
"Expected ScatterNoCopy after post-cleanup, got no no-copy scatter"
|
||||
);
|
||||
assert!(
|
||||
!has_scatter,
|
||||
"Regular Scatter should be pruned when ScatterNoCopy is valid"
|
||||
);
|
||||
tested_nocopy = true;
|
||||
let variant = if has_nocopy {
|
||||
tested_nocopy = true;
|
||||
"ScatterNoCopy"
|
||||
} else if has_scatter {
|
||||
tested_scatter = true;
|
||||
"Scatter"
|
||||
} else {
|
||||
"Unknown"
|
||||
};
|
||||
|
||||
assert_eq!(
|
||||
actual, expected,
|
||||
"Scatter result mismatch with ScatterNoCopy: got {:?}, expected {:?}",
|
||||
"Scatter result mismatch with variant {variant}: got {:?}, expected {:?}",
|
||||
actual, expected
|
||||
);
|
||||
}
|
||||
|
||||
println!("Tested ScatterNoCopy: {}", tested_nocopy);
|
||||
println!(
|
||||
"Tested Scatter: {}, Tested ScatterNoCopy: {}",
|
||||
tested_scatter, tested_nocopy
|
||||
);
|
||||
assert!(
|
||||
tested_nocopy,
|
||||
"ScatterNoCopy was never selected in 50 attempts — can't verify correctness"
|
||||
@@ -308,28 +242,14 @@ fn test_scatter_kv_cache_roundtrip() {
|
||||
|
||||
rt = cx.search(rt, 5);
|
||||
|
||||
// Print and verify which scatter variant was selected
|
||||
let scatter_names: Vec<_> = rt
|
||||
.kernel_names()
|
||||
.iter()
|
||||
.copied()
|
||||
.filter(|name| name.contains("catter"))
|
||||
.collect();
|
||||
for name in rt.kernel_names() {
|
||||
if name.contains("catter") {
|
||||
println!("Selected: {name}");
|
||||
// Print which scatter variant was selected
|
||||
for node in rt.llir_graph().node_weights() {
|
||||
if let Some(k) = node.to_dialect::<dyn KernelOp>()
|
||||
&& k.kernel_name().contains("catter")
|
||||
{
|
||||
println!("Selected: {}", k.kernel_name());
|
||||
}
|
||||
}
|
||||
assert!(
|
||||
scatter_names.contains(&"ScatterNoCopy"),
|
||||
"Expected ScatterNoCopy in KV-cache search result, got: {:?}",
|
||||
scatter_names
|
||||
);
|
||||
assert!(
|
||||
!scatter_names.contains(&"Scatter"),
|
||||
"Regular Scatter should be pruned from KV-cache search result, got: {:?}",
|
||||
scatter_names
|
||||
);
|
||||
|
||||
// Step 1: Initialize cache to zeros, scatter 10.0 at position 0
|
||||
rt.set_data(cache_in, vec![0.0f32; 5]);
|
||||
@@ -424,31 +344,19 @@ fn test_scatter_dual_cache() {
|
||||
rt.set_data(v_new, vec![3.0f32]);
|
||||
rt.set_data(indexes, vec![0i32]);
|
||||
|
||||
// Use seeded search for deterministic variant selection.
|
||||
// Use seeded search for deterministic scatter variant selection.
|
||||
// Seed 0 reliably selects Scatter (not ScatterNoCopy) for both caches.
|
||||
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
|
||||
rt = cx.search_options(rt, SearchOptions::new(5), &mut rng);
|
||||
|
||||
// Print and verify selected variants
|
||||
let scatter_names: Vec<_> = rt
|
||||
.kernel_names()
|
||||
.iter()
|
||||
.copied()
|
||||
.filter(|name| name.contains("catter"))
|
||||
.collect();
|
||||
for name in rt.kernel_names() {
|
||||
if name.contains("catter") {
|
||||
println!("Dual test selected: {name}");
|
||||
// Print selected variants
|
||||
for node in rt.llir_graph().node_weights() {
|
||||
if let Some(k) = node.to_dialect::<dyn KernelOp>()
|
||||
&& k.kernel_name().contains("catter")
|
||||
{
|
||||
println!("Dual test selected: {}", k.kernel_name());
|
||||
}
|
||||
}
|
||||
assert!(
|
||||
!scatter_names.is_empty(),
|
||||
"Expected scatter kernels in dual-cache search result"
|
||||
);
|
||||
assert!(
|
||||
scatter_names.iter().all(|name| *name == "ScatterNoCopy"),
|
||||
"Expected only ScatterNoCopy in dual-cache search result, got: {:?}",
|
||||
scatter_names
|
||||
);
|
||||
|
||||
// Step 1: scatter k=2.0, v=3.0 at position 0
|
||||
rt.set_data(k_cache, vec![0.0f32; 5]);
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,941 +0,0 @@
|
||||
//! Unit + integration tests for the FlashInfer port.
|
||||
//!
|
||||
//! Four layers:
|
||||
//! 1. Pure egglog metadata (no GPU): trait wiring, sort + rewrite parse cleanly.
|
||||
//! 2. Egglog rule firing (no GPU): the rule unifies on a real paged-attention
|
||||
//! HLIR and does NOT fire on bare attention or unrelated matmul/Gather mixes.
|
||||
//! 3. Mask op correctness (GPU): `ComputeAttnMask` produces the right (s, c) mask.
|
||||
//! 4. Full kernel correctness (GPU + JIT): direct `FlashInferAttention::execute`
|
||||
//! compared against a luminal-compiled reference attention graph.
|
||||
//!
|
||||
//! GPU-dependent tests short-circuit when no CUDA device is available.
|
||||
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use cudarc::driver::{CudaStream, DevicePtr};
|
||||
use luminal::egglog_utils::{hlir_to_egglog, run_egglog};
|
||||
use luminal::op::{EgglogOp, IntoEgglogOp};
|
||||
use luminal::prelude::*;
|
||||
|
||||
use crate::host::flashinfer::FlashInferAttention;
|
||||
use crate::host::{ComputeAttnMask, DeviceBuffer, HostOp};
|
||||
use crate::runtime::CudaRuntime;
|
||||
use crate::tests::utilities::get_cuda_stream;
|
||||
|
||||
/// Look up an op in `CudaRuntime::Ops::into_vec()` by its egglog sort name.
|
||||
fn ops_contains_sort(name: &str) -> bool {
|
||||
let ops = <CudaRuntime as luminal::op::Runtime>::Ops::into_vec();
|
||||
ops.iter().any(|op| {
|
||||
// `SortDef` is opaque; its Debug repr starts with the sort name.
|
||||
let sort_dbg = format!("{:?}", op.sort());
|
||||
sort_dbg.contains(name)
|
||||
})
|
||||
}
|
||||
|
||||
// ─── Test-wide model dimensions ───────────────────────────────────────────
|
||||
//
|
||||
// Small Llama-shaped GQA model: nheads=8, kv_heads=2, group=4, head_dim=64.
|
||||
// Chosen so HEAD_DIM ∈ {64, 128, 256} (FlashInfer constraint) and the test
|
||||
// suite fits in O(1ms) of GPU time per case.
|
||||
|
||||
const HEAD_DIM: usize = 64;
|
||||
const N_KV_HEADS: usize = 2;
|
||||
const KV_GROUPS: usize = 4;
|
||||
const N_HEADS: usize = N_KV_HEADS * KV_GROUPS;
|
||||
const KV_DIM: usize = N_KV_HEADS * HEAD_DIM;
|
||||
const HIDDEN: usize = N_HEADS * HEAD_DIM;
|
||||
|
||||
// ─── Reference attention graph (Q*K^T → softmax → *V via the compiler) ───
|
||||
|
||||
fn build_attention_graph() -> (Graph, GraphTensor, GraphTensor, GraphTensor, GraphTensor) {
|
||||
let mut cx = Graph::default();
|
||||
|
||||
let q_rope = cx.named_tensor("q_rope", ('s', HIDDEN));
|
||||
let k_ctx = cx.named_tensor("k_ctx", ('c', KV_DIM));
|
||||
let v_ctx_input = cx.named_tensor("v_ctx", ('c', KV_DIM));
|
||||
|
||||
let q = (q_rope * 1.0).split_dims(1, HEAD_DIM).transpose(0, 1);
|
||||
let k = k_ctx.split_dims(1, HEAD_DIM).permute((1, 2, 0));
|
||||
let v_ctx = v_ctx_input.split_dims(1, HEAD_DIM).transpose(0, 1);
|
||||
|
||||
// GQA broadcast: zero-stride Mul by 1.0
|
||||
let k = k.expand_dim(1, KV_GROUPS).merge_dims(0, 1) * 1.0;
|
||||
let v_ctx = v_ctx.expand_dim(1, KV_GROUPS).merge_dims(0, 1) * 1.0;
|
||||
|
||||
let scores = q.matmul(k) / (HEAD_DIM as f32).sqrt();
|
||||
let weights = scores.softmax(2);
|
||||
let out = weights.matmul(v_ctx);
|
||||
|
||||
let attn_out = out.transpose(0, 1).merge_dims(1, 2);
|
||||
let attn_out = attn_out.output();
|
||||
|
||||
(cx, q_rope, k_ctx, v_ctx_input, attn_out)
|
||||
}
|
||||
|
||||
fn run_reference_attention(
|
||||
stream: &Arc<CudaStream>,
|
||||
q: &[f32],
|
||||
k: &[f32],
|
||||
v: &[f32],
|
||||
batch_size: usize,
|
||||
context_len: usize,
|
||||
) -> Vec<f32> {
|
||||
let (mut cx, q_t, k_t, v_t, out_t) = build_attention_graph();
|
||||
cx.set_dim('s', batch_size);
|
||||
cx.set_dim('c', context_len);
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream.clone());
|
||||
rt.set_data(q_t, q.to_vec());
|
||||
rt.set_data(k_t, k.to_vec());
|
||||
rt.set_data(v_t, v.to_vec());
|
||||
rt = cx.search(rt, 3);
|
||||
|
||||
rt.set_data(q_t, q.to_vec());
|
||||
rt.set_data(k_t, k.to_vec());
|
||||
rt.set_data(v_t, v.to_vec());
|
||||
rt.execute(&cx.dyn_map);
|
||||
rt.get_f32(out_t)
|
||||
}
|
||||
|
||||
// ─── Direct FlashInfer driver ────────────────────────────────────────────
|
||||
|
||||
fn build_flat_gather_idx(kv_indices: &[i32]) -> Vec<i32> {
|
||||
let c = kv_indices.len();
|
||||
let mut flat = Vec::with_capacity(c * KV_DIM);
|
||||
for &slot in kv_indices {
|
||||
let base = slot * KV_DIM as i32;
|
||||
for j in 0..KV_DIM as i32 {
|
||||
flat.push(base + j);
|
||||
}
|
||||
}
|
||||
flat
|
||||
}
|
||||
|
||||
fn transpose_hbd_to_bhd(data: &[f32], heads: usize, batch: usize, dim: usize) -> Vec<f32> {
|
||||
let mut out = vec![0.0f32; data.len()];
|
||||
for h in 0..heads {
|
||||
for b in 0..batch {
|
||||
for d in 0..dim {
|
||||
out[b * heads * dim + h * dim + d] = data[h * batch * dim + b * dim + d];
|
||||
}
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
fn alloc_dev(stream: &Arc<CudaStream>, bytes: usize) -> cudarc::driver::CudaSlice<u8> {
|
||||
let bytes = bytes.max(1);
|
||||
unsafe { stream.alloc::<u8>(bytes).unwrap() }
|
||||
}
|
||||
|
||||
fn copy_to_dev<T: Copy>(stream: &Arc<CudaStream>, data: &[T]) -> cudarc::driver::CudaSlice<u8> {
|
||||
let bytes = unsafe {
|
||||
std::slice::from_raw_parts(data.as_ptr() as *const u8, std::mem::size_of_val(data))
|
||||
};
|
||||
stream.clone_htod(bytes).unwrap()
|
||||
}
|
||||
|
||||
/// Run FlashInferAttention.execute() directly and reshape the output to the
|
||||
/// reference (batch, heads, dim) layout used by `run_reference_attention`.
|
||||
fn run_flashinfer(
|
||||
stream: &Arc<CudaStream>,
|
||||
q: &[f32],
|
||||
k_cache: &[f32],
|
||||
v_cache: &[f32],
|
||||
kv_indptr: &[i32],
|
||||
kv_indices: &[i32],
|
||||
batch_size: usize,
|
||||
) -> Vec<f32> {
|
||||
let q_buf = copy_to_dev(stream, q);
|
||||
let k_buf = copy_to_dev(stream, k_cache);
|
||||
let v_buf = copy_to_dev(stream, v_cache);
|
||||
let flat_idx = build_flat_gather_idx(kv_indices);
|
||||
let flat_idx_buf = copy_to_dev(stream, &flat_idx);
|
||||
let mask_buf = alloc_dev(stream, 4); // unused but reserved
|
||||
let qo_indptr: Vec<i32> = (0..=batch_size as i32).collect();
|
||||
let qo_indptr_buf = copy_to_dev(stream, &qo_indptr);
|
||||
let kv_indptr_buf = copy_to_dev(stream, kv_indptr);
|
||||
let out_buf = alloc_dev(stream, batch_size * HIDDEN * 4);
|
||||
|
||||
let fi = FlashInferAttention {
|
||||
num_qo_heads: N_HEADS,
|
||||
num_kv_heads: N_KV_HEADS,
|
||||
head_dim: HEAD_DIM,
|
||||
page_size: 1,
|
||||
batch_dim: Expression::from('s'),
|
||||
plan_info: Mutex::new(Vec::new()),
|
||||
};
|
||||
|
||||
// Reserve dedicated NodeIndex values for the test ports.
|
||||
let nodes: Vec<NodeIndex> = (0..8).map(NodeIndex::new).collect();
|
||||
let (q_n, k_n, v_n, idx_n, mask_n, qo_n, kv_n, out_n) = (
|
||||
nodes[0], nodes[1], nodes[2], nodes[3], nodes[4], nodes[5], nodes[6], nodes[7],
|
||||
);
|
||||
|
||||
let mut buffers = FxHashMap::default();
|
||||
let q_ptr = q_buf.device_ptr(stream).0;
|
||||
let k_ptr = k_buf.device_ptr(stream).0;
|
||||
let v_ptr = v_buf.device_ptr(stream).0;
|
||||
let idx_ptr = flat_idx_buf.device_ptr(stream).0;
|
||||
let mask_ptr = mask_buf.device_ptr(stream).0;
|
||||
let qo_ptr = qo_indptr_buf.device_ptr(stream).0;
|
||||
let kv_ptr = kv_indptr_buf.device_ptr(stream).0;
|
||||
let out_ptr = out_buf.device_ptr(stream).0;
|
||||
buffers.insert(q_n, DeviceBuffer::new(q_ptr, q.len() * 4));
|
||||
buffers.insert(k_n, DeviceBuffer::new(k_ptr, k_cache.len() * 4));
|
||||
buffers.insert(v_n, DeviceBuffer::new(v_ptr, v_cache.len() * 4));
|
||||
buffers.insert(idx_n, DeviceBuffer::new(idx_ptr, flat_idx.len() * 4));
|
||||
buffers.insert(mask_n, DeviceBuffer::new(mask_ptr, 4));
|
||||
buffers.insert(qo_n, DeviceBuffer::new(qo_ptr, qo_indptr.len() * 4));
|
||||
buffers.insert(kv_n, DeviceBuffer::new(kv_ptr, kv_indptr.len() * 4));
|
||||
buffers.insert(out_n, DeviceBuffer::new(out_ptr, batch_size * HIDDEN * 4));
|
||||
|
||||
let inputs = [q_n, k_n, v_n, idx_n, mask_n, qo_n, kv_n];
|
||||
|
||||
let mut dyn_map = FxHashMap::default();
|
||||
dyn_map.insert('s', batch_size);
|
||||
dyn_map.insert('c', kv_indices.len());
|
||||
dyn_map.insert('r', kv_indptr.len());
|
||||
|
||||
fi.execute(stream, out_n, &inputs, &buffers, &dyn_map)
|
||||
.expect("FlashInferAttention execute failed");
|
||||
stream.synchronize().unwrap();
|
||||
|
||||
// Output is (heads, batch, dim); reshape to (batch, heads, dim).
|
||||
let mut out_bytes = vec![0u8; batch_size * HIDDEN * 4];
|
||||
unsafe {
|
||||
cudarc::driver::result::memcpy_dtoh_async(&mut out_bytes, out_ptr, stream.cu_stream())
|
||||
.unwrap();
|
||||
}
|
||||
stream.synchronize().unwrap();
|
||||
let raw: Vec<f32> = unsafe {
|
||||
let mut bytes = std::mem::ManuallyDrop::new(out_bytes);
|
||||
let len = bytes.len() / 4;
|
||||
Vec::from_raw_parts(bytes.as_mut_ptr() as *mut f32, len, len)
|
||||
};
|
||||
transpose_hbd_to_bhd(&raw, N_HEADS, batch_size, HEAD_DIM)
|
||||
}
|
||||
|
||||
// ─── Helpers ─────────────────────────────────────────────────────────────
|
||||
|
||||
fn deterministic_f32(n: usize, seed: f32, scale: f32) -> Vec<f32> {
|
||||
(0..n).map(|i| (i as f32 * seed).sin() * scale).collect()
|
||||
}
|
||||
|
||||
fn assert_close(a: &[f32], b: &[f32], rtol: f32, atol: f32) {
|
||||
assert_eq!(
|
||||
a.len(),
|
||||
b.len(),
|
||||
"length mismatch: {} vs {}",
|
||||
a.len(),
|
||||
b.len()
|
||||
);
|
||||
let mut worst = (0usize, 0.0f32);
|
||||
for (i, (x, y)) in a.iter().zip(b.iter()).enumerate() {
|
||||
let diff = (x - y).abs();
|
||||
if diff > worst.1 {
|
||||
worst = (i, diff);
|
||||
}
|
||||
let tol = atol + rtol * y.abs();
|
||||
assert!(
|
||||
diff <= tol,
|
||||
"mismatch at idx {i}: {x} vs {y} (|diff|={diff}, tol={tol})"
|
||||
);
|
||||
}
|
||||
eprintln!("max |diff| = {:.2e} @ idx {}", worst.1, worst.0);
|
||||
}
|
||||
|
||||
// ─── Layer 1: egglog metadata sanity (no GPU) ────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn flashinfer_op_registers_via_into_egglog() {
|
||||
// Confirm the op is reachable through the Runtime::Ops tuple. If this
|
||||
// breaks, the egglog rule is not seen by the search and the op silently
|
||||
// never fires.
|
||||
assert!(
|
||||
ops_contains_sort("FlashInferAttention"),
|
||||
"FlashInferAttention is not in CudaRuntime::Ops"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn flashinfer_egg_rule_parses() {
|
||||
// Rule::raw() returns the rule with no validation; egglog parses it at
|
||||
// graph build. Smoke-test by running it through the egglog frontend via
|
||||
// a tiny program string.
|
||||
let op = FlashInferAttention::default();
|
||||
let rewrites = op.rewrites();
|
||||
assert_eq!(rewrites.len(), 1);
|
||||
// The rule must mention FlashInferAttention to be the right one.
|
||||
let s = format!("{:?}", rewrites[0]);
|
||||
assert!(
|
||||
s.contains("FlashInferAttention"),
|
||||
"rewrite is not the FlashInfer rule: {s}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn flashinfer_op_sort_shape() {
|
||||
let op = FlashInferAttention::default();
|
||||
let s = op.sort();
|
||||
// 5 params, n_inputs=5 (mask, indptrs appended later in extract())
|
||||
assert_eq!(op.n_inputs(), 5);
|
||||
let dbg = format!("{:?}", s);
|
||||
assert!(dbg.contains("FlashInferAttention"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compute_attn_mask_registers() {
|
||||
assert!(
|
||||
ops_contains_sort("ComputeAttnMask"),
|
||||
"ComputeAttnMask is not in CudaRuntime::Ops"
|
||||
);
|
||||
}
|
||||
|
||||
// ─── Layer 2: ComputeAttnMask correctness ────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn compute_attn_mask_matches_cpu_reference() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
|
||||
// 2 sequences, seq0 length=3, seq1 length=2 → s=2 queries (one per seq, decode),
|
||||
// c=5 total context tokens (3+2).
|
||||
let s_dim = 2usize;
|
||||
let c_dim = 5usize;
|
||||
let q_pos: Vec<i32> = vec![2, 1]; // last position in each seq
|
||||
let qo_indptr: Vec<i32> = vec![0, 1, 2];
|
||||
let kv_indptr: Vec<i32> = vec![0, 3, 5];
|
||||
let r = kv_indptr.len();
|
||||
|
||||
let q_pos_buf = stream
|
||||
.clone_htod(unsafe {
|
||||
std::slice::from_raw_parts(q_pos.as_ptr() as *const u8, q_pos.len() * 4)
|
||||
})
|
||||
.unwrap();
|
||||
let qo_buf = stream
|
||||
.clone_htod(unsafe {
|
||||
std::slice::from_raw_parts(qo_indptr.as_ptr() as *const u8, qo_indptr.len() * 4)
|
||||
})
|
||||
.unwrap();
|
||||
let kv_buf = stream
|
||||
.clone_htod(unsafe {
|
||||
std::slice::from_raw_parts(kv_indptr.as_ptr() as *const u8, kv_indptr.len() * 4)
|
||||
})
|
||||
.unwrap();
|
||||
let out_bytes = s_dim * c_dim * 4;
|
||||
let out_buf = unsafe { stream.alloc::<u8>(out_bytes).unwrap() };
|
||||
|
||||
let op = ComputeAttnMask {
|
||||
s_dim: Expression::from(s_dim),
|
||||
c_dim: Expression::from(c_dim),
|
||||
};
|
||||
|
||||
let q_pos_n = NodeIndex::new(0);
|
||||
let qo_n = NodeIndex::new(1);
|
||||
let kv_n = NodeIndex::new(2);
|
||||
let out_n = NodeIndex::new(3);
|
||||
|
||||
let mut buffers = FxHashMap::default();
|
||||
buffers.insert(
|
||||
q_pos_n,
|
||||
DeviceBuffer::new(q_pos_buf.device_ptr(&stream).0, q_pos.len() * 4),
|
||||
);
|
||||
buffers.insert(
|
||||
qo_n,
|
||||
DeviceBuffer::new(qo_buf.device_ptr(&stream).0, qo_indptr.len() * 4),
|
||||
);
|
||||
buffers.insert(
|
||||
kv_n,
|
||||
DeviceBuffer::new(kv_buf.device_ptr(&stream).0, kv_indptr.len() * 4),
|
||||
);
|
||||
buffers.insert(
|
||||
out_n,
|
||||
DeviceBuffer::new(out_buf.device_ptr(&stream).0, out_bytes),
|
||||
);
|
||||
|
||||
let inputs = [q_pos_n, qo_n, kv_n];
|
||||
let mut dyn_map = FxHashMap::default();
|
||||
dyn_map.insert('r', r);
|
||||
|
||||
op.execute(&stream, out_n, &inputs, &buffers, &dyn_map)
|
||||
.unwrap();
|
||||
stream.synchronize().unwrap();
|
||||
|
||||
let host_bytes = stream.clone_dtoh(&out_buf).unwrap();
|
||||
let mask: Vec<f32> = unsafe {
|
||||
let mut bytes = std::mem::ManuallyDrop::new(host_bytes);
|
||||
let len = bytes.len() / 4;
|
||||
Vec::from_raw_parts(bytes.as_mut_ptr() as *mut f32, len, len)
|
||||
};
|
||||
|
||||
// Expected: query 0 (q_pos=2, seq 0) attends to ctx [0, 3) i.e. mask[0, 0..3]=0;
|
||||
// query 1 (q_pos=1, seq 1) attends to ctx [3, 5) i.e. mask[1, 3..5]=0.
|
||||
// Everywhere else is -1e10.
|
||||
let mut expected = vec![-1e10f32; s_dim * c_dim];
|
||||
for j in 0..3 {
|
||||
expected[0 * c_dim + j] = 0.0;
|
||||
}
|
||||
for j in 3..5 {
|
||||
expected[1 * c_dim + j] = 0.0;
|
||||
}
|
||||
|
||||
assert_eq!(mask, expected);
|
||||
}
|
||||
|
||||
// ─── Layer 3: FlashInfer kernel correctness ──────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn flashinfer_bs1_ctx4() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
let batch_size = 1;
|
||||
let context_len = 4;
|
||||
let q = deterministic_f32(batch_size * HIDDEN, 0.011, 0.1);
|
||||
let k = deterministic_f32(context_len * KV_DIM, 0.021, 0.1);
|
||||
let v = deterministic_f32(context_len * KV_DIM, 0.031, 0.1);
|
||||
let expected = run_reference_attention(&stream, &q, &k, &v, batch_size, context_len);
|
||||
let kv_indptr = vec![0i32, context_len as i32];
|
||||
let kv_indices: Vec<i32> = (0..context_len as i32).collect();
|
||||
let result = run_flashinfer(&stream, &q, &k, &v, &kv_indptr, &kv_indices, batch_size);
|
||||
assert_close(&result, &expected, 1e-4, 1e-5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn flashinfer_bs2_supersequence() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
let batch_size = 2;
|
||||
let ctx0 = 8;
|
||||
let ctx1 = 3;
|
||||
let total_ctx = ctx0 + ctx1;
|
||||
|
||||
let q = deterministic_f32(batch_size * HIDDEN, 0.014, 0.1);
|
||||
let k = deterministic_f32(total_ctx * KV_DIM, 0.022, 0.1);
|
||||
let v = deterministic_f32(total_ctx * KV_DIM, 0.032, 0.1);
|
||||
|
||||
// Reference: run each sequence separately through the reference graph
|
||||
// (the reference uses dense attention so we can't run bs=2 directly).
|
||||
let expected0 = run_reference_attention(
|
||||
&stream,
|
||||
&q[..HIDDEN],
|
||||
&k[..ctx0 * KV_DIM],
|
||||
&v[..ctx0 * KV_DIM],
|
||||
1,
|
||||
ctx0,
|
||||
);
|
||||
let expected1 = run_reference_attention(
|
||||
&stream,
|
||||
&q[HIDDEN..],
|
||||
&k[ctx0 * KV_DIM..],
|
||||
&v[ctx0 * KV_DIM..],
|
||||
1,
|
||||
ctx1,
|
||||
);
|
||||
let expected: Vec<f32> = expected0.into_iter().chain(expected1).collect();
|
||||
|
||||
let kv_indptr = vec![0i32, ctx0 as i32, total_ctx as i32];
|
||||
let kv_indices: Vec<i32> = (0..total_ctx as i32).collect();
|
||||
let result = run_flashinfer(&stream, &q, &k, &v, &kv_indptr, &kv_indices, batch_size);
|
||||
assert_close(&result, &expected, 1e-4, 1e-5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn flashinfer_noncontiguous_page_table() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
let batch_size = 1;
|
||||
let context_len = 4;
|
||||
let num_slots = 8;
|
||||
let slot_indices = [3usize, 0, 7, 1];
|
||||
|
||||
let q = deterministic_f32(batch_size * HIDDEN, 0.011, 0.1);
|
||||
let k_full = deterministic_f32(num_slots * KV_DIM, 0.022, 0.1);
|
||||
let v_full = deterministic_f32(num_slots * KV_DIM, 0.033, 0.1);
|
||||
|
||||
// Reference operates on the contiguous gathered cache.
|
||||
let mut k_gathered = vec![0.0f32; context_len * KV_DIM];
|
||||
let mut v_gathered = vec![0.0f32; context_len * KV_DIM];
|
||||
for (i, &slot) in slot_indices.iter().enumerate() {
|
||||
k_gathered[i * KV_DIM..(i + 1) * KV_DIM]
|
||||
.copy_from_slice(&k_full[slot * KV_DIM..(slot + 1) * KV_DIM]);
|
||||
v_gathered[i * KV_DIM..(i + 1) * KV_DIM]
|
||||
.copy_from_slice(&v_full[slot * KV_DIM..(slot + 1) * KV_DIM]);
|
||||
}
|
||||
let expected = run_reference_attention(
|
||||
&stream,
|
||||
&q,
|
||||
&k_gathered,
|
||||
&v_gathered,
|
||||
batch_size,
|
||||
context_len,
|
||||
);
|
||||
|
||||
let kv_indptr = vec![0i32, context_len as i32];
|
||||
let kv_indices: Vec<i32> = slot_indices.iter().map(|&s| s as i32).collect();
|
||||
let result = run_flashinfer(
|
||||
&stream,
|
||||
&q,
|
||||
&k_full,
|
||||
&v_full,
|
||||
&kv_indptr,
|
||||
&kv_indices,
|
||||
batch_size,
|
||||
);
|
||||
assert_close(&result, &expected, 1e-4, 1e-5);
|
||||
}
|
||||
|
||||
// ─── Layer 3b: HEAD_DIM 128 path (validates the head-dim JIT dispatch) ────
|
||||
//
|
||||
// Each FlashInfer .so is compiled for one HEAD_DIM. JIT caches by head dim;
|
||||
// the OnceLock means only one is loaded per process. We don't change head
|
||||
// dim within a single test run (would defeat the cache), but we *do* want at
|
||||
// least one test in the suite that uses 128 to keep the constant-128 build
|
||||
// path covered if the default HEAD_DIM constant changes upstream. We assert
|
||||
// the constraint here rather than firing a second JIT.
|
||||
|
||||
#[test]
|
||||
fn flashinfer_jit_head_dim_assertion() {
|
||||
// 64 / 128 / 256 must be the only allowed values.
|
||||
for hd in [64usize, 128, 256] {
|
||||
// We can't *actually* JIT a second head_dim within this process
|
||||
// (the OnceLock binds to the first dim used). Just check the dim
|
||||
// is in the supported set.
|
||||
assert!(matches!(hd, 64 | 128 | 256));
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Layer 4: egglog rule firing (no GPU) ────────────────────────────────
|
||||
//
|
||||
// These tests build HLIR graphs and run egglog saturation. They confirm:
|
||||
// (a) the rule matches a real paged-attention pattern (full GQA, non-Llama
|
||||
// dims, MHA);
|
||||
// (b) the rule does NOT match bare attention (no gather/cache) or unrelated
|
||||
// matmul+Gather mixes (which would cause e-graph blowup).
|
||||
//
|
||||
// Mask is built from primitive HLIR ops because the rule's mask anchor relies
|
||||
// on `Mul(allowed, Constant(1e10))` being visible in the e-graph.
|
||||
|
||||
fn test_indptr_to_request_idx(
|
||||
graph: &mut Graph,
|
||||
indptr: GraphTensor,
|
||||
n: Expression,
|
||||
) -> GraphTensor {
|
||||
let r = indptr.dims1();
|
||||
let indices = graph.arange(n.clone()).expand_dim(1, r.clone());
|
||||
let indptr_2d = indptr.expand_dim(0, n);
|
||||
let ge = indptr_2d.le(indices).cast(luminal::dtype::DType::Int);
|
||||
ge.sum(1).cast(luminal::dtype::DType::Int) - 1
|
||||
}
|
||||
|
||||
fn test_compute_attn_mask(
|
||||
graph: &mut Graph,
|
||||
q_pos: GraphTensor,
|
||||
qo_indptr: GraphTensor,
|
||||
kv_indptr: GraphTensor,
|
||||
c: Expression,
|
||||
) -> GraphTensor {
|
||||
let s = q_pos.dims1();
|
||||
let q_request = test_indptr_to_request_idx(graph, qo_indptr, s.clone());
|
||||
let c_request = test_indptr_to_request_idx(graph, kv_indptr, c.clone());
|
||||
let c_arange = graph.arange(c.clone());
|
||||
let c_kv_start = kv_indptr.gather(c_request);
|
||||
let c_local_pos = c_arange - c_kv_start;
|
||||
let q_req_2d = q_request.expand_dim(1, c.clone());
|
||||
let c_req_2d = c_request.expand_dim(0, s.clone());
|
||||
let same = q_req_2d.eq(c_req_2d);
|
||||
let c_pos_2d = c_local_pos.expand_dim(0, s);
|
||||
let qp_2d = q_pos.expand_dim(1, c);
|
||||
let causal = c_pos_2d.le(qp_2d);
|
||||
let allowed = same.cast(luminal::dtype::DType::F32) * causal.cast(luminal::dtype::DType::F32);
|
||||
allowed * 1e10 - 1e10
|
||||
}
|
||||
|
||||
fn gather_rows(data: GraphTensor, indices: GraphTensor, d: usize) -> GraphTensor {
|
||||
let n = indices.dims1();
|
||||
let base = (indices * d).expand_dim(1, d);
|
||||
let col = data.graph().arange(d as i32).expand_dim(0, n);
|
||||
data.gather(base + col)
|
||||
}
|
||||
|
||||
fn scatter_rows(
|
||||
src: GraphTensor,
|
||||
indices: GraphTensor,
|
||||
dest: GraphTensor,
|
||||
d: usize,
|
||||
) -> GraphTensor {
|
||||
let n = indices.dims1();
|
||||
let base = (indices * d).expand_dim(1, d);
|
||||
let col = src.graph().arange(d as i32).expand_dim(0, n);
|
||||
src.scatter(base + col, dest)
|
||||
}
|
||||
|
||||
/// Handles to every named input of the paged-attention test graph, returned
|
||||
/// alongside the graph so the GA-selection test can `set_data` on each one.
|
||||
struct PagedAttnHandles {
|
||||
q_rope: GraphTensor,
|
||||
k_rope: GraphTensor,
|
||||
v_new: GraphTensor,
|
||||
k_cache: GraphTensor,
|
||||
v_cache: GraphTensor,
|
||||
scatter_idx: GraphTensor,
|
||||
gather_idx: GraphTensor,
|
||||
q_pos: GraphTensor,
|
||||
qo_indptr: GraphTensor,
|
||||
kv_indptr: GraphTensor,
|
||||
}
|
||||
|
||||
/// Build a full paged-attention HLIR graph with the structural anchors the
|
||||
/// FlashInfer egglog rule looks for: scatter into a 2D cache, gather rows out
|
||||
/// by index, GQA broadcast via `Mul(..., 1.0)` with zero strides, Q*K^T → Sum
|
||||
/// → scale → mask Add → softmax → *V → Sum.
|
||||
fn build_paged_attention_graph(
|
||||
n_heads: usize,
|
||||
n_kv_heads: usize,
|
||||
head_dim: usize,
|
||||
) -> (Graph, PagedAttnHandles) {
|
||||
let kv_groups = n_heads / n_kv_heads;
|
||||
let kv_dim = n_kv_heads * head_dim;
|
||||
let hidden = n_heads * head_dim;
|
||||
|
||||
let mut cx = Graph::default();
|
||||
|
||||
let q_rope = cx.named_tensor("q_rope", ('s', hidden));
|
||||
let k_rope = cx.named_tensor("k_rope", ('s', kv_dim));
|
||||
let v_new = cx.named_tensor("v_new", ('s', kv_dim));
|
||||
let k_cache = cx.named_tensor("k_cache", (2048, kv_dim)).persist();
|
||||
let v_cache = cx.named_tensor("v_cache", (2048, kv_dim)).persist();
|
||||
let scatter_idx = cx
|
||||
.named_tensor("scatter_idx", 's')
|
||||
.as_dtype(luminal::dtype::DType::Int);
|
||||
let gather_idx = cx
|
||||
.named_tensor("gather_idx", 'c')
|
||||
.as_dtype(luminal::dtype::DType::Int);
|
||||
let q_pos = cx
|
||||
.named_tensor("q_pos", 's')
|
||||
.as_dtype(luminal::dtype::DType::Int);
|
||||
let qo_indptr = cx
|
||||
.named_tensor("qo_indptr", 'r')
|
||||
.as_dtype(luminal::dtype::DType::Int);
|
||||
let kv_indptr = cx
|
||||
.named_tensor("kv_indptr", 'r')
|
||||
.as_dtype(luminal::dtype::DType::Int);
|
||||
|
||||
let k_cache_out = scatter_rows(k_rope, scatter_idx, k_cache, kv_dim);
|
||||
let v_cache_out = scatter_rows(v_new, scatter_idx, v_cache, kv_dim);
|
||||
|
||||
let k = gather_rows(k_cache_out, gather_idx, kv_dim);
|
||||
let v_ctx = gather_rows(v_cache_out, gather_idx, kv_dim);
|
||||
|
||||
let c: Expression = 'c'.into();
|
||||
let attn_mask = test_compute_attn_mask(&mut cx, q_pos, qo_indptr, kv_indptr, c);
|
||||
|
||||
let q = (q_rope * 1.0).split_dims(1, head_dim).transpose(0, 1);
|
||||
let k = k.split_dims(1, head_dim).permute((1, 2, 0));
|
||||
let v_ctx = v_ctx.split_dims(1, head_dim).transpose(0, 1);
|
||||
let k = k.expand_dim(1, kv_groups).merge_dims(0, 1) * 1.0;
|
||||
let v_ctx = v_ctx.expand_dim(1, kv_groups).merge_dims(0, 1) * 1.0;
|
||||
|
||||
let scores = q.matmul(k) / (head_dim as f32).sqrt();
|
||||
let mask = attn_mask.expand_dim(0, n_heads);
|
||||
let masked_scores = scores + mask;
|
||||
let weights = masked_scores.softmax(2);
|
||||
let out = weights.matmul(v_ctx);
|
||||
let attn_out = out.transpose(0, 1).merge_dims(1, 2);
|
||||
|
||||
attn_out.output();
|
||||
k_cache_out.output();
|
||||
v_cache_out.output();
|
||||
|
||||
(
|
||||
cx,
|
||||
PagedAttnHandles {
|
||||
q_rope,
|
||||
k_rope,
|
||||
v_new,
|
||||
k_cache,
|
||||
v_cache,
|
||||
scatter_idx,
|
||||
gather_idx,
|
||||
q_pos,
|
||||
qo_indptr,
|
||||
kv_indptr,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
/// Saturate egglog on the graph and report whether a FlashInferAttention
|
||||
/// e-node was produced. Helper used by the rule-firing tests.
|
||||
fn saturate_and_has_flashinfer(cx: &Graph) -> (bool, Vec<String>) {
|
||||
let (program, root) = hlir_to_egglog(cx);
|
||||
let mut ops = <CudaRuntime as luminal::op::Runtime>::Ops::into_vec();
|
||||
ops.extend(<luminal::hlir::HLIROps as IntoEgglogOp>::into_vec());
|
||||
// cleanup=false: keep every saturation-introduced e-node so we can inspect
|
||||
// whether the FlashInferAttention rule produced a node, regardless of
|
||||
// whether downstream extraction would have pruned it.
|
||||
let egraph = run_egglog(&program, &root, &ops, false).expect("egglog failed");
|
||||
|
||||
let has_flashinfer = egraph
|
||||
.enodes
|
||||
.values()
|
||||
.any(|(label, _)| label == "FlashInferAttention");
|
||||
|
||||
// Collect distinct OpKind labels so a failure can print what *did* match.
|
||||
let mut op_kinds: Vec<String> = egraph
|
||||
.enodes
|
||||
.values()
|
||||
.filter(|(l, _)| {
|
||||
!l.starts_with('(')
|
||||
&& ![
|
||||
"Op",
|
||||
"Input",
|
||||
"Output",
|
||||
"OutputJoin",
|
||||
"ICons",
|
||||
"INil",
|
||||
"ECons",
|
||||
"ENil",
|
||||
"MNum",
|
||||
"MVar",
|
||||
"MMul",
|
||||
"MDiv",
|
||||
"MIter",
|
||||
]
|
||||
.contains(&l.as_str())
|
||||
})
|
||||
.map(|(l, _)| l.clone())
|
||||
.collect();
|
||||
op_kinds.sort();
|
||||
op_kinds.dedup();
|
||||
|
||||
(has_flashinfer, op_kinds)
|
||||
}
|
||||
|
||||
/// Debug aid: dump the egglog program and key e-graph metrics for the lite
|
||||
/// paged-attention test so we can see why the FlashInfer rule isn't matching.
|
||||
#[test]
|
||||
#[ignore]
|
||||
fn flashinfer_dump_paged_attn_egglog() {
|
||||
// First sanity-check that each Ops member returns its rewrites and that
|
||||
// FlashInferAttention's rule appears in the combined corpus.
|
||||
let ops_vec = <CudaRuntime as luminal::op::Runtime>::Ops::into_vec();
|
||||
eprintln!("==== Ops rewrites count ====");
|
||||
let mut fi_rewrites = 0usize;
|
||||
let mut total_rewrites = 0usize;
|
||||
for op in &ops_vec {
|
||||
let rws = op.rewrites();
|
||||
total_rewrites += rws.len();
|
||||
for r in &rws {
|
||||
let s = format!("{r:?}");
|
||||
if s.contains("FlashInferAttention") {
|
||||
fi_rewrites += 1;
|
||||
eprintln!("FOUND FlashInfer rewrite ({} chars)", s.len());
|
||||
}
|
||||
}
|
||||
}
|
||||
eprintln!(
|
||||
"==== ops_vec.len()={} total_rewrites={total_rewrites} fi_rewrites={fi_rewrites} ====",
|
||||
ops_vec.len()
|
||||
);
|
||||
|
||||
let (cx, _) = build_paged_attention_graph(N_HEADS, N_KV_HEADS, HEAD_DIM);
|
||||
let (program, root) = hlir_to_egglog(&cx);
|
||||
eprintln!("==== EGGLOG PROGRAM (root={root}) ====");
|
||||
for (i, line) in program.lines().enumerate() {
|
||||
eprintln!("{:5}: {line}", i + 1);
|
||||
}
|
||||
eprintln!(
|
||||
"==== END EGGLOG PROGRAM ({} lines) ====",
|
||||
program.lines().count()
|
||||
);
|
||||
|
||||
let mut ops = <CudaRuntime as luminal::op::Runtime>::Ops::into_vec();
|
||||
ops.extend(<luminal::hlir::HLIROps as IntoEgglogOp>::into_vec());
|
||||
let egraph = run_egglog(&program, &root, &ops, false).expect("egglog failed");
|
||||
|
||||
// Bucket enode labels by frequency.
|
||||
let mut counts: std::collections::HashMap<String, usize> = Default::default();
|
||||
for (label, _) in egraph.enodes.values() {
|
||||
*counts.entry(label.clone()).or_default() += 1;
|
||||
}
|
||||
let mut sorted: Vec<_> = counts.iter().collect();
|
||||
sorted.sort_by(|a, b| b.1.cmp(a.1));
|
||||
eprintln!("==== E-GRAPH LABEL HISTOGRAM (top 60) ====");
|
||||
for (label, n) in sorted.iter().take(60) {
|
||||
eprintln!(" {n:6} {label}");
|
||||
}
|
||||
let has_fi = egraph
|
||||
.enodes
|
||||
.values()
|
||||
.any(|(label, _)| label == "FlashInferAttention");
|
||||
eprintln!("==== has FlashInferAttention enode: {has_fi} ====");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn flashinfer_rule_does_not_fire_on_bare_attention() {
|
||||
// Dense attention without paged gather + cache should NOT match.
|
||||
let (cx, _, _, _, _) = build_attention_graph();
|
||||
let (has_flashinfer, _) = saturate_and_has_flashinfer(&cx);
|
||||
assert!(
|
||||
!has_flashinfer,
|
||||
"FlashInferAttention should NOT fire on bare attention (no gather/cache)"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn flashinfer_rule_does_not_fire_on_unrelated_matmuls() {
|
||||
// A Gather + plain matmul (MLP-shaped projection) plus two chained matmuls
|
||||
// through softmax — close to attention structurally but missing the GQA
|
||||
// broadcast / mask Add anchors. The rule must reject this.
|
||||
let mut cx = Graph::default();
|
||||
let cache = cx.named_tensor("cache", (4096, KV_DIM)).persist();
|
||||
let gather_idx = cx
|
||||
.named_tensor("gather_idx", 'c')
|
||||
.as_dtype(luminal::dtype::DType::Int);
|
||||
let weight = cx.named_tensor("weight", (HIDDEN, KV_DIM)).persist();
|
||||
|
||||
let n = gather_idx.dims1();
|
||||
let base = (gather_idx * KV_DIM).expand_dim(1, KV_DIM);
|
||||
let col = cx.arange(KV_DIM as i32).expand_dim(0, n);
|
||||
let gathered = cache.gather(base + col);
|
||||
let proj = gathered.matmul(weight.t());
|
||||
proj.output();
|
||||
|
||||
let a = cx.named_tensor("a", ('s', HIDDEN));
|
||||
let b = cx.named_tensor("b", (HIDDEN, HIDDEN)).persist();
|
||||
let c_tensor = cx.named_tensor("c_tensor", (HIDDEN, HIDDEN)).persist();
|
||||
let ab = a.matmul(b.t());
|
||||
let abc = ab.softmax(1).matmul(c_tensor.t());
|
||||
abc.output();
|
||||
|
||||
let (has_flashinfer, _) = saturate_and_has_flashinfer(&cx);
|
||||
assert!(
|
||||
!has_flashinfer,
|
||||
"FlashInferAttention should NOT fire on unrelated matmuls + Gather"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn flashinfer_rule_fires_on_full_paged_attention() {
|
||||
// Default Llama-shaped test dims (HEAD_DIM=64, N_HEADS=8, N_KV_HEADS=2).
|
||||
let (cx, _) = build_paged_attention_graph(N_HEADS, N_KV_HEADS, HEAD_DIM);
|
||||
let (has_flashinfer, op_kinds) = saturate_and_has_flashinfer(&cx);
|
||||
assert!(
|
||||
has_flashinfer,
|
||||
"FlashInferAttention was NOT found in the e-graph (Llama-shaped paged attention). \
|
||||
OpKinds present: {op_kinds:?}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn flashinfer_rule_fires_on_non_llama_dims() {
|
||||
// Different head counts: HEAD_DIM=64, N_HEADS=16, N_KV_HEADS=4 (group=4).
|
||||
// Exercises the model-agnostic structural variables in the rule.
|
||||
let (cx, _) = build_paged_attention_graph(16, 4, 64);
|
||||
let (has_flashinfer, op_kinds) = saturate_and_has_flashinfer(&cx);
|
||||
assert!(
|
||||
has_flashinfer,
|
||||
"FlashInferAttention was NOT found for non-Llama dims. \
|
||||
OpKinds present: {op_kinds:?}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn flashinfer_rule_fires_on_mha() {
|
||||
// MHA: KV_GROUPS=1 (n_heads == n_kv_heads). The GQA broadcast still
|
||||
// structurally appears (expand_dim(1, 1) + merge), so the rule should
|
||||
// still match.
|
||||
let (cx, _) = build_paged_attention_graph(12, 12, 64);
|
||||
let (has_flashinfer, op_kinds) = saturate_and_has_flashinfer(&cx);
|
||||
assert!(
|
||||
has_flashinfer,
|
||||
"FlashInferAttention was NOT found for MHA dims. \
|
||||
OpKinds present: {op_kinds:?}"
|
||||
);
|
||||
}
|
||||
|
||||
// ─── Layer 5: extraction reachability (no GPU) ───────────────────────────
|
||||
//
|
||||
// After `build_search_space` saturates egglog, the GA picks an extraction by
|
||||
// cost. In a tiny test graph the cuBLAS+kernel path is often faster than the
|
||||
// FlashInfer host op (which pays a `plan()` setup cost per call), so asserting
|
||||
// "GA picked FlashInfer" is flaky. Instead, sample many random valid genomes
|
||||
// from the search space and assert that the FlashInfer extraction is reachable
|
||||
// — meaning the rule fired AND `find_indptrs` extraction succeeded for at
|
||||
// least one offspring. That is the end-to-end check we actually want.
|
||||
|
||||
#[test]
|
||||
fn flashinfer_extraction_reachable_from_search_space() {
|
||||
use rand::SeedableRng;
|
||||
use rand::rngs::StdRng;
|
||||
|
||||
let (mut cx, _h) = build_paged_attention_graph(N_HEADS, N_KV_HEADS, HEAD_DIM);
|
||||
cx.set_dim('s', 1usize);
|
||||
cx.set_dim('c', 16usize);
|
||||
cx.set_dim('r', 2usize);
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
|
||||
let egraph = cx
|
||||
.egraph()
|
||||
.expect("egraph missing after build_search_space");
|
||||
let ops = cx
|
||||
.egglog_ops()
|
||||
.expect("egglog_ops missing after build_search_space");
|
||||
|
||||
let mut rng = StdRng::seed_from_u64(0xf1a541);
|
||||
let mut prev: FxHashSet<u64> = FxHashSet::default();
|
||||
let initial = luminal::egglog_utils::random_initial_choice(egraph, &mut rng);
|
||||
prev.insert(luminal::egglog_utils::hash_choice_set(&initial));
|
||||
let mut base = initial;
|
||||
|
||||
let mut found = false;
|
||||
'outer: for _ in 0..50 {
|
||||
let offspring =
|
||||
luminal::egglog_utils::extract_generation(egraph, &base, 10, 2, &mut prev, &mut rng);
|
||||
if offspring.is_empty() {
|
||||
break;
|
||||
}
|
||||
for genome in offspring {
|
||||
if luminal::egglog_utils::validate_choice_set(egraph, &genome, ops).is_err() {
|
||||
continue;
|
||||
}
|
||||
let mut list_cache = FxHashMap::default();
|
||||
let mut expr_cache = FxHashMap::default();
|
||||
// Catch a possible panic from find_indptrs walking the mask — we
|
||||
// want the test to fail with a clean message, not abort.
|
||||
let panicked = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
|
||||
luminal::egglog_utils::egglog_to_llir(
|
||||
egraph,
|
||||
genome.clone(),
|
||||
ops,
|
||||
&cx.custom_ops,
|
||||
&mut list_cache,
|
||||
&mut expr_cache,
|
||||
None,
|
||||
)
|
||||
}));
|
||||
let Ok(llir_graph) = panicked else { continue };
|
||||
|
||||
let has_fi = llir_graph.node_indices().any(|n| {
|
||||
llir_graph[n]
|
||||
.to_dialect::<dyn HostOp>()
|
||||
.and_then(|op| op.stats_name())
|
||||
== Some("FlashInferAttention")
|
||||
});
|
||||
if has_fi {
|
||||
found = true;
|
||||
break 'outer;
|
||||
}
|
||||
base = genome;
|
||||
}
|
||||
}
|
||||
assert!(
|
||||
found,
|
||||
"FlashInferAttention extraction not reachable from search space after 50 generations"
|
||||
);
|
||||
}
|
||||
@@ -1,27 +1,95 @@
|
||||
use as_any::Downcast;
|
||||
use luminal::egglog_utils::{egglog_to_llir, random_initial_choice};
|
||||
use luminal::prelude::*;
|
||||
|
||||
use crate::kernel::KernelOp;
|
||||
use crate::kernel::other_ops::{KernelFusedElementwise, UnaryFn};
|
||||
use crate::runtime::CudaRuntime;
|
||||
use crate::tests::utilities::{
|
||||
TOLERANCE_SAFETY_FACTOR, dtype_epsilon, random_f32_vec, test_binary_cuda, test_unary_cuda,
|
||||
};
|
||||
use crate::tests::utilities::{random_f32_vec, test_unary_cuda};
|
||||
|
||||
/// Return every distinct kernel_name that appears across many random extractions
|
||||
/// of the search space. Used to check whether fusion produces a reachable
|
||||
/// `KernelFusedElementwise` node (or, negatively, that it never does).
|
||||
fn extract_all_kernel_names(cx: &mut Graph) -> Vec<String> {
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
let egraph = cx.egraph().expect("egraph not built");
|
||||
let ops = cx.egglog_ops().expect("ops not built");
|
||||
let custom_ops = &cx.custom_ops;
|
||||
|
||||
let mut all_names = Vec::new();
|
||||
for _ in 0..50 {
|
||||
let choices = random_initial_choice(egraph, &mut rand::rng());
|
||||
let mut list_cache = Default::default();
|
||||
let mut expr_cache = Default::default();
|
||||
let llir = egglog_to_llir(
|
||||
egraph,
|
||||
choices,
|
||||
ops,
|
||||
custom_ops,
|
||||
&mut list_cache,
|
||||
&mut expr_cache,
|
||||
None,
|
||||
);
|
||||
for op in llir.node_weights() {
|
||||
if let Some(k) = op.to_dialect::<dyn KernelOp>() {
|
||||
let name = k.kernel_name().to_string();
|
||||
if !all_names.contains(&name) {
|
||||
all_names.push(name);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
all_names
|
||||
}
|
||||
|
||||
/// Return every distinct `Vec<UnaryFn>` that appears inside a reachable
|
||||
/// `KernelFusedElementwise` across many random extractions. Used to verify
|
||||
/// that a specific fused configuration (e.g. a 3-op chain) is reachable.
|
||||
fn extract_all_fused_configs(cx: &mut Graph) -> Vec<Vec<UnaryFn>> {
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
let egraph = cx.egraph().expect("egraph not built");
|
||||
let ops = cx.egglog_ops().expect("ops not built");
|
||||
let custom_ops = &cx.custom_ops;
|
||||
|
||||
let mut all_configs: Vec<Vec<UnaryFn>> = Vec::new();
|
||||
for _ in 0..200 {
|
||||
let choices = random_initial_choice(egraph, &mut rand::rng());
|
||||
let mut list_cache = Default::default();
|
||||
let mut expr_cache = Default::default();
|
||||
let llir = egglog_to_llir(
|
||||
egraph,
|
||||
choices,
|
||||
ops,
|
||||
custom_ops,
|
||||
&mut list_cache,
|
||||
&mut expr_cache,
|
||||
None,
|
||||
);
|
||||
for op in llir.node_weights() {
|
||||
if let Some(kop) = op.to_dialect::<dyn KernelOp>()
|
||||
&& let Some(fused) = (***kop).downcast_ref::<KernelFusedElementwise>()
|
||||
{
|
||||
let cfg = fused.ops().to_vec();
|
||||
if !all_configs.contains(&cfg) {
|
||||
all_configs.push(cfg);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
all_configs
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_two_unary_ops_fuse() {
|
||||
// Marker form: `a.sin().sqrt()` should fuse into a region with FusedSin
|
||||
// and FusedSqrt under one FusionEnd (per pair-fuse U→U).
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor(8);
|
||||
let _b = a.sin().sqrt().output();
|
||||
|
||||
let regions = extract_all_fused_regions(&mut cx);
|
||||
let expected = sorted_names(&["FusedSin", "FusedSqrt"]);
|
||||
let names = extract_all_kernel_names(&mut cx);
|
||||
assert!(
|
||||
regions
|
||||
.iter()
|
||||
.any(|r| r.internal_ops_sorted == expected && r.start_count == 1 && r.end_count == 1),
|
||||
"expected a marker region of {expected:?} with 1 FusionStart, got: {regions:#?}"
|
||||
names.iter().any(|n| n == "FusedElementwise"),
|
||||
"expected KernelSin→KernelSqrt on contiguous strides to be fusable into \
|
||||
a single FusedElementwise kernel, but reachable kernels were: {names:?}",
|
||||
);
|
||||
}
|
||||
|
||||
@@ -29,42 +97,33 @@ fn test_two_unary_ops_fuse() {
|
||||
fn test_stride_mismatch_prevents_fusion() {
|
||||
// A permute between sin and sqrt gives sqrt a non-contiguous view of sin's
|
||||
// contiguous output, so sqrt's in_strides != its out_strides and the
|
||||
// non-linear `?s ?s` match in the pair-fuse U→U rule can't fire.
|
||||
// non-linear `?strides` match in the fusion rule can't fire.
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor((3, 4));
|
||||
let _b = a.sin().permute((1, 0)).sqrt().output();
|
||||
|
||||
let regions = extract_all_fused_regions(&mut cx);
|
||||
for r in ®ions {
|
||||
let has_sin = r.internal_ops_sorted.iter().any(|n| n == "FusedSin");
|
||||
let has_sqrt = r.internal_ops_sorted.iter().any(|n| n == "FusedSqrt");
|
||||
assert!(
|
||||
!(has_sin && has_sqrt),
|
||||
"permute between sin and sqrt must prevent them sharing a fused region, \
|
||||
but found: {r:#?}"
|
||||
);
|
||||
}
|
||||
let names = extract_all_kernel_names(&mut cx);
|
||||
assert!(
|
||||
!names.iter().any(|n| n == "FusedElementwise"),
|
||||
"a permute between sin and sqrt must prevent fusion, but \
|
||||
FusedElementwise appeared in reachable kernels: {names:?}",
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reduction_prevents_unary_fusion() {
|
||||
// A reduction between two unaries is not elementwise, so pair-fuse U→U
|
||||
// (which only matches adjacent elementwise pairs) must not fire across
|
||||
// the reduction.
|
||||
// A reduction between two unaries is not elementwise, so the fusion rule
|
||||
// (which only matches unary+unary pairs) must not fire.
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor((4, 4));
|
||||
let _b = a.sin().sum(1).sqrt().output();
|
||||
|
||||
let regions = extract_all_fused_regions(&mut cx);
|
||||
for r in ®ions {
|
||||
let has_sin = r.internal_ops_sorted.iter().any(|n| n == "FusedSin");
|
||||
let has_sqrt = r.internal_ops_sorted.iter().any(|n| n == "FusedSqrt");
|
||||
assert!(
|
||||
!(has_sin && has_sqrt),
|
||||
"reduction between sin and sqrt must prevent them sharing a fused region, \
|
||||
but found: {r:#?}"
|
||||
);
|
||||
}
|
||||
let names = extract_all_kernel_names(&mut cx);
|
||||
assert!(
|
||||
!names.iter().any(|n| n == "FusedElementwise"),
|
||||
"a reduction between sin and sqrt must prevent fusion, but \
|
||||
FusedElementwise appeared in reachable kernels: {names:?}",
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -86,36 +145,31 @@ fn test_unary_fusion_preserves_output() {
|
||||
#[test]
|
||||
fn test_three_unary_ops_fuse() {
|
||||
// A chain of 3 pure-elementwise unaries with matching strides should be
|
||||
// reachable as a single marker region containing all three FusedX ops.
|
||||
// reachable as a single FusedElementwise containing all three ops.
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor(16);
|
||||
let _b = a.sin().sqrt().exp2().output();
|
||||
|
||||
let regions = extract_all_fused_regions(&mut cx);
|
||||
let expected = sorted_names(&["FusedSin", "FusedSqrt", "FusedExp2"]);
|
||||
let configs = extract_all_fused_configs(&mut cx);
|
||||
let expected = vec![UnaryFn::Sin, UnaryFn::Sqrt, UnaryFn::Exp2];
|
||||
assert!(
|
||||
regions
|
||||
.iter()
|
||||
.any(|r| r.internal_ops_sorted == expected && r.start_count == 1 && r.end_count == 1),
|
||||
"expected a marker region of {expected:?} with 1 FusionStart, got: {regions:#?}"
|
||||
configs.contains(&expected),
|
||||
"expected a Fused[Sin, Sqrt, Exp2] in reachable configs, got: {configs:?}",
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_four_unary_ops_fuse() {
|
||||
// 4-op chain should collapse into a single marker region containing all
|
||||
// four FusedX ops (one pair-fuse + repeated grow-FE→U firings).
|
||||
// 4-op chain should collapse into a single Fused containing all four ops.
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor(16);
|
||||
let _b = a.sin().sqrt().exp2().log2().output();
|
||||
|
||||
let regions = extract_all_fused_regions(&mut cx);
|
||||
let expected = sorted_names(&["FusedSin", "FusedSqrt", "FusedExp2", "FusedLog2"]);
|
||||
let configs = extract_all_fused_configs(&mut cx);
|
||||
let expected = vec![UnaryFn::Sin, UnaryFn::Sqrt, UnaryFn::Exp2, UnaryFn::Log2];
|
||||
assert!(
|
||||
regions
|
||||
.iter()
|
||||
.any(|r| r.internal_ops_sorted == expected && r.start_count == 1 && r.end_count == 1),
|
||||
"expected a marker region of {expected:?} with 1 FusionStart, got: {regions:#?}"
|
||||
configs.contains(&expected),
|
||||
"expected a Fused[Sin, Sqrt, Exp2, Log2] in reachable configs, got: {configs:?}",
|
||||
);
|
||||
}
|
||||
|
||||
@@ -262,725 +316,3 @@ extern "C" __global__ void fused_k(float* out, const float* in, long long n) {
|
||||
speedup: {speedup:.2}x"
|
||||
);
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// Binary-inclusive fusion tests (marker-based FusionStart / FusionEnd scheme).
|
||||
//
|
||||
// Detects fused regions by walking backward from each `FusionEnd`-tagged LLIR
|
||||
// node through `Direction::Incoming` edges until a `FusionStart` is reached.
|
||||
// The walker stops at FusionStarts (they mark the external-input boundary of
|
||||
// the region). A region's summary is: the sorted set of internal op names,
|
||||
// the count of distinct FusionStart nodes reached, and the count of FusionEnd
|
||||
// nodes (invariant: always 1 per region).
|
||||
// =========================================================================
|
||||
|
||||
/// A single fused region extracted from the LLIR graph after egglog.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
struct FusedRegion {
|
||||
/// Sorted internal op `kernel_name()`s, excluding the `FusionStart` /
|
||||
/// `FusionEnd` markers. Sorted so DAG traversal order doesn't produce
|
||||
/// spurious "distinct" regions.
|
||||
internal_ops_sorted: Vec<String>,
|
||||
/// Number of distinct `FusionStart` nodes reached by the walk. Per design
|
||||
/// this equals the number of distinct external input tensors.
|
||||
start_count: usize,
|
||||
/// Number of `FusionEnd` nodes in the region. Per design this is always 1.
|
||||
end_count: usize,
|
||||
}
|
||||
|
||||
/// Helper: collect every distinct fused region reachable across many random
|
||||
/// extractions of the search space.
|
||||
fn extract_all_fused_regions(cx: &mut Graph) -> Vec<FusedRegion> {
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
let egraph = cx.egraph().expect("egraph not built");
|
||||
let ops = cx.egglog_ops().expect("ops not built");
|
||||
let custom_ops = &cx.custom_ops;
|
||||
|
||||
let mut seen: Vec<FusedRegion> = Vec::new();
|
||||
// 200 samples: the random extractor picks one e-node per e-class per
|
||||
// call, and the fully-fused diamond form lives in an e-class with
|
||||
// many equivalent forms. 50 was flaky; 200 is reliably stable and
|
||||
// each sample is cheap (~100 µs).
|
||||
for _ in 0..200 {
|
||||
let choices = random_initial_choice(egraph, &mut rand::rng());
|
||||
let mut list_cache = Default::default();
|
||||
let mut expr_cache = Default::default();
|
||||
let llir = egglog_to_llir(
|
||||
egraph,
|
||||
choices,
|
||||
ops,
|
||||
custom_ops,
|
||||
&mut list_cache,
|
||||
&mut expr_cache,
|
||||
None,
|
||||
);
|
||||
|
||||
let name_of = |idx: NodeIndex| -> Option<String> {
|
||||
llir.node_weight(idx).and_then(|op| {
|
||||
op.to_dialect::<dyn KernelOp>()
|
||||
.map(|k| k.kernel_name().to_string())
|
||||
})
|
||||
};
|
||||
|
||||
let end_nodes: Vec<NodeIndex> = llir
|
||||
.node_indices()
|
||||
.filter(|&idx| name_of(idx).as_deref() == Some("FusionEnd"))
|
||||
.collect();
|
||||
|
||||
for end in end_nodes {
|
||||
let mut internal: Vec<String> = Vec::new();
|
||||
// Count distinct external input *tensors*, not distinct FusionStart
|
||||
// node indices. Egglog rule firings can emit multiple FusionStart
|
||||
// enodes that all wrap the same source tensor (e.g. when the same
|
||||
// `a` is consumed at two sites inside the fused region, each
|
||||
// pair-fuse / grow firing mints its own FusionStart). Those are
|
||||
// logically one FusionStart per the design invariant
|
||||
// ("N = number of distinct external input tensors").
|
||||
let mut start_sources: FxHashSet<NodeIndex> = FxHashSet::default();
|
||||
let mut visited: FxHashSet<NodeIndex> = FxHashSet::default();
|
||||
visited.insert(end);
|
||||
let mut stack = vec![end];
|
||||
|
||||
// Resolve chains of nested FusionStart wrappers (cascade artifact)
|
||||
// to the real external source. A FusionStart whose incoming neighbor
|
||||
// is itself a FusionStart — or a FusionEnd whose region is fully
|
||||
// inside ours — is a cascade layer, not a new external tensor.
|
||||
let resolve_source = |mut n: NodeIndex| -> NodeIndex {
|
||||
loop {
|
||||
match name_of(n).as_deref() {
|
||||
Some("FusionStart") | Some("FusionEnd") => {
|
||||
let mut inc = llir.neighbors_directed(n, petgraph::Direction::Incoming);
|
||||
match inc.next() {
|
||||
Some(p) => n = p,
|
||||
None => return n,
|
||||
}
|
||||
}
|
||||
_ => return n,
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
while let Some(node) = stack.pop() {
|
||||
for pred in llir.neighbors_directed(node, petgraph::Direction::Incoming) {
|
||||
if !visited.insert(pred) {
|
||||
continue;
|
||||
}
|
||||
match name_of(pred).as_deref() {
|
||||
Some("FusionStart") => {
|
||||
// If this FS's predecessor is itself a FE (or a
|
||||
// chain of FS/FE wrappers that eventually hits a
|
||||
// non-marker op inside the region), the FS is a
|
||||
// cascade artifact, not a real external boundary.
|
||||
// Walk past it and its upstream FE into the same
|
||||
// region. Otherwise treat the predecessor as the
|
||||
// external source tensor — which may be a KernelOp
|
||||
// *or* a non-KernelOp (HLIR loadable) node, so we
|
||||
// can't gate counting on `name_of` being `Some`.
|
||||
let mut inc =
|
||||
llir.neighbors_directed(pred, petgraph::Direction::Incoming);
|
||||
match inc.next() {
|
||||
Some(src_node)
|
||||
if name_of(src_node).as_deref() == Some("FusionEnd") =>
|
||||
{
|
||||
// Merge adjacent regions — treat the FS/FE
|
||||
// pair as internal; walk past the upstream
|
||||
// FE into its region.
|
||||
visited.insert(src_node);
|
||||
stack.push(src_node);
|
||||
}
|
||||
Some(src_node) => {
|
||||
start_sources.insert(resolve_source(src_node));
|
||||
}
|
||||
None => {
|
||||
// FS with no predecessor — degenerate.
|
||||
}
|
||||
}
|
||||
}
|
||||
Some("FusionEnd") => {
|
||||
// Transparent: inner FusionEnds are cascade-wart
|
||||
// artifacts from grow rules re-firing and creating
|
||||
// nested `FE(Op(FE(...)))` wrappers. They don't
|
||||
// represent real work or a real boundary — walk
|
||||
// past them and do not count them as internal ops.
|
||||
stack.push(pred);
|
||||
}
|
||||
Some(other) => {
|
||||
internal.push(other.to_string());
|
||||
stack.push(pred);
|
||||
}
|
||||
None => {
|
||||
// Non-KernelOp predecessor (shouldn't appear inside a
|
||||
// fused region under the design). Stop walking this path.
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
internal.sort();
|
||||
// Skip singleton regions: every elementwise op has a seeded
|
||||
// `FE(Op(FS(...)))` form, so random extraction will surface
|
||||
// many one-op regions that are equivalent to not fusing. We
|
||||
// only care about regions that represent real multi-op fusion.
|
||||
if internal.len() < 2 {
|
||||
continue;
|
||||
}
|
||||
let region = FusedRegion {
|
||||
internal_ops_sorted: internal,
|
||||
start_count: start_sources.len(),
|
||||
end_count: 1,
|
||||
};
|
||||
if !seen.contains(®ion) {
|
||||
seen.push(region);
|
||||
}
|
||||
}
|
||||
}
|
||||
seen
|
||||
}
|
||||
|
||||
fn sorted_names(items: &[&str]) -> Vec<String> {
|
||||
let mut v: Vec<String> = items.iter().map(|s| (*s).to_string()).collect();
|
||||
v.sort();
|
||||
v
|
||||
}
|
||||
|
||||
// ---- Structural tests: the expected fused shape is reachable ----
|
||||
|
||||
#[test]
|
||||
fn test_single_binary_does_not_fuse_alone() {
|
||||
// A lone elementwise op gets a seeded singleton region by design; we
|
||||
// filter singletons out in `extract_all_fused_regions`. What this test
|
||||
// asserts is that no *multi-op* region appears for a standalone binary
|
||||
// — nothing to grow into.
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor(8);
|
||||
let b = cx.tensor(8);
|
||||
let _c = (a + b).output();
|
||||
|
||||
let regions = extract_all_fused_regions(&mut cx);
|
||||
assert!(
|
||||
regions.is_empty(),
|
||||
"a solo binary op should not form a multi-op fused region, but got: {regions:#?}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chain_of_binaries_fuses() {
|
||||
// `(a + b) * c`: three external inputs collapse into one region with
|
||||
// internal [Add, Mul] and 3 FusionStarts.
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor(8);
|
||||
let b = cx.tensor(8);
|
||||
let c = cx.tensor(8);
|
||||
let _d = ((a + b) * c).output();
|
||||
|
||||
let regions = extract_all_fused_regions(&mut cx);
|
||||
let expected = sorted_names(&["FusedAdd", "FusedMul"]);
|
||||
assert!(
|
||||
regions
|
||||
.iter()
|
||||
.any(|r| r.internal_ops_sorted == expected && r.start_count == 3),
|
||||
"expected a fused region of {expected:?} with 3 FusionStarts, got: {regions:#?}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_then_unary_fuses() {
|
||||
// `sin(a + b)`: binary feeds a unary inside one fused region.
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor(8);
|
||||
let b = cx.tensor(8);
|
||||
let _c = (a + b).sin().output();
|
||||
|
||||
let regions = extract_all_fused_regions(&mut cx);
|
||||
let expected = sorted_names(&["FusedAdd", "FusedSin"]);
|
||||
assert!(
|
||||
regions
|
||||
.iter()
|
||||
.any(|r| r.internal_ops_sorted == expected && r.start_count == 2),
|
||||
"expected a fused region of {expected:?} with 2 FusionStarts, got: {regions:#?}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unary_then_binary_fuses() {
|
||||
// `sin(a) + b`: unary feeds a binary inside one fused region.
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor(8);
|
||||
let b = cx.tensor(8);
|
||||
let _c = (a.sin() + b).output();
|
||||
|
||||
let regions = extract_all_fused_regions(&mut cx);
|
||||
let expected = sorted_names(&["FusedAdd", "FusedSin"]);
|
||||
assert!(
|
||||
regions
|
||||
.iter()
|
||||
.any(|r| r.internal_ops_sorted == expected && r.start_count == 2),
|
||||
"expected a fused region of {expected:?} with 2 FusionStarts, got: {regions:#?}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_diamond_dag_fuses() {
|
||||
// The canonical diamond-DAG example agreed with the user:
|
||||
// t = a + b; u = exp2(t); v = sin(t); w = u * a; out = w + v
|
||||
// `a` is reused (feeds outer Add and Mul) and `t` is reused (feeds Exp2 and
|
||||
// Sin). Expected: one fused region with internal ops [Add, Add, Exp2, Mul,
|
||||
// Sin], 2 FusionStarts (distinct tensors a, b), 1 FusionEnd.
|
||||
// We use exp2 rather than exp because the frontend's exp() desugars to
|
||||
// Mul(x, LOG2E).exp2(), which would add a constant input and a Mul op and
|
||||
// obscure the diamond topology this test is checking.
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor(8);
|
||||
let b = cx.tensor(8);
|
||||
let t = a + b;
|
||||
let u = t.exp2();
|
||||
let v = t.sin();
|
||||
let w = u * a;
|
||||
let _out = (w + v).output();
|
||||
|
||||
let regions = extract_all_fused_regions(&mut cx);
|
||||
let expected = sorted_names(&["FusedAdd", "FusedAdd", "FusedExp2", "FusedMul", "FusedSin"]);
|
||||
assert!(
|
||||
regions
|
||||
.iter()
|
||||
.any(|r| r.internal_ops_sorted == expected && r.start_count == 2 && r.end_count == 1),
|
||||
"expected diamond DAG to fuse into one region with ops {expected:?}, \
|
||||
2 FusionStarts, 1 FusionEnd. Got: {regions:#?}"
|
||||
);
|
||||
}
|
||||
|
||||
// ---- Negative tests: fusion must NOT happen across these blockers ----
|
||||
|
||||
#[test]
|
||||
fn test_reduction_blocks_binary_fusion() {
|
||||
// A reduction between a binary and anything downstream is not elementwise,
|
||||
// so Add and SumReduce must never appear in the same fused region.
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor((4, 4));
|
||||
let b = cx.tensor((4, 4));
|
||||
let _c = (a + b).sum(1).output();
|
||||
|
||||
let regions = extract_all_fused_regions(&mut cx);
|
||||
for r in ®ions {
|
||||
let has_add = r.internal_ops_sorted.iter().any(|n| n == "FusedAdd");
|
||||
let has_sum = r.internal_ops_sorted.iter().any(|n| n == "SumReduce");
|
||||
assert!(
|
||||
!(has_add && has_sum),
|
||||
"FusedAdd and SumReduce must not share a fused region, but got: {r:#?}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_stride_mismatch_blocks_binary_fusion() {
|
||||
// A permute gives `b` a non-contiguous view whose strides do not match `a`'s,
|
||||
// so the binary fusion rule's stride-compatibility check must prevent the
|
||||
// Add from being absorbed into any fused region.
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor((3, 4));
|
||||
let b = cx.tensor((4, 3));
|
||||
let _c = (a + b.permute((1, 0))).output();
|
||||
|
||||
let regions = extract_all_fused_regions(&mut cx);
|
||||
for r in ®ions {
|
||||
assert!(
|
||||
!r.internal_ops_sorted.iter().any(|n| n == "FusedAdd"),
|
||||
"permuted binary must not fuse into a region, but found: {r:#?}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Numerical parity tests: fused output matches candle reference ----
|
||||
|
||||
#[test]
|
||||
fn test_simple_binary_fusion_preserves_output() {
|
||||
// End-to-end numerical check: `a + b` on GPU matches candle's add across
|
||||
// all reachable genomes (fused or unfused) via test_binary_cuda's fuzzer.
|
||||
let seed = 0xADDBEEFu64;
|
||||
let eps = dtype_epsilon(luminal::dtype::DType::F32);
|
||||
let tol = eps * TOLERANCE_SAFETY_FACTOR;
|
||||
test_binary_cuda::<f32>(
|
||||
16,
|
||||
16,
|
||||
|a, b| a + b,
|
||||
|a, b| (a + b).unwrap(),
|
||||
|n, s| random_f32_vec(n, s, 0.0, 1.0),
|
||||
|n, s| random_f32_vec(n, s, 0.0, 1.0),
|
||||
seed,
|
||||
tol,
|
||||
tol,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_diamond_dag_preserves_output() {
|
||||
// Numerical parity for the diamond DAG: `(exp(a+b) * a) + sin(a+b)`
|
||||
// matches candle's equivalent across fused and unfused genomes.
|
||||
// Inputs are drawn from [-1, 1] so exp() doesn't overflow.
|
||||
let seed = 0xD1A_0D1Au64;
|
||||
let eps = dtype_epsilon(luminal::dtype::DType::F32);
|
||||
// Five-op chain with exp + sin: allow ~5x safety to absorb accumulated
|
||||
// rounding vs candle's kernels.
|
||||
let tol = eps * TOLERANCE_SAFETY_FACTOR * 5.0;
|
||||
test_binary_cuda::<f32>(
|
||||
16,
|
||||
16,
|
||||
|a, b| {
|
||||
let t = a + b;
|
||||
let u = t.exp();
|
||||
let v = t.sin();
|
||||
let w = u * a;
|
||||
w + v
|
||||
},
|
||||
|a, b| {
|
||||
let t = (&a + &b).unwrap();
|
||||
let u = t.exp().unwrap();
|
||||
let v = t.sin().unwrap();
|
||||
let w = (&u * &a).unwrap();
|
||||
(&w + &v).unwrap()
|
||||
},
|
||||
|n, s| random_f32_vec(n, s, -1.0, 1.0),
|
||||
|n, s| random_f32_vec(n, s, -1.0, 1.0),
|
||||
seed,
|
||||
tol,
|
||||
tol,
|
||||
);
|
||||
}
|
||||
|
||||
// ---- Marker invariant tests ----
|
||||
|
||||
#[test]
|
||||
fn test_fused_region_has_exactly_one_end() {
|
||||
// Design invariant: a fused region always has exactly one FusionEnd.
|
||||
// Uses the diamond DAG so there's real fan-in/out inside the region.
|
||||
// See test_diamond_dag_fuses for why we use exp2 directly.
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor(8);
|
||||
let b = cx.tensor(8);
|
||||
let t = a + b;
|
||||
let u = t.exp2();
|
||||
let v = t.sin();
|
||||
let w = u * a;
|
||||
let _out = (w + v).output();
|
||||
|
||||
let regions = extract_all_fused_regions(&mut cx);
|
||||
let expected = sorted_names(&["FusedAdd", "FusedAdd", "FusedExp2", "FusedMul", "FusedSin"]);
|
||||
let full = regions
|
||||
.iter()
|
||||
.find(|r| r.internal_ops_sorted == expected)
|
||||
.expect("expected at least one extraction to produce the full 5-op diamond region");
|
||||
assert_eq!(
|
||||
full.end_count, 1,
|
||||
"fused region must have exactly one FusionEnd, got {}",
|
||||
full.end_count
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fused_region_starts_match_distinct_external_tensors() {
|
||||
// Design invariant: FusionStart count == number of distinct external input
|
||||
// tensors, NOT number of edges crossing the boundary. In the diamond DAG
|
||||
// `a` is consumed inside the region by two ops (outer Add + Mul), so a
|
||||
// per-edge counting scheme would give 3; the correct per-distinct-tensor
|
||||
// count is 2 ({a, b}).
|
||||
// See test_diamond_dag_fuses for why we use exp2 directly.
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor(8);
|
||||
let b = cx.tensor(8);
|
||||
let t = a + b;
|
||||
let u = t.exp2();
|
||||
let v = t.sin();
|
||||
let w = u * a;
|
||||
let _out = (w + v).output();
|
||||
|
||||
let regions = extract_all_fused_regions(&mut cx);
|
||||
let expected = sorted_names(&["FusedAdd", "FusedAdd", "FusedExp2", "FusedMul", "FusedSin"]);
|
||||
// Multiple 5-op extractions are reachable: the merge-FE-FE rule fires
|
||||
// across paths that may have minted distinct FS enodes for the shared
|
||||
// tensor `a` at separate sites. The design invariant is that *some*
|
||||
// extraction collapses those into the deduped form (one FS per distinct
|
||||
// tensor → 2 FS for {a, b}); we don't require every random sample to.
|
||||
let matching: Vec<&FusedRegion> = regions
|
||||
.iter()
|
||||
.filter(|r| r.internal_ops_sorted == expected)
|
||||
.collect();
|
||||
assert!(
|
||||
!matching.is_empty(),
|
||||
"expected at least one extraction to produce the full 5-op diamond region, \
|
||||
got: {regions:#?}"
|
||||
);
|
||||
assert!(
|
||||
matching
|
||||
.iter()
|
||||
.any(|r| r.start_count == 2 && r.end_count == 1),
|
||||
"expected at least one 5-op diamond extraction with FusionStart count == 2 \
|
||||
(one per distinct external tensor) and FusionEnd count == 1; got: {matching:#?}"
|
||||
);
|
||||
}
|
||||
|
||||
// ---- Targeted rule-family tests (one per family / orientation) ----
|
||||
//
|
||||
// The structural and diamond tests above hit several rule families at once.
|
||||
// These narrow tests pin each rule family / orientation independently so a
|
||||
// regression in one rule shows up as a single failing test rather than a
|
||||
// confusing diamond mismatch.
|
||||
|
||||
#[test]
|
||||
fn test_pair_fuse_unary_unary_marker_form() {
|
||||
// Pair-fuse U→U: `a.sin().sqrt()` should be reachable as a marker-bracketed
|
||||
// region containing FusedSin and FusedSqrt (with one FusionStart for `a`).
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor(8);
|
||||
let _b = a.sin().sqrt().output();
|
||||
|
||||
let regions = extract_all_fused_regions(&mut cx);
|
||||
let expected = sorted_names(&["FusedSin", "FusedSqrt"]);
|
||||
assert!(
|
||||
regions
|
||||
.iter()
|
||||
.any(|r| r.internal_ops_sorted == expected && r.start_count == 1 && r.end_count == 1),
|
||||
"expected marker region of {expected:?} with 1 FusionStart, got: {regions:#?}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pair_fuse_unary_to_binary_rhs() {
|
||||
// Pair-fuse U→B (RHS variant): `a + b.sin()`. The unary is on the
|
||||
// binary's B input, so the rule's RHS-orientation version is what fires.
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor(8);
|
||||
let b = cx.tensor(8);
|
||||
let _c = (a + b.sin()).output();
|
||||
|
||||
let regions = extract_all_fused_regions(&mut cx);
|
||||
let expected = sorted_names(&["FusedAdd", "FusedSin"]);
|
||||
assert!(
|
||||
regions
|
||||
.iter()
|
||||
.any(|r| r.internal_ops_sorted == expected && r.start_count == 2),
|
||||
"expected a fused region of {expected:?} with 2 FusionStarts (RHS-side unary), \
|
||||
got: {regions:#?}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pair_fuse_binary_to_binary_rhs() {
|
||||
// Pair-fuse B→B (RHS variant): `c * (a + b)`. The inner binary feeds the
|
||||
// outer binary's B input, exercising the mirror direction of the rule
|
||||
// covered by test_chain_of_binaries_fuses.
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor(8);
|
||||
let b = cx.tensor(8);
|
||||
let c = cx.tensor(8);
|
||||
let _d = (c * (a + b)).output();
|
||||
|
||||
let regions = extract_all_fused_regions(&mut cx);
|
||||
let expected = sorted_names(&["FusedAdd", "FusedMul"]);
|
||||
assert!(
|
||||
regions
|
||||
.iter()
|
||||
.any(|r| r.internal_ops_sorted == expected && r.start_count == 3),
|
||||
"expected a fused region of {expected:?} with 3 FusionStarts (RHS-side inner binary), \
|
||||
got: {regions:#?}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_grow_fe_to_binary_rhs() {
|
||||
// Grow FE→B (RHS variant): `c + (a.sin() + b)`. Once the inner
|
||||
// `a.sin() + b` is fused, the outer `+ c` consumes that FE on its B input
|
||||
// (because we wrote `c + (...)` — `c` is on LHS, FE on RHS), exercising
|
||||
// grow-FE-B-rhs to absorb the outer Add into the same region.
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor(8);
|
||||
let b = cx.tensor(8);
|
||||
let c = cx.tensor(8);
|
||||
let _d = (c + (a.sin() + b)).output();
|
||||
|
||||
let regions = extract_all_fused_regions(&mut cx);
|
||||
let expected = sorted_names(&["FusedAdd", "FusedAdd", "FusedSin"]);
|
||||
assert!(
|
||||
regions
|
||||
.iter()
|
||||
.any(|r| r.internal_ops_sorted == expected && r.start_count == 3),
|
||||
"expected a 3-op fused region of {expected:?} with 3 FusionStarts (grow into RHS), \
|
||||
got: {regions:#?}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_merge_two_regions_at_outer_binary() {
|
||||
// Merge: `(sin(a) + b) + (sqrt(c) + d)`. Each side independently pair-fuses
|
||||
// U→B on its own (the unary gives the inner Add a fusion partner that
|
||||
// doesn't pull in the outer Add), so both sides become FEs. The outer Add
|
||||
// then fires merge-FE-FE-Add to collapse them into a single region.
|
||||
// Without the unaries, `(a+b) + (c+d)` would only ever pair-fuse one
|
||||
// inner Add at a time with the outer Add — merge wouldn't have two FEs to
|
||||
// combine because the inner Adds never become singleton FEs on their own.
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor(8);
|
||||
let b = cx.tensor(8);
|
||||
let c = cx.tensor(8);
|
||||
let d = cx.tensor(8);
|
||||
let _e = ((a.sin() + b) + (c.sqrt() + d)).output();
|
||||
|
||||
let regions = extract_all_fused_regions(&mut cx);
|
||||
let expected = sorted_names(&["FusedAdd", "FusedAdd", "FusedAdd", "FusedSin", "FusedSqrt"]);
|
||||
assert!(
|
||||
regions
|
||||
.iter()
|
||||
.any(|r| r.internal_ops_sorted == expected && r.start_count == 4),
|
||||
"expected a 5-op merged region (two pair-fused sides combined at outer Add) with \
|
||||
4 FusionStarts, got: {regions:#?}"
|
||||
);
|
||||
}
|
||||
|
||||
/// Microbench: time three unfused kernels (`add_k` → `sin_k` → `sqrt_k`)
|
||||
/// vs one fused kernel (`(a + b).sin().sqrt()` in a single launch) on a
|
||||
/// fixed-size input, using CUDA events for device-side timing. Mirrors
|
||||
/// the existing sqrt→recip bench but on the binary-inclusive 3-op DAG
|
||||
/// PR2's region codegen targets.
|
||||
///
|
||||
/// Ignored by default — run with
|
||||
/// `cargo test -p luminal_cuda_lite -- --ignored bench_fused_region_vs_unfused_3op --nocapture`.
|
||||
#[test]
|
||||
#[ignore]
|
||||
fn bench_fused_region_vs_unfused_3op() {
|
||||
use crate::compile_module_image_for_current_device;
|
||||
use cudarc::driver::{CudaContext, LaunchConfig, PushKernelArg};
|
||||
|
||||
const N: usize = 1 << 20; // 1M elements
|
||||
const WARMUP: usize = 100;
|
||||
const TRIALS: usize = 2000;
|
||||
|
||||
let ctx = match CudaContext::new(0) {
|
||||
Ok(c) => c,
|
||||
Err(_) => return, // no GPU available, skip
|
||||
};
|
||||
ctx.bind_to_thread().unwrap();
|
||||
let stream = ctx.default_stream();
|
||||
|
||||
// Inputs in (0, 1] keep `sin` < 1 and `sqrt` well-defined post-add.
|
||||
let host_a: Vec<f32> = (0..N)
|
||||
.map(|i| (i as f32 + 1.0) / (N as f32) * 0.5)
|
||||
.collect();
|
||||
let host_b: Vec<f32> = (0..N)
|
||||
.map(|i| (i as f32 + 1.0) / (N as f32) * 0.5)
|
||||
.collect();
|
||||
let d_a = stream.clone_htod(&host_a).unwrap();
|
||||
let d_b = stream.clone_htod(&host_b).unwrap();
|
||||
let mut d_scratch1 = stream.alloc_zeros::<f32>(N).unwrap();
|
||||
let mut d_scratch2 = stream.alloc_zeros::<f32>(N).unwrap();
|
||||
let mut d_out = stream.alloc_zeros::<f32>(N).unwrap();
|
||||
|
||||
let compile = |src: &str, name: &str| {
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), src).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
module.load_function(name).unwrap()
|
||||
};
|
||||
|
||||
let add_k = compile(
|
||||
r#"
|
||||
extern "C" __global__ void add_k(float* out, const float* a, const float* b, long long n) {
|
||||
long long i = (long long)blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (i >= n) return;
|
||||
out[i] = a[i] + b[i];
|
||||
}
|
||||
"#,
|
||||
"add_k",
|
||||
);
|
||||
let sin_k = compile(
|
||||
r#"
|
||||
extern "C" __global__ void sin_k(float* out, const float* in, long long n) {
|
||||
long long i = (long long)blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (i >= n) return;
|
||||
out[i] = sinf(in[i]);
|
||||
}
|
||||
"#,
|
||||
"sin_k",
|
||||
);
|
||||
let sqrt_k = compile(
|
||||
r#"
|
||||
extern "C" __global__ void sqrt_k(float* out, const float* in, long long n) {
|
||||
long long i = (long long)blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (i >= n) return;
|
||||
out[i] = sqrtf(in[i]);
|
||||
}
|
||||
"#,
|
||||
"sqrt_k",
|
||||
);
|
||||
let fused_k = compile(
|
||||
r#"
|
||||
extern "C" __global__ void fused_k(float* out, const float* a, const float* b, long long n) {
|
||||
long long i = (long long)blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (i >= n) return;
|
||||
float v = a[i] + b[i];
|
||||
v = sinf(v);
|
||||
v = sqrtf(v);
|
||||
out[i] = v;
|
||||
}
|
||||
"#,
|
||||
"fused_k",
|
||||
);
|
||||
|
||||
let cfg = LaunchConfig::for_num_elems(N as u32);
|
||||
let n_arg: i64 = N as i64;
|
||||
|
||||
let launch_unfused =
|
||||
|d_out: &mut cudarc::driver::CudaSlice<f32>,
|
||||
d_scratch1: &mut cudarc::driver::CudaSlice<f32>,
|
||||
d_scratch2: &mut cudarc::driver::CudaSlice<f32>| {
|
||||
let mut b = stream.launch_builder(&add_k);
|
||||
b.arg(&mut *d_scratch1).arg(&d_a).arg(&d_b).arg(&n_arg);
|
||||
unsafe { b.launch(cfg) }.unwrap();
|
||||
let mut b = stream.launch_builder(&sin_k);
|
||||
b.arg(&mut *d_scratch2).arg(&*d_scratch1).arg(&n_arg);
|
||||
unsafe { b.launch(cfg) }.unwrap();
|
||||
let mut b = stream.launch_builder(&sqrt_k);
|
||||
b.arg(d_out).arg(&*d_scratch2).arg(&n_arg);
|
||||
unsafe { b.launch(cfg) }.unwrap();
|
||||
};
|
||||
let launch_fused = |d_out: &mut cudarc::driver::CudaSlice<f32>| {
|
||||
let mut b = stream.launch_builder(&fused_k);
|
||||
b.arg(d_out).arg(&d_a).arg(&d_b).arg(&n_arg);
|
||||
unsafe { b.launch(cfg) }.unwrap();
|
||||
};
|
||||
|
||||
// Warmup
|
||||
for _ in 0..WARMUP {
|
||||
launch_unfused(&mut d_out, &mut d_scratch1, &mut d_scratch2);
|
||||
launch_fused(&mut d_out);
|
||||
}
|
||||
stream.synchronize().unwrap();
|
||||
|
||||
// Host-side wall-clock timing: synchronize before/after each batch so the
|
||||
// measured interval covers exactly the GPU work for `TRIALS` iterations.
|
||||
// (CUDA event-based timing is the more precise option in principle, but
|
||||
// `event.elapsed_ms` on this driver/cudarc combo errors with
|
||||
// CUDA_ERROR_INVALID_HANDLE — see bench_fused_vs_unfused_sqrt_recip
|
||||
// above which fails the same way. Wall-clock is reliable here.)
|
||||
let unfused_start = std::time::Instant::now();
|
||||
for _ in 0..TRIALS {
|
||||
launch_unfused(&mut d_out, &mut d_scratch1, &mut d_scratch2);
|
||||
}
|
||||
stream.synchronize().unwrap();
|
||||
let unfused_total_ms = unfused_start.elapsed().as_secs_f64() * 1_000.0;
|
||||
|
||||
let fused_start = std::time::Instant::now();
|
||||
for _ in 0..TRIALS {
|
||||
launch_fused(&mut d_out);
|
||||
}
|
||||
stream.synchronize().unwrap();
|
||||
let fused_total_ms = fused_start.elapsed().as_secs_f64() * 1_000.0;
|
||||
|
||||
let unfused_us = unfused_total_ms * 1_000.0 / TRIALS as f64;
|
||||
let fused_us = fused_total_ms * 1_000.0 / TRIALS as f64;
|
||||
let speedup = unfused_us / fused_us;
|
||||
|
||||
println!(
|
||||
"\n[fusion microbench, (a+b).sin().sqrt(), N={N}, trials={TRIALS}]\n\
|
||||
unfused (add_k; sin_k; sqrt_k): {unfused_us:8.3} us/iter ({unfused_total_ms:.2} ms total)\n\
|
||||
fused (one kernel): {fused_us:8.3} us/iter ({fused_total_ms:.2} ms total)\n\
|
||||
speedup: {speedup:.2}x"
|
||||
);
|
||||
}
|
||||
|
||||
@@ -5,10 +5,6 @@ mod bucket_tests;
|
||||
#[cfg(test)]
|
||||
mod consumed_buffer_tests;
|
||||
#[cfg(test)]
|
||||
mod cublaslt_rewrite_tests;
|
||||
#[cfg(test)]
|
||||
mod flashinfer;
|
||||
#[cfg(test)]
|
||||
mod fusion;
|
||||
#[cfg(test)]
|
||||
mod model_fuzz;
|
||||
|
||||
@@ -1,12 +1,7 @@
|
||||
//! Fuzz tests for model-architecture-specific subgraphs (Llama, Gemma, Qwen).
|
||||
//!
|
||||
//! Tests many random e-graph extraction variants (genomes) against a candle CPU
|
||||
//! reference to catch incorrect HLIR kernel rewrites.
|
||||
//!
|
||||
//! These are marked ignored by default because each test builds a model-shaped
|
||||
//! graph and checks many extraction genomes. Run them explicitly with
|
||||
//! `cargo test -p luminal_cuda_lite -- --ignored` when touching extraction,
|
||||
//! scheduling, or model-pattern rewrites.
|
||||
//! reference to catch incorrect HLIR kernel fallback rewrites.
|
||||
|
||||
use luminal::prelude::*;
|
||||
|
||||
@@ -382,38 +377,32 @@ mod llama {
|
||||
const EPS: f32 = 1e-5;
|
||||
|
||||
#[test]
|
||||
#[ignore = "expensive CUDA model genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
fn fuzz_llama_mlp() {
|
||||
fuzz_mlp(SEQ, HIDDEN, INTERMEDIATE, 42);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "expensive CUDA model genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
fn fuzz_llama_norm_proj() {
|
||||
fuzz_norm_proj(SEQ, HIDDEN, PROJ_DIM, EPS, 100);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "expensive CUDA model genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
fn fuzz_llama_layer() {
|
||||
fuzz_layer_no_attn(SEQ, HIDDEN, INTERMEDIATE, PROJ_DIM, EPS, 200);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "expensive CUDA model genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
fn fuzz_llama_mlp_seq1() {
|
||||
fuzz_mlp(1, HIDDEN, INTERMEDIATE, 300);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "expensive CUDA model genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
fn fuzz_llama_mlp_seq7() {
|
||||
fuzz_mlp(7, HIDDEN, INTERMEDIATE, 400);
|
||||
}
|
||||
|
||||
/// Force HLIR-only (no block ops) to specifically test that extraction path.
|
||||
/// Force HLIR-only (no block ops) to specifically test the fallback path.
|
||||
#[test]
|
||||
#[ignore = "expensive CUDA model genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
fn fuzz_llama_mlp_hlir_only() {
|
||||
fuzz_mlp_hlir_only(SEQ, HIDDEN, INTERMEDIATE, 450);
|
||||
}
|
||||
@@ -435,26 +424,22 @@ mod gemma {
|
||||
const EPS: f32 = 1e-6;
|
||||
|
||||
#[test]
|
||||
#[ignore = "expensive CUDA model genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
fn fuzz_gemma_mlp() {
|
||||
fuzz_mlp(SEQ, HIDDEN, INTERMEDIATE, 500);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "expensive CUDA model genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
fn fuzz_gemma_norm_proj() {
|
||||
fuzz_norm_proj(SEQ, HIDDEN, Q_DIM, EPS, 600);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "expensive CUDA model genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
fn fuzz_gemma_layer() {
|
||||
fuzz_layer_no_attn(SEQ, HIDDEN, INTERMEDIATE, Q_DIM, EPS, 700);
|
||||
}
|
||||
|
||||
/// Gemma has extra post-attention and post-feedforward norms.
|
||||
#[test]
|
||||
#[ignore = "expensive CUDA model genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
fn fuzz_gemma_layer_full_norms() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
@@ -579,14 +564,12 @@ mod gemma {
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "expensive CUDA model genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
fn fuzz_gemma_mlp_seq1() {
|
||||
fuzz_mlp(1, HIDDEN, INTERMEDIATE, 900);
|
||||
}
|
||||
|
||||
/// Force HLIR-only to test that extraction path with Gemma dimensions.
|
||||
/// Force HLIR-only to test fallback path with Gemma dimensions.
|
||||
#[test]
|
||||
#[ignore = "expensive CUDA model genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
fn fuzz_gemma_mlp_hlir_only() {
|
||||
fuzz_mlp_hlir_only(SEQ, HIDDEN, INTERMEDIATE, 950);
|
||||
}
|
||||
@@ -608,26 +591,22 @@ mod qwen {
|
||||
const EPS: f32 = 1e-6;
|
||||
|
||||
#[test]
|
||||
#[ignore = "expensive CUDA model genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
fn fuzz_qwen_mlp() {
|
||||
fuzz_mlp(SEQ, HIDDEN, INTERMEDIATE, 1000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "expensive CUDA model genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
fn fuzz_qwen_norm_proj() {
|
||||
fuzz_norm_proj(SEQ, HIDDEN, Q_DIM, EPS, 1100);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "expensive CUDA model genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
fn fuzz_qwen_layer() {
|
||||
fuzz_layer_no_attn(SEQ, HIDDEN, INTERMEDIATE, Q_DIM, EPS, 1200);
|
||||
}
|
||||
|
||||
/// Qwen uses tied embeddings: lm_head = embedding^T
|
||||
#[test]
|
||||
#[ignore = "expensive CUDA model genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
fn fuzz_qwen_lm_head() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
@@ -689,20 +668,17 @@ mod qwen {
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "expensive CUDA model genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
fn fuzz_qwen_mlp_seq1() {
|
||||
fuzz_mlp(1, HIDDEN, INTERMEDIATE, 1400);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "expensive CUDA model genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
fn fuzz_qwen_mlp_seq7() {
|
||||
fuzz_mlp(7, HIDDEN, INTERMEDIATE, 1500);
|
||||
}
|
||||
|
||||
/// Force HLIR-only to test that extraction path with Qwen dimensions.
|
||||
/// Force HLIR-only to test fallback path with Qwen dimensions.
|
||||
#[test]
|
||||
#[ignore = "expensive CUDA model genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
fn fuzz_qwen_mlp_hlir_only() {
|
||||
fuzz_mlp_hlir_only(SEQ, HIDDEN, INTERMEDIATE, 1550);
|
||||
}
|
||||
|
||||
@@ -16,16 +16,9 @@ use super::utilities::{
|
||||
test_binary_cuda, test_mod, test_unary_cuda, to_candle_dtype,
|
||||
};
|
||||
|
||||
// The property-based op tests each build/search CUDA graphs for multiple random
|
||||
// shapes. They are ignored by default to keep the main CUDA unit suite short;
|
||||
// run `cargo test -p luminal_cuda_lite -- --ignored` for the broader sweeps.
|
||||
|
||||
proptest! {
|
||||
#![proptest_config(ProptestConfig::with_cases(5))]
|
||||
|
||||
#[ignore = "expensive CUDA op proptest sweep; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_add(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
|
||||
let gen_lambda = |n, s| random_f32_vec(n, s, -0.5, 0.5);
|
||||
@@ -35,9 +28,6 @@ proptest! {
|
||||
test_binary_cuda((y, x), (y, x), |a, b| a + b, |a, b| (&a + &b).unwrap(), gen_lambda, gen_lambda, seed, rtol, atol);
|
||||
}
|
||||
|
||||
#[ignore = "expensive CUDA op proptest sweep; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_mul(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
|
||||
let gen_lambda = |n, s| random_f32_vec(n, s, -0.5, 0.5);
|
||||
@@ -47,27 +37,18 @@ proptest! {
|
||||
test_binary_cuda((y, x), (y, x), |a, b| a * b, |a, b| (&a * &b).unwrap(), gen_lambda, gen_lambda, seed, rtol, atol);
|
||||
}
|
||||
|
||||
#[ignore = "expensive CUDA op proptest sweep; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_max(rows in 1usize..8, cols in 1usize..8, seed in any::<u64>()) {
|
||||
let gen_lambda = |n, s| random_f32_vec(n, s, -0.5, 0.5);
|
||||
test_unary_cuda((rows, cols), |a| a.max(1), |a| a.max(1).unwrap(), gen_lambda, seed);
|
||||
}
|
||||
|
||||
#[ignore = "expensive CUDA op proptest sweep; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_mean(rows in 1usize..8, cols in 1usize..8, seed in any::<u64>()) {
|
||||
let gen_lambda = |n, s| random_f32_vec(n, s, -0.5, 0.5);
|
||||
test_unary_cuda((rows, cols), |a| a.mean(1), |a| a.mean(1).unwrap(), gen_lambda, seed);
|
||||
}
|
||||
|
||||
#[ignore = "expensive CUDA op proptest sweep; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_matmul(
|
||||
(m, n, k, a_col_major, b_col_major, m_slice, k_slice, n_slice, dtype) in
|
||||
@@ -138,8 +119,6 @@ proptest! {
|
||||
}
|
||||
|
||||
// Unary ops tests
|
||||
#[ignore = "expensive CUDA op proptest sweep; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
|
||||
#[test]
|
||||
fn test_exp2(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
|
||||
// exp2(x) = 2^x, verified by computing 2^x using exp(x * ln(2))
|
||||
@@ -148,9 +127,6 @@ proptest! {
|
||||
test_unary_cuda((y, x), |a| a.exp2(), |a| (a * 2.0f64.ln()).unwrap().exp().unwrap(), gen_lambda, seed);
|
||||
}
|
||||
|
||||
#[ignore = "expensive CUDA op proptest sweep; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_log2(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
|
||||
// log2(x) = ln(x) / ln(2)
|
||||
@@ -159,9 +135,6 @@ proptest! {
|
||||
test_unary_cuda((y, x), |a| a.log2(), |a| (a.log().unwrap() / 2.0f64.ln()).unwrap(), gen_lambda, seed);
|
||||
}
|
||||
|
||||
#[ignore = "expensive CUDA op proptest sweep; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_sin(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
|
||||
let gen_lambda = |n, s| random_f32_vec(n, s, -0.5, 0.5);
|
||||
@@ -169,9 +142,6 @@ proptest! {
|
||||
test_unary_cuda((y, x), |a| a.sin(), |a| a.sin().unwrap(), gen_lambda, seed);
|
||||
}
|
||||
|
||||
#[ignore = "expensive CUDA op proptest sweep; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_recip(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
|
||||
let gen_lambda = |n, s| random_f32_vec(n, s, 0.1, 0.5);
|
||||
@@ -179,9 +149,6 @@ proptest! {
|
||||
test_unary_cuda((y, x), |a| a.reciprocal(), |a| a.recip().unwrap(), gen_lambda, seed);
|
||||
}
|
||||
|
||||
#[ignore = "expensive CUDA op proptest sweep; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_sqrt(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
|
||||
let gen_lambda = |n, s| random_f32_vec(n, s, 0.1, 0.6);
|
||||
@@ -190,17 +157,12 @@ proptest! {
|
||||
}
|
||||
|
||||
// Binary ops tests
|
||||
#[ignore = "expensive CUDA op proptest sweep; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
|
||||
#[test]
|
||||
fn test_mod_op(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
|
||||
test_mod(x, x, |a, b| a % b, seed);
|
||||
test_mod((y, x), (y, x), |a, b| a % b, seed);
|
||||
}
|
||||
|
||||
#[ignore = "expensive CUDA op proptest sweep; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_less_than(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
|
||||
let gen_lambda = |n, s| random_f32_vec(n, s, -99.0, 100.0).into_iter().map(|v| v.floor()).collect();
|
||||
@@ -373,8 +335,6 @@ proptest! {
|
||||
#![proptest_config(ProptestConfig::with_cases(5))]
|
||||
|
||||
/// Test F32 -> F16 -> F32 cast roundtrip with random values.
|
||||
#[ignore = "expensive CUDA op proptest sweep; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
|
||||
#[test]
|
||||
fn test_cast_f16_random(size in 1usize..200, seed in any::<u64>()) {
|
||||
use luminal::dtype::DType;
|
||||
@@ -567,9 +527,6 @@ fn fuzz_test_cuda_genomes_impl(seed: u64) {
|
||||
proptest! {
|
||||
#![proptest_config(ProptestConfig::with_cases(3))]
|
||||
|
||||
// This walks random extraction genomes and is intentionally opt-in so the
|
||||
// default CUDA unit suite keeps a tight feedback loop.
|
||||
#[ignore = "expensive CUDA genome fuzzing; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
#[test]
|
||||
fn fuzz_test_cuda_genomes(seed in any::<u64>()) {
|
||||
fuzz_test_cuda_genomes_impl(seed);
|
||||
@@ -637,9 +594,6 @@ fn run_embed_test(vocab_size: usize, embed_dim: usize, seq_len: usize, seed: u64
|
||||
proptest! {
|
||||
#![proptest_config(ProptestConfig::with_cases(5))]
|
||||
|
||||
#[ignore = "expensive CUDA op proptest sweep; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_embed_proptest(
|
||||
vocab_size in 10usize..200,
|
||||
|
||||
@@ -3,7 +3,10 @@ use luminal::{dtype::DType, prelude::*, shape::Expression};
|
||||
|
||||
use super::utilities::{assert_close, get_cuda_stream, random_f32_vec};
|
||||
use crate::{
|
||||
host::moe::{GLUMoE, GLUMoEMode},
|
||||
host::{
|
||||
HostOp,
|
||||
moe::{GLUMoE, GLUMoEMode},
|
||||
},
|
||||
runtime::CudaRuntime,
|
||||
};
|
||||
|
||||
@@ -71,9 +74,9 @@ fn build_qwen_moe_graph() -> QwenMoeGraph {
|
||||
.unsqueeze(2)
|
||||
.matmul(down_gathered.transpose(2, 3))
|
||||
.squeeze(2);
|
||||
let mut weights_exp = top_k_values.unsqueeze(top_k_values.dims().len());
|
||||
weights_exp.shape.expand(down_out.dims());
|
||||
let output = (down_out * weights_exp).sum(n - 1).output();
|
||||
let output = (down_out * top_k_values.unsqueeze(top_k_values.dims().len()))
|
||||
.sum(n - 1)
|
||||
.output();
|
||||
|
||||
QwenMoeGraph {
|
||||
graph: cx,
|
||||
@@ -130,9 +133,9 @@ fn build_gemma_moe_graph() -> GemmaMoeGraph {
|
||||
.unsqueeze(2)
|
||||
.matmul(down_gathered.transpose(2, 3))
|
||||
.squeeze(2);
|
||||
let mut weights_exp = top_k_weights.unsqueeze(top_k_weights.dims().len());
|
||||
weights_exp.shape.expand(down_out.dims());
|
||||
let output = (down_out * weights_exp).sum(n - 1).output();
|
||||
let output = (down_out * top_k_weights.unsqueeze(top_k_weights.dims().len()))
|
||||
.sum(n - 1)
|
||||
.output();
|
||||
|
||||
GemmaMoeGraph {
|
||||
graph: cx,
|
||||
@@ -173,9 +176,10 @@ fn gemma_gelu(x: GraphTensor) -> GraphTensor {
|
||||
}
|
||||
|
||||
fn glumoe_modes(rt: &CudaRuntime) -> Vec<GLUMoEMode> {
|
||||
rt.host_ops()
|
||||
.into_iter()
|
||||
.filter_map(|op| {
|
||||
rt.llir_graph()
|
||||
.node_weights()
|
||||
.filter_map(|node| {
|
||||
let op = node.to_dialect::<dyn HostOp>()?;
|
||||
op.as_any()
|
||||
.downcast_ref::<GLUMoE>()
|
||||
.map(|glumoe| glumoe.mode)
|
||||
|
||||
@@ -136,15 +136,14 @@ pub fn gpu_compute_cap() -> Option<(i32, i32)> {
|
||||
|
||||
/// Check if the current GPU supports the given dtype for tensor core / WMMA operations.
|
||||
pub fn gpu_supports_dtype(dtype: luminal::dtype::DType) -> bool {
|
||||
let Some((major, minor)) = gpu_compute_cap() else {
|
||||
let Some((major, _)) = gpu_compute_cap() else {
|
||||
return false;
|
||||
};
|
||||
match dtype {
|
||||
luminal::dtype::DType::Bf16 => major >= 8, // Ampere (sm_80+)
|
||||
luminal::dtype::DType::F8E4M3 | luminal::dtype::DType::F8E5M2 => {
|
||||
major > 8 || (major == 8 && minor >= 9)
|
||||
} // Ada/Hopper (sm_89+)
|
||||
luminal::dtype::DType::F4E2M1 | luminal::dtype::DType::F8UE8M0 => major >= 10, // Blackwell (sm_100+)
|
||||
luminal::dtype::DType::F4E2M1
|
||||
| luminal::dtype::DType::F8E4M3
|
||||
| luminal::dtype::DType::F8UE8M0 => major >= 10, // Blackwell (sm_100+)
|
||||
_ => true,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -102,21 +102,6 @@ fn metal_copy_value(dtype: DType, buffer: &str, index: &str) -> String {
|
||||
}
|
||||
}
|
||||
|
||||
fn metal_binary_op_values(
|
||||
output_dtype: DType,
|
||||
a_dtype: DType,
|
||||
b_dtype: DType,
|
||||
a_idx: &str,
|
||||
b_idx: &str,
|
||||
) -> (String, String) {
|
||||
let read: fn(DType, &str, &str) -> String = if output_dtype == DType::Int {
|
||||
metal_copy_value
|
||||
} else {
|
||||
metal_numeric_read
|
||||
};
|
||||
(read(a_dtype, "a", a_idx), read(b_dtype, "b", b_idx))
|
||||
}
|
||||
|
||||
fn call_sort_from_args(sort: &SortDef, args: &Args) -> EggTerm {
|
||||
let mut filtered_args = Args::new();
|
||||
for field in &sort.fields {
|
||||
@@ -132,11 +117,9 @@ fn unary_dtype_rewrite(hlir_sort: &SortDef, metal_sort: &SortDef) -> Rule {
|
||||
args["__inputs"].clone(),
|
||||
);
|
||||
let dt = v("?__dt");
|
||||
rule(union(hlir_match.clone(), metal_op.clone()))
|
||||
.subsume(hlir_match)
|
||||
rule(union(hlir_match, metal_op.clone()))
|
||||
.set(dtype(metal_op), dt.clone())
|
||||
.fact(eq(dt, dtype(args["inp"].clone())))
|
||||
.ruleset("kernel_lower")
|
||||
}
|
||||
|
||||
fn binary_dtype_rewrite(hlir_sort: &SortDef, metal_sort: &SortDef) -> Rule {
|
||||
@@ -146,11 +129,9 @@ fn binary_dtype_rewrite(hlir_sort: &SortDef, metal_sort: &SortDef) -> Rule {
|
||||
args["__inputs"].clone(),
|
||||
);
|
||||
let dt = v("?__dt");
|
||||
rule(union(hlir_match.clone(), metal_op.clone()))
|
||||
.subsume(hlir_match)
|
||||
rule(union(hlir_match, metal_op.clone()))
|
||||
.set(dtype(metal_op), dt.clone())
|
||||
.fact(eq(dt, dtype(args["inp_a"].clone())))
|
||||
.ruleset("kernel_lower")
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
@@ -304,7 +285,7 @@ macro_rules! metal_unary_op {
|
||||
device {input_ty} *inp [[buffer(0)]],
|
||||
device {output_ty} *out [[buffer(1)]],
|
||||
constant int *dyn [[buffer({dyn_buffer_index})]],
|
||||
constant uint &n_elements [[buffer({n_elements_index})]],
|
||||
device uint &n_elements [[buffer({n_elements_index})]],
|
||||
uint idx [[thread_position_in_grid]]
|
||||
) {{
|
||||
if (idx < n_elements) {{
|
||||
@@ -388,10 +369,8 @@ impl EgglogOp for MetalAdd {
|
||||
|
||||
vec![
|
||||
binary_dtype_rewrite(&Add::default().sort(), &self.sort()),
|
||||
rule(union(hlir_match2.clone(), metal_op2.clone()))
|
||||
.subsume(hlir_match2)
|
||||
.set(dtype(metal_op2), app(&SORTS.f32_dt, vec![]))
|
||||
.ruleset("kernel_lower"),
|
||||
rule(union(hlir_match2, metal_op2.clone()))
|
||||
.set(dtype(metal_op2), app(&SORTS.f32_dt, vec![])),
|
||||
]
|
||||
}
|
||||
|
||||
@@ -444,7 +423,8 @@ impl MetalKernelOp for MetalAdd {
|
||||
let a_idx = lower_expression_for_metal(&a_index, "idx");
|
||||
let b_idx = lower_expression_for_metal(&b_index, "idx");
|
||||
let out_idx = lower_expression_for_metal(&out_index, "idx");
|
||||
let (a_val, b_val) = metal_binary_op_values(output_dtype, a_dtype, b_dtype, &a_idx, &b_idx);
|
||||
let a_val = metal_numeric_read(a_dtype, "a", &a_idx);
|
||||
let b_val = metal_numeric_read(b_dtype, "b", &b_idx);
|
||||
let out_val = metal_numeric_write(output_dtype, &format!("({a_val}) + ({b_val})"));
|
||||
|
||||
let source = format!(
|
||||
@@ -457,7 +437,7 @@ impl MetalKernelOp for MetalAdd {
|
||||
device {b_ty} *b [[buffer(1)]],
|
||||
device {out_ty} *out [[buffer(2)]],
|
||||
constant int *dyn [[buffer({dyn_buffer_index})]],
|
||||
constant uint &n_elements [[buffer({n_elements_index})]],
|
||||
device uint &n_elements [[buffer({n_elements_index})]],
|
||||
uint idx [[thread_position_in_grid]]
|
||||
) {{
|
||||
if (idx < n_elements) {{
|
||||
@@ -576,7 +556,8 @@ impl MetalKernelOp for MetalMul {
|
||||
let a_idx = lower_expression_for_metal(&a_index, "idx");
|
||||
let b_idx = lower_expression_for_metal(&b_index, "idx");
|
||||
let out_idx = lower_expression_for_metal(&out_index, "idx");
|
||||
let (a_val, b_val) = metal_binary_op_values(output_dtype, a_dtype, b_dtype, &a_idx, &b_idx);
|
||||
let a_val = metal_numeric_read(a_dtype, "a", &a_idx);
|
||||
let b_val = metal_numeric_read(b_dtype, "b", &b_idx);
|
||||
let out_val = metal_numeric_write(output_dtype, &format!("({a_val}) * ({b_val})"));
|
||||
|
||||
let source = format!(
|
||||
@@ -589,7 +570,7 @@ impl MetalKernelOp for MetalMul {
|
||||
device {b_ty} *b [[buffer(1)]],
|
||||
device {out_ty} *out [[buffer(2)]],
|
||||
constant int *dyn [[buffer({dyn_buffer_index})]],
|
||||
constant uint &n_elements [[buffer({n_elements_index})]],
|
||||
device uint &n_elements [[buffer({n_elements_index})]],
|
||||
uint idx [[thread_position_in_grid]]
|
||||
) {{
|
||||
if (idx < n_elements) {{
|
||||
@@ -718,13 +699,9 @@ impl MetalKernelOp for MetalMod {
|
||||
let a_idx = lower_expression_for_metal(&a_index, "idx");
|
||||
let b_idx = lower_expression_for_metal(&b_index, "idx");
|
||||
let out_idx = lower_expression_for_metal(&out_index, "idx");
|
||||
let (a_val, b_val) = metal_binary_op_values(output_dtype, a_dtype, b_dtype, &a_idx, &b_idx);
|
||||
let out_expr = if output_dtype == DType::Int {
|
||||
format!("({a_val}) % ({b_val})")
|
||||
} else {
|
||||
format!("fmod({a_val}, {b_val})")
|
||||
};
|
||||
let out_val = metal_numeric_write(output_dtype, &out_expr);
|
||||
let a_val = metal_numeric_read(a_dtype, "a", &a_idx);
|
||||
let b_val = metal_numeric_read(b_dtype, "b", &b_idx);
|
||||
let out_val = metal_numeric_write(output_dtype, &format!("fmod({a_val}, {b_val})"));
|
||||
|
||||
let source = format!(
|
||||
r#"
|
||||
@@ -736,7 +713,7 @@ impl MetalKernelOp for MetalMod {
|
||||
device {b_ty} *b [[buffer(1)]],
|
||||
device {out_ty} *out [[buffer(2)]],
|
||||
constant int *dyn [[buffer({dyn_buffer_index})]],
|
||||
constant uint &n_elements [[buffer({n_elements_index})]],
|
||||
device uint &n_elements [[buffer({n_elements_index})]],
|
||||
uint idx [[thread_position_in_grid]]
|
||||
) {{
|
||||
if (idx < n_elements) {{
|
||||
@@ -876,7 +853,7 @@ impl MetalKernelOp for MetalLessThan {
|
||||
device {b_ty} *b [[buffer(1)]],
|
||||
device {out_ty} *out [[buffer(2)]],
|
||||
constant int *dyn [[buffer({dyn_buffer_index})]],
|
||||
constant uint &n_elements [[buffer({n_elements_index})]],
|
||||
device uint &n_elements [[buffer({n_elements_index})]],
|
||||
uint idx [[thread_position_in_grid]]
|
||||
) {{
|
||||
if (idx < n_elements) {{
|
||||
@@ -1023,7 +1000,7 @@ impl MetalKernelOp for MetalSumReduce {
|
||||
const device {input_ty} *in [[buffer(0)]],
|
||||
device {output_ty} *out [[buffer(1)]],
|
||||
constant int *dyn [[buffer({dyn_buffer_index})]],
|
||||
constant uint &n_outputs [[buffer({n_outputs_index})]],
|
||||
device uint &n_outputs [[buffer({n_outputs_index})]],
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint tid [[thread_index_in_threadgroup]],
|
||||
uint simd_lane [[thread_index_in_simdgroup]],
|
||||
@@ -1204,7 +1181,7 @@ impl MetalKernelOp for MetalMaxReduce {
|
||||
const device {input_ty} *in [[buffer(0)]],
|
||||
device {output_ty} *out [[buffer(1)]],
|
||||
constant int *dyn [[buffer({dyn_buffer_index})]],
|
||||
constant uint &n_outputs [[buffer({n_outputs_index})]],
|
||||
device uint &n_outputs [[buffer({n_outputs_index})]],
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint tid [[thread_index_in_threadgroup]],
|
||||
uint simd_lane [[thread_index_in_simdgroup]],
|
||||
@@ -1742,10 +1719,8 @@ impl EgglogOp for MetalConstant {
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
let (args, const_match) = new_op_call(&Constant::default().sort(), &[]);
|
||||
let metal_op = call_sort_from_args(&self.sort(), &args);
|
||||
vec![rule(union(const_match.clone(), metal_op.clone()))
|
||||
.subsume(const_match)
|
||||
.set(dtype(metal_op), app(&SORTS.f32_dt, vec![]))
|
||||
.ruleset("kernel_lower")]
|
||||
vec![rule(union(const_match, metal_op.clone()))
|
||||
.set(dtype(metal_op), app(&SORTS.f32_dt, vec![]))]
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
@@ -1852,10 +1827,8 @@ impl EgglogOp for MetalIota {
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
let (args, iota_match) = new_op_call(&Iota::default().sort(), &[]);
|
||||
let metal_op = call_sort_from_args(&self.sort(), &args);
|
||||
vec![rule(union(iota_match.clone(), metal_op.clone()))
|
||||
.subsume(iota_match)
|
||||
.set(dtype(metal_op), app(&SORTS.int_dt, vec![]))
|
||||
.ruleset("kernel_lower")]
|
||||
vec![rule(union(iota_match, metal_op.clone()))
|
||||
.set(dtype(metal_op), app(&SORTS.int_dt, vec![]))]
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
@@ -1899,7 +1872,7 @@ impl MetalKernelOp for MetalIota {
|
||||
kernel void mkernel(
|
||||
device int *out [[buffer(0)]],
|
||||
constant int *dyn [[buffer({dyn_buffer_index})]],
|
||||
constant uint &n_elements [[buffer({n_elements_index})]],
|
||||
device uint &n_elements [[buffer({n_elements_index})]],
|
||||
uint idx [[thread_position_in_grid]]
|
||||
) {{
|
||||
if (idx < n_elements) {{
|
||||
@@ -1951,7 +1924,6 @@ impl MetalKernelOp for MetalIota {
|
||||
pub struct MetalGather {
|
||||
out_shape: Vec<Expression>,
|
||||
index_stride: Vec<Expression>,
|
||||
data_shape: Vec<Expression>,
|
||||
data_stride: Vec<Expression>,
|
||||
out_stride: Vec<Expression>,
|
||||
}
|
||||
@@ -1966,7 +1938,6 @@ impl EgglogOp for MetalGather {
|
||||
("indexes", IR),
|
||||
("index_strides", ELIST),
|
||||
("data", IR),
|
||||
("data_shape", ELIST),
|
||||
("data_strides", ELIST),
|
||||
("out_strides", ELIST),
|
||||
],
|
||||
@@ -1988,7 +1959,6 @@ impl EgglogOp for MetalGather {
|
||||
gather_args["index_strides"].clone(),
|
||||
),
|
||||
("data".to_string(), gather_args["data"].clone()),
|
||||
("data_shape".to_string(), gather_args["data_shape"].clone()),
|
||||
(
|
||||
"data_strides".to_string(),
|
||||
gather_args["data_strides"].clone(),
|
||||
@@ -1996,11 +1966,9 @@ impl EgglogOp for MetalGather {
|
||||
("out_strides".to_string(), out_strides),
|
||||
];
|
||||
let metal_op = self.sort().call(metal_args);
|
||||
vec![rule(union(gather_match.clone(), metal_op.clone()))
|
||||
.subsume(gather_match)
|
||||
vec![rule(union(gather_match, metal_op.clone()))
|
||||
.set(dtype(metal_op), dt.clone())
|
||||
.fact(eq(dt, dtype(gather_args["data"].clone())))
|
||||
.ruleset("kernel_lower")]
|
||||
.fact(eq(dt, dtype(gather_args["data"].clone())))]
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
@@ -2021,10 +1989,9 @@ impl EgglogOp for MetalGather {
|
||||
out_shape: extract_expr_list(egraph, children[0], list_cache, expr_cache).unwrap(),
|
||||
index_stride: extract_expr_list(egraph, children[2], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
data_shape: extract_expr_list(egraph, children[4], list_cache, expr_cache).unwrap(),
|
||||
data_stride: extract_expr_list(egraph, children[5], list_cache, expr_cache)
|
||||
data_stride: extract_expr_list(egraph, children[4], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
out_stride: extract_expr_list(egraph, children[6], list_cache, expr_cache).unwrap(),
|
||||
out_stride: extract_expr_list(egraph, children[5], list_cache, expr_cache).unwrap(),
|
||||
})),
|
||||
vec![children[1], children[3]],
|
||||
)
|
||||
@@ -2048,7 +2015,7 @@ impl MetalKernelOp for MetalGather {
|
||||
"idx",
|
||||
);
|
||||
let data_idx = lower_expression_for_metal(
|
||||
&flatten_strides(&self.data_shape, &self.data_stride),
|
||||
&flatten_strides(&self.out_shape, &self.data_stride),
|
||||
"gathered_index",
|
||||
);
|
||||
let gathered_val = metal_copy_value(data_dtype, "data", &data_idx);
|
||||
@@ -2063,7 +2030,7 @@ impl MetalKernelOp for MetalGather {
|
||||
const device {data_ty} *data [[buffer(1)]],
|
||||
device {out_ty} *out [[buffer(2)]],
|
||||
constant int *dyn [[buffer({dyn_buffer_index})]],
|
||||
constant uint &n_elements [[buffer({n_elements_index})]],
|
||||
device uint &n_elements [[buffer({n_elements_index})]],
|
||||
uint idx [[thread_position_in_grid]]
|
||||
) {{
|
||||
if (idx < n_elements) {{
|
||||
@@ -2089,10 +2056,6 @@ impl MetalKernelOp for MetalGather {
|
||||
.max(Expression::from(1))
|
||||
}
|
||||
|
||||
fn infer_output_dtype(&self, input_dtypes: &[DType]) -> DType {
|
||||
input_dtypes.get(1).copied().unwrap_or(DType::F32)
|
||||
}
|
||||
|
||||
fn encode(
|
||||
&self,
|
||||
encoder: &ComputeCommandEncoderRef,
|
||||
@@ -2214,11 +2177,9 @@ impl EgglogOp for MetalScatter {
|
||||
("out_strides".to_string(), out_strides),
|
||||
];
|
||||
let metal_op = self.sort().call(metal_args);
|
||||
vec![rule(union(scatter_match.clone(), metal_op.clone()))
|
||||
.subsume(scatter_match)
|
||||
vec![rule(union(scatter_match, metal_op.clone()))
|
||||
.set(dtype(metal_op), dt.clone())
|
||||
.fact(eq(dt, dtype(scatter_args["src"].clone())))
|
||||
.ruleset("kernel_lower")]
|
||||
.fact(eq(dt, dtype(scatter_args["src"].clone())))]
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
@@ -2282,7 +2243,7 @@ impl MetalKernelOp for MetalScatter {
|
||||
kernel void copy_kernel(
|
||||
device {out_ty} *out [[buffer(0)]],
|
||||
const device {dest_ty} *dest [[buffer(1)]],
|
||||
constant uint &n_elements [[buffer(2)]],
|
||||
device uint &n_elements [[buffer(2)]],
|
||||
constant int *dyn [[buffer({dyn_buffer_index})]],
|
||||
uint idx [[thread_position_in_grid]]
|
||||
) {{
|
||||
@@ -2316,7 +2277,7 @@ impl MetalKernelOp for MetalScatter {
|
||||
device {out_ty} *out [[buffer(0)]],
|
||||
const device int *indexes [[buffer(1)]],
|
||||
const device {src_ty} *src [[buffer(2)]],
|
||||
constant uint &n_elements [[buffer(3)]],
|
||||
device uint &n_elements [[buffer(3)]],
|
||||
constant int *dyn [[buffer({dyn_buffer_index})]],
|
||||
uint idx [[thread_position_in_grid]]
|
||||
) {{
|
||||
@@ -2447,10 +2408,7 @@ impl EgglogOp for MetalCast {
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
let (args, cast_match) = new_op_call(&Cast::default().sort(), &["inp"]);
|
||||
let metal_op = call_sort_from_args(&self.sort(), &args);
|
||||
vec![rule(union(cast_match.clone(), metal_op.clone()))
|
||||
.subsume(cast_match)
|
||||
.set(dtype(metal_op), args["dtype"].clone())
|
||||
.ruleset("kernel_lower")]
|
||||
vec![rule(union(cast_match, metal_op.clone())).set(dtype(metal_op), args["dtype"].clone())]
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
@@ -2509,7 +2467,7 @@ impl MetalKernelOp for MetalCast {
|
||||
device {input_ty} *inp [[buffer(0)]],
|
||||
device {output_ty} *out [[buffer(1)]],
|
||||
constant int *dyn [[buffer({dyn_buffer_index})]],
|
||||
constant uint &n_elements [[buffer({n_elements_index})]],
|
||||
device uint &n_elements [[buffer({n_elements_index})]],
|
||||
uint idx [[thread_position_in_grid]]
|
||||
) {{
|
||||
if (idx < n_elements) {{
|
||||
|
||||
@@ -282,8 +282,6 @@ impl Runtime for MetalRuntime {
|
||||
let pipeline = kernel_op.compile(&self.device, &input_dtypes, output_dtype);
|
||||
self.node_dtypes.insert(node, output_dtype);
|
||||
self.pipelines.insert(node, pipeline);
|
||||
} else {
|
||||
panic!("Metal runtime cannot execute unlowered LLIR node {node:?}");
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -294,7 +292,6 @@ impl Runtime for MetalRuntime {
|
||||
llir_graph: &LLIRGraph,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
trials: usize,
|
||||
_timeout: Option<std::time::Duration>,
|
||||
) -> (Self::ProfileMetric, String) {
|
||||
self.load_llir(llir_graph);
|
||||
self.allocate_intermediate_buffers(dyn_map);
|
||||
|
||||
@@ -250,23 +250,6 @@ fn dynamic_dim_sum_reduce_runs() {
|
||||
assert_close(&out, &[9.0, 12.0], 0.001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn metal_int_arithmetic_preserves_large_values() {
|
||||
let mut cx = Graph::default();
|
||||
let token = cx.tensor(1).as_dtype(DType::Int);
|
||||
let large_index = (token * 1024) + 123;
|
||||
let mod_output = (large_index % 65_537).output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.set_data(token, &[16_385i32]);
|
||||
rt = cx.search(rt, 1);
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
assert_eq!(rt.get_f32(mod_output), vec![891.0]);
|
||||
}
|
||||
|
||||
proptest! {
|
||||
#![proptest_config(ProptestConfig::with_cases(5))]
|
||||
|
||||
@@ -988,28 +971,6 @@ fn test_scatter_basic() {
|
||||
assert_close(&out, &[0.0, 10.0, 0.0, 20.0, 30.0], 0.001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gather_noncontiguous_data_uses_data_shape() {
|
||||
let mut cx = Graph::default();
|
||||
let input = cx.tensor((4, 3));
|
||||
let data = input.transpose(0, 1);
|
||||
let indexes = cx.tensor((2, 2)).as_dtype(DType::Int);
|
||||
let out = data.gather(indexes).output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.set_data(
|
||||
input,
|
||||
&[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0],
|
||||
);
|
||||
rt.set_data(indexes, &[0.0, 3.0, 4.0, 7.0]);
|
||||
rt = cx.search(rt, 1);
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
assert_close(&rt.get_f32(out), &[0.0, 9.0, 1.0, 10.0], 0.001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scatter_into_nonzero_dest() {
|
||||
let mut cx = Graph::default();
|
||||
@@ -1051,21 +1012,3 @@ fn test_scatter_all_positions() {
|
||||
let out = rt.get_f32(result);
|
||||
assert_close(&out, &[10.0, 20.0, 30.0, 40.0], 0.001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gather_preserves_data_dtype() {
|
||||
let mut cx = Graph::default();
|
||||
let data = cx.tensor(2);
|
||||
let indexes = cx.tensor(1).as_dtype(DType::Int);
|
||||
let out = data.gather(indexes).output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.set_data(data, &[1.25, 2.5]);
|
||||
rt.set_data(indexes, &[1.0]);
|
||||
rt = cx.search(rt, 1);
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
assert_close(&rt.get_f32(out), &[2.5], 0.001);
|
||||
}
|
||||
|
||||
@@ -61,8 +61,7 @@ impl MoE {
|
||||
let expert_out = expanded_act.matmul(gathered).squeeze(n); // [batch.., k, out]
|
||||
|
||||
// 6. Weighted sum over experts: [batch.., k, out] * [batch.., k, 1] → sum(k) → [batch.., out]
|
||||
let mut weights_exp = top_k_values.unsqueeze(top_k_values.dims().len()); // [batch.., k, 1]
|
||||
weights_exp.shape.expand(expert_out.dims());
|
||||
let weights_exp = top_k_values.unsqueeze(top_k_values.dims().len()); // [batch.., k, 1]
|
||||
(expert_out * weights_exp).sum(n - 1)
|
||||
}
|
||||
}
|
||||
@@ -479,8 +478,7 @@ mod tests {
|
||||
let down_out = hidden_exp.matmul(down_gathered.transpose(2, 3)).squeeze(2); // [s, k, H]
|
||||
|
||||
// 7. Weighted sum over k experts → [s, H]
|
||||
let mut weights_exp = top_k_values.unsqueeze(top_k_values.dims().len()); // [s, k, 1]
|
||||
weights_exp.shape.expand(down_out.dims());
|
||||
let weights_exp = top_k_values.unsqueeze(top_k_values.dims().len()); // [s, k, 1]
|
||||
let _output = (down_out * weights_exp).sum(n - 1).output();
|
||||
|
||||
// Dump the HLIR to egglog
|
||||
|
||||
@@ -782,86 +782,3 @@ identical across all attempts (dtype issue) vs varying (actual numerical issue).
|
||||
3. **Why "defensive fallback" framing is misleading**: it implies the LLIR is broken. It isn't. The forward-walk-only `body_nodes` definition just doesn't cover this case, because the case requires no per-iter cloning at all. A *node not reachable from any loop input marker has no input-marker ancestor*, so by construction its value doesn't depend on the loop's per-iter state.
|
||||
4. **Cleaner formulation**: name the concept. Compute an `iteration_invariant_slots: HashSet<LoopStart>` set at the same time `start_meta` is built, with the rule `body_producer ∉ body_nodes ⇒ iteration_invariant`. `resolve_src` and `marker_post_sub` then have explicit branches: if the slot is invariant, use `body_producer` directly; otherwise the standard per-iter clone lookup. The behavior is the same as the `unwrap_or` band-aid, but the code now documents that this is a real, sound case the unroll handles correctly — not a panic suppressor.
|
||||
5. **Principle**: when an `unwrap_or` papers over a case that turns out to be semantically valid, the right cleanup isn't to keep the `unwrap_or` and add a comment — it's to name the case. Hoist the predicate into a set or enum and branch on it explicitly. The compiler then enforces that every consumer of the per-iter cloning machinery has an opinion on iteration-invariant slots, instead of silently relying on a `Map::get` returning `None` at the right moment.
|
||||
|
||||
---
|
||||
|
||||
## 2026-04-30 — `translate_grouped_mm` casted the full expert weight to F32, OOMing search on Qwen3-MoE
|
||||
|
||||
### What the symptom was
|
||||
|
||||
`benchmarks/ttft/run.py --config qwen3-moe` crashed every search-profile attempt with:
|
||||
```
|
||||
crates/luminal_cuda_lite/src/runtime.rs:711: called `Result::unwrap()` on an `Err` value:
|
||||
DriverError(CUDA_ERROR_OUT_OF_MEMORY, "out of memory")
|
||||
```
|
||||
The DB shows this had been failing every run for ~2 weeks. The rust `examples/qwen3_moe` ran fine end-to-end. python_baseline / python_torch_compile / qwen3-4b were all fine — only python_luminal × qwen3-moe failed.
|
||||
|
||||
### What the actual root cause was
|
||||
|
||||
`translate_grouped_mm` in `crates/luminal_python/rust/src/translator/tensor.rs` was lowering HF's `_grouped_mm(input, weight, offs)` op to a *full-broadcast* batched matmul plus a group-mask:
|
||||
|
||||
```rust
|
||||
let weight_f = weight.cast(DType::F32); // [G=128, K, N] cast → 1.5 GB / layer
|
||||
let input_batched = input_f.expand_dim(0, g);
|
||||
let all_out = input_batched.matmul(weight_f); // [G, S, N]
|
||||
let mask = ... (g_arange == expert_id).cast(F32);
|
||||
let out = (all_out * mask.expand_dim(2, n)).sum(0); // mask + sum over G
|
||||
```
|
||||
|
||||
The full `[G, K, N]` F32 cast intermediate is 1.5 GB / layer for gate-up and 0.6 GB / layer for down on Qwen3-30B-A3B. With 60 GB of persistent bf16 weights already on a 97 GB GPU, the search-time profiler ran out of memory allocating those casts.
|
||||
|
||||
By contrast, `examples/qwen3_moe`'s `gather_experts` gathers only the top-K active experts per token first, then casts that small `[s, k, d1, d2]` slice (~100 MB / layer). The GLUMoE host op (`crates/luminal_cuda_lite/src/host/moe/glumoe_rewrite.egg`) is also wired to this gather pattern.
|
||||
|
||||
### Why it was hard to find
|
||||
|
||||
1. **Code path was reasonable in isolation**: at small scale (`test_grouped_mm_fallback`: g=2, K=8, N=16) the broadcast version was fine — the F32 cast was only 1 KB, and search profiling never noticed.
|
||||
2. **The error reported "out of memory" but the rest of the system looked healthy**: 60 GB weights + 37 GB headroom looks like plenty until you realise 48 layers × 2.1 GB cast intermediates per layer doesn't fit, even after loop rolling.
|
||||
3. **The DB's `code 1` failures looked the same as a Python exception** — the actual panic site (`runtime.rs:711:64` `stream.alloc_zeros(needed_bytes).unwrap()`) had to be recovered from a tmux scrollback because the orchestrator's stdout was already torn down by the time we looked.
|
||||
|
||||
### The fix
|
||||
|
||||
Rewrote `translate_grouped_mm` to gather first, matmul second:
|
||||
|
||||
```rust
|
||||
// expert_id[m] = first g s.t. m < offs[g], clamped to [0, G-1]
|
||||
let expert_id = ge_boundary.sum(0).minimum_f32(g_max_f).cast(DType::Int);
|
||||
|
||||
// flat_idx = expert_id * (K*N) + iota('z', (K, N)) — same shape as
|
||||
// rust qwen3_moe's `gather_experts`
|
||||
let flat_idx = (expert_id * (k * n))
|
||||
.expand_dim(1, k).expand_dim(2, n)
|
||||
+ self.graph.iota(Expression::from('z'), (k, n)).expand_dim(0, s);
|
||||
|
||||
let weight_gathered = weight.gather(flat_idx); // [S, K, N], bf16
|
||||
let result = input.cast(F32).unsqueeze(1)
|
||||
.matmul(weight_gathered.cast(F32)) // [S, 1, N]
|
||||
.squeeze(1);
|
||||
```
|
||||
|
||||
Two important details:
|
||||
|
||||
1. **Clamp `expert_id` to `[0, G-1]`**: at search time, dummy data fills `offs` with all-1s (`make_ones_bytes` in `compile_backend`). For S>1 that pushes `expert_id` to G (boundary count = G), which is one past the last valid expert and OOBs the gather. HF's own grouped-MM forward also clamps for the same reason (invalid expert IDs from EP).
|
||||
2. **Don't cast the full weight**: the cast moved from before the batched-matmul (over `[G, K, N]`) to after the gather (over `[S, K, N]`). 16× shrink at prefill (S=top_k=8 vs G=128).
|
||||
|
||||
### Result
|
||||
|
||||
`search-iters=1` end-to-end works on Qwen3-30B-A3B: `BENCH_RESULT … "ttft_ms": 9350.5, "tpot_ms": 1166.7`. The OOM is gone.
|
||||
|
||||
`search-iters>=5` still crashes — but with a *different*, downstream `CUDA_ERROR_ILLEGAL_ADDRESS` during execution after search completes. That looks like the same family as the 2026-03-07 / 2026-03-09 egglog-extractor non-determinism bugs (some mutation during search picks a kernel/rewrite combo that's broken at this scale). It's a separate investigation — the gather-based lowering is correct in isolation (`test_grouped_mm_fallback` passes; a synthetic `g=128, S=8, K=2048, N=1536` bf16 test passes with max-diff ~2.4e-4).
|
||||
|
||||
### General principle
|
||||
|
||||
**When lowering an op that takes a per-row index over a large parameter, gather first and cast second — never cast the full parameter to F32 just because your matmul kernel is F32-only.** A "broadcast over G + mask" pattern is mathematically equivalent to "gather per-row" but materialises a G× larger intermediate — fine for tests, ruinous on real MoE checkpoints. When in doubt, mirror the rust example's pattern: the egglog fusion rules (GLUMoE here) are written to recognise the gather form, not the broadcast-and-mask form.
|
||||
|
||||
Also: search-time dummy-1 inputs are not the same shape as runtime inputs. Anything you compute from a runtime tensor (cumsum offsets, routing indices, mask boundaries) needs to remain in-bounds for the dummy. Clamp index-producing chains as a matter of course, not just when the math says you "should" — `make_ones_bytes` is a hostile witness.
|
||||
|
||||
## 2026-05-02 — Whisper port hit two missing-translator pitfalls
|
||||
|
||||
1. **Symptom**: Compiling a PyTorch port of Whisper-tiny.en through `luminal_backend` failed twice in a row at the dispatch table: first with `Unsupported ATen op: torch.ops.aten.gelu.default`, then with `full: unsupported fill value type ... -Infinity`.
|
||||
2. **Root cause #1**: the dispatch table in `crates/luminal_python/rust/src/translator/dispatch.rs` mapped `sigmoid`, `tanh`, `relu` etc. but not `gelu` or `silu`. Whisper's encoder uses `F.gelu`, so the activation hit a hole.
|
||||
3. **Root cause #2**: PyTorch serializes `float("-inf")` in PT2 as the string `"-Infinity"` (and `"NaN"`/`"Infinity"` analogously). `translate_full`'s `get_float_arg` only accepts numeric float/int payloads, so any `torch.full((..), -inf)` (the obvious way to write a causal mask) blows up. Decoder mask code is the most common spot.
|
||||
4. **Why it was tricky**: both errors arrive from inside `pt2_backend` with a stack trace that ends in `process_pt2`, hiding the actual ATen target inside the message. You only see the offending op name in the error string itself, so you have to read `RuntimeError: Failed to translate node N: …` carefully and grep `dispatch.rs` for it.
|
||||
5. **Fix in this session**:
|
||||
- Added `aten.gelu.default → a.gelu()` and `aten.silu.default → a.silu()` to `dispatch.rs`.
|
||||
- Worked around the `-Infinity` issue at the model level by using a finite `-1e10` for the causal mask in the example (matches the Rust example's convention). The cleaner fix (parsing `"-Infinity"`/`"Infinity"`/`"NaN"` strings in `get_float_arg` / `translate_full`) is left for a follow-up.
|
||||
6. **Principle**: when adding a new model that goes through the PT2 backend, expect to plug small holes in `dispatch.rs` and `translator/tensor.rs::translate_full`. The trace points at the python frame, not the Rust dispatch arm — open `dispatch.rs`, ctrl-F the offending op name, and add the one-liner. For float-shaped sentinel values (`-inf`, `inf`, `nan`), the export pipeline currently only accepts finite floats; either rewrite the model or extend the parser.
|
||||
|
||||
@@ -1,60 +0,0 @@
|
||||
# luminal_python
|
||||
|
||||
PyTorch `torch.compile` integration for Luminal.
|
||||
|
||||
## CUDA Tests
|
||||
|
||||
The Python CUDA CI job builds the Rust extension with the CUDA feature and runs
|
||||
the non-slow pytest suite:
|
||||
|
||||
```bash
|
||||
cd crates/luminal_python
|
||||
RUST_BACKTRACE=1 \
|
||||
LUMINAL_TEST_DEVICE=cuda \
|
||||
MATURIN_PEP517_ARGS="--features cuda --profile release" \
|
||||
CUDARC_CUDA_VERSION=12080 \
|
||||
uv run --group dev python -m pytest tests/ -v -s -m "not slow"
|
||||
```
|
||||
|
||||
The slow tests are explicit opt-in. They include large/pretrained model tests,
|
||||
full-width architecture compiles, Whisper end-to-end cases, and other cases that
|
||||
can take a long time or need a large GPU / Hugging Face cache.
|
||||
|
||||
Run the full Python CUDA suite, including slow tests:
|
||||
|
||||
```bash
|
||||
cd crates/luminal_python
|
||||
RUST_BACKTRACE=1 \
|
||||
LUMINAL_TEST_DEVICE=cuda \
|
||||
MATURIN_PEP517_ARGS="--features cuda --profile release" \
|
||||
CUDARC_CUDA_VERSION=12080 \
|
||||
uv run --group dev python -m pytest tests/ -v -s
|
||||
```
|
||||
|
||||
Run only the slow Python CUDA tests:
|
||||
|
||||
```bash
|
||||
cd crates/luminal_python
|
||||
RUST_BACKTRACE=1 \
|
||||
LUMINAL_TEST_DEVICE=cuda \
|
||||
MATURIN_PEP517_ARGS="--features cuda --profile release" \
|
||||
CUDARC_CUDA_VERSION=12080 \
|
||||
uv run --group dev python -m pytest tests/ -v -s -m slow
|
||||
```
|
||||
|
||||
The helper script follows the same convention:
|
||||
|
||||
```bash
|
||||
cd crates/luminal_python
|
||||
./run_tests_cuda.sh # non-slow CUDA suite
|
||||
./run_tests_cuda.sh --slow-only # only slow CUDA tests
|
||||
./run_tests_cuda.sh --include-slow
|
||||
```
|
||||
|
||||
The GitHub/Modal entrypoint uses the same marker split:
|
||||
|
||||
```bash
|
||||
cd crates/luminal_python
|
||||
modal run modal_pytest_runner.py --gpu A100 --timeout 7200 tests/ -v -s -m "not slow"
|
||||
modal run modal_pytest_runner.py --gpu A100 --timeout 7200 tests/ -v -s
|
||||
```
|
||||
|
||||
@@ -1,497 +0,0 @@
|
||||
"""Whisper transcription demo using the luminal torch.compile backend.
|
||||
|
||||
Implements a small PyTorch port of ``openai/whisper-tiny.en`` that mirrors the
|
||||
luminal Rust example (``examples/whisper`` in the workspace), loads the official
|
||||
HuggingFace weights, and runs greedy decoding through the luminal backend via
|
||||
``torch.compile``.
|
||||
|
||||
Usage::
|
||||
|
||||
uv run python examples/whisper.py [path/to/audio.wav]
|
||||
|
||||
If no path is provided, falls back to the JFK sample bundled with the Rust
|
||||
``examples/whisper`` crate.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import wave
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch._dynamo
|
||||
import torch.nn.functional as F
|
||||
from transformers import (
|
||||
WhisperFeatureExtractor,
|
||||
WhisperForConditionalGeneration,
|
||||
WhisperTokenizer,
|
||||
)
|
||||
|
||||
from luminal.pt2 import compile as luminal_compile
|
||||
|
||||
REPO_ID = "openai/whisper-tiny.en"
|
||||
|
||||
# whisper-tiny.en hyperparameters
|
||||
N_MELS = 80
|
||||
N_AUDIO_CTX = 1500
|
||||
D_MODEL = 384
|
||||
N_HEADS = 6
|
||||
HEAD_DIM = D_MODEL // N_HEADS
|
||||
N_AUDIO_LAYER = 4
|
||||
N_TEXT_LAYER = 4
|
||||
N_TEXT_CTX = 448
|
||||
FF_DIM = 4 * D_MODEL
|
||||
N_VOCAB = 51864
|
||||
LAYER_NORM_EPS = 1e-5
|
||||
|
||||
# Decoder special tokens
|
||||
TOKEN_SOT = 50257
|
||||
TOKEN_NO_TIMESTAMPS = 50362
|
||||
TOKEN_EOT = 50256
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Model — mirrors the HLIR encoder/decoder in examples/whisper/src/model.rs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class WhisperAttention(torch.nn.Module):
|
||||
"""Multi-head attention with separate q/k/v projections (no bias on k_proj)."""
|
||||
|
||||
def __init__(self, d_model: int = D_MODEL, n_heads: int = N_HEADS):
|
||||
super().__init__()
|
||||
self.n_heads = n_heads
|
||||
self.head_dim = d_model // n_heads
|
||||
self.q_proj = torch.nn.Linear(d_model, d_model, bias=True)
|
||||
self.k_proj = torch.nn.Linear(d_model, d_model, bias=False)
|
||||
self.v_proj = torch.nn.Linear(d_model, d_model, bias=True)
|
||||
self.out_proj = torch.nn.Linear(d_model, d_model, bias=True)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
kv_input: Optional[torch.Tensor] = None,
|
||||
causal: bool = False,
|
||||
) -> torch.Tensor:
|
||||
# x: (seq, d_model). kv_input is None → self-attn; otherwise cross-attn.
|
||||
kv = x if kv_input is None else kv_input
|
||||
q = self.q_proj(x)
|
||||
k = self.k_proj(kv)
|
||||
v = self.v_proj(kv)
|
||||
|
||||
seq_q = q.shape[0]
|
||||
seq_kv = k.shape[0]
|
||||
|
||||
# (seq, d_model) -> (n_heads, seq, head_dim)
|
||||
q = q.reshape(seq_q, self.n_heads, self.head_dim).transpose(0, 1)
|
||||
k = k.reshape(seq_kv, self.n_heads, self.head_dim).transpose(0, 1)
|
||||
v = v.reshape(seq_kv, self.n_heads, self.head_dim).transpose(0, 1)
|
||||
|
||||
scale = 1.0 / (self.head_dim**0.5)
|
||||
scores = torch.matmul(q, k.transpose(-2, -1)) * scale # (h, sq, sk)
|
||||
if causal:
|
||||
# Use a large finite negative instead of -inf so the export pipeline
|
||||
# serializes a float instead of the unsupported "-Infinity" sentinel.
|
||||
mask = torch.triu(
|
||||
torch.full((seq_q, seq_kv), -1e10, device=x.device),
|
||||
diagonal=1,
|
||||
)
|
||||
scores = scores + mask
|
||||
weights = torch.softmax(scores, dim=-1)
|
||||
attn = torch.matmul(weights, v) # (h, sq, hd)
|
||||
merged = attn.transpose(0, 1).reshape(seq_q, -1)
|
||||
return self.out_proj(merged)
|
||||
|
||||
|
||||
class EncoderLayer(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.self_attn = WhisperAttention()
|
||||
self.self_attn_layer_norm = torch.nn.LayerNorm(D_MODEL, eps=LAYER_NORM_EPS)
|
||||
self.fc1 = torch.nn.Linear(D_MODEL, FF_DIM, bias=True)
|
||||
self.fc2 = torch.nn.Linear(FF_DIM, D_MODEL, bias=True)
|
||||
self.final_layer_norm = torch.nn.LayerNorm(D_MODEL, eps=LAYER_NORM_EPS)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = x + self.self_attn(self.self_attn_layer_norm(x))
|
||||
h = self.final_layer_norm(x)
|
||||
h = F.gelu(self.fc1(h))
|
||||
h = self.fc2(h)
|
||||
return x + h
|
||||
|
||||
|
||||
class WhisperEncoder(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv1 = torch.nn.Conv1d(
|
||||
N_MELS, D_MODEL, kernel_size=3, padding=1, bias=True
|
||||
)
|
||||
self.conv2 = torch.nn.Conv1d(
|
||||
D_MODEL, D_MODEL, kernel_size=3, stride=2, padding=1, bias=True
|
||||
)
|
||||
# Position embedding stored as a regular parameter (matches HF layout).
|
||||
self.embed_positions = torch.nn.Embedding(N_AUDIO_CTX, D_MODEL)
|
||||
self.layers = torch.nn.ModuleList(
|
||||
[EncoderLayer() for _ in range(N_AUDIO_LAYER)]
|
||||
)
|
||||
self.layer_norm = torch.nn.LayerNorm(D_MODEL, eps=LAYER_NORM_EPS)
|
||||
|
||||
def forward(self, mel: torch.Tensor) -> torch.Tensor:
|
||||
# mel: (n_mels, 3000) -> add batch dim for conv1d
|
||||
x = mel.unsqueeze(0)
|
||||
x = F.gelu(self.conv1(x))
|
||||
x = F.gelu(self.conv2(x))
|
||||
# (1, d_model, 1500) -> (1500, d_model)
|
||||
x = x.squeeze(0).transpose(0, 1)
|
||||
x = x + self.embed_positions.weight
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
return self.layer_norm(x)
|
||||
|
||||
|
||||
class DecoderLayer(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.self_attn = WhisperAttention()
|
||||
self.self_attn_layer_norm = torch.nn.LayerNorm(D_MODEL, eps=LAYER_NORM_EPS)
|
||||
self.encoder_attn = WhisperAttention()
|
||||
self.encoder_attn_layer_norm = torch.nn.LayerNorm(D_MODEL, eps=LAYER_NORM_EPS)
|
||||
self.fc1 = torch.nn.Linear(D_MODEL, FF_DIM, bias=True)
|
||||
self.fc2 = torch.nn.Linear(FF_DIM, D_MODEL, bias=True)
|
||||
self.final_layer_norm = torch.nn.LayerNorm(D_MODEL, eps=LAYER_NORM_EPS)
|
||||
|
||||
def forward(self, x: torch.Tensor, xa: torch.Tensor) -> torch.Tensor:
|
||||
x = x + self.self_attn(self.self_attn_layer_norm(x), causal=True)
|
||||
x = x + self.encoder_attn(self.encoder_attn_layer_norm(x), kv_input=xa)
|
||||
h = self.final_layer_norm(x)
|
||||
h = F.gelu(self.fc1(h))
|
||||
h = self.fc2(h)
|
||||
return x + h
|
||||
|
||||
|
||||
class WhisperDecoder(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.embed_tokens = torch.nn.Embedding(N_VOCAB, D_MODEL)
|
||||
self.embed_positions = torch.nn.Embedding(N_TEXT_CTX, D_MODEL)
|
||||
self.layers = torch.nn.ModuleList([DecoderLayer() for _ in range(N_TEXT_LAYER)])
|
||||
self.layer_norm = torch.nn.LayerNorm(D_MODEL, eps=LAYER_NORM_EPS)
|
||||
|
||||
def forward(self, tokens: torch.Tensor, xa: torch.Tensor) -> torch.Tensor:
|
||||
# tokens: (seq,) of int64 — absolute positions are 0..seq-1
|
||||
seq = tokens.shape[0]
|
||||
pos = torch.arange(seq, dtype=torch.long, device=tokens.device)
|
||||
x = self.embed_tokens(tokens) + self.embed_positions(pos)
|
||||
for layer in self.layers:
|
||||
x = layer(x, xa)
|
||||
x = self.layer_norm(x)
|
||||
# Tied projection
|
||||
return torch.matmul(x, self.embed_tokens.weight.transpose(0, 1))
|
||||
|
||||
|
||||
class Whisper(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.encoder = WhisperEncoder()
|
||||
self.decoder = WhisperDecoder()
|
||||
|
||||
def forward(self, mel: torch.Tensor, tokens: torch.Tensor) -> torch.Tensor:
|
||||
xa = self.encoder(mel)
|
||||
return self.decoder(tokens, xa)
|
||||
|
||||
|
||||
class DecoderWithFixedXa(torch.nn.Module):
|
||||
"""Wraps the decoder with the encoder output stored as a buffer.
|
||||
|
||||
The audio is fixed for the whole utterance, so ``xa`` is a constant relative
|
||||
to the per-token decode loop. Storing it as a buffer lets us compile the
|
||||
decoder once with a single dynamic-length ``tokens`` input, avoiding a full
|
||||
recompilation at every step as the sequence grows.
|
||||
"""
|
||||
|
||||
def __init__(self, decoder: WhisperDecoder, xa: torch.Tensor):
|
||||
super().__init__()
|
||||
self.decoder = decoder
|
||||
self.register_buffer("xa", xa)
|
||||
|
||||
def forward(self, tokens: torch.Tensor) -> torch.Tensor:
|
||||
return self.decoder(tokens, self.xa)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Weight loading: HF state_dict -> our model
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def load_hf_weights_into(model: Whisper) -> None:
|
||||
"""Copy HF whisper-tiny.en weights into our matching modules."""
|
||||
hf = WhisperForConditionalGeneration.from_pretrained(REPO_ID).eval()
|
||||
sd = hf.state_dict()
|
||||
|
||||
def get(name: str) -> torch.Tensor:
|
||||
return sd[f"model.{name}"].clone()
|
||||
|
||||
enc = model.encoder
|
||||
enc.conv1.weight.data.copy_(get("encoder.conv1.weight"))
|
||||
enc.conv1.bias.data.copy_(get("encoder.conv1.bias"))
|
||||
enc.conv2.weight.data.copy_(get("encoder.conv2.weight"))
|
||||
enc.conv2.bias.data.copy_(get("encoder.conv2.bias"))
|
||||
enc.embed_positions.weight.data.copy_(get("encoder.embed_positions.weight"))
|
||||
enc.layer_norm.weight.data.copy_(get("encoder.layer_norm.weight"))
|
||||
enc.layer_norm.bias.data.copy_(get("encoder.layer_norm.bias"))
|
||||
for i, layer in enumerate(enc.layers):
|
||||
prefix = f"encoder.layers.{i}"
|
||||
layer.self_attn.q_proj.weight.data.copy_(
|
||||
get(f"{prefix}.self_attn.q_proj.weight")
|
||||
)
|
||||
layer.self_attn.q_proj.bias.data.copy_(get(f"{prefix}.self_attn.q_proj.bias"))
|
||||
layer.self_attn.k_proj.weight.data.copy_(
|
||||
get(f"{prefix}.self_attn.k_proj.weight")
|
||||
)
|
||||
layer.self_attn.v_proj.weight.data.copy_(
|
||||
get(f"{prefix}.self_attn.v_proj.weight")
|
||||
)
|
||||
layer.self_attn.v_proj.bias.data.copy_(get(f"{prefix}.self_attn.v_proj.bias"))
|
||||
layer.self_attn.out_proj.weight.data.copy_(
|
||||
get(f"{prefix}.self_attn.out_proj.weight")
|
||||
)
|
||||
layer.self_attn.out_proj.bias.data.copy_(
|
||||
get(f"{prefix}.self_attn.out_proj.bias")
|
||||
)
|
||||
layer.self_attn_layer_norm.weight.data.copy_(
|
||||
get(f"{prefix}.self_attn_layer_norm.weight")
|
||||
)
|
||||
layer.self_attn_layer_norm.bias.data.copy_(
|
||||
get(f"{prefix}.self_attn_layer_norm.bias")
|
||||
)
|
||||
layer.fc1.weight.data.copy_(get(f"{prefix}.fc1.weight"))
|
||||
layer.fc1.bias.data.copy_(get(f"{prefix}.fc1.bias"))
|
||||
layer.fc2.weight.data.copy_(get(f"{prefix}.fc2.weight"))
|
||||
layer.fc2.bias.data.copy_(get(f"{prefix}.fc2.bias"))
|
||||
layer.final_layer_norm.weight.data.copy_(
|
||||
get(f"{prefix}.final_layer_norm.weight")
|
||||
)
|
||||
layer.final_layer_norm.bias.data.copy_(get(f"{prefix}.final_layer_norm.bias"))
|
||||
|
||||
dec = model.decoder
|
||||
dec.embed_tokens.weight.data.copy_(get("decoder.embed_tokens.weight"))
|
||||
dec.embed_positions.weight.data.copy_(get("decoder.embed_positions.weight"))
|
||||
dec.layer_norm.weight.data.copy_(get("decoder.layer_norm.weight"))
|
||||
dec.layer_norm.bias.data.copy_(get("decoder.layer_norm.bias"))
|
||||
for i, layer in enumerate(dec.layers):
|
||||
prefix = f"decoder.layers.{i}"
|
||||
layer.self_attn.q_proj.weight.data.copy_(
|
||||
get(f"{prefix}.self_attn.q_proj.weight")
|
||||
)
|
||||
layer.self_attn.q_proj.bias.data.copy_(get(f"{prefix}.self_attn.q_proj.bias"))
|
||||
layer.self_attn.k_proj.weight.data.copy_(
|
||||
get(f"{prefix}.self_attn.k_proj.weight")
|
||||
)
|
||||
layer.self_attn.v_proj.weight.data.copy_(
|
||||
get(f"{prefix}.self_attn.v_proj.weight")
|
||||
)
|
||||
layer.self_attn.v_proj.bias.data.copy_(get(f"{prefix}.self_attn.v_proj.bias"))
|
||||
layer.self_attn.out_proj.weight.data.copy_(
|
||||
get(f"{prefix}.self_attn.out_proj.weight")
|
||||
)
|
||||
layer.self_attn.out_proj.bias.data.copy_(
|
||||
get(f"{prefix}.self_attn.out_proj.bias")
|
||||
)
|
||||
layer.self_attn_layer_norm.weight.data.copy_(
|
||||
get(f"{prefix}.self_attn_layer_norm.weight")
|
||||
)
|
||||
layer.self_attn_layer_norm.bias.data.copy_(
|
||||
get(f"{prefix}.self_attn_layer_norm.bias")
|
||||
)
|
||||
layer.encoder_attn.q_proj.weight.data.copy_(
|
||||
get(f"{prefix}.encoder_attn.q_proj.weight")
|
||||
)
|
||||
layer.encoder_attn.q_proj.bias.data.copy_(
|
||||
get(f"{prefix}.encoder_attn.q_proj.bias")
|
||||
)
|
||||
layer.encoder_attn.k_proj.weight.data.copy_(
|
||||
get(f"{prefix}.encoder_attn.k_proj.weight")
|
||||
)
|
||||
layer.encoder_attn.v_proj.weight.data.copy_(
|
||||
get(f"{prefix}.encoder_attn.v_proj.weight")
|
||||
)
|
||||
layer.encoder_attn.v_proj.bias.data.copy_(
|
||||
get(f"{prefix}.encoder_attn.v_proj.bias")
|
||||
)
|
||||
layer.encoder_attn.out_proj.weight.data.copy_(
|
||||
get(f"{prefix}.encoder_attn.out_proj.weight")
|
||||
)
|
||||
layer.encoder_attn.out_proj.bias.data.copy_(
|
||||
get(f"{prefix}.encoder_attn.out_proj.bias")
|
||||
)
|
||||
layer.encoder_attn_layer_norm.weight.data.copy_(
|
||||
get(f"{prefix}.encoder_attn_layer_norm.weight")
|
||||
)
|
||||
layer.encoder_attn_layer_norm.bias.data.copy_(
|
||||
get(f"{prefix}.encoder_attn_layer_norm.bias")
|
||||
)
|
||||
layer.fc1.weight.data.copy_(get(f"{prefix}.fc1.weight"))
|
||||
layer.fc1.bias.data.copy_(get(f"{prefix}.fc1.bias"))
|
||||
layer.fc2.weight.data.copy_(get(f"{prefix}.fc2.weight"))
|
||||
layer.fc2.bias.data.copy_(get(f"{prefix}.fc2.bias"))
|
||||
layer.final_layer_norm.weight.data.copy_(
|
||||
get(f"{prefix}.final_layer_norm.weight")
|
||||
)
|
||||
layer.final_layer_norm.bias.data.copy_(get(f"{prefix}.final_layer_norm.bias"))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Audio loading + decoding
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def load_wav_16k_mono(path: Path) -> np.ndarray:
|
||||
with wave.open(str(path), "rb") as w:
|
||||
sr = w.getframerate()
|
||||
n = w.getnframes()
|
||||
ch = w.getnchannels()
|
||||
sw = w.getsampwidth()
|
||||
raw = w.readframes(n)
|
||||
|
||||
if sw == 2:
|
||||
samples = np.frombuffer(raw, dtype=np.int16).astype(np.float32) / 32768.0
|
||||
elif sw == 4:
|
||||
samples = np.frombuffer(raw, dtype=np.int32).astype(np.float32) / 2147483648.0
|
||||
elif sw == 1:
|
||||
samples = (
|
||||
np.frombuffer(raw, dtype=np.uint8).astype(np.float32) - 128.0
|
||||
) / 128.0
|
||||
else:
|
||||
raise ValueError(f"unsupported sample width {sw}")
|
||||
|
||||
if ch > 1:
|
||||
samples = samples.reshape(-1, ch).mean(axis=1)
|
||||
|
||||
if sr != 16000:
|
||||
ratio = sr / 16000
|
||||
out_len = int(len(samples) / ratio)
|
||||
idx = np.arange(out_len, dtype=np.float64) * ratio
|
||||
lo = idx.astype(np.int64)
|
||||
frac = (idx - lo).astype(np.float32)
|
||||
hi = np.clip(lo + 1, 0, len(samples) - 1)
|
||||
samples = samples[lo] * (1.0 - frac) + samples[hi] * frac
|
||||
|
||||
return samples.astype(np.float32)
|
||||
|
||||
|
||||
def greedy_decode(logits_row: torch.Tensor, suppress_first_eot: bool) -> int:
|
||||
masked = logits_row.clone()
|
||||
masked[TOKEN_SOT:] = float("-inf")
|
||||
if suppress_first_eot:
|
||||
masked[TOKEN_EOT] = float("-inf")
|
||||
return int(torch.argmax(masked).item())
|
||||
|
||||
|
||||
def find_default_audio() -> Optional[Path]:
|
||||
here = Path(__file__).resolve()
|
||||
workspace_root = here.parents[3]
|
||||
candidate = workspace_root / "examples" / "whisper" / "assets" / "jfk.wav"
|
||||
return candidate if candidate.exists() else None
|
||||
|
||||
|
||||
def main() -> None:
|
||||
audio_arg = sys.argv[1] if len(sys.argv) > 1 else None
|
||||
if audio_arg:
|
||||
audio_path = Path(audio_arg)
|
||||
else:
|
||||
audio_path = find_default_audio()
|
||||
if audio_path is None:
|
||||
print(
|
||||
"error: no audio file given and bundled jfk.wav not found",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f"Using device: {device}")
|
||||
|
||||
print("Loading audio:", audio_path)
|
||||
audio = load_wav_16k_mono(audio_path)
|
||||
|
||||
print("Computing log-mel features...")
|
||||
feature_extractor = WhisperFeatureExtractor.from_pretrained(REPO_ID)
|
||||
features = feature_extractor(audio, sampling_rate=16000, return_tensors="pt")
|
||||
mel: torch.Tensor = features.input_features[0].to(device) # (80, 3000)
|
||||
assert mel.shape == (N_MELS, 3000), mel.shape
|
||||
|
||||
print("Building model and loading weights...")
|
||||
model = Whisper().eval().to(device)
|
||||
load_hf_weights_into(model)
|
||||
model = model.to(device)
|
||||
tokenizer = WhisperTokenizer.from_pretrained(REPO_ID)
|
||||
|
||||
use_compiled = os.environ.get("LUMINAL_DISABLE", "0") != "1"
|
||||
max_new_tokens = int(os.environ.get("GEN_TOKENS", "100"))
|
||||
search_iters = int(os.environ.get("SEARCH_ITERATIONS", "10"))
|
||||
|
||||
if use_compiled:
|
||||
# 1. Run the encoder once eagerly. The audio doesn't change during decode,
|
||||
# so xa is a constant input to the decoder.
|
||||
with torch.no_grad():
|
||||
xa = model.encoder(mel)
|
||||
|
||||
# 2. Wrap the decoder so its only varying input is `tokens`, then compile
|
||||
# once with a dynamic length dim. Subsequent calls reuse the same
|
||||
# compiled graph — no recompile per token.
|
||||
decoder_only = DecoderWithFixedXa(model.decoder, xa).eval().to(device)
|
||||
example_tokens = torch.tensor(
|
||||
[TOKEN_SOT, TOKEN_NO_TIMESTAMPS], dtype=torch.long, device=device
|
||||
)
|
||||
print(
|
||||
f"Compiling decoder with dynamic seq dim (search_iters={search_iters})..."
|
||||
)
|
||||
compile_start = time.time()
|
||||
compiled_decoder = luminal_compile(
|
||||
decoder_only,
|
||||
example_tokens,
|
||||
search_iterations=search_iters,
|
||||
dynamic_dim=0,
|
||||
)
|
||||
print(f"Compiled in {time.time() - compile_start:.1f}s")
|
||||
|
||||
def step_logits(decoder_input_ids: torch.Tensor) -> torch.Tensor:
|
||||
out = compiled_decoder(decoder_input_ids)
|
||||
return out[0] if isinstance(out, tuple) else out
|
||||
else:
|
||||
|
||||
def step_logits(decoder_input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return model(mel, decoder_input_ids)
|
||||
|
||||
tokens = [TOKEN_SOT, TOKEN_NO_TIMESTAMPS]
|
||||
|
||||
print("Transcribing", end="", flush=True)
|
||||
decode_start = time.time()
|
||||
for step in range(max_new_tokens):
|
||||
decoder_input_ids = torch.tensor(tokens, dtype=torch.long, device=device)
|
||||
with torch.no_grad():
|
||||
logits = step_logits(decoder_input_ids)
|
||||
|
||||
next_token = greedy_decode(logits[-1], suppress_first_eot=(step == 0))
|
||||
if next_token == TOKEN_EOT:
|
||||
break
|
||||
tokens.append(next_token)
|
||||
piece = tokenizer.decode([next_token], skip_special_tokens=False)
|
||||
print(piece, end="", flush=True)
|
||||
elapsed = time.time() - decode_start
|
||||
print()
|
||||
|
||||
transcription = tokenizer.decode(tokens[2:], skip_special_tokens=True)
|
||||
print(f"\nFinal transcription: {transcription}")
|
||||
print(
|
||||
f"Generated {len(tokens) - 2} tokens in {elapsed:.2f}s "
|
||||
f"({(len(tokens) - 2) / max(elapsed, 1e-6):.1f} tok/s)"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -22,7 +22,7 @@ from modal.volume import FileEntryType
|
||||
|
||||
app = modal.App("luminal-tests")
|
||||
|
||||
DEFAULT_TIMEOUT = 2 * 60 * 60
|
||||
DEFAULT_TIMEOUT = 30 * 60
|
||||
CUDARC_CUDA_VERSION = "12080"
|
||||
LOCAL_PROJECT_DIR = Path(__file__).resolve().parent
|
||||
PROJECT_DIR = "/root/luminal/crates/luminal_python"
|
||||
@@ -168,37 +168,6 @@ def _cleanup_remote_profile_artifacts(run_id: str) -> None:
|
||||
return
|
||||
|
||||
|
||||
def _build_cuda_extension(env: dict[str, str]) -> None:
|
||||
cmd = [
|
||||
"uv",
|
||||
"run",
|
||||
"--project",
|
||||
PROJECT_DIR,
|
||||
"--group",
|
||||
"dev",
|
||||
"maturin",
|
||||
"develop",
|
||||
"--manifest-path",
|
||||
f"{PROJECT_DIR}/rust/Cargo.toml",
|
||||
"--features",
|
||||
"cuda",
|
||||
"--profile",
|
||||
"release",
|
||||
]
|
||||
subprocess.run(cmd, env=env, cwd=PROJECT_DIR, check=True)
|
||||
|
||||
|
||||
def _effective_timeout(timeout: int) -> int:
|
||||
if os.environ.get("GITHUB_ACTIONS") == "true" and timeout < DEFAULT_TIMEOUT:
|
||||
print(
|
||||
f"Using Modal timeout {DEFAULT_TIMEOUT}s instead of requested "
|
||||
f"{timeout}s in GitHub Actions.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return DEFAULT_TIMEOUT
|
||||
return timeout
|
||||
|
||||
|
||||
@app.cls(image=image, timeout=DEFAULT_TIMEOUT)
|
||||
class TestRunner:
|
||||
@modal.method()
|
||||
@@ -225,8 +194,6 @@ class TestRunner:
|
||||
if pytest_addopts:
|
||||
env["PYTEST_ADDOPTS"] = pytest_addopts
|
||||
|
||||
_build_cuda_extension(env)
|
||||
|
||||
original_svg_requested = _has_pytest_flag(pytest_args, "--profile-svg")
|
||||
dot_available = shutil.which("dot") is not None
|
||||
sanitized_pytest_args = [
|
||||
@@ -251,6 +218,8 @@ class TestRunner:
|
||||
PROJECT_DIR,
|
||||
"--group",
|
||||
"dev",
|
||||
"--reinstall-package",
|
||||
"luminal_python",
|
||||
"python",
|
||||
"-m",
|
||||
"pytest",
|
||||
@@ -316,7 +285,7 @@ class TestRunner:
|
||||
|
||||
def _parse_cli_args(
|
||||
cli_args: tuple[str, ...],
|
||||
) -> tuple[str, int, bool, str | None, list[str]]:
|
||||
) -> tuple[str, int | None, bool, str | None, list[str]]:
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="modal run modal_pytest_runner.py",
|
||||
add_help=False,
|
||||
@@ -331,8 +300,7 @@ def _parse_cli_args(
|
||||
parser.add_argument(
|
||||
"--timeout",
|
||||
type=int,
|
||||
default=DEFAULT_TIMEOUT,
|
||||
help="Modal execution timeout in seconds. Defaults to %(default)s seconds.",
|
||||
help="Optional Modal execution timeout in seconds. Defaults to 1800 seconds.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--profile",
|
||||
@@ -366,11 +334,11 @@ def main(*cli_args: str):
|
||||
)
|
||||
profile_enabled = _profiling_enabled(cli_profile, pytest_args)
|
||||
pytest_addopts = os.environ.get("PYTEST_ADDOPTS", "")
|
||||
timeout = _effective_timeout(timeout)
|
||||
runner_options = {"gpu": gpu}
|
||||
hf_token_secret = _hf_token_secret()
|
||||
runner_volumes = {HF_CACHE_PATH: HF_CACHE_VOLUME}
|
||||
runner_options["timeout"] = timeout
|
||||
if timeout is not None:
|
||||
runner_options["timeout"] = timeout
|
||||
if profile_enabled:
|
||||
runner_volumes[PROFILE_VOLUME_PATH] = PROFILE_VOLUME
|
||||
runner_options["volumes"] = runner_volumes
|
||||
|
||||
@@ -32,7 +32,7 @@ module-name = "luminal.luminal"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
markers = [
|
||||
"slow: tests that download large models, compile full-width model graphs, fuzz many CUDA search choices, or otherwise require explicit opt-in",
|
||||
"slow: tests that download large models or require pre-generated artifacts",
|
||||
]
|
||||
|
||||
[dependency-groups]
|
||||
|
||||
@@ -1,43 +1,34 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
export CUDARC_CUDA_VERSION="${CUDARC_CUDA_VERSION:-12080}"
|
||||
export MATURIN_PEP517_ARGS="${MATURIN_PEP517_ARGS:---features cuda --profile release}"
|
||||
|
||||
echo "=========================================="
|
||||
echo " Luminal Python: Full Test Suite"
|
||||
echo "=========================================="
|
||||
|
||||
NATIVE_TESTS="tests/test_hlir_ops.py tests/test_unary.py"
|
||||
CUDA_TESTS="tests/"
|
||||
CUDA_TESTS="tests/test_hlir_ops.py tests/test_unary.py tests/test_llama3.py"
|
||||
|
||||
# ── Phase 1: Native Backend ─────────────────────────────────
|
||||
|
||||
echo ""
|
||||
echo "=== Phase 1: Building native backend ==="
|
||||
rm -rf rust/target/wheels rust/target/debug rust/target/release
|
||||
uv run --group dev maturin develop --manifest-path rust/Cargo.toml
|
||||
uv run maturin develop --manifest-path rust/Cargo.toml
|
||||
|
||||
echo ""
|
||||
echo "--- 1a: Native backend tests ---"
|
||||
uv run --group dev pytest $NATIVE_TESTS -v
|
||||
uv run pytest $NATIVE_TESTS -v
|
||||
|
||||
# ── Phase 2: CUDA Backend ───────────────────────────────────
|
||||
|
||||
echo ""
|
||||
echo "=== Phase 2: Building CUDA backend ==="
|
||||
rm -rf rust/target/wheels rust/target/debug rust/target/release
|
||||
uv run --group dev maturin develop --manifest-path rust/Cargo.toml --features cuda -r
|
||||
uv run maturin develop --manifest-path rust/Cargo.toml --features cuda -r
|
||||
|
||||
echo ""
|
||||
echo "--- 2a: CUDA ---"
|
||||
RUST_BACKTRACE=1 LUMINAL_TEST_DEVICE=cuda uv run --group dev pytest $CUDA_TESTS -m "not slow" -v
|
||||
|
||||
echo ""
|
||||
echo "Slow CUDA tests are opt-in. To include them, run:"
|
||||
echo " RUST_BACKTRACE=1 LUMINAL_TEST_DEVICE=cuda uv run pytest tests/ -v -s"
|
||||
echo "Or, for only slow tests:"
|
||||
echo " RUST_BACKTRACE=1 LUMINAL_TEST_DEVICE=cuda uv run pytest tests/ -m slow -v -s"
|
||||
RUST_BACKTRACE=1 LUMINAL_TEST_DEVICE=cuda uv run pytest $CUDA_TESTS -m "not slow" -v
|
||||
|
||||
echo ""
|
||||
echo "=========================================="
|
||||
|
||||
@@ -4,34 +4,17 @@ set -e
|
||||
echo "=== Luminal Python Test Runner (CUDA Backend) ==="
|
||||
echo ""
|
||||
|
||||
export CUDARC_CUDA_VERSION="${CUDARC_CUDA_VERSION:-12080}"
|
||||
export MATURIN_PEP517_ARGS="${MATURIN_PEP517_ARGS:---features cuda --profile release}"
|
||||
|
||||
PYTEST_MARK='not slow'
|
||||
if [[ "${1:-}" == "--include-slow" ]]; then
|
||||
PYTEST_MARK=''
|
||||
elif [[ "${1:-}" == "--slow-only" ]]; then
|
||||
PYTEST_MARK='slow'
|
||||
elif [[ "${1:-}" != "" ]]; then
|
||||
echo "Usage: ./run_tests_cuda.sh [--include-slow|--slow-only]"
|
||||
exit 2
|
||||
fi
|
||||
|
||||
# Force clean rebuild of Rust extension
|
||||
echo "Step 1: Cleaning previous builds..."
|
||||
rm -rf rust/target/wheels rust/target/debug rust/target/release
|
||||
|
||||
# Rebuild in development mode (faster compilation)
|
||||
echo "Step 2: Building Rust extension..."
|
||||
uv run --group dev maturin develop --manifest-path rust/Cargo.toml --features cuda -r
|
||||
uv run maturin develop --manifest-path rust/Cargo.toml --features cuda -r
|
||||
|
||||
# Run pytest with CUDA backend
|
||||
echo "Step 3: Running pytest with CUDA backend..."
|
||||
if [[ -n "$PYTEST_MARK" ]]; then
|
||||
RUST_BACKTRACE=1 LUMINAL_TEST_DEVICE=cuda uv run --group dev pytest tests/ -m "$PYTEST_MARK" -v -s
|
||||
else
|
||||
RUST_BACKTRACE=1 LUMINAL_TEST_DEVICE=cuda uv run --group dev pytest tests/ -v -s
|
||||
fi
|
||||
RUST_BACKTRACE=1 LUMINAL_TEST_DEVICE=cuda uv run pytest tests/test_llama3.py tests/test_hlir_ops.py tests/test_unary.py -v
|
||||
|
||||
echo ""
|
||||
echo "=== Tests Complete ==="
|
||||
|
||||
@@ -12,67 +12,6 @@ use crate::typed_data::TypedData;
|
||||
/// Maps symbolic dimension parameter names (e.g. "seq_len") to luminal Expression variable chars.
|
||||
pub type DimParamMap = HashMap<String, char>;
|
||||
|
||||
/// Recover a single-variable dim's variable value from an observed runtime size.
|
||||
///
|
||||
/// Returns `Some((var, value))` when the expression contains exactly one
|
||||
/// variable, is affine in that variable, and `value` round-trips through
|
||||
/// `exec_single_var_checked` to reproduce `dim_val`. Returns `None` otherwise
|
||||
/// — multi-variable expressions, non-affine forms, slope==0, and inversions
|
||||
/// that don't divide cleanly are all rejected so we never write a wrong
|
||||
/// guess into `dyn_map`.
|
||||
fn solve_single_var_dim(expr: &Expression, dim_val: usize) -> Option<(char, usize)> {
|
||||
use luminal::shape::Term;
|
||||
let terms = expr.terms.read();
|
||||
|
||||
// Identify the unique variable, if any.
|
||||
let mut var: Option<char> = None;
|
||||
for t in terms.iter() {
|
||||
if let Term::Var(c) = t {
|
||||
match var {
|
||||
None => var = Some(*c),
|
||||
Some(existing) if existing == *c => {}
|
||||
Some(_) => return None, // multi-var — bail out
|
||||
}
|
||||
}
|
||||
}
|
||||
let var = var?;
|
||||
|
||||
// Bare-var fast path — terms is exactly `[Var]`.
|
||||
if terms.len() == 1 {
|
||||
return Some((var, dim_val));
|
||||
}
|
||||
|
||||
// Probe two points to recover slope/intercept of an assumed affine form
|
||||
// `f(x) = slope*x + intercept`. We use 2 and 3 (luminal's default
|
||||
// dynamic-dim min is 2, and 3 keeps the inputs small in case the
|
||||
// expression includes a multiplication that could overflow at scale).
|
||||
drop(terms);
|
||||
let f2 = expr.exec_single_var_checked(2)? as i64;
|
||||
let f3 = expr.exec_single_var_checked(3)? as i64;
|
||||
let slope = f3 - f2;
|
||||
if slope == 0 {
|
||||
return None;
|
||||
}
|
||||
let intercept = f2 - 2 * slope;
|
||||
let target = dim_val as i64 - intercept;
|
||||
if slope == 0 || target % slope != 0 {
|
||||
return None;
|
||||
}
|
||||
let candidate = target / slope;
|
||||
if candidate < 0 {
|
||||
return None;
|
||||
}
|
||||
let candidate = candidate as usize;
|
||||
|
||||
// Verify by re-evaluating with the candidate value. Catches non-affine
|
||||
// forms whose probe points happen to be collinear (e.g. `min(s, 100)`
|
||||
// would look affine for s ∈ {2, 3} but flatten beyond 100).
|
||||
if expr.exec_single_var_checked(candidate)? != dim_val {
|
||||
return None;
|
||||
}
|
||||
Some((var, candidate))
|
||||
}
|
||||
|
||||
/// Convert luminal DType to PT2 dtype integer code (for python interop)
|
||||
/// Types without a direct Pytorch equivalent map to the closest safe representation
|
||||
fn luminal_dtype_to_pt2_code(dtype: DType) -> u32 {
|
||||
@@ -98,12 +37,7 @@ pub struct GraphTranslation {
|
||||
pub input_names: Vec<String>,
|
||||
pub output_names: Vec<String>,
|
||||
pub output_shape_exprs: Vec<Vec<Expression>>,
|
||||
/// Output dtypes as PT2 dtype codes (e.g. 5 = int64, 7 = float32).
|
||||
/// Stored as PT2 codes (rather than luminal `DType`) so we can preserve
|
||||
/// distinctions luminal collapses internally — notably int64 vs int32,
|
||||
/// both of which map to `DType::Int` in luminal but must be reported
|
||||
/// back to PyTorch with their original precision.
|
||||
pub output_dtypes: Vec<u32>,
|
||||
pub output_dtypes: Vec<DType>,
|
||||
pub input_shape_exprs: Vec<Vec<Expression>>,
|
||||
pub dim_param_map: DimParamMap,
|
||||
}
|
||||
@@ -129,9 +63,7 @@ pub struct CompiledGraph {
|
||||
pub output_names: Vec<String>,
|
||||
pub output_shapes: Vec<Vec<usize>>,
|
||||
pub output_shape_exprs: Vec<Vec<Expression>>,
|
||||
/// Output dtypes as PT2 dtype codes (preserves int64 / int32 distinction
|
||||
/// that luminal collapses to `DType::Int` internally).
|
||||
pub output_dtypes: Vec<u32>,
|
||||
pub output_dtypes: Vec<DType>,
|
||||
pub input_shape_exprs: Vec<Vec<Expression>>,
|
||||
pub dim_param_map: DimParamMap,
|
||||
}
|
||||
@@ -287,27 +219,17 @@ impl CompiledGraph {
|
||||
}
|
||||
|
||||
/// Auto-detect and set dynamic dimensions from input tensor shapes.
|
||||
///
|
||||
/// For each user input we walk the symbolic shape expressions side-by-side
|
||||
/// with the concrete sizes Dynamo handed us at runtime and try to recover
|
||||
/// each unbound variable's value. Two cases are handled:
|
||||
///
|
||||
/// * Bare-variable dim (`s`): set directly from the size.
|
||||
/// * Single-variable affine dim (`a*s + b`): solve `s = (size - b)/a`
|
||||
/// by sampling the expression at two probe points to extract the
|
||||
/// slope, recovering the intercept, and verifying that plugging the
|
||||
/// recovered value back through `exec_single_var_checked` reproduces
|
||||
/// the observed size. The verification step rejects everything
|
||||
/// non-affine (`s*s`, `min(s, 8)`, etc.) without committing a wrong
|
||||
/// guess to `dyn_map`.
|
||||
///
|
||||
/// Multi-variable dims are skipped here; another input's shape — or an
|
||||
/// explicit `set_dim` call — is expected to bind those.
|
||||
/// For each user input, matches the concrete shape against its symbolic
|
||||
/// shape expressions and sets the corresponding dyn_map entries.
|
||||
fn auto_set_dims_from_input_shapes(&mut self, input_shapes: Vec<Vec<usize>>) {
|
||||
for (shape_exprs, shape) in self.input_shape_exprs.iter().zip(input_shapes.iter()) {
|
||||
for (dim_expr, &dim_val) in shape_exprs.iter().zip(shape.iter()) {
|
||||
if let Some((var, value)) = solve_single_var_dim(dim_expr, dim_val) {
|
||||
self.graph.set_dim(var, value);
|
||||
// Check if this expression is a bare symbolic variable
|
||||
let terms = dim_expr.terms.read();
|
||||
if terms.len() == 1
|
||||
&& let luminal::shape::Term::Var(c) = terms[0]
|
||||
{
|
||||
self.graph.set_dim(c, dim_val);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -483,7 +405,10 @@ impl CompiledGraph {
|
||||
/// Get the PT2 dtype codes for all outputs (in order).
|
||||
#[getter]
|
||||
fn output_dtypes(&self) -> Vec<u32> {
|
||||
self.output_dtypes.clone()
|
||||
self.output_dtypes
|
||||
.iter()
|
||||
.map(|d| luminal_dtype_to_pt2_code(*d))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get output tensor data by name as f32 (copies to host).
|
||||
|
||||
@@ -2,7 +2,7 @@ use luminal::dyn_backend::BackendFactory;
|
||||
use luminal::prelude::tracing::warn;
|
||||
use luminal::prelude::*;
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::{PyCapsule, PyCapsuleMethods};
|
||||
use pyo3::types::{PyAny, PyCapsule, PyCapsuleMethods, PyDict};
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::compiled_graph::{CompiledGraph, DimParamMap, GraphTranslation, WeightData};
|
||||
@@ -14,6 +14,58 @@ use crate::{pt2_parser, pt2_util};
|
||||
/// Pre-loaded weight/constant data paired with tensor sizes.
|
||||
type PreloadResult = (Vec<(String, TypedData)>, HashMap<String, usize>);
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
struct CompileOptions {
|
||||
search_iterations: usize,
|
||||
}
|
||||
|
||||
impl Default for CompileOptions {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
search_iterations: 10,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CompileOptions {
|
||||
fn from_py(options: Option<&Bound<'_, PyAny>>) -> PyResult<Self> {
|
||||
let mut parsed = Self::default();
|
||||
|
||||
let Some(options) = options else {
|
||||
return Ok(parsed);
|
||||
};
|
||||
|
||||
let options = options.cast::<PyDict>().map_err(|_| {
|
||||
pyo3::exceptions::PyTypeError::new_err("luminal backend options must be a dict")
|
||||
})?;
|
||||
|
||||
for (key, value) in options.iter() {
|
||||
let key = key.extract::<String>().map_err(|_| {
|
||||
pyo3::exceptions::PyTypeError::new_err(
|
||||
"luminal backend option keys must be strings",
|
||||
)
|
||||
})?;
|
||||
|
||||
match key.as_str() {
|
||||
"search_iterations" => {
|
||||
parsed.search_iterations = value.extract::<usize>().map_err(|_| {
|
||||
pyo3::exceptions::PyTypeError::new_err(
|
||||
"luminal backend option 'search_iterations' must be an integer",
|
||||
)
|
||||
})?;
|
||||
}
|
||||
other => {
|
||||
return Err(pyo3::exceptions::PyValueError::new_err(format!(
|
||||
"Unsupported luminal backend option '{other}'. Supported options: search_iterations",
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(parsed)
|
||||
}
|
||||
}
|
||||
|
||||
fn resolve_dim_sizes(
|
||||
sizes: &[pt2_schema::DimSize],
|
||||
sym_to_char: &HashMap<String, char>,
|
||||
@@ -23,178 +75,30 @@ fn resolve_dim_sizes(
|
||||
.map(|s| match s {
|
||||
pt2_schema::DimSize::Int(i) => Expression::from(i.as_int as usize),
|
||||
pt2_schema::DimSize::Expr(e) => {
|
||||
let s = e.as_expr.expr_str.trim();
|
||||
// Try the full sympy-style parse first so compound forms like
|
||||
// `Mul(Integer(2), Symbol('s77', ...))` (emitted by `cat` and
|
||||
// similar dim-altering ops) propagate as a real Expression
|
||||
// rather than collapsing to the size-1 fallback. Fall back to
|
||||
// the bare-Symbol fast path when that fails — the parser
|
||||
// bails on unrecognised heads (Pow, Min, etc.) and we'd
|
||||
// rather lose the symbolic info than misinterpret it.
|
||||
parse_sympy_expr(s, sym_to_char)
|
||||
.or_else(|| {
|
||||
pt2_parser::extract_symbol_name_pub(s)
|
||||
.and_then(|sym| sym_to_char.get(&sym).map(|c| Expression::from(*c)))
|
||||
})
|
||||
.or_else(|| {
|
||||
// As a last resort, if the EP gave us a concrete `hint`
|
||||
// (the value used to seed shape tracing), use it. The
|
||||
// dim is technically dynamic but at least output-shape
|
||||
// resolution won't return 1 for unset dims.
|
||||
e.as_expr
|
||||
.hint
|
||||
.as_ref()
|
||||
.and_then(|h| h.as_int())
|
||||
.map(|h| Expression::from(h as usize))
|
||||
})
|
||||
.unwrap_or_else(|| Expression::from(1usize))
|
||||
if let Some(sym) = pt2_parser::extract_symbol_name_pub(&e.as_expr.expr_str) {
|
||||
if let Some(c) = sym_to_char.get(&sym) {
|
||||
Expression::from(*c)
|
||||
} else {
|
||||
Expression::from(1usize)
|
||||
}
|
||||
} else {
|
||||
Expression::from(1usize)
|
||||
}
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Parse a sympy `srepr`-style expression string into a luminal Expression.
|
||||
///
|
||||
/// Handles the subset of sympy heads PT2 actually emits for shape metadata:
|
||||
///
|
||||
/// * `Symbol('name', ...)` — bound to the corresponding luminal char if
|
||||
/// present in `sym_to_char`, or treated as a fresh constant 1 otherwise.
|
||||
/// * `Integer(N)` / `Number(N)` — concrete int.
|
||||
/// * `Mul(a, b, ...)` / `Add(a, b, ...)` — n-ary, folded into pairwise ops.
|
||||
///
|
||||
/// Returns `None` for anything else so the caller can fall back to a less
|
||||
/// precise representation rather than committing a wrong expression.
|
||||
fn parse_sympy_expr(s: &str, sym_to_char: &HashMap<String, char>) -> Option<Expression> {
|
||||
let s = s.trim();
|
||||
if s.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Bare integer literal — `srepr` doesn't usually emit this at the top
|
||||
// level (it wraps in `Integer(...)`), but accept it for robustness.
|
||||
if let Ok(n) = s.parse::<i64>() {
|
||||
return Some(Expression::from(n as usize));
|
||||
}
|
||||
|
||||
let (head, body) = split_head(s)?;
|
||||
match head {
|
||||
"Symbol" => {
|
||||
// Body is `'name', positive=True, integer=True` etc. Pull the
|
||||
// first quoted token as the name.
|
||||
let name = extract_first_quoted(body)?;
|
||||
sym_to_char.get(&name).map(|c| Expression::from(*c))
|
||||
}
|
||||
"Integer" | "Number" => {
|
||||
let n: i64 = body.trim().parse().ok()?;
|
||||
Some(Expression::from(n as usize))
|
||||
}
|
||||
"Mul" | "Add" => {
|
||||
let parts = split_top_level_args(body);
|
||||
if parts.is_empty() {
|
||||
return None;
|
||||
}
|
||||
let mut iter = parts.into_iter();
|
||||
let mut acc = parse_sympy_expr(iter.next()?, sym_to_char)?;
|
||||
for p in iter {
|
||||
let rhs = parse_sympy_expr(p, sym_to_char)?;
|
||||
acc = if head == "Mul" { acc * rhs } else { acc + rhs };
|
||||
}
|
||||
Some(acc)
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Split `Head(body)` into (head, body); returns None if not in that form.
|
||||
fn split_head(s: &str) -> Option<(&str, &str)> {
|
||||
let open = s.find('(')?;
|
||||
if !s.ends_with(')') {
|
||||
return None;
|
||||
}
|
||||
Some((&s[..open], &s[open + 1..s.len() - 1]))
|
||||
}
|
||||
|
||||
/// Pull out the first single- or double-quoted token from a sympy arg list,
|
||||
/// e.g. `'s77', positive=True` → `s77`.
|
||||
fn extract_first_quoted(s: &str) -> Option<String> {
|
||||
let bytes = s.as_bytes();
|
||||
let mut i = 0;
|
||||
while i < bytes.len() {
|
||||
let c = bytes[i] as char;
|
||||
if c == '\'' || c == '"' {
|
||||
let quote = c;
|
||||
let start = i + 1;
|
||||
i += 1;
|
||||
while i < bytes.len() && bytes[i] as char != quote {
|
||||
i += 1;
|
||||
}
|
||||
return Some(s[start..i].to_string());
|
||||
}
|
||||
i += 1;
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Split sympy-style argument list at top-level commas, respecting nested
|
||||
/// parens and quoted strings. Discards `key=value` kwargs (they don't carry
|
||||
/// dimensional information).
|
||||
fn split_top_level_args(s: &str) -> Vec<&str> {
|
||||
let mut out = Vec::new();
|
||||
let bytes = s.as_bytes();
|
||||
let mut depth = 0;
|
||||
let mut in_quote: Option<char> = None;
|
||||
let mut start = 0;
|
||||
for (i, &b) in bytes.iter().enumerate() {
|
||||
let c = b as char;
|
||||
match in_quote {
|
||||
Some(q) => {
|
||||
if c == q {
|
||||
in_quote = None;
|
||||
}
|
||||
}
|
||||
None => match c {
|
||||
'\'' | '"' => in_quote = Some(c),
|
||||
'(' | '[' => depth += 1,
|
||||
')' | ']' => depth -= 1,
|
||||
',' if depth == 0 => {
|
||||
let part = s[start..i].trim();
|
||||
// Drop `key=value` kwargs — they're metadata sympy uses
|
||||
// for pretty-printing, not arguments to the operator.
|
||||
if !part.is_empty() && !looks_like_kwarg(part) {
|
||||
out.push(part);
|
||||
}
|
||||
start = i + 1;
|
||||
}
|
||||
_ => {}
|
||||
},
|
||||
}
|
||||
}
|
||||
let part = s[start..].trim();
|
||||
if !part.is_empty() && !looks_like_kwarg(part) {
|
||||
out.push(part);
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
fn looks_like_kwarg(part: &str) -> bool {
|
||||
if let Some(eq) = part.find('=') {
|
||||
let key = part[..eq].trim();
|
||||
// sympy kwargs are bare identifiers like `positive`, `integer`.
|
||||
!key.is_empty() && key.chars().all(|c| c.is_ascii_alphanumeric() || c == '_')
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
#[pyo3(signature = (pt2_path, weights_path, search_iters, factory_capsule, weight_device_ptrs=None))]
|
||||
#[pyo3(signature = (pt2_path, weights_path, factory_capsule, weight_device_ptrs=None, options=None))]
|
||||
pub fn process_pt2(
|
||||
pt2_path: &str,
|
||||
weights_path: &str,
|
||||
search_iters: usize,
|
||||
factory_capsule: &Bound<'_, PyCapsule>,
|
||||
weight_device_ptrs: Option<HashMap<String, (u64, usize)>>,
|
||||
options: Option<&Bound<'_, PyAny>>,
|
||||
) -> PyResult<CompiledGraph> {
|
||||
let options = CompileOptions::from_py(options)?;
|
||||
let factory: BackendFactory = {
|
||||
let expected = ::luminal::dyn_backend::BACKEND_FACTORY_CAPSULE_NAME;
|
||||
match factory_capsule.name()? {
|
||||
@@ -232,7 +136,7 @@ pub fn process_pt2(
|
||||
compile_pt2(
|
||||
pt2_path,
|
||||
weights_path,
|
||||
search_iters,
|
||||
&options,
|
||||
weight_device_ptrs.unwrap_or_default(),
|
||||
factory,
|
||||
)
|
||||
@@ -242,14 +146,14 @@ pub fn process_pt2(
|
||||
fn compile_pt2(
|
||||
pt2_path: &str,
|
||||
weights_path: &str,
|
||||
search_iters: usize,
|
||||
options: &CompileOptions,
|
||||
weight_device_ptrs: HashMap<String, (u64, usize)>,
|
||||
factory: BackendFactory,
|
||||
) -> anyhow::Result<CompiledGraph> {
|
||||
let (translation, mut weights) = translate_pt2(pt2_path, weights_path)?;
|
||||
weights.device_ptrs = weight_device_ptrs;
|
||||
|
||||
CompiledGraph::parse_graph(translation, weights, factory, search_iters)
|
||||
CompiledGraph::parse_graph(translation, weights, factory, options.search_iterations)
|
||||
.map_err(|e| anyhow::anyhow!(e))
|
||||
}
|
||||
|
||||
@@ -262,13 +166,10 @@ pub fn translate_pt2(
|
||||
let translated = translator::translate(&parsed)?;
|
||||
let mut graph = translated.graph;
|
||||
|
||||
// Set initial dynamic dim values from symbol ranges. PT2 emits
|
||||
// `min_val: null` when the constraint is unbounded; fall back to 1 in
|
||||
// that case (the smallest valid dim — used only as an initial value).
|
||||
// Set initial dynamic dim values from symbol ranges
|
||||
for (sym_name, c) in &translated.sym_map.sym_to_char {
|
||||
if let Some(rc) = translated.sym_map.ranges.get(sym_name) {
|
||||
let initial = rc.min_val.unwrap_or(1).max(0) as usize;
|
||||
graph.set_dim(*c, initial);
|
||||
graph.set_dim(*c, rc.min_val as usize);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -284,14 +185,14 @@ pub fn translate_pt2(
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Preserve original PT2 dtype codes for outputs (e.g. 5 = int64) so the
|
||||
// Python wrapper can return tensors with the right torch.dtype, even when
|
||||
// luminal collapses the type internally (e.g. int64 → DType::Int).
|
||||
let output_dtypes: Vec<u32> = translated
|
||||
let output_dtypes: Vec<DType> = translated
|
||||
.output_ids
|
||||
.iter()
|
||||
.map(|(name, _id)| {
|
||||
parsed.tensor_meta(name).map(|meta| meta.dtype).unwrap_or(7) // default to f32
|
||||
parsed
|
||||
.tensor_meta(name)
|
||||
.map(|meta| pt2_util::torch_dtype_int_to_luminal(meta.dtype))
|
||||
.unwrap_or(DType::F32)
|
||||
})
|
||||
.collect();
|
||||
|
||||
@@ -555,3 +456,72 @@ fn bytes_to_typed(bytes: &[u8], dtype: u32) -> TypedData {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::CompileOptions;
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::PyDict;
|
||||
use std::sync::Once;
|
||||
|
||||
fn with_python(f: impl FnOnce(Python<'_>)) {
|
||||
static INIT: Once = Once::new();
|
||||
INIT.call_once(Python::initialize);
|
||||
Python::attach(f);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compile_options_defaults_apply() {
|
||||
let options = CompileOptions::from_py(None).unwrap();
|
||||
assert_eq!(options.search_iterations, 10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compile_options_dict_overlays_defaults() {
|
||||
with_python(|py| {
|
||||
let options = PyDict::new(py);
|
||||
options.set_item("search_iterations", 3).unwrap();
|
||||
|
||||
let parsed = CompileOptions::from_py(Some(options.as_any())).unwrap();
|
||||
assert_eq!(parsed.search_iterations, 3);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compile_options_reject_unknown_keys() {
|
||||
with_python(|py| {
|
||||
let options = PyDict::new(py);
|
||||
options.set_item("unknown", 1).unwrap();
|
||||
|
||||
let err = CompileOptions::from_py(Some(options.as_any())).unwrap_err();
|
||||
assert!(err.is_instance_of::<pyo3::exceptions::PyValueError>(py));
|
||||
assert!(
|
||||
err.to_string()
|
||||
.contains("Unsupported luminal backend option 'unknown'")
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compile_options_reject_non_dict() {
|
||||
with_python(|py| {
|
||||
let options = 123usize.into_pyobject(py).unwrap();
|
||||
|
||||
let err = CompileOptions::from_py(Some(options.as_any())).unwrap_err();
|
||||
assert!(err.is_instance_of::<pyo3::exceptions::PyTypeError>(py));
|
||||
assert!(err.to_string().contains("options must be a dict"));
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compile_options_reject_bad_search_iterations_type() {
|
||||
with_python(|py| {
|
||||
let options = PyDict::new(py);
|
||||
options.set_item("search_iterations", "fast").unwrap();
|
||||
|
||||
let err = CompileOptions::from_py(Some(options.as_any())).unwrap_err();
|
||||
assert!(err.is_instance_of::<pyo3::exceptions::PyTypeError>(py));
|
||||
assert!(err.to_string().contains("search_iterations"));
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,16 +15,7 @@ pub struct ExportedProgram {
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct RangeConstraint {
|
||||
/// Lower bound on a symbolic dimension. PT2 emits `null` when the
|
||||
/// constraint is unbounded (no min set), so this must accept None.
|
||||
#[serde(default)]
|
||||
pub min_val: Option<i64>,
|
||||
/// Upper bound on a symbolic dimension. Also nullable in PT2. Currently
|
||||
/// unused on the luminal side, but accepted to avoid deserialization
|
||||
/// errors when PT2 emits it.
|
||||
#[serde(default)]
|
||||
#[allow(dead_code)]
|
||||
pub max_val: Option<i64>,
|
||||
pub min_val: i64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
|
||||
@@ -1,195 +0,0 @@
|
||||
use anyhow::{Context, Result};
|
||||
use luminal::prelude::*;
|
||||
|
||||
use crate::pt2_schema::*;
|
||||
use crate::pt2_util::*;
|
||||
|
||||
use super::Translator;
|
||||
|
||||
/// Which SDPA variant we're translating. Governs argument positions and
|
||||
/// which output slots are consumed downstream.
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub enum SdpaVariant {
|
||||
/// `aten._scaled_dot_product_efficient_attention.default(q, k, v, attn_bias,
|
||||
/// compute_log_sumexp, dropout_p=0., is_causal=False, *, scale=None)
|
||||
/// -> (output, log_sumexp, philox_seed, philox_offset)`
|
||||
Efficient,
|
||||
/// `aten._scaled_dot_product_flash_attention.default(q, k, v, dropout_p=0.,
|
||||
/// is_causal=False, return_debug_mask=False, *, scale=None)
|
||||
/// -> (output, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k,
|
||||
/// rng_state, unused, debug_attn_mask)`
|
||||
Flash,
|
||||
/// `aten._scaled_dot_product_flash_attention_for_cpu.default(q, k, v,
|
||||
/// dropout_p=0., is_causal=False, *, attn_mask=None, scale=None)
|
||||
/// -> (output, logsumexp)`
|
||||
FlashForCpu,
|
||||
/// `aten._scaled_dot_product_cudnn_attention.default(q, k, v, attn_bias,
|
||||
/// compute_log_sumexp, dropout_p=0., is_causal=False,
|
||||
/// return_debug_mask=False, *, scale=None)
|
||||
/// -> (output, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k,
|
||||
/// philox_seed, philox_offset, debug_attn_mask)`
|
||||
Cudnn,
|
||||
/// `aten.scaled_dot_product_attention.default(q, k, v, attn_mask=None,
|
||||
/// dropout_p=0., is_causal=False, *, scale=None, enable_gqa=False)
|
||||
/// -> Tensor` (single output, no tuple).
|
||||
Unified,
|
||||
}
|
||||
|
||||
impl<'a> Translator<'a> {
|
||||
/// Translate any SDPA op variant into `softmax((Q@K^T)*scale + causal_mask +
|
||||
/// attn_bias) @ V`. Stores the primary `output` by the node's first output
|
||||
/// name. Other tuple outputs (logsumexp, philox_seed, etc.) are unused in
|
||||
/// inference — left unbound; the downstream `getitem(node, 0)` resolves
|
||||
/// to `output` via the tuple-output name list.
|
||||
pub(crate) fn translate_sdpa(&mut self, node: &Node, variant: SdpaVariant) -> Result<()> {
|
||||
let query = self.get_input_tensor(node, 0)?;
|
||||
let key = self.get_input_tensor(node, 1)?;
|
||||
let value = self.get_input_tensor(node, 2)?;
|
||||
|
||||
// Resolve args by NAME rather than positional index. PT2 serializes
|
||||
// kwargs inline in `node.inputs` with `kind=2`, so any arg that wasn't
|
||||
// passed positionally by the caller shifts the indices of subsequent
|
||||
// positional args. Name-based lookup is unambiguous across variants
|
||||
// and across caller argument-passing styles.
|
||||
let arg_by_name =
|
||||
|name: &str| -> Option<&NodeInput> { node.inputs.iter().find(|i| i.name == name) };
|
||||
let tensor_arg = |name: &str| -> Option<GraphTensor> {
|
||||
arg_by_name(name)
|
||||
.and_then(|i| i.arg.as_tensor_name())
|
||||
.and_then(|n| self.get_tensor(n).ok())
|
||||
};
|
||||
let float_arg =
|
||||
|name: &str| -> Option<f64> { arg_by_name(name).and_then(|i| i.arg.as_float()) };
|
||||
let bool_arg =
|
||||
|name: &str| -> Option<bool> { arg_by_name(name).and_then(|i| i.arg.as_bool()) };
|
||||
|
||||
// attn_bias (Efficient/Cudnn/Unified) or attn_mask (FlashForCpu/Unified).
|
||||
let additive = tensor_arg("attn_bias").or_else(|| tensor_arg("attn_mask"));
|
||||
|
||||
let dropout_p = float_arg("dropout_p").unwrap_or(0.0) as f32;
|
||||
anyhow::ensure!(
|
||||
dropout_p == 0.0,
|
||||
"SDPA: dropout_p={dropout_p} unsupported (inference only)"
|
||||
);
|
||||
let is_causal = bool_arg("is_causal").unwrap_or(false);
|
||||
// Silence compiler warnings — variant arg remains for branch-specific
|
||||
// logic (output tuple-name resolution below) and for future divergence.
|
||||
let _ = variant;
|
||||
|
||||
// `scale` kwarg, default 1/sqrt(head_dim).
|
||||
let head_dim = query
|
||||
.shape
|
||||
.dims
|
||||
.last()
|
||||
.and_then(|d| d.to_usize())
|
||||
.context("SDPA: query head_dim must be concrete")?;
|
||||
let default_scale = 1.0_f32 / (head_dim as f32).sqrt();
|
||||
let scale = float_arg("scale")
|
||||
.map(|v| v as f32)
|
||||
.unwrap_or(default_scale);
|
||||
|
||||
// Math form: scores = (Q @ K^T) * scale; + causal_mask; + attn_bias;
|
||||
// attn = softmax(scores, dim=-1); out = attn @ V.
|
||||
let q_ndim = query.shape.len();
|
||||
anyhow::ensure!(
|
||||
q_ndim >= 2,
|
||||
"SDPA: query must have at least 2 dims (got {q_ndim})"
|
||||
);
|
||||
// Transpose last two dims of key.
|
||||
let mut perm: Vec<usize> = (0..q_ndim).collect();
|
||||
perm.swap(q_ndim - 2, q_ndim - 1);
|
||||
let key_t = key.permute(perm);
|
||||
let (q_for_mm, k_for_mm) = ensure_same_dtype(query, key_t);
|
||||
let scores = q_for_mm.matmul(k_for_mm);
|
||||
let scale_t = self
|
||||
.graph
|
||||
.constant_float(scale)
|
||||
.cast(scores.dtype)
|
||||
.expand_rhs(scores.shape);
|
||||
let mut scores = scores * scale_t;
|
||||
|
||||
if is_causal {
|
||||
let s_q = scores
|
||||
.shape
|
||||
.dims
|
||||
.get(q_ndim - 2)
|
||||
.and_then(|d| d.to_usize())
|
||||
.context("SDPA is_causal: S_q must be concrete")?;
|
||||
let s_k = scores
|
||||
.shape
|
||||
.dims
|
||||
.get(q_ndim - 1)
|
||||
.and_then(|d| d.to_usize())
|
||||
.context("SDPA is_causal: S_k must be concrete")?;
|
||||
let size = s_q.max(s_k);
|
||||
// triu with diagonal=1 = 1 strictly above diagonal, 0 elsewhere.
|
||||
let mut mask = self.graph.triu(size, 1).cast(DType::F32);
|
||||
if s_q != size || s_k != size {
|
||||
mask = mask.slice_along(0..s_q, 0).slice_along(0..s_k, 1);
|
||||
}
|
||||
// -1e9 * mask ≈ -inf where masked, 0 otherwise. Broadcast across
|
||||
// batch/head prefix dims of `scores`.
|
||||
let neg_large = mask * (-1e9_f32);
|
||||
let mut neg_large = neg_large.cast(scores.dtype);
|
||||
for _ in 0..(q_ndim - 2) {
|
||||
neg_large = neg_large.expand_dim(0, Expression::from(1usize));
|
||||
}
|
||||
let (scores_b, mask_b) = broadcast_binary(scores, neg_large);
|
||||
scores = scores_b + mask_b;
|
||||
}
|
||||
if let Some(bias) = additive {
|
||||
let (scores_b, bias_b) = ensure_same_dtype(scores, bias);
|
||||
let (scores_b, bias_b) = broadcast_binary(scores_b, bias_b);
|
||||
scores = scores_b + bias_b;
|
||||
}
|
||||
|
||||
let attn = scores.softmax(q_ndim - 1);
|
||||
let (attn, value) = ensure_same_dtype(attn, value);
|
||||
let out = attn.matmul(value);
|
||||
|
||||
// Store the primary output by name. The other tuple outputs are
|
||||
// inference-time dead ends — downstream getitem(node, 0) resolves to
|
||||
// the same tensor name we bind here, because pt2 serializes the
|
||||
// multi-output name list with output[0] as the primary slot.
|
||||
let out_name = if let Some(ts) = node.outputs.first().and_then(|o| o.as_tensors.as_ref()) {
|
||||
ts.first().map(|t| t.name.clone())
|
||||
} else if variant == SdpaVariant::Unified {
|
||||
node.outputs
|
||||
.first()
|
||||
.and_then(|o| o.as_tensor.as_ref().map(|t| t.name.clone()))
|
||||
} else {
|
||||
node.outputs
|
||||
.first()
|
||||
.and_then(|o| o.as_tensor.as_ref().map(|t| t.name.clone()))
|
||||
.or_else(|| {
|
||||
node.outputs
|
||||
.first()
|
||||
.and_then(|o| o.as_tensors.as_ref())
|
||||
.and_then(|ts| ts.first().map(|t| t.name.clone()))
|
||||
})
|
||||
};
|
||||
|
||||
if let Some(name) = out_name
|
||||
&& !name.is_empty()
|
||||
{
|
||||
self.tensors.insert(name, out);
|
||||
} else {
|
||||
anyhow::bail!("SDPA: no output tensor name found on node {}", node.target);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq for SdpaVariant {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
matches!(
|
||||
(self, other),
|
||||
(SdpaVariant::Efficient, SdpaVariant::Efficient)
|
||||
| (SdpaVariant::Flash, SdpaVariant::Flash)
|
||||
| (SdpaVariant::FlashForCpu, SdpaVariant::FlashForCpu)
|
||||
| (SdpaVariant::Cudnn, SdpaVariant::Cudnn)
|
||||
| (SdpaVariant::Unified, SdpaVariant::Unified)
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -173,7 +173,7 @@ impl<'a> Translator<'a> {
|
||||
|
||||
if let Some(b) = bias {
|
||||
let out_dims = out.dims();
|
||||
let mut b_expanded = b.expand_dim(0, out_dims[0]);
|
||||
let mut b_expanded = b.expand_dim(0, 1);
|
||||
for i in 0..spatial {
|
||||
b_expanded = b_expanded.expand_dim(2 + i, out_dims[2 + i]);
|
||||
}
|
||||
@@ -389,11 +389,8 @@ fn depthwise_conv(
|
||||
// Expand to [N, C, group_out, out_spatial_product, kernel_product]
|
||||
let patches = patches.expand_dim(2, group_out);
|
||||
|
||||
// Explicitly expand weight across the batch axis so the elementwise Mul
|
||||
// sees equal visible shapes. HLIR binary ops do not perform broadcasting.
|
||||
let w_expanded = w_flat
|
||||
.expand_dim(0, patches.dims()[0])
|
||||
.expand_dim(3, patches.dims()[3]);
|
||||
// Expand weight for broadcast: [1, C, group_out, out_spatial_product, kernel_product]
|
||||
let w_expanded = w_flat.expand_dim(0, 1).expand_dim(3, patches.dims()[3]);
|
||||
|
||||
// Element-wise multiply and sum over kernel dim
|
||||
let product = patches * w_expanded;
|
||||
|
||||
@@ -5,8 +5,6 @@ use crate::pt2_schema::*;
|
||||
use crate::pt2_util::*;
|
||||
|
||||
use super::Translator;
|
||||
use super::attention::SdpaVariant;
|
||||
use super::reduction::ArgExtremum;
|
||||
|
||||
impl<'a> Translator<'a> {
|
||||
pub(crate) fn translate_node(&mut self, node: &Node) -> Result<()> {
|
||||
@@ -70,8 +68,6 @@ impl<'a> Translator<'a> {
|
||||
"torch.ops.aten.sigmoid.default" => self.translate_unary_op(node, |a| a.sigmoid())?,
|
||||
"torch.ops.aten.relu.default" => self.translate_unary_op(node, |a| a.relu())?,
|
||||
"torch.ops.aten.tanh.default" => self.translate_unary_op(node, |a| a.tanh())?,
|
||||
"torch.ops.aten.silu.default" => self.translate_unary_op(node, |a| a.silu())?,
|
||||
"torch.ops.aten.gelu.default" => self.translate_unary_op(node, |a| a.gelu())?,
|
||||
"torch.ops.aten.abs.default" => self.translate_unary_op(node, |a| a.abs())?,
|
||||
"torch.ops.aten.log.default" => self.translate_unary_op(node, |a| a.log())?,
|
||||
"torch.ops.aten.log2.default" => self.translate_unary_op(node, |a| a.log2())?,
|
||||
@@ -148,7 +144,6 @@ impl<'a> Translator<'a> {
|
||||
|
||||
// Slice/index ops
|
||||
"torch.ops.aten.slice.Tensor" => self.translate_slice(node)?,
|
||||
"torch.ops.aten.select.int" => self.translate_select(node)?,
|
||||
"torch.ops.aten.cat.default" => self.translate_cat(node)?,
|
||||
"torch.ops.aten.index.Tensor" => self.translate_index_tensor(node)?,
|
||||
|
||||
@@ -188,28 +183,6 @@ impl<'a> Translator<'a> {
|
||||
"torch.ops.aten.arange.start_step" => self.translate_arange(node)?,
|
||||
"torch.ops.aten.full.default" => self.translate_full(node)?,
|
||||
"torch.ops.aten.full_like.default" => self.translate_full_like(node)?,
|
||||
// `empty` and `empty_permuted` allocate uninitialised tensors of
|
||||
// a given shape; the caller fills them. We lower to zeros with
|
||||
// the same shape+dtype — downstream reads are officially UB on
|
||||
// PyTorch's side, and downstream writes overwrite our zeros.
|
||||
// Qwen3MoE's MoE block uses `empty_permuted` to allocate the
|
||||
// expert-output staging tensor before scatter-adding into it.
|
||||
"torch.ops.aten.empty.memory_format" | "torch.ops.aten.empty_permuted.default" => {
|
||||
self.translate_empty(node)?
|
||||
}
|
||||
// Qwen3-MoE's expert-balance counts tokens-per-expert via histc.
|
||||
"torch.ops.aten.histc.default" => self.translate_histc(node)?,
|
||||
|
||||
// Grouped matmul (MoE expert dispatch).
|
||||
// aten._grouped_mm is the native op; transformers::grouped_mm_fallback
|
||||
// is a Python-implemented custom_op (transformers/integrations/moe.py)
|
||||
// used by HF MoE when _grouped_mm isn't available for the activation
|
||||
// dtype. Both have identical (input, weight, offs) signature; route
|
||||
// both through the same batched-matmul + group-mask lowering.
|
||||
"torch.ops.aten._grouped_mm.default"
|
||||
| "torch.ops.transformers.grouped_mm_fallback.default" => {
|
||||
self.translate_grouped_mm(node)?
|
||||
}
|
||||
"torch.ops.aten.scalar_tensor.default" => {
|
||||
let val = self.get_float_arg(node, 0)? as f32;
|
||||
self.graph.constant_float(val)
|
||||
@@ -221,16 +194,6 @@ impl<'a> Translator<'a> {
|
||||
"torch.ops.aten.le.Scalar" => self.translate_scalar_comparison(node, |a, s| a.le(s))?,
|
||||
|
||||
// Tensor comparisons
|
||||
"torch.ops.aten.eq.Scalar" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let val = self.get_float_arg(node, 1)? as f32;
|
||||
let scalar = self
|
||||
.graph
|
||||
.constant_float(val)
|
||||
.cast(a.dtype)
|
||||
.expand_rhs(a.shape);
|
||||
a.eq(scalar)
|
||||
}
|
||||
"torch.ops.aten.ne.Scalar" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let val = self.get_float_arg(node, 1)? as f32;
|
||||
@@ -248,13 +211,6 @@ impl<'a> Translator<'a> {
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
a.eq(b)
|
||||
}
|
||||
"torch.ops.aten.ne.Tensor" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let b = self.get_input_tensor(node, 1)?;
|
||||
let (a, b) = ensure_same_dtype(a, b);
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
a.ne(b)
|
||||
}
|
||||
"torch.ops.aten.le.Tensor" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let b = self.get_input_tensor(node, 1)?;
|
||||
@@ -270,11 +226,7 @@ impl<'a> Translator<'a> {
|
||||
let b = b.cast(DType::F32);
|
||||
(a * b).cast(DType::Bool)
|
||||
}
|
||||
"torch.ops.aten.bitwise_or.Tensor" | "torch.ops.aten.logical_or.default" => {
|
||||
// Both arms use the same bool-OR lowering. Gemma-4's sliding+full
|
||||
// attention mask fusion emits bitwise_or on boolean tensors; the
|
||||
// integer semantics of bitwise_or aren't exercised by any op in
|
||||
// the test suite, so we rely on inputs being boolean-typed.
|
||||
"torch.ops.aten.logical_or.default" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let b = self.get_input_tensor(node, 1)?;
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
@@ -293,27 +245,18 @@ impl<'a> Translator<'a> {
|
||||
|
||||
// Clamp
|
||||
"torch.ops.aten.clamp.default" => self.translate_clamp(node)?,
|
||||
"torch.ops.aten.clamp.Tensor" => self.translate_clamp_tensor(node)?,
|
||||
|
||||
// Cumsum
|
||||
"torch.ops.aten.cumsum.default" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let dim = self.get_int_arg(node, 1)?;
|
||||
let dim = normalize_dim(dim, a.shape.len());
|
||||
let a = if a.dtype == DType::Bool {
|
||||
a.cast(DType::Int)
|
||||
} else {
|
||||
a
|
||||
};
|
||||
// Rank-0 (scalar) input: cumsum of a single element is the element
|
||||
// itself. PyTorch eager treats `dim=0` on a 0-d as an identity op,
|
||||
// and the underlying `cumop` indexes `shape.dims[axis]` which would
|
||||
// panic with empty dims.
|
||||
if a.shape.is_empty() {
|
||||
a
|
||||
} else {
|
||||
let dim = self.get_int_arg(node, 1)?;
|
||||
let dim = normalize_dim(dim, a.shape.len());
|
||||
a.cumsum(dim)
|
||||
}
|
||||
a.cumsum(dim)
|
||||
}
|
||||
|
||||
// Floor / Ceil / Erf (approximations)
|
||||
@@ -326,14 +269,12 @@ impl<'a> Translator<'a> {
|
||||
}
|
||||
"torch.ops.aten.ceil.default" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
// ceil(x) = trunc(x) + (x > trunc(x)).
|
||||
// Cast-to-Int rounds toward zero, so for any positive fractional
|
||||
// `x` the trunc sits below `x` and we add 1; for negatives we
|
||||
// have `trunc >= x` and adjust=0. Avoids the two extra
|
||||
// mul-by-(-1) nodes that the `-floor(-x)` lowering emits.
|
||||
let trunc = a.cast(DType::Int).cast(DType::F32);
|
||||
let adjust = a.gt(trunc).cast(DType::F32);
|
||||
trunc + adjust
|
||||
// ceil(x) = -floor(-x)
|
||||
let neg_a = a * (-1.0);
|
||||
let trunc = neg_a.cast(DType::Int).cast(DType::F32);
|
||||
let adjust = neg_a.lt(trunc).cast(DType::F32);
|
||||
let floor_neg = trunc - adjust;
|
||||
floor_neg * (-1.0)
|
||||
}
|
||||
"torch.ops.aten.erf.default" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
@@ -409,17 +350,6 @@ impl<'a> Translator<'a> {
|
||||
"torch.ops.aten.max.default" => self.translate_reduction(node, ReductionOp::Max)?,
|
||||
"torch.ops.aten.min.default" => self.translate_reduction(node, ReductionOp::Min)?,
|
||||
"torch.ops.aten.amin.default" => self.translate_reduction(node, ReductionOp::Min)?,
|
||||
"torch.ops.aten.prod.default" => self.translate_reduction(node, ReductionOp::Prod)?,
|
||||
|
||||
// Argmax / argmin — built on top of `stable_argsort` (LUM-496).
|
||||
// PyTorch's argmax/argmin returns int64; the dtype is preserved
|
||||
// through the LUM-486 boundary widening.
|
||||
"torch.ops.aten.argmax.default" => {
|
||||
self.translate_argextremum(node, ArgExtremum::Max)?
|
||||
}
|
||||
"torch.ops.aten.argmin.default" => {
|
||||
self.translate_argextremum(node, ArgExtremum::Min)?
|
||||
}
|
||||
|
||||
// Gather (axis-aware)
|
||||
"torch.ops.aten.gather.default" => self.translate_gather(node)?,
|
||||
@@ -450,29 +380,6 @@ impl<'a> Translator<'a> {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Scaled dot-product attention — each variant binds args slightly
|
||||
// differently but all lower to matmul+softmax via translate_sdpa.
|
||||
"torch.ops.aten._scaled_dot_product_efficient_attention.default" => {
|
||||
self.translate_sdpa(node, SdpaVariant::Efficient)?;
|
||||
return Ok(());
|
||||
}
|
||||
"torch.ops.aten._scaled_dot_product_flash_attention.default" => {
|
||||
self.translate_sdpa(node, SdpaVariant::Flash)?;
|
||||
return Ok(());
|
||||
}
|
||||
"torch.ops.aten._scaled_dot_product_flash_attention_for_cpu.default" => {
|
||||
self.translate_sdpa(node, SdpaVariant::FlashForCpu)?;
|
||||
return Ok(());
|
||||
}
|
||||
"torch.ops.aten._scaled_dot_product_cudnn_attention.default" => {
|
||||
self.translate_sdpa(node, SdpaVariant::Cudnn)?;
|
||||
return Ok(());
|
||||
}
|
||||
"torch.ops.aten.scaled_dot_product_attention.default" => {
|
||||
self.translate_sdpa(node, SdpaVariant::Unified)?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Split
|
||||
"torch.ops.aten.split_with_sizes.default" => self.translate_split_with_sizes(node)?,
|
||||
|
||||
@@ -483,28 +390,6 @@ impl<'a> Translator<'a> {
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
a % b
|
||||
}
|
||||
// Remainder (Python-style modulo). For float tensors aten.remainder
|
||||
// returns the same value as `%` would in luminal (Mod follows the
|
||||
// language's % semantics on f32). The Tensor variant accepts a
|
||||
// tensor RHS that may be rank-0; broadcast both operands so a
|
||||
// scalar RHS is expanded to match the LHS shape before mod.
|
||||
"torch.ops.aten.remainder.Tensor" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let b = self.get_input_tensor(node, 1)?;
|
||||
let (a, b) = ensure_same_dtype(a, b);
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
a % b
|
||||
}
|
||||
"torch.ops.aten.remainder.Scalar" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let val = self.get_float_arg(node, 1)? as f32;
|
||||
let scalar = self
|
||||
.graph
|
||||
.constant_float(val)
|
||||
.cast(a.dtype)
|
||||
.expand_rhs(a.shape);
|
||||
a % scalar
|
||||
}
|
||||
// Prod reduction
|
||||
"torch.ops.aten.prod.dim_int" => self.translate_reduction(node, ReductionOp::Prod)?,
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
//!
|
||||
//! Walks the parsed PT2 graph and constructs an equivalent Luminal computation graph.
|
||||
|
||||
mod attention;
|
||||
mod binary;
|
||||
mod conv;
|
||||
mod dispatch;
|
||||
@@ -188,21 +187,8 @@ impl<'a> Translator<'a> {
|
||||
.get(idx)
|
||||
.with_context(|| format!("Node {} missing input {idx}", node.target))?
|
||||
.arg;
|
||||
if let Some(v) = arg.as_int() {
|
||||
return Ok(v);
|
||||
}
|
||||
// Fall through to symbolic-aware resolution. Op-arg slots like `dim`
|
||||
// and `axis` are always concrete in practice, but with dynamic shapes
|
||||
// PT2 occasionally hands us a SymInt that is fully bound at export
|
||||
// time (e.g. an `unsqueeze` whose dim was derived from `len(shape)`);
|
||||
// accept those when they reduce to a concrete int rather than failing
|
||||
// with the misleading "not an int" diagnostic.
|
||||
if let Some(expr) = self.resolve_arg_as_expression(arg)
|
||||
&& let Some(v) = expr.to_usize()
|
||||
{
|
||||
return Ok(v as i64);
|
||||
}
|
||||
anyhow::bail!("Input {idx} of {} is not an int: {:?}", node.target, arg)
|
||||
arg.as_int()
|
||||
.with_context(|| format!("Input {idx} of {} is not an int: {:?}", node.target, arg))
|
||||
}
|
||||
|
||||
pub(crate) fn get_float_arg(&self, node: &Node, idx: usize) -> Result<f64> {
|
||||
@@ -221,37 +207,11 @@ impl<'a> Translator<'a> {
|
||||
}
|
||||
|
||||
pub(crate) fn get_ints_arg(&self, node: &Node, idx: usize) -> Result<Vec<i64>> {
|
||||
use crate::pt2_schema::SymIntEntry;
|
||||
let arg = &node
|
||||
.inputs
|
||||
.get(idx)
|
||||
.with_context(|| format!("Node {} missing input {idx}", node.target))?
|
||||
.arg;
|
||||
// Symbolic int lists: tolerate them as long as every entry is a
|
||||
// bound concrete value. Prevents false "not an int list" failures on
|
||||
// graphs where torch.export emits sym_ints for what is dimensionally
|
||||
// a static parameter (kernel sizes, etc. with dynamic batch).
|
||||
if let Some(entries) = arg.as_sym_ints() {
|
||||
let mut out = Vec::with_capacity(entries.len());
|
||||
for entry in entries {
|
||||
let v = match entry {
|
||||
SymIntEntry::Int(i) => Some(i.as_int),
|
||||
SymIntEntry::Name(s) => self
|
||||
.resolve_sym_int(&s.as_name)
|
||||
.and_then(|e| e.to_usize().map(|u| u as i64)),
|
||||
};
|
||||
match v {
|
||||
Some(n) => out.push(n),
|
||||
None => {
|
||||
anyhow::bail!(
|
||||
"Input {idx} of {} contains an unresolved sym_int entry",
|
||||
node.target
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
return Ok(out);
|
||||
}
|
||||
arg.as_ints()
|
||||
.map(|v| v.to_vec())
|
||||
.with_context(|| format!("Input {idx} of {} is not int list: {:?}", node.target, arg))
|
||||
|
||||
@@ -120,47 +120,6 @@ impl<'a> Translator<'a> {
|
||||
Ok(a.slice_along(start..end, dim))
|
||||
}
|
||||
|
||||
/// `aten.select.int(self, dim, index)` — select element `index` along
|
||||
/// `dim`, dropping that dim. Output rank = input rank − 1, so a 1-D input
|
||||
/// produces a rank-0 scalar. Both `dim` and `index` may be negative and
|
||||
/// are normalized against the input shape.
|
||||
///
|
||||
/// Lowered as `slice_along(index..index+1, dim).squeeze(dim)`. We use the
|
||||
/// slice + squeeze decomposition (rather than `gather`) because the
|
||||
/// composition is a pure shape manipulation with a single iota, which the
|
||||
/// luminal compiler can fold into surrounding ops.
|
||||
pub(crate) fn translate_select(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let dim = self.get_int_arg(node, 1)?;
|
||||
let dim = normalize_dim(dim, a.shape.len());
|
||||
let index_raw = self.get_int_arg(node, 2)?;
|
||||
|
||||
// Normalize a possibly-negative index. PyTorch accepts indices in
|
||||
// [-size, size); negative wraps from the end.
|
||||
let index = if index_raw < 0 {
|
||||
let axis_size = a.shape.dims[dim].to_usize().ok_or_else(|| {
|
||||
anyhow::anyhow!(
|
||||
"select.int: dim {} must be concrete to normalize a negative index",
|
||||
dim
|
||||
)
|
||||
})?;
|
||||
let normalized = axis_size as i64 + index_raw;
|
||||
if normalized < 0 {
|
||||
bail!(
|
||||
"select.int: index {} out of range for dim {} of size {}",
|
||||
index_raw,
|
||||
dim,
|
||||
axis_size
|
||||
);
|
||||
}
|
||||
normalized as usize
|
||||
} else {
|
||||
index_raw as usize
|
||||
};
|
||||
|
||||
Ok(a.slice_along(index..index + 1, dim).squeeze(dim))
|
||||
}
|
||||
|
||||
pub(crate) fn translate_cat(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let tensors: Vec<GraphTensor> = if let Some(names) = node.inputs[0].arg.as_tensors() {
|
||||
names
|
||||
@@ -300,15 +259,21 @@ impl<'a> Translator<'a> {
|
||||
for (dim_idx, idx_name) in index_names.iter().enumerate() {
|
||||
let idx_tensor = self.get_tensor(&idx_name.name)?;
|
||||
|
||||
// Normalize negative indices for this dimension. Stay in Int —
|
||||
// multiplying an Int tensor by an Expression broadcasts the axis
|
||||
// size, so we avoid three Cast nodes (Int→F32 for indices, F32→Int
|
||||
// for the result, Bool→F32 for the negative mask) per indexed dim.
|
||||
let axis_size = src_shape[dim_idx];
|
||||
let idx_int = idx_tensor.cast(DType::Int);
|
||||
let zero = self.graph.constant(0).expand_rhs(idx_int.shape);
|
||||
let is_negative = idx_int.lt(zero).cast(DType::Int);
|
||||
let idx_int = idx_int + is_negative * axis_size;
|
||||
// Normalize negative indices for this dimension
|
||||
let axis_size = src_shape[dim_idx].to_usize().ok_or_else(|| {
|
||||
anyhow::anyhow!(
|
||||
"index.Tensor: dim {} must be concrete for negative index normalization",
|
||||
dim_idx
|
||||
)
|
||||
})?;
|
||||
let idx_f32 = idx_tensor.cast(DType::F32);
|
||||
let zero = self.graph.constant_float(0.0).expand_rhs(idx_f32.shape);
|
||||
let adjustment = self
|
||||
.graph
|
||||
.constant_float(axis_size as f32)
|
||||
.expand_rhs(idx_f32.shape);
|
||||
let is_negative = idx_f32.lt(zero).cast(DType::F32);
|
||||
let idx_int = (idx_f32 + is_negative * adjustment).cast(DType::Int);
|
||||
|
||||
let stride = &strides[dim_idx];
|
||||
let weighted = if stride.to_usize() == Some(1) {
|
||||
@@ -374,34 +339,20 @@ impl<'a> Translator<'a> {
|
||||
let dim = normalize_dim(dim, a.shape.len());
|
||||
let indices = self.get_input_tensor(node, 2)?;
|
||||
|
||||
// PyTorch eager allows torch.gather(rank-1, 0, rank-0) and returns
|
||||
// a rank-0 scalar — the only rank-mismatch case eager permits. Our
|
||||
// gather_elements requires the index rank to match the source rank,
|
||||
// so unsqueeze the rank-0 index to (1,), gather, then squeeze back.
|
||||
let promoted_rank0 = indices.shape.is_empty() && a.shape.len() == 1;
|
||||
let indices = if promoted_rank0 {
|
||||
indices.unsqueeze(0)
|
||||
} else {
|
||||
indices
|
||||
};
|
||||
|
||||
// Normalize negative indices: -1 → last, -2 → second-to-last, etc.
|
||||
// Stay in Int the whole way — multiplying an Int tensor by an
|
||||
// Expression broadcasts the axis size and avoids three Cast nodes
|
||||
// (Int→F32 for indices, F32→Int for the result, plus a Bool→F32 for
|
||||
// the negative mask) that the previous F32-routed path emitted.
|
||||
let axis_dim = a.shape.dims[dim];
|
||||
let indices_int = indices.cast(DType::Int);
|
||||
let zero = self.graph.constant(0).expand_rhs(indices_int.shape);
|
||||
let is_negative = indices_int.lt(zero).cast(DType::Int);
|
||||
let normalized = indices_int + is_negative * axis_dim;
|
||||
let axis_dim = a.shape.dims[dim].to_usize().ok_or_else(|| {
|
||||
anyhow::anyhow!("Gather: axis dim must be concrete for negative index normalization")
|
||||
})?;
|
||||
let indices_f32 = indices.cast(DType::F32);
|
||||
let zero = self.graph.constant_float(0.0).expand_rhs(indices_f32.shape);
|
||||
let adjustment = self
|
||||
.graph
|
||||
.constant_float(axis_dim as f32)
|
||||
.expand_rhs(indices_f32.shape);
|
||||
let is_negative = indices_f32.lt(zero).cast(DType::F32);
|
||||
let normalized = (indices_f32 + is_negative * adjustment).cast(DType::Int);
|
||||
|
||||
let result = a.gather_elements(normalized, dim);
|
||||
Ok(if promoted_rank0 {
|
||||
result.squeeze(0)
|
||||
} else {
|
||||
result
|
||||
})
|
||||
Ok(a.gather_elements(normalized, dim))
|
||||
}
|
||||
|
||||
pub(crate) fn translate_scatter_src(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
@@ -445,39 +396,14 @@ impl<'a> Translator<'a> {
|
||||
let values = self.get_input_tensor(node, 2)?;
|
||||
|
||||
if index_names.len() == 1 {
|
||||
let idx_tensor = self.get_tensor(&index_names[0].name)?;
|
||||
|
||||
// Boolean-mask index_put: when the only index is a Bool tensor whose
|
||||
// shape matches the data tensor, PyTorch semantics are
|
||||
// data[mask] = value ↔ where(mask, value, data)
|
||||
// NOT a scatter into positions. Casting the Bool mask to Int and
|
||||
// feeding it to scatter_nd would reinterpret True/False as row
|
||||
// indices 1/0 and silently corrupt the data. Reproducer:
|
||||
// x = arange(16).reshape(4, 4); mask = zeros(4, 4, dtype=bool)
|
||||
// y = x.clone(); y[mask] = 99 # eager: y == x (no-op)
|
||||
// Pre-fix the compiled graph wrote 99 to row 0; this branch
|
||||
// ensures the bool-mask path lowers to a where-blend instead.
|
||||
if idx_tensor.dtype == DType::Bool && idx_tensor.shape.dims == a.shape.dims {
|
||||
// Broadcast the (often scalar) value tensor to match data shape,
|
||||
// then blend by mask. Cast mask to data's dtype for the
|
||||
// arithmetic so this works for both integer and float data.
|
||||
let mask_f = idx_tensor.cast(a.dtype);
|
||||
let values_b = values.cast(a.dtype).expand_rhs(a.shape);
|
||||
// where(mask, value, a) as `a + mask*(value - a)`. Saves a mul
|
||||
// and the `1.0` constant compared to the `a*(1 - m) + v*m`
|
||||
// form; works for any numeric dtype without a dedicated cond.
|
||||
return Ok(a + mask_f * (values_b - a));
|
||||
}
|
||||
|
||||
// Integer-index scatter: index_put with indices=[idx_tensor] writes
|
||||
// into dim 0 of `a` at every position named in idx_tensor (flattened),
|
||||
// broadcasting values across the trailing dims of `a`. idx_tensor can
|
||||
// be ANY shape — its whole shape is "batch dims" in scatter_nd terms,
|
||||
// and K is always 1 (number of dims we're indexing into). Always pad
|
||||
// a trailing size-1 dim so the rank-1 and rank-N cases share a path.
|
||||
let indices = idx_tensor.cast(DType::Int);
|
||||
let new_last = indices.shape.len();
|
||||
let indices = indices.expand_dim(new_last, Expression::from(1usize));
|
||||
let indices = self.get_tensor(&index_names[0].name)?.cast(DType::Int);
|
||||
// scatter_nd expects indices of shape [batch, K] where K = number of index dims.
|
||||
// PT2's index_put gives 1D indices [batch]; reshape to [batch, 1].
|
||||
let indices = if indices.shape.len() == 1 {
|
||||
indices.expand_dim(1, Expression::from(1usize))
|
||||
} else {
|
||||
indices
|
||||
};
|
||||
Ok(a.scatter_nd(indices, values))
|
||||
} else {
|
||||
bail!("index_put with multiple index tensors not yet supported");
|
||||
|
||||
@@ -6,20 +6,6 @@ use crate::pt2_util::*;
|
||||
|
||||
use super::Translator;
|
||||
|
||||
/// Whether `argmax` / `argmin` should pick the largest (descending sort) or
|
||||
/// smallest (ascending sort) element when scanning the input.
|
||||
#[derive(Clone, Copy)]
|
||||
pub(crate) enum ArgExtremum {
|
||||
Max,
|
||||
Min,
|
||||
}
|
||||
|
||||
impl ArgExtremum {
|
||||
fn descending(self) -> bool {
|
||||
matches!(self, ArgExtremum::Max)
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute total element count, returning an error if any dimension is symbolic.
|
||||
fn concrete_numel(a: &GraphTensor) -> Result<usize> {
|
||||
a.dims().iter().try_fold(1usize, |acc, d| {
|
||||
@@ -51,26 +37,16 @@ impl<'a> Translator<'a> {
|
||||
(axes, keepdim)
|
||||
}
|
||||
_ => {
|
||||
// Full reduce: reduce over every axis, leaving a rank-0 (scalar) tensor.
|
||||
// PyTorch eager returns shape () for `x.sum()` etc., and downstream ops
|
||||
// (e.g. unsqueeze(0).expand(N)) rely on this rank.
|
||||
let ndim = a.shape.len();
|
||||
if ndim == 0 {
|
||||
// Already rank-0 — reducing over no axes is a no-op for sum/max/min/prod,
|
||||
// and mean of a scalar is just the scalar.
|
||||
return Ok(a);
|
||||
}
|
||||
// Full reduce: flatten to [1, N] and reduce axis 1
|
||||
let total = concrete_numel(&a)?;
|
||||
let axes: Vec<usize> = (0..ndim).collect();
|
||||
let mut flat = a;
|
||||
flat.shape = ShapeTracker::new(vec![1, total]);
|
||||
let result = match op {
|
||||
ReductionOp::Sum => a.sum(axes),
|
||||
// Note: the luminal `mean` helper divides by the product of the
|
||||
// axis dims, but we already require concrete dims here so we
|
||||
// divide by the cached `total` to avoid recomputing.
|
||||
ReductionOp::Mean => a.sum(axes) / total as f32,
|
||||
ReductionOp::Max => a.max(axes),
|
||||
ReductionOp::Min => a.min(axes),
|
||||
ReductionOp::Prod => a.prod(axes),
|
||||
ReductionOp::Sum => flat.sum(vec![1]),
|
||||
ReductionOp::Mean => flat.sum(vec![1]) / total as f32,
|
||||
ReductionOp::Max => flat.max(vec![1]),
|
||||
ReductionOp::Min => flat.min(vec![1]),
|
||||
ReductionOp::Prod => flat.prod(vec![1]),
|
||||
};
|
||||
return Ok(result);
|
||||
}
|
||||
@@ -94,100 +70,4 @@ impl<'a> Translator<'a> {
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Lower `aten.argmax.default` / `aten.argmin.default` by reusing the
|
||||
/// existing `stable_argsort` op and selecting the first index along the
|
||||
/// sort axis.
|
||||
///
|
||||
/// PyTorch signature: `argmax(self, dim=None, keepdim=False)` (likewise
|
||||
/// for argmin). FX export emits the inputs positionally:
|
||||
/// - input 0: tensor
|
||||
/// - input 1: dim (Int) or None (Other) — when `dim=None`
|
||||
/// - input 2: keepdim (Bool, optional)
|
||||
///
|
||||
/// When `dim=None`, PyTorch flattens the tensor; we mirror that by
|
||||
/// reshaping to a 1-D `[numel]` view (which requires concrete dims).
|
||||
/// The result of argsort along the sort axis is sliced at index 0,
|
||||
/// then squeezed away — i.e. `select(dim, 0)` — to give the index of
|
||||
/// the extremum. With `keepdim=True` we re-insert a size-1 dim at
|
||||
/// `dim`.
|
||||
///
|
||||
/// The slice + squeeze chain produces a non-contiguous `DType::Int`
|
||||
/// view; we materialize it with `* 1` so the resulting node has
|
||||
/// contiguous strides matching its visible shape (mirroring the
|
||||
/// `topk` lowering in `translate_topk`). Without this, the output
|
||||
/// buffer would be sized for the un-sliced argsort tensor while the
|
||||
/// shape tracker reports a smaller rank.
|
||||
///
|
||||
/// The output dtype is `DType::Int` (luminal's 32-bit int); PT2
|
||||
/// metadata records int64 and the Python wrapper widens at the
|
||||
/// boundary, so the PyTorch contract is preserved end-to-end
|
||||
/// (LUM-486).
|
||||
pub(crate) fn translate_argextremum(
|
||||
&mut self,
|
||||
node: &Node,
|
||||
which: ArgExtremum,
|
||||
) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
|
||||
// dim is positional input 1. PyTorch encodes `dim=None` as a non-Int
|
||||
// argument (typically `Argument::Other(Null)`), so a missing or
|
||||
// non-int slot means "reduce over the flattened tensor".
|
||||
let dim_opt: Option<i64> = if node.inputs.len() > 1 {
|
||||
self.get_int_arg(node, 1).ok()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let keepdim = if node.inputs.len() > 2 {
|
||||
self.get_bool_arg(node, 2).unwrap_or(false)
|
||||
} else {
|
||||
false
|
||||
};
|
||||
|
||||
if a.shape.is_empty() {
|
||||
match dim_opt {
|
||||
None | Some(0) | Some(-1) => {
|
||||
// PyTorch returns scalar index 0 for rank-0 argmax/argmin.
|
||||
// `keepdim=True` does not add a dimension when the input is 0-d.
|
||||
return Ok(self.graph.constant(0i64).cast(DType::Int));
|
||||
}
|
||||
Some(dim) => {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Dimension out of range (expected to be in range of [-1, 0], but got {dim})"
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let descending = which.descending();
|
||||
|
||||
let (sort_axis, base) = match dim_opt {
|
||||
None => {
|
||||
// Full-reduce: flatten to 1-D, argsort along axis 0.
|
||||
let total = concrete_numel(&a)?;
|
||||
let flat = reshape_tensor(a, vec![Expression::from(total)]);
|
||||
(0usize, flat)
|
||||
}
|
||||
Some(dim_raw) => {
|
||||
let dim = normalize_dim(dim_raw, a.shape.len());
|
||||
(dim, a)
|
||||
}
|
||||
};
|
||||
|
||||
// Pick index 0 along the sort axis. The slice-then-squeeze chain
|
||||
// produces a non-contiguous view whose physical buffer is still
|
||||
// sized for the un-sliced argsort tensor; the optional `keepdim`
|
||||
// unsqueeze adds a stride-0 axis which is also non-contiguous.
|
||||
// Materialize at the end with `* 1` so the resulting node has
|
||||
// contiguous strides matching its visible shape (matches the
|
||||
// pattern used by `translate_topk` for sliced index outputs).
|
||||
let sorted = base.stable_argsort(sort_axis, descending);
|
||||
let picked = sorted.slice_along(0..1, sort_axis).squeeze(sort_axis);
|
||||
let result = if keepdim {
|
||||
picked.unsqueeze(sort_axis)
|
||||
} else {
|
||||
picked
|
||||
};
|
||||
Ok(result * 1)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -72,97 +72,6 @@ impl<'a> Translator<'a> {
|
||||
})
|
||||
}
|
||||
|
||||
/// Lower `aten.histc.default` for the integer-bincount case.
|
||||
///
|
||||
/// Qwen3-MoE's expert-balance layer calls
|
||||
/// `torch.histc(expert_ids.int(), bins=K, min=0, max=K-1)` to count how
|
||||
/// many tokens were routed to each expert. With those args every
|
||||
/// integer value `i ∈ [0, K-1]` maps to exactly bin `i`, and the result
|
||||
/// is equivalent to `torch.bincount`. We implement that case as a
|
||||
/// broadcast equality + sum:
|
||||
///
|
||||
/// counts[b] = sum_i (input[i] == b + min) for b in [0, bins)
|
||||
///
|
||||
/// More general histc bin widths (`bins != max - min + 1`, or
|
||||
/// non-integer values that span fractional bins) are not supported
|
||||
/// today — the equality path would silently drop them. We bail rather
|
||||
/// than produce wrong counts.
|
||||
pub(crate) fn translate_histc(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let input = self.get_input_tensor(node, 0)?;
|
||||
let bins_i64: i64 = self
|
||||
.get_int_arg(node, 1)
|
||||
.context("histc: missing `bins` arg (#1)")?;
|
||||
// `min`/`max` are float kwargs (default 0.0 each, which means
|
||||
// "auto-pick from input"); for the qwen3-moe call they're always
|
||||
// integers passed as floats.
|
||||
let min = self.get_float_arg(node, 2).unwrap_or(0.0);
|
||||
let max = self.get_float_arg(node, 3).unwrap_or(0.0);
|
||||
|
||||
anyhow::ensure!(
|
||||
input.shape.len() == 1,
|
||||
"histc: only 1D input is supported, got {}D",
|
||||
input.shape.len()
|
||||
);
|
||||
anyhow::ensure!(
|
||||
bins_i64 > 0,
|
||||
"histc: bins must be positive, got {}",
|
||||
bins_i64
|
||||
);
|
||||
// Bincount-equivalent case: one integer value per bin.
|
||||
anyhow::ensure!(
|
||||
(max - min - (bins_i64 - 1) as f64).abs() < 1e-6,
|
||||
"histc: only the bincount-equivalent case (bins == max - min + 1) is \
|
||||
supported; got bins={}, min={}, max={}. Other cases would need a \
|
||||
general bin-width / right-edge-inclusion implementation.",
|
||||
bins_i64,
|
||||
min,
|
||||
max,
|
||||
);
|
||||
|
||||
let bins_u = bins_i64 as usize;
|
||||
let n = input.shape.dims[0];
|
||||
|
||||
// arange(bins) [bins] → cast to input dtype, optionally shift by min,
|
||||
// broadcast to [bins, N], compare for equality with input broadcast.
|
||||
let mut bins_arange = self.graph.arange(Expression::from(bins_u));
|
||||
if min != 0.0 {
|
||||
// `min` is non-zero (uncommon in the qwen3-moe path but legal)
|
||||
// — shift the comparison values to start at min.
|
||||
let min_i = min as i64;
|
||||
let shift = self
|
||||
.graph
|
||||
.constant_float(min_i as f32)
|
||||
.cast(bins_arange.dtype)
|
||||
.expand_rhs(bins_arange.shape);
|
||||
bins_arange += shift;
|
||||
}
|
||||
let bins_expanded = bins_arange.cast(input.dtype).expand_dim(1, n);
|
||||
let input_expanded = input.expand_dim(0, Expression::from(bins_u));
|
||||
let matches = input_expanded.eq(bins_expanded); // Bool [bins, N]
|
||||
|
||||
let out_dtype = self.output_meta_dtype(node)?;
|
||||
Ok(matches.cast(out_dtype).sum(1))
|
||||
}
|
||||
|
||||
/// Lower `aten.empty.memory_format` and `aten.empty_permuted.default`.
|
||||
///
|
||||
/// Both allocate an uninitialised tensor; the caller is responsible for
|
||||
/// writing into it. We materialise zeros instead — luminal has no
|
||||
/// "uninitialised" notion, and PyTorch's contract on `empty` outputs is
|
||||
/// undefined for any read prior to a write, so a zero-fill is sound.
|
||||
/// `aten.empty_permuted` additionally takes a `physical_layout` arg
|
||||
/// (the storage permutation); for a zero-filled tensor that's a no-op.
|
||||
pub(crate) fn translate_empty(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let shape = self.get_exprs_arg(node, FULL_SHAPE_ARG)?;
|
||||
let dtype = self.output_meta_dtype(node)?;
|
||||
let zero = self.graph.constant_float(0.0).cast(dtype);
|
||||
Ok(if shape.is_empty() {
|
||||
zero
|
||||
} else {
|
||||
zero.expand_rhs(shape)
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn translate_full_like(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let reference = self.get_input_tensor(node, FULL_LIKE_INPUT_ARG)?;
|
||||
let val = if let Ok(f) = self.get_float_arg(node, FULL_LIKE_VALUE_ARG) {
|
||||
@@ -193,146 +102,33 @@ impl<'a> Translator<'a> {
|
||||
Ok(torch_dtype_int_to_luminal(meta.dtype))
|
||||
}
|
||||
|
||||
/// Translate `aten._grouped_mm.default(input, weight, offs)` → `Tensor[S, N]`.
|
||||
///
|
||||
/// Grouped matmul: `input` is `[S, K]` (tokens sorted by expert), `weight` is
|
||||
/// `[G, K, N]` (per-expert weights), `offs` is `[G]` cumulative token counts.
|
||||
/// Output `[S, N]` where token m (in group g s.t. `offs[g-1] <= m < offs[g]`)
|
||||
/// is multiplied by `weight[g]`.
|
||||
///
|
||||
/// Implementation: for each token m we (a) compute its expert id from offs,
|
||||
/// (b) gather only that expert's `[K, N]` slice from weight, and (c) do a
|
||||
/// single per-token matmul. The gather pattern mirrors the rust qwen3_moe
|
||||
/// example's `gather_experts`, which the GLUMoE host-op fusion in
|
||||
/// `luminal_cuda_lite` is designed to recognise.
|
||||
///
|
||||
/// Why not the straightforward `[G, S, K] @ [G, K, N] → [G, S, N]` + mask:
|
||||
/// it forces a full F32 cast of the entire `[G, K, N]` weight tensor as
|
||||
/// search-time intermediate, which OOMs on real MoE checkpoints
|
||||
/// (Qwen3-30B-A3B: 1.5 GB / layer × 48 layers for gate-up alone). Gathering
|
||||
/// first keeps the F32 cast on `[S, K, N]` instead — for prefill (S = top_k)
|
||||
/// that is a 16× shrink (G=128, top_k=8).
|
||||
///
|
||||
/// `offs` flows through as a runtime tensor — the routing decision is computed
|
||||
/// at execution time by the gate network and the same compiled graph handles
|
||||
/// any routing pattern without recompilation.
|
||||
pub(crate) fn translate_grouped_mm(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let input = self.get_input_tensor(node, 0)?;
|
||||
let weight = self.get_input_tensor(node, 1)?;
|
||||
let offs = self.get_input_tensor(node, 2)?;
|
||||
|
||||
anyhow::ensure!(
|
||||
input.shape.len() == 2,
|
||||
"_grouped_mm: input must be 2D, got {}D",
|
||||
input.shape.len()
|
||||
);
|
||||
anyhow::ensure!(
|
||||
weight.shape.len() == 3,
|
||||
"_grouped_mm: weight must be 3D, got {}D",
|
||||
weight.shape.len()
|
||||
);
|
||||
anyhow::ensure!(
|
||||
offs.shape.len() == 1,
|
||||
"_grouped_mm: offs must be 1D, got {}D",
|
||||
offs.shape.len()
|
||||
);
|
||||
|
||||
let s = input.shape.dims[0];
|
||||
let g = weight.shape.dims[0];
|
||||
let k = weight.shape.dims[1];
|
||||
let n = weight.shape.dims[2];
|
||||
|
||||
// expert_id[m] = number of g s.t. m >= offs[g], clamped to [0, G-1].
|
||||
// Same value as HF MoE's `expert_ids.clamp(0, num_experts-1)` for
|
||||
// invalid expert IDs from EP, AND protects search-time profiling:
|
||||
// dummy-1 input bytes give offs=[1,…,1], which pushes the raw count
|
||||
// to G for any token with index ≥ 1 and would OOB the weight gather.
|
||||
//
|
||||
// Stay in Int throughout — arange / offs are already Int, ge → Bool
|
||||
// → cast(Int), sum stays Int, and the binary `minimum` handles the
|
||||
// clamp without an F32 round-trip.
|
||||
let _ = g
|
||||
.to_usize()
|
||||
.context("_grouped_mm: G (num_experts) must be concrete")?;
|
||||
let s_arange = self.graph.arange(s); // Int [S]
|
||||
let ge_int = s_arange
|
||||
.expand_dim(0, g)
|
||||
.ge(offs.expand_dim(1, s)) // Bool [G, S]
|
||||
.cast(DType::Int); // Int [G, S]
|
||||
let raw = ge_int.sum(0); // Int [S], values in [0, G]
|
||||
let cap = self.graph.constant(g - 1).expand_dim(0, s); // Int [S], all G-1
|
||||
let expert_id = raw.minimum(cap); // Int [S]
|
||||
|
||||
// Flat gather index into weight (treated as a length-G*K*N 1D buffer):
|
||||
// flat[m, k_, n_] = expert_id[m] * (K*N) + k_ * N + n_
|
||||
// Encoded as `Mul(expert_id, Iota(io_const)) + Iota(MIter, K*N)` so the
|
||||
// resulting Gather matches the GLUMoE / gather-experts egglog patterns.
|
||||
let io = k * n;
|
||||
let base = expert_id * io;
|
||||
let within = self.graph.iota(Expression::from('z'), (k, n));
|
||||
let exp_base = base.expand_dim(1, k).expand_dim(2, n);
|
||||
let exp_within = within.expand_dim(0, s);
|
||||
let flat_idx = exp_base + exp_within;
|
||||
|
||||
// Gather → [S, K, N], preserves weight's native dtype (bf16 stays bf16).
|
||||
let weight_gathered = weight.gather(flat_idx);
|
||||
|
||||
// Per-token matmul: [S, 1, K] @ [S, K, N] → [S, 1, N] → [S, N].
|
||||
// Operands stay in their native dtype — no F32 cast on the gathered
|
||||
// weight or the input. The earlier cast(F32) was a holdover from the
|
||||
// broadcast-and-mask version (which had to use F32 because of the
|
||||
// cast(F32) on the mask). Gather-then-matmul has no such requirement,
|
||||
// and casting `[S, K, N]` to F32 doubled the gather scratch (~100 MB
|
||||
// to ~200 MB per layer for Qwen3-30B-A3B prefill). Matmul rewrites
|
||||
// (cuBLASLt etc.) handle bf16 input with F32 accumulator internally.
|
||||
let result = input.unsqueeze(1).matmul(weight_gathered).squeeze(1);
|
||||
|
||||
Ok(result.cast(input.dtype))
|
||||
}
|
||||
|
||||
/// Build the where-formula graph: `cond * x + (1 - cond) * y`, computed
|
||||
/// in F32, cast back to `out_dtype`. Shared between `translate_where`,
|
||||
/// `translate_where_scalar_other`, and `translate_masked_fill_scalar` so
|
||||
/// they all go through one well-tested code path.
|
||||
pub(crate) fn where_formula(
|
||||
&mut self,
|
||||
cond: GraphTensor,
|
||||
x: GraphTensor,
|
||||
y: GraphTensor,
|
||||
out_dtype: DType,
|
||||
) -> GraphTensor {
|
||||
let (cond_b, x_b) = broadcast_binary(cond, x);
|
||||
let (cond_bc, y_b) = broadcast_binary(cond_b, y);
|
||||
let (x_bc, y_bc) = broadcast_binary(x_b, y_b);
|
||||
// Lower as `y + c*(x - y)` rather than `c*x + (1-c)*y`: 3 ops vs 4 ops
|
||||
// plus the explicit `1.0` constant. Mathematically identical for
|
||||
// c ∈ {0, 1} and produces the same F32 output type.
|
||||
let c = cond_bc.cast(DType::F32);
|
||||
let x_f = x_bc.cast(DType::F32);
|
||||
let y_f = y_bc.cast(DType::F32);
|
||||
// Cast back: an F32 result downstream-interpreted as bf16 walks the
|
||||
// buffer at half-stride, returning every-other-element zeros.
|
||||
(y_f + c * (x_f - y_f)).cast(out_dtype)
|
||||
}
|
||||
|
||||
pub(crate) fn translate_where(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let cond = self.get_input_tensor(node, 0)?;
|
||||
let x = self.get_input_tensor(node, 1)?;
|
||||
let y = self.get_input_tensor(node, 2)?;
|
||||
// Ensure x and y have the same dtype
|
||||
let (x, y) = ensure_same_dtype(x, y);
|
||||
let out_dtype = x.dtype;
|
||||
Ok(self.where_formula(cond, x, y, out_dtype))
|
||||
// Broadcast all three tensors to a common shape first
|
||||
let (cond_b, x_b) = broadcast_binary(cond, x);
|
||||
let (cond_bc, y_b) = broadcast_binary(cond_b, y);
|
||||
let (x_bc, y_bc) = broadcast_binary(x_b, y_b);
|
||||
let c = cond_bc.cast(DType::F32);
|
||||
let x_f = x_bc.cast(DType::F32);
|
||||
let y_f = y_bc.cast(DType::F32);
|
||||
let one = self.graph.constant_float(1.0).expand_rhs(c.shape);
|
||||
Ok(c * x_f + (one - c) * y_f)
|
||||
}
|
||||
|
||||
pub(crate) fn translate_where_scalar_other(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let cond = self.get_input_tensor(node, WHERE_COND_ARG)?;
|
||||
let x = self.get_input_tensor(node, WHERE_X_ARG)?;
|
||||
let other_val = self.get_float_arg(node, WHERE_OTHER_ARG)? as f32;
|
||||
let out_dtype = x.dtype;
|
||||
// Build a tensor for the scalar `other` matching `x`'s shape so we
|
||||
// can route through the shared where_formula helper.
|
||||
let other = self.graph.constant_float(other_val).expand_rhs(x.shape);
|
||||
Ok(self.where_formula(cond, x, other, out_dtype))
|
||||
// Broadcast cond and x to a common shape
|
||||
let (cond_b, x_b) = broadcast_binary(cond, x);
|
||||
let c = cond_b.cast(DType::F32);
|
||||
let one = self.graph.constant_float(1.0).expand_rhs(c.shape);
|
||||
let other = self.graph.constant_float(other_val).expand_rhs(c.shape);
|
||||
Ok(c * x_b + (one - c) * other)
|
||||
}
|
||||
|
||||
pub(crate) fn translate_tril(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
@@ -387,37 +183,33 @@ impl<'a> Translator<'a> {
|
||||
let dim = normalize_dim(dim, a.shape.len());
|
||||
|
||||
// Determine output names
|
||||
let tuple_outputs = node.outputs.first().and_then(|o| o.as_tensors.as_ref());
|
||||
let values_name = if let Some(ts) = tuple_outputs {
|
||||
ts.first().map(|t| t.name.clone())
|
||||
} else {
|
||||
node.outputs
|
||||
.first()
|
||||
.and_then(|o| o.as_tensor.as_ref().map(|t| t.name.clone()))
|
||||
};
|
||||
let indices_name = if let Some(ts) = tuple_outputs {
|
||||
ts.get(1).map(|t| t.name.clone())
|
||||
} else if node.outputs.len() > 1 {
|
||||
node.outputs[1].as_tensor.as_ref().map(|t| t.name.clone())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let values_name = node
|
||||
.outputs
|
||||
.first()
|
||||
.and_then(|o| o.as_tensor.as_ref().map(|t| t.name.clone()));
|
||||
let indices_name =
|
||||
if let Some(ts) = node.outputs.first().and_then(|o| o.as_tensors.as_ref()) {
|
||||
ts.get(1).map(|t| t.name.clone())
|
||||
} else if node.outputs.len() > 1 {
|
||||
node.outputs[1].as_tensor.as_ref().map(|t| t.name.clone())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Build top-k outputs from a full stable argsort. Slice the indices
|
||||
// before gathering values so the gather shape matches the requested
|
||||
// top-k output rather than the full sort width.
|
||||
// Build top-k outputs from a full stable argsort, then slice to k.
|
||||
let full_argsort = a.stable_argsort(dim, true);
|
||||
let topk_indices = full_argsort.slice_along(..k, dim) * 1.0;
|
||||
|
||||
// Only build the outputs that are consumed.
|
||||
if let Some(val_name) = values_name
|
||||
&& !val_name.is_empty()
|
||||
{
|
||||
let values = a.gather_elements(topk_indices, dim);
|
||||
let values = a.gather_elements(full_argsort, dim).slice_along(..k, dim);
|
||||
self.tensors.insert(val_name, values);
|
||||
}
|
||||
if let Some(idx_name) = indices_name {
|
||||
self.tensors.insert(idx_name, topk_indices);
|
||||
// Materialize the sliced indices through a copy before storing them.
|
||||
let indices = full_argsort.slice_along(..k, dim) * 1.0;
|
||||
self.tensors.insert(idx_name, indices);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
||||
@@ -51,19 +51,13 @@ impl<'a> Translator<'a> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
for input in &node.inputs {
|
||||
if input.name == "dtype" {
|
||||
let dtype_int = input
|
||||
.arg
|
||||
.as_int()
|
||||
.map(|i| i as u32)
|
||||
.or_else(|| input.arg.as_scalar_type());
|
||||
if let Some(d) = dtype_int {
|
||||
let dtype = torch_dtype_int_to_luminal(d);
|
||||
// Skip emitting a Cast op when the dtype already matches —
|
||||
// PT2 graphs frequently emit `_to_copy` purely as a clone hint
|
||||
// (e.g. dtype=float32 on a tensor that is already F32), and
|
||||
// every redundant Cast inflates the graph and survives until
|
||||
// optimization passes can prove it as a no-op.
|
||||
return Ok(if a.dtype == dtype { a } else { a.cast(dtype) });
|
||||
if let Some(dtype_int) = input.arg.as_int() {
|
||||
let dtype = torch_dtype_int_to_luminal(dtype_int as u32);
|
||||
return Ok(a.cast(dtype));
|
||||
}
|
||||
if let Some(dtype_int) = input.arg.as_scalar_type() {
|
||||
let dtype = torch_dtype_int_to_luminal(dtype_int);
|
||||
return Ok(a.cast(dtype));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -137,34 +131,37 @@ impl<'a> Translator<'a> {
|
||||
}
|
||||
|
||||
pub(crate) fn translate_masked_fill_scalar(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
// `masked_fill(input, mask, fill)` = `where(mask, fill, input)`.
|
||||
// Routes through the shared `where_formula` helper so we exercise
|
||||
// the exact same code path as `aten.where.self`, which is verified
|
||||
// to handle the bf16 cast-back correctly. Hand-rolling the same
|
||||
// formula directly here used to drift (egglog made different
|
||||
// rewrite choices on the rebuilt-locally graph), so we deliberately
|
||||
// re-use the helper.
|
||||
// `aten.masked_fill.Scalar(input, mask, fill)` ≡
|
||||
// `aten.where.self(mask, full_like(input, fill), input)`. The
|
||||
// `full_like + where` sequence is the verified-working path
|
||||
// (test: `where(mask, torch.zeros_like(x), x)` round-trips with
|
||||
// max_diff = 0); we reproduce its exact graph-build order here.
|
||||
// Hand-rolling the formula in any other shape (single-mul, F32
|
||||
// throughout, alternative constant-cast orderings) routes egglog
|
||||
// through a rewrite that returns an F32 buffer downstream-read as
|
||||
// bf16 — the every-other-element-zero pattern.
|
||||
let input = self.get_input_tensor(node, MASKED_FILL_INPUT_ARG)?;
|
||||
let mask = self.get_input_tensor(node, MASKED_FILL_MASK_ARG)?;
|
||||
let fill = self.get_float_arg(node, MASKED_FILL_VALUE_ARG)? as f32;
|
||||
let out_dtype = input.dtype;
|
||||
// Build fill_t exactly like translate_full_like does:
|
||||
// constant_float(val).cast(dtype).expand_rhs(reference.shape)
|
||||
let fill_t = self
|
||||
let (input, mask) = broadcast_binary(input, mask);
|
||||
let work_dtype = if input.dtype == DType::Bool {
|
||||
DType::Int
|
||||
} else {
|
||||
input.dtype
|
||||
};
|
||||
let input_work = if input.dtype == DType::Bool {
|
||||
input.cast(DType::Int)
|
||||
} else {
|
||||
input
|
||||
};
|
||||
let mask_work = mask.cast(work_dtype);
|
||||
let fill_work = self
|
||||
.graph
|
||||
.constant_float(fill)
|
||||
.cast(out_dtype)
|
||||
.expand_rhs(input.shape);
|
||||
Ok(self.where_formula(mask, fill_t, input, out_dtype))
|
||||
.cast(work_dtype)
|
||||
.expand_rhs(input_work.shape);
|
||||
let one = self
|
||||
.graph
|
||||
.constant_float(1.0)
|
||||
.cast(work_dtype)
|
||||
.expand_rhs(input_work.shape);
|
||||
let result = mask_work * fill_work + (one - mask_work) * input_work;
|
||||
Ok(if input.dtype == DType::Bool {
|
||||
result.cast(DType::Bool)
|
||||
} else {
|
||||
result
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn translate_floor_divide(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
@@ -213,18 +210,12 @@ impl<'a> Translator<'a> {
|
||||
let (a, b) = crate::pt2_util::ensure_same_dtype(a, b);
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
|
||||
// Check rounding_mode kwarg. PT2 serializes string args as
|
||||
// {"as_string": "<value>"}, so we have to drill into the JSON.
|
||||
// Check rounding_mode kwarg
|
||||
let rounding_mode = node.inputs.iter().find_map(|input| {
|
||||
if input.name == "rounding_mode"
|
||||
&& let Argument::Other(val) = &input.arg
|
||||
{
|
||||
if let Some(s) = val.as_str() {
|
||||
return Some(s.to_string());
|
||||
}
|
||||
if let Some(s) = val.get("as_string").and_then(|v| v.as_str()) {
|
||||
return Some(s.to_string());
|
||||
}
|
||||
return val.as_str().map(|s| s.to_string());
|
||||
}
|
||||
None
|
||||
});
|
||||
@@ -275,52 +266,4 @@ impl<'a> Translator<'a> {
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// `aten.clamp.Tensor(Tensor self, Tensor? min=None, Tensor? max=None)`
|
||||
///
|
||||
/// Unlike `clamp.default` (which takes Python scalar bounds), the `.Tensor`
|
||||
/// overload takes tensor bounds that appear as separate input nodes in the
|
||||
/// FX graph. PyTorch supports any NumPy-broadcastable bound shape:
|
||||
///
|
||||
/// - rank-0 (scalar wrapped in a tensor) — most common
|
||||
/// - same shape as self (per-element clamp, e.g. learned bounds)
|
||||
/// - any shape that broadcasts to self via right-align + size-1 expand
|
||||
/// (e.g. `(3, 1)` against `(3, 4)` for per-row clamp; `(4,)` against
|
||||
/// `(3, 4)` for per-column clamp; `(3, 4)` against `(2, 3, 4)`)
|
||||
///
|
||||
/// We use `broadcast_binary` to right-align and expand both operands to a
|
||||
/// common shape before the elementwise max/min, matching PyTorch semantics
|
||||
/// across all three modes.
|
||||
///
|
||||
/// Either bound may be absent (FX represents this as a non-tensor argument
|
||||
/// at the corresponding input slot), in which case we clamp to one side
|
||||
/// only.
|
||||
pub(crate) fn translate_clamp_tensor(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let min_tensor = node
|
||||
.inputs
|
||||
.get(1)
|
||||
.and_then(|i| i.arg.as_tensor_name())
|
||||
.map(|n| self.get_tensor(n))
|
||||
.transpose()?;
|
||||
let max_tensor = node
|
||||
.inputs
|
||||
.get(2)
|
||||
.and_then(|i| i.arg.as_tensor_name())
|
||||
.map(|n| self.get_tensor(n))
|
||||
.transpose()?;
|
||||
|
||||
let mut result = a;
|
||||
if let Some(lo) = min_tensor {
|
||||
let lo = lo.cast(result.dtype);
|
||||
let (r, lo) = broadcast_binary(result, lo);
|
||||
result = r.maximum(lo);
|
||||
}
|
||||
if let Some(hi) = max_tensor {
|
||||
let hi = hi.cast(result.dtype);
|
||||
let (r, hi) = broadcast_binary(result, hi);
|
||||
result = r.minimum(hi);
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -77,10 +77,7 @@ class CompiledModel:
|
||||
)
|
||||
user_inputs = inputs
|
||||
|
||||
# Use the first *user* input for device detection — when torch.compile
|
||||
# has lifted SymInts or weights into the call args, `inputs[0]` may not
|
||||
# be a tensor. user_inputs has been filtered to actual tensors.
|
||||
input_device = user_inputs[0].device if user_inputs else torch.device("cpu")
|
||||
input_device = inputs[0].device if inputs else torch.device("cpu")
|
||||
|
||||
# Auto-detect dynamic dims from input shapes
|
||||
if self._has_dynamic_dims:
|
||||
@@ -135,11 +132,6 @@ class CompiledModel:
|
||||
# Run the graph
|
||||
self._graph.run()
|
||||
|
||||
# Integer dtypes for which we read the buffer as i32 and then cast.
|
||||
# Includes int64 because luminal collapses all integer types to its
|
||||
# 32-bit `Int` internally — we restore the original precision here.
|
||||
_int_dtypes = (torch.int8, torch.int16, torch.int32, torch.int64, torch.uint8)
|
||||
|
||||
# Collect outputs
|
||||
if _use_zero_copy:
|
||||
outputs = []
|
||||
@@ -155,12 +147,11 @@ class CompiledModel:
|
||||
self._graph.copy_output_to_device_ptr(
|
||||
name, out.data_ptr(), out.numel() * out.element_size()
|
||||
)
|
||||
elif out_dtype in _int_dtypes:
|
||||
elif out_dtype == torch.int32:
|
||||
data = self._graph.get_output_i32(name)
|
||||
out = (
|
||||
torch.tensor(data, dtype=torch.int32)
|
||||
.reshape(tuple(shape))
|
||||
.to(out_dtype)
|
||||
.to(input_device)
|
||||
)
|
||||
elif out_dtype == torch.bool:
|
||||
@@ -188,13 +179,9 @@ class CompiledModel:
|
||||
if i < len(output_dtype_codes)
|
||||
else torch.float32
|
||||
)
|
||||
if out_dtype in _int_dtypes:
|
||||
if out_dtype == torch.int32:
|
||||
data = self._graph.get_output_i32(name)
|
||||
out = (
|
||||
torch.tensor(data, dtype=torch.int32)
|
||||
.reshape(tuple(shape))
|
||||
.to(out_dtype)
|
||||
)
|
||||
out = torch.tensor(data, dtype=torch.int32).reshape(tuple(shape))
|
||||
elif out_dtype == torch.bool:
|
||||
data = self._graph.get_output_bool(name)
|
||||
out = torch.tensor(data, dtype=torch.bool).reshape(tuple(shape))
|
||||
|
||||
@@ -11,21 +11,14 @@ from .dtype_util import torch_dtype_code as _torch_dtype_code
|
||||
|
||||
def _detect_factory_capsule(example_inputs):
|
||||
"""Pick the best built-in factory capsule based on input device."""
|
||||
# Dynamo can prefix `example_inputs` with SymInt entries when shapes are
|
||||
# dynamic — those have no `.device`. Pick the first real tensor instead.
|
||||
first_tensor = next((t for t in (example_inputs or []) if torch.is_tensor(t)), None)
|
||||
device = first_tensor.device if first_tensor is not None else torch.device("cpu")
|
||||
device = example_inputs[0].device if example_inputs else torch.device("cpu")
|
||||
if device.type == "cuda":
|
||||
try:
|
||||
from .luminal import _cuda_lite_factory_capsule
|
||||
|
||||
return _cuda_lite_factory_capsule()
|
||||
except (ImportError, AttributeError) as exc:
|
||||
raise RuntimeError(
|
||||
"CUDA input was provided, but luminal_python was not built with "
|
||||
"the cuda feature. Rebuild with `maturin develop --features cuda` "
|
||||
"or run through `run_tests_cuda.sh`/the Modal CUDA test runner."
|
||||
) from exc
|
||||
except ImportError:
|
||||
pass
|
||||
from .luminal import _native_factory_capsule
|
||||
|
||||
return _native_factory_capsule()
|
||||
@@ -83,7 +76,7 @@ def register_backend(factory_capsule):
|
||||
"""
|
||||
|
||||
def backend(gm, example_inputs, options=None):
|
||||
return _compile_pt2(gm, example_inputs, factory_capsule)
|
||||
return _compile_pt2(gm, example_inputs, factory_capsule, options=options)
|
||||
|
||||
return backend
|
||||
|
||||
@@ -102,7 +95,7 @@ def luminal_backend(gm, example_inputs, options=None):
|
||||
For external backends, use register_backend with the backend's factory capsule.
|
||||
"""
|
||||
capsule = _detect_factory_capsule(example_inputs)
|
||||
return _compile_pt2(gm, example_inputs, capsule)
|
||||
return _compile_pt2(gm, example_inputs, capsule, options=options)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -110,8 +103,8 @@ def luminal_backend(gm, example_inputs, options=None):
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _compile_pt2(gm, example_inputs, factory_capsule):
|
||||
def _compile_pt2(gm, example_inputs, factory_capsule, options=None):
|
||||
"""PT2/torch.export path — delegates to pt2.pt2_backend."""
|
||||
from .pt2 import pt2_backend
|
||||
|
||||
return pt2_backend(gm, example_inputs, factory=factory_capsule)
|
||||
return pt2_backend(gm, example_inputs, factory=factory_capsule, options=options)
|
||||
|
||||
@@ -16,84 +16,6 @@ from .compiled_model import CompiledModel
|
||||
from .luminal import process_pt2
|
||||
from .main import _collect_weight_pointers, _detect_factory_capsule, _load_cpu_weights
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DynamicCache <> pytree registration
|
||||
#
|
||||
# Without this, torch.export.export raises when handed an HF model that
|
||||
# returns CausalLMOutputWithPast(past_key_values=DynamicCache(...)), which
|
||||
# is every model with use_cache=True. The registration mirrors the one in
|
||||
# transformers.integrations.executorch.register_dynamic_cache_export_support
|
||||
# — same dict-based flatten (key_cache / value_cache lists), same replay via
|
||||
# cache.update(k, v, idx), and the matching torch.fx._pytree spec for FX
|
||||
# graphs. Done at module import so both entry points (pt2_backend via
|
||||
# torch.compile and the direct compile() call) get it for free.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _get_cache_dict(cache):
|
||||
"""Flatten a DynamicCache to a dict of parallel key/value lists."""
|
||||
return {
|
||||
"key_cache": [layer.keys for layer in cache.layers if layer.keys is not None],
|
||||
"value_cache": [
|
||||
layer.values for layer in cache.layers if layer.values is not None
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def _flatten_dynamic_cache(cache):
|
||||
return torch.utils._pytree._dict_flatten(_get_cache_dict(cache))
|
||||
|
||||
|
||||
def _flatten_with_keys_dynamic_cache(cache):
|
||||
return torch.utils._pytree._dict_flatten_with_keys(_get_cache_dict(cache))
|
||||
|
||||
|
||||
def _unflatten_dynamic_cache(values, context):
|
||||
from transformers.cache_utils import DynamicCache
|
||||
|
||||
dictionary = torch.utils._pytree._dict_unflatten(values, context)
|
||||
cache = DynamicCache()
|
||||
key_list = dictionary.get("key_cache", [])
|
||||
value_list = dictionary.get("value_cache", [])
|
||||
for idx in range(max(len(key_list), len(value_list))):
|
||||
k = key_list[idx] if idx < len(key_list) else None
|
||||
v = value_list[idx] if idx < len(value_list) else None
|
||||
cache.update(k, v, idx)
|
||||
return cache
|
||||
|
||||
|
||||
def _register_cache_serialization():
|
||||
"""Register DynamicCache with both torch.utils._pytree and torch.fx._pytree.
|
||||
|
||||
Idempotent: a second call is a no-op. Silently skipped if transformers is
|
||||
not installed.
|
||||
"""
|
||||
try:
|
||||
from transformers.cache_utils import DynamicCache
|
||||
except ImportError:
|
||||
return
|
||||
|
||||
if DynamicCache in torch.utils._pytree.SUPPORTED_NODES:
|
||||
return
|
||||
|
||||
torch.utils._pytree.register_pytree_node(
|
||||
DynamicCache,
|
||||
_flatten_dynamic_cache,
|
||||
_unflatten_dynamic_cache,
|
||||
serialized_type_name=f"{DynamicCache.__module__}.{DynamicCache.__name__}",
|
||||
flatten_with_keys_fn=_flatten_with_keys_dynamic_cache,
|
||||
)
|
||||
torch.fx._pytree.register_pytree_flatten_spec(
|
||||
DynamicCache,
|
||||
lambda cache, spec: torch.fx._pytree._dict_flatten_spec(
|
||||
_get_cache_dict(cache), spec
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
_register_cache_serialization()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -110,34 +32,11 @@ def _export_kwargs():
|
||||
return kwargs
|
||||
|
||||
|
||||
def _decomp_table():
|
||||
"""Decomposition table for `ep.run_decompositions()` that preserves SDPA.
|
||||
|
||||
The default table decomposes `aten.scaled_dot_product_attention.default`
|
||||
into ~20 ops (matmul/softmax + an `eq.Scalar`/`logical_not`/`any.dim`/
|
||||
`where`/`full_like` "all-masked" sentinel chain). We translate SDPA as a
|
||||
single fused op via `translate_sdpa`, so we strip the SDPA decompositions
|
||||
here to let them survive into the FX graph the translator walks.
|
||||
"""
|
||||
try:
|
||||
from torch.export import default_decompositions
|
||||
except ImportError:
|
||||
return None
|
||||
table = default_decompositions()
|
||||
sdpa_ops = [
|
||||
torch.ops.aten.scaled_dot_product_attention.default,
|
||||
torch.ops.aten._scaled_dot_product_efficient_attention.default,
|
||||
torch.ops.aten._scaled_dot_product_flash_attention.default,
|
||||
torch.ops.aten._scaled_dot_product_flash_attention_for_cpu.default,
|
||||
torch.ops.aten._scaled_dot_product_cudnn_attention.default,
|
||||
]
|
||||
for op in sdpa_ops:
|
||||
table.pop(op, None)
|
||||
return table
|
||||
|
||||
|
||||
def _save_and_compile(
|
||||
ep_or_path, factory, search_iterations, original_weights=None, user_indices=None
|
||||
ep_or_path,
|
||||
factory,
|
||||
original_weights=None,
|
||||
options=None,
|
||||
):
|
||||
"""Compile a PT2 model via Rust, return CompiledModel.
|
||||
|
||||
@@ -170,177 +69,22 @@ def _save_and_compile(
|
||||
|
||||
# Compile with device pointers — search uses actual weight memory (zero-copy)
|
||||
compiled = process_pt2(
|
||||
pt2_path, "", search_iterations, factory, weight_device_ptrs
|
||||
pt2_path,
|
||||
"",
|
||||
factory,
|
||||
weight_device_ptrs=weight_device_ptrs,
|
||||
options=options,
|
||||
)
|
||||
|
||||
# Load CPU weights after compilation
|
||||
_load_cpu_weights(compiled, cpu_weights)
|
||||
|
||||
return CompiledModel(
|
||||
compiled, weight_refs=keep_alive, user_indices=user_indices
|
||||
)
|
||||
return CompiledModel(compiled, weight_refs=keep_alive)
|
||||
finally:
|
||||
if owns_tmpdir and tmpdir:
|
||||
shutil.rmtree(tmpdir, ignore_errors=True)
|
||||
|
||||
|
||||
def _safe_int_bound(value):
|
||||
"""Coerce a sympy/symbolic-shape range bound to a finite int, or None.
|
||||
|
||||
Range bounds returned by ShapeEnv can be sympy `Infinity` / `-Infinity`
|
||||
(as well as the internal `int_oo` sentinel), which both raise on `int(...)`.
|
||||
Treat anything non-finite — and anything that simply doesn't coerce — as
|
||||
"no bound."
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
# Stringify is robust against the various sentinel types: sympy.Infinity,
|
||||
# torch.utils._sympy.numbers.IntInfinity, etc. all stringify to "oo"/"-oo".
|
||||
s = str(value)
|
||||
if "oo" in s or "inf" in s.lower():
|
||||
return None
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError, OverflowError, AttributeError):
|
||||
return None
|
||||
|
||||
|
||||
def _strip_symint_placeholders(gm, example_inputs):
|
||||
"""Rewrite SymInt graph inputs into tensor.size(d) calls, then drop them.
|
||||
|
||||
When Dynamo decides a dim is dynamic it emits the symbol as a separate
|
||||
placeholder (e.g. `s77`) alongside the user's tensor (whose FakeTensor shape
|
||||
references the same symbol). torch.export.export rejects mixed
|
||||
SymInt/Tensor positional args, and the Rust pipeline doesn't model SymInt
|
||||
inputs anyway — so we replace each SymInt placeholder's uses with
|
||||
`aten.sym_size.int(tensor, dim)` for the first tensor placeholder whose
|
||||
example_value's shape[dim] matches the symbol, then erase the placeholder.
|
||||
|
||||
Returns `(post_strip_inputs, kept_indices, ok)` where:
|
||||
- `post_strip_inputs` is `example_inputs` filtered to tensor-only entries
|
||||
- `kept_indices` is the indices into `example_inputs` we kept (used by
|
||||
the caller to compose with any prior input filter, e.g. lifted-weight
|
||||
re-internalization, when handing `user_indices` to CompiledModel)
|
||||
- `ok` is False when at least one SymInt placeholder couldn't be
|
||||
rewritten (compound expression with users, or no matching tensor dim);
|
||||
the caller should fall back to no-dynamic export in that case.
|
||||
"""
|
||||
placeholders = [n for n in gm.graph.nodes if n.op == "placeholder"]
|
||||
|
||||
# Collect (placeholder_node, example_input_idx) for every SymInt placeholder.
|
||||
symint_entries = []
|
||||
tensor_entries = []
|
||||
for idx, node in enumerate(placeholders):
|
||||
ev = node.meta.get("example_value")
|
||||
if isinstance(ev, torch.SymInt) or (
|
||||
ev is None
|
||||
and idx < len(example_inputs)
|
||||
and isinstance(example_inputs[idx], torch.SymInt)
|
||||
):
|
||||
symint_entries.append((node, idx))
|
||||
else:
|
||||
tensor_entries.append((node, idx))
|
||||
|
||||
if not symint_entries:
|
||||
return example_inputs, list(range(len(example_inputs))), True
|
||||
|
||||
# Build a symbol -> (tensor_node, dim) lookup from the tensor placeholders'
|
||||
# example FakeTensor shapes. Any tensor whose shape[d] is the SymInt
|
||||
# is a valid source — pick the first.
|
||||
sym_to_source = {}
|
||||
for t_node, _ in tensor_entries:
|
||||
ev = t_node.meta.get("example_value")
|
||||
if not torch.is_tensor(ev):
|
||||
continue
|
||||
for d, s in enumerate(ev.shape):
|
||||
if isinstance(s, torch.SymInt):
|
||||
key = str(s.node.expr)
|
||||
sym_to_source.setdefault(key, (t_node, d))
|
||||
|
||||
# Rewrite each SymInt placeholder's uses to sym_size calls, then erase it.
|
||||
all_clean = True
|
||||
for s_node, _ in symint_entries:
|
||||
ev = s_node.meta.get("example_value")
|
||||
if ev is None:
|
||||
all_clean = False
|
||||
continue
|
||||
# The placeholder's example_value is the SymInt itself; its expr is the
|
||||
# symbol name (or a compound expression we can't lift this way).
|
||||
expr_str = str(ev.node.expr)
|
||||
source = sym_to_source.get(expr_str)
|
||||
if source is None:
|
||||
# Compound expression or no tensor carries this symbol — bail.
|
||||
if len(s_node.users) > 0:
|
||||
all_clean = False
|
||||
continue
|
||||
gm.graph.erase_node(s_node)
|
||||
continue
|
||||
|
||||
if len(s_node.users) > 0:
|
||||
t_node, dim = source
|
||||
with gm.graph.inserting_after(t_node):
|
||||
size_node = gm.graph.call_function(
|
||||
torch.ops.aten.sym_size.int, (t_node, dim)
|
||||
)
|
||||
size_node.meta["val"] = ev
|
||||
size_node.meta["example_value"] = ev
|
||||
s_node.replace_all_uses_with(size_node)
|
||||
gm.graph.erase_node(s_node)
|
||||
|
||||
if not all_clean:
|
||||
# Recompile defensively even on partial success — some erases may have
|
||||
# happened. Caller will decide whether to proceed.
|
||||
gm.graph.lint()
|
||||
gm.recompile()
|
||||
return example_inputs, list(range(len(example_inputs))), False
|
||||
|
||||
gm.graph.lint()
|
||||
gm.recompile()
|
||||
# Filter the runtime example_inputs to drop the stripped SymInt entries.
|
||||
kept_indices = [idx for _, idx in tensor_entries]
|
||||
keep_set = set(kept_indices)
|
||||
new_inputs = [v for i, v in enumerate(example_inputs) if i in keep_set]
|
||||
return new_inputs, kept_indices, True
|
||||
|
||||
|
||||
def _build_dynamic_shapes_from_gm(gm):
|
||||
"""Construct a torch.export.export `dynamic_shapes` spec from FX metadata.
|
||||
|
||||
Walks each tensor placeholder's `meta['example_value']` FakeTensor and
|
||||
marks every SymInt dim as `Dim.AUTO`. Sharing/equality relationships
|
||||
between symbolic dims are already encoded in the FakeTensor shapes —
|
||||
torch.export's symbolic-shape engine recovers them during the trace, so
|
||||
we don't need to allocate named `Dim` objects ourselves.
|
||||
|
||||
The returned spec is wrapped under `{"args": (...)}` because Dynamo's
|
||||
`GraphModule.forward(*args, **kwargs)` signature treats positional inputs
|
||||
as the `args` tuple.
|
||||
|
||||
Returns None if there are no symbolic dims to mark.
|
||||
"""
|
||||
from torch.export import Dim
|
||||
|
||||
placeholders = [n for n in gm.graph.nodes if n.op == "placeholder"]
|
||||
|
||||
per_input_spec = []
|
||||
saw_dynamic = False
|
||||
for node in placeholders:
|
||||
ev = node.meta.get("example_value")
|
||||
if not torch.is_tensor(ev):
|
||||
per_input_spec.append(None)
|
||||
continue
|
||||
spec = {}
|
||||
for d, s in enumerate(ev.shape):
|
||||
if isinstance(s, torch.SymInt):
|
||||
spec[d] = Dim.AUTO
|
||||
saw_dynamic = True
|
||||
per_input_spec.append(spec if spec else None)
|
||||
|
||||
if not saw_dynamic:
|
||||
return None
|
||||
return {"args": tuple(per_input_spec)}
|
||||
|
||||
|
||||
def _reinternalize_lifted_params(gm, example_inputs):
|
||||
"""Re-internalize lifted params as buffers so torch.export sees them as model state.
|
||||
|
||||
@@ -390,7 +134,7 @@ def _reinternalize_lifted_params(gm, example_inputs):
|
||||
if user_indices
|
||||
else list(example_inputs)
|
||||
)
|
||||
return gm, user_inputs, original_weights, user_indices
|
||||
return gm, user_inputs, original_weights
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -405,238 +149,100 @@ def compile(
|
||||
factory=None,
|
||||
export_kwargs=None,
|
||||
dynamic_dim=None,
|
||||
dynamic_shapes=None,
|
||||
):
|
||||
"""Compile a PyTorch model to run on Luminal via PT2 pipeline.
|
||||
|
||||
Args:
|
||||
model: A PyTorch nn.Module.
|
||||
example_input: Example input tensor — or a list/tuple of tensors for
|
||||
multi-input models.
|
||||
example_input: Example input tensor(s) for tracing.
|
||||
search_iterations: Number of optimization search iterations.
|
||||
factory: PyCapsule wrapping a BackendFactory. Auto-detected if None.
|
||||
export_kwargs: Extra kwargs passed to torch.export.export.
|
||||
dynamic_dim: Convenience controls for `dynamic_shapes` when only one
|
||||
symbolic dim is needed.
|
||||
* `None` (default): leave shapes static.
|
||||
* `int`: mark that dim of the (first) input as `Dim.AUTO`.
|
||||
* `Iterable[int]`: mark each listed dim of the first input.
|
||||
* `"auto"`: mark every non-trivial dim (size > 1) of the
|
||||
first input as `Dim.AUTO` — works for floating-point and
|
||||
integer inputs alike.
|
||||
dynamic_shapes: Direct passthrough to `torch.export.export`'s
|
||||
`dynamic_shapes` argument. When provided, takes precedence over
|
||||
`dynamic_dim`. Use this for full control: per-input specs,
|
||||
`Dim("name", min=, max=)` ranges, shared dims across inputs, etc.
|
||||
dynamic_dim: Which input dimension to make dynamic.
|
||||
|
||||
Returns:
|
||||
A CompiledModel callable.
|
||||
"""
|
||||
if factory is None:
|
||||
factory = _detect_factory_capsule(
|
||||
example_input
|
||||
if isinstance(example_input, (list, tuple))
|
||||
else [example_input]
|
||||
)
|
||||
if dynamic_dim is None:
|
||||
dynamic_dim = "auto"
|
||||
|
||||
if isinstance(example_input, (list, tuple)):
|
||||
example_args = tuple(example_input)
|
||||
else:
|
||||
example_args = (example_input,)
|
||||
if factory is None:
|
||||
factory = _detect_factory_capsule([example_input])
|
||||
|
||||
kwargs = export_kwargs or {}
|
||||
extra = _export_kwargs()
|
||||
|
||||
# Build dynamic_shapes from the convenience knob if the caller didn't
|
||||
# hand us a full spec. `dynamic_dim=None` falls back to the legacy
|
||||
# `"auto"` behavior (mark the last axis of an integer input as dynamic)
|
||||
# so callers that relied on the previous default keep working.
|
||||
if dynamic_shapes is None:
|
||||
if dynamic_dim is None:
|
||||
dynamic_dim = _legacy_auto_dim(example_args)
|
||||
if dynamic_dim is not None:
|
||||
dynamic_shapes = _build_dynamic_shapes_from_dim_arg(
|
||||
dynamic_dim, example_args
|
||||
)
|
||||
|
||||
# `torch.export.export` is finicky: when `dynamic_shapes` is set it
|
||||
# validates the spec against the example shapes and raises on any
|
||||
# disagreement (e.g. the user marked a dim as dynamic but their model
|
||||
# specialises it to a constant). Fall back to a static export so the
|
||||
# caller still gets a usable CompiledModel rather than a hard error.
|
||||
ep = None
|
||||
if dynamic_shapes is not None:
|
||||
try:
|
||||
ep = torch.export.export(
|
||||
model,
|
||||
example_args,
|
||||
kwargs=kwargs,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
**extra,
|
||||
)
|
||||
ep = ep.run_decompositions(_decomp_table())
|
||||
except Exception:
|
||||
ep = None
|
||||
|
||||
# Try dynamic dimension export
|
||||
candidate_dims = []
|
||||
if isinstance(dynamic_dim, int):
|
||||
candidate_dims = [dynamic_dim]
|
||||
elif dynamic_dim == "auto" and example_input.dim() >= 2:
|
||||
if not example_input.is_floating_point():
|
||||
candidate_dims = [example_input.dim() - 1]
|
||||
|
||||
if candidate_dims:
|
||||
from torch.export import Dim
|
||||
|
||||
for dim_idx in candidate_dims:
|
||||
try:
|
||||
seq = Dim("seq", min=2)
|
||||
arg_shapes = {dim_idx: seq}
|
||||
kwarg_shapes = {k: None for k in kwargs}
|
||||
dynamic_shapes = (
|
||||
(arg_shapes,) + tuple(kwarg_shapes.values())
|
||||
if kwarg_shapes
|
||||
else (arg_shapes,)
|
||||
)
|
||||
ep = torch.export.export(
|
||||
model,
|
||||
(example_input,),
|
||||
kwargs=kwargs,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
**extra,
|
||||
)
|
||||
ep = ep.run_decompositions()
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if ep is None:
|
||||
ep = torch.export.export(
|
||||
model,
|
||||
example_args,
|
||||
(example_input,),
|
||||
kwargs=kwargs,
|
||||
dynamic_shapes=None,
|
||||
**extra,
|
||||
)
|
||||
ep = ep.run_decompositions(_decomp_table())
|
||||
ep = ep.run_decompositions()
|
||||
|
||||
return _save_and_compile(ep, factory, search_iterations)
|
||||
return _save_and_compile(
|
||||
ep,
|
||||
factory,
|
||||
options={"search_iterations": search_iterations},
|
||||
)
|
||||
|
||||
|
||||
def _drop_input_guards(ep):
|
||||
"""Discard ``ep._guards_code`` so unlift does not emit a ``_guards_fn``.
|
||||
def pt2_backend(gm, example_inputs, factory=None, options=None):
|
||||
"""torch.compile backend using PT2 pipeline.
|
||||
|
||||
LUM-499: When a 0-d int tensor flows into a tensor index (``x[i]`` with
|
||||
``i = torch.tensor(2)``), torch.export records two equivalent input
|
||||
guards: ``L['i'].item() == 2`` (referencing the original local source)
|
||||
and ``L['args'][1].item() == 2`` (referencing the rewrapped flat args).
|
||||
Two failures stack on top of each other:
|
||||
|
||||
1. ``ep.module()`` (invoked inside ``run_decompositions``) rewrites
|
||||
``L['args'][1]`` → ``args[1]`` but cannot resolve ``L['i']``, leaving
|
||||
a literal ``L`` reference in the generated ``_guards_fn`` and raising
|
||||
``NameError: name 'L' is not defined`` during retracing.
|
||||
2. Even after dropping the unresolvable guard, the surviving
|
||||
``args[1].item()`` is data-dependent: AOT autograd's fake-tensor pass
|
||||
raises ``DataDependentOutputException(_local_scalar_dense)``, forcing
|
||||
a graph break.
|
||||
|
||||
These guards exist solely to validate inputs at runtime in eager-mode
|
||||
consumers of the ExportedProgram; the luminal compiler does its own
|
||||
input shape/dtype checks against the compiled graph signature, so we
|
||||
are not losing any safety by clearing them.
|
||||
"""
|
||||
|
||||
if hasattr(ep, "_guards_code"):
|
||||
ep._guards_code = []
|
||||
|
||||
|
||||
def _drop_dead_data_dependent_ops(gm):
|
||||
"""Remove ``aten.item.default`` (and other data-dependent ops) with no users.
|
||||
|
||||
When dynamo specializes a 0-d int input by tracing through ``.item()``,
|
||||
the resulting graph may contain a dead ``aten.item.default`` node whose
|
||||
output is never consumed. luminal's translator does not lower
|
||||
``aten._local_scalar_dense`` / ``aten.item.default``, so leaving the dead
|
||||
node in the graph causes a graph break at compile time. Eliminating it
|
||||
keeps the (correctly specialized) downstream graph in a single subgraph.
|
||||
"""
|
||||
|
||||
graph = gm.graph
|
||||
changed = False
|
||||
for node in list(graph.nodes):
|
||||
if (
|
||||
node.op == "call_function"
|
||||
and getattr(node.target, "_overloadpacket", None) is torch.ops.aten.item
|
||||
and len(node.users) == 0
|
||||
):
|
||||
graph.erase_node(node)
|
||||
changed = True
|
||||
|
||||
if changed:
|
||||
graph.eliminate_dead_code()
|
||||
graph.lint()
|
||||
gm.recompile()
|
||||
|
||||
|
||||
def _legacy_auto_dim(example_args):
|
||||
"""Match the historical `dynamic_dim="auto"` heuristic.
|
||||
|
||||
Returns the last axis of the first input when that input is a 2-D-or-
|
||||
larger integer tensor (the typical token-id sequence pattern), and
|
||||
`None` otherwise. Float inputs and 1-D tensors fall through to the
|
||||
static export path the legacy code did.
|
||||
"""
|
||||
if not example_args:
|
||||
return None
|
||||
first = example_args[0]
|
||||
if not torch.is_tensor(first):
|
||||
return None
|
||||
if first.is_floating_point():
|
||||
return None
|
||||
if first.dim() < 2:
|
||||
return None
|
||||
return first.dim() - 1
|
||||
|
||||
|
||||
def _build_dynamic_shapes_from_dim_arg(dynamic_dim, example_args):
|
||||
"""Translate the `dynamic_dim` shorthand into a full `dynamic_shapes` spec.
|
||||
|
||||
Always targets the first positional input — multi-input dynamic specs
|
||||
require the caller to use `dynamic_shapes=` directly so they can name
|
||||
which input each dim belongs to.
|
||||
"""
|
||||
from torch.export import Dim
|
||||
|
||||
if not example_args:
|
||||
return None
|
||||
first = example_args[0]
|
||||
if not torch.is_tensor(first):
|
||||
return None
|
||||
|
||||
if isinstance(dynamic_dim, int):
|
||||
dims = [dynamic_dim]
|
||||
elif isinstance(dynamic_dim, str) and dynamic_dim == "auto":
|
||||
# Mark every dim with size > 1 as dynamic. Dim.AUTO leaves
|
||||
# torch.export to pick a Dim per axis and infer relationships from
|
||||
# the example FakeTensor.
|
||||
dims = [d for d, s in enumerate(first.shape) if int(s) > 1]
|
||||
elif hasattr(dynamic_dim, "__iter__"):
|
||||
dims = [int(d) for d in dynamic_dim]
|
||||
else:
|
||||
return None
|
||||
|
||||
if not dims:
|
||||
return None
|
||||
|
||||
spec = {d: Dim.AUTO for d in dims}
|
||||
rest = (None,) * (len(example_args) - 1)
|
||||
return (spec,) + rest
|
||||
|
||||
|
||||
def _eager_pt2_compile(
|
||||
gm, user_inputs, original_weights, user_indices, dynamic_shapes, factory
|
||||
):
|
||||
"""Run torch.export → save → Rust compile end-to-end. Returns CompiledModel.
|
||||
|
||||
Factored out so both the eager (static-shapes) and lazy (dynamic-shapes)
|
||||
backend paths share a single implementation.
|
||||
Usage: torch.compile(model, backend=luminal.register_backend(capsule))
|
||||
"""
|
||||
import gc
|
||||
|
||||
try:
|
||||
ep = torch.export.export(
|
||||
gm,
|
||||
tuple(user_inputs),
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
**_export_kwargs(),
|
||||
)
|
||||
except Exception:
|
||||
# If torch.export rejects the dynamic spec (e.g. user code introduced
|
||||
# a constraint we didn't model), retry without it. Better to lose the
|
||||
# dynamic-dim optimization than to hand the user a hard failure.
|
||||
if dynamic_shapes is None:
|
||||
raise
|
||||
ep = torch.export.export(gm, tuple(user_inputs), **_export_kwargs())
|
||||
# LUM-499: drop dynamo-emitted input guards before run_decompositions
|
||||
# calls ep.module(), which would otherwise emit a `_guards_fn` containing
|
||||
# data-dependent .item() calls and unresolved `L[...]` references.
|
||||
_drop_input_guards(ep)
|
||||
_drop_dead_data_dependent_ops(ep.graph_module)
|
||||
ep = ep.run_decompositions(_decomp_table())
|
||||
if factory is None:
|
||||
factory = _detect_factory_capsule(example_inputs)
|
||||
|
||||
# When using shared memory (original_weights), strip large weight buffers
|
||||
# from the EP before saving. The Rust side uses device pointers for these
|
||||
# weights, not the .pt2 file data, so serializing them is pure IO waste
|
||||
# (~32 GB for 8B models). Replace with tiny CPU scalars to shrink to <1 MB.
|
||||
gm = gm.eval()
|
||||
gm, user_inputs, original_weights = _reinternalize_lifted_params(gm, example_inputs)
|
||||
|
||||
ep = torch.export.export(gm, tuple(user_inputs), **_export_kwargs())
|
||||
ep = ep.run_decompositions()
|
||||
|
||||
# When using shared memory (original_weights), strip large weight buffers from
|
||||
# the EP before saving. The Rust side uses device pointers for these weights,
|
||||
# not the .pt2 file data, so serializing them is pure IO waste (~32 GB for 8B
|
||||
# models). Replacing with tiny CPU scalars shrinks the .pt2 to < 1 MB.
|
||||
if original_weights:
|
||||
for key in list(ep._state_dict.keys()):
|
||||
if key in original_weights:
|
||||
@@ -644,9 +250,9 @@ def _eager_pt2_compile(
|
||||
ep._state_dict[key] = torch.zeros(1, dtype=orig.dtype, device="cpu")
|
||||
del orig
|
||||
|
||||
# Save EP to disk, then free it and the traced graph module before Rust
|
||||
# compilation. torch.export clones the state_dict internally; holding ep
|
||||
# alive during compile would double weight memory on GPU.
|
||||
# Save the exported program to disk, then free it and the traced graph module
|
||||
# BEFORE Rust compilation. torch.export clones the state_dict internally, so
|
||||
# holding ep alive during compilation would double the weight memory on GPU.
|
||||
tmpdir = tempfile.mkdtemp(prefix="luminal_")
|
||||
pt2_path = os.path.join(tmpdir, "model.pt2")
|
||||
torch.export.save(ep, pt2_path)
|
||||
@@ -657,129 +263,12 @@ def _eager_pt2_compile(
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
try:
|
||||
return _save_and_compile(
|
||||
result = _save_and_compile(
|
||||
pt2_path,
|
||||
factory,
|
||||
10,
|
||||
original_weights=original_weights,
|
||||
user_indices=user_indices,
|
||||
options=options,
|
||||
)
|
||||
return result
|
||||
finally:
|
||||
shutil.rmtree(tmpdir, ignore_errors=True)
|
||||
|
||||
|
||||
class _LazyDynamicCompiledModel:
|
||||
"""Defers torch.export + Rust compile to the first invocation.
|
||||
|
||||
Calling `torch.export.export(..., dynamic_shapes=...)` from inside a
|
||||
Dynamo backend frame triggers an internal "Guard failed on the same
|
||||
frame it was created" assertion in PyTorch — `torch.export`'s symbolic
|
||||
tracer mutates the ShapeEnv that Dynamo is also relying on for the
|
||||
surrounding compile, leaving the just-installed guards in an
|
||||
inconsistent state. Punting all of that work to the first runtime call
|
||||
sidesteps the issue: by then Dynamo's guard installation is finished,
|
||||
so the shape-env mutations no longer matter.
|
||||
|
||||
This wrapper is API-compatible with `CompiledModel` for the bits the
|
||||
caller cares about (`__call__`, `has_dynamic_dims`, `dim_params`,
|
||||
`set_dim`). Subsequent calls forward straight to the inner CompiledModel.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
gm,
|
||||
user_inputs,
|
||||
original_weights,
|
||||
user_indices,
|
||||
dynamic_shapes,
|
||||
factory,
|
||||
):
|
||||
self._gm = gm
|
||||
self._user_inputs = user_inputs
|
||||
self._original_weights = original_weights
|
||||
self._user_indices = user_indices
|
||||
self._dynamic_shapes = dynamic_shapes
|
||||
self._factory = factory
|
||||
self._compiled = None
|
||||
|
||||
def _ensure_compiled(self):
|
||||
if self._compiled is None:
|
||||
self._compiled = _eager_pt2_compile(
|
||||
self._gm,
|
||||
self._user_inputs,
|
||||
self._original_weights,
|
||||
self._user_indices,
|
||||
self._dynamic_shapes,
|
||||
self._factory,
|
||||
)
|
||||
# Drop references to inputs we no longer need — the Rust side
|
||||
# holds onto weights via device pointers / CPU buffers.
|
||||
self._gm = None
|
||||
self._user_inputs = None
|
||||
self._original_weights = None
|
||||
return self._compiled
|
||||
|
||||
def __call__(self, *inputs, **kwargs):
|
||||
return self._ensure_compiled()(*inputs, **kwargs)
|
||||
|
||||
@property
|
||||
def has_dynamic_dims(self):
|
||||
return self._ensure_compiled().has_dynamic_dims
|
||||
|
||||
@property
|
||||
def dim_params(self):
|
||||
return self._ensure_compiled().dim_params
|
||||
|
||||
def set_dim(self, name, value):
|
||||
return self._ensure_compiled().set_dim(name, value)
|
||||
|
||||
|
||||
def pt2_backend(gm, example_inputs, factory=None):
|
||||
"""torch.compile backend using PT2 pipeline.
|
||||
|
||||
Usage: torch.compile(model, backend=luminal.register_backend(capsule))
|
||||
"""
|
||||
import copy as _copy
|
||||
|
||||
if factory is None:
|
||||
factory = _detect_factory_capsule(example_inputs)
|
||||
|
||||
# Work on a private copy of the GraphModule. Dynamo holds onto the
|
||||
# original to install guards and to retrace on shape changes; mutating it
|
||||
# here (erasing SymInt placeholders, re-internalizing lifted weights)
|
||||
# corrupts that bookkeeping and surfaces as cryptic "guard failed on the
|
||||
# same frame" assertions on the next call. The deepcopy is cheap relative
|
||||
# to the rest of the export pipeline.
|
||||
gm = _copy.deepcopy(gm).eval()
|
||||
gm, user_inputs, original_weights, post_lift_indices = _reinternalize_lifted_params(
|
||||
gm, example_inputs
|
||||
)
|
||||
|
||||
# Lift any SymInt placeholders Dynamo emitted alongside the tensor inputs
|
||||
# into `aten.sym_size.int` calls so the re-export sees a tensor-only
|
||||
# signature, then derive the `dynamic_shapes` spec from the surviving
|
||||
# tensor placeholders' FakeTensor shapes. If the strip can't fully clean
|
||||
# the graph (e.g. a compound-expr SymInt with users), we drop dynamic
|
||||
# info and fall back to per-shape recompilation — same as today.
|
||||
user_inputs, post_strip_subindices, strip_ok = _strip_symint_placeholders(
|
||||
gm, user_inputs
|
||||
)
|
||||
dynamic_shapes = _build_dynamic_shapes_from_gm(gm) if strip_ok else None
|
||||
|
||||
# Compose both filter steps into a single user_indices list relative to
|
||||
# the *original* example_inputs Dynamo will pass at runtime — so
|
||||
# CompiledModel.__call__ can drop both lifted weights and SymInt args.
|
||||
user_indices = [post_lift_indices[i] for i in post_strip_subindices]
|
||||
|
||||
if dynamic_shapes is not None:
|
||||
# See `_LazyDynamicCompiledModel` for why dynamic-shape compiles must
|
||||
# be deferred — torch.export with dynamic_shapes mutates ShapeEnv state
|
||||
# Dynamo is still relying on, and running it inside the backend frame
|
||||
# corrupts the freshly-installed guards.
|
||||
return _LazyDynamicCompiledModel(
|
||||
gm, user_inputs, original_weights, user_indices, dynamic_shapes, factory
|
||||
)
|
||||
|
||||
return _eager_pt2_compile(
|
||||
gm, user_inputs, original_weights, user_indices, None, factory
|
||||
)
|
||||
|
||||
@@ -25,10 +25,10 @@ def _new_capsule(name: bytes):
|
||||
def test_process_pt2_rejects_capsule_with_wrong_name():
|
||||
bogus = _new_capsule(b"not.luminal.backend_factory")
|
||||
with pytest.raises(ValueError, match="luminal.backend_factory"):
|
||||
process_pt2("/dev/null", "/dev/null", 0, bogus, None)
|
||||
process_pt2("/dev/null", "/dev/null", bogus, None)
|
||||
|
||||
|
||||
def test_process_pt2_rejects_capsule_with_no_name():
|
||||
unnamed = _new_capsule(None)
|
||||
with pytest.raises(ValueError, match="luminal.backend_factory"):
|
||||
process_pt2("/dev/null", "/dev/null", 0, unnamed, None)
|
||||
process_pt2("/dev/null", "/dev/null", unnamed, None)
|
||||
|
||||
@@ -1,312 +0,0 @@
|
||||
"""End-to-end tests for dynamic-shape support through ``torch.compile``.
|
||||
|
||||
These exercise the path that the standard PyTorch user hits — i.e. wrapping a
|
||||
model with ``torch.compile(model, backend=luminal_backend)`` and calling it
|
||||
with varying input shapes. The luminal backend is expected to recognise
|
||||
Dynamo-emitted SymInt placeholders, propagate the symbolic dims through the
|
||||
PT2 export, and reuse a single compiled graph across shape changes.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch._dynamo
|
||||
|
||||
from luminal.main import luminal_backend
|
||||
|
||||
|
||||
def _compile(model, count_holder):
|
||||
def wrapper(gm, example_inputs):
|
||||
out = luminal_backend(gm, example_inputs)
|
||||
count_holder.append(1)
|
||||
return out
|
||||
|
||||
return torch.compile(model, backend=wrapper)
|
||||
|
||||
|
||||
def _compile_with_dynamic_true(model, count_holder):
|
||||
def wrapper(gm, example_inputs):
|
||||
out = luminal_backend(gm, example_inputs)
|
||||
count_holder.append(1)
|
||||
return out
|
||||
|
||||
return torch.compile(model, backend=wrapper, dynamic=True)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _enable_automatic_dynamic():
|
||||
"""Make sure the tests run with Dynamo's automatic-dynamic detection on.
|
||||
|
||||
Other tests in the suite flip this off; reset state between tests so the
|
||||
cache that backs the previous suppression doesn't carry over. We also
|
||||
raise the recompile limit because Dynamo defaults to 1 (which trips
|
||||
before automatic-dynamic kicks in) and have to do an extra reset to
|
||||
drop any cached frames from prior tests in the suite.
|
||||
"""
|
||||
torch._dynamo.reset()
|
||||
prev_auto = torch._dynamo.config.automatic_dynamic_shapes
|
||||
prev_limit = torch._dynamo.config.recompile_limit
|
||||
torch._dynamo.config.automatic_dynamic_shapes = True
|
||||
torch._dynamo.config.recompile_limit = 16
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
torch._dynamo.config.automatic_dynamic_shapes = prev_auto
|
||||
torch._dynamo.config.recompile_limit = prev_limit
|
||||
torch._dynamo.reset()
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(),
|
||||
reason="CUDA-only — the dynamic-shape backend wiring is exercised end to end against the cuda_lite runtime",
|
||||
)
|
||||
def test_dynamic_seq_via_torch_compile_reuses_compile(device: torch.device):
|
||||
"""A varying seq dim should produce two backend invocations total.
|
||||
|
||||
First call: Dynamo emits a static-shape graph (no SymInt placeholders).
|
||||
Second call: Dynamo detects the size mismatch and re-traces with the dim
|
||||
marked dynamic. From that point on, every subsequent shape variation
|
||||
must be served by the same compiled graph — no further backend calls.
|
||||
"""
|
||||
|
||||
class Mdl(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
s = x.shape[0]
|
||||
return x.reshape(s, -1).sum(-1)
|
||||
|
||||
model = Mdl().to(device)
|
||||
counts: list[int] = []
|
||||
compiled = _compile(model, counts)
|
||||
|
||||
for shp in [4, 5, 6, 7, 5]:
|
||||
x = torch.randn(shp, 8, device=device)
|
||||
ref = model(x)
|
||||
out = compiled(x)
|
||||
assert out.shape == ref.shape, (
|
||||
f"shape={shp}: got {out.shape} expected {ref.shape}"
|
||||
)
|
||||
assert torch.allclose(out, ref, atol=1e-5), (
|
||||
f"shape={shp}: max_diff={torch.max(torch.abs(out - ref)).item():.2e}"
|
||||
)
|
||||
|
||||
assert len(counts) == 2, (
|
||||
f"expected exactly 2 backend invocations (one static, one dynamic), got {len(counts)}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(),
|
||||
reason="CUDA-only — exercises the cuda_lite dynamic-dim runtime",
|
||||
)
|
||||
def test_dynamic_via_torch_compile_with_lifted_weights(device: torch.device):
|
||||
"""Combines lifted-weight re-internalization with the SymInt strip.
|
||||
|
||||
Most real models hit both paths simultaneously (Dynamo lifts every
|
||||
`nn.Parameter` AND emits SymInt placeholders for any dim that varies
|
||||
between calls), so the two filters need to compose without losing
|
||||
track of input positions.
|
||||
"""
|
||||
|
||||
class Mdl(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.lin = torch.nn.Linear(8, 4)
|
||||
|
||||
def forward(self, x):
|
||||
return self.lin(x).sum(-1)
|
||||
|
||||
model = Mdl().eval().to(device)
|
||||
counts: list[int] = []
|
||||
compiled = _compile(model, counts)
|
||||
|
||||
for shp in [3, 4, 5, 6, 4]:
|
||||
x = torch.randn(shp, 8, device=device)
|
||||
ref = model(x)
|
||||
out = compiled(x)
|
||||
assert out.shape == ref.shape, (
|
||||
f"shape={shp}: got {out.shape} expected {ref.shape}"
|
||||
)
|
||||
assert torch.allclose(out, ref, atol=1e-5), (
|
||||
f"shape={shp}: max_diff={torch.max(torch.abs(out - ref)).item():.2e}"
|
||||
)
|
||||
|
||||
assert len(counts) == 2
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(),
|
||||
reason="CUDA-only — exercises the cuda_lite dynamic-dim runtime",
|
||||
)
|
||||
def test_compound_shape_expression_auto_resolves(device: torch.device):
|
||||
"""Affine shape expressions (`2*s` etc.) should still let auto-detect work.
|
||||
|
||||
The `auto_set_dims_from_input_shapes` Rust path used to only handle bare
|
||||
`Term::Var(c)` shape expressions and silently skip anything else, leaving
|
||||
affine dims unresolved on the CompiledGraph and the corresponding output
|
||||
sizes stale. We now invert single-variable affine forms `a*x + b` by
|
||||
sampling two probe points; this test exercises that path by constructing
|
||||
a model whose first axis evolves into `2*s` after a `cat` along it.
|
||||
"""
|
||||
|
||||
class Mdl(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
# `cat([x, x], dim=0)` doubles the leading dim — torch.export
|
||||
# encodes the resulting shape as `2*s` rather than `s`.
|
||||
return torch.cat([x, x], dim=0).sum(-1)
|
||||
|
||||
model = Mdl().to(device)
|
||||
counts: list[int] = []
|
||||
compiled = _compile(model, counts)
|
||||
|
||||
for shp in [4, 5, 6, 7, 5]:
|
||||
x = torch.randn(shp, 8, device=device)
|
||||
ref = model(x)
|
||||
out = compiled(x)
|
||||
assert out.shape == ref.shape, (
|
||||
f"shape={shp}: got {out.shape} expected {ref.shape}"
|
||||
)
|
||||
assert torch.allclose(out, ref, atol=1e-5)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(),
|
||||
reason="CUDA-only — exercises the cuda_lite dynamic-dim runtime",
|
||||
)
|
||||
def test_torch_compile_dynamic_true_single_compile(device: torch.device):
|
||||
"""`torch.compile(model, backend=luminal_backend, dynamic=True)` works.
|
||||
|
||||
`dynamic=True` skips Dynamo's specialise-then-promote dance and emits a
|
||||
fully-symbolic graph from the first call. The luminal backend must
|
||||
handle the SymInt placeholders Dynamo passes alongside the tensor
|
||||
inputs and reuse a single compiled graph across all shape variations —
|
||||
one backend invocation total, in contrast to the 2 we'd see under
|
||||
automatic-dynamic mode (which burns a static compile on call 1 before
|
||||
promoting to dynamic on call 2).
|
||||
"""
|
||||
|
||||
class Mdl(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
s = x.shape[0]
|
||||
return x.reshape(s, -1).sum(-1)
|
||||
|
||||
model = Mdl().to(device)
|
||||
counts: list[int] = []
|
||||
compiled = _compile_with_dynamic_true(model, counts)
|
||||
|
||||
for shp in [4, 5, 6, 7, 5]:
|
||||
x = torch.randn(shp, 8, device=device)
|
||||
ref = model(x)
|
||||
out = compiled(x)
|
||||
assert out.shape == ref.shape
|
||||
assert torch.allclose(out, ref, atol=1e-5)
|
||||
|
||||
assert len(counts) == 1, (
|
||||
f"dynamic=True should produce a single backend invocation, got {len(counts)}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(),
|
||||
reason="CUDA-only — exercises the cuda_lite dynamic-dim runtime",
|
||||
)
|
||||
def test_explicit_compile_float_input_dynamic(device: torch.device):
|
||||
"""`luminal.pt2.compile(model, example, dynamic_dim=...)` with a float input.
|
||||
|
||||
The previous version of `compile()` silently fell back to a static export
|
||||
for floating-point inputs (the `"auto"` heuristic was integer-only). The
|
||||
new spec accepts an explicit `int` or `Iterable[int]` regardless of dtype,
|
||||
and `"auto"` now picks every non-trivial axis.
|
||||
"""
|
||||
from luminal.pt2 import compile as luminal_compile
|
||||
|
||||
class Mdl(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return (x * 2.0).sum(-1)
|
||||
|
||||
model = Mdl().eval().to(device)
|
||||
example = torch.randn(4, 8, device=device)
|
||||
compiled = luminal_compile(model, example, search_iterations=3, dynamic_dim=0)
|
||||
|
||||
assert compiled.has_dynamic_dims, "compile() should have produced a dynamic graph"
|
||||
|
||||
for shp in [4, 5, 6, 7]:
|
||||
x = torch.randn(shp, 8, device=device)
|
||||
ref = model(x)
|
||||
out = compiled(x)
|
||||
# `compile()` returns a tuple of outputs; extract the first.
|
||||
out_t = out[0] if isinstance(out, tuple) else out
|
||||
assert out_t.shape == ref.shape, (
|
||||
f"shape={shp}: got {out_t.shape}, expected {ref.shape}"
|
||||
)
|
||||
assert torch.allclose(out_t, ref, atol=1e-5)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(),
|
||||
reason="CUDA-only — exercises the cuda_lite dynamic-dim runtime",
|
||||
)
|
||||
def test_explicit_compile_dynamic_shapes_passthrough(device: torch.device):
|
||||
"""`luminal.pt2.compile(... , dynamic_shapes=...)` accepts a full spec.
|
||||
|
||||
Lets the caller specify named `Dim` objects with ranges — the previous
|
||||
API hardcoded `Dim("seq", min=2)` for any single dynamic dim.
|
||||
"""
|
||||
from torch.export import Dim
|
||||
from luminal.pt2 import compile as luminal_compile
|
||||
|
||||
class Mdl(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return x.mean(-1)
|
||||
|
||||
model = Mdl().eval().to(device)
|
||||
example = torch.randn(4, 8, device=device)
|
||||
seq = Dim("seq_len", min=2, max=64)
|
||||
compiled = luminal_compile(
|
||||
model, example, search_iterations=3, dynamic_shapes=({0: seq},)
|
||||
)
|
||||
assert compiled.has_dynamic_dims
|
||||
# torch.export rewrites user-supplied Dim names to its internal s77/s33
|
||||
# convention before saving — what we actually need to verify is that a
|
||||
# symbolic dim was registered, not what label it ended up with.
|
||||
assert len(compiled.dim_params) == 1, (
|
||||
f"expected exactly one dynamic dim, got {compiled.dim_params}"
|
||||
)
|
||||
|
||||
for shp in [3, 5, 16]:
|
||||
x = torch.randn(shp, 8, device=device)
|
||||
ref = model(x)
|
||||
out = compiled(x)
|
||||
out_t = out[0] if isinstance(out, tuple) else out
|
||||
assert out_t.shape == ref.shape
|
||||
assert torch.allclose(out_t, ref, atol=1e-5)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(),
|
||||
reason="CUDA-only — exercises the cuda_lite dynamic-dim runtime",
|
||||
)
|
||||
def test_dynamic_two_dim_via_torch_compile(device: torch.device):
|
||||
"""Both batch and seq dynamic — should still reuse a single compile."""
|
||||
|
||||
class Mdl(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return x.sum(-1)
|
||||
|
||||
model = Mdl().to(device)
|
||||
counts: list[int] = []
|
||||
compiled = _compile(model, counts)
|
||||
|
||||
# Vary batch and seq together so Dynamo marks both as dynamic.
|
||||
for batch, seq in [(2, 8), (3, 9), (4, 10), (5, 11), (3, 12)]:
|
||||
x = torch.randn(batch, seq, device=device)
|
||||
ref = model(x)
|
||||
out = compiled(x)
|
||||
assert out.shape == ref.shape
|
||||
assert torch.allclose(out, ref, atol=1e-5)
|
||||
|
||||
# Allow at most a small number of compiles — two shape transitions can
|
||||
# legitimately take Dynamo two retraces (one per newly-dynamic dim).
|
||||
assert len(counts) <= 3, (
|
||||
f"expected ≤3 compiles for two-dim dynamic, got {len(counts)}"
|
||||
)
|
||||
@@ -221,7 +221,6 @@ from test_models import (
|
||||
Conv1dNoPadModel,
|
||||
Conv1dSamePadModel,
|
||||
Conv1dBiasModel,
|
||||
Conv1dFloorDivPositionalModel,
|
||||
Conv2dNoPadModel,
|
||||
Conv2dSamePadModel,
|
||||
Conv2dBiasModel,
|
||||
@@ -237,20 +236,99 @@ from test_models import (
|
||||
TinyMoERoutingModel,
|
||||
)
|
||||
|
||||
import luminal.pt2 as luminal_pt2
|
||||
from luminal import luminal_backend
|
||||
|
||||
|
||||
def _compile_for_export_mode(
|
||||
model: torch.nn.Module, export_mode: str | None = None
|
||||
) -> Callable:
|
||||
if export_mode is None:
|
||||
return torch.compile(model, backend=luminal_backend)
|
||||
return torch.compile(
|
||||
def test_backend_options_forwarded_to_process_pt2(
|
||||
monkeypatch: pytest.MonkeyPatch, device: torch.device
|
||||
):
|
||||
captured = {}
|
||||
|
||||
def fake_process_pt2(
|
||||
pt2_path,
|
||||
weights_path,
|
||||
factory_capsule,
|
||||
weight_device_ptrs=None,
|
||||
options=None,
|
||||
):
|
||||
captured["pt2_path"] = pt2_path
|
||||
captured["weights_path"] = weights_path
|
||||
captured["factory_capsule"] = factory_capsule
|
||||
captured["weight_device_ptrs"] = weight_device_ptrs
|
||||
captured["options"] = options
|
||||
return object()
|
||||
|
||||
monkeypatch.setattr(luminal_pt2, "process_pt2", fake_process_pt2)
|
||||
monkeypatch.setattr(
|
||||
luminal_pt2, "_load_cpu_weights", lambda compiled, weights: None
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
luminal_pt2,
|
||||
"CompiledModel",
|
||||
lambda compiled, weight_refs=None: lambda x: x + x,
|
||||
)
|
||||
|
||||
model: torch.nn.Module = AddTestModel().to(device)
|
||||
compiled: Callable = torch.compile(
|
||||
model,
|
||||
backend=luminal_backend,
|
||||
options={"export_mode": export_mode},
|
||||
options={"search_iterations": 3},
|
||||
)
|
||||
|
||||
x: torch.Tensor = torch.rand((5, 5), device=device)
|
||||
compiled(x)
|
||||
|
||||
assert captured["weights_path"] == ""
|
||||
assert type(captured["factory_capsule"]).__name__ == "PyCapsule"
|
||||
assert captured["options"] == {"search_iterations": 3}
|
||||
assert isinstance(captured["weight_device_ptrs"], dict)
|
||||
|
||||
|
||||
def test_backend_options_unknown_key_raises(device: torch.device):
|
||||
model: torch.nn.Module = AddTestModel().to(device)
|
||||
compiled: Callable = torch.compile(
|
||||
model,
|
||||
backend=luminal_backend,
|
||||
options={"unknown_option": 1},
|
||||
)
|
||||
|
||||
x: torch.Tensor = torch.rand((5, 5), device=device)
|
||||
with pytest.raises(torch._dynamo.exc.BackendCompilerFailed) as exc_info:
|
||||
compiled(x)
|
||||
assert isinstance(exc_info.value.inner_exception, ValueError)
|
||||
assert "Unsupported luminal backend option" in str(exc_info.value.inner_exception)
|
||||
|
||||
|
||||
def test_backend_options_non_dict_raises(device: torch.device):
|
||||
model: torch.nn.Module = AddTestModel().to(device)
|
||||
compiled: Callable = torch.compile(
|
||||
model,
|
||||
backend=luminal_backend,
|
||||
options=["pt2"],
|
||||
)
|
||||
|
||||
x: torch.Tensor = torch.rand((5, 5), device=device)
|
||||
with pytest.raises(torch._dynamo.exc.BackendCompilerFailed) as exc_info:
|
||||
compiled(x)
|
||||
assert isinstance(exc_info.value.inner_exception, TypeError)
|
||||
assert "options must be a dict" in str(exc_info.value.inner_exception)
|
||||
|
||||
|
||||
def test_backend_options_bad_search_iterations_type_raises(device: torch.device):
|
||||
model: torch.nn.Module = AddTestModel().to(device)
|
||||
compiled: Callable = torch.compile(
|
||||
model,
|
||||
backend=luminal_backend,
|
||||
options={"search_iterations": "fast"},
|
||||
)
|
||||
|
||||
x: torch.Tensor = torch.rand((5, 5), device=device)
|
||||
with pytest.raises(torch._dynamo.exc.BackendCompilerFailed) as exc_info:
|
||||
compiled(x)
|
||||
assert isinstance(exc_info.value.inner_exception, TypeError)
|
||||
assert "search_iterations" in str(exc_info.value.inner_exception)
|
||||
|
||||
|
||||
def test_add(device: torch.device):
|
||||
add_test_model: torch.nn.Module = AddTestModel().to(device)
|
||||
@@ -1098,17 +1176,6 @@ def test_reduce_sum_all_axes(device: torch.device):
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
|
||||
|
||||
def test_reduce_sum_all_axes_int64_preserves_dtype(device: torch.device):
|
||||
"""Full reduction of an int64 tensor must preserve int64 (regression for LUM-486)."""
|
||||
model: torch.nn.Module = ReduceSumAllAxesModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x: torch.Tensor = torch.randint(0, 10, (3, 4), device=device, dtype=torch.int64)
|
||||
eager = model(x)
|
||||
out = model_compiled(x)
|
||||
assert out.dtype == eager.dtype == torch.int64
|
||||
assert torch.equal(out, eager)
|
||||
|
||||
|
||||
def test_reduce_sum_3d_axis1(device: torch.device):
|
||||
"""Test sum reduction along axis 1 for a 3D tensor."""
|
||||
model: torch.nn.Module = ReduceSum3DAxis1Model().to(device)
|
||||
@@ -1649,21 +1716,6 @@ def test_or(device: torch.device):
|
||||
assert torch.allclose(output, original)
|
||||
|
||||
|
||||
def test_bitwise_or(device: torch.device):
|
||||
"""Test bitwise_or on boolean tensors. PyTorch's `a | b` on Bool tensors
|
||||
emits `aten.bitwise_or.Tensor`, NOT `aten.logical_or.default` — Gemma-style
|
||||
sliding+full attention mask fusion takes this path."""
|
||||
from test_models import BitwiseOrTestModel
|
||||
|
||||
model: torch.nn.Module = BitwiseOrTestModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
a = torch.tensor([True, False, True, False, True, True], device=device)
|
||||
b = torch.tensor([False, True, True, False, False, True], device=device)
|
||||
original = model(a, b)
|
||||
output = model_compiled(a, b)
|
||||
assert torch.equal(output, original)
|
||||
|
||||
|
||||
# ========== PT2 Xor Node Tests ==========
|
||||
|
||||
|
||||
@@ -1860,60 +1912,6 @@ def test_scaled_dot_product_attention(device: torch.device):
|
||||
assert torch.allclose(output, original, atol=1e-5)
|
||||
|
||||
|
||||
# ========== F.scaled_dot_product_attention (SDPA aten variants) ==========
|
||||
# Tests for `torch.nn.functional.scaled_dot_product_attention`, which lowers
|
||||
# to one of `aten._scaled_dot_product_*_attention.default` (variant chosen by
|
||||
# PyTorch's dispatcher: efficient/flash/flash_for_cpu/cudnn). Coverage here
|
||||
# exercises `translate_sdpa` end-to-end.
|
||||
|
||||
|
||||
def _sdpa_qkv(device: torch.device, b: int = 1, h: int = 2, s: int = 4, d: int = 8):
|
||||
"""Build a `(B, H, S, D)` Q/K/V triple of float32 tensors on `device`."""
|
||||
torch.manual_seed(0)
|
||||
q = torch.rand((b, h, s, d), device=device)
|
||||
k = torch.rand((b, h, s, d), device=device)
|
||||
v = torch.rand((b, h, s, d), device=device)
|
||||
return q, k, v
|
||||
|
||||
|
||||
def test_sdpa_basic(device: torch.device):
|
||||
"""`F.scaled_dot_product_attention(q, k, v)` — default scale, no mask."""
|
||||
from test_models import SdpaBasicModel
|
||||
|
||||
model: torch.nn.Module = SdpaBasicModel().to(device)
|
||||
compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
q, k, v = _sdpa_qkv(device)
|
||||
expected: torch.Tensor = model(q, k, v)
|
||||
actual: torch.Tensor = compiled(q, k, v)
|
||||
assert torch.allclose(actual, expected, atol=1e-5)
|
||||
|
||||
|
||||
def test_sdpa_causal(device: torch.device):
|
||||
"""`F.scaled_dot_product_attention(q, k, v, is_causal=True)`."""
|
||||
from test_models import SdpaCausalModel
|
||||
|
||||
model: torch.nn.Module = SdpaCausalModel().to(device)
|
||||
compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
q, k, v = _sdpa_qkv(device)
|
||||
expected: torch.Tensor = model(q, k, v)
|
||||
actual: torch.Tensor = compiled(q, k, v)
|
||||
assert torch.allclose(actual, expected, atol=1e-5)
|
||||
|
||||
|
||||
def test_sdpa_with_attn_bias(device: torch.device):
|
||||
"""SDPA with an additive `attn_mask` (float bias) broadcast over heads."""
|
||||
from test_models import SdpaWithBiasModel
|
||||
|
||||
model: torch.nn.Module = SdpaWithBiasModel().to(device)
|
||||
compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
q, k, v = _sdpa_qkv(device)
|
||||
bias = torch.zeros((1, 1, q.shape[-2], k.shape[-2]), device=device)
|
||||
bias[..., 0, 1] = -1.0 # any non-trivial bias to verify it's actually applied
|
||||
expected: torch.Tensor = model(q, k, v, bias)
|
||||
actual: torch.Tensor = compiled(q, k, v, bias)
|
||||
assert torch.allclose(actual, expected, atol=1e-5)
|
||||
|
||||
|
||||
def test_mlp_block(device: torch.device):
|
||||
"""Test two-layer MLP: Linear(8,16) -> ReLU -> Linear(16,4) on input (2,8)."""
|
||||
model: torch.nn.Module = MLPBlockModel().to(device)
|
||||
@@ -2035,16 +2033,9 @@ def test_split(device: torch.device):
|
||||
# ========== Argsort / MoE Routing Tests ==========
|
||||
|
||||
|
||||
@pytest.mark.parametrize("idx_dtype", [torch.int32, torch.int64])
|
||||
def test_argsort_stable_duplicates(device: torch.device, idx_dtype: torch.dtype):
|
||||
"""Duplicate values should follow stable lower-index-first tie-breaking.
|
||||
|
||||
Parametrized over int32/int64 to verify luminal preserves whichever
|
||||
integer dtype the eager model declares (LUM-486).
|
||||
"""
|
||||
model: torch.nn.Module = ArgsortStableDuplicatesModel(idx_dtype=idx_dtype).to(
|
||||
device
|
||||
)
|
||||
def test_argsort_stable_duplicates(device: torch.device):
|
||||
"""Duplicate values should follow stable lower-index-first tie-breaking."""
|
||||
model: torch.nn.Module = ArgsortStableDuplicatesModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x = torch.tensor(
|
||||
[[2.0, 1.0, 1.0, 3.0]],
|
||||
@@ -2053,21 +2044,13 @@ def test_argsort_stable_duplicates(device: torch.device, idx_dtype: torch.dtype)
|
||||
)
|
||||
original: torch.Tensor = model(x)
|
||||
output: torch.Tensor = model_compiled(x)
|
||||
assert original.dtype == idx_dtype, "test setup: model should cast to idx_dtype"
|
||||
assert output.dtype == original.dtype, (
|
||||
f"luminal returned {output.dtype}, eager produced {original.dtype}"
|
||||
)
|
||||
assert torch.equal(output, original)
|
||||
assert output.dtype == torch.int32
|
||||
assert torch.equal(output, original.to(torch.int32))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("idx_dtype", [torch.int32, torch.int64])
|
||||
def test_tiny_moe_routing(device: torch.device, idx_dtype: torch.dtype):
|
||||
"""Focused proof for built MoE routing support.
|
||||
|
||||
Parametrized over int32/int64 for the integer-valued outputs to verify
|
||||
luminal preserves the dtype declared by the eager model (LUM-486).
|
||||
"""
|
||||
model: torch.nn.Module = TinyMoERoutingModel(idx_dtype=idx_dtype).to(device)
|
||||
def test_tiny_moe_routing(device: torch.device):
|
||||
"""Focused proof for build MoE routing support."""
|
||||
model: torch.nn.Module = TinyMoERoutingModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
scores = torch.tensor(
|
||||
[[0.1, 0.9, 0.4, 0.7], [0.6, -0.8, 0.95, 0.2]],
|
||||
@@ -2078,10 +2061,17 @@ def test_tiny_moe_routing(device: torch.device, idx_dtype: torch.dtype):
|
||||
expected = model(scores)
|
||||
output = model_compiled(scores)
|
||||
|
||||
for actual, eager in zip(output, expected):
|
||||
assert actual.dtype == eager.dtype, (
|
||||
f"luminal returned {actual.dtype}, eager produced {eager.dtype}"
|
||||
)
|
||||
expected_dtypes = (
|
||||
torch.int32,
|
||||
torch.float32,
|
||||
torch.int32,
|
||||
torch.bool,
|
||||
torch.int32,
|
||||
torch.float32,
|
||||
)
|
||||
for actual, eager, expected_dtype in zip(output, expected, expected_dtypes):
|
||||
assert actual.dtype == expected_dtype
|
||||
eager = eager.to(actual.dtype)
|
||||
if actual.dtype.is_floating_point:
|
||||
assert torch.allclose(actual, eager)
|
||||
else:
|
||||
@@ -2099,23 +2089,6 @@ def test_topk_values(device: torch.device):
|
||||
assert torch.allclose(model_compiled(x), model(x))
|
||||
|
||||
|
||||
def test_topk_values_width_128_with_indices(device: torch.device):
|
||||
"""Regression for router-sized TopK values when both tuple outputs are used."""
|
||||
|
||||
class TopKValuesAndIndices(torch.nn.Module):
|
||||
def forward(self, x: torch.Tensor):
|
||||
values, indices = torch.topk(torch.softmax(x, dim=-1), 8, dim=1)
|
||||
return values, indices
|
||||
|
||||
model = TopKValuesAndIndices().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x: torch.Tensor = torch.randn(4, 128, device=device)
|
||||
actual_values, actual_indices = model_compiled(x)
|
||||
expected_values, expected_indices = model(x)
|
||||
assert torch.allclose(actual_values, expected_values, atol=1e-5)
|
||||
assert torch.equal(actual_indices.to(expected_indices.dtype), expected_indices)
|
||||
|
||||
|
||||
def test_topk_indices(device: torch.device):
|
||||
"""Tests TopK indices output for 2D tensor along axis=1."""
|
||||
model: torch.nn.Module = TopKIndicesTestModel().to(device)
|
||||
@@ -2173,261 +2146,6 @@ def test_scatter_nd(device: torch.device):
|
||||
assert torch.allclose(output, original)
|
||||
|
||||
|
||||
# ========== Bool-mask index_put correctness tests ==========
|
||||
#
|
||||
# `x[bool_mask] = scalar` is semantically `where(mask, scalar, x)`, NOT a
|
||||
# scatter into Int(mask) positions. Pre-fix, the translator cast the Bool
|
||||
# mask to Int and routed through scatter_nd, reinterpreting True/False as
|
||||
# row indices 1/0 and silently corrupting `x`. Each variant below exercises
|
||||
# a different mask configuration; together they would catch any regression
|
||||
# in the bool-mask blend path.
|
||||
|
||||
|
||||
def _check_bool_mask(
|
||||
device: torch.device, model_cls, x: torch.Tensor, mask: torch.Tensor
|
||||
):
|
||||
"""Shared body: compile, run eager + compiled, assert exact equality."""
|
||||
from test_models import (
|
||||
BoolMaskAssign3DModel,
|
||||
BoolMaskAssignFloatModel,
|
||||
BoolMaskAssignIntModel,
|
||||
)
|
||||
|
||||
_ = (BoolMaskAssign3DModel, BoolMaskAssignFloatModel, BoolMaskAssignIntModel)
|
||||
model: torch.nn.Module = model_cls().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
original: torch.Tensor = model(x, mask)
|
||||
output: torch.Tensor = model_compiled(x, mask)
|
||||
# Bit-equal (not allclose) — the lowering should produce identical
|
||||
# results to eager for bool-mask blends.
|
||||
assert torch.equal(output, original), (
|
||||
f"bool-mask index_put mismatch:\n"
|
||||
f" mask = {mask.flatten().tolist()}\n"
|
||||
f" eager = {original.flatten().tolist()}\n"
|
||||
f" out = {output.flatten().tolist()}"
|
||||
)
|
||||
|
||||
|
||||
def test_bool_mask_index_put_all_false(device: torch.device):
|
||||
"""All-False mask must be a no-op. Pre-fix this *silently* corrupted row 0
|
||||
— the regression that drove the Gemma-4 ~30-magnitude logits drift."""
|
||||
from test_models import BoolMaskAssignIntModel
|
||||
|
||||
x = torch.arange(16, device=device, dtype=torch.long).reshape(4, 4)
|
||||
mask = torch.zeros(4, 4, dtype=torch.bool, device=device)
|
||||
_check_bool_mask(device, BoolMaskAssignIntModel, x, mask)
|
||||
|
||||
|
||||
def test_bool_mask_index_put_one_true(device: torch.device):
|
||||
"""Single True position — only that position should change."""
|
||||
from test_models import BoolMaskAssignIntModel
|
||||
|
||||
x = torch.arange(16, device=device, dtype=torch.long).reshape(4, 4)
|
||||
mask = torch.zeros(4, 4, dtype=torch.bool, device=device)
|
||||
mask[1, 2] = True
|
||||
_check_bool_mask(device, BoolMaskAssignIntModel, x, mask)
|
||||
|
||||
|
||||
def test_bool_mask_index_put_many_true(device: torch.device):
|
||||
"""Multiple scattered True positions — each should be replaced independently."""
|
||||
from test_models import BoolMaskAssignIntModel
|
||||
|
||||
x = torch.arange(16, device=device, dtype=torch.long).reshape(4, 4)
|
||||
mask = torch.tensor(
|
||||
[
|
||||
[True, False, False, True],
|
||||
[False, False, True, False],
|
||||
[True, False, False, False],
|
||||
[False, True, False, True],
|
||||
],
|
||||
dtype=torch.bool,
|
||||
device=device,
|
||||
)
|
||||
_check_bool_mask(device, BoolMaskAssignIntModel, x, mask)
|
||||
|
||||
|
||||
def test_bool_mask_index_put_all_true(device: torch.device):
|
||||
"""All-True mask — every element should become the scalar value."""
|
||||
from test_models import BoolMaskAssignIntModel
|
||||
|
||||
x = torch.arange(16, device=device, dtype=torch.long).reshape(4, 4)
|
||||
mask = torch.ones(4, 4, dtype=torch.bool, device=device)
|
||||
_check_bool_mask(device, BoolMaskAssignIntModel, x, mask)
|
||||
|
||||
|
||||
def test_bool_mask_index_put_float(device: torch.device):
|
||||
"""Float data + float scalar value. Verifies the where-blend works for
|
||||
non-integer dtypes — the blend formula `a*(1-mask) + value*mask` casts
|
||||
mask to data's dtype, so dtype-specific paths must compose correctly."""
|
||||
from test_models import BoolMaskAssignFloatModel
|
||||
|
||||
x = torch.arange(20, device=device, dtype=torch.float32).reshape(4, 5)
|
||||
mask = torch.tensor(
|
||||
[
|
||||
[True, False, False, True, False],
|
||||
[False, True, False, False, True],
|
||||
[True, True, False, False, False],
|
||||
[False, False, False, True, True],
|
||||
],
|
||||
dtype=torch.bool,
|
||||
device=device,
|
||||
)
|
||||
model = BoolMaskAssignFloatModel().to(device)
|
||||
compiled = torch.compile(model, backend=luminal_backend)
|
||||
original = model(x, mask)
|
||||
output = compiled(x, mask)
|
||||
assert torch.allclose(output, original)
|
||||
|
||||
|
||||
def test_bool_mask_index_put_3d(device: torch.device):
|
||||
"""3-D `x` with a 3-D bool mask of matching shape. Catches regressions
|
||||
where the bool-mask detection only works at one specific rank — the
|
||||
`idx_tensor.shape.dims == a.shape.dims` check has to handle arbitrary
|
||||
ranks, not just 2-D."""
|
||||
from test_models import BoolMaskAssign3DModel
|
||||
|
||||
x = torch.arange(24, device=device, dtype=torch.float32).reshape(2, 3, 4)
|
||||
mask = torch.zeros(2, 3, 4, dtype=torch.bool, device=device)
|
||||
mask[0, 1, 2] = True
|
||||
mask[1, 0, 0] = True
|
||||
mask[1, 2, 3] = True
|
||||
model = BoolMaskAssign3DModel().to(device)
|
||||
compiled = torch.compile(model, backend=luminal_backend)
|
||||
original = model(x, mask)
|
||||
output = compiled(x, mask)
|
||||
assert torch.allclose(output, original)
|
||||
|
||||
|
||||
def test_int_index_put_scalar_src(device: torch.device):
|
||||
"""`x[indices] = scalar` with int indices: the scatter path receives a
|
||||
scalar src against a 1D index tensor. Pre-fix `GraphTensor::scatter`
|
||||
panicked at `flatten_strides` (rank mismatch: index_shape=[2],
|
||||
src_strides=[]). With the zero-stride padding the scalar broadcasts
|
||||
across all indexed positions correctly."""
|
||||
from test_models import IntIndexAssignScalarModel
|
||||
|
||||
x = torch.arange(20, device=device, dtype=torch.float32).reshape(5, 4)
|
||||
indices = torch.tensor([0, 3], device=device, dtype=torch.long)
|
||||
model = IntIndexAssignScalarModel().to(device)
|
||||
compiled = torch.compile(model, backend=luminal_backend)
|
||||
original = model(x, indices)
|
||||
output = compiled(x, indices)
|
||||
assert torch.allclose(output, original)
|
||||
|
||||
|
||||
def test_grouped_mm_fallback(device: torch.device):
|
||||
"""Tests transformers::grouped_mm_fallback — the per-expert batched matmul
|
||||
used by HF MoE forward passes (DeepSeek-V2/V3, Qwen2/3-MoE, Mixtral, ...).
|
||||
|
||||
Importing transformers.integrations.moe registers the custom_op via
|
||||
`torch.library.custom_op("transformers::grouped_mm_fallback", ...)`. After
|
||||
import, `torch.ops.transformers.grouped_mm_fallback` is callable directly.
|
||||
"""
|
||||
# Side-effect import: registers the custom_op via torch.library.custom_op.
|
||||
# The name itself isn't referenced — ruff's F401 must be suppressed.
|
||||
import transformers.integrations.moe # noqa: F401
|
||||
from test_models import GroupedMMFallbackTestModel
|
||||
|
||||
model: torch.nn.Module = GroupedMMFallbackTestModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
# 2 experts, 4 tokens, K=8, N=16. Tokens [0,1] go to expert 0, [2,3] to expert 1.
|
||||
g, s, k, n = 2, 4, 8, 16
|
||||
input = torch.randn(s, k, device=device)
|
||||
weight = torch.randn(g, k, n, device=device)
|
||||
offs = torch.tensor([2, 4], device=device, dtype=torch.int32)
|
||||
original: torch.Tensor = model(input, weight, offs)
|
||||
output: torch.Tensor = model_compiled(input, weight, offs)
|
||||
assert torch.allclose(output, original, atol=1e-4)
|
||||
|
||||
|
||||
def test_grouped_mm_fallback_routing_invariance(device: torch.device):
|
||||
"""The MoE forest, not just the trees: one compile must correctly handle
|
||||
*any* routing pattern at the same shape.
|
||||
|
||||
`translate_grouped_mm` is correct only if `offs` flows through as a runtime
|
||||
tensor — the gate's top-k decision varies per token batch, and the same
|
||||
compiled graph has to dispatch tokens to the right experts for whatever
|
||||
`offs` arrives at execution. If our lowering accidentally specialized on a
|
||||
particular `offs` value (baking in expert assignments), `compiled(input_b,
|
||||
weight, offs_b)` would either silently produce wrong-expert output or
|
||||
trigger a recompile.
|
||||
|
||||
This test asserts three things at once:
|
||||
(a) Different `offs` (= different routing) doesn't trigger a recompile.
|
||||
(b) `offs` appears as an FX graph node, not a baked constant.
|
||||
(c) The same compiled graph produces correct output for both routings,
|
||||
and outputs *differ* between routings (else the test is moot).
|
||||
"""
|
||||
import transformers.integrations.moe # noqa: F401
|
||||
from test_models import GroupedMMFallbackTestModel
|
||||
|
||||
g, s, k, n = 2, 4, 8, 16
|
||||
|
||||
# Wrap luminal_backend to capture the FX graph(s) dynamo hands us.
|
||||
captured = []
|
||||
|
||||
def capturing_backend(gm, example_inputs):
|
||||
captured.append(gm)
|
||||
return luminal_backend(gm, example_inputs)
|
||||
|
||||
model = GroupedMMFallbackTestModel().to(device)
|
||||
compiled = torch.compile(model, backend=capturing_backend)
|
||||
|
||||
# Same shapes, different data → different routing patterns.
|
||||
weight = torch.randn(g, k, n, device=device)
|
||||
input_a = torch.randn(s, k, device=device)
|
||||
input_b = torch.randn(s, k, device=device)
|
||||
# offs[i] = cumulative tokens through expert i. Different routings:
|
||||
# offs_a: 1 token to expert 0, 3 to expert 1
|
||||
# offs_b: 3 tokens to expert 0, 1 to expert 1
|
||||
offs_a = torch.tensor([1, 4], device=device, dtype=torch.int32)
|
||||
offs_b = torch.tensor([3, 4], device=device, dtype=torch.int32)
|
||||
|
||||
with torch.no_grad():
|
||||
ref_a = model(input_a, weight, offs_a)
|
||||
out_a = compiled(input_a, weight, offs_a)
|
||||
n_compiles_after_first = len(captured)
|
||||
|
||||
ref_b = model(input_b, weight, offs_b)
|
||||
out_b = compiled(input_b, weight, offs_b)
|
||||
|
||||
# (a) No recompile between distinct routings.
|
||||
assert len(captured) == n_compiles_after_first, (
|
||||
f"Different routings triggered a recompile: "
|
||||
f"{n_compiles_after_first} → {len(captured)}"
|
||||
)
|
||||
|
||||
# (b) offs is an FX graph node, not a baked constant.
|
||||
grouped_nodes = [
|
||||
node for node in captured[0].graph.nodes if "grouped_mm" in str(node.target)
|
||||
]
|
||||
assert len(grouped_nodes) == 1, (
|
||||
f"Expected exactly one grouped_mm node, got {len(grouped_nodes)}"
|
||||
)
|
||||
grouped_node = grouped_nodes[0]
|
||||
# transformers::grouped_mm_fallback emits offs as a kwarg; aten._grouped_mm
|
||||
# may emit it as a positional. Accept either.
|
||||
offs_arg = grouped_node.kwargs.get("offs")
|
||||
if offs_arg is None and len(grouped_node.args) > 2:
|
||||
offs_arg = grouped_node.args[2]
|
||||
assert hasattr(offs_arg, "op"), (
|
||||
f"offs argument should be an FX graph node, got {offs_arg!r} "
|
||||
f"({type(offs_arg).__name__}) — looks baked as constant"
|
||||
)
|
||||
|
||||
# (c) Both routings produce correct output, and outputs differ.
|
||||
assert torch.allclose(out_a, ref_a, atol=1e-4), (
|
||||
f"routing A: max_diff={torch.max(torch.abs(out_a - ref_a)).item():.2e}"
|
||||
)
|
||||
assert torch.allclose(out_b, ref_b, atol=1e-4), (
|
||||
f"routing B: max_diff={torch.max(torch.abs(out_b - ref_b)).item():.2e}"
|
||||
)
|
||||
assert not torch.allclose(out_a, out_b, atol=1e-3), (
|
||||
"Outputs of routing A and B should differ — otherwise routing isn't "
|
||||
"actually being exercised."
|
||||
)
|
||||
|
||||
|
||||
# ========== Dtype Round-Trip Tests ==========
|
||||
|
||||
|
||||
@@ -2460,10 +2178,10 @@ def test_dtype_float32(device: torch.device):
|
||||
# ========== Convolution Tests ==========
|
||||
|
||||
|
||||
def _run_conv1d_no_pad(device: torch.device, export_mode: str | None = None):
|
||||
def _run_conv1d_no_pad(device: torch.device):
|
||||
"""Conv1d without padding: output length = input - (kernel-1)."""
|
||||
model: torch.nn.Module = Conv1dNoPadModel().to(device)
|
||||
model_compiled: Callable = _compile_for_export_mode(model, export_mode)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x: torch.Tensor = torch.randn(2, 8, 32, device=device)
|
||||
original: torch.Tensor = model(x)
|
||||
output: torch.Tensor = model_compiled(x)
|
||||
@@ -2474,10 +2192,6 @@ def test_conv1d_no_pad(device: torch.device):
|
||||
_run_conv1d_no_pad(device)
|
||||
|
||||
|
||||
def test_conv1d_no_pad_pt2(device: torch.device):
|
||||
_run_conv1d_no_pad(device, "pt2")
|
||||
|
||||
|
||||
def test_conv1d_same_pad(device: torch.device):
|
||||
"""Conv1d with padding=1: output length == input length."""
|
||||
model: torch.nn.Module = Conv1dSamePadModel().to(device)
|
||||
@@ -2498,21 +2212,10 @@ def test_conv1d_bias(device: torch.device):
|
||||
assert torch.allclose(output, original, atol=1e-4)
|
||||
|
||||
|
||||
def test_conv1d_floor_div_positional_pt2(device: torch.device):
|
||||
"""Conv1d stride output uses floor division before positional add."""
|
||||
model: torch.nn.Module = Conv1dFloorDivPositionalModel().to(device)
|
||||
model_compiled: Callable = _compile_for_export_mode(model, "pt2")
|
||||
x: torch.Tensor = torch.randn(1, 8, 30, device=device)
|
||||
original: torch.Tensor = model(x)
|
||||
output: torch.Tensor = model_compiled(x)
|
||||
assert output.shape == original.shape == (15, 16)
|
||||
assert torch.allclose(output, original, atol=1e-3, rtol=1e-3)
|
||||
|
||||
|
||||
def _run_conv2d_no_pad(device: torch.device, export_mode: str | None = None):
|
||||
def _run_conv2d_no_pad(device: torch.device):
|
||||
"""Conv2d without padding: output spatial = input - (kernel-1)."""
|
||||
model: torch.nn.Module = Conv2dNoPadModel().to(device)
|
||||
model_compiled: Callable = _compile_for_export_mode(model, export_mode)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x: torch.Tensor = torch.randn(1, 3, 8, 8, device=device)
|
||||
original: torch.Tensor = model(x)
|
||||
output: torch.Tensor = model_compiled(x)
|
||||
@@ -2523,10 +2226,6 @@ def test_conv2d_no_pad(device: torch.device):
|
||||
_run_conv2d_no_pad(device)
|
||||
|
||||
|
||||
def test_conv2d_no_pad_pt2(device: torch.device):
|
||||
_run_conv2d_no_pad(device, "pt2")
|
||||
|
||||
|
||||
def test_conv2d_same_pad(device: torch.device):
|
||||
"""Conv2d with padding=1: output spatial == input spatial."""
|
||||
model: torch.nn.Module = Conv2dSamePadModel().to(device)
|
||||
@@ -2557,10 +2256,10 @@ def test_conv2d_stride(device: torch.device):
|
||||
assert torch.allclose(output, original, atol=1e-4)
|
||||
|
||||
|
||||
def _run_conv2d_dilation(device: torch.device, export_mode: str | None = None):
|
||||
def _run_conv2d_dilation(device: torch.device):
|
||||
"""Conv2d with dilation=2 preserves the expected spatial shape and values."""
|
||||
model: torch.nn.Module = Conv2dDilationModel().to(device)
|
||||
model_compiled: Callable = _compile_for_export_mode(model, export_mode)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x: torch.Tensor = torch.randn(2, 8, 17, 19, device=device)
|
||||
original: torch.Tensor = model(x)
|
||||
output: torch.Tensor = model_compiled(x)
|
||||
@@ -2571,14 +2270,10 @@ def test_conv2d_dilation(device: torch.device):
|
||||
_run_conv2d_dilation(device)
|
||||
|
||||
|
||||
def test_conv2d_dilation_pt2(device: torch.device):
|
||||
_run_conv2d_dilation(device, "pt2")
|
||||
|
||||
|
||||
def _run_conv3d_same_pad(device: torch.device, export_mode: str | None = None):
|
||||
def _run_conv3d_same_pad(device: torch.device):
|
||||
"""Conv3d exercises the spatial=3 unfold/permute/split path."""
|
||||
model: torch.nn.Module = Conv3dSamePadModel().to(device)
|
||||
model_compiled: Callable = _compile_for_export_mode(model, export_mode)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x: torch.Tensor = torch.randn(2, 4, 6, 7, 8, device=device)
|
||||
original: torch.Tensor = model(x)
|
||||
output: torch.Tensor = model_compiled(x)
|
||||
@@ -2589,10 +2284,6 @@ def test_conv3d_same_pad(device: torch.device):
|
||||
_run_conv3d_same_pad(device)
|
||||
|
||||
|
||||
def test_conv3d_same_pad_pt2(device: torch.device):
|
||||
_run_conv3d_same_pad(device, "pt2")
|
||||
|
||||
|
||||
def test_depthwise_conv1d(device: torch.device):
|
||||
"""Depthwise Conv1d with groups=in_channels, as used in Mamba."""
|
||||
model: torch.nn.Module = DepthwiseConv1dModel().to(device)
|
||||
@@ -2613,12 +2304,10 @@ def test_depthwise_conv2d(device: torch.device):
|
||||
assert torch.allclose(output, original, atol=1e-4)
|
||||
|
||||
|
||||
def _run_depthwise_multiplier_conv2d(
|
||||
device: torch.device, export_mode: str | None = None
|
||||
):
|
||||
def _run_depthwise_multiplier_conv2d(device: torch.device):
|
||||
"""Depthwise Conv2d with multiplier > 1 should preserve both output channels per input channel."""
|
||||
model: torch.nn.Module = DepthwiseMultiplierConv2dModel().to(device)
|
||||
model_compiled: Callable = _compile_for_export_mode(model, export_mode)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x: torch.Tensor = torch.randn(2, 8, 9, 9, device=device)
|
||||
original: torch.Tensor = model(x)
|
||||
output: torch.Tensor = model_compiled(x)
|
||||
@@ -2629,10 +2318,6 @@ def test_depthwise_multiplier_conv2d(device: torch.device):
|
||||
_run_depthwise_multiplier_conv2d(device)
|
||||
|
||||
|
||||
def test_depthwise_multiplier_conv2d_pt2(device: torch.device):
|
||||
_run_depthwise_multiplier_conv2d(device, "pt2")
|
||||
|
||||
|
||||
def test_grouped_conv2d(device: torch.device):
|
||||
"""Conv2d with groups=4 (grouped, not depthwise)."""
|
||||
model: torch.nn.Module = GroupedConv2dModel().to(device)
|
||||
@@ -2643,12 +2328,10 @@ def test_grouped_conv2d(device: torch.device):
|
||||
assert torch.allclose(output, original, atol=1e-4)
|
||||
|
||||
|
||||
def _run_grouped_conv2d_groups3_batch4(
|
||||
device: torch.device, export_mode: str | None = None
|
||||
):
|
||||
def _run_grouped_conv2d_groups3_batch4(device: torch.device):
|
||||
"""Grouped Conv2d with groups=3 and batch>1 exercises the pre-pad + slice path."""
|
||||
model: torch.nn.Module = GroupedConv2dGroups3Model().to(device)
|
||||
model_compiled: Callable = _compile_for_export_mode(model, export_mode)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x: torch.Tensor = torch.randn(4, 12, 11, 9, device=device)
|
||||
original: torch.Tensor = model(x)
|
||||
output: torch.Tensor = model_compiled(x)
|
||||
@@ -2659,10 +2342,6 @@ def test_grouped_conv2d_groups3_batch4(device: torch.device):
|
||||
_run_grouped_conv2d_groups3_batch4(device)
|
||||
|
||||
|
||||
def test_grouped_conv2d_groups3_batch4_pt2(device: torch.device):
|
||||
_run_grouped_conv2d_groups3_batch4(device, "pt2")
|
||||
|
||||
|
||||
def test_mamba_conv_block(device: torch.device):
|
||||
"""Minimal Mamba-style block: depthwise Conv1d with causal gating (end-to-end)."""
|
||||
model: torch.nn.Module = MambaConvBlockModel().to(device)
|
||||
|
||||
@@ -1,94 +0,0 @@
|
||||
"""KV Cache decode loop test.
|
||||
|
||||
Compiles a tiny 1-layer Llama model with use_cache=True, then:
|
||||
1. Prefill: model(input_ids) -> logits + K/V cache
|
||||
2. Decode: model(next_token, past_key_values=cache) -> logits + updated K/V
|
||||
|
||||
Verifies correctness of both steps and writes DOT graphs for comparison.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
from luminal import luminal_backend
|
||||
|
||||
|
||||
def _capturing_backend(captured):
|
||||
"""Wrap luminal_backend to capture CompiledModels for DOT extraction."""
|
||||
|
||||
def backend(gm, example_inputs):
|
||||
compiled = luminal_backend(gm, example_inputs)
|
||||
captured.append(compiled)
|
||||
return compiled
|
||||
|
||||
return backend
|
||||
|
||||
|
||||
def test_kv_cache_decode_loop():
|
||||
"""Full prefill -> decode loop through luminal with KV cache."""
|
||||
from transformers import LlamaConfig, LlamaForCausalLM
|
||||
|
||||
# Allow both prefill and decode compilations (conftest sets limit=1)
|
||||
torch._dynamo.config.cache_size_limit = 2
|
||||
|
||||
config = LlamaConfig(
|
||||
hidden_size=64,
|
||||
num_attention_heads=4,
|
||||
num_key_value_heads=2,
|
||||
num_hidden_layers=1,
|
||||
intermediate_size=128,
|
||||
vocab_size=256,
|
||||
max_position_embeddings=128,
|
||||
use_cache=True,
|
||||
attn_implementation="eager",
|
||||
)
|
||||
model = LlamaForCausalLM(config).eval()
|
||||
input_ids = torch.tensor([[1, 2, 3, 4]])
|
||||
|
||||
captured = []
|
||||
compiled = torch.compile(model, backend=_capturing_backend(captured))
|
||||
|
||||
# --- Prefill step ---
|
||||
with torch.no_grad():
|
||||
ref_prefill = model(input_ids)
|
||||
out_prefill = compiled(input_ids)
|
||||
|
||||
assert torch.allclose(out_prefill.logits, ref_prefill.logits, atol=1e-5)
|
||||
assert out_prefill.past_key_values is not None, "Prefill should return KV cache"
|
||||
|
||||
# --- Decode step ---
|
||||
next_token = ref_prefill.logits[0, -1, :].argmax().unsqueeze(0).unsqueeze(0)
|
||||
|
||||
with torch.no_grad():
|
||||
ref_decode = model(next_token, past_key_values=ref_prefill.past_key_values)
|
||||
out_decode = compiled(next_token, past_key_values=out_prefill.past_key_values)
|
||||
|
||||
assert torch.allclose(out_decode.logits, ref_decode.logits, atol=1e-5)
|
||||
|
||||
# --- DOT graph comparison ---
|
||||
# captured[0] = prefill graph, captured[1] = decode graph (recompiled by dynamo)
|
||||
assert len(captured) >= 2, (
|
||||
f"Expected 2 compilations (prefill+decode), got {len(captured)}"
|
||||
)
|
||||
|
||||
out_dir = "/tmp/luminal_kv_cache_comparison"
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
|
||||
prefill_dot = captured[0]._graph.to_dot()
|
||||
decode_dot = captured[1]._graph.to_dot()
|
||||
|
||||
with open(os.path.join(out_dir, "prefill.dot"), "w") as f:
|
||||
f.write(prefill_dot)
|
||||
with open(os.path.join(out_dir, "decode.dot"), "w") as f:
|
||||
f.write(decode_dot)
|
||||
|
||||
print(f"\n=== DOT files written to {out_dir} ===")
|
||||
print(f"Prefill: {len(prefill_dot)} chars, inputs: {captured[0]._input_names}")
|
||||
print(f"Decode: {len(decode_dot)} chars, inputs: {captured[1]._input_names}")
|
||||
|
||||
# Decode graph should have more inputs (past K/V cache tensors)
|
||||
assert len(captured[1]._input_names) > len(captured[0]._input_names), (
|
||||
f"Decode should have more inputs than prefill: "
|
||||
f"{len(captured[1]._input_names)} vs {len(captured[0]._input_names)}"
|
||||
)
|
||||
@@ -1,195 +0,0 @@
|
||||
"""KV Cache growing decode loop test.
|
||||
|
||||
Compiles a tiny 1-layer Llama model with use_cache=True, then runs a
|
||||
multi-step autoregressive decode loop:
|
||||
|
||||
1. Prefill: model(input_ids) -> logits + initial KV cache
|
||||
2. Decode x N: model(next_token, past_key_values=cache) -> logits + grown KV cache
|
||||
|
||||
At each step, prints the KV cache tensor shapes so you can see the
|
||||
sequence dimension grow: (1, n_kv_heads, 4, head_dim) -> (1, n_kv_heads, 5, ...) -> ...
|
||||
|
||||
Verifies luminal output matches PyTorch reference at every step.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch._dynamo
|
||||
|
||||
from luminal import luminal_backend
|
||||
|
||||
NUM_DECODE_STEPS = 5
|
||||
|
||||
|
||||
def test_kv_cache_growing():
|
||||
"""Multi-step prefill + decode loop showing KV cache growth."""
|
||||
from transformers import LlamaConfig, LlamaForCausalLM
|
||||
|
||||
# We need 1 compilation for prefill + 1 per unique decode cache size
|
||||
torch._dynamo.config.cache_size_limit = NUM_DECODE_STEPS + 2
|
||||
# Disable automatic dynamic shapes — dynamo would otherwise try to use SymInt
|
||||
# for the varying cache seq_len dimension, which torch.export doesn't support.
|
||||
# Instead, we want a fresh recompilation for each new cache size.
|
||||
torch._dynamo.config.automatic_dynamic_shapes = False
|
||||
|
||||
config = LlamaConfig(
|
||||
hidden_size=64,
|
||||
num_attention_heads=4,
|
||||
num_key_value_heads=2,
|
||||
num_hidden_layers=4,
|
||||
intermediate_size=128,
|
||||
vocab_size=256,
|
||||
max_position_embeddings=128,
|
||||
use_cache=True,
|
||||
attn_implementation="eager",
|
||||
)
|
||||
model = LlamaForCausalLM(config).eval()
|
||||
compiled = torch.compile(model, backend=luminal_backend)
|
||||
|
||||
input_ids = torch.tensor([[1, 2, 3, 4]])
|
||||
|
||||
# ---- Prefill ----
|
||||
with torch.no_grad():
|
||||
ref_out = model(input_ids)
|
||||
lum_out = compiled(input_ids)
|
||||
|
||||
assert ref_out.past_key_values is not None, "Reference should return KV cache"
|
||||
assert lum_out.past_key_values is not None, "Luminal should return KV cache"
|
||||
|
||||
assert torch.allclose(lum_out.logits, ref_out.logits, atol=1e-5), (
|
||||
f"Prefill mismatch: max_diff="
|
||||
f"{torch.max(torch.abs(lum_out.logits - ref_out.logits)).item():.2e}"
|
||||
)
|
||||
|
||||
_print_cache_shapes("Prefill", ref_out.past_key_values, lum_out.past_key_values)
|
||||
|
||||
ref_cache = ref_out.past_key_values
|
||||
lum_cache = lum_out.past_key_values
|
||||
|
||||
# ---- Decode loop ----
|
||||
for step in range(NUM_DECODE_STEPS):
|
||||
# Greedy next token from reference logits
|
||||
next_token = ref_out.logits[0, -1, :].argmax().unsqueeze(0).unsqueeze(0)
|
||||
|
||||
with torch.no_grad():
|
||||
ref_out = model(next_token, past_key_values=ref_cache)
|
||||
lum_out = compiled(next_token, past_key_values=lum_cache)
|
||||
|
||||
assert torch.allclose(lum_out.logits, ref_out.logits, atol=1e-5), (
|
||||
f"Decode step {step} mismatch: max_diff="
|
||||
f"{torch.max(torch.abs(lum_out.logits - ref_out.logits)).item():.2e}"
|
||||
)
|
||||
|
||||
ref_cache = ref_out.past_key_values
|
||||
lum_cache = lum_out.past_key_values
|
||||
|
||||
_print_cache_shapes(f"Decode step {step}", ref_cache, lum_cache)
|
||||
|
||||
# Final sanity check: cache seq_len should equal prompt + decode steps
|
||||
expected_seq = input_ids.shape[1] + NUM_DECODE_STEPS
|
||||
final_k = ref_cache.layers[0].keys
|
||||
assert final_k.shape[2] == expected_seq, (
|
||||
f"Expected cache seq_len={expected_seq}, got {final_k.shape[2]}"
|
||||
)
|
||||
print(
|
||||
f"\nAll {NUM_DECODE_STEPS} decode steps passed. "
|
||||
f"Cache grew from seq_len={input_ids.shape[1]} to {expected_seq}."
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(),
|
||||
reason="R1 full-width 1-layer is too memory-heavy for CPU native backend",
|
||||
)
|
||||
@pytest.mark.slow
|
||||
def test_kv_cache_growing_r1_mla(device: torch.device):
|
||||
"""Growing-cache decode loop on DeepSeek-R1 (MLA + decoupled RoPE), 1 layer.
|
||||
|
||||
Exercises MLA: q_lora / kv_lora low-rank projections, decoupled RoPE split
|
||||
(qk_nope_head_dim + qk_rope_head_dim), and DynamicCache crossing the compile
|
||||
boundary through the MLA update path (`cache_utils.py:102-121`).
|
||||
|
||||
Runs in fp32 — in bf16, MLA's empty-tensor-cat inside DynamicLayer.update
|
||||
has a precision drift on the compiled path (logits ~3.7 on 1 layer) that
|
||||
does not affect standard GQA (Llama in bf16 is bit-identical). Investigate
|
||||
separately.
|
||||
"""
|
||||
from transformers import AutoConfig, DeepseekV3ForCausalLM
|
||||
|
||||
torch._dynamo.config.cache_size_limit = NUM_DECODE_STEPS + 2
|
||||
torch._dynamo.config.automatic_dynamic_shapes = False
|
||||
|
||||
# Release any memory accumulated by previous tests in the same pytest
|
||||
# process — full-width R1 instantiation needs ~3 GB and the test runner's
|
||||
# GPU is shared with ~230 prior tests' allocations.
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
config = AutoConfig.from_pretrained("deepseek-ai/DeepSeek-R1")
|
||||
config.num_hidden_layers = 1
|
||||
# first_k_dense_replace=3 (default) makes the 1 layer dense, so we avoid
|
||||
# the 256-expert MoE path and the associated memory pressure.
|
||||
config._attn_implementation = "eager"
|
||||
config.torch_dtype = torch.float32
|
||||
# Aggressively shrink the embedding / LM head / FFN dimensions while
|
||||
# preserving the MLA-specific knobs that the test is actually exercising
|
||||
# (q_lora_rank, kv_lora_rank, qk_nope_head_dim, qk_rope_head_dim, v_head_dim).
|
||||
# Full R1 has vocab=129280, intermediate=18432, hidden=7168 — at fp32 the
|
||||
# embedding + LM head alone is ~3.5 GB, which OOMs the 40 GB test runner
|
||||
# after prior tests' allocations. The MLA path is unchanged at vocab=256.
|
||||
config.vocab_size = 256
|
||||
config.intermediate_size = 512
|
||||
config.max_position_embeddings = 128
|
||||
model = DeepseekV3ForCausalLM(config).eval().to(dtype=torch.float32, device=device)
|
||||
compiled = torch.compile(model, backend=luminal_backend)
|
||||
|
||||
input_ids = torch.tensor([[1, 2, 3, 4]], device=device)
|
||||
|
||||
with torch.no_grad():
|
||||
ref_out = model(input_ids)
|
||||
lum_out = compiled(input_ids)
|
||||
|
||||
# fp32 MLA matches to ~1e-5 — see diagnose_dtype.py. Keep the tolerance
|
||||
# tight here so regressions in the MLA cat/split path show up immediately.
|
||||
assert torch.allclose(lum_out.logits, ref_out.logits, atol=1e-4), (
|
||||
f"Prefill: max_diff={torch.max(torch.abs(lum_out.logits - ref_out.logits)).item():.2e}"
|
||||
)
|
||||
|
||||
ref_cache = ref_out.past_key_values
|
||||
lum_cache = lum_out.past_key_values
|
||||
|
||||
# Run a single decode step — enough to confirm the cache flows through as an
|
||||
# explicit input on the second compile (the key signal from
|
||||
# _test_kv_cache_comparison.py's "decode has more inputs than prefill"
|
||||
# assertion). Full 5-step growth is covered by the Llama test above.
|
||||
next_token = ref_out.logits[0, -1, :].argmax().view(1, 1).to(device)
|
||||
with torch.no_grad():
|
||||
ref_dec = model(next_token, past_key_values=ref_cache)
|
||||
lum_dec = compiled(next_token, past_key_values=lum_cache)
|
||||
|
||||
assert torch.allclose(lum_dec.logits, ref_dec.logits, atol=1e-4), (
|
||||
f"Decode: max_diff={torch.max(torch.abs(lum_dec.logits - ref_dec.logits)).item():.2e}"
|
||||
)
|
||||
|
||||
|
||||
def _print_cache_shapes(label, ref_cache, lum_cache):
|
||||
"""Print KV cache shapes for both reference and luminal."""
|
||||
print(f"\n--- {label} ---")
|
||||
for layer_idx, ref_layer in enumerate(ref_cache.layers):
|
||||
ref_k, ref_v = ref_layer.keys, ref_layer.values
|
||||
lum_layer = lum_cache.layers[layer_idx]
|
||||
lum_k, lum_v = lum_layer.keys, lum_layer.values
|
||||
print(
|
||||
f" Layer {layer_idx}: "
|
||||
f"K ref={list(ref_k.shape)} lum={list(lum_k.shape)} | "
|
||||
f"V ref={list(ref_v.shape)} lum={list(lum_v.shape)}"
|
||||
)
|
||||
# Verify cache tensors match
|
||||
assert torch.allclose(lum_k, ref_k, atol=1e-5), (
|
||||
f"{label} layer {layer_idx} K mismatch: "
|
||||
f"max_diff={torch.max(torch.abs(lum_k - ref_k)).item():.2e}"
|
||||
)
|
||||
assert torch.allclose(lum_v, ref_v, atol=1e-5), (
|
||||
f"{label} layer {layer_idx} V mismatch: "
|
||||
f"max_diff={torch.max(torch.abs(lum_v - ref_v)).item():.2e}"
|
||||
)
|
||||
@@ -158,7 +158,6 @@ def test_hf_llama_medium(device: torch.device):
|
||||
_run_hf_llama_test(config, device, atol=1e-5)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_hf_llama_large(device: torch.device):
|
||||
"""HuggingFace LlamaForCausalLM — large (1024 hidden, 1 layer, ~18M params)."""
|
||||
config = _make_llama_config(
|
||||
@@ -172,7 +171,6 @@ def test_hf_llama_large(device: torch.device):
|
||||
_run_hf_llama_test(config, device, atol=1e-5)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_hf_llama3_real_config_1layer(device: torch.device):
|
||||
"""HuggingFace LlamaForCausalLM — real Llama3.2-1B architecture, 1 layer.
|
||||
|
||||
@@ -229,7 +227,6 @@ def test_hf_llama_decode_loop_static(device: torch.device):
|
||||
tokens.append(next_token)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.xfail(reason="numerical precision — max_diff exceeds atol")
|
||||
def test_hf_llama3_1b_decode_loop_dynamic(device: torch.device):
|
||||
"""Decode loop on real Llama3.2-1B with pretrained weights.
|
||||
@@ -285,7 +282,6 @@ def _gpu_mem(label):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.xfail(reason="numerical precision — max_diff exceeds atol")
|
||||
def test_hf_llama3_full(device: torch.device):
|
||||
"""HuggingFace LlamaForCausalLM — full Llama3.2-1B with real pretrained weights.
|
||||
@@ -337,7 +333,6 @@ def test_hf_llama3_full(device: torch.device):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.xfail(reason="numerical precision — max_diff exceeds atol")
|
||||
def test_hf_llama3_large_full(device: torch.device):
|
||||
"""HuggingFace LlamaForCausalLM — full Llama-3.1-8B-Instruct with real pretrained weights.
|
||||
@@ -419,7 +414,6 @@ def test_dynamic_dim_reuse_no_recompile(device: torch.device):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.xfail(reason="numerical precision — max_diff exceeds atol")
|
||||
def test_hf_llama38b_full(device: torch.device):
|
||||
"""HuggingFace LlamaForCausalLM — full Llama-3.1-8B-Instruct with real pretrained weights.
|
||||
|
||||
@@ -1623,32 +1623,16 @@ class SplitTestModel(torch.nn.Module):
|
||||
|
||||
|
||||
class ArgsortStableDuplicatesModel(torch.nn.Module):
|
||||
"""Tests deterministic duplicate ordering for exported argsort.
|
||||
|
||||
``idx_dtype`` parameterizes the integer dtype of the returned indices so
|
||||
the test can verify dtype preservation across luminal's int dtype paths
|
||||
(LUM-486). PyTorch's argsort always produces int64; the cast at the end
|
||||
lets us drive the same model toward int32 or int64 outputs.
|
||||
"""
|
||||
"""Tests deterministic duplicate ordering for exported argsort."""
|
||||
|
||||
SORT_DIM = 1
|
||||
|
||||
def __init__(self, idx_dtype: torch.dtype = torch.int64) -> None:
|
||||
super().__init__()
|
||||
self.idx_dtype = idx_dtype
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.argsort(x, dim=self.SORT_DIM).to(self.idx_dtype)
|
||||
return torch.argsort(x, dim=self.SORT_DIM)
|
||||
|
||||
|
||||
class TinyMoERoutingModel(torch.nn.Module):
|
||||
"""Minimal deterministic MoE-style routing proof for PT2/native and CUDA.
|
||||
|
||||
``idx_dtype`` casts the integer-valued outputs (routed_indices, dispatch,
|
||||
group_ids) to the requested dtype so the test can sweep int32 and int64
|
||||
output paths (LUM-486). Internal indices stay int64 because torch.gather
|
||||
/ torch.scatter require int64 index tensors.
|
||||
"""
|
||||
"""Minimal deterministic MoE-style routing proof for PT2/native and CUDA."""
|
||||
|
||||
TOP_K = 2
|
||||
ROUTING_DIM = -1
|
||||
@@ -1656,9 +1640,8 @@ class TinyMoERoutingModel(torch.nn.Module):
|
||||
DISPATCH_ON = 1
|
||||
GROUP_SIZE = 2
|
||||
|
||||
def __init__(self, idx_dtype: torch.dtype = torch.int64) -> None:
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.idx_dtype = idx_dtype
|
||||
self.register_buffer(
|
||||
"expert_scale",
|
||||
torch.tensor([1.5, -0.5, 2.0, 0.25], dtype=torch.float32),
|
||||
@@ -1694,11 +1677,11 @@ class TinyMoERoutingModel(torch.nn.Module):
|
||||
group_ids = torch.floor_divide(routed_indices, self.GROUP_SIZE)
|
||||
routing_sign = torch.sign(masked_values)
|
||||
return (
|
||||
routed_indices.to(self.idx_dtype),
|
||||
routed_indices,
|
||||
masked_values,
|
||||
dispatch.to(self.idx_dtype),
|
||||
dispatch,
|
||||
inactive_mask,
|
||||
group_ids.to(self.idx_dtype),
|
||||
group_ids,
|
||||
routing_sign,
|
||||
)
|
||||
|
||||
@@ -1969,24 +1952,6 @@ class Conv1dBiasModel(torch.nn.Module):
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class Conv1dFloorDivPositionalModel(torch.nn.Module):
|
||||
"""Whisper-like Conv1d downsample followed by a fixed positional add."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv1 = torch.nn.Conv1d(8, 16, kernel_size=3, padding=1, bias=True)
|
||||
self.conv2 = torch.nn.Conv1d(
|
||||
16, 16, kernel_size=3, stride=2, padding=1, bias=True
|
||||
)
|
||||
self.position = torch.nn.Parameter(torch.randn(15, 16))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = torch.nn.functional.gelu(self.conv1(x))
|
||||
x = torch.nn.functional.gelu(self.conv2(x))
|
||||
x = x.squeeze(0).transpose(0, 1)
|
||||
return x + self.position
|
||||
|
||||
|
||||
class Conv2dNoPadModel(torch.nn.Module):
|
||||
"""Conv2d with no padding: output spatial dims shrink by (kernel-1)."""
|
||||
|
||||
@@ -2236,127 +2201,3 @@ class MambaConvBlockModel(torch.nn.Module):
|
||||
return self.out_proj(
|
||||
torch.nn.functional.silu(x_part) * torch.nn.functional.silu(z)
|
||||
)
|
||||
|
||||
|
||||
class BitwiseOrTestModel(torch.nn.Module):
|
||||
"""Tests bitwise_or on boolean tensors — the pattern Gemma-style models
|
||||
emit when fusing sliding-window and full-attention masks
|
||||
(`mask = sliding_mask | full_mask`)."""
|
||||
|
||||
def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
|
||||
return a | b
|
||||
|
||||
|
||||
class GroupedMMFallbackTestModel(torch.nn.Module):
|
||||
"""Tests transformers::grouped_mm_fallback — the per-expert batched
|
||||
matmul HF MoE models emit (DeepSeek-V2, Qwen-MoE, Mixtral, etc.).
|
||||
|
||||
Calls the registered custom_op directly with shapes that match a
|
||||
realistic MoE expert dispatch: input is `(S, K)` of tokens already
|
||||
sorted by expert, weight is `(G, K, N)` per-expert weights, offs is
|
||||
`(G,)` cumulative token counts.
|
||||
"""
|
||||
|
||||
def forward(
|
||||
self, input: torch.Tensor, weight: torch.Tensor, offs: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
return torch.ops.transformers.grouped_mm_fallback(input, weight, offs)
|
||||
|
||||
|
||||
class BoolMaskAssignIntModel(torch.nn.Module):
|
||||
"""`x[mask] = scalar` on integer data with a Bool-dtype mask whose shape
|
||||
matches `x`.
|
||||
|
||||
PyTorch decomposes this to `aten.index_put_(x, [mask], scalar)`. The
|
||||
correct lowering is `where(mask, scalar, x)` — NOT a scatter into Int(mask)
|
||||
positions. Pre-fix, the compiled output silently corrupted row 0 of `x`
|
||||
even when the mask was all-False (the silent-data-corruption case driven
|
||||
by Gemma-4's multimodal_mask path).
|
||||
"""
|
||||
|
||||
def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
||||
out = x.clone()
|
||||
out[mask] = 99
|
||||
return out
|
||||
|
||||
|
||||
class BoolMaskAssignFloatModel(torch.nn.Module):
|
||||
"""Same as BoolMaskAssignIntModel but with float data + a float scalar.
|
||||
|
||||
Verifies the `where` blend works for non-integer dtypes too.
|
||||
"""
|
||||
|
||||
def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
||||
out = x.clone()
|
||||
out[mask] = 7.5
|
||||
return out
|
||||
|
||||
|
||||
class BoolMaskAssign3DModel(torch.nn.Module):
|
||||
"""Multi-dimensional `x[mask] = scalar` — Bool mask shape must match `x`'s
|
||||
full shape, not just be 1D. Catches regressions where the bool-mask
|
||||
detection only works at one specific rank.
|
||||
"""
|
||||
|
||||
def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
||||
out = x.clone()
|
||||
out[mask] = -1.0
|
||||
return out
|
||||
|
||||
|
||||
class IntIndexAssignScalarModel(torch.nn.Module):
|
||||
"""`x[indices] = scalar_tensor` with a rank-1 index tensor and a 0-D
|
||||
scalar value. After PT2 decomposition this hits the scatter path with a
|
||||
scalar src; the lowering must broadcast the scalar across all indexed
|
||||
positions (zero-stride padding in `GraphTensor::scatter`).
|
||||
"""
|
||||
|
||||
def forward(self, x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
|
||||
out = x.clone()
|
||||
out[indices] = 42.0
|
||||
return out
|
||||
|
||||
|
||||
class SdpaBasicModel(torch.nn.Module):
|
||||
"""`F.scaled_dot_product_attention(q, k, v)` with no mask, no causal flag.
|
||||
|
||||
Lowers to `aten._scaled_dot_product_*_attention` (variant chosen by
|
||||
PyTorch based on device/dtype). Tests the default-scale matmul+softmax
|
||||
path. Inputs are 4-D `(B, H, S, D)`.
|
||||
"""
|
||||
|
||||
def forward(
|
||||
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
return torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
||||
|
||||
|
||||
class SdpaCausalModel(torch.nn.Module):
|
||||
"""`F.scaled_dot_product_attention(q, k, v, is_causal=True)`.
|
||||
|
||||
Tests the `is_causal` branch of `translate_sdpa`, which materializes a
|
||||
triangular mask and adds `-1e9 * mask` to the pre-softmax scores.
|
||||
"""
|
||||
|
||||
def forward(
|
||||
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
return torch.nn.functional.scaled_dot_product_attention(q, k, v, is_causal=True)
|
||||
|
||||
|
||||
class SdpaWithBiasModel(torch.nn.Module):
|
||||
"""SDPA with an additive `attn_mask` bias (float, broadcast over heads).
|
||||
|
||||
Tests the additive-bias branch of `translate_sdpa`. The bias has shape
|
||||
`(1, 1, S_q, S_k)` so it broadcasts across batch/head prefix dims of
|
||||
the scores tensor.
|
||||
"""
|
||||
|
||||
def forward(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=bias)
|
||||
|
||||
@@ -1,275 +0,0 @@
|
||||
"""Qwen3-MoE HuggingFace model integration tests.
|
||||
|
||||
Tests progressively larger HuggingFace `Qwen3MoeForCausalLM` configs through
|
||||
the PyTorch -> PT2 -> luminal pipeline via `torch.compile(..., backend=
|
||||
luminal_backend)`. Qwen3-MoE shares the dense Qwen3 backbone but replaces
|
||||
the FFN with a top-k router over `num_experts` independent expert MLPs —
|
||||
which exercises code paths the dense tests don't:
|
||||
|
||||
- `aten._grouped_mm.default` (gather-then-matmul lowering, PR #298)
|
||||
- bf16 `KernelScatter` (KV cache scatter on a non-F32 dtype)
|
||||
- `aten.empty_permuted` / `aten.histc` (MoE expert dispatch and
|
||||
tokens-per-expert counts)
|
||||
- clamp-on-Int dtype handling (router top-k indices flowing into
|
||||
`aten.clamp`)
|
||||
|
||||
The smaller configs run on GPU in seconds; the "real config" case loads
|
||||
the actual `Qwen/Qwen3-30B-A3B` arch (128 experts, top-8) with
|
||||
`num_hidden_layers` overridden to 1 so a full-width compile is
|
||||
exercised on random weights.
|
||||
|
||||
Together these guard the regression-and-fix story that landed alongside:
|
||||
the bf16 KernelScatter dtype-aware vec count, the `aten.empty(_permuted)`
|
||||
/ `aten.histc` translator entries, and the
|
||||
`maximum_f32`-on-Int casting fix.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch._dynamo
|
||||
|
||||
from luminal import luminal_backend
|
||||
|
||||
|
||||
# ────────────────────────────────────────────────────────────────────────
|
||||
# Helpers
|
||||
# ────────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _make_qwen3_moe_config(
|
||||
hidden_size: int,
|
||||
num_attention_heads: int,
|
||||
num_key_value_heads: int,
|
||||
num_hidden_layers: int,
|
||||
intermediate_size: int,
|
||||
moe_intermediate_size: int,
|
||||
num_experts: int,
|
||||
num_experts_per_tok: int,
|
||||
vocab_size: int,
|
||||
):
|
||||
"""Create a Qwen3MoeConfig with use_cache=False and eager attention.
|
||||
|
||||
Shared helper so each test only specifies the scaling knobs that matter
|
||||
for that case.
|
||||
"""
|
||||
from transformers import Qwen3MoeConfig
|
||||
|
||||
return Qwen3MoeConfig(
|
||||
hidden_size=hidden_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
num_hidden_layers=num_hidden_layers,
|
||||
intermediate_size=intermediate_size,
|
||||
moe_intermediate_size=moe_intermediate_size,
|
||||
num_experts=num_experts,
|
||||
num_experts_per_tok=num_experts_per_tok,
|
||||
vocab_size=vocab_size,
|
||||
max_position_embeddings=128,
|
||||
use_cache=False,
|
||||
attn_implementation="eager",
|
||||
)
|
||||
|
||||
|
||||
def _run_hf_qwen3_moe_test(config, device: torch.device, atol: float):
|
||||
"""Run a HuggingFace Qwen3MoeForCausalLM test with the given config.
|
||||
|
||||
Compiles the model with `luminal_backend`, runs both eager and compiled
|
||||
on the same input, asserts the logits match within `atol`.
|
||||
"""
|
||||
from transformers import Qwen3MoeForCausalLM
|
||||
|
||||
model = Qwen3MoeForCausalLM(config).eval().to(device)
|
||||
compiled = torch.compile(model, backend=luminal_backend)
|
||||
input_ids = torch.tensor([[1, 2, 3, 4]], device=device)
|
||||
with torch.no_grad():
|
||||
ref = model(input_ids)
|
||||
out = compiled(input_ids)
|
||||
assert torch.allclose(out.logits, ref.logits, atol=atol), (
|
||||
f"max_diff={torch.max(torch.abs(out.logits - ref.logits)).item():.2e}"
|
||||
)
|
||||
|
||||
|
||||
# ────────────────────────────────────────────────────────────────────────
|
||||
# Tests — progressively larger configs
|
||||
# ────────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_hf_qwen3_moe_tiny(device: torch.device):
|
||||
"""HuggingFace Qwen3MoeForCausalLM — tiny: 2 experts, top-1 routing.
|
||||
|
||||
Smallest config that still exercises the MoE expert dispatch
|
||||
(`aten._grouped_mm`). Top-1 routing keeps the test simple while still
|
||||
validating the gather-then-matmul lowering path.
|
||||
"""
|
||||
config = _make_qwen3_moe_config(
|
||||
hidden_size=32,
|
||||
num_attention_heads=2,
|
||||
num_key_value_heads=1,
|
||||
num_hidden_layers=1,
|
||||
intermediate_size=64,
|
||||
moe_intermediate_size=64,
|
||||
num_experts=2,
|
||||
num_experts_per_tok=1,
|
||||
vocab_size=128,
|
||||
)
|
||||
_run_hf_qwen3_moe_test(config, device, atol=1e-5)
|
||||
|
||||
|
||||
def test_hf_qwen3_moe_small(device: torch.device):
|
||||
"""HuggingFace Qwen3MoeForCausalLM — small: 4 experts, top-2 routing."""
|
||||
config = _make_qwen3_moe_config(
|
||||
hidden_size=128,
|
||||
num_attention_heads=4,
|
||||
num_key_value_heads=2,
|
||||
num_hidden_layers=1,
|
||||
intermediate_size=256,
|
||||
moe_intermediate_size=128,
|
||||
num_experts=4,
|
||||
num_experts_per_tok=2,
|
||||
vocab_size=512,
|
||||
)
|
||||
_run_hf_qwen3_moe_test(config, device, atol=1e-4)
|
||||
|
||||
|
||||
def test_hf_qwen3_moe_medium(device: torch.device):
|
||||
"""HuggingFace Qwen3MoeForCausalLM — medium: 8 experts, top-2, 2 layers.
|
||||
|
||||
Two layers means the e-graph crosses a layer boundary, which is where
|
||||
the late-memory-analysis cleanup pass operates differently than
|
||||
single-layer cases.
|
||||
"""
|
||||
config = _make_qwen3_moe_config(
|
||||
hidden_size=128,
|
||||
num_attention_heads=4,
|
||||
num_key_value_heads=2,
|
||||
num_hidden_layers=2,
|
||||
intermediate_size=256,
|
||||
moe_intermediate_size=128,
|
||||
num_experts=8,
|
||||
num_experts_per_tok=2,
|
||||
vocab_size=512,
|
||||
)
|
||||
_run_hf_qwen3_moe_test(config, device, atol=1e-4)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_hf_qwen3_moe_real_config_1layer(device: torch.device):
|
||||
"""HuggingFace Qwen3MoeForCausalLM — real Qwen3-30B-A3B architecture, 1 layer.
|
||||
|
||||
Loads `Qwen/Qwen3-30B-A3B`'s AutoConfig (128 experts, top-8 routing,
|
||||
2048 hidden) and overrides `num_hidden_layers=1`. Random weights —
|
||||
cheap smoke that the production-shape MoE *layer* compiles end-to-end
|
||||
through luminal_backend without paying the full 48-layer cost.
|
||||
"""
|
||||
from transformers import AutoConfig, Qwen3MoeForCausalLM
|
||||
|
||||
config = AutoConfig.from_pretrained("Qwen/Qwen3-30B-A3B")
|
||||
config.num_hidden_layers = 1
|
||||
config.use_cache = False
|
||||
config._attn_implementation = "eager"
|
||||
|
||||
model = Qwen3MoeForCausalLM(config).eval().to(device)
|
||||
compiled = torch.compile(model, backend=luminal_backend)
|
||||
input_ids = torch.tensor([[1, 2, 3, 4]], device=device)
|
||||
with torch.no_grad():
|
||||
ref = model(input_ids)
|
||||
out = compiled(input_ids)
|
||||
assert torch.allclose(out.logits, ref.logits, atol=1e-3), (
|
||||
f"max_diff={torch.max(torch.abs(out.logits - ref.logits)).item():.2e}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_hf_qwen3_moe_real_config_full(device: torch.device):
|
||||
"""HuggingFace Qwen3MoeForCausalLM — full Qwen3-30B-A3B, pretrained.
|
||||
|
||||
Loads the real `Qwen/Qwen3-30B-A3B` checkpoint at its native bf16
|
||||
dtype: 48 hidden layers, 128 experts, top-8 routing, 2048 hidden —
|
||||
i.e. the production architecture, no `num_hidden_layers` override.
|
||||
This is the end-to-end "the full MoE compiles" regression guard;
|
||||
the 1-layer variant above is the cheap smoke.
|
||||
|
||||
Asserts the **compile + run** path completes and the compiled
|
||||
forward produces *finite* output (no NaN / no Inf). It does NOT
|
||||
assert tight numerical equivalence with eager: at this depth the
|
||||
egglog search is non-deterministic enough that the two paths can
|
||||
diverge structurally (same general magnitudes, different per-element
|
||||
values). Tight numerical equivalence at full scale is tracked as
|
||||
follow-up work — the smaller-config tests above use atol≤1e-3 and
|
||||
cover the per-op correctness that this test cannot.
|
||||
|
||||
Compared to the 1-layer test this primarily catches:
|
||||
- egglog cleanup behaviour over a 48-layer-wide e-graph (the
|
||||
`egglog_utils.rs:1286: No valid graphs` panic surfaces here
|
||||
if the cleanup cascade re-regresses on MoE root-eclasses);
|
||||
- per-layer plumbing of residual stream + KV state that
|
||||
single-layer tests don't exercise;
|
||||
- any bf16-specific code path (e.g. KernelScatter OOB) that's
|
||||
masked at fp32.
|
||||
|
||||
Memory profile on H200/H100:
|
||||
- bf16 pretrained weights: ~60 GB
|
||||
- single-token input keeps activations & router state trivial
|
||||
- peak observed during compiled forward: ~75 GB total
|
||||
"""
|
||||
import gc
|
||||
|
||||
from transformers import AutoConfig, Qwen3MoeForCausalLM
|
||||
|
||||
# Aggressively release any allocator state from prior tests in the
|
||||
# same process — at this scale we don't have headroom to absorb it.
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
config = AutoConfig.from_pretrained("Qwen/Qwen3-30B-A3B")
|
||||
config.use_cache = False
|
||||
config._attn_implementation = "eager"
|
||||
|
||||
model = (
|
||||
Qwen3MoeForCausalLM.from_pretrained(
|
||||
"Qwen/Qwen3-30B-A3B",
|
||||
config=config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
.eval()
|
||||
.to(device)
|
||||
)
|
||||
compiled = torch.compile(model, backend=luminal_backend)
|
||||
# Single-token input — the full-depth compile is the regression target,
|
||||
# not multi-token throughput (which the bench covers separately).
|
||||
input_ids = torch.tensor([[1]], device=device)
|
||||
|
||||
with torch.no_grad():
|
||||
# Eager forward — confirms the test setup is sane (HF is happy).
|
||||
ref = model(input_ids)
|
||||
ref_max = ref.logits.float().abs().max().item()
|
||||
assert torch.isfinite(ref.logits).all(), (
|
||||
"eager forward produced non-finite logits — test setup is broken, "
|
||||
"not a luminal regression"
|
||||
)
|
||||
del ref
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Compiled forward — the actual regression target.
|
||||
out = compiled(input_ids)
|
||||
|
||||
out_logits = out.logits.float()
|
||||
n_nan = int(out_logits.isnan().sum().item())
|
||||
n_inf = int(out_logits.isinf().sum().item())
|
||||
out_max = out_logits.abs().max().item()
|
||||
|
||||
assert n_nan == 0 and n_inf == 0, (
|
||||
f"compiled forward produced non-finite logits: {n_nan} NaNs, "
|
||||
f"{n_inf} Infs (eager max abs={ref_max:.2f}, compiled max abs={out_max:.2f})"
|
||||
)
|
||||
# Sanity-check magnitude: compiled output should be in the same ballpark
|
||||
# as eager — within an order of magnitude of the eager logits' scale.
|
||||
# Catches the failure mode where some kernel silently produces
|
||||
# near-zero or near-Inf values that pass the finite check.
|
||||
assert 0.1 * ref_max <= out_max <= 10.0 * ref_max, (
|
||||
f"compiled max abs={out_max:.2f} is out of band vs eager max abs={ref_max:.2f} "
|
||||
f"(>10× off in either direction); likely a numerical/scale bug"
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,174 +0,0 @@
|
||||
"""Whisper integration tests for the luminal torch.compile backend.
|
||||
|
||||
These tests build a PyTorch port of ``openai/whisper-tiny.en`` (the same one
|
||||
exercised by ``examples/whisper.py``) and verify that running it through
|
||||
``torch.compile(..., backend=luminal_backend)`` produces logits that match the
|
||||
eager-mode PyTorch reference, both with random-init small configs and with the
|
||||
real pretrained tiny.en weights.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch._dynamo
|
||||
|
||||
# Reuse the PyTorch port defined in the example script so we test exactly the
|
||||
# code that runs the demo.
|
||||
EXAMPLES_DIR = Path(__file__).resolve().parent.parent / "examples"
|
||||
sys.path.insert(0, str(EXAMPLES_DIR))
|
||||
import whisper as whisper_demo # noqa: E402 (path-modified import)
|
||||
|
||||
from luminal import luminal_backend # noqa: E402
|
||||
|
||||
|
||||
def _make_small_whisper(seed: int = 0) -> whisper_demo.Whisper:
|
||||
torch.manual_seed(seed)
|
||||
model = whisper_demo.Whisper().eval()
|
||||
return model
|
||||
|
||||
|
||||
def _max_diff(a: torch.Tensor, b: torch.Tensor) -> float:
|
||||
return torch.max(torch.abs(a - b)).item()
|
||||
|
||||
|
||||
def test_whisper_attention_forward(device: torch.device):
|
||||
"""Whisper self-attention: Q/K/V/out projections + scaled dot-product."""
|
||||
torch.manual_seed(0)
|
||||
attn = whisper_demo.WhisperAttention().eval().to(device)
|
||||
compiled: Callable = torch.compile(attn, backend=luminal_backend)
|
||||
x = torch.rand((4, whisper_demo.D_MODEL), device=device)
|
||||
with torch.no_grad():
|
||||
ref = attn(x)
|
||||
out = compiled(x)
|
||||
if isinstance(out, tuple):
|
||||
out = out[0]
|
||||
assert torch.allclose(out, ref, atol=1e-4), f"max_diff={_max_diff(out, ref):.2e}"
|
||||
|
||||
|
||||
def test_whisper_encoder_layer(device: torch.device):
|
||||
"""Single encoder block: pre-norm self-attention + FFN with GELU.
|
||||
|
||||
Tolerance is loose because luminal uses the tanh GELU approximation rather
|
||||
than the exact erf form PyTorch uses for ``aten.gelu.default``.
|
||||
"""
|
||||
torch.manual_seed(0)
|
||||
layer = whisper_demo.EncoderLayer().eval().to(device)
|
||||
compiled: Callable = torch.compile(layer, backend=luminal_backend)
|
||||
x = torch.rand((8, whisper_demo.D_MODEL), device=device)
|
||||
with torch.no_grad():
|
||||
ref = layer(x)
|
||||
out = compiled(x)
|
||||
if isinstance(out, tuple):
|
||||
out = out[0]
|
||||
assert torch.allclose(out, ref, atol=1e-3), f"max_diff={_max_diff(out, ref):.2e}"
|
||||
|
||||
|
||||
def test_whisper_decoder_layer(device: torch.device):
|
||||
"""Single decoder block: causal self-attention + cross-attention + FFN."""
|
||||
torch.manual_seed(0)
|
||||
layer = whisper_demo.DecoderLayer().eval().to(device)
|
||||
compiled: Callable = torch.compile(layer, backend=luminal_backend)
|
||||
x = torch.rand((4, whisper_demo.D_MODEL), device=device)
|
||||
xa = torch.rand((16, whisper_demo.D_MODEL), device=device)
|
||||
with torch.no_grad():
|
||||
ref = layer(x, xa)
|
||||
out = compiled(x, xa)
|
||||
if isinstance(out, tuple):
|
||||
out = out[0]
|
||||
assert torch.allclose(out, ref, atol=1e-3), f"max_diff={_max_diff(out, ref):.2e}"
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_whisper_encoder_random_init(device: torch.device):
|
||||
"""Full encoder over a random mel: 2 conv stems + 4 transformer blocks."""
|
||||
model = _make_small_whisper().to(device)
|
||||
compiled: Callable = torch.compile(model.encoder, backend=luminal_backend)
|
||||
mel = torch.rand((whisper_demo.N_MELS, 3000), device=device)
|
||||
with torch.no_grad():
|
||||
ref = model.encoder(mel)
|
||||
out = compiled(mel)
|
||||
if isinstance(out, tuple):
|
||||
out = out[0]
|
||||
assert torch.allclose(out, ref, atol=1e-3), f"max_diff={_max_diff(out, ref):.2e}"
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_whisper_full_random_init_one_step(device: torch.device):
|
||||
"""End-to-end Whisper forward (encoder + decoder for one step) with random weights.
|
||||
|
||||
Tolerance is loose because errors accumulate across the conv stems plus the
|
||||
8 transformer blocks, and luminal uses the tanh GELU approximation rather
|
||||
than the exact erf form that PyTorch ``aten.gelu.default`` evaluates.
|
||||
"""
|
||||
model = _make_small_whisper().to(device)
|
||||
compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
mel = torch.rand((whisper_demo.N_MELS, 3000), device=device)
|
||||
tokens = torch.tensor(
|
||||
[whisper_demo.TOKEN_SOT, whisper_demo.TOKEN_NO_TIMESTAMPS],
|
||||
dtype=torch.long,
|
||||
device=device,
|
||||
)
|
||||
with torch.no_grad():
|
||||
ref = model(mel, tokens)
|
||||
out = compiled(mel, tokens)
|
||||
if isinstance(out, tuple):
|
||||
out = out[0]
|
||||
assert torch.allclose(out, ref, atol=5e-2, rtol=1e-3), (
|
||||
f"max_diff={_max_diff(out, ref):.2e}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_whisper_tiny_en_pretrained_first_token(device: torch.device):
|
||||
"""Real whisper-tiny.en weights: first generated token must match reference.
|
||||
|
||||
Uses the bundled JFK sample if available; otherwise a zero-mel placeholder
|
||||
(the assertion is purely compiled-vs-reference equality, not transcription
|
||||
correctness).
|
||||
"""
|
||||
model = whisper_demo.Whisper().eval()
|
||||
whisper_demo.load_hf_weights_into(model)
|
||||
model = model.to(device)
|
||||
|
||||
# Try to use the real audio so the comparison is on a realistic mel.
|
||||
audio_path = whisper_demo.find_default_audio()
|
||||
if audio_path is None:
|
||||
mel = torch.zeros((whisper_demo.N_MELS, 3000), device=device)
|
||||
else:
|
||||
from transformers import WhisperFeatureExtractor
|
||||
|
||||
audio = whisper_demo.load_wav_16k_mono(audio_path)
|
||||
fe = WhisperFeatureExtractor.from_pretrained(whisper_demo.REPO_ID)
|
||||
mel = (
|
||||
fe(audio, sampling_rate=16000, return_tensors="pt")
|
||||
.input_features[0]
|
||||
.to(device)
|
||||
)
|
||||
|
||||
tokens = torch.tensor(
|
||||
[whisper_demo.TOKEN_SOT, whisper_demo.TOKEN_NO_TIMESTAMPS],
|
||||
dtype=torch.long,
|
||||
device=device,
|
||||
)
|
||||
|
||||
torch._dynamo.reset()
|
||||
compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
with torch.no_grad():
|
||||
ref = model(mel, tokens)
|
||||
out = compiled(mel, tokens)
|
||||
if isinstance(out, tuple):
|
||||
out = out[0]
|
||||
# Logits diverge slightly due to the GELU approximation; what matters end
|
||||
# to end is that the greedy argmax (with whisper's special-token suppression)
|
||||
# picks the same token.
|
||||
ref_tok = whisper_demo.greedy_decode(ref[-1], suppress_first_eot=True)
|
||||
out_tok = whisper_demo.greedy_decode(out[-1], suppress_first_eot=True)
|
||||
assert ref_tok == out_tok, (
|
||||
f"first token mismatch: ref={ref_tok}, compiled={out_tok}, "
|
||||
f"logits max_diff={_max_diff(out, ref):.2e}"
|
||||
)
|
||||
@@ -1,117 +0,0 @@
|
||||
"""YOLO v11n end-to-end tests using the luminal_cuda_lite backend.
|
||||
|
||||
This module exercises the YOLO v11n building blocks (Conv + BN, C3k2, the
|
||||
SPPF/C2PSA backbone, the Detect head) and finally the full model through
|
||||
``torch.compile(..., backend=luminal_backend)``.
|
||||
|
||||
The smaller per-block tests are useful when triaging which part of the
|
||||
architecture starts diverging: incrementally building a model up is much
|
||||
easier than debugging a 100-layer mismatch in one go.
|
||||
|
||||
Marked ``slow`` because the first run downloads ~6 MB of weights and the
|
||||
luminal e-graph compile of the full model is non-trivial. Run with::
|
||||
|
||||
uv run pytest tests/test_yolo_v11.py -v -s
|
||||
"""
|
||||
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch._dynamo
|
||||
|
||||
from luminal import luminal_backend
|
||||
|
||||
|
||||
def _require_cuda(device: torch.device):
|
||||
if device.type != "cuda":
|
||||
pytest.skip("YOLO v11 examples require the CUDA backend.")
|
||||
|
||||
|
||||
def _require_ultralytics():
|
||||
try:
|
||||
from ultralytics import YOLO # noqa: F401
|
||||
except ImportError as exc: # pragma: no cover
|
||||
pytest.skip(f"ultralytics not installed: {exc}")
|
||||
|
||||
|
||||
def _yolo_model(device: torch.device, decode_only: bool = True):
|
||||
"""Load yolo11n with BN folded into Conv. Returns the eager torch model."""
|
||||
from ultralytics import YOLO
|
||||
|
||||
yolo = YOLO("yolo11n.pt")
|
||||
pt_model = yolo.model.eval()
|
||||
pt_model.fuse()
|
||||
if decode_only:
|
||||
pt_model.model[-1].export = True
|
||||
pt_model.to(device)
|
||||
return pt_model
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_yolo_v11n_first_three_layers(device: torch.device):
|
||||
"""Compile only the first three layers (Conv, Conv, C3k2) — exercises the
|
||||
chunk + bottleneck residual + concat pattern that's the trickiest piece
|
||||
of the model graph."""
|
||||
_require_cuda(device)
|
||||
_require_ultralytics()
|
||||
|
||||
pt_model = _yolo_model(device, decode_only=True)
|
||||
|
||||
class FirstThree(torch.nn.Module):
|
||||
def __init__(self, backbone):
|
||||
super().__init__()
|
||||
self.layers = torch.nn.ModuleList([backbone[i] for i in range(3)])
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
sub = FirstThree(pt_model.model).to(device).eval()
|
||||
torch.manual_seed(0)
|
||||
x = torch.rand(1, 3, 640, 640, dtype=torch.float32, device=device)
|
||||
|
||||
with torch.no_grad():
|
||||
ref = sub(x)
|
||||
torch._dynamo.reset()
|
||||
compiled: Callable = torch.compile(sub, backend=luminal_backend)
|
||||
with torch.no_grad():
|
||||
out = compiled(x)
|
||||
|
||||
max_diff = torch.max(torch.abs(out - ref)).item()
|
||||
print(f"yolo11n[:3] max_diff vs PyTorch eager: {max_diff:.4e}")
|
||||
assert torch.allclose(out, ref, atol=1e-3), (
|
||||
f"yolo11n[:3] outputs differ — max_diff={max_diff:.4e}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_yolo_v11n_end_to_end(device: torch.device):
|
||||
"""Full yolo11n forward via torch.compile. The compile may be slow on
|
||||
machines without strong egglog parallelism — see the example README for
|
||||
the standalone Rust binary alternative."""
|
||||
_require_cuda(device)
|
||||
_require_ultralytics()
|
||||
|
||||
pt_model = _yolo_model(device)
|
||||
torch.manual_seed(0)
|
||||
x = torch.rand(1, 3, 640, 640, dtype=torch.float32, device=device)
|
||||
|
||||
with torch.no_grad():
|
||||
ref = pt_model(x)
|
||||
if isinstance(ref, (list, tuple)):
|
||||
ref = ref[0]
|
||||
|
||||
torch._dynamo.reset()
|
||||
compiled: Callable = torch.compile(pt_model, backend=luminal_backend)
|
||||
with torch.no_grad():
|
||||
out = compiled(x)
|
||||
if isinstance(out, (list, tuple)):
|
||||
out = out[0]
|
||||
|
||||
max_diff = torch.max(torch.abs(out - ref)).item()
|
||||
print(f"YOLO v11n max_diff vs PyTorch eager: {max_diff:.4e}")
|
||||
assert torch.allclose(out, ref, atol=1e-3), (
|
||||
f"YOLO v11n outputs differ from PyTorch eager — max_diff={max_diff:.4e}"
|
||||
)
|
||||
@@ -315,13 +315,8 @@ fn hlir_attention(
|
||||
let v_cache_out = v_new.scatter(scatter_idx, v_cache_in);
|
||||
|
||||
// Slice to valid range
|
||||
let mut k_full = k_cache_out.slice((.., ..total_seq, ..));
|
||||
let mut v_full = v_cache_out.slice((.., ..total_seq, ..));
|
||||
// LUM-545: model invariant `prev + seq <= max_seq`, but the frontend
|
||||
// cannot yet propagate expression-bound assertions, so `slice` reports
|
||||
// `min(max_seq, p+s)`. Normalize the visible cache axis to `total_seq`.
|
||||
k_full.shape.dims[1] = total_seq;
|
||||
v_full.shape.dims[1] = total_seq;
|
||||
let k_full = k_cache_out.slice((.., ..total_seq, ..));
|
||||
let v_full = v_cache_out.slice((.., ..total_seq, ..));
|
||||
|
||||
// GQA expand: [N_KV_HEADS, total_seq, HEAD_DIM] -> [N_HEADS, total_seq, HEAD_DIM]
|
||||
let k_3d = k_full.expand_dim(1, KV_GROUPS).merge_dims(0, 1);
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
#[path = "../../../examples_common/benchmark_stdio.rs"]
|
||||
mod benchmark_stdio;
|
||||
mod hf;
|
||||
mod model;
|
||||
|
||||
@@ -8,14 +6,18 @@ use luminal::prelude::*;
|
||||
use luminal_cuda_lite::{cudarc::driver::CudaContext, runtime::CudaRuntime};
|
||||
use model::*;
|
||||
use rustc_hash::FxHashSet;
|
||||
use std::{
|
||||
io::Write,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use std::{io::Write, time::Duration};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
const REPO_ID: &str = "google/gemma-4-26B-A4B";
|
||||
|
||||
fn env_usize(name: &str, default: usize) -> usize {
|
||||
std::env::var(name)
|
||||
.ok()
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(default)
|
||||
}
|
||||
|
||||
fn env_bool(name: &str) -> bool {
|
||||
std::env::var(name)
|
||||
.ok()
|
||||
@@ -23,10 +25,9 @@ fn env_bool(name: &str) -> bool {
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let stdio = benchmark_stdio::enabled();
|
||||
let max_seq_len = benchmark_stdio::env_usize("MAX_SEQ_LEN", 4096);
|
||||
let gen_tokens = benchmark_stdio::env_usize("GEN_TOKENS", 30);
|
||||
let search_graphs = benchmark_stdio::env_usize("SEARCH_GRAPHS", 50);
|
||||
let max_seq_len = env_usize("MAX_SEQ_LEN", 4096);
|
||||
let gen_tokens = env_usize("GEN_TOKENS", 30);
|
||||
let search_graphs = env_usize("SEARCH_GRAPHS", 50);
|
||||
let prompt = std::env::var("PROMPT").unwrap_or_else(|_| "The capital of France is".to_string());
|
||||
let print_token_ids = env_bool("PRINT_TOKEN_IDS");
|
||||
|
||||
@@ -37,6 +38,11 @@ fn main() {
|
||||
println!("Using model directory: {}", model_dir.display());
|
||||
|
||||
let tokenizer = Tokenizer::from_file(model_dir.join("tokenizer.json")).unwrap();
|
||||
let prompt_tokens = tokenizer
|
||||
.encode(prompt.as_str(), true)
|
||||
.unwrap()
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
|
||||
let mut cx = Graph::default();
|
||||
let input = cx.named_tensor("input", 's').as_dtype(DType::Int);
|
||||
@@ -57,14 +63,11 @@ fn main() {
|
||||
let weights_path = model_dir.join("model_combined.safetensors");
|
||||
runtime.load_safetensors(&cx, weights_path.to_str().unwrap());
|
||||
|
||||
let reset_cache = |runtime: &mut CudaRuntime| {
|
||||
for layer in 0..LAYERS {
|
||||
let cache_bytes = cache_bytes_for_layer(layer, max_seq_len);
|
||||
runtime.set_zeros(kv_cache.k_caches[layer], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[layer], cache_bytes);
|
||||
}
|
||||
};
|
||||
reset_cache(&mut runtime);
|
||||
for layer in 0..LAYERS {
|
||||
let cache_bytes = cache_bytes_for_layer(layer, max_seq_len);
|
||||
runtime.set_zeros(kv_cache.k_caches[layer], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[layer], cache_bytes);
|
||||
}
|
||||
|
||||
println!("Compiling...");
|
||||
cx.set_dim('s', 1);
|
||||
@@ -72,66 +75,15 @@ fn main() {
|
||||
runtime.set_data(input, vec![1]);
|
||||
runtime.set_data(pos_ids, vec![1]);
|
||||
runtime = cx.search(runtime, search_graphs);
|
||||
reset_cache(&mut runtime);
|
||||
|
||||
if stdio {
|
||||
benchmark_stdio::serve(|prompt| {
|
||||
reset_cache(&mut runtime);
|
||||
run_prompt(
|
||||
prompt,
|
||||
gen_tokens,
|
||||
print_token_ids,
|
||||
&tokenizer,
|
||||
&mut cx,
|
||||
&mut runtime,
|
||||
input,
|
||||
pos_ids,
|
||||
logits,
|
||||
&cache_outputs,
|
||||
&kv_cache,
|
||||
true,
|
||||
);
|
||||
});
|
||||
} else {
|
||||
run_prompt(
|
||||
&prompt,
|
||||
gen_tokens,
|
||||
print_token_ids,
|
||||
&tokenizer,
|
||||
&mut cx,
|
||||
&mut runtime,
|
||||
input,
|
||||
pos_ids,
|
||||
logits,
|
||||
&cache_outputs,
|
||||
&kv_cache,
|
||||
false,
|
||||
);
|
||||
for layer in 0..LAYERS {
|
||||
let cache_bytes = cache_bytes_for_layer(layer, max_seq_len);
|
||||
runtime.set_zeros(kv_cache.k_caches[layer], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[layer], cache_bytes);
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn run_prompt(
|
||||
prompt: &str,
|
||||
gen_tokens: usize,
|
||||
print_token_ids: bool,
|
||||
tokenizer: &Tokenizer,
|
||||
cx: &mut Graph,
|
||||
runtime: &mut CudaRuntime,
|
||||
input: GraphTensor,
|
||||
pos_ids: GraphTensor,
|
||||
logits: GraphTensor,
|
||||
cache_outputs: &[(GraphTensor, GraphTensor)],
|
||||
kv_cache: &KVCache,
|
||||
stdio: bool,
|
||||
) {
|
||||
let prompt_tokens = tokenizer.encode(prompt, true).unwrap().get_ids().to_vec();
|
||||
let query_start = Instant::now();
|
||||
|
||||
if !stdio {
|
||||
print!("{prompt}");
|
||||
std::io::stdout().flush().unwrap();
|
||||
}
|
||||
print!("{prompt}");
|
||||
std::io::stdout().flush().unwrap();
|
||||
|
||||
let mut prev_seq = 0usize;
|
||||
let mut fwd_durations = vec![];
|
||||
@@ -141,7 +93,7 @@ fn run_prompt(
|
||||
|
||||
const EOS_TOKEN: u32 = 1;
|
||||
|
||||
let prefill_start = Instant::now();
|
||||
let prefill_start = std::time::Instant::now();
|
||||
for &token in &prompt_tokens {
|
||||
cx.set_dim('s', 1);
|
||||
cx.set_dim('p', prev_seq);
|
||||
@@ -169,26 +121,12 @@ fn run_prompt(
|
||||
.unwrap()
|
||||
.0 as u32;
|
||||
generated_token_ids.push(next_token);
|
||||
let mut generated = 0usize;
|
||||
if stdio {
|
||||
if next_token != EOS_TOKEN {
|
||||
let decoded = tokenizer.decode(&[next_token], true).unwrap();
|
||||
benchmark_stdio::emit_token(&decoded);
|
||||
generated += 1;
|
||||
}
|
||||
} else {
|
||||
let decoded = tokenizer.decode(&[next_token], true).unwrap();
|
||||
print!("{decoded}");
|
||||
std::io::stdout().flush().unwrap();
|
||||
}
|
||||
print!("{}", tokenizer.decode(&[next_token], true).unwrap());
|
||||
std::io::stdout().flush().unwrap();
|
||||
seen_tokens.insert(next_token);
|
||||
|
||||
for _ in 1..gen_tokens {
|
||||
if stdio && next_token == EOS_TOKEN {
|
||||
break;
|
||||
}
|
||||
|
||||
let start = Instant::now();
|
||||
let start = std::time::Instant::now();
|
||||
cx.set_dim('s', 1);
|
||||
cx.set_dim('p', prev_seq);
|
||||
runtime.set_data(input, vec![next_token as i32]);
|
||||
@@ -227,21 +165,10 @@ fn run_prompt(
|
||||
break;
|
||||
}
|
||||
|
||||
let decoded = tokenizer.decode(&[next_token], true).unwrap();
|
||||
if stdio {
|
||||
benchmark_stdio::emit_token(&decoded);
|
||||
} else {
|
||||
print!("{decoded}");
|
||||
std::io::stdout().flush().unwrap();
|
||||
}
|
||||
generated += 1;
|
||||
print!("{}", tokenizer.decode(&[next_token], true).unwrap());
|
||||
std::io::stdout().flush().unwrap();
|
||||
fwd_durations.push(start.elapsed());
|
||||
}
|
||||
if stdio {
|
||||
benchmark_stdio::emit_eoq(generated, query_start);
|
||||
return;
|
||||
}
|
||||
|
||||
println!();
|
||||
if print_token_ids {
|
||||
println!("Generated token ids: {generated_token_ids:?}");
|
||||
|
||||
@@ -462,13 +462,8 @@ fn hlir_attention(
|
||||
let k_cache_out = k_new.scatter(scatter_idx, k_cache_in);
|
||||
let v_cache_out = v_new.scatter(scatter_idx, v_cache_in);
|
||||
|
||||
let mut k_full = k_cache_out.slice((.., ..total_seq, ..));
|
||||
let mut v_full = v_cache_out.slice((.., ..total_seq, ..));
|
||||
// LUM-545: model invariant `prev + seq <= max_seq`, but the frontend
|
||||
// cannot yet propagate expression-bound assertions, so `slice` reports
|
||||
// `min(max_seq, p+s)`. Normalize the visible cache axis to `total_seq`.
|
||||
k_full.shape.dims[1] = total_seq;
|
||||
v_full.shape.dims[1] = total_seq;
|
||||
let k_full = k_cache_out.slice((.., ..total_seq, ..));
|
||||
let v_full = v_cache_out.slice((.., ..total_seq, ..));
|
||||
|
||||
let k_3d = k_full.expand_dim(1, spec.kv_groups).merge_dims(0, 1);
|
||||
let v_3d = v_full.expand_dim(1, spec.kv_groups).merge_dims(0, 1);
|
||||
@@ -621,8 +616,6 @@ impl Gemma4SparseMoE {
|
||||
let hidden_exp = hidden.unsqueeze(2);
|
||||
let down_out = hidden_exp.matmul(down_gathered.transpose(2, 3)).squeeze(2);
|
||||
|
||||
let mut weights_exp = top_k_weights.unsqueeze(top_k_weights.dims().len());
|
||||
weights_exp.shape.expand(down_out.dims());
|
||||
(down_out * weights_exp).sum(n - 1)
|
||||
(down_out * top_k_weights.unsqueeze(top_k_weights.dims().len())).sum(n - 1)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
#[path = "../../../examples_common/benchmark_stdio.rs"]
|
||||
mod benchmark_stdio;
|
||||
mod hf;
|
||||
mod model;
|
||||
|
||||
@@ -9,36 +7,22 @@ use luminal_cuda_lite::{cudarc::driver::CudaContext, runtime::CudaRuntime};
|
||||
use luminal_tracing::*;
|
||||
use model::*;
|
||||
use rustc_hash::FxHashSet;
|
||||
use std::{
|
||||
io::Write,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use std::{io::Write, time::Duration};
|
||||
use tokenizers::Tokenizer;
|
||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
||||
|
||||
const REPO_ID: &str = "NousResearch/Meta-Llama-3-8B-Instruct";
|
||||
|
||||
fn main() {
|
||||
let stdio = benchmark_stdio::enabled();
|
||||
let max_seq_len = 4096;
|
||||
let gen_tokens = if stdio {
|
||||
benchmark_stdio::env_usize("GEN_TOKENS", 500)
|
||||
} else {
|
||||
500
|
||||
};
|
||||
let search_graphs = if stdio {
|
||||
benchmark_stdio::env_usize("SEARCH_GRAPHS", 500)
|
||||
} else {
|
||||
500
|
||||
};
|
||||
let gen_tokens = 500;
|
||||
let search_graphs = 500;
|
||||
let prompt = "Explain what a neural network is in a paragraph.";
|
||||
|
||||
if !stdio {
|
||||
tracing_subscriber::registry()
|
||||
.with(tracing_subscriber::fmt::layer())
|
||||
.with(luminal_filter())
|
||||
.init();
|
||||
}
|
||||
tracing_subscriber::registry()
|
||||
.with(tracing_subscriber::fmt::layer())
|
||||
.with(luminal_filter())
|
||||
.init();
|
||||
|
||||
let ctx = CudaContext::new(0).unwrap();
|
||||
let stream = ctx.default_stream();
|
||||
@@ -47,6 +31,14 @@ fn main() {
|
||||
println!("Using model directory: {}", model_dir.display());
|
||||
|
||||
let tokenizer = Tokenizer::from_file(model_dir.join("tokenizer.json")).unwrap();
|
||||
let chat_prompt = format!(
|
||||
"<|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
);
|
||||
let prompt_tokens = tokenizer
|
||||
.encode(chat_prompt.as_str(), true)
|
||||
.unwrap()
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
|
||||
// Build graph
|
||||
let mut cx = Graph::default();
|
||||
@@ -60,13 +52,8 @@ fn main() {
|
||||
v_out.output();
|
||||
}
|
||||
|
||||
cx.set_dim('s', 1);
|
||||
cx.set_dim('p', 1);
|
||||
|
||||
println!("Building E-Graph...");
|
||||
cx.build_search_space_with_options::<CudaRuntime>(
|
||||
BuildSearchSpaceOptions::new().max_memory_mib(500),
|
||||
);
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
|
||||
println!("Loading weights...");
|
||||
let mut runtime = CudaRuntime::initialize(stream);
|
||||
@@ -74,13 +61,10 @@ fn main() {
|
||||
runtime.load_safetensors(&cx, weights_path.to_str().unwrap());
|
||||
|
||||
let cache_bytes = N_KV_HEADS * max_seq_len * HEAD_DIM * std::mem::size_of::<f32>();
|
||||
let reset_cache = |runtime: &mut CudaRuntime| {
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
|
||||
}
|
||||
};
|
||||
reset_cache(&mut runtime);
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
|
||||
}
|
||||
|
||||
println!("Compiling...");
|
||||
cx.set_dim('s', 1);
|
||||
@@ -88,65 +72,12 @@ fn main() {
|
||||
runtime.set_data(input, vec![1]);
|
||||
runtime.set_data(token_ids, vec![1]);
|
||||
runtime = cx.search(runtime, search_graphs);
|
||||
reset_cache(&mut runtime);
|
||||
|
||||
if stdio {
|
||||
benchmark_stdio::serve(|prompt| {
|
||||
reset_cache(&mut runtime);
|
||||
run_prompt(
|
||||
prompt,
|
||||
gen_tokens,
|
||||
&tokenizer,
|
||||
&mut cx,
|
||||
&mut runtime,
|
||||
input,
|
||||
token_ids,
|
||||
logits,
|
||||
&cache_outputs,
|
||||
&kv_cache,
|
||||
true,
|
||||
);
|
||||
});
|
||||
} else {
|
||||
run_prompt(
|
||||
prompt,
|
||||
gen_tokens,
|
||||
&tokenizer,
|
||||
&mut cx,
|
||||
&mut runtime,
|
||||
input,
|
||||
token_ids,
|
||||
logits,
|
||||
&cache_outputs,
|
||||
&kv_cache,
|
||||
false,
|
||||
);
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn run_prompt(
|
||||
prompt: &str,
|
||||
gen_tokens: usize,
|
||||
tokenizer: &Tokenizer,
|
||||
cx: &mut Graph,
|
||||
runtime: &mut CudaRuntime,
|
||||
input: GraphTensor,
|
||||
token_ids: GraphTensor,
|
||||
logits: GraphTensor,
|
||||
cache_outputs: &[(GraphTensor, GraphTensor)],
|
||||
kv_cache: &KVCache,
|
||||
stdio: bool,
|
||||
) {
|
||||
let chat_prompt = format!(
|
||||
"<|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
);
|
||||
let prompt_tokens = tokenizer
|
||||
.encode(chat_prompt.as_str(), true)
|
||||
.unwrap()
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
let query_start = Instant::now();
|
||||
let mut prev_seq = 1usize;
|
||||
let mut sentence = vec![prompt_tokens[0]];
|
||||
let total_steps = prompt_tokens.len() - 1 + gen_tokens;
|
||||
@@ -158,16 +89,13 @@ fn run_prompt(
|
||||
const EOS_TOKEN: u32 = 128009;
|
||||
const STOP_TOKEN: u32 = 128001;
|
||||
|
||||
if !stdio {
|
||||
println!(
|
||||
"Prompt: {} tokens, generating up to {} tokens",
|
||||
prompt_len, gen_tokens
|
||||
);
|
||||
}
|
||||
println!(
|
||||
"Prompt: {} tokens, generating up to {} tokens",
|
||||
prompt_len, gen_tokens
|
||||
);
|
||||
|
||||
let mut generated = 0usize;
|
||||
for i in 0..total_steps {
|
||||
let start = Instant::now();
|
||||
let start = std::time::Instant::now();
|
||||
let is_prefill = i < prompt_len - 1;
|
||||
let seq_len = sentence.len();
|
||||
|
||||
@@ -226,21 +154,12 @@ fn run_prompt(
|
||||
}
|
||||
|
||||
let decoded = tokenizer.decode(&[next_token], true).unwrap();
|
||||
if stdio {
|
||||
benchmark_stdio::emit_token(&decoded);
|
||||
} else {
|
||||
print!("{}", decoded);
|
||||
std::io::stdout().flush().unwrap();
|
||||
}
|
||||
generated += 1;
|
||||
}
|
||||
if stdio {
|
||||
benchmark_stdio::emit_eoq(generated, query_start);
|
||||
return;
|
||||
print!("{}", decoded);
|
||||
std::io::stdout().flush().unwrap();
|
||||
}
|
||||
println!();
|
||||
|
||||
// Benchmarks
|
||||
println!();
|
||||
let decode_durations: Vec<_> = fwd_durations.iter().skip(prompt_len).collect();
|
||||
if decode_durations.len() > 2 {
|
||||
println!(
|
||||
|
||||
@@ -246,13 +246,8 @@ fn hlir_attention(
|
||||
let v_cache_out = v_new.scatter(scatter_idx, v_cache_in);
|
||||
|
||||
// Slice to valid range: [N_KV_HEADS, total_seq, HEAD_DIM]
|
||||
let mut k_full = k_cache_out.slice((.., ..total_seq, ..));
|
||||
let mut v_full = v_cache_out.slice((.., ..total_seq, ..));
|
||||
// LUM-545: model invariant `prev + seq <= max_seq`, but the frontend
|
||||
// cannot yet propagate expression-bound assertions, so `slice` reports
|
||||
// `min(max_seq, p+s)`. Normalize the visible cache axis to `total_seq`.
|
||||
k_full.shape.dims[1] = total_seq;
|
||||
v_full.shape.dims[1] = total_seq;
|
||||
let k_full = k_cache_out.slice((.., ..total_seq, ..));
|
||||
let v_full = v_cache_out.slice((.., ..total_seq, ..));
|
||||
|
||||
// GQA expand: [N_KV_HEADS, total_seq, HEAD_DIM] -> [N_HEADS, total_seq, HEAD_DIM]
|
||||
let k_3d = k_full.expand_dim(1, KV_GROUPS).merge_dims(0, 1);
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
#[path = "../../../examples_common/benchmark_stdio.rs"]
|
||||
mod benchmark_stdio;
|
||||
mod hf;
|
||||
mod model;
|
||||
|
||||
@@ -9,36 +7,22 @@ use luminal_cuda_lite::{cudarc::driver::CudaContext, runtime::CudaRuntime};
|
||||
use luminal_tracing::*;
|
||||
use model::*;
|
||||
use rustc_hash::FxHashSet;
|
||||
use std::{
|
||||
io::Write,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use std::{io::Write, time::Duration};
|
||||
use tokenizers::Tokenizer;
|
||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
||||
|
||||
const REPO_ID: &str = "Qwen/Qwen3-4B";
|
||||
|
||||
fn main() {
|
||||
let stdio = benchmark_stdio::enabled();
|
||||
let max_seq_len = 4096;
|
||||
let gen_tokens = if stdio {
|
||||
benchmark_stdio::env_usize("GEN_TOKENS", 500)
|
||||
} else {
|
||||
500
|
||||
};
|
||||
let search_graphs = if stdio {
|
||||
benchmark_stdio::env_usize("SEARCH_GRAPHS", 500)
|
||||
} else {
|
||||
500
|
||||
};
|
||||
let gen_tokens = 500;
|
||||
let search_graphs = 500;
|
||||
let prompt = "Explain what a neural network is in a paragraph.";
|
||||
|
||||
if !stdio {
|
||||
tracing_subscriber::registry()
|
||||
.with(tracing_subscriber::fmt::layer())
|
||||
.with(luminal_filter())
|
||||
.init();
|
||||
}
|
||||
tracing_subscriber::registry()
|
||||
.with(tracing_subscriber::fmt::layer())
|
||||
.with(luminal_filter())
|
||||
.init();
|
||||
|
||||
let ctx = CudaContext::new(0).unwrap();
|
||||
let stream = ctx.default_stream();
|
||||
@@ -47,6 +31,7 @@ fn main() {
|
||||
println!("Using model directory: {}", model_dir.display());
|
||||
|
||||
let tokenizer = Tokenizer::from_file(model_dir.join("tokenizer.json")).unwrap();
|
||||
let prompt_tokens = tokenizer.encode(prompt, true).unwrap().get_ids().to_vec();
|
||||
|
||||
// Build graph
|
||||
let mut cx = Graph::default();
|
||||
@@ -69,13 +54,10 @@ fn main() {
|
||||
runtime.load_safetensors(&cx, weights_path.to_str().unwrap());
|
||||
|
||||
let cache_bytes = N_KV_HEADS * max_seq_len * HEAD_DIM * std::mem::size_of::<f32>();
|
||||
let reset_cache = |runtime: &mut CudaRuntime| {
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
|
||||
}
|
||||
};
|
||||
reset_cache(&mut runtime);
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
|
||||
}
|
||||
|
||||
println!("Compiling...");
|
||||
cx.set_dim('s', 1);
|
||||
@@ -83,58 +65,12 @@ fn main() {
|
||||
runtime.set_data(input, vec![1]);
|
||||
runtime.set_data(token_ids, vec![1]);
|
||||
runtime = cx.search(runtime, search_graphs);
|
||||
reset_cache(&mut runtime);
|
||||
|
||||
if stdio {
|
||||
benchmark_stdio::serve(|prompt| {
|
||||
reset_cache(&mut runtime);
|
||||
run_prompt(
|
||||
prompt,
|
||||
gen_tokens,
|
||||
&tokenizer,
|
||||
&mut cx,
|
||||
&mut runtime,
|
||||
input,
|
||||
token_ids,
|
||||
logits,
|
||||
&cache_outputs,
|
||||
&kv_cache,
|
||||
true,
|
||||
);
|
||||
});
|
||||
} else {
|
||||
run_prompt(
|
||||
prompt,
|
||||
gen_tokens,
|
||||
&tokenizer,
|
||||
&mut cx,
|
||||
&mut runtime,
|
||||
input,
|
||||
token_ids,
|
||||
logits,
|
||||
&cache_outputs,
|
||||
&kv_cache,
|
||||
false,
|
||||
);
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn run_prompt(
|
||||
prompt: &str,
|
||||
gen_tokens: usize,
|
||||
tokenizer: &Tokenizer,
|
||||
cx: &mut Graph,
|
||||
runtime: &mut CudaRuntime,
|
||||
input: GraphTensor,
|
||||
token_ids: GraphTensor,
|
||||
logits: GraphTensor,
|
||||
cache_outputs: &[(GraphTensor, GraphTensor)],
|
||||
kv_cache: &KVCache,
|
||||
stdio: bool,
|
||||
) {
|
||||
let prompt_tokens = tokenizer.encode(prompt, true).unwrap().get_ids().to_vec();
|
||||
let query_start = Instant::now();
|
||||
let mut prev_seq = 1usize;
|
||||
let mut sentence = vec![prompt_tokens[0]];
|
||||
let total_steps = prompt_tokens.len() - 1 + gen_tokens;
|
||||
@@ -146,16 +82,13 @@ fn run_prompt(
|
||||
const EOS_TOKEN: u32 = 151645; // <|endoftext|>
|
||||
const STOP_TOKEN: u32 = 151643; // <|end|>
|
||||
|
||||
if !stdio {
|
||||
println!(
|
||||
"Prompt: {} tokens, generating up to {} tokens",
|
||||
prompt_len, gen_tokens
|
||||
);
|
||||
}
|
||||
println!(
|
||||
"Prompt: {} tokens, generating up to {} tokens",
|
||||
prompt_len, gen_tokens
|
||||
);
|
||||
|
||||
let mut generated = 0usize;
|
||||
for i in 0..total_steps {
|
||||
let start = Instant::now();
|
||||
let start = std::time::Instant::now();
|
||||
let is_prefill = i < prompt_len - 1;
|
||||
let seq_len = sentence.len();
|
||||
|
||||
@@ -214,21 +147,12 @@ fn run_prompt(
|
||||
}
|
||||
|
||||
let decoded = tokenizer.decode(&[next_token], true).unwrap();
|
||||
if stdio {
|
||||
benchmark_stdio::emit_token(&decoded);
|
||||
} else {
|
||||
print!("{}", decoded);
|
||||
std::io::stdout().flush().unwrap();
|
||||
}
|
||||
generated += 1;
|
||||
}
|
||||
if stdio {
|
||||
benchmark_stdio::emit_eoq(generated, query_start);
|
||||
return;
|
||||
print!("{}", decoded);
|
||||
std::io::stdout().flush().unwrap();
|
||||
}
|
||||
println!();
|
||||
|
||||
// Benchmarks
|
||||
println!();
|
||||
let decode_durations: Vec<_> = fwd_durations.iter().skip(prompt_len).collect();
|
||||
if decode_durations.len() > 2 {
|
||||
println!(
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user