mirror of
https://github.com/jafioti/luminal.git
synced 2026-06-01 21:49:47 +09:00
Compare commits
16 Commits
rust-examp
...
flashinfer
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
62e86f9dc5 | ||
|
|
75e4e6be0a | ||
|
|
4cd47ffa45 | ||
|
|
db72cf505c | ||
|
|
766db93b08 | ||
|
|
4e93f02725 | ||
|
|
25393a9fdd | ||
|
|
81ea750e6b | ||
|
|
f94335b1b8 | ||
|
|
f62e3c50d0 | ||
|
|
eeeabd7c20 | ||
|
|
0f02466f3d | ||
|
|
156fac518e | ||
|
|
a3df68bd43 | ||
|
|
7a95e56a8b | ||
|
|
e558ce6849 |
@@ -1,3 +1,6 @@
|
||||
[alias]
|
||||
examples = "run --release --bin examples-perf --"
|
||||
|
||||
[target.aarch64-unknown-linux-gnu]
|
||||
rustflags = [
|
||||
"-Ctarget-feature=+fp16,+fhm"
|
||||
|
||||
2
.github/workflows/test-core.yml
vendored
2
.github/workflows/test-core.yml
vendored
@@ -21,4 +21,4 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- name: Run tests
|
||||
run: cargo test --release --workspace --exclude luminal_cuda_lite --exclude luminal_metal --exclude luminal_bench --verbose
|
||||
run: cargo test --release -p luminal -p luminal_nn -p luminal_tracing -p luminal_python --verbose
|
||||
|
||||
67
.github/workflows/test-full-cuda.yml
vendored
Normal file
67
.github/workflows/test-full-cuda.yml
vendored
Normal file
@@ -0,0 +1,67 @@
|
||||
name: Test Full CUDA
|
||||
|
||||
on:
|
||||
pull_request_target:
|
||||
branches: ["main"]
|
||||
types: [labeled, synchronize]
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
rust_cuda_ignored_tests:
|
||||
if: >-
|
||||
github.event_name == 'workflow_dispatch'
|
||||
|| (github.event_name == 'pull_request_target'
|
||||
&& contains(github.event.pull_request.labels.*.name, 'full-modal-ready'))
|
||||
name: Rust CUDA Ignored Tests
|
||||
runs-on: ubuntu-latest
|
||||
environment: Modal
|
||||
timeout-minutes: 300
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
with:
|
||||
ref: ${{ github.event.pull_request.head.sha || github.sha }}
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- name: Install Modal
|
||||
run: pip install modal
|
||||
- name: Run ignored CUDA Rust tests on Modal
|
||||
env:
|
||||
MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }}
|
||||
MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }}
|
||||
GPU_TYPE: H100
|
||||
MODAL_TIMEOUT: "14400"
|
||||
CARGO_TEST_ARGS: "--ignored --test-threads=1"
|
||||
run: modal run ci/modal_cargo_test.py
|
||||
|
||||
python_cuda_slow_tests:
|
||||
if: >-
|
||||
github.event_name == 'workflow_dispatch'
|
||||
|| (github.event_name == 'pull_request_target'
|
||||
&& contains(github.event.pull_request.labels.*.name, 'full-modal-ready'))
|
||||
name: Python CUDA Slow Tests
|
||||
runs-on: ubuntu-latest
|
||||
environment: Modal
|
||||
timeout-minutes: 300
|
||||
defaults:
|
||||
run:
|
||||
working-directory: crates/luminal_python
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
with:
|
||||
ref: ${{ github.event.pull_request.head.sha || github.sha }}
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- name: Install Modal
|
||||
run: pip install modal
|
||||
- name: Run slow pytest CUDA tests on Modal
|
||||
env:
|
||||
MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }}
|
||||
MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }}
|
||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||
run: modal run modal_pytest_runner.py --gpu A100-80GB --timeout 14400 tests/ -v -s -m slow
|
||||
17
.github/workflows/test-metal.yml
vendored
17
.github/workflows/test-metal.yml
vendored
@@ -17,3 +17,20 @@ jobs:
|
||||
- uses: actions/checkout@v6
|
||||
- name: Run Metal crate tests
|
||||
run: rustup update; cargo test --release -p luminal_metal --verbose -- --test-threads=1
|
||||
|
||||
llama_1b_metal_example:
|
||||
name: Llama 1B Metal Example
|
||||
runs-on: macos-14-xlarge
|
||||
timeout-minutes: 120
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- name: Print runner hardware
|
||||
run: system_profiler SPHardwareDataType SPDisplaysDataType
|
||||
- name: Cache Hugging Face models
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/huggingface
|
||||
key: llama-1b-metal-hf-${{ runner.os }}-${{ runner.arch }}-v1
|
||||
- name: Run Llama 1B Metal example and validate output
|
||||
run: rustup update; python3 ci/metal_llama_1b_example.py
|
||||
|
||||
@@ -21,8 +21,7 @@ let b = cx.tensor((1, 4));
|
||||
let c = a.matmul(b).output();
|
||||
|
||||
// Compile
|
||||
cx.build_search_space::<NativeRuntime>();
|
||||
let mut rt = cx.search(NativeRuntime::default(), 1);
|
||||
let mut rt = cx.compile(NativeRuntime::default(), CompileOptions::default());
|
||||
|
||||
// Set input tensors
|
||||
rt.set_data(a, vec![1.0, 2.0, 3.0]);
|
||||
|
||||
185
ci/examples_perf.py
Normal file
185
ci/examples_perf.py
Normal file
@@ -0,0 +1,185 @@
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from example_output import validate_output
|
||||
|
||||
|
||||
DEFAULT_EXAMPLES = ["llama", "gemma", "qwen", "qwen3_moe", "gemma4_moe", "whisper"]
|
||||
|
||||
EXAMPLE_CARGO_ARGS = {
|
||||
"llama": ["run", "--release", "-p", "llama"],
|
||||
"gemma": ["run", "--release", "-p", "gemma"],
|
||||
"qwen": ["run", "--release", "-p", "qwen", "--features", "cuda"],
|
||||
"qwen3_moe": ["run", "--release", "-p", "qwen3_moe"],
|
||||
"gemma4_moe": ["run", "--release", "-p", "gemma4_moe"],
|
||||
"whisper": ["run", "--release", "-p", "whisper"],
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Metrics:
|
||||
ttft_ms: float | None = None
|
||||
tpot_ms: float | None = None
|
||||
tps: float | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExampleResult:
|
||||
name: str
|
||||
ok: bool
|
||||
metrics: Metrics = field(default_factory=Metrics)
|
||||
wall_s: float = 0.0
|
||||
error: str | None = None
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = [arg for arg in sys.argv[1:] if arg != "--"]
|
||||
if any(arg in {"-h", "--help"} for arg in args):
|
||||
print_help()
|
||||
return
|
||||
if "--list" in args:
|
||||
print("\n".join(DEFAULT_EXAMPLES))
|
||||
return
|
||||
|
||||
examples = args or DEFAULT_EXAMPLES
|
||||
results = [run_example(example) for example in examples]
|
||||
print_table(results)
|
||||
if any(not result.ok for result in results):
|
||||
raise SystemExit(1)
|
||||
|
||||
|
||||
def print_help() -> None:
|
||||
print(
|
||||
"Run validated Luminal examples, validate textual output, and summarize perf.\n"
|
||||
"\n"
|
||||
"Usage:\n"
|
||||
" cargo examples\n"
|
||||
" cargo examples llama qwen whisper\n"
|
||||
"\n"
|
||||
"Options:\n"
|
||||
" --list Print the default validated examples\n"
|
||||
" -h, --help\n"
|
||||
"\n"
|
||||
f"The default set matches the Modal examples CI: {', '.join(DEFAULT_EXAMPLES)}."
|
||||
)
|
||||
|
||||
|
||||
def run_example(example: str) -> ExampleResult:
|
||||
cargo_args = EXAMPLE_CARGO_ARGS.get(example)
|
||||
if cargo_args is None:
|
||||
known = ", ".join(DEFAULT_EXAMPLES)
|
||||
return ExampleResult(example, False, error=f"unknown example; known examples: {known}")
|
||||
|
||||
print(f"\n=== Running {example} ===")
|
||||
print(f"$ cargo {' '.join(cargo_args)}")
|
||||
started = time.monotonic()
|
||||
env = os.environ.copy()
|
||||
env.setdefault("CUDARC_CUDA_VERSION", "12080")
|
||||
process = subprocess.Popen(
|
||||
["cargo", *cargo_args],
|
||||
cwd=repo_root(),
|
||||
env=env,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
)
|
||||
assert process.stdout is not None
|
||||
|
||||
chunks: list[bytes] = []
|
||||
while True:
|
||||
chunk = process.stdout.read1(4096)
|
||||
if not chunk:
|
||||
break
|
||||
sys.stdout.buffer.write(chunk)
|
||||
sys.stdout.buffer.flush()
|
||||
chunks.append(chunk)
|
||||
|
||||
return_code = process.wait()
|
||||
output = b"".join(chunks).decode("utf-8", errors="replace")
|
||||
wall_s = time.monotonic() - started
|
||||
metrics = parse_metrics(output)
|
||||
|
||||
if return_code:
|
||||
return ExampleResult(
|
||||
example,
|
||||
False,
|
||||
metrics=metrics,
|
||||
wall_s=wall_s,
|
||||
error=f"process exited with code {return_code}",
|
||||
)
|
||||
|
||||
try:
|
||||
validate_output(example, output)
|
||||
except Exception as exc:
|
||||
return ExampleResult(example, False, metrics=metrics, wall_s=wall_s, error=str(exc))
|
||||
|
||||
return ExampleResult(example, True, metrics=metrics, wall_s=wall_s)
|
||||
|
||||
|
||||
def repo_root() -> str:
|
||||
return os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
|
||||
def parse_metrics(output: str) -> Metrics:
|
||||
metrics = Metrics()
|
||||
for line in output.splitlines():
|
||||
if "TTFT:" in line:
|
||||
metrics.ttft_ms = parse_number_after(line, "TTFT:")
|
||||
if "TPOT:" in line:
|
||||
metrics.tpot_ms = parse_number_after(line, "TPOT:")
|
||||
if "tok/s" in line:
|
||||
metrics.tps = parse_tok_per_second(line)
|
||||
if metrics.tps is None and metrics.tpot_ms:
|
||||
metrics.tps = 1000.0 / metrics.tpot_ms
|
||||
return metrics
|
||||
|
||||
|
||||
def parse_number_after(line: str, marker: str) -> float | None:
|
||||
tail = line.split(marker, 1)[1].lstrip()
|
||||
chars = []
|
||||
for char in tail:
|
||||
if char.isdigit() or char == ".":
|
||||
chars.append(char)
|
||||
else:
|
||||
break
|
||||
if not chars:
|
||||
return None
|
||||
return float("".join(chars))
|
||||
|
||||
|
||||
def parse_tok_per_second(line: str) -> float | None:
|
||||
head = line.split("tok/s", 1)[0].rstrip(" (")
|
||||
parts = head.split()
|
||||
if not parts:
|
||||
return None
|
||||
try:
|
||||
return float(parts[-1])
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
def print_table(results: list[ExampleResult]) -> None:
|
||||
print("\nSummary")
|
||||
print(f"{'example':<14} {'status':<8} {'TTFT ms':>10} {'TPOT ms':>10} {'tok/s':>10} {'wall s':>10}")
|
||||
print("-" * 68)
|
||||
for result in results:
|
||||
status = "ok" if result.ok else "failed"
|
||||
print(
|
||||
f"{result.name:<14} {status:<8} "
|
||||
f"{format_metric(result.metrics.ttft_ms):>10} "
|
||||
f"{format_metric(result.metrics.tpot_ms):>10} "
|
||||
f"{format_metric(result.metrics.tps):>10} "
|
||||
f"{result.wall_s:>10.1f}"
|
||||
)
|
||||
if result.error:
|
||||
print(f" error: {result.error}")
|
||||
|
||||
|
||||
def format_metric(value: float | None) -> str:
|
||||
return "-" if value is None else f"{value:.2f}"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
48
ci/metal_llama_1b_example.py
Normal file
48
ci/metal_llama_1b_example.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
|
||||
def run_and_capture(command: list[str], *, cwd: str, env: dict[str, str]) -> str:
|
||||
process = subprocess.Popen(
|
||||
command,
|
||||
cwd=cwd,
|
||||
env=env,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
)
|
||||
assert process.stdout is not None
|
||||
|
||||
chunks = []
|
||||
while True:
|
||||
chunk = process.stdout.read1(4096)
|
||||
if not chunk:
|
||||
break
|
||||
sys.stdout.buffer.write(chunk)
|
||||
sys.stdout.buffer.flush()
|
||||
chunks.append(chunk)
|
||||
|
||||
return_code = process.wait()
|
||||
output = b"".join(chunks).decode("utf-8", errors="replace")
|
||||
if return_code:
|
||||
raise subprocess.CalledProcessError(return_code, command, output=output)
|
||||
return output
|
||||
|
||||
|
||||
def main():
|
||||
repo_root = os.environ.get("GITHUB_WORKSPACE", os.getcwd())
|
||||
sys.path.insert(0, os.path.join(repo_root, "ci"))
|
||||
from example_output import validate_output
|
||||
|
||||
output = run_and_capture(
|
||||
["cargo", "run", "--release", "-p", "luminal_metal", "--example", "llama_1b"],
|
||||
cwd=repo_root,
|
||||
env=os.environ.copy(),
|
||||
)
|
||||
if "TTFT:" not in output or "TPOT:" not in output:
|
||||
raise AssertionError("Llama 1B Metal example did not complete generation")
|
||||
validate_output("llama", output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,8 +1,10 @@
|
||||
import modal
|
||||
import subprocess
|
||||
import os
|
||||
import shlex
|
||||
|
||||
gpu_type = os.environ.get("GPU_TYPE", "T4")
|
||||
modal_timeout = int(os.environ.get("MODAL_TIMEOUT", "7200"))
|
||||
CUDARC_CUDA_VERSION = "12080"
|
||||
|
||||
app = modal.App("luminal-ci-cargo-test")
|
||||
@@ -28,7 +30,7 @@ cuda_image = (
|
||||
@app.function(
|
||||
image=cuda_image,
|
||||
gpu=gpu_type,
|
||||
timeout=7200, # 2 hours
|
||||
timeout=modal_timeout,
|
||||
)
|
||||
def run_cargo_test():
|
||||
"""Run cargo test for luminal_cuda_lite on a Modal GPU."""
|
||||
@@ -43,17 +45,20 @@ def run_cargo_test():
|
||||
)
|
||||
compute_cap = result.stdout.strip().replace(".", "")
|
||||
|
||||
test_args = shlex.split(os.environ.get("CARGO_TEST_ARGS", "--test-threads=1"))
|
||||
cmd = [
|
||||
"cargo",
|
||||
"test",
|
||||
"--release",
|
||||
"-p",
|
||||
"luminal_cuda_lite",
|
||||
"--verbose",
|
||||
"--",
|
||||
*test_args,
|
||||
]
|
||||
print("Running:", " ".join(cmd), flush=True)
|
||||
subprocess.run(
|
||||
[
|
||||
"cargo",
|
||||
"test",
|
||||
"--release",
|
||||
"-p",
|
||||
"luminal_cuda_lite",
|
||||
"--verbose",
|
||||
"--",
|
||||
"--test-threads=1",
|
||||
],
|
||||
cmd,
|
||||
cwd=WORKDIR,
|
||||
env={
|
||||
**os.environ,
|
||||
|
||||
@@ -39,7 +39,7 @@ fn run_metal_pattern_benchmark(
|
||||
let mut cx = Graph::default();
|
||||
pattern.build_graph(&mut cx, *size);
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
|
||||
let mut rng = rand::rng();
|
||||
@@ -50,7 +50,7 @@ fn run_metal_pattern_benchmark(
|
||||
}
|
||||
}
|
||||
|
||||
let mut rt = cx.search(rt, 5);
|
||||
let mut rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
|
||||
let mut bench_metrics = None;
|
||||
|
||||
@@ -41,7 +41,7 @@ struct PreparedBench {
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
fn prepare_and_search(cx: &mut Graph, input_sizes: &[(NodeIndex, usize)]) -> Option<PreparedBench> {
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
|
||||
let mut rng = rand::rng();
|
||||
@@ -50,7 +50,7 @@ fn prepare_and_search(cx: &mut Graph, input_sizes: &[(NodeIndex, usize)]) -> Opt
|
||||
rt.set_data(*node, &data);
|
||||
}
|
||||
|
||||
let rt = cx.search(rt, 5);
|
||||
let rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
|
||||
Some(PreparedBench {
|
||||
rt,
|
||||
|
||||
@@ -41,7 +41,7 @@ mod metal_backend {
|
||||
const NAME: &'static str = "Metal";
|
||||
|
||||
fn build_search_space(cx: &mut Graph) {
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -29,9 +29,21 @@ impl DynBackend for CudaLiteDynBackend {
|
||||
fn get_output_f32(&self, node: NodeIndex) -> Vec<f32> {
|
||||
self.runtime.get_f32(node)
|
||||
}
|
||||
fn get_output_f16(&self, node: NodeIndex) -> Vec<half::f16> {
|
||||
self.runtime.get_f16(node)
|
||||
}
|
||||
fn get_output_bf16(&self, node: NodeIndex) -> Vec<half::bf16> {
|
||||
self.runtime.get_bf16(node)
|
||||
}
|
||||
fn get_output_i32(&self, node: NodeIndex) -> Vec<i32> {
|
||||
self.runtime.get_i32(node)
|
||||
}
|
||||
fn get_output_i64(&self, node: NodeIndex) -> Vec<i64> {
|
||||
self.runtime.get_i64(node)
|
||||
}
|
||||
fn get_output_f64(&self, node: NodeIndex) -> Vec<f64> {
|
||||
self.runtime.get_f64(node)
|
||||
}
|
||||
fn get_output_bool(&self, node: NodeIndex) -> Vec<bool> {
|
||||
self.runtime.get_bool(node)
|
||||
}
|
||||
|
||||
@@ -1,258 +0,0 @@
|
||||
use std::sync::{Arc, OnceLock};
|
||||
|
||||
use luminal::{
|
||||
egglog_utils::{
|
||||
api::{Rule, SortDef, sort},
|
||||
base::{EXPRESSION, OP_KIND, STRING},
|
||||
extract_expr,
|
||||
},
|
||||
op::{EgglogOp, LLIROp},
|
||||
prelude::{
|
||||
tracing::{Level, span, trace},
|
||||
*,
|
||||
},
|
||||
};
|
||||
|
||||
use crate::{
|
||||
cudarc::{
|
||||
cublas::{
|
||||
CudaBlas,
|
||||
sys::{cublasOperation_t, cublasSetStream_v2, cublasSgemm_v2, cublasStatus_t},
|
||||
},
|
||||
driver::CudaStream,
|
||||
},
|
||||
host::{DeviceBuffer, HostOp},
|
||||
};
|
||||
|
||||
/// Global shared cuBLAS handle to avoid per-operation workspace allocation
|
||||
static SHARED_CUBLAS: OnceLock<Arc<CudaBlas>> = OnceLock::new();
|
||||
|
||||
/// Parse cuBLAS operation from egglog string (e.g., "\"T\"" -> CUBLAS_OP_T)
|
||||
pub fn parse_cublas_op(s: &str) -> cublasOperation_t {
|
||||
// Strip quotes if present (egglog strings are stored with quotes)
|
||||
let stripped = s.trim_matches('"');
|
||||
match stripped {
|
||||
"T" => cublasOperation_t::CUBLAS_OP_T,
|
||||
"N" => cublasOperation_t::CUBLAS_OP_N,
|
||||
"C" => cublasOperation_t::CUBLAS_OP_C,
|
||||
other => panic!("Unknown cuBLAS operation: '{other}' (original: '{s}')"),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[allow(dead_code)]
|
||||
pub struct CuBlasSgemmV2 {
|
||||
m: Expression,
|
||||
n: Expression,
|
||||
k: Expression,
|
||||
a_layout: cublasOperation_t,
|
||||
b_layout: cublasOperation_t,
|
||||
lda: Expression,
|
||||
ldb: Expression,
|
||||
ldc: Expression,
|
||||
/// Lazily initialized cuBLAS handle - created on first execute
|
||||
cublas: OnceLock<Arc<CudaBlas>>,
|
||||
}
|
||||
|
||||
// Useless default for IntoEgglogOp
|
||||
impl Default for CuBlasSgemmV2 {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
m: Expression::default(),
|
||||
n: Expression::default(),
|
||||
k: Expression::default(),
|
||||
a_layout: cublasOperation_t::CUBLAS_OP_N, // IGNORE NOT REAL
|
||||
b_layout: cublasOperation_t::CUBLAS_OP_T, // IGNORE NOT REAL
|
||||
lda: Expression::default(),
|
||||
ldb: Expression::default(),
|
||||
ldc: Expression::default(),
|
||||
cublas: OnceLock::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl EgglogOp for CuBlasSgemmV2 {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
"cublasSgemmV2",
|
||||
&[
|
||||
("m", EXPRESSION),
|
||||
("n", EXPRESSION),
|
||||
("k", EXPRESSION),
|
||||
("a_layout", STRING),
|
||||
("b_layout", STRING),
|
||||
("lda", EXPRESSION),
|
||||
("ldb", EXPRESSION),
|
||||
("ldc", EXPRESSION),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
fn n_inputs(&self) -> usize {
|
||||
2
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![
|
||||
Rule::raw(include_str!["sgemm_v2_RmRm_rewrite.egg"]), // row row
|
||||
Rule::raw(include_str!["sgemm_v2_RmCm_rewrite.egg"]), // row col
|
||||
Rule::raw(include_str!["sgemm_v2_CmRm_rewrite.egg"]), // col row
|
||||
Rule::raw(include_str!["sgemm_v2_CmCm_rewrite.egg"]), // col col
|
||||
]
|
||||
}
|
||||
|
||||
#[allow(unused_variables)]
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a luminal::egglog_utils::SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
// Extract dimensions from egglog
|
||||
let m = extract_expr(egraph, kind_children[0], expr_cache).unwrap();
|
||||
let n = extract_expr(egraph, kind_children[1], expr_cache).unwrap();
|
||||
let k = extract_expr(egraph, kind_children[2], expr_cache).unwrap();
|
||||
|
||||
// Extract layout strings from egglog
|
||||
let a_layout_str = &egraph.enodes[kind_children[3]].0;
|
||||
let b_layout_str = &egraph.enodes[kind_children[4]].0;
|
||||
let a_layout = parse_cublas_op(a_layout_str);
|
||||
let b_layout = parse_cublas_op(b_layout_str);
|
||||
|
||||
// Extract leading dimensions from egglog
|
||||
let lda = extract_expr(egraph, kind_children[5], expr_cache).unwrap();
|
||||
let ldb = extract_expr(egraph, kind_children[6], expr_cache).unwrap();
|
||||
let ldc = extract_expr(egraph, kind_children[7], expr_cache).unwrap();
|
||||
|
||||
let extracted_state = Self {
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
a_layout,
|
||||
b_layout,
|
||||
lda,
|
||||
ldb,
|
||||
ldc,
|
||||
cublas: OnceLock::new(),
|
||||
};
|
||||
trace!(?extracted_state);
|
||||
|
||||
let extracted = LLIROp::new::<dyn HostOp>(Box::new(extracted_state) as Box<dyn HostOp>);
|
||||
|
||||
(extracted, input_enodes)
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
impl HostOp for CuBlasSgemmV2 {
|
||||
fn execute(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
self_node: NodeIndex,
|
||||
inputs: &[NodeIndex],
|
||||
buffers: &FxHashMap<NodeIndex, DeviceBuffer>,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> anyhow::Result<()> {
|
||||
// GEMM parameters
|
||||
let m = self.m.exec(dyn_map).unwrap() as i32;
|
||||
let n = self.n.exec(dyn_map).unwrap() as i32;
|
||||
let k = self.k.exec(dyn_map).unwrap() as i32;
|
||||
let a_layout = self.a_layout;
|
||||
let b_layout = self.b_layout;
|
||||
let lda = self.lda.exec(dyn_map).unwrap() as i32;
|
||||
let ldb = self.ldb.exec(dyn_map).unwrap() as i32;
|
||||
let ldc = self.ldc.exec(dyn_map).unwrap() as i32;
|
||||
|
||||
let alpha = 1.0f32;
|
||||
let beta = 0.0f32;
|
||||
|
||||
// Get buffers: output is self_node, inputs are from graph edges
|
||||
let c_buf = buffers[&self_node];
|
||||
let a_buf = buffers[&inputs[0]];
|
||||
let b_buf = buffers[&inputs[1]];
|
||||
|
||||
// Get device pointers
|
||||
let a_ptr = a_buf.ptr();
|
||||
let b_ptr = b_buf.ptr();
|
||||
let c_ptr = c_buf.ptr();
|
||||
|
||||
// Debug: Check buffer sizes
|
||||
trace!(
|
||||
"buffer_validation {}=={},{}=={},{}=={}",
|
||||
a_buf.len(),
|
||||
m * k * 4,
|
||||
b_buf.len(),
|
||||
k * n * 4,
|
||||
c_buf.len(),
|
||||
m * n * 4
|
||||
);
|
||||
let _sgemm_span = span!(
|
||||
Level::TRACE,
|
||||
"cuBLAS_SGEMM_V2",
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
alpha,
|
||||
beta,
|
||||
lda,
|
||||
ldb,
|
||||
ldc,
|
||||
?a_layout,
|
||||
?b_layout,
|
||||
)
|
||||
.entered();
|
||||
|
||||
// Use shared cuBLAS handle to avoid per-operation workspace allocation
|
||||
let cublas = SHARED_CUBLAS.get_or_init(|| Arc::new(CudaBlas::new(stream.clone()).unwrap()));
|
||||
|
||||
// Set the stream for this operation (cuBLAS handle can work with any stream)
|
||||
// The CUstream types from cublas::sys and driver::sys are compatible, just cast
|
||||
unsafe {
|
||||
cublasSetStream_v2(*cublas.handle(), stream.cu_stream() as _);
|
||||
}
|
||||
|
||||
let status = unsafe {
|
||||
cublasSgemm_v2(
|
||||
*cublas.handle(),
|
||||
a_layout,
|
||||
b_layout,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
&alpha as *const f32,
|
||||
a_ptr as *const f32,
|
||||
lda,
|
||||
b_ptr as *const f32,
|
||||
ldb,
|
||||
&beta as *const f32,
|
||||
c_ptr as *mut f32,
|
||||
ldc,
|
||||
)
|
||||
};
|
||||
stream.synchronize().unwrap();
|
||||
|
||||
if status != cublasStatus_t::CUBLAS_STATUS_SUCCESS {
|
||||
return Err(anyhow::anyhow!(
|
||||
"cuBLAS SGEMM TN failed with status: {:?}",
|
||||
status
|
||||
));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
self.m * self.n
|
||||
}
|
||||
|
||||
fn output_bytes(&self) -> Expression {
|
||||
// CuBlasSgemmV2 is F32 only (Sgemm = Single precision)
|
||||
self.output_size() * 4
|
||||
}
|
||||
}
|
||||
@@ -1,73 +0,0 @@
|
||||
; Column-major × Column-major matmul: C[m,n] = A[m,k] × B[k,n]
|
||||
; A[m,k] column-major → expand to [m, n, k] with strides [1, 0, m]
|
||||
; B[k,n] column-major → expand to [m, n, k] with strides [0, k, 1]
|
||||
;
|
||||
; Row-major viewed as column-major (swap trick):
|
||||
; Column-major A[m,k] is already column-major with lda=m
|
||||
; Column-major B[k,n] is already column-major with ldb=k
|
||||
; Row-major C[m,n] ≡ column-major C^T[n,m] with ldc=n
|
||||
;
|
||||
; C^T[n,m] = (A × B)^T = B^T[n,k] × A^T[k,m]
|
||||
; cuBLAS: cublasSgemm(OP_T, OP_T, n, m, k, α, B, k, A, m, β, C, n)
|
||||
(rule
|
||||
(
|
||||
; Match Mul node
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
; Match Sum that reduces the Mul (k dimension)
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
|
||||
; Must be exactly 2D (no batch dims) — batched matmul uses CuBlasLt
|
||||
(= (len ?out_shape) 2)
|
||||
|
||||
; Get dimensions from output shape
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
|
||||
; Get A strides in [m, n, k] space
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
; Get B strides in [m, n, k] space
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
; Assert contiguous k stride on output (required for reduction)
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
; Assert A has strides [1, 0, m] (column-major A[m,k] broadcast to [m,n,k])
|
||||
(= ?a_m_stride (MIter))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MMul (MIter) ?m))
|
||||
|
||||
; Assert B has strides [0, k, 1] (column-major B[k,n] broadcast to [m,n,k])
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MMul (MIter) ?k))
|
||||
(= ?b_k_stride (MIter))
|
||||
|
||||
(= (F32) (dtype ?a))
|
||||
(= (F32) (dtype ?b))
|
||||
)
|
||||
(
|
||||
; For column-major A × column-major B with cuBLAS:
|
||||
; C^T = B^T × A^T → cublasSgemm(OP_T, OP_T, n, m, k, α, B, k, A, m, β, C, n)
|
||||
(let ?sgemm (Op (cublasSgemmV2
|
||||
?n ; cuBLAS m = our n (swapped)
|
||||
?m ; cuBLAS n = our m (swapped)
|
||||
?k ; k unchanged
|
||||
"T" ; transa = Transpose (B is column-major [k,n], need B^T[n,k])
|
||||
"T" ; transb = Transpose (A is column-major [m,k], need A^T[k,m])
|
||||
?k ; lda = k (column-major B[k,n])
|
||||
?m ; ldb = m (column-major A[m,k])
|
||||
?n) ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublas sgemm column-major × column-major"
|
||||
)
|
||||
@@ -1,73 +0,0 @@
|
||||
; Column-major × Row-major matmul: C[m,n] = A[m,k] × B[k,n]
|
||||
; A[m,k] column-major → expand to [m, n, k] with strides [1, 0, m]
|
||||
; B[k,n] row-major → permute to [n,k] then expand to [m, n, k] with strides [0, 1, n]
|
||||
;
|
||||
; Row-major viewed as column-major (swap trick):
|
||||
; Column-major A[m,k] is already column-major with lda=m
|
||||
; Row-major B[k,n] ≡ column-major B^T[n,k] with ldb=n
|
||||
; Row-major C[m,n] ≡ column-major C^T[n,m] with ldc=n
|
||||
;
|
||||
; C^T[n,m] = (A × B)^T = B^T[n,k] × A^T[k,m]
|
||||
; cuBLAS: cublasSgemm(OP_N, OP_T, n, m, k, α, B, n, A, m, β, C, n)
|
||||
(rule
|
||||
(
|
||||
; Match Mul node
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
; Match Sum that reduces the Mul (k dimension)
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
|
||||
; Must be exactly 2D (no batch dims) — batched matmul uses CuBlasLt
|
||||
(= (len ?out_shape) 2)
|
||||
|
||||
; Get dimensions from output shape
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
|
||||
; Get A strides in [m, n, k] space
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
; Get B strides in [m, n, k] space
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
; Assert contiguous k stride on output (required for reduction)
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
; Assert A has strides [1, 0, m] (column-major A[m,k] broadcast to [m,n,k])
|
||||
(= ?a_m_stride (MIter))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MMul (MIter) ?m))
|
||||
|
||||
; Assert B has strides [0, 1, n] (row-major B[k,n] permuted to [n,k] then broadcast to [m,n,k])
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MIter))
|
||||
(= ?b_k_stride (MMul (MIter) ?n))
|
||||
|
||||
(= (F32) (dtype ?a))
|
||||
(= (F32) (dtype ?b))
|
||||
)
|
||||
(
|
||||
; For column-major A × row-major B with cuBLAS:
|
||||
; C^T = B^T × A^T → cublasSgemm(OP_N, OP_T, n, m, k, α, B, n, A, m, β, C, n)
|
||||
(let ?sgemm (Op (cublasSgemmV2
|
||||
?n ; cuBLAS m = our n (swapped)
|
||||
?m ; cuBLAS n = our m (swapped)
|
||||
?k ; k unchanged
|
||||
"N" ; transa = No transpose (B is row-major, viewed as col-major [n,k])
|
||||
"T" ; transb = Transpose (A is column-major [m,k], need A^T[k,m])
|
||||
?n ; lda = n (row-major B[k,n] viewed as col-major [n,k])
|
||||
?m ; ldb = m (column-major A[m,k])
|
||||
?n) ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublas sgemm column-major × row-major"
|
||||
)
|
||||
@@ -1,73 +0,0 @@
|
||||
; Row-major × Column-major matmul: C[m,n] = A[m,k] × B[k,n]
|
||||
; A[m,k] row-major → expand to [m, n, k] with strides [k, 0, 1]
|
||||
; B[k,n] column-major → expand to [m, n, k] with strides [0, k, 1]
|
||||
;
|
||||
; Row-major viewed as column-major (swap trick):
|
||||
; Row-major A[m,k] ≡ column-major A^T[k,m] with lda=k
|
||||
; Column-major B[k,n] is already column-major with ldb=k
|
||||
; Row-major C[m,n] ≡ column-major C^T[n,m] with ldc=n
|
||||
;
|
||||
; C^T[n,m] = (A × B)^T = B^T[n,k] × A^T[k,m]
|
||||
; cuBLAS: cublasSgemm(OP_T, OP_N, n, m, k, α, B, k, A, k, β, C, n)
|
||||
(rule
|
||||
(
|
||||
; Match Mul node
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
; Match Sum that reduces the Mul (k dimension)
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
|
||||
; Must be exactly 2D (no batch dims) — batched matmul uses CuBlasLt
|
||||
(= (len ?out_shape) 2)
|
||||
|
||||
; Get dimensions from output shape
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
|
||||
; Get A strides in [m, n, k] space
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
; Get B strides in [m, n, k] space
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
; Assert contiguous k stride on output (required for reduction)
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
; Assert A has strides [k, 0, 1] (row-major A[m,k] broadcast to [m,n,k])
|
||||
(= ?a_m_stride (MMul (MIter) ?k))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MIter))
|
||||
|
||||
; Assert B has strides [0, k, 1] (column-major B[k,n] broadcast to [m,n,k])
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MMul (MIter) ?k))
|
||||
(= ?b_k_stride (MIter))
|
||||
|
||||
(= (F32) (dtype ?a))
|
||||
(= (F32) (dtype ?b))
|
||||
)
|
||||
(
|
||||
; For row-major A × column-major B with cuBLAS:
|
||||
; C^T = B^T × A^T → cublasSgemm(OP_T, OP_N, n, m, k, α, B, k, A, k, β, C, n)
|
||||
(let ?sgemm (Op (cublasSgemmV2
|
||||
?n ; cuBLAS m = our n (swapped)
|
||||
?m ; cuBLAS n = our m (swapped)
|
||||
?k ; k unchanged
|
||||
"T" ; transa = Transpose (B is column-major, need B^T)
|
||||
"N" ; transb = No transpose
|
||||
?k ; lda = k (column-major B[k,n])
|
||||
?k ; ldb = k (row-major A[m,k] viewed as col-major [k,m])
|
||||
?n) ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublas sgemm row-major × column-major"
|
||||
)
|
||||
@@ -1,73 +0,0 @@
|
||||
; Row-major matmul: C[m,n] = A[m,k] × B[k,n]
|
||||
; A[m,k] row-major → expand to [m, n, k] with strides [k, 0, 1]
|
||||
; B[k,n] row-major → permute to [n,k] then expand to [m, n, k] with strides [0, 1, n]
|
||||
;
|
||||
; Row-major viewed as column-major (swap trick):
|
||||
; Row-major A[m,k] ≡ column-major [k,m] with lda=k
|
||||
; Row-major B[k,n] ≡ column-major [n,k] with ldb=n
|
||||
; Row-major C[m,n] ≡ column-major [n,m] with ldc=n
|
||||
;
|
||||
; cuBLAS computes: C_col[n,m] = B_col[n,k] × A_col[k,m]
|
||||
; cublasSgemm(OP_N, OP_N, n, m, k, α, B, n, A, k, β, C, n)
|
||||
(rule
|
||||
(
|
||||
; Match Mul node
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
; Match Sum that reduces the Mul (k dimension)
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
|
||||
; Must be exactly 2D (no batch dims) — batched matmul uses CuBlasLt
|
||||
(= (len ?out_shape) 2)
|
||||
|
||||
; Get dimensions from output shape
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
|
||||
; Get A strides in [m, n, k] space
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
; Get B strides in [m, n, k] space
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
; Assert contiguous k stride on output (required for reduction)
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
; Assert A has strides [k, 0, 1] (row-major A[m,k] broadcast to [m,n,k])
|
||||
(= ?a_m_stride (MMul (MIter) ?k))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MIter))
|
||||
|
||||
; Assert B has strides [0, 1, n] (row-major B[k,n] permuted to [n,k] then broadcast to [m,n,k])
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MIter))
|
||||
(= ?b_k_stride (MMul (MIter) ?n))
|
||||
|
||||
(= (F32) (dtype ?a))
|
||||
(= (F32) (dtype ?b))
|
||||
)
|
||||
(
|
||||
; For row-major C = A × B with cuBLAS (column-major):
|
||||
; cublasSgemm(OP_N, OP_N, n, m, k, α, B, n, A, k, β, C, n)
|
||||
(let ?sgemm (Op (cublasSgemmV2
|
||||
?n ; cuBLAS m = our n (swapped)
|
||||
?m ; cuBLAS n = our m (swapped)
|
||||
?k ; k unchanged
|
||||
"N" ; transa = No transpose
|
||||
"N" ; transb = No transpose
|
||||
?n ; lda = n (row-major B[k,n] viewed as col-major [n,k])
|
||||
?k ; ldb = k (row-major A[m,k] viewed as col-major [k,m])
|
||||
?n) ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublas sgemm row-major"
|
||||
)
|
||||
@@ -11,11 +11,13 @@
|
||||
; cuBLAS: cublasSgemm(OP_T, OP_T, n, m, k, α, B, k, A, m, β, C, n)
|
||||
(rule
|
||||
(
|
||||
; Match Mul node
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
; Match Sum that reduces the Mul (k dimension)
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
; Match the generic matmul produced from Mul -> Sum.
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
; Match exactly 2D output shape
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
@@ -77,8 +79,12 @@
|
||||
; B column-major per batch: b_k_stride=MIter, b_m_stride=0
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
|
||||
@@ -11,11 +11,13 @@
|
||||
; cuBLAS: cublasSgemm(OP_N, OP_T, n, m, k, α, B, n, A, m, β, C, n)
|
||||
(rule
|
||||
(
|
||||
; Match Mul node
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
; Match Sum that reduces the Mul (k dimension)
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
; Match the generic matmul produced from Mul -> Sum.
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
; Match exactly 2D output shape
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
@@ -77,8 +79,12 @@
|
||||
; B row-major per batch: b_n_stride=MIter, b_m_stride=0
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
|
||||
@@ -11,11 +11,13 @@
|
||||
; cuBLAS: cublasSgemm(OP_T, OP_N, n, m, k, α, B, k, A, k, β, C, n)
|
||||
(rule
|
||||
(
|
||||
; Match Mul node
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
; Match Sum that reduces the Mul (k dimension)
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
; Match the generic matmul produced from Mul -> Sum.
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
; Match exactly 2D output shape
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
@@ -77,8 +79,12 @@
|
||||
; B column-major per batch: b_k_stride=MIter, b_m_stride=0
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
|
||||
@@ -11,11 +11,13 @@
|
||||
; cublasSgemm(OP_N, OP_N, n, m, k, α, B, n, A, k, β, C, n)
|
||||
(rule
|
||||
(
|
||||
; Match Mul node
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
; Match Sum that reduces the Mul (k dimension)
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
; Match the generic matmul produced from Mul -> Sum.
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
; Match exactly 2D output shape
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
@@ -79,8 +81,12 @@
|
||||
; Leading dimensions may differ from k/n when batch slices are non-contiguous.
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
; Output shape: [batch, m, n]
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
|
||||
@@ -25,8 +25,12 @@
|
||||
(ICons ?input_scale (INil))))
|
||||
(= ?a (Op (Cast ?a_size ?a_dtype) (ICons ?scaled_activation (INil))))
|
||||
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(= ?cast (Op (Cast ?size (F32)) (ICons ?sum (INil))))
|
||||
(= ?scale_product (Op (Mul (ENil) (ENil) (ENil) (ENil))
|
||||
(ICons ?input_scale (ICons ?weight_scale (INil)))))
|
||||
@@ -96,8 +100,12 @@
|
||||
(ICons ?input_scale (INil))))
|
||||
(= ?a (Op (Cast ?a_size ?a_dtype) (ICons ?scaled_activation (INil))))
|
||||
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(= ?cast (Op (Cast ?size (F32)) (ICons ?sum (INil))))
|
||||
(= ?scale_product (Op (Mul (ENil) (ENil) (ENil) (ENil))
|
||||
(ICons ?input_scale (ICons ?weight_scale (INil)))))
|
||||
@@ -368,8 +376,12 @@
|
||||
(ICons ?input_scale (INil))))
|
||||
(= ?a (Op (Cast ?a_size ?a_dtype) (ICons ?scaled_activation (INil))))
|
||||
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(= ?cast (Op (Cast ?size (F32)) (ICons ?sum (INil))))
|
||||
(= ?scale_product (Op (Mul (ENil) (ENil) (ENil) (ENil))
|
||||
(ICons ?input_scale (ICons ?weight_scale (INil)))))
|
||||
@@ -440,8 +452,12 @@
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(= ?cast (Op (Cast ?size (F32)) (ICons ?sum (INil))))
|
||||
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
@@ -489,8 +505,12 @@
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(= ?cast (Op (Cast ?size (F32)) (ICons ?sum (INil))))
|
||||
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
@@ -538,8 +558,12 @@
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(= ?cast (Op (Cast ?size (F32)) (ICons ?sum (INil))))
|
||||
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
@@ -587,8 +611,12 @@
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(= ?cast (Op (Cast ?size (F32)) (ICons ?sum (INil))))
|
||||
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
@@ -650,8 +678,12 @@
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(= ?cast (Op (Cast ?size (F32)) (ICons ?sum (INil))))
|
||||
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
@@ -713,8 +745,12 @@
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(= ?cast (Op (Cast ?size (F32)) (ICons ?sum (INil))))
|
||||
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
|
||||
@@ -5,8 +5,12 @@
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
(!= ?m (MNum 0))
|
||||
@@ -54,8 +58,12 @@
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
(!= ?m (MNum 0))
|
||||
@@ -103,8 +111,12 @@
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
(!= ?m (MNum 0))
|
||||
@@ -152,8 +164,12 @@
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
(!= ?m (MNum 0))
|
||||
@@ -201,8 +217,12 @@
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
@@ -264,8 +284,12 @@
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
@@ -327,8 +351,12 @@
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
@@ -390,8 +418,12 @@
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
use std::sync::{Arc, OnceLock};
|
||||
use std::sync::{Arc, Mutex, OnceLock};
|
||||
|
||||
#[cfg(test)]
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
|
||||
use half::{bf16, f16};
|
||||
use luminal::{
|
||||
@@ -15,6 +18,8 @@ use luminal::{
|
||||
},
|
||||
};
|
||||
|
||||
#[cfg(test)]
|
||||
use crate::kernel::CudaGraphHandle;
|
||||
use crate::{
|
||||
cudarc::{
|
||||
cublas::sys::cublasOperation_t,
|
||||
@@ -33,12 +38,22 @@ use crate::{
|
||||
cublasLtMatrixLayoutSetAttribute, cublasLtOrder_t, cudaDataType,
|
||||
},
|
||||
},
|
||||
driver::{CudaStream, DevicePtr},
|
||||
driver::{CudaSlice, CudaStream, DevicePtr},
|
||||
},
|
||||
host::{DeviceBuffer, HostOp, cublas::parse_cublas_op},
|
||||
host::{DeviceBuffer, HostOp},
|
||||
try_create_cublaslt,
|
||||
};
|
||||
|
||||
fn parse_cublas_op(s: &str) -> cublasOperation_t {
|
||||
let stripped = s.trim_matches('"');
|
||||
match stripped {
|
||||
"T" => cublasOperation_t::CUBLAS_OP_T,
|
||||
"N" => cublasOperation_t::CUBLAS_OP_N,
|
||||
"C" => cublasOperation_t::CUBLAS_OP_C,
|
||||
other => panic!("Unknown cuBLAS operation: '{other}' (original: '{s}')"),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[allow(dead_code)]
|
||||
pub struct CuBlasLt {
|
||||
@@ -189,36 +204,50 @@ impl EgglogOp for CuBlasLt {
|
||||
Rule::raw(include_str!["cublaslt_beta_rewrite.egg"]),
|
||||
Rule::raw(include_str!["cublaslt_epilogue_rewrite.egg"]),
|
||||
Rule::raw(include_str!["cublaslt_row_order_rewrite.egg"]),
|
||||
// Delete the matmul-broadcast Mul eclass when the consuming Sum
|
||||
// eclass has a `cublaslt` or `KernelBatchMatMul` alternative. The
|
||||
// cuBLASLt / batched-matmul rewrite rules only union those enodes
|
||||
// into the Sum eclass after the broadcast pattern check passes,
|
||||
// so their presence is the matmul-broadcast signal — no further
|
||||
// stride-form check needed.
|
||||
//
|
||||
// Delete the HLIR `Mul` and its generic fusion-region alternative
|
||||
// from the Mul eclass. Emptying that eclass lets the empty-eclass
|
||||
// cascade prune the downstream Sum / KernelSum fallback. cuBLAS,
|
||||
// TileMatmulFullSplit, KernelBatchMatVec, and KernelBatchMatMul all
|
||||
// take original (a, b) inputs rather than the Mul eclass, so they
|
||||
// survive the cascade and remain as the matmul output alternative.
|
||||
// cuBLASLt now specializes GenericMatmul, so cleanup should prune
|
||||
// the matmul output alternatives directly. Do not delete the
|
||||
// broadcast Mul here; it may still have non-matmul consumers.
|
||||
Rule::raw("(rule
|
||||
((= ?mul (Op (Mul ?shape ?as ?bs ?os) ?inputs))
|
||||
(= ?mul (Op (FusionEnd ?fshape ?fos ?fdt) ?finputs))
|
||||
(= ?sum (Op (Sum ?sshape ?sk ?ssi ?sks ?sso) (ICons ?mul (INil))))
|
||||
((= ?sum (Op (Sum ?shape ?k ?sis ?sks ?sos) ?inputs))
|
||||
(= ?sum (Op (cublaslt ?cm ?cn ?ck ?cta ?ctb ?cao ?cbo ?cco ?cdo ?clda ?cldb ?cldc ?cldd ?cbc ?csa ?csb ?csc ?csd ?cadt ?cbdt ?ccdt ?cddt ?ccompute ?cscale ?calpha ?cbeta ?cepilogue) ?ci)))
|
||||
((delete (Op (Mul ?shape ?as ?bs ?os) ?inputs))
|
||||
(delete (Op (FusionEnd ?fshape ?fos ?fdt) ?finputs)))
|
||||
((delete (Op (Sum ?shape ?k ?sis ?sks ?sos) ?inputs)))
|
||||
:ruleset cleanup
|
||||
:name \"delete-sum-when-cublaslt-exists\"
|
||||
)"),
|
||||
Rule::raw("(rule
|
||||
((= ?mul (Op (Mul ?shape ?as ?bs ?os) ?inputs))
|
||||
(= ?mul (Op (FusionEnd ?fshape ?fos ?fdt) ?finputs))
|
||||
(= ?sum (Op (Sum ?sshape ?sk ?ssi ?sks ?sso) (ICons ?mul (INil))))
|
||||
(= ?sum (Op (KernelBatchMatMul ?bos ?bk ?bas ?baks ?bbs ?bbks ?bouts ?bdt) ?bi)))
|
||||
((delete (Op (Mul ?shape ?as ?bs ?os) ?inputs))
|
||||
(delete (Op (FusionEnd ?fshape ?fos ?fdt) ?finputs)))
|
||||
((= ?sum (Op (KernelSum ?shape ?k ?sis ?sks ?sos ?dt) ?inputs))
|
||||
(= ?sum (Op (cublaslt ?cm ?cn ?ck ?cta ?ctb ?cao ?cbo ?cco ?cdo ?clda ?cldb ?cldc ?cldd ?cbc ?csa ?csb ?csc ?csd ?cadt ?cbdt ?ccdt ?cddt ?ccompute ?cscale ?calpha ?cbeta ?cepilogue) ?ci)))
|
||||
((delete (Op (KernelSum ?shape ?k ?sis ?sks ?sos ?dt) ?inputs)))
|
||||
:ruleset cleanup
|
||||
:name \"delete-kernel-sum-when-cublaslt-exists\"
|
||||
)"),
|
||||
Rule::raw("(rule
|
||||
((= ?sum (Op (Sum ?shape ?k ?sis ?sks ?sos) ?inputs))
|
||||
(= ?sum (Op (cublaslt_scaled ?cm ?cn ?ck ?cta ?ctb ?cao ?cbo ?cco ?cdo ?clda ?cldb ?cldc ?cldd ?cbc ?csa ?csb ?csc ?csd ?cadt ?cbdt ?ccdt ?cddt ?ccompute ?cscale ?calpha ?cbeta ?cepilogue) ?ci)))
|
||||
((delete (Op (Sum ?shape ?k ?sis ?sks ?sos) ?inputs)))
|
||||
:ruleset cleanup
|
||||
:name \"delete-sum-when-scaled-cublaslt-exists\"
|
||||
)"),
|
||||
Rule::raw("(rule
|
||||
((= ?sum (Op (KernelSum ?shape ?k ?sis ?sks ?sos ?dt) ?inputs))
|
||||
(= ?sum (Op (cublaslt_scaled ?cm ?cn ?ck ?cta ?ctb ?cao ?cbo ?cco ?cdo ?clda ?cldb ?cldc ?cldd ?cbc ?csa ?csb ?csc ?csd ?cadt ?cbdt ?ccdt ?cddt ?ccompute ?cscale ?calpha ?cbeta ?cepilogue) ?ci)))
|
||||
((delete (Op (KernelSum ?shape ?k ?sis ?sks ?sos ?dt) ?inputs)))
|
||||
:ruleset cleanup
|
||||
:name \"delete-kernel-sum-when-scaled-cublaslt-exists\"
|
||||
)"),
|
||||
Rule::raw("(rule
|
||||
((= ?sum (Op (GenericMatmul ?go ?gm ?gk ?gls ?grs ?gsis ?gsit ?gos ?gdt) ?generic_inputs))
|
||||
(= ?sum (Op (cublaslt ?cm ?cn ?ck ?cta ?ctb ?cao ?cbo ?cco ?cdo ?clda ?cldb ?cldc ?cldd ?cbc ?csa ?csb ?csc ?csd ?cadt ?cbdt ?ccdt ?cddt ?ccompute ?cscale ?calpha ?cbeta ?cepilogue) ?cublas_inputs)))
|
||||
((delete (Op (GenericMatmul ?go ?gm ?gk ?gls ?grs ?gsis ?gsit ?gos ?gdt) ?generic_inputs)))
|
||||
:ruleset cleanup
|
||||
:name \"prefer-cublaslt-over-generic-matmul\"
|
||||
)"),
|
||||
Rule::raw("(rule
|
||||
((= ?sum (Op (GenericMatmul ?go ?gm ?gk ?gls ?grs ?gsis ?gsit ?gos ?gdt) ?generic_inputs))
|
||||
(= ?sum (Op (cublaslt_scaled ?cm ?cn ?ck ?cta ?ctb ?cao ?cbo ?cco ?cdo ?clda ?cldb ?cldc ?cldd ?cbc ?csa ?csb ?csc ?csd ?cadt ?cbdt ?ccdt ?cddt ?ccompute ?cscale ?calpha ?cbeta ?cepilogue) ?cublas_inputs)))
|
||||
((delete (Op (GenericMatmul ?go ?gm ?gk ?gls ?grs ?gsis ?gsit ?gos ?gdt) ?generic_inputs)))
|
||||
:ruleset cleanup
|
||||
:name \"prefer-scaled-cublaslt-over-generic-matmul\"
|
||||
)"),
|
||||
]
|
||||
}
|
||||
@@ -557,8 +586,8 @@ fn epilogue_name(epilogue: cublasLtEpilogue_t) -> &'static str {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
enum LtScalar {
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub(crate) enum LtScalar {
|
||||
F64(f64),
|
||||
F32(f32),
|
||||
F16(f16),
|
||||
@@ -598,16 +627,16 @@ impl LtScalar {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
struct LtMatmulProblem {
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub(crate) struct LtMatmulProblem {
|
||||
m: u64,
|
||||
n: u64,
|
||||
k: u64,
|
||||
batch_count: i32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
struct LtMatrixSpec {
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub(crate) struct LtMatrixSpec {
|
||||
dtype: cudaDataType,
|
||||
rows: u64,
|
||||
cols: u64,
|
||||
@@ -616,8 +645,8 @@ struct LtMatrixSpec {
|
||||
order: cublasLtOrder_t,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
struct LtComputeSpec {
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub(crate) struct LtComputeSpec {
|
||||
compute_type: cublasComputeType_t,
|
||||
scale_dtype: cudaDataType,
|
||||
alpha: LtScalar,
|
||||
@@ -625,8 +654,8 @@ struct LtComputeSpec {
|
||||
epilogue: cublasLtEpilogue_t,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
struct LtMatmulSpec {
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub(crate) struct LtMatmulSpec {
|
||||
problem: LtMatmulProblem,
|
||||
trans_a: cublasOperation_t,
|
||||
trans_b: cublasOperation_t,
|
||||
@@ -638,8 +667,8 @@ struct LtMatmulSpec {
|
||||
workspace_size: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
struct LtMatmulPointers {
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub(crate) struct LtMatmulPointers {
|
||||
a: u64,
|
||||
b: u64,
|
||||
c: u64,
|
||||
@@ -649,7 +678,35 @@ struct LtMatmulPointers {
|
||||
b_scale: Option<u64>,
|
||||
}
|
||||
|
||||
struct LtRawDescriptors {
|
||||
impl LtMatmulPointers {
|
||||
pub(crate) fn changed_fields(self, other: Self) -> Vec<&'static str> {
|
||||
let mut fields = Vec::new();
|
||||
if self.a != other.a {
|
||||
fields.push("a");
|
||||
}
|
||||
if self.b != other.b {
|
||||
fields.push("b");
|
||||
}
|
||||
if self.c != other.c {
|
||||
fields.push("c");
|
||||
}
|
||||
if self.d != other.d {
|
||||
fields.push("d");
|
||||
}
|
||||
if self.bias != other.bias {
|
||||
fields.push("bias");
|
||||
}
|
||||
if self.a_scale != other.a_scale {
|
||||
fields.push("a_scale");
|
||||
}
|
||||
if self.b_scale != other.b_scale {
|
||||
fields.push("b_scale");
|
||||
}
|
||||
fields
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct LtRawDescriptors {
|
||||
matmul_desc: cublasLtMatmulDesc_t,
|
||||
a_desc: cublasLtMatrixLayout_t,
|
||||
b_desc: cublasLtMatrixLayout_t,
|
||||
@@ -658,6 +715,23 @@ struct LtRawDescriptors {
|
||||
preference: cublasLtMatmulPreference_t,
|
||||
}
|
||||
|
||||
static CUBLASLT_HEURISTIC_CACHE: OnceLock<
|
||||
Mutex<Vec<(LtMatmulSpec, cublasLtMatmulHeuristicResult_t)>>,
|
||||
> = OnceLock::new();
|
||||
|
||||
#[cfg(test)]
|
||||
static CUBLASLT_PREPARE_COUNT: AtomicUsize = AtomicUsize::new(0);
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn reset_cublaslt_prepare_count_for_test() {
|
||||
CUBLASLT_PREPARE_COUNT.store(0, Ordering::SeqCst);
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn cublaslt_prepare_count_for_test() -> usize {
|
||||
CUBLASLT_PREPARE_COUNT.load(Ordering::SeqCst)
|
||||
}
|
||||
|
||||
impl Default for LtRawDescriptors {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
@@ -696,6 +770,121 @@ impl Drop for LtRawDescriptors {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct PreparedCuBlasLtMatmul {
|
||||
cublaslt: Arc<CudaBlasLT>,
|
||||
spec: LtMatmulSpec,
|
||||
resources: LtRawDescriptors,
|
||||
heuristic: cublasLtMatmulHeuristicResult_t,
|
||||
_workspace: CudaSlice<u8>,
|
||||
workspace_ptr: u64,
|
||||
_a_scale: Option<CudaSlice<f32>>,
|
||||
default_a_scale_ptr: Option<u64>,
|
||||
_b_scale: Option<CudaSlice<f32>>,
|
||||
default_b_scale_ptr: Option<u64>,
|
||||
_c_scale: Option<CudaSlice<f32>>,
|
||||
_d_scale: Option<CudaSlice<f32>>,
|
||||
}
|
||||
|
||||
impl PreparedCuBlasLtMatmul {
|
||||
fn update_descriptor_pointers(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
ptrs: LtMatmulPointers,
|
||||
) -> anyhow::Result<()> {
|
||||
stream.context().bind_to_thread()?;
|
||||
if let Some(bias_ptr) = ptrs.bias {
|
||||
set_scalar_scale_pointer(
|
||||
self.resources.matmul_desc,
|
||||
cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_BIAS_POINTER,
|
||||
bias_ptr,
|
||||
)?;
|
||||
}
|
||||
if cuda_dtype_needs_tensorwide_scale(self.spec.a.dtype) {
|
||||
let ptr = ptrs.a_scale.or(self.default_a_scale_ptr).ok_or_else(|| {
|
||||
anyhow::anyhow!("cuBLASLt matmul is missing required A scale pointer")
|
||||
})?;
|
||||
set_scalar_scale_pointer(
|
||||
self.resources.matmul_desc,
|
||||
cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
|
||||
ptr,
|
||||
)?;
|
||||
}
|
||||
if cuda_dtype_needs_tensorwide_scale(self.spec.b.dtype) {
|
||||
let ptr = ptrs.b_scale.or(self.default_b_scale_ptr).ok_or_else(|| {
|
||||
anyhow::anyhow!("cuBLASLt matmul is missing required B scale pointer")
|
||||
})?;
|
||||
set_scalar_scale_pointer(
|
||||
self.resources.matmul_desc,
|
||||
cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
|
||||
ptr,
|
||||
)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn enqueue(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
ptrs: LtMatmulPointers,
|
||||
) -> anyhow::Result<()> {
|
||||
self.update_descriptor_pointers(stream, ptrs)?;
|
||||
let alpha_ptr = self.spec.compute.alpha.as_ptr();
|
||||
let beta_ptr = self.spec.compute.beta.as_ptr();
|
||||
unsafe {
|
||||
cublasLtMatmul(
|
||||
*self.cublaslt.handle(),
|
||||
self.resources.matmul_desc,
|
||||
alpha_ptr,
|
||||
ptrs.a as *const std::ffi::c_void,
|
||||
self.resources.a_desc,
|
||||
ptrs.b as *const std::ffi::c_void,
|
||||
self.resources.b_desc,
|
||||
beta_ptr,
|
||||
ptrs.c as *const std::ffi::c_void,
|
||||
self.resources.c_desc,
|
||||
ptrs.d as *mut std::ffi::c_void,
|
||||
self.resources.d_desc,
|
||||
&self.heuristic.algo,
|
||||
self.workspace_ptr as *mut std::ffi::c_void,
|
||||
self.spec.workspace_size,
|
||||
stream.cu_stream() as *mut _,
|
||||
)
|
||||
.result()?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub(crate) struct CuBlasLtCaptureSignature {
|
||||
pub(crate) spec: LtMatmulSpec,
|
||||
pub(crate) ptrs: LtMatmulPointers,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub(crate) struct CuBlasLtPrepareKey {
|
||||
spec: LtMatmulSpec,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub(crate) struct CuBlasLtResolvedGraphCall {
|
||||
pub(crate) spec: LtMatmulSpec,
|
||||
pub(crate) ptrs: LtMatmulPointers,
|
||||
}
|
||||
|
||||
impl CuBlasLtResolvedGraphCall {
|
||||
pub(crate) fn signature(self) -> CuBlasLtCaptureSignature {
|
||||
CuBlasLtCaptureSignature {
|
||||
spec: self.spec,
|
||||
ptrs: self.ptrs,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn prepare_key(self) -> CuBlasLtPrepareKey {
|
||||
CuBlasLtPrepareKey { spec: self.spec }
|
||||
}
|
||||
}
|
||||
|
||||
fn create_matrix_layout(
|
||||
desc: &mut cublasLtMatrixLayout_t,
|
||||
spec: LtMatrixSpec,
|
||||
@@ -772,12 +961,15 @@ fn set_scalar_scale_pointer(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_cublaslt_matmul(
|
||||
pub(crate) fn prepare_cublaslt_matmul(
|
||||
stream: &Arc<CudaStream>,
|
||||
cublaslt: &Arc<CudaBlasLT>,
|
||||
spec: &LtMatmulSpec,
|
||||
ptrs: LtMatmulPointers,
|
||||
) -> anyhow::Result<()> {
|
||||
) -> anyhow::Result<PreparedCuBlasLtMatmul> {
|
||||
#[cfg(test)]
|
||||
CUBLASLT_PREPARE_COUNT.fetch_add(1, Ordering::SeqCst);
|
||||
|
||||
if spec.problem.m == 0 || spec.problem.n == 0 || spec.problem.k == 0 {
|
||||
return Err(anyhow::anyhow!(
|
||||
"cuBLASLT matmul got zero-sized dimensions: m={}, n={}, k={}",
|
||||
@@ -789,17 +981,17 @@ fn run_cublaslt_matmul(
|
||||
|
||||
let mut resources = LtRawDescriptors::default();
|
||||
let mut heuristic: cublasLtMatmulHeuristicResult_t = unsafe { std::mem::zeroed() };
|
||||
let mut algo_count: i32 = 0;
|
||||
|
||||
let workspace = unsafe { stream.alloc::<u8>(spec.workspace_size)? };
|
||||
let (workspace_ptr, _workspace_guard) = workspace.device_ptr(stream);
|
||||
let (workspace_ptr, workspace_guard) = workspace.device_ptr(stream);
|
||||
drop(workspace_guard);
|
||||
|
||||
let a_scale = if cuda_dtype_needs_tensorwide_scale(spec.a.dtype) && ptrs.a_scale.is_none() {
|
||||
let a_scale = if cuda_dtype_needs_tensorwide_scale(spec.a.dtype) {
|
||||
Some(stream.clone_htod(&[1.0f32])?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let b_scale = if cuda_dtype_needs_tensorwide_scale(spec.b.dtype) && ptrs.b_scale.is_none() {
|
||||
let b_scale = if cuda_dtype_needs_tensorwide_scale(spec.b.dtype) {
|
||||
Some(stream.clone_htod(&[1.0f32])?)
|
||||
} else {
|
||||
None
|
||||
@@ -855,29 +1047,27 @@ fn run_cublaslt_matmul(
|
||||
}
|
||||
}
|
||||
|
||||
let (a_scale_ptr, _a_scale_guard) = if let Some(ptr) = ptrs.a_scale {
|
||||
(Some(ptr), None)
|
||||
} else if let Some(scale) = &a_scale {
|
||||
let (default_a_scale_ptr, a_scale_guard) = if let Some(scale) = &a_scale {
|
||||
let (ptr, guard) = scale.device_ptr(stream);
|
||||
(Some(ptr), Some(guard))
|
||||
} else {
|
||||
(None, None)
|
||||
};
|
||||
let (b_scale_ptr, _b_scale_guard) = if let Some(ptr) = ptrs.b_scale {
|
||||
(Some(ptr), None)
|
||||
} else if let Some(scale) = &b_scale {
|
||||
let a_scale_ptr = ptrs.a_scale.or(default_a_scale_ptr);
|
||||
let (default_b_scale_ptr, b_scale_guard) = if let Some(scale) = &b_scale {
|
||||
let (ptr, guard) = scale.device_ptr(stream);
|
||||
(Some(ptr), Some(guard))
|
||||
} else {
|
||||
(None, None)
|
||||
};
|
||||
let (c_scale_ptr, _c_scale_guard) = if let Some(scale) = &c_scale {
|
||||
let b_scale_ptr = ptrs.b_scale.or(default_b_scale_ptr);
|
||||
let (c_scale_ptr, c_scale_guard) = if let Some(scale) = &c_scale {
|
||||
let (ptr, guard) = scale.device_ptr(stream);
|
||||
(Some(ptr), Some(guard))
|
||||
} else {
|
||||
(None, None)
|
||||
};
|
||||
let (d_scale_ptr, _d_scale_guard) = if let Some(scale) = &d_scale {
|
||||
let (d_scale_ptr, d_scale_guard) = if let Some(scale) = &d_scale {
|
||||
let (ptr, guard) = scale.device_ptr(stream);
|
||||
(Some(ptr), Some(guard))
|
||||
} else {
|
||||
@@ -911,6 +1101,7 @@ fn run_cublaslt_matmul(
|
||||
ptr,
|
||||
)?;
|
||||
}
|
||||
drop((a_scale_guard, b_scale_guard, c_scale_guard, d_scale_guard));
|
||||
|
||||
create_matrix_layout(&mut resources.a_desc, spec.a)?;
|
||||
create_matrix_layout(&mut resources.b_desc, spec.b)?;
|
||||
@@ -928,58 +1119,148 @@ fn run_cublaslt_matmul(
|
||||
}
|
||||
}
|
||||
|
||||
unsafe {
|
||||
cublasLtMatmulPreferenceCreate(&mut resources.preference).result()?;
|
||||
cublasLtMatmulPreferenceSetAttribute(
|
||||
resources.preference,
|
||||
cublasLtMatmulPreferenceAttributes_t::CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
|
||||
&spec.workspace_size as *const _ as *const std::ffi::c_void,
|
||||
std::mem::size_of::<usize>(),
|
||||
)
|
||||
.result()?;
|
||||
let heuristic_cache = CUBLASLT_HEURISTIC_CACHE.get_or_init(|| Mutex::new(Vec::new()));
|
||||
let cached_heuristic = {
|
||||
let cache = heuristic_cache.lock().unwrap();
|
||||
cache
|
||||
.iter()
|
||||
.find(|(cached_spec, _)| cached_spec == spec)
|
||||
.map(|(_, heuristic)| unsafe { std::ptr::read(heuristic) })
|
||||
};
|
||||
if let Some(cached) = cached_heuristic {
|
||||
heuristic = cached;
|
||||
} else {
|
||||
let mut algo_count: i32 = 0;
|
||||
unsafe {
|
||||
cublasLtMatmulPreferenceCreate(&mut resources.preference).result()?;
|
||||
cublasLtMatmulPreferenceSetAttribute(
|
||||
resources.preference,
|
||||
cublasLtMatmulPreferenceAttributes_t::CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
|
||||
&spec.workspace_size as *const _ as *const std::ffi::c_void,
|
||||
std::mem::size_of::<usize>(),
|
||||
)
|
||||
.result()?;
|
||||
|
||||
cublasLtMatmulAlgoGetHeuristic(
|
||||
*cublaslt.handle(),
|
||||
resources.matmul_desc,
|
||||
resources.a_desc,
|
||||
resources.b_desc,
|
||||
resources.c_desc,
|
||||
resources.d_desc,
|
||||
resources.preference,
|
||||
1,
|
||||
&mut heuristic,
|
||||
&mut algo_count,
|
||||
)
|
||||
.result()?;
|
||||
cublasLtMatmulAlgoGetHeuristic(
|
||||
*cublaslt.handle(),
|
||||
resources.matmul_desc,
|
||||
resources.a_desc,
|
||||
resources.b_desc,
|
||||
resources.c_desc,
|
||||
resources.d_desc,
|
||||
resources.preference,
|
||||
1,
|
||||
&mut heuristic,
|
||||
&mut algo_count,
|
||||
)
|
||||
.result()?;
|
||||
|
||||
if algo_count == 0 {
|
||||
return Err(anyhow::anyhow!("No suitable cuBLASLT algorithm found"));
|
||||
if algo_count == 0 {
|
||||
return Err(anyhow::anyhow!("No suitable cuBLASLT algorithm found"));
|
||||
}
|
||||
}
|
||||
|
||||
let alpha_ptr = spec.compute.alpha.as_ptr();
|
||||
let beta_ptr = spec.compute.beta.as_ptr();
|
||||
cublasLtMatmul(
|
||||
*cublaslt.handle(),
|
||||
resources.matmul_desc,
|
||||
alpha_ptr,
|
||||
ptrs.a as *const std::ffi::c_void,
|
||||
resources.a_desc,
|
||||
ptrs.b as *const std::ffi::c_void,
|
||||
resources.b_desc,
|
||||
beta_ptr,
|
||||
ptrs.c as *const std::ffi::c_void,
|
||||
resources.c_desc,
|
||||
ptrs.d as *mut std::ffi::c_void,
|
||||
resources.d_desc,
|
||||
&heuristic.algo,
|
||||
workspace_ptr as *mut std::ffi::c_void,
|
||||
spec.workspace_size,
|
||||
stream.cu_stream() as *mut _,
|
||||
)
|
||||
.result()?;
|
||||
heuristic_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.push((*spec, unsafe { std::ptr::read(&heuristic) }));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
Ok(PreparedCuBlasLtMatmul {
|
||||
cublaslt: cublaslt.clone(),
|
||||
spec: *spec,
|
||||
resources,
|
||||
heuristic,
|
||||
_workspace: workspace,
|
||||
workspace_ptr,
|
||||
_a_scale: a_scale,
|
||||
default_a_scale_ptr,
|
||||
_b_scale: b_scale,
|
||||
default_b_scale_ptr,
|
||||
_c_scale: c_scale,
|
||||
_d_scale: d_scale,
|
||||
})
|
||||
}
|
||||
|
||||
fn run_cublaslt_matmul(
|
||||
stream: &Arc<CudaStream>,
|
||||
cublaslt: &Arc<CudaBlasLT>,
|
||||
spec: &LtMatmulSpec,
|
||||
ptrs: LtMatmulPointers,
|
||||
) -> anyhow::Result<()> {
|
||||
let prepared = prepare_cublaslt_matmul(stream, cublaslt, spec, ptrs)?;
|
||||
prepared.enqueue(stream, ptrs)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn cublaslt_graph_capture_supported(stream: &Arc<CudaStream>) -> bool {
|
||||
fn probe(stream: &Arc<CudaStream>) -> anyhow::Result<()> {
|
||||
let capture_stream = stream.context().new_stream()?;
|
||||
let cublaslt = try_create_cublaslt(stream.clone())
|
||||
.map_err(|message| anyhow::anyhow!("cuBLASLt unavailable: {message}"))?;
|
||||
|
||||
let a_buf = stream.clone_htod(&[1.0f32])?;
|
||||
let b_buf = stream.clone_htod(&[1.0f32])?;
|
||||
let d_buf = unsafe { stream.alloc::<f32>(1)? };
|
||||
let (a, a_guard) = a_buf.device_ptr(stream);
|
||||
let (b, b_guard) = b_buf.device_ptr(stream);
|
||||
let (d, d_guard) = d_buf.device_ptr(stream);
|
||||
drop((a_guard, b_guard, d_guard));
|
||||
|
||||
let matrix = LtMatrixSpec {
|
||||
dtype: cudaDataType::CUDA_R_32F,
|
||||
rows: 1,
|
||||
cols: 1,
|
||||
ld: 1,
|
||||
batch_stride: 1,
|
||||
order: cublasLtOrder_t::CUBLASLT_ORDER_ROW,
|
||||
};
|
||||
let spec = LtMatmulSpec {
|
||||
problem: LtMatmulProblem {
|
||||
m: 1,
|
||||
n: 1,
|
||||
k: 1,
|
||||
batch_count: 1,
|
||||
},
|
||||
trans_a: cublasOperation_t::CUBLAS_OP_N,
|
||||
trans_b: cublasOperation_t::CUBLAS_OP_N,
|
||||
a: matrix,
|
||||
b: matrix,
|
||||
c: matrix,
|
||||
d: matrix,
|
||||
compute: LtComputeSpec {
|
||||
compute_type: cublasComputeType_t::CUBLAS_COMPUTE_32F,
|
||||
scale_dtype: cudaDataType::CUDA_R_32F,
|
||||
alpha: LtScalar::F32(1.0),
|
||||
beta: LtScalar::F32(0.0),
|
||||
epilogue: cublasLtEpilogue_t::CUBLASLT_EPILOGUE_DEFAULT,
|
||||
},
|
||||
workspace_size: 1024 * 1024,
|
||||
};
|
||||
let ptrs = LtMatmulPointers {
|
||||
a,
|
||||
b,
|
||||
c: d,
|
||||
d,
|
||||
bias: None,
|
||||
a_scale: None,
|
||||
b_scale: None,
|
||||
};
|
||||
let prepared = prepare_cublaslt_matmul(stream, &cublaslt, &spec, ptrs)?;
|
||||
|
||||
let mut graph = CudaGraphHandle::new(stream.context().clone())?;
|
||||
let entry = graph.add_empty_node(&[])?;
|
||||
capture_stream.join(stream)?;
|
||||
graph.begin_capture_to_graph(&capture_stream, &[entry])?;
|
||||
let enqueue_result = prepared.enqueue(&capture_stream, ptrs);
|
||||
let end_result = graph.end_capture(&capture_stream);
|
||||
enqueue_result?;
|
||||
end_result?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
let supported = probe(stream).is_ok();
|
||||
let _ = stream.synchronize();
|
||||
supported
|
||||
}
|
||||
|
||||
fn resolve_cublaslt_pointers(
|
||||
@@ -1102,6 +1383,151 @@ impl CuBlasLt {
|
||||
Ok(created)
|
||||
}
|
||||
|
||||
pub(crate) fn graph_inputs(&self) -> usize {
|
||||
self.n_inputs()
|
||||
}
|
||||
|
||||
pub(crate) fn resolve_for_graph(
|
||||
&self,
|
||||
self_node: NodeIndex,
|
||||
inputs: &[NodeIndex],
|
||||
buffers: &FxHashMap<NodeIndex, DeviceBuffer>,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> anyhow::Result<CuBlasLtResolvedGraphCall> {
|
||||
let resolve = |e: &Expression| -> Expression { e.substitute('z', Expression::from(1)) };
|
||||
let m = resolve(&self.m).exec(dyn_map).unwrap() as u64;
|
||||
let n = resolve(&self.n).exec(dyn_map).unwrap() as u64;
|
||||
let k = resolve(&self.k).exec(dyn_map).unwrap() as u64;
|
||||
let a_layout = self.a_layout;
|
||||
let b_layout = self.b_layout;
|
||||
let lda = resolve(&self.lda).exec(dyn_map).unwrap() as i64;
|
||||
let ldb = resolve(&self.ldb).exec(dyn_map).unwrap() as i64;
|
||||
let ldc = resolve(&self.ldc).exec(dyn_map).unwrap() as i64;
|
||||
let ldd = resolve(&self.ldd).exec(dyn_map).unwrap() as i64;
|
||||
let batch_count = resolve(&self.batch_count).exec(dyn_map).unwrap() as i32;
|
||||
let stride_a = resolve(&self.stride_a).exec(dyn_map).unwrap() as i64;
|
||||
let stride_b = resolve(&self.stride_b).exec(dyn_map).unwrap() as i64;
|
||||
let stride_c = resolve(&self.stride_c).exec(dyn_map).unwrap() as i64;
|
||||
let stride_d = resolve(&self.stride_d).exec(dyn_map).unwrap() as i64;
|
||||
|
||||
let a_cuda_dtype = dtype_to_cuda_dtype(self.a_dtype);
|
||||
let b_cuda_dtype = dtype_to_cuda_dtype(self.b_dtype);
|
||||
let c_cuda_dtype = dtype_to_cuda_dtype(self.c_dtype);
|
||||
let d_cuda_dtype = dtype_to_cuda_dtype(self.d_dtype);
|
||||
let scale_cuda_dtype = dtype_to_cuda_dtype(self.scale_dtype);
|
||||
let element_size = (self.d_dtype.bits() / 8) as u64;
|
||||
assert!(
|
||||
element_size > 0,
|
||||
"cuBLAS LT does not support sub-byte dtype {}",
|
||||
self.d_dtype
|
||||
);
|
||||
|
||||
let alpha = LtScalar::from_f64(self.scale_dtype, self.alpha)?;
|
||||
let beta = LtScalar::from_f64(self.scale_dtype, self.beta)?;
|
||||
|
||||
let ptrs = resolve_cublaslt_pointers(
|
||||
self_node,
|
||||
inputs,
|
||||
buffers,
|
||||
self.beta,
|
||||
self.epilogue,
|
||||
self.a_scale_input,
|
||||
self.b_scale_input,
|
||||
)?;
|
||||
|
||||
let (a_rows, a_cols) = if a_layout == cublasOperation_t::CUBLAS_OP_N {
|
||||
(m, k)
|
||||
} else {
|
||||
(k, m)
|
||||
};
|
||||
let (b_rows, b_cols) = if b_layout == cublasOperation_t::CUBLAS_OP_N {
|
||||
(k, n)
|
||||
} else {
|
||||
(n, k)
|
||||
};
|
||||
let lda = clamp_ld_for_order(lda, a_rows, a_cols, self.a_order);
|
||||
let ldb = clamp_ld_for_order(ldb, b_rows, b_cols, self.b_order);
|
||||
let ldc = clamp_ld_for_order(ldc, m, n, self.c_order);
|
||||
let ldd = clamp_ld_for_order(ldd, m, n, self.d_order);
|
||||
|
||||
let _span = span!(
|
||||
Level::TRACE,
|
||||
"cuBLASLT_resolve_graph",
|
||||
m, n, k, lda, ldb, ldc, ldd, batch_count, ?a_layout, ?b_layout,
|
||||
?self.a_order, ?self.b_order, ?self.c_order, ?self.d_order,
|
||||
?self.a_dtype, ?self.b_dtype, ?self.c_dtype, ?self.d_dtype,
|
||||
?self.compute_type, ?self.scale_dtype, self.alpha, self.beta,
|
||||
?self.epilogue,
|
||||
)
|
||||
.entered();
|
||||
|
||||
const WORKSPACE_SIZE: usize = 32 * 1024 * 1024;
|
||||
let c_spec = LtMatrixSpec {
|
||||
dtype: c_cuda_dtype,
|
||||
rows: m,
|
||||
cols: n,
|
||||
ld: ldc,
|
||||
batch_stride: stride_c,
|
||||
order: self.c_order,
|
||||
};
|
||||
let d_spec = LtMatrixSpec {
|
||||
dtype: d_cuda_dtype,
|
||||
rows: m,
|
||||
cols: n,
|
||||
ld: ldd,
|
||||
batch_stride: stride_d,
|
||||
order: self.d_order,
|
||||
};
|
||||
let spec = LtMatmulSpec {
|
||||
problem: LtMatmulProblem {
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
batch_count,
|
||||
},
|
||||
trans_a: a_layout,
|
||||
trans_b: b_layout,
|
||||
a: LtMatrixSpec {
|
||||
dtype: a_cuda_dtype,
|
||||
rows: a_rows,
|
||||
cols: a_cols,
|
||||
ld: lda,
|
||||
batch_stride: stride_a,
|
||||
order: self.a_order,
|
||||
},
|
||||
b: LtMatrixSpec {
|
||||
dtype: b_cuda_dtype,
|
||||
rows: b_rows,
|
||||
cols: b_cols,
|
||||
ld: ldb,
|
||||
batch_stride: stride_b,
|
||||
order: self.b_order,
|
||||
},
|
||||
c: c_spec,
|
||||
d: d_spec,
|
||||
compute: LtComputeSpec {
|
||||
compute_type: self.compute_type,
|
||||
scale_dtype: scale_cuda_dtype,
|
||||
alpha,
|
||||
beta,
|
||||
epilogue: self.epilogue,
|
||||
},
|
||||
workspace_size: WORKSPACE_SIZE,
|
||||
};
|
||||
|
||||
Ok(CuBlasLtResolvedGraphCall { spec, ptrs })
|
||||
}
|
||||
|
||||
pub(crate) fn prepare_resolved_for_graph(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
resolved: CuBlasLtResolvedGraphCall,
|
||||
) -> anyhow::Result<PreparedCuBlasLtMatmul> {
|
||||
let _span = span!(Level::TRACE, "cuBLASLT_prepare_graph").entered();
|
||||
let cublaslt = self.get_cublaslt(stream)?;
|
||||
prepare_cublaslt_matmul(stream, &cublaslt, &resolved.spec, resolved.ptrs)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn type_tuple(&self) -> (DType, DType, DType, DType, &'static str, DType) {
|
||||
(
|
||||
|
||||
@@ -2,13 +2,11 @@ use std::{fmt::Debug, sync::Arc};
|
||||
|
||||
use crate::cudarc::driver::{CudaStream, DriverError, result};
|
||||
use luminal::{op::EgglogOp, prelude::*};
|
||||
mod cublas;
|
||||
mod cublaslt;
|
||||
pub(crate) mod cublaslt;
|
||||
pub mod flashinfer;
|
||||
pub mod moe;
|
||||
|
||||
pub type Ops = (
|
||||
// cublas::CuBlasSgemmV2,
|
||||
cublaslt::CuBlasLt,
|
||||
cublaslt::CuBlasLtScaled,
|
||||
moe::GLUMoE,
|
||||
@@ -169,6 +167,15 @@ pub trait HostOp: Debug + as_any::AsAny + EgglogOp {
|
||||
None
|
||||
}
|
||||
|
||||
/// Returns pairs of extra buffer nodes that must not share arena storage.
|
||||
///
|
||||
/// This refines `extra_buffer_lifetimes` for host ops with internal DAGs:
|
||||
/// two buffers may have disjoint positions in one topological order while
|
||||
/// still being unordered by real dependencies, so CUDA could overlap them.
|
||||
fn extra_buffer_conflicts(&self) -> Option<Vec<(NodeIndex, NodeIndex)>> {
|
||||
None
|
||||
}
|
||||
|
||||
/// Returns buffer size requirements for extra nodes (node -> size in elements).
|
||||
///
|
||||
/// Called during buffer allocation to ensure all required buffers exist.
|
||||
|
||||
@@ -1,60 +1,533 @@
|
||||
//! Direct conv2d_bias kernel — fuses unfold + matmul + bias into one
|
||||
//! CUDA kernel with no `(H_out*W_out, C_in*K*K)` intermediate matrix.
|
||||
//! CUDA conv2d-with-bias backend rewrite.
|
||||
//!
|
||||
//! This is exposed as a luminal `CustomOp`, not a standard egglog-rewritten
|
||||
//! `KernelOp`, because the conv has no useful fusion opportunities with
|
||||
//! surrounding ops in the graphs it's used in (the VAE's resnet blocks),
|
||||
//! and pattern-matching the unfold+permute+merge_dims+matmul+bias chain
|
||||
//! reliably from egglog is significantly more work than just bypassing
|
||||
//! the egglog rewrite path entirely.
|
||||
//!
|
||||
//! The kernel is one-thread-per-output: each thread computes
|
||||
//! `out[co, ho, wo] = bias[co] + sum_{ci,ki,kj} input[ci, ho*S+ki-P, wo*S+kj-P] * weight[co, ci, ki, kj]`
|
||||
//! with bounds checks on the spatial dims for padding. This is far from
|
||||
//! peak FLOPs (no shared-memory tiling, no warp-level reduction over K)
|
||||
//! but it's correct and the memory footprint is just the input + weight +
|
||||
//! bias + output buffers — no `(M, K)` or `(M, N, K)` intermediate, so it
|
||||
//! scales linearly with the actual conv FLOPs rather than blowing up at
|
||||
//! large H/W like the unfold-based formulation.
|
||||
//! `KernelConv2D` is selected by egglog from pure HLIR conv graphs and lowers
|
||||
//! to a one-thread-per-output CUDA kernel. It avoids materializing unfold/im2col
|
||||
//! intermediates while keeping model code free of custom ops.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use cudarc::driver::{CudaFunction, CudaModule, CudaSlice, CudaStream};
|
||||
use luminal::prelude::FxHashMap;
|
||||
use luminal::{
|
||||
dtype::DType, graph::Graph, op::CustomOp, op::LLIROp, prelude::GraphTensor, shape::Expression,
|
||||
dtype::DType,
|
||||
egglog_utils::{
|
||||
api::{Rule, SortDef, sort},
|
||||
base::{DTYPE, ELIST, EXPRESSION, OP_KIND},
|
||||
extract_dtype, extract_expr, extract_expr_list,
|
||||
},
|
||||
op::{EgglogOp, LLIROp},
|
||||
prelude::FxHashSet,
|
||||
shape::{Expression, flatten_strides},
|
||||
};
|
||||
|
||||
use crate::compile_module_image_for_current_device;
|
||||
use crate::kernel::KernelOp;
|
||||
use crate::kernel::{KernelOp, hlir::generate_dyn_dims_defines};
|
||||
|
||||
/// Direct conv2d-with-bias kernel. All shape/kernel params are static
|
||||
/// (baked into the CUDA source via #defines), so each conv shape gets
|
||||
/// its own compiled kernel. Inputs (in order): input `(C_in, H_in, W_in)`,
|
||||
/// weight `(C_out, C_in*K*K)` (i.e. flattened `(C_out, C_in, K, K)`), bias
|
||||
/// `(C_out,)`. Output: `(C_out, H_out, W_out)`.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Conv2DKernel {
|
||||
pub c_in: usize,
|
||||
pub h_in: usize,
|
||||
pub w_in: usize,
|
||||
pub c_out: usize,
|
||||
pub kernel: usize,
|
||||
pub stride: usize,
|
||||
pub padding: usize,
|
||||
pub h_out: usize,
|
||||
pub w_out: usize,
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct KernelConv2D {
|
||||
out_shape: Vec<Expression>,
|
||||
input_shape: Vec<Expression>,
|
||||
input_stride: Vec<Expression>,
|
||||
weight_co_stride: Expression,
|
||||
weight_inner_stride: Expression,
|
||||
bias_c_stride: Expression,
|
||||
out_stride: Vec<Expression>,
|
||||
kernel_h: Expression,
|
||||
kernel_w: Expression,
|
||||
stride_h: Expression,
|
||||
stride_w: Expression,
|
||||
dilation_h: Expression,
|
||||
dilation_w: Expression,
|
||||
pad_h: Expression,
|
||||
pad_w: Expression,
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
impl Conv2DKernel {
|
||||
fn output_elements(&self) -> usize {
|
||||
self.c_out * self.h_out * self.w_out
|
||||
impl EgglogOp for KernelConv2D {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
"KernelConv2D",
|
||||
&[
|
||||
("out_shape", ELIST),
|
||||
("input_shape", ELIST),
|
||||
("input_stride", ELIST),
|
||||
("weight_co_stride", EXPRESSION),
|
||||
("weight_inner_stride", EXPRESSION),
|
||||
("bias_c_stride", EXPRESSION),
|
||||
("out_stride", ELIST),
|
||||
("kernel_h", EXPRESSION),
|
||||
("kernel_w", EXPRESSION),
|
||||
("stride_h", EXPRESSION),
|
||||
("stride_w", EXPRESSION),
|
||||
("dilation_h", EXPRESSION),
|
||||
("dilation_w", EXPRESSION),
|
||||
("pad_h", EXPRESSION),
|
||||
("pad_w", EXPRESSION),
|
||||
("dtype", DTYPE),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
fn n_inputs(&self) -> usize {
|
||||
3
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![
|
||||
// 1x1 convs in Flux2's VAE are represented without `unfold`:
|
||||
//
|
||||
// input.permute([H,W,C]).merge(H,W)
|
||||
// -> matmul(weight.t())
|
||||
// -> split/permute back to [C_out,H,W]
|
||||
// -> + channel bias
|
||||
//
|
||||
// The lowered form is still the same Mul -> KernelSum -> Add
|
||||
// matmul skeleton, but the lhs FusionStart reads directly from the
|
||||
// original input instead of a KernelGather window tensor.
|
||||
Rule::raw(
|
||||
"(rule
|
||||
(
|
||||
(= ?out (Op (FusionEnd ?out_shape ?out_stride (F32)) (ICons ?add_elem (INil))))
|
||||
(= ?add_elem (Op (CudaBinaryElementwise \"Add\" ?out_shape ?sum_add_stride ?bias_add_stride ?out_stride (F32)) (ICons ?sum_fs (ICons ?bias_fs (INil)))))
|
||||
(= ?sum_fs (Op (FusionStart ?out_shape ?sum_add_stride (F32)) (ICons ?sum (INil))))
|
||||
(= ?bias_fs (Op (FusionStart ?out_shape ?bias_add_stride (F32)) (ICons ?bias (INil))))
|
||||
|
||||
(= ?sum (Op (KernelSum ?matmul_out_shape ?c_in ?sum_in_stride ?k_stride ?sum_out_stride (F32)) (ICons ?mul_fe (INil))))
|
||||
(= ?mul_fe (Op (FusionEnd ?mul_shape ?mul_out_stride (F32)) (ICons ?mul_elem (INil))))
|
||||
(= ?mul_elem (Op (CudaBinaryElementwise \"Mul\" ?mul_shape ?input_1x1_stride ?weight_stride ?mul_out_stride (F32)) (ICons ?input_fs (ICons ?weight_fs (INil)))))
|
||||
(= ?input_fs (Op (FusionStart ?mul_shape ?input_1x1_stride (F32)) (ICons ?input (INil))))
|
||||
(= ?weight_fs (Op (FusionStart ?mul_shape ?weight_stride (F32)) (ICons ?weight (INil))))
|
||||
|
||||
(= ?out_shape (ECons ?c_out (ECons ?h_out (ECons ?w_out (ENil)))))
|
||||
(= ?matmul_out_shape (ECons ?m (ECons ?c_out (ENil))))
|
||||
(= ?mul_shape (ECons ?m (ECons ?c_out (ECons ?c_in (ENil)))))
|
||||
(= ?input_1x1_stride (ECons ?flat_stride (ECons (MNum 0) (ECons ?input_c_stride (ENil)))))
|
||||
(= ?flat_stride (MIter))
|
||||
|
||||
(= ?k_stride (MIter))
|
||||
(= ?sum_in_stride (ECons ?sum_m_stride (ECons ?sum_c_stride (ENil))))
|
||||
(= ?sum_out_stride (ECons ?sum_out_m_stride (ECons ?sum_out_c_stride (ENil))))
|
||||
(= ?sum_add_stride (ECons ?sum_add_c_stride (ECons ?sum_add_h_stride (ECons ?sum_add_w_stride (ENil)))))
|
||||
(= ?weight_co_stride (nth_from_end ?weight_stride 1))
|
||||
(= ?weight_inner_stride (nth_from_end ?weight_stride 0))
|
||||
(= (MNum 0) (nth_from_end ?weight_stride 2))
|
||||
(= ?bias_add_stride (ECons ?bias_c_stride (ECons (MNum 0) (ECons (MNum 0) (ENil)))))
|
||||
)
|
||||
(
|
||||
(let ?conv (Op (KernelConv2D
|
||||
?out_shape
|
||||
(ECons ?c_in (ECons ?h_out (ECons ?w_out (ENil))))
|
||||
(ECons ?input_c_stride (ECons (MMul ?w_out ?flat_stride) (ECons ?flat_stride (ENil))))
|
||||
?weight_co_stride
|
||||
?weight_inner_stride
|
||||
?bias_c_stride
|
||||
?out_stride
|
||||
(MNum 1)
|
||||
(MNum 1)
|
||||
(MNum 1)
|
||||
(MNum 1)
|
||||
(MNum 1)
|
||||
(MNum 1)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(F32))
|
||||
(ICons ?input (ICons ?weight (ICons ?bias (INil))))))
|
||||
(union ?out ?conv)
|
||||
(subsume (Op (FusionEnd ?out_shape ?out_stride (F32)) (ICons ?add_elem (INil))))
|
||||
(set (dtype ?conv) (F32))
|
||||
)
|
||||
:ruleset kernel_lower
|
||||
:name \"kernel conv2d 1x1 from cuda lowered matmul bias\"
|
||||
)",
|
||||
),
|
||||
Rule::raw(
|
||||
"(rule
|
||||
(
|
||||
(= ?out (Op (FusionEnd ?out_shape ?out_stride (F32)) (ICons ?add_elem (INil))))
|
||||
(= ?add_elem (Op (CudaBinaryElementwise \"Add\" ?out_shape ?bias_add_stride ?sum_add_stride ?out_stride (F32)) (ICons ?bias_fs (ICons ?sum_fs (INil)))))
|
||||
(= ?sum_fs (Op (FusionStart ?out_shape ?sum_add_stride (F32)) (ICons ?sum (INil))))
|
||||
(= ?bias_fs (Op (FusionStart ?out_shape ?bias_add_stride (F32)) (ICons ?bias (INil))))
|
||||
|
||||
(= ?sum (Op (KernelSum ?matmul_out_shape ?c_in ?sum_in_stride ?k_stride ?sum_out_stride (F32)) (ICons ?mul_fe (INil))))
|
||||
(= ?mul_fe (Op (FusionEnd ?mul_shape ?mul_out_stride (F32)) (ICons ?mul_elem (INil))))
|
||||
(= ?mul_elem (Op (CudaBinaryElementwise \"Mul\" ?mul_shape ?input_1x1_stride ?weight_stride ?mul_out_stride (F32)) (ICons ?input_fs (ICons ?weight_fs (INil)))))
|
||||
(= ?input_fs (Op (FusionStart ?mul_shape ?input_1x1_stride (F32)) (ICons ?input (INil))))
|
||||
(= ?weight_fs (Op (FusionStart ?mul_shape ?weight_stride (F32)) (ICons ?weight (INil))))
|
||||
|
||||
(= ?out_shape (ECons ?c_out (ECons ?h_out (ECons ?w_out (ENil)))))
|
||||
(= ?matmul_out_shape (ECons ?m (ECons ?c_out (ENil))))
|
||||
(= ?mul_shape (ECons ?m (ECons ?c_out (ECons ?c_in (ENil)))))
|
||||
(= ?input_1x1_stride (ECons ?flat_stride (ECons (MNum 0) (ECons ?input_c_stride (ENil)))))
|
||||
(= ?flat_stride (MIter))
|
||||
|
||||
(= ?k_stride (MIter))
|
||||
(= ?sum_in_stride (ECons ?sum_m_stride (ECons ?sum_c_stride (ENil))))
|
||||
(= ?sum_out_stride (ECons ?sum_out_m_stride (ECons ?sum_out_c_stride (ENil))))
|
||||
(= ?sum_add_stride (ECons ?sum_add_c_stride (ECons ?sum_add_h_stride (ECons ?sum_add_w_stride (ENil)))))
|
||||
(= ?weight_co_stride (nth_from_end ?weight_stride 1))
|
||||
(= ?weight_inner_stride (nth_from_end ?weight_stride 0))
|
||||
(= (MNum 0) (nth_from_end ?weight_stride 2))
|
||||
(= ?bias_add_stride (ECons ?bias_c_stride (ECons (MNum 0) (ECons (MNum 0) (ENil)))))
|
||||
)
|
||||
(
|
||||
(let ?conv (Op (KernelConv2D
|
||||
?out_shape
|
||||
(ECons ?c_in (ECons ?h_out (ECons ?w_out (ENil))))
|
||||
(ECons ?input_c_stride (ECons (MMul ?w_out ?flat_stride) (ECons ?flat_stride (ENil))))
|
||||
?weight_co_stride
|
||||
?weight_inner_stride
|
||||
?bias_c_stride
|
||||
?out_stride
|
||||
(MNum 1)
|
||||
(MNum 1)
|
||||
(MNum 1)
|
||||
(MNum 1)
|
||||
(MNum 1)
|
||||
(MNum 1)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(F32))
|
||||
(ICons ?input (ICons ?weight (ICons ?bias (INil))))))
|
||||
(union ?out ?conv)
|
||||
(subsume (Op (FusionEnd ?out_shape ?out_stride (F32)) (ICons ?add_elem (INil))))
|
||||
(set (dtype ?conv) (F32))
|
||||
)
|
||||
:ruleset kernel_lower
|
||||
:name \"kernel conv2d 1x1 from cuda lowered bias matmul\"
|
||||
)",
|
||||
),
|
||||
// Match the same conv after generic CUDA lowering has normalized
|
||||
// the elementwise pieces into fusion regions:
|
||||
//
|
||||
// KernelGather(input windows)
|
||||
// -> CudaBinaryElementwise("Mul", weight)
|
||||
// -> KernelSum(reduce K)
|
||||
// -> CudaBinaryElementwise("Add", bias)
|
||||
//
|
||||
// This is the form that survives long enough for CUDA search in
|
||||
// real models. The KernelConv2D op consumes the pre-gather input
|
||||
// and avoids materializing both the im2col window tensor and the
|
||||
// elementwise product tensor.
|
||||
//
|
||||
// TODO(egglog-shapes): the current e-graph does not reliably prove
|
||||
// the derived arithmetic equalities for this chain after CUDA
|
||||
// normalization:
|
||||
// * `M == H_out * W_out`
|
||||
// * `K == C_in * KH * KW`
|
||||
// * separately-derived but structurally identical stride
|
||||
// expressions, e.g. the Mul output stride and KernelSum input
|
||||
// stride, belong to the same e-class.
|
||||
// Keep the rewrite anchored on the stable conv layout facts the
|
||||
// graph does carry today: six-axis unfold window shape, flattened
|
||||
// `[M, C_out, K]` product, reduction over `K`, the three-axis
|
||||
// `[C_out, H_out, W_out]` output view, and channel-only bias
|
||||
// broadcast. Once expression/list canonicalization can prove those
|
||||
// equalities, tighten this rule and its regression tests.
|
||||
Rule::raw(
|
||||
"(rule
|
||||
(
|
||||
(= ?out (Op (FusionEnd ?out_shape ?out_stride (F32)) (ICons ?add_elem (INil))))
|
||||
(= ?add_elem (Op (CudaBinaryElementwise \"Add\" ?out_shape ?sum_add_stride ?bias_add_stride ?out_stride (F32)) (ICons ?sum_fs (ICons ?bias_fs (INil)))))
|
||||
(= ?sum_fs (Op (FusionStart ?out_shape ?sum_add_stride (F32)) (ICons ?sum (INil))))
|
||||
(= ?bias_fs (Op (FusionStart ?out_shape ?bias_add_stride (F32)) (ICons ?bias (INil))))
|
||||
|
||||
(= ?sum (Op (KernelSum ?matmul_out_shape ?k_dim ?sum_in_stride ?k_stride ?sum_out_stride (F32)) (ICons ?mul_fe (INil))))
|
||||
(= ?mul_fe (Op (FusionEnd ?mul_shape ?mul_out_stride (F32)) (ICons ?mul_elem (INil))))
|
||||
(= ?mul_elem (Op (CudaBinaryElementwise \"Mul\" ?mul_shape ?patch_stride ?weight_stride ?mul_out_stride (F32)) (ICons ?patch_fs (ICons ?weight_fs (INil)))))
|
||||
(= ?patch_fs (Op (FusionStart ?mul_shape ?patch_stride (F32)) (ICons ?patches (INil))))
|
||||
(= ?weight_fs (Op (FusionStart ?mul_shape ?weight_stride (F32)) (ICons ?weight (INil))))
|
||||
(= ?patches (Op (KernelGather ?idx_shape ?idx_stride ?input_shape ?input_stride ?gather_out_stride (F32)) (ICons ?indices (ICons ?input (INil)))))
|
||||
|
||||
(= ?out_shape (ECons ?c_out (ECons ?h_out (ECons ?w_out (ENil)))))
|
||||
(= ?input_shape (ECons ?c_in (ECons ?h_in (ECons ?w_in (ENil)))))
|
||||
(= ?idx_shape (ECons ?c_in (ECons ?h_out (ECons ?w_out (ECons (MNum 1) (ECons ?kernel_h (ECons ?kernel_w (ENil))))))))
|
||||
(= ?matmul_out_shape (ECons ?m (ECons ?c_out (ENil))))
|
||||
(= ?mul_shape (ECons ?m (ECons ?c_out (ECons ?k_dim (ENil)))))
|
||||
|
||||
(= ?k_stride (MIter))
|
||||
(= ?sum_in_stride (ECons ?sum_m_stride (ECons ?sum_c_stride (ENil))))
|
||||
(= ?sum_out_stride (ECons ?sum_out_m_stride (ECons ?sum_out_c_stride (ENil))))
|
||||
(= ?sum_add_stride (ECons ?sum_add_c_stride (ECons ?sum_add_h_stride (ECons ?sum_add_w_stride (ENil)))))
|
||||
(= ?weight_co_stride (nth_from_end ?weight_stride 1))
|
||||
(= ?weight_inner_stride (nth_from_end ?weight_stride 0))
|
||||
(= (MNum 0) (nth_from_end ?weight_stride 2))
|
||||
(= ?bias_add_stride (ECons ?bias_c_stride (ECons (MNum 0) (ECons (MNum 0) (ENil)))))
|
||||
)
|
||||
(
|
||||
(let ?conv (Op (KernelConv2D
|
||||
?out_shape
|
||||
?input_shape
|
||||
?input_stride
|
||||
?weight_co_stride
|
||||
?weight_inner_stride
|
||||
?bias_c_stride
|
||||
?out_stride
|
||||
?kernel_h
|
||||
?kernel_w
|
||||
(MNum 1)
|
||||
(MNum 1)
|
||||
(MNum 1)
|
||||
(MNum 1)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(F32))
|
||||
(ICons ?input (ICons ?weight (ICons ?bias (INil))))))
|
||||
(union ?out ?conv)
|
||||
(subsume (Op (FusionEnd ?out_shape ?out_stride (F32)) (ICons ?add_elem (INil))))
|
||||
(set (dtype ?conv) (F32))
|
||||
)
|
||||
:ruleset kernel_lower
|
||||
:name \"kernel conv2d from cuda lowered unfold matmul bias\"
|
||||
)",
|
||||
),
|
||||
Rule::raw(
|
||||
"(rule
|
||||
(
|
||||
(= ?out (Op (FusionEnd ?out_shape ?out_stride (F32)) (ICons ?add_elem (INil))))
|
||||
(= ?add_elem (Op (CudaBinaryElementwise \"Add\" ?out_shape ?bias_add_stride ?sum_add_stride ?out_stride (F32)) (ICons ?bias_fs (ICons ?sum_fs (INil)))))
|
||||
(= ?sum_fs (Op (FusionStart ?out_shape ?sum_add_stride (F32)) (ICons ?sum (INil))))
|
||||
(= ?bias_fs (Op (FusionStart ?out_shape ?bias_add_stride (F32)) (ICons ?bias (INil))))
|
||||
|
||||
(= ?sum (Op (KernelSum ?matmul_out_shape ?k_dim ?sum_in_stride ?k_stride ?sum_out_stride (F32)) (ICons ?mul_fe (INil))))
|
||||
(= ?mul_fe (Op (FusionEnd ?mul_shape ?mul_out_stride (F32)) (ICons ?mul_elem (INil))))
|
||||
(= ?mul_elem (Op (CudaBinaryElementwise \"Mul\" ?mul_shape ?patch_stride ?weight_stride ?mul_out_stride (F32)) (ICons ?patch_fs (ICons ?weight_fs (INil)))))
|
||||
(= ?patch_fs (Op (FusionStart ?mul_shape ?patch_stride (F32)) (ICons ?patches (INil))))
|
||||
(= ?weight_fs (Op (FusionStart ?mul_shape ?weight_stride (F32)) (ICons ?weight (INil))))
|
||||
(= ?patches (Op (KernelGather ?idx_shape ?idx_stride ?input_shape ?input_stride ?gather_out_stride (F32)) (ICons ?indices (ICons ?input (INil)))))
|
||||
|
||||
(= ?out_shape (ECons ?c_out (ECons ?h_out (ECons ?w_out (ENil)))))
|
||||
(= ?input_shape (ECons ?c_in (ECons ?h_in (ECons ?w_in (ENil)))))
|
||||
(= ?idx_shape (ECons ?c_in (ECons ?h_out (ECons ?w_out (ECons (MNum 1) (ECons ?kernel_h (ECons ?kernel_w (ENil))))))))
|
||||
(= ?matmul_out_shape (ECons ?m (ECons ?c_out (ENil))))
|
||||
(= ?mul_shape (ECons ?m (ECons ?c_out (ECons ?k_dim (ENil)))))
|
||||
|
||||
(= ?k_stride (MIter))
|
||||
(= ?sum_in_stride (ECons ?sum_m_stride (ECons ?sum_c_stride (ENil))))
|
||||
(= ?sum_out_stride (ECons ?sum_out_m_stride (ECons ?sum_out_c_stride (ENil))))
|
||||
(= ?sum_add_stride (ECons ?sum_add_c_stride (ECons ?sum_add_h_stride (ECons ?sum_add_w_stride (ENil)))))
|
||||
(= ?weight_co_stride (nth_from_end ?weight_stride 1))
|
||||
(= ?weight_inner_stride (nth_from_end ?weight_stride 0))
|
||||
(= (MNum 0) (nth_from_end ?weight_stride 2))
|
||||
(= ?bias_add_stride (ECons ?bias_c_stride (ECons (MNum 0) (ECons (MNum 0) (ENil)))))
|
||||
)
|
||||
(
|
||||
(let ?conv (Op (KernelConv2D
|
||||
?out_shape
|
||||
?input_shape
|
||||
?input_stride
|
||||
?weight_co_stride
|
||||
?weight_inner_stride
|
||||
?bias_c_stride
|
||||
?out_stride
|
||||
?kernel_h
|
||||
?kernel_w
|
||||
(MNum 1)
|
||||
(MNum 1)
|
||||
(MNum 1)
|
||||
(MNum 1)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(F32))
|
||||
(ICons ?input (ICons ?weight (ICons ?bias (INil))))))
|
||||
(union ?out ?conv)
|
||||
(subsume (Op (FusionEnd ?out_shape ?out_stride (F32)) (ICons ?add_elem (INil))))
|
||||
(set (dtype ?conv) (F32))
|
||||
)
|
||||
:ruleset kernel_lower
|
||||
:name \"kernel conv2d from cuda lowered bias unfold matmul\"
|
||||
)",
|
||||
),
|
||||
// Match the im2col-style HLIR conv used by Flux2:
|
||||
//
|
||||
// input.unfold([1, kh, kw], [1, 1, 1], [1, 1, 1])
|
||||
// -> squeeze/permute/merge view
|
||||
// -> matmul(weight.t())
|
||||
// -> split/permute view
|
||||
// -> + bias.expand_dim(1, h_out).expand_dim(2, w_out)
|
||||
//
|
||||
// The kernel consumes the pre-unfold input directly. That input may
|
||||
// already be a padded HLIR tensor, so the rewrite is still correct
|
||||
// for Flux2's padded convs while removing the large patch matrix.
|
||||
Rule::raw(
|
||||
"(rule
|
||||
(
|
||||
(= ?add (Op (Add ?out_shape ?sum_add_stride ?bias_add_stride ?add_out_stride) (ICons ?sum (ICons ?bias (INil)))))
|
||||
(= ?sum (Op (Sum ?matmul_out_shape ?k_dim ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?mul (Op (Mul ?mul_shape ?patch_stride ?weight_stride ?mul_out_stride) (ICons ?patches (ICons ?weight (INil)))))
|
||||
(= ?patches (Op (Gather ?idx_shape ?idx_stride ?input_shape ?input_stride) (ICons ?indices (ICons ?input (INil)))))
|
||||
|
||||
(= ?out_shape (ECons ?c_out (ECons ?h_out (ECons ?w_out (ENil)))))
|
||||
(= ?input_shape (ECons ?c_in (ECons ?h_in (ECons ?w_in (ENil)))))
|
||||
(= ?idx_shape (ECons ?c_in (ECons ?h_out (ECons ?w_out (ECons (MNum 1) (ECons ?kernel_h (ECons ?kernel_w (ENil))))))))
|
||||
(= ?matmul_out_shape (ECons ?m (ECons ?c_out (ENil))))
|
||||
|
||||
; This rewrite is for stride=1, dilation=1 over the
|
||||
; tensor passed to unfold. Padded HLIR inputs are already
|
||||
; represented as their own tensor, so padding is 0 here.
|
||||
(= ?h_out (MAdd (MSub ?h_in ?kernel_h) (MNum 1)))
|
||||
(= ?w_out (MAdd (MSub ?w_in ?kernel_w) (MNum 1)))
|
||||
(= ?m (MMul ?h_out ?w_out))
|
||||
(= ?k_dim (MMul ?c_in (MMul ?kernel_h ?kernel_w)))
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
(= ?weight_co_stride (nth_from_end ?weight_stride 1))
|
||||
(= ?weight_inner_stride (nth_from_end ?weight_stride 0))
|
||||
(= (MNum 0) (nth_from_end ?weight_stride 2))
|
||||
|
||||
(= ?bias_add_stride (ECons ?bias_c_stride (ECons (MNum 0) (ECons (MNum 0) (ENil)))))
|
||||
|
||||
(= (F32) (dtype ?input))
|
||||
(= (F32) (dtype ?weight))
|
||||
(= (F32) (dtype ?bias))
|
||||
)
|
||||
(
|
||||
(let ?conv (Op (KernelConv2D
|
||||
?out_shape
|
||||
?input_shape
|
||||
?input_stride
|
||||
?weight_co_stride
|
||||
?weight_inner_stride
|
||||
?bias_c_stride
|
||||
?add_out_stride
|
||||
?kernel_h
|
||||
?kernel_w
|
||||
(MNum 1)
|
||||
(MNum 1)
|
||||
(MNum 1)
|
||||
(MNum 1)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(F32))
|
||||
(ICons ?input (ICons ?weight (ICons ?bias (INil))))))
|
||||
(union ?add ?conv)
|
||||
(subsume (Op (Add ?out_shape ?sum_add_stride ?bias_add_stride ?add_out_stride) (ICons ?sum (ICons ?bias (INil)))))
|
||||
(set (dtype ?conv) (F32))
|
||||
)
|
||||
:ruleset kernel_specialize
|
||||
:name \"kernel conv2d from unfold matmul bias\"
|
||||
)",
|
||||
),
|
||||
Rule::raw(
|
||||
"(rule
|
||||
(
|
||||
(= ?add (Op (Add ?out_shape ?bias_add_stride ?sum_add_stride ?add_out_stride) (ICons ?bias (ICons ?sum (INil)))))
|
||||
(= ?sum (Op (Sum ?matmul_out_shape ?k_dim ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?mul (Op (Mul ?mul_shape ?patch_stride ?weight_stride ?mul_out_stride) (ICons ?patches (ICons ?weight (INil)))))
|
||||
(= ?patches (Op (Gather ?idx_shape ?idx_stride ?input_shape ?input_stride) (ICons ?indices (ICons ?input (INil)))))
|
||||
|
||||
(= ?out_shape (ECons ?c_out (ECons ?h_out (ECons ?w_out (ENil)))))
|
||||
(= ?input_shape (ECons ?c_in (ECons ?h_in (ECons ?w_in (ENil)))))
|
||||
(= ?idx_shape (ECons ?c_in (ECons ?h_out (ECons ?w_out (ECons (MNum 1) (ECons ?kernel_h (ECons ?kernel_w (ENil))))))))
|
||||
(= ?matmul_out_shape (ECons ?m (ECons ?c_out (ENil))))
|
||||
|
||||
(= ?h_out (MAdd (MSub ?h_in ?kernel_h) (MNum 1)))
|
||||
(= ?w_out (MAdd (MSub ?w_in ?kernel_w) (MNum 1)))
|
||||
(= ?m (MMul ?h_out ?w_out))
|
||||
(= ?k_dim (MMul ?c_in (MMul ?kernel_h ?kernel_w)))
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
(= ?weight_co_stride (nth_from_end ?weight_stride 1))
|
||||
(= ?weight_inner_stride (nth_from_end ?weight_stride 0))
|
||||
(= (MNum 0) (nth_from_end ?weight_stride 2))
|
||||
|
||||
(= ?bias_add_stride (ECons ?bias_c_stride (ECons (MNum 0) (ECons (MNum 0) (ENil)))))
|
||||
|
||||
(= (F32) (dtype ?input))
|
||||
(= (F32) (dtype ?weight))
|
||||
(= (F32) (dtype ?bias))
|
||||
)
|
||||
(
|
||||
(let ?conv (Op (KernelConv2D
|
||||
?out_shape
|
||||
?input_shape
|
||||
?input_stride
|
||||
?weight_co_stride
|
||||
?weight_inner_stride
|
||||
?bias_c_stride
|
||||
?add_out_stride
|
||||
?kernel_h
|
||||
?kernel_w
|
||||
(MNum 1)
|
||||
(MNum 1)
|
||||
(MNum 1)
|
||||
(MNum 1)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(F32))
|
||||
(ICons ?input (ICons ?weight (ICons ?bias (INil))))))
|
||||
(union ?add ?conv)
|
||||
(subsume (Op (Add ?out_shape ?bias_add_stride ?sum_add_stride ?add_out_stride) (ICons ?bias (ICons ?sum (INil)))))
|
||||
(set (dtype ?conv) (F32))
|
||||
)
|
||||
:ruleset kernel_specialize
|
||||
:name \"kernel conv2d from bias unfold matmul\"
|
||||
)",
|
||||
),
|
||||
Rule::raw(
|
||||
"(rule
|
||||
(
|
||||
(= ?add (Op (Add ?shape ?as ?bs ?os) ?inputs))
|
||||
(= ?add (Op (KernelConv2D ?out_shape ?input_shape ?input_stride ?wco ?wi ?bc ?out_stride ?kh ?kw ?sh ?sw ?dh ?dw ?ph ?pw ?dt) ?conv_inputs))
|
||||
)
|
||||
((delete (Op (Add ?shape ?as ?bs ?os) ?inputs)))
|
||||
:ruleset cleanup
|
||||
)",
|
||||
),
|
||||
Rule::raw(
|
||||
"(rule
|
||||
(
|
||||
(= ?fe (Op (FusionEnd ?shape ?os ?dt) ?inputs))
|
||||
(= ?fe (Op (KernelConv2D ?out_shape ?input_shape ?input_stride ?wco ?wi ?bc ?out_stride ?kh ?kw ?sh ?sw ?dh ?dw ?ph ?pw ?conv_dt) ?conv_inputs))
|
||||
)
|
||||
((delete (Op (FusionEnd ?shape ?os ?dt) ?inputs)))
|
||||
:ruleset cleanup
|
||||
)",
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a luminal::egglog_utils::SerializedEGraph,
|
||||
kind_children: &[&'a luminal::egglog_utils::NodeId],
|
||||
input_enodes: Vec<&'a luminal::egglog_utils::NodeId>,
|
||||
list_cache: &mut FxHashMap<&'a luminal::egglog_utils::NodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a luminal::egglog_utils::NodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a luminal::egglog_utils::NodeId>) {
|
||||
(
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(Self {
|
||||
out_shape: extract_expr_list(egraph, kind_children[0], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
input_shape: extract_expr_list(egraph, kind_children[1], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
input_stride: extract_expr_list(egraph, kind_children[2], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
weight_co_stride: extract_expr(egraph, kind_children[3], expr_cache).unwrap(),
|
||||
weight_inner_stride: extract_expr(egraph, kind_children[4], expr_cache).unwrap(),
|
||||
bias_c_stride: extract_expr(egraph, kind_children[5], expr_cache).unwrap(),
|
||||
out_stride: extract_expr_list(egraph, kind_children[6], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
kernel_h: extract_expr(egraph, kind_children[7], expr_cache).unwrap(),
|
||||
kernel_w: extract_expr(egraph, kind_children[8], expr_cache).unwrap(),
|
||||
stride_h: extract_expr(egraph, kind_children[9], expr_cache).unwrap(),
|
||||
stride_w: extract_expr(egraph, kind_children[10], expr_cache).unwrap(),
|
||||
dilation_h: extract_expr(egraph, kind_children[11], expr_cache).unwrap(),
|
||||
dilation_w: extract_expr(egraph, kind_children[12], expr_cache).unwrap(),
|
||||
pad_h: extract_expr(egraph, kind_children[13], expr_cache).unwrap(),
|
||||
pad_w: extract_expr(egraph, kind_children[14], expr_cache).unwrap(),
|
||||
dtype: extract_dtype(egraph, kind_children[15]),
|
||||
}) as Box<dyn KernelOp>),
|
||||
input_enodes,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
const THREADS_PER_BLOCK: usize = 256;
|
||||
|
||||
impl KernelOp for Conv2DKernel {
|
||||
impl KernelOp for KernelConv2D {
|
||||
fn compile(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
@@ -68,74 +541,135 @@ impl KernelOp for Conv2DKernel {
|
||||
Expression,
|
||||
FxHashMap<char, CudaSlice<u8>>,
|
||||
) {
|
||||
let total = self.output_elements();
|
||||
let grid = total.div_ceil(THREADS_PER_BLOCK);
|
||||
assert_eq!(self.dtype, DType::F32, "KernelConv2D currently emits F32");
|
||||
|
||||
let vars: FxHashSet<char> = self
|
||||
.out_shape
|
||||
.iter()
|
||||
.chain(&self.input_shape)
|
||||
.chain(&self.input_stride)
|
||||
.chain(&self.out_stride)
|
||||
.flat_map(|e| e.dyn_vars())
|
||||
.chain(self.weight_co_stride.dyn_vars())
|
||||
.chain(self.weight_inner_stride.dyn_vars())
|
||||
.chain(self.bias_c_stride.dyn_vars())
|
||||
.chain(self.kernel_h.dyn_vars())
|
||||
.chain(self.kernel_w.dyn_vars())
|
||||
.chain(self.stride_h.dyn_vars())
|
||||
.chain(self.stride_w.dyn_vars())
|
||||
.chain(self.dilation_h.dyn_vars())
|
||||
.chain(self.dilation_w.dyn_vars())
|
||||
.chain(self.pad_h.dyn_vars())
|
||||
.chain(self.pad_w.dyn_vars())
|
||||
.collect();
|
||||
|
||||
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
|
||||
let dyn_dims_param = if vars.is_empty() {
|
||||
""
|
||||
} else {
|
||||
", const int* dyn_dims"
|
||||
};
|
||||
|
||||
let c_out = self.out_shape[0].to_kernel();
|
||||
let h_out = self.out_shape[1].to_kernel();
|
||||
let w_out = self.out_shape[2].to_kernel();
|
||||
let c_in = self.input_shape[0].to_kernel();
|
||||
let h_in = self.input_shape[1].to_kernel();
|
||||
let w_in = self.input_shape[2].to_kernel();
|
||||
let weight_co_stride = self
|
||||
.weight_co_stride
|
||||
.substitute('z', Expression::from(1))
|
||||
.simplify()
|
||||
.to_kernel();
|
||||
let weight_inner_stride = self
|
||||
.weight_inner_stride
|
||||
.substitute('z', Expression::from(1))
|
||||
.simplify()
|
||||
.to_kernel();
|
||||
let bias_c_stride = self
|
||||
.bias_c_stride
|
||||
.substitute('z', Expression::from(1))
|
||||
.simplify()
|
||||
.to_kernel();
|
||||
let kh = self.kernel_h.to_kernel();
|
||||
let kw = self.kernel_w.to_kernel();
|
||||
let stride_h = self.stride_h.to_kernel();
|
||||
let stride_w = self.stride_w.to_kernel();
|
||||
let dilation_h = self.dilation_h.to_kernel();
|
||||
let dilation_w = self.dilation_w.to_kernel();
|
||||
let pad_h = self.pad_h.to_kernel();
|
||||
let pad_w = self.pad_w.to_kernel();
|
||||
let out_idx = flatten_strides(&self.out_shape, &self.out_stride).to_kernel();
|
||||
let input_idx = flatten_strides(&self.input_shape, &self.input_stride)
|
||||
.to_kernel()
|
||||
.replace("const_z", "input_linear");
|
||||
let n_outputs: Expression = self.out_shape.iter().copied().product();
|
||||
|
||||
let kernel = format!(
|
||||
"
|
||||
extern \"C\" __global__ void conv2d_bias_kernel(
|
||||
float* __restrict__ out,
|
||||
const float* __restrict__ input,
|
||||
const float* __restrict__ weight,
|
||||
const float* __restrict__ bias
|
||||
) {{
|
||||
const int TOTAL = {total};
|
||||
const int CIN = {c_in};
|
||||
const int H = {h_in};
|
||||
const int W = {w_in};
|
||||
const int HOUT = {h_out};
|
||||
const int WOUT = {w_out};
|
||||
const int K = {k};
|
||||
const int S = {s};
|
||||
const int P = {p};
|
||||
{dyn_defines}
|
||||
extern \"C\" {{
|
||||
__global__ void generic_conv2d_bias(
|
||||
float* __restrict__ out,
|
||||
const float* __restrict__ input,
|
||||
const float* __restrict__ weight,
|
||||
const float* __restrict__ bias{dyn_dims_param}
|
||||
) {{
|
||||
long long const_z = (long long)blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const long long total = {total};
|
||||
if (const_z >= total) return;
|
||||
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx >= TOTAL) return;
|
||||
int hw = HOUT * WOUT;
|
||||
int co = idx / hw;
|
||||
int rem = idx - co * hw;
|
||||
int ho = rem / WOUT;
|
||||
int wo = rem - ho * WOUT;
|
||||
const long long COUT = {c_out};
|
||||
const long long HOUT = {h_out};
|
||||
const long long WOUT = {w_out};
|
||||
const long long CIN = {c_in};
|
||||
const long long HIN = {h_in};
|
||||
const long long WIN = {w_in};
|
||||
const long long KH = {kh};
|
||||
const long long KW = {kw};
|
||||
const long long SH = {stride_h};
|
||||
const long long SW = {stride_w};
|
||||
const long long DH = {dilation_h};
|
||||
const long long DW = {dilation_w};
|
||||
const long long PH = {pad_h};
|
||||
const long long PW = {pad_w};
|
||||
const long long W_CO_STRIDE = {weight_co_stride};
|
||||
const long long W_INNER_STRIDE = {weight_inner_stride};
|
||||
const long long BIAS_C_STRIDE = {bias_c_stride};
|
||||
|
||||
float acc = bias[co];
|
||||
int weight_co_base = co * (CIN * K * K);
|
||||
for (int ci = 0; ci < CIN; ci++) {{
|
||||
int input_ci_base = ci * (H * W);
|
||||
int weight_ci_base = weight_co_base + ci * (K * K);
|
||||
#pragma unroll
|
||||
for (int ki = 0; ki < K; ki++) {{
|
||||
int hi = ho * S + ki - P;
|
||||
if (hi < 0 || hi >= H) continue;
|
||||
int input_row_base = input_ci_base + hi * W;
|
||||
int weight_row_base = weight_ci_base + ki * K;
|
||||
#pragma unroll
|
||||
for (int kj = 0; kj < K; kj++) {{
|
||||
int wj = wo * S + kj - P;
|
||||
if (wj < 0 || wj >= W) continue;
|
||||
acc += input[input_row_base + wj] * weight[weight_row_base + kj];
|
||||
long long co = const_z / (HOUT * WOUT);
|
||||
long long rem = const_z - co * HOUT * WOUT;
|
||||
long long oh = rem / WOUT;
|
||||
long long ow = rem - oh * WOUT;
|
||||
|
||||
float acc = bias[co * BIAS_C_STRIDE];
|
||||
for (long long ci = 0; ci < CIN; ++ci) {{
|
||||
for (long long r = 0; r < KH; ++r) {{
|
||||
long long ih = oh * SH + r * DH - PH;
|
||||
if (ih < 0 || ih >= HIN) continue;
|
||||
for (long long s = 0; s < KW; ++s) {{
|
||||
long long iw = ow * SW + s * DW - PW;
|
||||
if (iw < 0 || iw >= WIN) continue;
|
||||
long long input_linear = (ci * HIN + ih) * WIN + iw;
|
||||
long long input_idx = {input_idx};
|
||||
long long inner = (ci * KH + r) * KW + s;
|
||||
long long weight_idx = co * W_CO_STRIDE + inner * W_INNER_STRIDE;
|
||||
acc += input[input_idx] * weight[weight_idx];
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
out[{out_idx}] = acc;
|
||||
}}
|
||||
out[idx] = acc;
|
||||
}}
|
||||
",
|
||||
total = total,
|
||||
c_in = self.c_in,
|
||||
h_in = self.h_in,
|
||||
w_in = self.w_in,
|
||||
h_out = self.h_out,
|
||||
w_out = self.w_out,
|
||||
k = self.kernel,
|
||||
s = self.stride,
|
||||
p = self.padding,
|
||||
}}",
|
||||
total = n_outputs.to_kernel(),
|
||||
);
|
||||
|
||||
let (module, func) = if let Some((m, f)) = compile_cache.get(&kernel) {
|
||||
(m.clone(), f.clone())
|
||||
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
|
||||
(module.clone(), func.clone())
|
||||
} else {
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let func = module.load_function("conv2d_bias_kernel").unwrap();
|
||||
let func = module.load_function("generic_conv2d_bias").unwrap();
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
(module, func)
|
||||
};
|
||||
@@ -144,37 +678,45 @@ extern \"C\" __global__ void conv2d_bias_kernel(
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(
|
||||
Expression::from(grid),
|
||||
Expression::from(1usize),
|
||||
Expression::from(1usize),
|
||||
),
|
||||
(
|
||||
Expression::from(THREADS_PER_BLOCK),
|
||||
Expression::from(1usize),
|
||||
Expression::from(1usize),
|
||||
),
|
||||
Expression::from(0usize),
|
||||
(n_outputs.ceil_div(256), 1.into(), 1.into()),
|
||||
(n_outputs.min(256), 1.into(), 1.into()),
|
||||
0.into(),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
Expression::from(self.output_elements())
|
||||
self.out_shape.iter().copied().product()
|
||||
}
|
||||
|
||||
fn all_dyn_vars(&self) -> FxHashSet<char> {
|
||||
self.out_shape
|
||||
.iter()
|
||||
.chain(&self.input_shape)
|
||||
.chain(&self.input_stride)
|
||||
.chain(&self.out_stride)
|
||||
.flat_map(|e| e.dyn_vars())
|
||||
.chain(self.weight_co_stride.dyn_vars())
|
||||
.chain(self.weight_inner_stride.dyn_vars())
|
||||
.chain(self.bias_c_stride.dyn_vars())
|
||||
.chain(self.kernel_h.dyn_vars())
|
||||
.chain(self.kernel_w.dyn_vars())
|
||||
.chain(self.stride_h.dyn_vars())
|
||||
.chain(self.stride_w.dyn_vars())
|
||||
.chain(self.dilation_h.dyn_vars())
|
||||
.chain(self.dilation_w.dyn_vars())
|
||||
.chain(self.pad_h.dyn_vars())
|
||||
.chain(self.pad_w.dyn_vars())
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn output_bytes(&self) -> Expression {
|
||||
self.output_size() * 4
|
||||
}
|
||||
|
||||
fn output_dtype(&self) -> DType {
|
||||
DType::F32
|
||||
}
|
||||
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
// Per output: C_in * K * K input loads + same many weight loads + 1 bias load.
|
||||
let per_out = self.c_in * self.kernel * self.kernel * 2 + 1;
|
||||
Expression::from(self.output_elements() * per_out * 4)
|
||||
let c_in = self.input_shape[0];
|
||||
self.output_size() * self.kernel_h * self.kernel_w * c_in * 2 * 4 + self.output_size() * 4
|
||||
}
|
||||
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
@@ -182,108 +724,15 @@ extern \"C\" __global__ void conv2d_bias_kernel(
|
||||
}
|
||||
|
||||
fn flops(&self) -> Expression {
|
||||
// 2 * C_in * K * K mul-adds per output, plus the bias add = +1.
|
||||
let per_out = self.c_in * self.kernel * self.kernel * 2 + 1;
|
||||
Expression::from(self.output_elements() * per_out)
|
||||
let c_in = self.input_shape[0];
|
||||
self.output_size() * self.kernel_h * self.kernel_w * c_in * 2
|
||||
}
|
||||
|
||||
fn output_dtype(&self) -> DType {
|
||||
self.dtype
|
||||
}
|
||||
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
"Conv2DBias"
|
||||
"GenericConv2D"
|
||||
}
|
||||
}
|
||||
|
||||
/// luminal `CustomOp` that wraps `Conv2DKernel`. Lets us drop the kernel
|
||||
/// straight into an HLIR graph via `cx.custom_op(...)` without going
|
||||
/// through egglog rewrites.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Conv2DCustom(pub Conv2DKernel);
|
||||
|
||||
impl CustomOp for Conv2DCustom {
|
||||
fn to_llir_op(&self) -> LLIROp {
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(self.0.clone()) as Box<dyn KernelOp>)
|
||||
}
|
||||
}
|
||||
|
||||
/// 2D conv-with-bias on a `(C_in, H, W)` F32 input tensor, with weights
|
||||
/// stored as `(C_out, C_in*K*K)` and bias as `(C_out,)`. Stride/padding/kernel
|
||||
/// are static. Output: `(C_out, H_out, W_out)`.
|
||||
///
|
||||
/// This is a thin wrapper over [`Conv2DKernel`] that hides the
|
||||
/// `cx.custom_op` plumbing. All inputs MUST be `DType::F32` and contiguous
|
||||
/// row-major; pass `tensor * 1.0_f32` first if you have a strided view.
|
||||
pub fn conv2d_bias(
|
||||
input: GraphTensor,
|
||||
weight: GraphTensor,
|
||||
bias: GraphTensor,
|
||||
kernel: usize,
|
||||
stride: usize,
|
||||
padding: usize,
|
||||
) -> GraphTensor {
|
||||
assert_eq!(input.dtype, DType::F32, "conv2d_bias requires F32 input");
|
||||
assert_eq!(weight.dtype, DType::F32, "conv2d_bias requires F32 weight");
|
||||
assert_eq!(bias.dtype, DType::F32, "conv2d_bias requires F32 bias");
|
||||
|
||||
let dims = input.dims();
|
||||
assert_eq!(dims.len(), 3, "conv2d_bias expects (C_in, H, W) input");
|
||||
let c_in = dims[0].to_usize().expect("C_in must be a static dim");
|
||||
let h_in = dims[1].to_usize().expect("H must be a static dim");
|
||||
let w_in = dims[2].to_usize().expect("W must be a static dim");
|
||||
|
||||
let w_dims = weight.dims();
|
||||
assert_eq!(
|
||||
w_dims.len(),
|
||||
2,
|
||||
"conv2d_bias expects weight (C_out, C_in*K*K)"
|
||||
);
|
||||
let c_out = w_dims[0].to_usize().expect("C_out must be a static dim");
|
||||
let w_kk = w_dims[1]
|
||||
.to_usize()
|
||||
.expect("weight inner dim must be static");
|
||||
assert_eq!(
|
||||
w_kk,
|
||||
c_in * kernel * kernel,
|
||||
"weight inner dim {w_kk} != C_in*K*K = {}",
|
||||
c_in * kernel * kernel,
|
||||
);
|
||||
|
||||
let b_dims = bias.dims();
|
||||
assert_eq!(b_dims.len(), 1, "conv2d_bias expects bias (C_out,)");
|
||||
assert_eq!(
|
||||
b_dims[0].to_usize().expect("bias dim must be static"),
|
||||
c_out
|
||||
);
|
||||
|
||||
assert!(
|
||||
h_in + 2 * padding >= kernel,
|
||||
"padded H_in ({}) is smaller than kernel ({})",
|
||||
h_in + 2 * padding,
|
||||
kernel,
|
||||
);
|
||||
assert!(
|
||||
w_in + 2 * padding >= kernel,
|
||||
"padded W_in ({}) is smaller than kernel ({})",
|
||||
w_in + 2 * padding,
|
||||
kernel,
|
||||
);
|
||||
let h_out = (h_in + 2 * padding - kernel) / stride + 1;
|
||||
let w_out = (w_in + 2 * padding - kernel) / stride + 1;
|
||||
|
||||
let kern = Conv2DKernel {
|
||||
c_in,
|
||||
h_in,
|
||||
w_in,
|
||||
c_out,
|
||||
kernel,
|
||||
stride,
|
||||
padding,
|
||||
h_out,
|
||||
w_out,
|
||||
};
|
||||
let cx: &mut Graph = unsafe { &mut *input.graph_ref };
|
||||
cx.custom_op(
|
||||
Conv2DCustom(kern),
|
||||
vec![input, weight, bias],
|
||||
(c_out, h_out, w_out),
|
||||
DType::F32,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -7,7 +7,10 @@ use std::sync::Arc;
|
||||
|
||||
use cudarc::driver::{
|
||||
CudaContext, CudaFunction, CudaStream, DriverError,
|
||||
sys::{self, CUevent, CUfunction, CUgraph, CUgraphExec, CUgraphNode},
|
||||
sys::{
|
||||
self, CUevent, CUfunction, CUgraph, CUgraphExec, CUgraphExecUpdateResult,
|
||||
CUgraphExecUpdateResultInfo, CUgraphNode, CUstreamCaptureMode,
|
||||
},
|
||||
};
|
||||
|
||||
/// A CUDA graph that can be modified and instantiated.
|
||||
@@ -69,6 +72,176 @@ impl CudaGraphHandle {
|
||||
}
|
||||
}
|
||||
|
||||
/// Updates a kernel node in the mutable source graph.
|
||||
pub unsafe fn set_kernel_node_params(
|
||||
&mut self,
|
||||
node: CUgraphNode,
|
||||
func: CUfunction,
|
||||
grid_dim: (u32, u32, u32),
|
||||
block_dim: (u32, u32, u32),
|
||||
shared_mem_bytes: u32,
|
||||
kernel_params: *mut *mut c_void,
|
||||
) -> Result<(), DriverError> {
|
||||
self.ctx.bind_to_thread()?;
|
||||
let params = sys::CUDA_KERNEL_NODE_PARAMS {
|
||||
func,
|
||||
gridDimX: grid_dim.0,
|
||||
gridDimY: grid_dim.1,
|
||||
gridDimZ: grid_dim.2,
|
||||
blockDimX: block_dim.0,
|
||||
blockDimY: block_dim.1,
|
||||
blockDimZ: block_dim.2,
|
||||
sharedMemBytes: shared_mem_bytes,
|
||||
kernelParams: kernel_params,
|
||||
extra: std::ptr::null_mut(),
|
||||
kern: std::ptr::null_mut(),
|
||||
ctx: std::ptr::null_mut(),
|
||||
};
|
||||
|
||||
unsafe { sys::cuGraphKernelNodeSetParams_v2(node, ¶ms).result() }
|
||||
}
|
||||
|
||||
/// Adds an empty dependency node to the graph.
|
||||
pub fn add_empty_node(
|
||||
&mut self,
|
||||
dependencies: &[CUgraphNode],
|
||||
) -> Result<CUgraphNode, DriverError> {
|
||||
self.ctx.bind_to_thread()?;
|
||||
let mut node = MaybeUninit::uninit();
|
||||
unsafe {
|
||||
sys::cuGraphAddEmptyNode(
|
||||
node.as_mut_ptr(),
|
||||
self.cu_graph,
|
||||
dependencies.as_ptr(),
|
||||
dependencies.len(),
|
||||
)
|
||||
.result()?;
|
||||
Ok(node.assume_init())
|
||||
}
|
||||
}
|
||||
|
||||
/// Destroys a node in the mutable graph.
|
||||
pub unsafe fn destroy_node(&mut self, node: CUgraphNode) -> Result<(), DriverError> {
|
||||
self.ctx.bind_to_thread()?;
|
||||
unsafe { sys::cuGraphDestroyNode(node).result() }
|
||||
}
|
||||
|
||||
/// Adds dependency edges to the mutable graph.
|
||||
pub fn add_dependencies(
|
||||
&mut self,
|
||||
from: &[CUgraphNode],
|
||||
to: &[CUgraphNode],
|
||||
) -> Result<(), DriverError> {
|
||||
assert_eq!(from.len(), to.len());
|
||||
self.ctx.bind_to_thread()?;
|
||||
unsafe {
|
||||
sys::cuGraphAddDependencies(self.cu_graph, from.as_ptr(), to.as_ptr(), from.len())
|
||||
}
|
||||
.result()
|
||||
}
|
||||
|
||||
/// Removes dependency edges from the mutable graph.
|
||||
pub fn remove_dependencies(
|
||||
&mut self,
|
||||
from: &[CUgraphNode],
|
||||
to: &[CUgraphNode],
|
||||
) -> Result<(), DriverError> {
|
||||
assert_eq!(from.len(), to.len());
|
||||
self.ctx.bind_to_thread()?;
|
||||
unsafe {
|
||||
sys::cuGraphRemoveDependencies(self.cu_graph, from.as_ptr(), to.as_ptr(), from.len())
|
||||
}
|
||||
.result()
|
||||
}
|
||||
|
||||
/// Returns all nodes currently in the graph.
|
||||
pub fn nodes(&self) -> Result<Vec<CUgraphNode>, DriverError> {
|
||||
self.ctx.bind_to_thread()?;
|
||||
let mut count = 0usize;
|
||||
unsafe {
|
||||
sys::cuGraphGetNodes(self.cu_graph, std::ptr::null_mut(), &mut count).result()?;
|
||||
}
|
||||
if count == 0 {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
let mut nodes = vec![std::ptr::null_mut(); count];
|
||||
unsafe {
|
||||
sys::cuGraphGetNodes(self.cu_graph, nodes.as_mut_ptr(), &mut count).result()?;
|
||||
}
|
||||
nodes.truncate(count);
|
||||
Ok(nodes)
|
||||
}
|
||||
|
||||
/// Returns the direct dependencies of a node.
|
||||
pub fn dependencies(&self, node: CUgraphNode) -> Result<Vec<CUgraphNode>, DriverError> {
|
||||
self.ctx.bind_to_thread()?;
|
||||
let mut count = 0usize;
|
||||
unsafe {
|
||||
sys::cuGraphNodeGetDependencies(node, std::ptr::null_mut(), &mut count).result()?;
|
||||
}
|
||||
if count == 0 {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
let mut deps = vec![std::ptr::null_mut(); count];
|
||||
unsafe {
|
||||
sys::cuGraphNodeGetDependencies(node, deps.as_mut_ptr(), &mut count).result()?;
|
||||
}
|
||||
deps.truncate(count);
|
||||
Ok(deps)
|
||||
}
|
||||
|
||||
/// Returns the direct dependents of a node.
|
||||
pub fn dependent_nodes(&self, node: CUgraphNode) -> Result<Vec<CUgraphNode>, DriverError> {
|
||||
self.ctx.bind_to_thread()?;
|
||||
let mut count = 0usize;
|
||||
unsafe {
|
||||
sys::cuGraphNodeGetDependentNodes(node, std::ptr::null_mut(), &mut count).result()?;
|
||||
}
|
||||
if count == 0 {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
let mut deps = vec![std::ptr::null_mut(); count];
|
||||
unsafe {
|
||||
sys::cuGraphNodeGetDependentNodes(node, deps.as_mut_ptr(), &mut count).result()?;
|
||||
}
|
||||
deps.truncate(count);
|
||||
Ok(deps)
|
||||
}
|
||||
|
||||
/// Begins stream capture that appends captured work into this graph.
|
||||
pub fn begin_capture_to_graph(
|
||||
&mut self,
|
||||
stream: &CudaStream,
|
||||
dependencies: &[CUgraphNode],
|
||||
) -> Result<(), DriverError> {
|
||||
self.ctx.bind_to_thread()?;
|
||||
unsafe {
|
||||
sys::cuStreamBeginCaptureToGraph(
|
||||
stream.cu_stream(),
|
||||
self.cu_graph,
|
||||
dependencies.as_ptr(),
|
||||
std::ptr::null(),
|
||||
dependencies.len(),
|
||||
CUstreamCaptureMode::CU_STREAM_CAPTURE_MODE_RELAXED,
|
||||
)
|
||||
.result()
|
||||
}
|
||||
}
|
||||
|
||||
/// Ends stream capture previously started by begin_capture_to_graph.
|
||||
pub fn end_capture(&mut self, stream: &CudaStream) -> Result<(), DriverError> {
|
||||
self.ctx.bind_to_thread()?;
|
||||
let mut graph = MaybeUninit::uninit();
|
||||
unsafe {
|
||||
sys::cuStreamEndCapture(stream.cu_stream(), graph.as_mut_ptr()).result()?;
|
||||
let captured = graph.assume_init();
|
||||
if captured != self.cu_graph && !captured.is_null() {
|
||||
sys::cuGraphDestroy(captured).result()?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Adds an event record node to the graph for timing.
|
||||
pub fn add_event_record_node(
|
||||
&mut self,
|
||||
@@ -155,6 +328,25 @@ impl CudaGraphExecHandle {
|
||||
unsafe { sys::cuGraphExecKernelNodeSetParams_v2(self.cu_graph_exec, node, ¶ms) }
|
||||
.result()
|
||||
}
|
||||
|
||||
/// Attempts to update this executable graph from an already-mutated source graph.
|
||||
pub fn update_from_graph(&mut self, graph: &CudaGraphHandle) -> Result<(), DriverError> {
|
||||
self.ctx.bind_to_thread()?;
|
||||
let mut result = CUgraphExecUpdateResultInfo {
|
||||
result: CUgraphExecUpdateResult::CU_GRAPH_EXEC_UPDATE_SUCCESS,
|
||||
errorNode: std::ptr::null_mut(),
|
||||
errorFromNode: std::ptr::null_mut(),
|
||||
};
|
||||
unsafe {
|
||||
sys::cuGraphExecUpdate_v2(self.cu_graph_exec, graph.cu_graph, &mut result).result()?;
|
||||
}
|
||||
if result.result != CUgraphExecUpdateResult::CU_GRAPH_EXEC_UPDATE_SUCCESS {
|
||||
return Err(DriverError(
|
||||
sys::CUresult::CUDA_ERROR_GRAPH_EXEC_UPDATE_FAILURE,
|
||||
));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for CudaGraphExecHandle {
|
||||
@@ -480,6 +672,38 @@ mod tests {
|
||||
assert_eq!(result[0], 6.0f32);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_graph_empty_node_dependency_reconnect() {
|
||||
let Ok(ctx) = CudaContext::new(0) else { return };
|
||||
let mut graph = CudaGraphHandle::new(ctx).unwrap();
|
||||
|
||||
let entry = graph.add_empty_node(&[]).unwrap();
|
||||
let middle = graph.add_empty_node(&[entry]).unwrap();
|
||||
let exit = graph.add_empty_node(&[middle]).unwrap();
|
||||
|
||||
let nodes = graph.nodes().unwrap();
|
||||
assert!(nodes.contains(&entry));
|
||||
assert!(nodes.contains(&middle));
|
||||
assert!(nodes.contains(&exit));
|
||||
assert_eq!(graph.dependencies(middle).unwrap(), vec![entry]);
|
||||
assert_eq!(graph.dependent_nodes(middle).unwrap(), vec![exit]);
|
||||
|
||||
graph.add_dependencies(&[entry], &[exit]).unwrap();
|
||||
let exit_deps = graph.dependencies(exit).unwrap();
|
||||
assert!(exit_deps.contains(&entry));
|
||||
assert!(exit_deps.contains(&middle));
|
||||
|
||||
graph.remove_dependencies(&[middle], &[exit]).unwrap();
|
||||
let exit_deps = graph.dependencies(exit).unwrap();
|
||||
assert_eq!(exit_deps.len(), 1);
|
||||
assert!(exit_deps.contains(&entry));
|
||||
|
||||
unsafe {
|
||||
graph.destroy_node(middle).unwrap();
|
||||
}
|
||||
assert!(!graph.nodes().unwrap().contains(&middle));
|
||||
}
|
||||
|
||||
// CUDA Graph Tests
|
||||
|
||||
#[test]
|
||||
@@ -498,8 +722,8 @@ mod tests {
|
||||
let data_b = random_f32_vec(size, 43, -0.5, 0.5);
|
||||
rt.set_data(a, data_a.clone());
|
||||
rt.set_data(b, data_b.clone());
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
rt = cx.search(rt, 5);
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result1 = rt.get_f32(c);
|
||||
rt.execute(&cx.dyn_map);
|
||||
@@ -530,8 +754,8 @@ mod tests {
|
||||
let data_b = random_f32_vec(size, 43, -0.5, 0.5);
|
||||
rt.set_data(a, data_a.clone());
|
||||
rt.set_data(b, data_b.clone());
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
rt = cx.search(rt, 5);
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
let mut results = Vec::new();
|
||||
for _ in 0..5 {
|
||||
rt.execute(&cx.dyn_map);
|
||||
@@ -568,8 +792,8 @@ mod tests {
|
||||
rt.set_data(a, data_a.clone());
|
||||
rt.set_data(b, data_b.clone());
|
||||
cx.set_dim('s', size);
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
rt = cx.search(rt, 5);
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.execute(&cx.dyn_map);
|
||||
let expected: Vec<f32> = data_a
|
||||
.iter()
|
||||
@@ -610,8 +834,8 @@ mod tests {
|
||||
let data_b = random_f32_vec(size, 43, -0.5, 0.5);
|
||||
rt.set_data(a, data_a.clone());
|
||||
rt.set_data(b, data_b.clone());
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
rt = cx.search(rt, 5);
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.execute(&cx.dyn_map);
|
||||
let expected: Vec<f32> = data_a.iter().zip(&data_b).map(|(a, b)| a + b).collect();
|
||||
let eps = dtype_epsilon(luminal::dtype::DType::F32);
|
||||
@@ -641,8 +865,8 @@ mod tests {
|
||||
let data_b = random_f32_vec(size, 43, -0.5, 0.5);
|
||||
rt.set_data(a, data_a.clone());
|
||||
rt.set_data(b, data_b.clone());
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
rt = cx.search(rt, 5);
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
for _ in 0..10 {
|
||||
rt.execute(&cx.dyn_map);
|
||||
}
|
||||
@@ -674,8 +898,8 @@ mod tests {
|
||||
let data_b = random_f32_vec(initial_size, 43, -0.5, 0.5);
|
||||
rt.set_data(a, data_a.clone());
|
||||
rt.set_data(b, data_b.clone());
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
rt = cx.search(rt, 5);
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
|
||||
// Initial execution
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -89,6 +89,21 @@ impl EgglogOp for CudaUnaryElementwise {
|
||||
)));
|
||||
}
|
||||
|
||||
rules.push(Rule::raw(
|
||||
"(rule (
|
||||
(= ?sqrt (Op (Sqrt ?shape ?x_stride ?sqrt_stride) (ICons ?x (INil))))
|
||||
(= ?recip (Op (Recip ?shape ?sqrt_stride ?out_stride) (ICons ?sqrt (INil))))
|
||||
(= ?dt (dtype ?recip))
|
||||
) (
|
||||
(let ?fs (Op (FusionStart ?shape ?x_stride ?dt) (ICons ?x (INil))))
|
||||
(let ?elem (Op (CudaUnaryElementwise \"Rsqrt\" ?shape ?x_stride ?out_stride ?dt)
|
||||
(ICons ?fs (INil))))
|
||||
(let ?fe (Op (FusionEnd ?shape ?out_stride ?dt) (ICons ?elem (INil))))
|
||||
(union ?recip ?fe)
|
||||
(set (dtype ?fe) ?dt)
|
||||
) :ruleset kernel_lower :name \"cuda-elem-rsqrt-from-sqrt-recip\")",
|
||||
));
|
||||
|
||||
rules.push(Rule::raw(
|
||||
"(rule
|
||||
(
|
||||
|
||||
@@ -309,6 +309,61 @@ impl EgglogOp for FusionEnd {
|
||||
// `Cycle(NodeIndex(_))`. Grow rules already compose adjacent regions
|
||||
// correctly without dissolve.
|
||||
|
||||
rules.push(Rule::raw(
|
||||
"(rule (
|
||||
(= ?fe (Op (FusionEnd ?fe_shape ?fe_s ?dt) (ICons ?inner (INil))))
|
||||
(= ?inner (Op (CudaUnaryElementwise ?op ?inner_shape ?inner_in_s ?inner_s ?dt) ?inner_inputs))
|
||||
(!= ?fe_shape ?inner_shape)
|
||||
) (
|
||||
(delete (Op (FusionEnd ?fe_shape ?fe_s ?dt) (ICons ?inner (INil))))
|
||||
) :ruleset cleanup :name \"delete-malformed-FE-unary-shape\")",
|
||||
));
|
||||
rules.push(Rule::raw(
|
||||
"(rule (
|
||||
(= ?fe (Op (FusionEnd ?fe_shape ?fe_s ?dt) (ICons ?inner (INil))))
|
||||
(= ?inner (Op (CudaUnaryElementwise ?op ?inner_shape ?inner_in_s ?inner_s ?dt) ?inner_inputs))
|
||||
(!= ?fe_s ?inner_s)
|
||||
) (
|
||||
(delete (Op (FusionEnd ?fe_shape ?fe_s ?dt) (ICons ?inner (INil))))
|
||||
) :ruleset cleanup :name \"delete-malformed-FE-unary-strides\")",
|
||||
));
|
||||
rules.push(Rule::raw(
|
||||
"(rule (
|
||||
(= ?fe (Op (FusionEnd ?fe_shape ?fe_s ?dt) (ICons ?inner (INil))))
|
||||
(= ?inner (Op (CudaBinaryElementwise ?op ?inner_shape ?a_s ?b_s ?inner_s ?dt) ?inner_inputs))
|
||||
(!= ?fe_shape ?inner_shape)
|
||||
) (
|
||||
(delete (Op (FusionEnd ?fe_shape ?fe_s ?dt) (ICons ?inner (INil))))
|
||||
) :ruleset cleanup :name \"delete-malformed-FE-binary-shape\")",
|
||||
));
|
||||
rules.push(Rule::raw(
|
||||
"(rule (
|
||||
(= ?fe (Op (FusionEnd ?fe_shape ?fe_s ?dt) (ICons ?inner (INil))))
|
||||
(= ?inner (Op (CudaBinaryElementwise ?op ?inner_shape ?a_s ?b_s ?inner_s ?dt) ?inner_inputs))
|
||||
(!= ?fe_s ?inner_s)
|
||||
) (
|
||||
(delete (Op (FusionEnd ?fe_shape ?fe_s ?dt) (ICons ?inner (INil))))
|
||||
) :ruleset cleanup :name \"delete-malformed-FE-binary-strides\")",
|
||||
));
|
||||
rules.push(Rule::raw(
|
||||
"(rule (
|
||||
(= ?fe (Op (FusionEnd ?fe_shape ?fe_s ?dt) (ICons ?inner (INil))))
|
||||
(= ?inner (Op (FusionEnd ?inner_shape ?inner_s ?dt) ?inner_inputs))
|
||||
(!= ?fe_shape ?inner_shape)
|
||||
) (
|
||||
(delete (Op (FusionEnd ?fe_shape ?fe_s ?dt) (ICons ?inner (INil))))
|
||||
) :ruleset cleanup :name \"delete-malformed-FE-nested-shape\")",
|
||||
));
|
||||
rules.push(Rule::raw(
|
||||
"(rule (
|
||||
(= ?fe (Op (FusionEnd ?fe_shape ?fe_s ?dt) (ICons ?inner (INil))))
|
||||
(= ?inner (Op (FusionEnd ?inner_shape ?inner_s ?dt) ?inner_inputs))
|
||||
(!= ?fe_s ?inner_s)
|
||||
) (
|
||||
(delete (Op (FusionEnd ?fe_shape ?fe_s ?dt) (ICons ?inner (INil))))
|
||||
) :ruleset cleanup :name \"delete-malformed-FE-nested-strides\")",
|
||||
));
|
||||
|
||||
rules
|
||||
}
|
||||
|
||||
|
||||
@@ -358,6 +358,7 @@ fn elementwise_body(op: &str, locals: &[&str], dtype: DType) -> String {
|
||||
match op {
|
||||
"Sin" => format!("sinf({})", a()),
|
||||
"Sqrt" => format!("sqrtf({})", a()),
|
||||
"Rsqrt" => format!("rsqrtf({})", a()),
|
||||
"Exp" => format!("expf({})", a()),
|
||||
"Exp2" => format!("exp2f({})", a()),
|
||||
"Log2" => format!("log2f({})", a()),
|
||||
|
||||
319
crates/luminal_cuda_lite/src/kernel/generic_matmul.rs
Normal file
319
crates/luminal_cuda_lite/src/kernel/generic_matmul.rs
Normal file
@@ -0,0 +1,319 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{
|
||||
compile_module_image_for_current_device, cuda_dtype,
|
||||
kernel::{
|
||||
KernelOp,
|
||||
hlir::{dtype_includes, generate_dyn_dims_defines},
|
||||
},
|
||||
};
|
||||
use cudarc::driver::{CudaFunction, CudaModule, CudaSlice, CudaStream};
|
||||
use luminal::{
|
||||
egglog_utils::{
|
||||
api::{Rule, SortDef, sort},
|
||||
base::{DTYPE, ELIST, EXPRESSION, OP_KIND},
|
||||
extract_dtype, extract_expr, extract_expr_list,
|
||||
},
|
||||
op::*,
|
||||
prelude::*,
|
||||
shape::flatten_strides,
|
||||
};
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct GenericMatmul {
|
||||
out_shape: Vec<Expression>,
|
||||
mul_shape: Vec<Expression>,
|
||||
k: Expression,
|
||||
lhs_strides: Vec<Expression>,
|
||||
rhs_strides: Vec<Expression>,
|
||||
sum_input_strides: Vec<Expression>,
|
||||
sum_iter_stride: Expression,
|
||||
out_strides: Vec<Expression>,
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
impl EgglogOp for GenericMatmul {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
"GenericMatmul",
|
||||
&[
|
||||
("out_shape", ELIST),
|
||||
("mul_shape", ELIST),
|
||||
("k", EXPRESSION),
|
||||
("lhs_strides", ELIST),
|
||||
("rhs_strides", ELIST),
|
||||
("sum_input_strides", ELIST),
|
||||
("sum_iter_stride", EXPRESSION),
|
||||
("out_strides", ELIST),
|
||||
("dtype", DTYPE),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
fn n_inputs(&self) -> usize {
|
||||
2
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![
|
||||
Rule::raw(
|
||||
"(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?lhs_strides ?rhs_strides ?mul_out_strides)
|
||||
(ICons ?lhs (ICons ?rhs (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_input_strides ?sum_iter_stride ?out_strides)
|
||||
(ICons ?mul (INil))))
|
||||
(= ?dt (dtype ?sum))
|
||||
)
|
||||
(
|
||||
(let ?generic (Op (GenericMatmul
|
||||
?out_shape
|
||||
?mul_shape
|
||||
?k
|
||||
?lhs_strides
|
||||
?rhs_strides
|
||||
?sum_input_strides
|
||||
?sum_iter_stride
|
||||
?out_strides
|
||||
?dt)
|
||||
(ICons ?lhs (ICons ?rhs (INil)))))
|
||||
(union ?sum ?generic)
|
||||
(set (dtype ?generic) ?dt)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name \"generic-matmul-cuda-mul-sum\"
|
||||
)",
|
||||
),
|
||||
Rule::raw(
|
||||
"(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?lhs_strides ?rhs_strides ?mul_out_strides)
|
||||
(ICons ?lhs (ICons ?rhs (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_input_strides ?sum_iter_stride ?out_strides)
|
||||
(ICons ?mul (INil))))
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?go ?gm ?gk ?gls ?grs ?gsis ?gsit ?gos ?gdt)
|
||||
?generic_inputs))
|
||||
)
|
||||
(
|
||||
(delete (Op (Sum ?out_shape ?k ?sum_input_strides ?sum_iter_stride ?out_strides)
|
||||
(ICons ?mul (INil))))
|
||||
)
|
||||
:ruleset cleanup
|
||||
:name \"delete-sum-when-generic-matmul-exists\"
|
||||
)",
|
||||
),
|
||||
Rule::raw(
|
||||
"(rule
|
||||
(
|
||||
(= ?kernel_sum (Op (KernelSum ?out_shape ?k ?sum_input_strides ?sum_iter_stride ?out_strides ?dt)
|
||||
?sum_inputs))
|
||||
(= ?kernel_sum (Op (GenericMatmul
|
||||
?go ?gm ?gk ?gls ?grs ?gsis ?gsit ?gos ?gdt)
|
||||
?generic_inputs))
|
||||
)
|
||||
((delete (Op (KernelSum ?out_shape ?k ?sum_input_strides ?sum_iter_stride ?out_strides ?dt)
|
||||
?sum_inputs)))
|
||||
:ruleset cleanup
|
||||
:name \"delete-kernel-sum-when-generic-matmul-exists\"
|
||||
)",
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
(
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(Self {
|
||||
out_shape: extract_expr_list(egraph, kind_children[0], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
mul_shape: extract_expr_list(egraph, kind_children[1], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
k: extract_expr(egraph, kind_children[2], expr_cache).unwrap(),
|
||||
lhs_strides: extract_expr_list(egraph, kind_children[3], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
rhs_strides: extract_expr_list(egraph, kind_children[4], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
sum_input_strides: extract_expr_list(
|
||||
egraph,
|
||||
kind_children[5],
|
||||
list_cache,
|
||||
expr_cache,
|
||||
)
|
||||
.unwrap(),
|
||||
sum_iter_stride: extract_expr(egraph, kind_children[6], expr_cache).unwrap(),
|
||||
out_strides: extract_expr_list(egraph, kind_children[7], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
dtype: extract_dtype(egraph, kind_children[8]),
|
||||
})),
|
||||
input_enodes,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl KernelOp for GenericMatmul {
|
||||
fn compile(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> (
|
||||
CudaFunction,
|
||||
Arc<CudaModule>,
|
||||
String,
|
||||
(Expression, Expression, Expression),
|
||||
(Expression, Expression, Expression),
|
||||
Expression,
|
||||
FxHashMap<char, CudaSlice<u8>>,
|
||||
) {
|
||||
let vars = self.all_dyn_vars();
|
||||
let dtype = cuda_dtype(self.dtype);
|
||||
let includes = dtype_includes(&[self.dtype]);
|
||||
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
|
||||
let dyn_dims_param = if vars.is_empty() {
|
||||
""
|
||||
} else {
|
||||
", const int* dyn_dims"
|
||||
};
|
||||
|
||||
let n_outputs = self.output_size();
|
||||
let sum_base_idx = flatten_strides(&self.out_shape, &self.sum_input_strides).to_kernel();
|
||||
let iter_offset = self.sum_iter_stride.to_kernel().replace("const_z", "i");
|
||||
let lhs_idx = flatten_strides(&self.mul_shape, &self.lhs_strides)
|
||||
.to_kernel()
|
||||
.replace("const_z", "mul_idx");
|
||||
let rhs_idx = flatten_strides(&self.mul_shape, &self.rhs_strides)
|
||||
.to_kernel()
|
||||
.replace("const_z", "mul_idx");
|
||||
let out_idx = flatten_strides(&self.out_shape, &self.out_strides).to_kernel();
|
||||
let k = self.k.to_kernel();
|
||||
|
||||
let kernel = format!(
|
||||
"{includes}
|
||||
#define WARP_SIZE 32
|
||||
#define THREADS_PER_BLOCK 256
|
||||
#define FULL_MASK 0xffffffff
|
||||
{dyn_defines}
|
||||
extern \"C\" {{
|
||||
__global__ void generic_matmul({dtype} *out, const {dtype} *lhs, const {dtype} *rhs{dyn_dims_param}) {{
|
||||
__shared__ float warp_sums[THREADS_PER_BLOCK / WARP_SIZE];
|
||||
long long const_z = blockIdx.x;
|
||||
if (const_z >= {n_outputs}) return;
|
||||
|
||||
int tid = threadIdx.x;
|
||||
int lane_id = tid % WARP_SIZE;
|
||||
int warp_id = tid / WARP_SIZE;
|
||||
|
||||
long long base_idx = {sum_base_idx};
|
||||
long long iters = {k};
|
||||
|
||||
float partial = 0.0f;
|
||||
for (long long i = tid; i < iters; i += THREADS_PER_BLOCK) {{
|
||||
long long mul_idx = base_idx + {iter_offset};
|
||||
partial += static_cast<float>(lhs[{lhs_idx}]) * static_cast<float>(rhs[{rhs_idx}]);
|
||||
}}
|
||||
|
||||
#pragma unroll
|
||||
for (int s = WARP_SIZE / 2; s > 0; s >>= 1) {{
|
||||
partial += __shfl_down_sync(FULL_MASK, partial, s);
|
||||
}}
|
||||
|
||||
if (lane_id == 0) {{
|
||||
warp_sums[warp_id] = partial;
|
||||
}}
|
||||
__syncthreads();
|
||||
|
||||
if (warp_id == 0) {{
|
||||
float block_sum = tid < (THREADS_PER_BLOCK / WARP_SIZE) ? warp_sums[tid] : 0.0f;
|
||||
|
||||
#pragma unroll
|
||||
for (int s = (THREADS_PER_BLOCK / WARP_SIZE) / 2; s > 0; s >>= 1) {{
|
||||
block_sum += __shfl_down_sync(FULL_MASK, block_sum, s);
|
||||
}}
|
||||
|
||||
if (tid == 0) {{
|
||||
out[{out_idx}] = ({dtype})block_sum;
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
}}",
|
||||
n_outputs = n_outputs.to_kernel(),
|
||||
);
|
||||
|
||||
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
|
||||
(module.clone(), func.clone())
|
||||
} else {
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let func = module.load_function("generic_matmul").unwrap();
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
(module, func)
|
||||
};
|
||||
|
||||
(
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(n_outputs, 1.into(), 1.into()),
|
||||
(256.into(), 1.into(), 1.into()),
|
||||
32.into(),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
self.out_shape
|
||||
.iter()
|
||||
.copied()
|
||||
.product::<Expression>()
|
||||
.max(Expression::from(1))
|
||||
}
|
||||
|
||||
fn all_dyn_vars(&self) -> FxHashSet<char> {
|
||||
self.out_shape
|
||||
.iter()
|
||||
.flat_map(|e| e.dyn_vars())
|
||||
.chain(self.mul_shape.iter().flat_map(|e| e.dyn_vars()))
|
||||
.chain(self.k.dyn_vars())
|
||||
.chain(self.lhs_strides.iter().flat_map(|e| e.dyn_vars()))
|
||||
.chain(self.rhs_strides.iter().flat_map(|e| e.dyn_vars()))
|
||||
.chain(self.sum_input_strides.iter().flat_map(|e| e.dyn_vars()))
|
||||
.chain(self.sum_iter_stride.dyn_vars())
|
||||
.chain(self.out_strides.iter().flat_map(|e| e.dyn_vars()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn output_bytes(&self) -> Expression {
|
||||
(self.output_size() * self.dtype.bits()).ceil_div(8)
|
||||
}
|
||||
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
(self.output_size() * self.k * self.dtype.bits() * 2).ceil_div(8)
|
||||
}
|
||||
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
self.output_bytes()
|
||||
}
|
||||
|
||||
fn flops(&self) -> Expression {
|
||||
self.output_size() * self.k * 2
|
||||
}
|
||||
|
||||
fn output_dtype(&self) -> DType {
|
||||
self.dtype
|
||||
}
|
||||
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
"GenericMatmul"
|
||||
}
|
||||
}
|
||||
@@ -987,7 +987,7 @@ extern \"C\" {{
|
||||
|
||||
fn output_bytes(&self) -> Expression {
|
||||
let elem_size: Expression = match self.dtype {
|
||||
DType::F64 => 8,
|
||||
DType::F64 | DType::I64 => 8,
|
||||
DType::F32 | DType::Int => 4,
|
||||
DType::F16 | DType::Bf16 | DType::I16 | DType::U16 => 2,
|
||||
DType::Bool
|
||||
@@ -1021,7 +1021,7 @@ extern \"C\" {{
|
||||
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
let data_elem_size: Expression = match self.dtype {
|
||||
DType::F64 => 8,
|
||||
DType::F64 | DType::I64 => 8,
|
||||
DType::F32 | DType::Int => 4,
|
||||
DType::F16 | DType::Bf16 | DType::I16 | DType::U16 => 2,
|
||||
DType::Bool
|
||||
|
||||
@@ -12,20 +12,28 @@ use uuid::Uuid;
|
||||
pub mod conv2d;
|
||||
pub mod cuda_graph;
|
||||
pub mod fusion;
|
||||
pub mod generic_matmul;
|
||||
pub mod hlir;
|
||||
pub mod matmul2d;
|
||||
pub mod other_ops;
|
||||
pub mod rope;
|
||||
|
||||
pub use conv2d::{Conv2DCustom, Conv2DKernel, conv2d_bias};
|
||||
pub use conv2d::KernelConv2D;
|
||||
pub use cuda_graph::*;
|
||||
pub use generic_matmul::GenericMatmul;
|
||||
pub use matmul2d::{
|
||||
Matmul2DCustom, Matmul2DKernel, linear_bias, linear_no_bias_bf16_w, matmul_2d, matmul_2d_t,
|
||||
matmul_3d, matmul_3d_t,
|
||||
};
|
||||
pub use rope::{RoPECustom, RoPEKernel, apply_rope};
|
||||
|
||||
pub type Ops = (hlir::Ops, other_ops::Ops, fusion::Ops);
|
||||
pub type Ops = (
|
||||
hlir::Ops,
|
||||
other_ops::Ops,
|
||||
conv2d::KernelConv2D,
|
||||
GenericMatmul,
|
||||
fusion::Ops,
|
||||
);
|
||||
|
||||
/// Build a mapping from interned string IDs to their string values for a given sequence.
|
||||
fn build_interned_strings(trace: &schema::Trace) -> std::collections::HashMap<(u32, u64), String> {
|
||||
@@ -296,4 +304,6 @@ luminal::impl_into_ops!(KernelOp);
|
||||
|
||||
// Kernel to host op compilation
|
||||
mod to_host;
|
||||
#[cfg(test)]
|
||||
pub(crate) use to_host::CudaGraphDebugSummary;
|
||||
pub use to_host::{CudaGraphOp, kernel_to_host};
|
||||
|
||||
@@ -17,13 +17,7 @@ use luminal::{
|
||||
prelude::*,
|
||||
};
|
||||
|
||||
pub type Ops = (
|
||||
KernelMeanReduce,
|
||||
KernelBatchMatVec,
|
||||
KernelBatchMatMul,
|
||||
KernelScatterNoCopy,
|
||||
KernelSoftmax,
|
||||
);
|
||||
pub type Ops = (KernelMeanReduce, KernelScatterNoCopy, KernelSoftmax);
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
|
||||
@@ -532,7 +526,7 @@ extern \"C\" {{
|
||||
|
||||
fn output_bytes(&self) -> Expression {
|
||||
let elem_size: Expression = match self.dtype {
|
||||
DType::F64 => 8,
|
||||
DType::F64 | DType::I64 => 8,
|
||||
DType::F32 | DType::Int => 4,
|
||||
DType::F16 | DType::Bf16 | DType::I16 | DType::U16 => 2,
|
||||
DType::Bool
|
||||
@@ -566,7 +560,7 @@ extern \"C\" {{
|
||||
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
let data_elem_size: Expression = match self.dtype {
|
||||
DType::F64 => 8,
|
||||
DType::F64 | DType::I64 => 8,
|
||||
DType::F32 | DType::Int => 4,
|
||||
DType::F16 | DType::Bf16 | DType::I16 | DType::U16 => 2,
|
||||
DType::Bool
|
||||
@@ -585,7 +579,7 @@ extern \"C\" {{
|
||||
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
let data_elem_size: Expression = match self.dtype {
|
||||
DType::F64 => 8,
|
||||
DType::F64 | DType::I64 => 8,
|
||||
DType::F32 | DType::Int => 4,
|
||||
DType::F16 | DType::Bf16 | DType::I16 | DType::U16 => 2,
|
||||
DType::Bool
|
||||
@@ -619,569 +613,6 @@ extern \"C\" {{
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// KernelBatchMatVec: Fused batched matrix-vector product for attention
|
||||
// Matches: Mul(broadcast) + Sum pattern for [B, 1, K] x [B, K, N] -> [B, 1, N]
|
||||
// or [B, M, K] x [B, K, N] -> [B, M, N] with small M
|
||||
// Replaces the broadcast elementwise Mul + single-threaded KernelSumReduce pipeline
|
||||
// =============================================================================
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct KernelBatchMatVec {
|
||||
// Output shape: the final reduced shape [B..., M, N]
|
||||
out_shape: Vec<Expression>,
|
||||
// K: the reduction dimension (was the Sum iters)
|
||||
k_dim: Expression,
|
||||
// Strides for input A (with K dim removed)
|
||||
a_stride: Vec<Expression>,
|
||||
a_k_stride: Expression,
|
||||
// Strides for input B (with K dim removed)
|
||||
b_stride: Vec<Expression>,
|
||||
b_k_stride: Expression,
|
||||
// Output strides
|
||||
out_stride: Vec<Expression>,
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
impl EgglogOp for KernelBatchMatVec {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
"KernelBatchMatVec",
|
||||
&[
|
||||
("out_shape", ELIST),
|
||||
("k_dim", EXPRESSION),
|
||||
("a_stride", ELIST),
|
||||
("a_k_stride", EXPRESSION),
|
||||
("b_stride", ELIST),
|
||||
("b_k_stride", EXPRESSION),
|
||||
("out_strides", ELIST),
|
||||
("dtype", DTYPE),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
fn n_inputs(&self) -> usize {
|
||||
2
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![Rule::raw(
|
||||
"(rule
|
||||
(
|
||||
; Match Mul node (broadcast multiply)
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
; Match Sum that reduces the Mul (k dimension)
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
|
||||
; Output shape must have 3+ dimensions (batched)
|
||||
(= ?out_shape (ECons ?batch_or_d0 (ECons ?d1 (ECons ?d2 ?rest))))
|
||||
|
||||
; k_stride must be contiguous
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
; Get A's k-dimension stride (second from end in Mul's a_stride)
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 1))
|
||||
|
||||
; Get B's k-dimension stride (second from end in Mul's b_stride)
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 1))
|
||||
|
||||
; A's k stride must be contiguous (row-major A)
|
||||
(= ?a_k_stride (MIter))
|
||||
|
||||
; B's k stride must be contiguous (col-major B)
|
||||
(= ?b_k_stride (MIter))
|
||||
|
||||
; Must be F32
|
||||
(= (F32) (dtype ?a))
|
||||
(= (F32) (dtype ?b))
|
||||
)
|
||||
(
|
||||
; Remove the k-dimension from A strides for the kernel
|
||||
(let ?a_kern_stride (RemoveNthFromEnd ?a_stride 1))
|
||||
; Remove the k-dimension from B strides
|
||||
(let ?b_kern_stride (RemoveNthFromEnd ?b_stride 1))
|
||||
|
||||
(let ?bmv (Op (KernelBatchMatVec
|
||||
?out_shape ?k
|
||||
?a_kern_stride ?a_k_stride
|
||||
?b_kern_stride ?b_k_stride
|
||||
?sum_out_stride (F32)) (ICons ?a (ICons ?b (INil)))))
|
||||
(union ?sum ?bmv)
|
||||
(set (dtype ?bmv) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name \"batch mat-vec\"
|
||||
)"
|
||||
)]
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
(
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(Self {
|
||||
out_shape: extract_expr_list(egraph, kind_children[0], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
k_dim: extract_expr(egraph, kind_children[1], expr_cache).unwrap(),
|
||||
a_stride: extract_expr_list(egraph, kind_children[2], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
a_k_stride: extract_expr(egraph, kind_children[3], expr_cache).unwrap(),
|
||||
b_stride: extract_expr_list(egraph, kind_children[4], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
b_k_stride: extract_expr(egraph, kind_children[5], expr_cache).unwrap(),
|
||||
out_stride: extract_expr_list(egraph, kind_children[6], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
dtype: extract_dtype(egraph, kind_children[7]),
|
||||
})),
|
||||
input_enodes, // A, B
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl KernelOp for KernelBatchMatVec {
|
||||
fn compile(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> (
|
||||
CudaFunction,
|
||||
Arc<CudaModule>,
|
||||
String,
|
||||
(Expression, Expression, Expression),
|
||||
(Expression, Expression, Expression),
|
||||
Expression,
|
||||
FxHashMap<char, CudaSlice<u8>>,
|
||||
) {
|
||||
let vars: FxHashSet<char> = self
|
||||
.out_shape
|
||||
.iter()
|
||||
.flat_map(|e| e.dyn_vars())
|
||||
.chain(self.a_stride.iter().flat_map(|e| e.dyn_vars()))
|
||||
.chain(self.b_stride.iter().flat_map(|e| e.dyn_vars()))
|
||||
.chain(self.out_stride.iter().flat_map(|e| e.dyn_vars()))
|
||||
.chain(self.k_dim.dyn_vars())
|
||||
.chain(self.a_k_stride.dyn_vars())
|
||||
.chain(self.b_k_stride.dyn_vars())
|
||||
.collect();
|
||||
|
||||
let n_outputs: Expression = self.out_shape.iter().copied().product();
|
||||
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
|
||||
let dyn_dims_param = if vars.is_empty() {
|
||||
""
|
||||
} else {
|
||||
", const int* dyn_dims"
|
||||
};
|
||||
|
||||
// Each output element is a dot product of length K.
|
||||
// We launch one block of 256 threads per output element.
|
||||
// Threads cooperatively reduce K using warp shuffles.
|
||||
let a_idx = flatten_strides(&self.out_shape, &self.a_stride).to_kernel();
|
||||
let b_idx = flatten_strides(&self.out_shape, &self.b_stride).to_kernel();
|
||||
let out_idx = flatten_strides(&self.out_shape, &self.out_stride).to_kernel();
|
||||
let k_expr = self.k_dim.to_kernel();
|
||||
let a_k_stride_expr = self
|
||||
.a_k_stride
|
||||
.substitute('z', Expression::from(1))
|
||||
.simplify()
|
||||
.to_kernel();
|
||||
let b_k_stride_expr = self
|
||||
.b_k_stride
|
||||
.substitute('z', Expression::from(1))
|
||||
.simplify()
|
||||
.to_kernel();
|
||||
|
||||
let kernel = format!(
|
||||
"
|
||||
#define WARP_SIZE 32
|
||||
#define THREADS_PER_BLOCK 256
|
||||
#define FULL_MASK 0xffffffff
|
||||
{dyn_defines}
|
||||
extern \"C\" {{
|
||||
__global__ void batch_matvec(float *out, const float *A, const float *B{dyn_dims_param}) {{
|
||||
__shared__ float warp_sums[THREADS_PER_BLOCK / WARP_SIZE];
|
||||
long long const_z = blockIdx.x;
|
||||
int tid = threadIdx.x;
|
||||
int lane_id = tid % WARP_SIZE;
|
||||
int warp_id = tid / WARP_SIZE;
|
||||
|
||||
long long a_base = {a_idx};
|
||||
long long b_base = {b_idx};
|
||||
long long K = {k_expr};
|
||||
long long a_k_stride = {a_k_stride_expr};
|
||||
long long b_k_stride = {b_k_stride_expr};
|
||||
|
||||
float partial = 0.0f;
|
||||
for (long long k = tid; k < K; k += THREADS_PER_BLOCK) {{
|
||||
partial += A[a_base + k * a_k_stride] * B[b_base + k * b_k_stride];
|
||||
}}
|
||||
|
||||
#pragma unroll
|
||||
for (int s = WARP_SIZE / 2; s > 0; s /= 2) {{
|
||||
partial += __shfl_down_sync(FULL_MASK, partial, s);
|
||||
}}
|
||||
|
||||
if (lane_id == 0) {{
|
||||
warp_sums[warp_id] = partial;
|
||||
}}
|
||||
__syncthreads();
|
||||
|
||||
if (warp_id == 0) {{
|
||||
int cnt = THREADS_PER_BLOCK / WARP_SIZE;
|
||||
float block_sum = tid < cnt ? warp_sums[tid] : 0.0f;
|
||||
|
||||
#pragma unroll
|
||||
for (int s = cnt / 2; s > 0; s /= 2) {{
|
||||
block_sum += __shfl_down_sync(FULL_MASK, block_sum, s);
|
||||
}}
|
||||
|
||||
if (tid == 0) {{
|
||||
out[{out_idx}] = block_sum;
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
}}"
|
||||
);
|
||||
|
||||
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
|
||||
(module.clone(), func.clone())
|
||||
} else {
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let func = module.load_function("batch_matvec").unwrap();
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
(module, func)
|
||||
};
|
||||
|
||||
(
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(n_outputs, 1.into(), 1.into()), // grid: one block per output
|
||||
(256.into(), 1.into(), 1.into()), // block: 256 threads
|
||||
32.into(), // shared mem for warp_sums
|
||||
FxHashMap::default(),
|
||||
)
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
self.out_shape.iter().copied().product()
|
||||
}
|
||||
|
||||
fn output_bytes(&self) -> Expression {
|
||||
self.output_size() * 4
|
||||
}
|
||||
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
let n = self.output_size();
|
||||
// Each output loads K elements from A and K elements from B
|
||||
n * self.k_dim * 2 * 4
|
||||
}
|
||||
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
self.output_size() * 4
|
||||
}
|
||||
|
||||
fn flops(&self) -> Expression {
|
||||
// Each output: K multiply-adds = 2*K FLOPs
|
||||
self.output_size() * self.k_dim * 2
|
||||
}
|
||||
|
||||
fn output_dtype(&self) -> DType {
|
||||
self.dtype
|
||||
}
|
||||
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
"BatchMatVec"
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// KernelBatchMatMul: General batched matmul with arbitrary strides
|
||||
// Like KernelBatchMatVec but handles non-contiguous K strides (e.g., transposed
|
||||
// inputs) and non-uniform batch strides (e.g., GQA expansion). One block of 256
|
||||
// threads per output element; threads cooperatively reduce along K.
|
||||
// =============================================================================
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct KernelBatchMatMul {
|
||||
out_shape: Vec<Expression>,
|
||||
k_dim: Expression,
|
||||
a_stride: Vec<Expression>,
|
||||
a_k_stride: Expression,
|
||||
b_stride: Vec<Expression>,
|
||||
b_k_stride: Expression,
|
||||
out_stride: Vec<Expression>,
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
impl EgglogOp for KernelBatchMatMul {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
"KernelBatchMatMul",
|
||||
&[
|
||||
("out_shape", ELIST),
|
||||
("k_dim", EXPRESSION),
|
||||
("a_stride", ELIST),
|
||||
("a_k_stride", EXPRESSION),
|
||||
("b_stride", ELIST),
|
||||
("b_k_stride", EXPRESSION),
|
||||
("out_strides", ELIST),
|
||||
("dtype", DTYPE),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
fn n_inputs(&self) -> usize {
|
||||
2
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![Rule::raw(
|
||||
"(rule
|
||||
(
|
||||
; Match Mul node (broadcast multiply)
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
; Match Sum that reduces the Mul (k dimension)
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
|
||||
; Output shape must have 3+ dimensions (batched)
|
||||
(= ?out_shape (ECons ?batch_or_d0 (ECons ?d1 (ECons ?d2 ?rest))))
|
||||
|
||||
; k_stride must be contiguous in the Sum output
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
; K must be > 1 (K=1 is a degenerate outer product, not a real matmul)
|
||||
(!= ?k (MNum 1))
|
||||
|
||||
; Get A's and B's k-dimension strides (no contiguity requirement)
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 1))
|
||||
|
||||
; One of A's non-k strides must be 0 (broadcast along n)
|
||||
(= (MNum 0) (nth_from_end ?a_stride 0))
|
||||
|
||||
; One of B's non-k strides must be 0 (broadcast along m)
|
||||
(= (MNum 0) (nth_from_end ?b_stride 2))
|
||||
|
||||
; Must be F32
|
||||
(= (F32) (dtype ?a))
|
||||
(= (F32) (dtype ?b))
|
||||
)
|
||||
(
|
||||
(let ?a_kern_stride (RemoveNthFromEnd ?a_stride 1))
|
||||
(let ?b_kern_stride (RemoveNthFromEnd ?b_stride 1))
|
||||
|
||||
(let ?bmm (Op (KernelBatchMatMul
|
||||
?out_shape ?k
|
||||
?a_kern_stride ?a_k_stride
|
||||
?b_kern_stride ?b_k_stride
|
||||
?sum_out_stride (F32)) (ICons ?a (ICons ?b (INil)))))
|
||||
(union ?sum ?bmm)
|
||||
(set (dtype ?bmm) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name \"batch matmul\"
|
||||
)"
|
||||
)]
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
(
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(Self {
|
||||
out_shape: extract_expr_list(egraph, kind_children[0], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
k_dim: extract_expr(egraph, kind_children[1], expr_cache).unwrap(),
|
||||
a_stride: extract_expr_list(egraph, kind_children[2], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
a_k_stride: extract_expr(egraph, kind_children[3], expr_cache).unwrap(),
|
||||
b_stride: extract_expr_list(egraph, kind_children[4], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
b_k_stride: extract_expr(egraph, kind_children[5], expr_cache).unwrap(),
|
||||
out_stride: extract_expr_list(egraph, kind_children[6], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
dtype: extract_dtype(egraph, kind_children[7]),
|
||||
})),
|
||||
input_enodes,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl KernelOp for KernelBatchMatMul {
|
||||
fn compile(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> (
|
||||
CudaFunction,
|
||||
Arc<CudaModule>,
|
||||
String,
|
||||
(Expression, Expression, Expression),
|
||||
(Expression, Expression, Expression),
|
||||
Expression,
|
||||
FxHashMap<char, CudaSlice<u8>>,
|
||||
) {
|
||||
let vars: FxHashSet<char> = self
|
||||
.out_shape
|
||||
.iter()
|
||||
.flat_map(|e| e.dyn_vars())
|
||||
.chain(self.a_stride.iter().flat_map(|e| e.dyn_vars()))
|
||||
.chain(self.b_stride.iter().flat_map(|e| e.dyn_vars()))
|
||||
.chain(self.out_stride.iter().flat_map(|e| e.dyn_vars()))
|
||||
.chain(self.k_dim.dyn_vars())
|
||||
.chain(self.a_k_stride.dyn_vars())
|
||||
.chain(self.b_k_stride.dyn_vars())
|
||||
.collect();
|
||||
|
||||
let n_outputs: Expression = self.out_shape.iter().copied().product();
|
||||
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
|
||||
let dyn_dims_param = if vars.is_empty() {
|
||||
""
|
||||
} else {
|
||||
", const int* dyn_dims"
|
||||
};
|
||||
|
||||
let a_idx = flatten_strides(&self.out_shape, &self.a_stride).to_kernel();
|
||||
let b_idx = flatten_strides(&self.out_shape, &self.b_stride).to_kernel();
|
||||
let out_idx = flatten_strides(&self.out_shape, &self.out_stride).to_kernel();
|
||||
let k_expr = self.k_dim.to_kernel();
|
||||
let a_k_stride_expr = self
|
||||
.a_k_stride
|
||||
.substitute('z', Expression::from(1))
|
||||
.simplify()
|
||||
.to_kernel();
|
||||
let b_k_stride_expr = self
|
||||
.b_k_stride
|
||||
.substitute('z', Expression::from(1))
|
||||
.simplify()
|
||||
.to_kernel();
|
||||
|
||||
let kernel = format!(
|
||||
"
|
||||
#define WARP_SIZE 32
|
||||
#define THREADS_PER_BLOCK 256
|
||||
#define FULL_MASK 0xffffffff
|
||||
{dyn_defines}
|
||||
extern \"C\" {{
|
||||
__global__ void batch_matmul(float *out, const float *A, const float *B{dyn_dims_param}) {{
|
||||
__shared__ float warp_sums[THREADS_PER_BLOCK / WARP_SIZE];
|
||||
long long const_z = blockIdx.x;
|
||||
int tid = threadIdx.x;
|
||||
int lane_id = tid % WARP_SIZE;
|
||||
int warp_id = tid / WARP_SIZE;
|
||||
|
||||
long long a_base = {a_idx};
|
||||
long long b_base = {b_idx};
|
||||
long long K = {k_expr};
|
||||
long long a_k_stride = {a_k_stride_expr};
|
||||
long long b_k_stride = {b_k_stride_expr};
|
||||
|
||||
float partial = 0.0f;
|
||||
for (long long k = tid; k < K; k += THREADS_PER_BLOCK) {{
|
||||
partial += A[a_base + k * a_k_stride] * B[b_base + k * b_k_stride];
|
||||
}}
|
||||
|
||||
#pragma unroll
|
||||
for (int s = WARP_SIZE / 2; s > 0; s /= 2) {{
|
||||
partial += __shfl_down_sync(FULL_MASK, partial, s);
|
||||
}}
|
||||
|
||||
if (lane_id == 0) {{
|
||||
warp_sums[warp_id] = partial;
|
||||
}}
|
||||
__syncthreads();
|
||||
|
||||
if (warp_id == 0) {{
|
||||
int cnt = THREADS_PER_BLOCK / WARP_SIZE;
|
||||
float block_sum = tid < cnt ? warp_sums[tid] : 0.0f;
|
||||
|
||||
#pragma unroll
|
||||
for (int s = cnt / 2; s > 0; s /= 2) {{
|
||||
block_sum += __shfl_down_sync(FULL_MASK, block_sum, s);
|
||||
}}
|
||||
|
||||
if (tid == 0) {{
|
||||
out[{out_idx}] = block_sum;
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
}}"
|
||||
);
|
||||
|
||||
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
|
||||
(module.clone(), func.clone())
|
||||
} else {
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let func = module.load_function("batch_matmul").unwrap();
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
(module, func)
|
||||
};
|
||||
|
||||
(
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(n_outputs, 1.into(), 1.into()),
|
||||
(256.into(), 1.into(), 1.into()),
|
||||
32.into(),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
self.out_shape.iter().copied().product()
|
||||
}
|
||||
|
||||
fn output_bytes(&self) -> Expression {
|
||||
self.output_size() * 4
|
||||
}
|
||||
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
let n = self.output_size();
|
||||
n * self.k_dim * 2 * 4
|
||||
}
|
||||
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
self.output_size() * 4
|
||||
}
|
||||
|
||||
fn flops(&self) -> Expression {
|
||||
self.output_size() * self.k_dim * 2
|
||||
}
|
||||
|
||||
fn output_dtype(&self) -> DType {
|
||||
self.dtype
|
||||
}
|
||||
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
"BatchMatMul"
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// KernelSoftmax: Fused softmax over last dimension
|
||||
// Matches: Mul(Recip(Sum(Exp2(Sub(x, Max(x))))), Exp2(Sub(x, Max(x))))
|
||||
@@ -1338,9 +769,21 @@ impl KernelOp for KernelSoftmax {
|
||||
#define FULL_MASK 0xffffffff
|
||||
#define NEG_INF_F __int_as_float(0xff800000)
|
||||
{dyn_defines}
|
||||
#define LOG2E 1.4426950408889634f
|
||||
|
||||
extern \"C\" {{
|
||||
// Online normalizer calculation for softmax (Milakov & Gimelshein 2018).
|
||||
|
||||
// Merge two partial (max, sum) pairs using the online softmax rule.
|
||||
__device__ __forceinline__ void merge_md(float *m, float *d, float m2, float d2) {{
|
||||
float new_m = fmaxf(*m, m2);
|
||||
*d = *d * exp2f((*m - new_m) * LOG2E) + d2 * exp2f((m2 - new_m) * LOG2E);
|
||||
*m = new_m;
|
||||
}}
|
||||
|
||||
__global__ void fused_softmax(float *out, const float *inp{dyn_dims_param}) {{
|
||||
__shared__ float shared[THREADS_PER_BLOCK / WARP_SIZE];
|
||||
__shared__ float sh_m[THREADS_PER_BLOCK / WARP_SIZE];
|
||||
__shared__ float sh_d[THREADS_PER_BLOCK / WARP_SIZE];
|
||||
long long const_z = blockIdx.x;
|
||||
int tid = threadIdx.x;
|
||||
int lane_id = tid % WARP_SIZE;
|
||||
@@ -1352,55 +795,36 @@ extern \"C\" {{
|
||||
long long in_stride = {in_reduce_stride};
|
||||
long long out_stride = {out_reduce_stride};
|
||||
|
||||
// Pass 1: find max
|
||||
float max_val = NEG_INF_F;
|
||||
// Pass 1: one read of inp produces (global_max, global_sum).
|
||||
float m = NEG_INF_F, d = 0.0f;
|
||||
for (long long i = tid; i < N; i += THREADS_PER_BLOCK) {{
|
||||
max_val = fmaxf(max_val, inp[in_base + i * in_stride]);
|
||||
merge_md(&m, &d, inp[in_base + i * in_stride], 1.0f);
|
||||
}}
|
||||
// Warp reduce: collapse 32 threads within each warp down to lane 0.
|
||||
#pragma unroll
|
||||
for (int s = WARP_SIZE / 2; s > 0; s /= 2) {{
|
||||
max_val = fmaxf(max_val, __shfl_down_sync(FULL_MASK, max_val, s));
|
||||
merge_md(&m, &d, __shfl_down_sync(FULL_MASK, m, s), __shfl_down_sync(FULL_MASK, d, s));
|
||||
}}
|
||||
if (lane_id == 0) shared[warp_id] = max_val;
|
||||
if (lane_id == 0) {{ sh_m[warp_id] = m; sh_d[warp_id] = d; }}
|
||||
__syncthreads();
|
||||
// Block reduce: warp 0 collapses the 8 warp results down to one.
|
||||
if (warp_id == 0) {{
|
||||
max_val = tid < (THREADS_PER_BLOCK / WARP_SIZE) ? shared[tid] : NEG_INF_F;
|
||||
m = tid < (THREADS_PER_BLOCK / WARP_SIZE) ? sh_m[tid] : NEG_INF_F;
|
||||
d = tid < (THREADS_PER_BLOCK / WARP_SIZE) ? sh_d[tid] : 0.0f;
|
||||
#pragma unroll
|
||||
for (int s = (THREADS_PER_BLOCK / WARP_SIZE) / 2; s > 0; s /= 2) {{
|
||||
max_val = fmaxf(max_val, __shfl_down_sync(FULL_MASK, max_val, s));
|
||||
merge_md(&m, &d, __shfl_down_sync(FULL_MASK, m, s), __shfl_down_sync(FULL_MASK, d, s));
|
||||
}}
|
||||
shared[0] = max_val;
|
||||
sh_m[0] = m;
|
||||
sh_d[0] = d;
|
||||
}}
|
||||
__syncthreads();
|
||||
max_val = shared[0];
|
||||
float global_max = sh_m[0];
|
||||
float inv_sum = 1.0f / sh_d[0];
|
||||
|
||||
// Pass 2: compute exp2 and sum
|
||||
float sum_val = 0.0f;
|
||||
// Pass 2: write final softmax values.
|
||||
for (long long i = tid; i < N; i += THREADS_PER_BLOCK) {{
|
||||
float v = exp2f((inp[in_base + i * in_stride] - max_val) * 1.4426950408889634f);
|
||||
out[out_base + i * out_stride] = v; // store exp temporarily
|
||||
sum_val += v;
|
||||
}}
|
||||
#pragma unroll
|
||||
for (int s = WARP_SIZE / 2; s > 0; s /= 2) {{
|
||||
sum_val += __shfl_down_sync(FULL_MASK, sum_val, s);
|
||||
}}
|
||||
if (lane_id == 0) shared[warp_id] = sum_val;
|
||||
__syncthreads();
|
||||
if (warp_id == 0) {{
|
||||
sum_val = tid < (THREADS_PER_BLOCK / WARP_SIZE) ? shared[tid] : 0.0f;
|
||||
#pragma unroll
|
||||
for (int s = (THREADS_PER_BLOCK / WARP_SIZE) / 2; s > 0; s /= 2) {{
|
||||
sum_val += __shfl_down_sync(FULL_MASK, sum_val, s);
|
||||
}}
|
||||
shared[0] = sum_val;
|
||||
}}
|
||||
__syncthreads();
|
||||
float inv_sum = 1.0f / shared[0];
|
||||
|
||||
// Pass 3: normalize
|
||||
for (long long i = tid; i < N; i += THREADS_PER_BLOCK) {{
|
||||
out[out_base + i * out_stride] *= inv_sum;
|
||||
out[out_base + i * out_stride] = exp2f((inp[in_base + i * in_stride] - global_max) * LOG2E) * inv_sum;
|
||||
}}
|
||||
}}
|
||||
}}"
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -34,6 +34,7 @@ fn cuda_dtype(dtype: DType) -> &'static str {
|
||||
DType::Bf16 => "__nv_bfloat16",
|
||||
DType::TF32 => "float", // TF32 uses float storage, tensor cores handle the format
|
||||
DType::Int => "int",
|
||||
DType::I64 => "long long",
|
||||
DType::I16 => "short",
|
||||
DType::U16 => "unsigned short",
|
||||
DType::I8 => "signed char",
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -22,6 +22,10 @@ fn build_dynamic_matmul_graph(k: usize, n: usize) -> (Graph, NodeIndex, NodeInde
|
||||
(cx, a.id, b.id, c.id)
|
||||
}
|
||||
|
||||
fn bucket_options(buckets: &[DimBucket]) -> CompileOptions {
|
||||
CompileOptions::default().dim_buckets('s', buckets)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bucket_dispatch_simple() {
|
||||
// Tests that bucketed compilation produces correct results for different dim values
|
||||
@@ -31,9 +35,10 @@ fn test_bucket_dispatch_simple() {
|
||||
|
||||
let (mut cx, a, b) = build_dynamic_add_graph();
|
||||
|
||||
cx.set_dim_buckets('s', &[DimBucket::new(1, 1), DimBucket::new(2, 4)]);
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(bucket_options(&[
|
||||
DimBucket::new(1, 1),
|
||||
DimBucket::new(2, 4),
|
||||
]));
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
|
||||
// Set dummy input for search
|
||||
@@ -41,7 +46,11 @@ fn test_bucket_dispatch_simple() {
|
||||
rt.set_data(a, vec![1.0f32; 4]);
|
||||
|
||||
let mut rng = SmallRng::seed_from_u64(42);
|
||||
rt = cx.search_options(rt, SearchOptions::new(5), &mut rng);
|
||||
rt = cx.search_with_rng(
|
||||
rt,
|
||||
CompileOptions::default().search_graph_limit(5),
|
||||
&mut rng,
|
||||
);
|
||||
|
||||
// Test bucket 1: s=1
|
||||
cx.set_dim('s', 1);
|
||||
@@ -73,9 +82,10 @@ fn test_bucket_matmul_dynamic() {
|
||||
let n = 4;
|
||||
let (mut cx, a, b_tensor, c) = build_dynamic_matmul_graph(k, n);
|
||||
|
||||
cx.set_dim_buckets('s', &[DimBucket::new(1, 1), DimBucket::new(2, 8)]);
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(bucket_options(&[
|
||||
DimBucket::new(1, 1),
|
||||
DimBucket::new(2, 8),
|
||||
]));
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
|
||||
cx.set_dim('s', 1);
|
||||
@@ -85,7 +95,11 @@ fn test_bucket_matmul_dynamic() {
|
||||
rt.set_data(b_tensor, b_data.clone());
|
||||
|
||||
let mut rng = SmallRng::seed_from_u64(42);
|
||||
rt = cx.search_options(rt, SearchOptions::new(5), &mut rng);
|
||||
rt = cx.search_with_rng(
|
||||
rt,
|
||||
CompileOptions::default().search_graph_limit(5),
|
||||
&mut rng,
|
||||
);
|
||||
|
||||
// Execute at s=1
|
||||
cx.set_dim('s', 1);
|
||||
@@ -135,12 +149,16 @@ fn test_bucket_results_match_unbucketed() {
|
||||
// Non-bucketed run
|
||||
let (mut cx1, a1, b1) = build_dynamic_add_graph();
|
||||
cx1.set_dim('s', 3);
|
||||
cx1.build_search_space::<CudaRuntime>();
|
||||
cx1.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let mut rt1 = CudaRuntime::initialize(stream.clone());
|
||||
let input_data = random_f32_vec(12, seed, -1.0, 1.0);
|
||||
rt1.set_data(a1, input_data.clone());
|
||||
let mut rng1 = SmallRng::seed_from_u64(seed);
|
||||
rt1 = cx1.search_options(rt1, SearchOptions::new(5), &mut rng1);
|
||||
rt1 = cx1.search_with_rng(
|
||||
rt1,
|
||||
CompileOptions::default().search_graph_limit(5),
|
||||
&mut rng1,
|
||||
);
|
||||
rt1.set_data(a1, input_data.clone());
|
||||
rt1.execute(&cx1.dyn_map);
|
||||
let result_unbucketed = rt1.get_f32(b1);
|
||||
@@ -148,12 +166,15 @@ fn test_bucket_results_match_unbucketed() {
|
||||
// Bucketed run with bucket that covers s=3
|
||||
let (mut cx2, a2, b2) = build_dynamic_add_graph();
|
||||
cx2.set_dim('s', 3);
|
||||
cx2.set_dim_buckets('s', &[DimBucket::new(1, 4)]);
|
||||
cx2.build_search_space::<CudaRuntime>();
|
||||
cx2.build_search_space::<CudaRuntime>(bucket_options(&[DimBucket::new(1, 4)]));
|
||||
let mut rt2 = CudaRuntime::initialize(stream.clone());
|
||||
rt2.set_data(a2, input_data.clone());
|
||||
let mut rng2 = SmallRng::seed_from_u64(seed);
|
||||
rt2 = cx2.search_options(rt2, SearchOptions::new(5), &mut rng2);
|
||||
rt2 = cx2.search_with_rng(
|
||||
rt2,
|
||||
CompileOptions::default().search_graph_limit(5),
|
||||
&mut rng2,
|
||||
);
|
||||
rt2.set_data(a2, input_data.clone());
|
||||
rt2.execute(&cx2.dyn_map);
|
||||
let result_bucketed = rt2.get_f32(b2);
|
||||
@@ -172,14 +193,20 @@ fn test_bucket_out_of_range_panics() {
|
||||
};
|
||||
|
||||
let (mut cx, a, _b) = build_dynamic_add_graph();
|
||||
cx.set_dim_buckets('s', &[DimBucket::new(1, 1), DimBucket::new(2, 4)]);
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(bucket_options(&[
|
||||
DimBucket::new(1, 1),
|
||||
DimBucket::new(2, 4),
|
||||
]));
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
cx.set_dim('s', 1);
|
||||
rt.set_data(a, vec![1.0f32; 4]);
|
||||
let mut rng = SmallRng::seed_from_u64(42);
|
||||
rt = cx.search_options(rt, SearchOptions::new(3), &mut rng);
|
||||
rt = cx.search_with_rng(
|
||||
rt,
|
||||
CompileOptions::default().search_graph_limit(3),
|
||||
&mut rng,
|
||||
);
|
||||
|
||||
// s=10 is outside all buckets — should panic
|
||||
cx.set_dim('s', 10);
|
||||
@@ -197,14 +224,18 @@ fn test_bucket_no_buckets_backward_compat() {
|
||||
let (mut cx, a, b) = build_dynamic_add_graph();
|
||||
cx.set_dim('s', 2);
|
||||
|
||||
// No set_dim_buckets call
|
||||
// No bucket options
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
let input_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
|
||||
rt.set_data(a, input_data.clone());
|
||||
let mut rng = SmallRng::seed_from_u64(42);
|
||||
rt = cx.search_options(rt, SearchOptions::new(3), &mut rng);
|
||||
rt = cx.search_with_rng(
|
||||
rt,
|
||||
CompileOptions::default().search_graph_limit(3),
|
||||
&mut rng,
|
||||
);
|
||||
|
||||
rt.set_data(a, input_data.clone());
|
||||
rt.execute(&cx.dyn_map);
|
||||
@@ -237,9 +268,10 @@ fn test_bucket_switch_preserves_weights() {
|
||||
let n = 4;
|
||||
let (mut cx, a, b_tensor, c) = build_dynamic_matmul_graph(k, n);
|
||||
|
||||
cx.set_dim_buckets('s', &[DimBucket::new(1, 1), DimBucket::new(2, 4)]);
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(bucket_options(&[
|
||||
DimBucket::new(1, 1),
|
||||
DimBucket::new(2, 4),
|
||||
]));
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
|
||||
cx.set_dim('s', 1);
|
||||
@@ -249,7 +281,11 @@ fn test_bucket_switch_preserves_weights() {
|
||||
rt.set_data(b_tensor, b_data.clone());
|
||||
|
||||
let mut rng = SmallRng::seed_from_u64(42);
|
||||
rt = cx.search_options(rt, SearchOptions::new(5), &mut rng);
|
||||
rt = cx.search_with_rng(
|
||||
rt,
|
||||
CompileOptions::default().search_graph_limit(5),
|
||||
&mut rng,
|
||||
);
|
||||
|
||||
// Execute with bucket 1 (s=1)
|
||||
cx.set_dim('s', 1);
|
||||
@@ -297,15 +333,17 @@ fn test_bucket_multiple_executions_same_bucket() {
|
||||
|
||||
let (mut cx, a, b) = build_dynamic_add_graph();
|
||||
|
||||
cx.set_dim_buckets('s', &[DimBucket::new(1, 8)]);
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(bucket_options(&[DimBucket::new(1, 8)]));
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
|
||||
cx.set_dim('s', 1);
|
||||
rt.set_data(a, vec![1.0f32; 4]);
|
||||
let mut rng = SmallRng::seed_from_u64(42);
|
||||
rt = cx.search_options(rt, SearchOptions::new(3), &mut rng);
|
||||
rt = cx.search_with_rng(
|
||||
rt,
|
||||
CompileOptions::default().search_graph_limit(3),
|
||||
&mut rng,
|
||||
);
|
||||
|
||||
// Execute at different sizes within the same bucket
|
||||
for s in [1, 2, 4, 8] {
|
||||
@@ -323,8 +361,7 @@ fn test_bucket_multiple_executions_same_bucket() {
|
||||
#[test]
|
||||
#[should_panic(expected = "Overlapping buckets")]
|
||||
fn test_bucket_overlapping_ranges_panics() {
|
||||
let mut cx = Graph::default();
|
||||
cx.set_dim_buckets('s', &[DimBucket::new(1, 4), DimBucket::new(3, 8)]);
|
||||
let _ = bucket_options(&[DimBucket::new(1, 4), DimBucket::new(3, 8)]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -10,7 +10,7 @@ use crate::runtime::CudaRuntime;
|
||||
|
||||
/// Helper: build search space and extract all possible kernel names across many random choices.
|
||||
fn extract_all_kernel_names(cx: &mut Graph) -> Vec<String> {
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let egraph = cx.egraph().expect("egraph not built");
|
||||
let ops = cx.egglog_ops().expect("ops not built");
|
||||
let custom_ops = &cx.custom_ops;
|
||||
@@ -199,7 +199,7 @@ fn test_scatter_execution_correctness() {
|
||||
|
||||
let result = src.scatter(indexes, dest).output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let egraph = cx.egraph().expect("egraph not built");
|
||||
let ops = cx.egglog_ops().expect("ops not built");
|
||||
|
||||
@@ -298,7 +298,7 @@ fn test_scatter_kv_cache_roundtrip() {
|
||||
// Return cache for round-trip
|
||||
let cache_output = cache_out.output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream.clone());
|
||||
|
||||
@@ -307,7 +307,7 @@ fn test_scatter_kv_cache_roundtrip() {
|
||||
rt.set_data(src, vec![10.0f32]);
|
||||
rt.set_data(indexes, vec![0i32]);
|
||||
|
||||
rt = cx.search(rt, 5);
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
|
||||
// Print and verify which scatter variant was selected
|
||||
let scatter_names: Vec<_> = rt
|
||||
@@ -415,7 +415,7 @@ fn test_scatter_dual_cache() {
|
||||
let k_cache_out = k_out.output();
|
||||
let v_cache_out = v_out.output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream.clone());
|
||||
|
||||
@@ -427,7 +427,11 @@ fn test_scatter_dual_cache() {
|
||||
|
||||
// Use seeded search for deterministic variant selection.
|
||||
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
|
||||
rt = cx.search_options(rt, SearchOptions::new(5), &mut rng);
|
||||
rt = cx.search_with_rng(
|
||||
rt,
|
||||
CompileOptions::default().search_graph_limit(5),
|
||||
&mut rng,
|
||||
);
|
||||
|
||||
// Print and verify selected variants
|
||||
let scatter_names: Vec<_> = rt
|
||||
@@ -535,7 +539,7 @@ fn test_scatter_rows_dynamic_prefill_roundtrip() {
|
||||
let gathered = gather_rows(updated, gather_idx, D).output();
|
||||
let cache_out = updated.output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
cx.set_dim('s', S);
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
@@ -554,7 +558,11 @@ fn test_scatter_rows_dynamic_prefill_roundtrip() {
|
||||
rt.set_data(gather_idx, scatter);
|
||||
rt.set_data(cache, (0..SLOTS * D).map(|i| i as f32).collect::<Vec<_>>());
|
||||
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
|
||||
rt = cx.search_options(rt, SearchOptions::new(10), &mut rng);
|
||||
rt = cx.search_with_rng(
|
||||
rt,
|
||||
CompileOptions::default().search_graph_limit(10),
|
||||
&mut rng,
|
||||
);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
assert_eq!(rt.get_f32(gathered), expected_gather);
|
||||
@@ -733,7 +741,7 @@ fn test_tiny_gqa_attention_batched_matches_sequential_prefill() {
|
||||
|
||||
cx.set_dim('s', S);
|
||||
cx.set_dim('c', S);
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
|
||||
let q_data: Vec<f32> = (0..S * Q_DIM)
|
||||
.map(|i| ((i as f32 + 1.0) * 0.031).sin())
|
||||
@@ -763,7 +771,11 @@ fn test_tiny_gqa_attention_batched_matches_sequential_prefill() {
|
||||
rt.set_data(k_cache, zero_k.clone());
|
||||
rt.set_data(v_cache, zero_v.clone());
|
||||
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
|
||||
rt = cx.search_options(rt, SearchOptions::new(10), &mut rng);
|
||||
rt = cx.search_with_rng(
|
||||
rt,
|
||||
CompileOptions::default().search_graph_limit(10),
|
||||
&mut rng,
|
||||
);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let batched_attn = rt.get_f32(attn_out);
|
||||
let batched_k = rt.get_f32(k_out);
|
||||
@@ -844,7 +856,7 @@ fn test_original_gqa_attention_batched_matches_sequential_prefill() {
|
||||
|
||||
cx.set_dim('s', S);
|
||||
cx.set_dim('p', 0);
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
|
||||
let q_data: Vec<f32> = (0..S * Q_DIM)
|
||||
.map(|i| ((i as f32 + 1.0) * 0.031).sin())
|
||||
@@ -865,7 +877,11 @@ fn test_original_gqa_attention_batched_matches_sequential_prefill() {
|
||||
rt.set_data(k_cache, zero_k.clone());
|
||||
rt.set_data(v_cache, zero_v.clone());
|
||||
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
|
||||
rt = cx.search_options(rt, SearchOptions::new(10), &mut rng);
|
||||
rt = cx.search_with_rng(
|
||||
rt,
|
||||
CompileOptions::default().search_graph_limit(10),
|
||||
&mut rng,
|
||||
);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let batched_attn = rt.get_f32(attn_out);
|
||||
let batched_k = rt.get_f32(k_out);
|
||||
@@ -925,7 +941,7 @@ fn test_dynamic_expanded_causal_mask_softmax() {
|
||||
|
||||
cx.set_dim('s', S);
|
||||
cx.set_dim('c', S);
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
|
||||
let mut mask_data = vec![0.0f32; S * S];
|
||||
for row in 0..S {
|
||||
@@ -937,7 +953,11 @@ fn test_dynamic_expanded_causal_mask_softmax() {
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.set_data(mask, mask_data);
|
||||
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
|
||||
rt = cx.search_options(rt, SearchOptions::new(10), &mut rng);
|
||||
rt = cx.search_with_rng(
|
||||
rt,
|
||||
CompileOptions::default().search_graph_limit(10),
|
||||
&mut rng,
|
||||
);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let got = rt.get_f32(weights);
|
||||
|
||||
@@ -991,7 +1011,7 @@ fn test_tiny_gqa_value_matmul_with_expanded_kv() {
|
||||
|
||||
cx.set_dim('s', S);
|
||||
cx.set_dim('c', S);
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
|
||||
let v_data: Vec<f32> = (0..S * KV_DIM)
|
||||
.map(|i| ((i as f32 + 5.0) * 0.029).sin())
|
||||
@@ -1007,7 +1027,11 @@ fn test_tiny_gqa_value_matmul_with_expanded_kv() {
|
||||
rt.set_data(v_full, v_data.clone());
|
||||
rt.set_data(mask, mask_data);
|
||||
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
|
||||
rt = cx.search_options(rt, SearchOptions::new(10), &mut rng);
|
||||
rt = cx.search_with_rng(
|
||||
rt,
|
||||
CompileOptions::default().search_graph_limit(10),
|
||||
&mut rng,
|
||||
);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let got = rt.get_f32(out);
|
||||
|
||||
@@ -1055,7 +1079,7 @@ fn test_broadcast_merge_gqa_value_matmul_matches_cpu() {
|
||||
|
||||
cx.set_dim('s', S);
|
||||
cx.set_dim('c', S);
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
|
||||
let v_data: Vec<f32> = (0..N_KV_HEADS * S * HEAD_DIM)
|
||||
.map(|i| ((i as f32 + 5.0) * 0.029).sin())
|
||||
@@ -1073,7 +1097,11 @@ fn test_broadcast_merge_gqa_value_matmul_matches_cpu() {
|
||||
rt.set_data(v_full, v_data.clone());
|
||||
rt.set_data(weights, weights_data);
|
||||
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
|
||||
rt = cx.search_options(rt, SearchOptions::new(10), &mut rng);
|
||||
rt = cx.search_with_rng(
|
||||
rt,
|
||||
CompileOptions::default().search_graph_limit(10),
|
||||
&mut rng,
|
||||
);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let got = rt.get_f32(out);
|
||||
|
||||
@@ -1115,7 +1143,7 @@ fn test_transpose_merge_split_roundtrip_matches_cpu() {
|
||||
let roundtrip = flat.split_dims(1, D).transpose(0, 1).output();
|
||||
|
||||
cx.set_dim('s', S);
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
|
||||
let x_data: Vec<f32> = (0..H * S * D)
|
||||
.map(|i| ((i as f32 + 0.75) * 0.051).sin())
|
||||
@@ -1124,7 +1152,11 @@ fn test_transpose_merge_split_roundtrip_matches_cpu() {
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.set_data(x, x_data.clone());
|
||||
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
|
||||
rt = cx.search_options(rt, SearchOptions::new(10), &mut rng);
|
||||
rt = cx.search_with_rng(
|
||||
rt,
|
||||
CompileOptions::default().search_graph_limit(10),
|
||||
&mut rng,
|
||||
);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let got = rt.get_f32(roundtrip);
|
||||
|
||||
@@ -1158,7 +1190,7 @@ fn test_batched_moe_x_expand_matmul_matches_cpu() {
|
||||
.output();
|
||||
|
||||
cx.set_dim('s', S);
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
|
||||
let x_data: Vec<f32> = (0..S * H)
|
||||
.map(|i| ((i as f32 + 0.5) * 0.137).sin())
|
||||
@@ -1171,7 +1203,11 @@ fn test_batched_moe_x_expand_matmul_matches_cpu() {
|
||||
rt.set_data(x, x_data.clone());
|
||||
rt.set_data(w, w_data.clone());
|
||||
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
|
||||
rt = cx.search_options(rt, SearchOptions::new(10), &mut rng);
|
||||
rt = cx.search_with_rng(
|
||||
rt,
|
||||
CompileOptions::default().search_graph_limit(10),
|
||||
&mut rng,
|
||||
);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let got = rt.get_f32(out);
|
||||
|
||||
@@ -1211,7 +1247,7 @@ fn test_batched_topk_axis1_matches_cpu() {
|
||||
let topk = routing.topk_indexes(K, 1).output();
|
||||
|
||||
cx.set_dim('s', S);
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
|
||||
let routing_data: Vec<f32> = (0..S * E)
|
||||
.map(|i| ((i as f32 + 3.25) * 0.113).sin() + ((i as f32 + 7.0) * 0.019).cos() * 0.1)
|
||||
@@ -1220,7 +1256,11 @@ fn test_batched_topk_axis1_matches_cpu() {
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.set_data(routing, routing_data.clone());
|
||||
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
|
||||
rt = cx.search_options(rt, SearchOptions::new(10), &mut rng);
|
||||
rt = cx.search_with_rng(
|
||||
rt,
|
||||
CompileOptions::default().search_graph_limit(10),
|
||||
&mut rng,
|
||||
);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let got = rt.get_i32(topk);
|
||||
|
||||
@@ -1250,7 +1290,7 @@ fn test_batched_argsort_axis1_matches_cpu() {
|
||||
let argsort = routing.argsort(1, true).output();
|
||||
|
||||
cx.set_dim('s', S);
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
|
||||
let routing_data: Vec<f32> = (0..S * E)
|
||||
.map(|i| ((i as f32 + 3.25) * 0.113).sin() + ((i as f32 + 7.0) * 0.019).cos() * 0.1)
|
||||
@@ -1259,7 +1299,11 @@ fn test_batched_argsort_axis1_matches_cpu() {
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.set_data(routing, routing_data.clone());
|
||||
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
|
||||
rt = cx.search_options(rt, SearchOptions::new(10), &mut rng);
|
||||
rt = cx.search_with_rng(
|
||||
rt,
|
||||
CompileOptions::default().search_graph_limit(10),
|
||||
&mut rng,
|
||||
);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let got = rt.get_i32(argsort);
|
||||
|
||||
@@ -1290,7 +1334,7 @@ fn test_dynamic_3d_sum_axis1_matches_cpu() {
|
||||
let out = input.sum(1).output();
|
||||
|
||||
cx.set_dim('s', S);
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
|
||||
let data: Vec<f32> = (0..S * A * B)
|
||||
.map(|i| ((i as f32 + 4.0) * 0.031).sin())
|
||||
@@ -1299,7 +1343,11 @@ fn test_dynamic_3d_sum_axis1_matches_cpu() {
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.set_data(input, data.clone());
|
||||
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
|
||||
rt = cx.search_options(rt, SearchOptions::new(10), &mut rng);
|
||||
rt = cx.search_with_rng(
|
||||
rt,
|
||||
CompileOptions::default().search_graph_limit(10),
|
||||
&mut rng,
|
||||
);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let got = rt.get_f32(out);
|
||||
|
||||
@@ -1347,7 +1395,7 @@ fn test_batched_argsort_ranks_axis1_matches_cpu() {
|
||||
.output();
|
||||
|
||||
cx.set_dim('s', S);
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
|
||||
let routing_data: Vec<f32> = (0..S * E)
|
||||
.map(|i| ((i as f32 + 3.25) * 0.113).sin() + ((i as f32 + 7.0) * 0.019).cos() * 0.1)
|
||||
@@ -1356,7 +1404,11 @@ fn test_batched_argsort_ranks_axis1_matches_cpu() {
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.set_data(routing, routing_data.clone());
|
||||
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
|
||||
rt = cx.search_options(rt, SearchOptions::new(10), &mut rng);
|
||||
rt = cx.search_with_rng(
|
||||
rt,
|
||||
CompileOptions::default().search_graph_limit(10),
|
||||
&mut rng,
|
||||
);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let got = rt.get_i32(ranks);
|
||||
|
||||
@@ -1391,11 +1443,15 @@ fn test_dynamic_3d_flat_index_iota_rows() {
|
||||
.output();
|
||||
|
||||
cx.set_dim('s', S);
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
|
||||
rt = cx.search_options(rt, SearchOptions::new(10), &mut rng);
|
||||
rt = cx.search_with_rng(
|
||||
rt,
|
||||
CompileOptions::default().search_graph_limit(10),
|
||||
&mut rng,
|
||||
);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let got = rt.get_i32(idx);
|
||||
|
||||
@@ -1431,14 +1487,18 @@ fn test_dynamic_2d_to_3d_gather_rows() {
|
||||
let out = data.gather(idx).output();
|
||||
|
||||
cx.set_dim('s', S);
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
|
||||
let data_values: Vec<i32> = (0..S * E).map(|i| ((i * 17 + 5) % 1000) as i32).collect();
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.set_data(data, data_values.clone());
|
||||
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
|
||||
rt = cx.search_options(rt, SearchOptions::new(10), &mut rng);
|
||||
rt = cx.search_with_rng(
|
||||
rt,
|
||||
CompileOptions::default().search_graph_limit(10),
|
||||
&mut rng,
|
||||
);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let got = rt.get_i32(out);
|
||||
|
||||
@@ -1479,7 +1539,7 @@ fn test_batched_gather_experts_matches_cpu() {
|
||||
let out = weights.gather(exp_base + exp_within).output();
|
||||
|
||||
cx.set_dim('s', S);
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
|
||||
let topk_data: Vec<i32> = (0..S * K).map(|i| ((i * 5 + 3) % E) as i32).collect();
|
||||
let weights_data: Vec<f32> = (0..E * D1 * D2)
|
||||
@@ -1490,7 +1550,11 @@ fn test_batched_gather_experts_matches_cpu() {
|
||||
rt.set_data(topk, topk_data.clone());
|
||||
rt.set_data(weights, weights_data.clone());
|
||||
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
|
||||
rt = cx.search_options(rt, SearchOptions::new(10), &mut rng);
|
||||
rt = cx.search_with_rng(
|
||||
rt,
|
||||
CompileOptions::default().search_graph_limit(10),
|
||||
&mut rng,
|
||||
);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let got = rt.get_f32(out);
|
||||
|
||||
|
||||
482
crates/luminal_cuda_lite/src/tests/conv2d_rewrite.rs
Normal file
482
crates/luminal_cuda_lite/src/tests/conv2d_rewrite.rs
Normal file
@@ -0,0 +1,482 @@
|
||||
use luminal::{
|
||||
egglog_utils::{
|
||||
NodeId, SerializedEGraph, egglog_to_llir, random_initial_choice, validate_choice_set,
|
||||
},
|
||||
prelude::*,
|
||||
};
|
||||
use rand::{SeedableRng, rngs::StdRng};
|
||||
|
||||
use crate::{kernel::KernelOp, runtime::CudaRuntime};
|
||||
|
||||
use super::utilities::{assert_close, get_cuda_stream};
|
||||
|
||||
fn conv2d_bias_hlir(
|
||||
x: GraphTensor,
|
||||
weight: GraphTensor,
|
||||
bias: GraphTensor,
|
||||
kernel_h: usize,
|
||||
kernel_w: usize,
|
||||
) -> GraphTensor {
|
||||
let unfolded = x.unfold(
|
||||
vec![1usize, kernel_h, kernel_w],
|
||||
vec![1usize, 1, 1],
|
||||
vec![1usize, 1, 1],
|
||||
);
|
||||
let output_spatial_dims = unfolded.dims()[1..3].to_vec();
|
||||
|
||||
let mut patches = unfolded.squeeze(3).permute(&[1, 2, 0, 3, 4]);
|
||||
while patches.dims().len() > 3 {
|
||||
let last = patches.dims().len();
|
||||
patches = patches.merge_dims(last - 2, last - 1);
|
||||
}
|
||||
let patches = patches.merge_dims(0, 1);
|
||||
|
||||
let out = patches.matmul(weight.t());
|
||||
let out = out
|
||||
.split_dims(0, output_spatial_dims[1])
|
||||
.permute(&[2, 0, 1]);
|
||||
let out_dims = out.dims();
|
||||
out + bias.expand_dim(1, out_dims[1]).expand_dim(2, out_dims[2])
|
||||
}
|
||||
|
||||
fn build_conv_graph() -> (Graph, GraphTensor, GraphTensor, GraphTensor, GraphTensor) {
|
||||
let mut cx = Graph::new();
|
||||
let x = cx.tensor((2usize, 5usize, 6usize));
|
||||
let weight = cx.tensor((3usize, 2usize * 3 * 2));
|
||||
let bias = cx.tensor(3usize);
|
||||
let out = conv2d_bias_hlir(x, weight, bias, 3, 2).output();
|
||||
(cx, x, weight, bias, out)
|
||||
}
|
||||
|
||||
fn conv2d_bias_padded_hlir(
|
||||
x: GraphTensor,
|
||||
weight: GraphTensor,
|
||||
bias: GraphTensor,
|
||||
kernel: usize,
|
||||
padding: usize,
|
||||
) -> GraphTensor {
|
||||
let zero = Expression::from(0);
|
||||
let pad = Expression::from(padding);
|
||||
let padded = x.pad(vec![(zero, zero), (pad, pad), (pad, pad)], 0.0);
|
||||
conv2d_bias_hlir(padded, weight, bias, kernel, kernel)
|
||||
}
|
||||
|
||||
fn build_padded_conv_graph() -> (Graph, GraphTensor, GraphTensor, GraphTensor, GraphTensor) {
|
||||
let mut cx = Graph::new();
|
||||
let x = cx.tensor((2usize, 4usize, 5usize));
|
||||
let weight = cx.tensor((3usize, 2usize * 3 * 3));
|
||||
let bias = cx.tensor(3usize);
|
||||
let out = conv2d_bias_padded_hlir(x, weight, bias, 3, 1).output();
|
||||
(cx, x, weight, bias, out)
|
||||
}
|
||||
|
||||
fn nearest_upsample_2x_hlir(x: GraphTensor) -> GraphTensor {
|
||||
let stage1 = x.expand_dim(2, 2usize).merge_dims(1, 2);
|
||||
stage1.expand_dim(3, 2usize).merge_dims(2, 3)
|
||||
}
|
||||
|
||||
fn build_upsample_conv_graph() -> (Graph, GraphTensor, GraphTensor, GraphTensor, GraphTensor) {
|
||||
let mut cx = Graph::new();
|
||||
let x = cx.tensor((2usize, 3usize, 4usize));
|
||||
let weight = cx.tensor((3usize, 2usize * 3 * 3));
|
||||
let bias = cx.tensor(3usize);
|
||||
let up = nearest_upsample_2x_hlir(x);
|
||||
let out = conv2d_bias_padded_hlir(up, weight, bias, 3, 1).output();
|
||||
(cx, x, weight, bias, out)
|
||||
}
|
||||
|
||||
fn conv1x1_bias_hlir(x: GraphTensor, weight: GraphTensor, bias: GraphTensor) -> GraphTensor {
|
||||
let dims = x.dims();
|
||||
let h = dims[1];
|
||||
let w = dims[2];
|
||||
let xt = x.permute(&[1, 2, 0]).merge_dims(0, 1);
|
||||
let out = xt.matmul(weight.t());
|
||||
let out = out.split_dims(0, w).permute(&[2, 0, 1]);
|
||||
out + bias.expand_dim(1, h).expand_dim(2, w)
|
||||
}
|
||||
|
||||
fn build_conv1x1_graph() -> (Graph, GraphTensor, GraphTensor, GraphTensor, GraphTensor) {
|
||||
let mut cx = Graph::new();
|
||||
let x = cx.tensor((2usize, 4usize, 5usize));
|
||||
let weight = cx.tensor((3usize, 2usize));
|
||||
let bias = cx.tensor(3usize);
|
||||
let out = conv1x1_bias_hlir(x, weight, bias).output();
|
||||
(cx, x, weight, bias, out)
|
||||
}
|
||||
|
||||
fn conv2d_matmul_without_conv_output_shape(
|
||||
x: GraphTensor,
|
||||
weight: GraphTensor,
|
||||
bias: GraphTensor,
|
||||
kernel_h: usize,
|
||||
kernel_w: usize,
|
||||
) -> GraphTensor {
|
||||
let unfolded = x.unfold(
|
||||
vec![1usize, kernel_h, kernel_w],
|
||||
vec![1usize, 1, 1],
|
||||
vec![1usize, 1, 1],
|
||||
);
|
||||
|
||||
let mut patches = unfolded.squeeze(3).permute(&[1, 2, 0, 3, 4]);
|
||||
while patches.dims().len() > 3 {
|
||||
let last = patches.dims().len();
|
||||
patches = patches.merge_dims(last - 2, last - 1);
|
||||
}
|
||||
let patches = patches.merge_dims(0, 1);
|
||||
|
||||
let out = patches.matmul(weight.t());
|
||||
let out_dims = out.dims();
|
||||
out + bias.expand_dim(0, out_dims[0])
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn generic_conv2d_rewrite_matches_unfold_matmul_bias() {
|
||||
let (mut cx, _, _, _, _) = build_conv_graph();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let egraph = cx.egraph().expect("search space should have an e-graph");
|
||||
|
||||
assert!(
|
||||
!op_ir_nodes(egraph, "KernelConv2D").is_empty(),
|
||||
"expected generic conv2d rewrite candidate"
|
||||
);
|
||||
assert!(
|
||||
op_ir_nodes(egraph, "Add").is_empty(),
|
||||
"generic conv2d cleanup should prune the final bias Add fallback"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn generic_conv2d_rewrite_matches_conv1x1_matmul_bias() {
|
||||
let (mut cx, _, _, _, _) = build_conv1x1_graph();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let egraph = cx.egraph().expect("search space should have an e-graph");
|
||||
|
||||
assert!(
|
||||
!op_ir_nodes(egraph, "KernelConv2D").is_empty(),
|
||||
"expected generic conv2d rewrite candidate for 1x1 conv"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn generic_conv2d_rewrite_requires_conv_output_shape() {
|
||||
let mut cx = Graph::new();
|
||||
let x = cx.tensor((2usize, 5usize, 6usize));
|
||||
let weight = cx.tensor((3usize, 2usize * 3 * 2));
|
||||
let bias = cx.tensor(3usize);
|
||||
conv2d_matmul_without_conv_output_shape(x, weight, bias, 3, 2).output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let egraph = cx.egraph().expect("search space should have an e-graph");
|
||||
|
||||
assert!(
|
||||
op_ir_nodes(egraph, "KernelConv2D").is_empty(),
|
||||
"matmul+bias without [C_out,H_out,W_out] conv output shape should not match KernelConv2D"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn generic_conv2d_candidate_executes_unfold_matmul_bias() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
|
||||
let (mut cx, x, weight, bias, out) = build_conv_graph();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let llir = extract_forced_kernel_llir(&mut cx, "GenericConv2D");
|
||||
|
||||
let input: Vec<f32> = (0..2 * 5 * 6).map(|i| i as f32 * 0.03 - 0.4).collect();
|
||||
let weights: Vec<f32> = (0..3 * 2 * 3 * 2)
|
||||
.map(|i| (i as f32 % 11.0) * 0.04 - 0.2)
|
||||
.collect();
|
||||
let biases = vec![0.25_f32, -0.15, 0.05];
|
||||
let expected = reference_conv2d(
|
||||
&input,
|
||||
&weights,
|
||||
&biases,
|
||||
ConvCase {
|
||||
c_in: 2,
|
||||
h: 5,
|
||||
w: 6,
|
||||
c_out: 3,
|
||||
kh: 3,
|
||||
kw: 2,
|
||||
padding_h: 0,
|
||||
padding_w: 0,
|
||||
},
|
||||
);
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.load_llir(&llir);
|
||||
rt.set_data(x, input);
|
||||
rt.set_data(weight, weights);
|
||||
rt.set_data(bias, biases);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
assert_close(&rt.get_f32(out.id), &expected, 1e-5, 1e-5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn generic_conv2d_candidate_executes_conv1x1_matmul_bias() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
|
||||
let (mut cx, x, weight, bias, out) = build_conv1x1_graph();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let llir = extract_forced_kernel_llir(&mut cx, "GenericConv2D");
|
||||
|
||||
let input: Vec<f32> = (0..2 * 4 * 5).map(|i| i as f32 * 0.07 - 1.0).collect();
|
||||
let weights: Vec<f32> = (0..3 * 2).map(|i| (i as f32 % 5.0) * 0.11 - 0.2).collect();
|
||||
let biases = vec![0.2_f32, -0.1, 0.4];
|
||||
let expected = reference_conv2d(
|
||||
&input,
|
||||
&weights,
|
||||
&biases,
|
||||
ConvCase {
|
||||
c_in: 2,
|
||||
h: 4,
|
||||
w: 5,
|
||||
c_out: 3,
|
||||
kh: 1,
|
||||
kw: 1,
|
||||
padding_h: 0,
|
||||
padding_w: 0,
|
||||
},
|
||||
);
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.load_llir(&llir);
|
||||
rt.set_data(x, input);
|
||||
rt.set_data(weight, weights);
|
||||
rt.set_data(bias, biases);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
assert_close(&rt.get_f32(out.id), &expected, 1e-5, 1e-5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn generic_conv2d_candidate_executes_padded_unfold_matmul_bias() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
|
||||
let (mut cx, x, weight, bias, out) = build_padded_conv_graph();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let llir = extract_forced_kernel_llir(&mut cx, "GenericConv2D");
|
||||
|
||||
let input: Vec<f32> = (0..2 * 4 * 5).map(|i| i as f32 * 0.05 - 0.5).collect();
|
||||
let weights: Vec<f32> = (0..3 * 2 * 3 * 3)
|
||||
.map(|i| (i as f32 % 13.0) * 0.03 - 0.17)
|
||||
.collect();
|
||||
let biases = vec![0.15_f32, -0.25, 0.35];
|
||||
let expected = reference_conv2d(
|
||||
&input,
|
||||
&weights,
|
||||
&biases,
|
||||
ConvCase {
|
||||
c_in: 2,
|
||||
h: 4,
|
||||
w: 5,
|
||||
c_out: 3,
|
||||
kh: 3,
|
||||
kw: 3,
|
||||
padding_h: 1,
|
||||
padding_w: 1,
|
||||
},
|
||||
);
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.load_llir(&llir);
|
||||
rt.set_data(x, input);
|
||||
rt.set_data(weight, weights);
|
||||
rt.set_data(bias, biases);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
assert_close(&rt.get_f32(out.id), &expected, 1e-5, 1e-5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn generic_conv2d_candidate_executes_upsample_view_input() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
|
||||
let (mut cx, x, weight, bias, out) = build_upsample_conv_graph();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let llir = extract_forced_kernel_llir(&mut cx, "GenericConv2D");
|
||||
|
||||
let input: Vec<f32> = (0..2 * 3 * 4).map(|i| i as f32 * 0.09 - 0.8).collect();
|
||||
let weights: Vec<f32> = (0..3 * 2 * 3 * 3)
|
||||
.map(|i| (i as f32 % 17.0) * 0.025 - 0.2)
|
||||
.collect();
|
||||
let biases = vec![0.05_f32, -0.1, 0.2];
|
||||
let upsampled = reference_nearest_upsample_2x(&input, 2, 3, 4);
|
||||
let expected = reference_conv2d(
|
||||
&upsampled,
|
||||
&weights,
|
||||
&biases,
|
||||
ConvCase {
|
||||
c_in: 2,
|
||||
h: 6,
|
||||
w: 8,
|
||||
c_out: 3,
|
||||
kh: 3,
|
||||
kw: 3,
|
||||
padding_h: 1,
|
||||
padding_w: 1,
|
||||
},
|
||||
);
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.load_llir(&llir);
|
||||
rt.set_data(x, input);
|
||||
rt.set_data(weight, weights);
|
||||
rt.set_data(bias, biases);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
assert_close(&rt.get_f32(out.id), &expected, 1e-5, 1e-5);
|
||||
}
|
||||
|
||||
struct ConvCase {
|
||||
c_in: usize,
|
||||
h: usize,
|
||||
w: usize,
|
||||
c_out: usize,
|
||||
kh: usize,
|
||||
kw: usize,
|
||||
padding_h: usize,
|
||||
padding_w: usize,
|
||||
}
|
||||
|
||||
fn reference_nearest_upsample_2x(input: &[f32], c: usize, h: usize, w: usize) -> Vec<f32> {
|
||||
let mut out = vec![0.0_f32; c * h * 2 * w * 2];
|
||||
for ci in 0..c {
|
||||
for y in 0..h {
|
||||
for x in 0..w {
|
||||
let value = input[ci * h * w + y * w + x];
|
||||
for dy in 0..2 {
|
||||
for dx in 0..2 {
|
||||
let oy = y * 2 + dy;
|
||||
let ox = x * 2 + dx;
|
||||
out[ci * h * 2 * w * 2 + oy * w * 2 + ox] = value;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
fn reference_conv2d(input: &[f32], weight: &[f32], bias: &[f32], case: ConvCase) -> Vec<f32> {
|
||||
let ConvCase {
|
||||
c_in,
|
||||
h,
|
||||
w,
|
||||
c_out,
|
||||
kh,
|
||||
kw,
|
||||
padding_h,
|
||||
padding_w,
|
||||
} = case;
|
||||
let h_out = h + 2 * padding_h - kh + 1;
|
||||
let w_out = w + 2 * padding_w - kw + 1;
|
||||
let mut out = vec![0.0; c_out * h_out * w_out];
|
||||
for co in 0..c_out {
|
||||
for oh in 0..h_out {
|
||||
for ow in 0..w_out {
|
||||
let mut acc = bias[co];
|
||||
for ci in 0..c_in {
|
||||
for r in 0..kh {
|
||||
for s in 0..kw {
|
||||
let Some(ih) = (oh + r).checked_sub(padding_h) else {
|
||||
continue;
|
||||
};
|
||||
let Some(iw) = (ow + s).checked_sub(padding_w) else {
|
||||
continue;
|
||||
};
|
||||
if ih >= h || iw >= w {
|
||||
continue;
|
||||
}
|
||||
let input_idx = ci * h * w + ih * w + iw;
|
||||
let weight_idx = co * c_in * kh * kw + (ci * kh + r) * kw + s;
|
||||
acc += input[input_idx] * weight[weight_idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
out[co * h_out * w_out + oh * w_out + ow] = acc;
|
||||
}
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
fn extract_forced_kernel_llir(cx: &mut Graph, kernel_name: &str) -> LLIRGraph {
|
||||
let egraph = cx.egraph().expect("search space should have an e-graph");
|
||||
let ops = cx
|
||||
.egglog_ops()
|
||||
.expect("search space should have registered egglog ops");
|
||||
let kernel_nodes = op_ir_nodes(egraph, "KernelConv2D");
|
||||
assert!(
|
||||
!kernel_nodes.is_empty(),
|
||||
"expected at least one {kernel_name} candidate"
|
||||
);
|
||||
|
||||
for (idx, kernel_node) in kernel_nodes.iter().enumerate() {
|
||||
let mut rng = StdRng::seed_from_u64(0xC0_2D00 + idx as u64);
|
||||
let mut choices = random_initial_choice(egraph, &mut rng);
|
||||
let kernel_class = &egraph.node_to_class[*kernel_node];
|
||||
choices.insert(kernel_class, kernel_node);
|
||||
|
||||
if validate_choice_set(egraph, &choices, ops).is_err() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut list_cache = FxHashMap::default();
|
||||
let mut expr_cache = FxHashMap::default();
|
||||
let llir = egglog_to_llir(
|
||||
egraph,
|
||||
choices,
|
||||
ops,
|
||||
&cx.custom_ops,
|
||||
&mut list_cache,
|
||||
&mut expr_cache,
|
||||
None,
|
||||
);
|
||||
if llir_kernel_names(&llir).contains(&kernel_name) {
|
||||
return llir;
|
||||
}
|
||||
}
|
||||
|
||||
panic!("could not extract a valid {kernel_name} candidate");
|
||||
}
|
||||
|
||||
fn llir_kernel_names(llir: &LLIRGraph) -> Vec<&'static str> {
|
||||
llir.node_indices()
|
||||
.filter_map(|node| {
|
||||
llir[node]
|
||||
.to_dialect::<dyn KernelOp>()
|
||||
.map(|kernel| kernel.kernel_name())
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn op_ir_nodes<'a>(egraph: &'a SerializedEGraph, kind_label: &str) -> Vec<&'a NodeId> {
|
||||
let op_kind_classes = egraph
|
||||
.enodes
|
||||
.iter()
|
||||
.filter(|(_, (label, _))| label == kind_label)
|
||||
.map(|(node, _)| egraph.node_to_class[node].clone())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
egraph
|
||||
.enodes
|
||||
.iter()
|
||||
.filter_map(|(node, (label, children))| {
|
||||
(label == "Op"
|
||||
&& children
|
||||
.first()
|
||||
.is_some_and(|kind| op_kind_classes.contains(kind)))
|
||||
.then_some(node)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
@@ -7,10 +7,12 @@ use luminal::{
|
||||
prelude::*,
|
||||
};
|
||||
use rand::{SeedableRng, rngs::StdRng};
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{
|
||||
host::{
|
||||
CublasLtMatrixOrders, CublasLtScaleValues, CublasLtTransposeOps, CublasLtTypeTuple, HostOp,
|
||||
cublaslt::{cublaslt_prepare_count_for_test, reset_cublaslt_prepare_count_for_test},
|
||||
cublaslt_c_d_layouts_match, cublaslt_epilogue, cublaslt_matrix_orders,
|
||||
cublaslt_scale_values, cublaslt_tensor_scale_inputs, cublaslt_transpose_ops,
|
||||
cublaslt_type_tuple,
|
||||
@@ -134,6 +136,45 @@ fn reference_matmul_2d(a: &[f32], b: &[f32], m: usize, n: usize, k: usize) -> Ve
|
||||
expected
|
||||
}
|
||||
|
||||
fn reference_mixed_chain(
|
||||
a: &[f32],
|
||||
pre: &[f32],
|
||||
b: &[f32],
|
||||
m: usize,
|
||||
n: usize,
|
||||
k: usize,
|
||||
) -> Vec<f32> {
|
||||
let mut expected = vec![0.0; m * n];
|
||||
for row in 0..m {
|
||||
for col in 0..n {
|
||||
let mut acc = 0.0;
|
||||
for inner in 0..k {
|
||||
acc += (a[row * k + inner] + pre[row * k + inner]) * b[inner * n + col];
|
||||
}
|
||||
expected[row * n + col] = acc.exp();
|
||||
}
|
||||
}
|
||||
expected
|
||||
}
|
||||
|
||||
fn cublaslt_available_for_runtime(stream: &Arc<cudarc::driver::CudaStream>) -> bool {
|
||||
crate::try_create_cublaslt(stream.clone()).is_ok()
|
||||
}
|
||||
|
||||
fn build_mixed_chain_graph(
|
||||
m: impl Into<Expression>,
|
||||
n: usize,
|
||||
k: usize,
|
||||
) -> (Graph, NodeIndex, NodeIndex, NodeIndex, NodeIndex) {
|
||||
let m = m.into();
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor((m, k));
|
||||
let pre = cx.tensor((m, k));
|
||||
let b = cx.tensor((k, n));
|
||||
let out = ((a + pre).matmul(b).exp()).output();
|
||||
(cx, a.id, pre.id, b.id, out.id)
|
||||
}
|
||||
|
||||
fn add_in_place(values: &mut [f32], addends: &[f32]) {
|
||||
for (value, addend) in values.iter_mut().zip(addends) {
|
||||
*value += *addend;
|
||||
@@ -445,6 +486,54 @@ fn cublaslt_rewrites_cover_batched_row_order_layout_pairs() {
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cublaslt_rewrites_cover_flux2_qk_transposed_matmul() {
|
||||
let mut cx = Graph::new();
|
||||
let q = cx.tensor((8usize, 4usize));
|
||||
let k = cx.tensor((8usize, 4usize));
|
||||
let _out = q.matmul(k.t()).output();
|
||||
|
||||
assert_cublaslt_rewrite(cx, "flux2 q @ k.t()", |llir| {
|
||||
cublaslt_matrix_order_tuples(llir).contains(&("ROW", "COL", "ROW", "ROW"))
|
||||
|| cublaslt_matrix_order_tuples(llir).contains(&("COL", "COL", "COL", "COL"))
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cublaslt_rewrites_cover_flux2_linear_bias_epilogue() {
|
||||
let mut cx = Graph::new();
|
||||
let x = cx.tensor((8usize, 4usize));
|
||||
let weight = cx.tensor((6usize, 4usize));
|
||||
let bias = cx.tensor(6usize);
|
||||
let _out = (x.matmul(weight.t()) + bias.expand_dim(0, 8usize)).output();
|
||||
|
||||
assert_cublaslt_epilogue_rewrite(
|
||||
cx,
|
||||
"flux2 x @ weight.t() + bias",
|
||||
"BIAS",
|
||||
Some(("COL", "COL", "COL", "COL")),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cublaslt_cleanup_prunes_flux2_broadcast_mul_fallback() {
|
||||
let mut cx = Graph::new();
|
||||
let q = cx.tensor((8usize, 4usize));
|
||||
let k = cx.tensor((8usize, 4usize));
|
||||
let _out = q.matmul(k.t()).output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let egraph = cx.egraph().expect("search space should have an e-graph");
|
||||
assert!(
|
||||
!cublaslt_ir_nodes(egraph).is_empty(),
|
||||
"Flux2 q @ k.t() should have at least one cuBLASLt candidate"
|
||||
);
|
||||
assert!(
|
||||
op_ir_nodes(egraph, "Mul").is_empty(),
|
||||
"cuBLASLt cleanup should prune the broadcast Mul fallback once a cuBLASLt candidate exists"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cublaslt_rewrites_keep_c_and_d_layouts_equal_initially() {
|
||||
for case in LAYOUT_CASES {
|
||||
@@ -459,6 +548,463 @@ fn cublaslt_rewrites_keep_c_and_d_layouts_equal_initially() {
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mixed_cuda_graph_cublaslt_kernel_chain_executes_correctly() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
if !cublaslt_available_for_runtime(&stream) {
|
||||
return;
|
||||
}
|
||||
if !crate::host::cublaslt::cublaslt_graph_capture_supported(&stream) {
|
||||
return;
|
||||
}
|
||||
|
||||
let (m, n, k) = (7, 11, 5);
|
||||
let (mut cx, a, pre, b, out) = build_mixed_chain_graph(m, n, k);
|
||||
let llir = extract_forced_cublaslt_llir_where(&mut cx, "mixed graph chain", |llir| {
|
||||
cublaslt_scale_value_tuples(llir).contains(&(1.0, 0.0))
|
||||
});
|
||||
|
||||
let a_data = random_f32_vec(m * k, 0xCAFE_0001, -0.08, 0.08);
|
||||
let pre_data = random_f32_vec(m * k, 0xCAFE_0002, -0.03, 0.03);
|
||||
let b_data = random_f32_vec(k * n, 0xCAFE_0003, -0.08, 0.08);
|
||||
let expected = reference_mixed_chain(&a_data, &pre_data, &b_data, m, n, k);
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.load_llir(&llir);
|
||||
rt.set_data(a, a_data);
|
||||
rt.set_data(pre, pre_data);
|
||||
rt.set_data(b, b_data);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
assert_close(&rt.get_f32(out), &expected, 1e-5, 1e-5);
|
||||
let summaries = rt.debug_cuda_graph_summaries();
|
||||
let mixed = summaries
|
||||
.iter()
|
||||
.find(|summary| summary.n_cublaslt == 1)
|
||||
.expect("expected one CudaGraphOp to capture the cuBLASLt island");
|
||||
assert!(mixed.n_kernels >= 2, "expected kernels around cuBLASLt");
|
||||
assert_eq!(mixed.n_steps, mixed.n_kernels + mixed.n_cublaslt);
|
||||
assert_eq!(mixed.absorbed_host_nodes.len(), 1);
|
||||
assert_eq!(rt.debug_standalone_cublaslt_host_ops(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cuda_graph_cublaslt_only_executes_correctly() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
if !cublaslt_available_for_runtime(&stream) {
|
||||
return;
|
||||
}
|
||||
if !crate::host::cublaslt::cublaslt_graph_capture_supported(&stream) {
|
||||
return;
|
||||
}
|
||||
|
||||
let (m, n, k) = (7, 11, 5);
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor((m, k));
|
||||
let b = cx.tensor((k, n));
|
||||
let out = a.matmul(b).output();
|
||||
let llir = extract_forced_cublaslt_llir_where(&mut cx, "cuBLASLt-only graph", |_| true);
|
||||
|
||||
let a_data = random_f32_vec(m * k, 0xC001_0001, -0.08, 0.08);
|
||||
let b_data = random_f32_vec(k * n, 0xC001_0002, -0.08, 0.08);
|
||||
let expected = reference_matmul_2d(&a_data, &b_data, m, n, k);
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.load_llir(&llir);
|
||||
rt.set_data(a.id, a_data);
|
||||
rt.set_data(b.id, b_data);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
assert_close(&rt.get_f32(out.id), &expected, 1e-5, 1e-5);
|
||||
let summary = rt
|
||||
.debug_cuda_graph_summaries()
|
||||
.into_iter()
|
||||
.find(|summary| summary.n_cublaslt == 1)
|
||||
.expect("expected a cuBLASLt-only CudaGraphOp");
|
||||
assert_eq!(summary.n_kernels, 0);
|
||||
assert_eq!(summary.n_steps, 1);
|
||||
assert_eq!(rt.debug_standalone_cublaslt_host_ops(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mixed_cuda_graph_reuses_prepared_for_ordered_matching_cublaslt_ops() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
if !cublaslt_available_for_runtime(&stream) {
|
||||
return;
|
||||
}
|
||||
if !crate::host::cublaslt::cublaslt_graph_capture_supported(&stream) {
|
||||
return;
|
||||
}
|
||||
|
||||
let (m, n, k) = (5, 8, 8);
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor((m, k));
|
||||
let b = cx.tensor((k, n));
|
||||
let first = a.matmul(b);
|
||||
let out = (a + first.sin()).matmul(b).output();
|
||||
let llir = extract_forced_cublaslt_llir_where(
|
||||
&mut cx,
|
||||
"ordered matching cuBLASLt prepared reuse",
|
||||
|llir| {
|
||||
let orders = cublaslt_matrix_order_tuples(llir);
|
||||
orders.len() == 2 && orders[0] == orders[1]
|
||||
},
|
||||
);
|
||||
|
||||
let a_data = random_f32_vec(m * k, 0xC001_1001, -0.08, 0.08);
|
||||
let b_data = random_f32_vec(k * n, 0xC001_1002, -0.08, 0.08);
|
||||
let first = reference_matmul_2d(&a_data, &b_data, m, n, k);
|
||||
let dep = a_data
|
||||
.iter()
|
||||
.zip(&first)
|
||||
.map(|(a, first)| a + first.sin())
|
||||
.collect::<Vec<_>>();
|
||||
let expected = reference_matmul_2d(&dep, &b_data, m, n, k);
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.load_llir(&llir);
|
||||
rt.set_data(a.id, a_data);
|
||||
rt.set_data(b.id, b_data);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
assert_close(&rt.get_f32(out.id), &expected, 1e-5, 1e-5);
|
||||
let summary = rt
|
||||
.debug_cuda_graph_summaries()
|
||||
.into_iter()
|
||||
.find(|summary| summary.n_cublaslt == 2)
|
||||
.expect("expected one mixed CudaGraphOp with two cuBLASLt islands");
|
||||
assert_eq!(
|
||||
summary.n_cublaslt_prepared, 1,
|
||||
"dependency-ordered matching cuBLASLt calls should share prepared resources"
|
||||
);
|
||||
assert_eq!(rt.debug_standalone_cublaslt_host_ops(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cuda_graph_cublaslt_skips_prepare_when_unrelated_dyn_dim_changes() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
if !cublaslt_available_for_runtime(&stream) {
|
||||
return;
|
||||
}
|
||||
if !crate::host::cublaslt::cublaslt_graph_capture_supported(&stream) {
|
||||
return;
|
||||
}
|
||||
|
||||
let (m, n, k) = (7, 11, 5);
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor((m, k));
|
||||
let b = cx.tensor((k, n));
|
||||
let out = a.matmul(b).output();
|
||||
a.output();
|
||||
b.output();
|
||||
cx.set_dim('p', 1);
|
||||
let llir = extract_forced_cublaslt_llir_where(
|
||||
&mut cx,
|
||||
"cuBLASLt unchanged under unrelated dyn dim",
|
||||
|_| true,
|
||||
);
|
||||
|
||||
let a_data = random_f32_vec(m * k, 0xC004_0001, -0.08, 0.08);
|
||||
let b_data = random_f32_vec(k * n, 0xC004_0002, -0.08, 0.08);
|
||||
let expected = reference_matmul_2d(&a_data, &b_data, m, n, k);
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.load_llir(&llir);
|
||||
rt.set_data(a.id, a_data);
|
||||
rt.set_data(b.id, b_data);
|
||||
|
||||
reset_cublaslt_prepare_count_for_test();
|
||||
rt.execute(&cx.dyn_map);
|
||||
assert_close(&rt.get_f32(out.id), &expected, 1e-5, 1e-5);
|
||||
let first_prepare_count = cublaslt_prepare_count_for_test();
|
||||
assert!(
|
||||
first_prepare_count > 0,
|
||||
"first execution should prepare the captured cuBLASLt island"
|
||||
);
|
||||
|
||||
cx.set_dim('p', 2);
|
||||
rt.execute(&cx.dyn_map);
|
||||
assert_close(&rt.get_f32(out.id), &expected, 1e-5, 1e-5);
|
||||
assert_eq!(
|
||||
cublaslt_prepare_count_for_test(),
|
||||
first_prepare_count,
|
||||
"unrelated dyn dim changes should not redo expensive cuBLASLt prepare"
|
||||
);
|
||||
assert_eq!(rt.debug_standalone_cublaslt_host_ops(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cuda_graph_cublaslt_only_recaptures_on_dynamic_shape_change() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
if !cublaslt_available_for_runtime(&stream) {
|
||||
return;
|
||||
}
|
||||
if !crate::host::cublaslt::cublaslt_graph_capture_supported(&stream) {
|
||||
return;
|
||||
}
|
||||
|
||||
let (n, k) = (11, 5);
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor(('m', k));
|
||||
let b = cx.tensor((k, n));
|
||||
let out = a.matmul(b).output();
|
||||
cx.set_dim('m', 7);
|
||||
let llir = extract_forced_cublaslt_llir_where(&mut cx, "cuBLASLt-only dynamic graph", |_| true);
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.load_llir(&llir);
|
||||
|
||||
for (m, seed) in [
|
||||
(7usize, 0xC002_0001),
|
||||
(9usize, 0xC002_0002),
|
||||
(7usize, 0xC002_0003),
|
||||
] {
|
||||
cx.set_dim('m', m);
|
||||
let a_data = random_f32_vec(m * k, seed, -0.08, 0.08);
|
||||
let b_data = random_f32_vec(k * n, seed + 10, -0.08, 0.08);
|
||||
let expected = reference_matmul_2d(&a_data, &b_data, m, n, k);
|
||||
|
||||
rt.set_data(a.id, a_data);
|
||||
rt.set_data(b.id, b_data);
|
||||
rt.execute(&cx.dyn_map);
|
||||
assert_close(&rt.get_f32(out.id), &expected, 1e-5, 1e-5);
|
||||
}
|
||||
|
||||
let summary = rt
|
||||
.debug_cuda_graph_summaries()
|
||||
.into_iter()
|
||||
.find(|summary| summary.n_cublaslt == 1)
|
||||
.expect("expected a cuBLASLt-only CudaGraphOp after recapture");
|
||||
assert_eq!(summary.n_kernels, 0);
|
||||
assert_eq!(summary.n_steps, 1);
|
||||
assert_eq!(rt.debug_standalone_cublaslt_host_ops(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cublaslt_with_dynamic_c_spec_is_captured() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
if !cublaslt_available_for_runtime(&stream) {
|
||||
return;
|
||||
}
|
||||
|
||||
let (n, k) = (11, 5);
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor(('c', k));
|
||||
let b = cx.tensor((k, n));
|
||||
let out = a.matmul(b).output();
|
||||
cx.set_dim('c', 7);
|
||||
let llir = extract_forced_cublaslt_llir_where(&mut cx, "dynamic c cuBLASLt graph", |_| true);
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.load_llir(&llir);
|
||||
|
||||
for (c, seed) in [(7usize, 0xC003_0001), (9usize, 0xC003_0002)] {
|
||||
cx.set_dim('c', c);
|
||||
let a_data = random_f32_vec(c * k, seed, -0.08, 0.08);
|
||||
let b_data = random_f32_vec(k * n, seed + 10, -0.08, 0.08);
|
||||
let expected = reference_matmul_2d(&a_data, &b_data, c, n, k);
|
||||
|
||||
rt.set_data(a.id, a_data);
|
||||
rt.set_data(b.id, b_data);
|
||||
rt.execute(&cx.dyn_map);
|
||||
assert_close(&rt.get_f32(out.id), &expected, 1e-5, 1e-5);
|
||||
}
|
||||
|
||||
assert!(
|
||||
rt.debug_cuda_graph_summaries()
|
||||
.iter()
|
||||
.any(|summary| summary.n_cublaslt == 1),
|
||||
"c-dependent cuBLASLt should be absorbed into a CUDA graph"
|
||||
);
|
||||
assert_eq!(rt.debug_standalone_cublaslt_host_ops(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bucket_range_and_singleton_cublaslt_buckets_are_captured() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
if !cublaslt_available_for_runtime(&stream) {
|
||||
return;
|
||||
}
|
||||
if !crate::host::cublaslt::cublaslt_graph_capture_supported(&stream) {
|
||||
return;
|
||||
}
|
||||
|
||||
let (n, k) = (11, 5);
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor(('s', k));
|
||||
let b = cx.tensor((k, n));
|
||||
let out = a.matmul(b).output();
|
||||
cx.set_dim('s', 1);
|
||||
let llir =
|
||||
extract_forced_cublaslt_llir_where(&mut cx, "bucketed s cuBLASLt graph capture", |_| true);
|
||||
|
||||
let dim_buckets = [('s', vec![DimBucket::new(1, 1), DimBucket::new(2, 4)])]
|
||||
.into_iter()
|
||||
.collect();
|
||||
let bucket_llirs = vec![
|
||||
(
|
||||
[('s', 0usize)].into_iter().collect(),
|
||||
[('s', 1usize)].into_iter().collect(),
|
||||
llir.clone(),
|
||||
),
|
||||
(
|
||||
[('s', 1usize)].into_iter().collect(),
|
||||
[('s', 3usize)].into_iter().collect(),
|
||||
llir,
|
||||
),
|
||||
];
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.load_llir_buckets(&dim_buckets, &bucket_llirs);
|
||||
|
||||
cx.set_dim('s', 1);
|
||||
let a_data = random_f32_vec(k, 0xB001_0001, -0.08, 0.08);
|
||||
let b_data = random_f32_vec(k * n, 0xB001_0002, -0.08, 0.08);
|
||||
let expected = reference_matmul_2d(&a_data, &b_data, 1, n, k);
|
||||
rt.set_data(a.id, a_data);
|
||||
rt.set_data(b.id, b_data.clone());
|
||||
rt.execute(&cx.dyn_map);
|
||||
assert_close(&rt.get_f32(out.id), &expected, 1e-5, 1e-5);
|
||||
assert!(
|
||||
rt.debug_cuda_graph_summaries()
|
||||
.iter()
|
||||
.any(|summary| summary.n_cublaslt == 1),
|
||||
"singleton s bucket should capture s-dependent cuBLASLt"
|
||||
);
|
||||
assert_eq!(rt.debug_standalone_cublaslt_host_ops(), 0);
|
||||
assert!(
|
||||
rt.debug_active_bucket_stabilizes_intermediate_pointers(),
|
||||
"bucket with captured cuBLASLt needs stable intermediate pointers"
|
||||
);
|
||||
|
||||
cx.set_dim('s', 3);
|
||||
let a_data = random_f32_vec(3 * k, 0xB001_0003, -0.08, 0.08);
|
||||
let expected = reference_matmul_2d(&a_data, &b_data, 3, n, k);
|
||||
rt.set_data(a.id, a_data);
|
||||
rt.set_data(b.id, b_data);
|
||||
rt.execute(&cx.dyn_map);
|
||||
assert_close(&rt.get_f32(out.id), &expected, 1e-5, 1e-5);
|
||||
assert!(
|
||||
rt.debug_cuda_graph_summaries()
|
||||
.iter()
|
||||
.any(|summary| summary.n_cublaslt == 1),
|
||||
"range s bucket should capture s-dependent cuBLASLt"
|
||||
);
|
||||
assert_eq!(rt.debug_standalone_cublaslt_host_ops(), 0);
|
||||
assert!(
|
||||
rt.debug_active_bucket_stabilizes_intermediate_pointers(),
|
||||
"bucket with captured cuBLASLt needs stable intermediate pointers"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mixed_cuda_graph_cublaslt_recaptures_on_input_pointer_change() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
if !cublaslt_available_for_runtime(&stream) {
|
||||
return;
|
||||
}
|
||||
if !crate::host::cublaslt::cublaslt_graph_capture_supported(&stream) {
|
||||
return;
|
||||
}
|
||||
|
||||
let (m, n, k) = (7, 11, 5);
|
||||
let (mut cx, a, pre, b, out) = build_mixed_chain_graph(m, n, k);
|
||||
let llir =
|
||||
extract_forced_cublaslt_llir_where(&mut cx, "mixed graph pointer recapture", |_| true);
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.load_llir(&llir);
|
||||
|
||||
reset_cublaslt_prepare_count_for_test();
|
||||
let mut first_prepare_count = None;
|
||||
for seed in [0xCC00_0001, 0xCC00_0002] {
|
||||
let a_data = random_f32_vec(m * k, seed, -0.08, 0.08);
|
||||
let pre_data = random_f32_vec(m * k, seed + 10, -0.03, 0.03);
|
||||
let b_data = random_f32_vec(k * n, seed + 20, -0.08, 0.08);
|
||||
let expected = reference_mixed_chain(&a_data, &pre_data, &b_data, m, n, k);
|
||||
|
||||
rt.set_data(a, a_data);
|
||||
rt.set_data(pre, pre_data);
|
||||
rt.set_data(b, b_data);
|
||||
rt.execute(&cx.dyn_map);
|
||||
assert_close(&rt.get_f32(out), &expected, 1e-5, 1e-5);
|
||||
if first_prepare_count.is_none() {
|
||||
first_prepare_count = Some(cublaslt_prepare_count_for_test());
|
||||
}
|
||||
}
|
||||
assert_eq!(
|
||||
cublaslt_prepare_count_for_test(),
|
||||
first_prepare_count.unwrap(),
|
||||
"A/B/C/D pointer-only recapture should reuse prepared cuBLASLt resources"
|
||||
);
|
||||
|
||||
let summaries = rt.debug_cuda_graph_summaries();
|
||||
assert!(
|
||||
summaries.iter().any(|summary| summary.n_cublaslt == 1),
|
||||
"expected cuBLASLt to remain captured after pointer recapture"
|
||||
);
|
||||
assert_eq!(rt.debug_standalone_cublaslt_host_ops(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mixed_cuda_graph_cublaslt_recaptures_on_dynamic_shape_change() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
if !cublaslt_available_for_runtime(&stream) {
|
||||
return;
|
||||
}
|
||||
if !crate::host::cublaslt::cublaslt_graph_capture_supported(&stream) {
|
||||
return;
|
||||
}
|
||||
|
||||
let (n, k) = (11, 5);
|
||||
let (mut cx, a, pre, b, out) = build_mixed_chain_graph('m', n, k);
|
||||
cx.set_dim('m', 7);
|
||||
let llir =
|
||||
extract_forced_cublaslt_llir_where(&mut cx, "mixed graph dynamic recapture", |_| true);
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.load_llir(&llir);
|
||||
|
||||
for (m, seed) in [(7usize, 0xDD00_0001), (9usize, 0xDD00_0002)] {
|
||||
cx.set_dim('m', m);
|
||||
let a_data = random_f32_vec(m * k, seed, -0.08, 0.08);
|
||||
let pre_data = random_f32_vec(m * k, seed + 10, -0.03, 0.03);
|
||||
let b_data = random_f32_vec(k * n, seed + 20, -0.08, 0.08);
|
||||
let expected = reference_mixed_chain(&a_data, &pre_data, &b_data, m, n, k);
|
||||
|
||||
rt.set_data(a, a_data);
|
||||
rt.set_data(pre, pre_data);
|
||||
rt.set_data(b, b_data);
|
||||
rt.execute(&cx.dyn_map);
|
||||
assert_close(&rt.get_f32(out), &expected, 1e-5, 1e-5);
|
||||
}
|
||||
|
||||
let summaries = rt.debug_cuda_graph_summaries();
|
||||
assert!(
|
||||
summaries.iter().any(|summary| summary.n_cublaslt == 1),
|
||||
"expected cuBLASLt to remain captured after dynamic-shape recapture"
|
||||
);
|
||||
assert_eq!(rt.debug_standalone_cublaslt_host_ops(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "expensive CUDA rewrite sweep; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
fn cublaslt_rewrites_cover_2d_matmul_plus_c_beta_one() {
|
||||
@@ -979,7 +1525,7 @@ fn cublaslt_fp8_scaled_candidate_reaches_fused_output_scale_consumer() {
|
||||
let scaled_out = scaled_a.matmul(b).cast(DType::F32) * (a_scale * b_scale).expand_rhs((m, n));
|
||||
(scaled_out * side).output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let egraph = cx.egraph().expect("search space should have an e-graph");
|
||||
|
||||
assert!(
|
||||
@@ -1013,7 +1559,7 @@ fn cublaslt_fp8_scaled_candidates_reach_fused_mlp_consumer() {
|
||||
* (up_input_scale * up_weight_scale).expand_rhs((m, n));
|
||||
(gate.swish() * up).output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let egraph = cx.egraph().expect("search space should have an e-graph");
|
||||
|
||||
assert!(
|
||||
@@ -2888,7 +3434,7 @@ fn extract_forced_cublaslt_llir_where(
|
||||
case_name: &str,
|
||||
matches: impl Fn(&LLIRGraph) -> bool,
|
||||
) -> LLIRGraph {
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
|
||||
let egraph = cx.egraph().expect("search space should have an e-graph");
|
||||
let ops = cx
|
||||
@@ -2943,7 +3489,7 @@ fn assert_no_forced_cublaslt_llir_where(
|
||||
case_name: &str,
|
||||
matches: impl Fn(&LLIRGraph) -> bool,
|
||||
) {
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
|
||||
let egraph = cx.egraph().expect("search space should have an e-graph");
|
||||
let ops = cx
|
||||
@@ -2992,7 +3538,7 @@ fn assert_no_cublaslt_llir_where(
|
||||
case_name: &str,
|
||||
matches: impl Fn(&LLIRGraph) -> bool,
|
||||
) {
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
|
||||
let egraph = cx.egraph().expect("search space should have an e-graph");
|
||||
let ops = cx
|
||||
@@ -3033,10 +3579,17 @@ fn assert_no_cublaslt_llir_where(
|
||||
}
|
||||
|
||||
fn cublaslt_ir_nodes(egraph: &SerializedEGraph) -> Vec<&NodeId> {
|
||||
let cublaslt_kind_classes = egraph
|
||||
op_ir_nodes(egraph, "cublaslt")
|
||||
.into_iter()
|
||||
.chain(op_ir_nodes(egraph, "cublaslt_scaled"))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn op_ir_nodes<'a>(egraph: &'a SerializedEGraph, kind_label: &str) -> Vec<&'a NodeId> {
|
||||
let op_kind_classes = egraph
|
||||
.enodes
|
||||
.iter()
|
||||
.filter(|(_, (label, _))| label == "cublaslt" || label == "cublaslt_scaled")
|
||||
.filter(|(_, (label, _))| label == kind_label)
|
||||
.map(|(node, _)| egraph.node_to_class[node].clone())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
@@ -3047,7 +3600,7 @@ fn cublaslt_ir_nodes(egraph: &SerializedEGraph) -> Vec<&NodeId> {
|
||||
(label == "Op"
|
||||
&& children
|
||||
.first()
|
||||
.is_some_and(|kind| cublaslt_kind_classes.contains(kind)))
|
||||
.is_some_and(|kind| op_kind_classes.contains(kind)))
|
||||
.then_some(node)
|
||||
})
|
||||
.collect()
|
||||
|
||||
@@ -83,13 +83,13 @@ fn run_reference_attention(
|
||||
let (mut cx, q_t, k_t, v_t, out_t) = build_attention_graph();
|
||||
cx.set_dim('s', batch_size);
|
||||
cx.set_dim('c', context_len);
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream.clone());
|
||||
rt.set_data(q_t, q.to_vec());
|
||||
rt.set_data(k_t, k.to_vec());
|
||||
rt.set_data(v_t, v.to_vec());
|
||||
rt = cx.search(rt, 3);
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(3));
|
||||
|
||||
rt.set_data(q_t, q.to_vec());
|
||||
rt.set_data(k_t, k.to_vec());
|
||||
@@ -779,7 +779,7 @@ fn flashinfer_extraction_reachable_from_search_space() {
|
||||
cx.set_dim('s', 1usize);
|
||||
cx.set_dim('c', 16usize);
|
||||
cx.set_dim('r', 2usize);
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
|
||||
let egraph = cx
|
||||
.egraph()
|
||||
|
||||
@@ -293,7 +293,7 @@ struct FusedRegion {
|
||||
/// Helper: collect every distinct fused region reachable across many random
|
||||
/// extractions of the search space.
|
||||
fn extract_all_fused_regions(cx: &mut Graph) -> Vec<FusedRegion> {
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let egraph = cx.egraph().expect("egraph not built");
|
||||
let ops = cx.egglog_ops().expect("ops not built");
|
||||
let custom_ops = &cx.custom_ops;
|
||||
|
||||
169
crates/luminal_cuda_lite/src/tests/generic_matmul_rewrite.rs
Normal file
169
crates/luminal_cuda_lite/src/tests/generic_matmul_rewrite.rs
Normal file
@@ -0,0 +1,169 @@
|
||||
use luminal::{
|
||||
egglog_utils::{
|
||||
NodeId, SerializedEGraph, egglog_to_llir, random_initial_choice, validate_choice_set,
|
||||
},
|
||||
prelude::*,
|
||||
};
|
||||
use rand::{SeedableRng, rngs::StdRng};
|
||||
|
||||
use crate::{kernel::KernelOp, runtime::CudaRuntime};
|
||||
|
||||
use super::utilities::{assert_close, get_cuda_stream};
|
||||
|
||||
#[test]
|
||||
fn generic_matmul_covers_noncontiguous_merged_head_projection() {
|
||||
let mut cx = Graph::default();
|
||||
let heads = 3;
|
||||
let seq = 4;
|
||||
let head_dim = 5;
|
||||
let hidden = heads * head_dim;
|
||||
let out_dim = 7;
|
||||
|
||||
let attn = cx.tensor((heads, seq, head_dim));
|
||||
let weight = cx.tensor((out_dim, hidden));
|
||||
let merged = attn.transpose(0, 1).merge_dims(1, 2);
|
||||
merged.matmul(weight.t()).output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let llir = extract_forced_kernel_llir(&mut cx, "GenericMatmul");
|
||||
let names = llir_kernel_names(&llir);
|
||||
|
||||
assert!(
|
||||
names.contains(&"GenericMatmul"),
|
||||
"expected generic matmul fallback, kernels: {names:?}"
|
||||
);
|
||||
assert!(
|
||||
!names.contains(&"Mul") && !names.contains(&"SumReduce"),
|
||||
"generic matmul should prune the broadcast multiply/sum fallback, kernels: {names:?}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn generic_matmul_executes_noncontiguous_merged_head_projection() {
|
||||
let mut cx = Graph::default();
|
||||
let heads = 3;
|
||||
let seq = 4;
|
||||
let head_dim = 5;
|
||||
let hidden = heads * head_dim;
|
||||
let out_dim = 7;
|
||||
|
||||
let attn = cx.tensor((heads, seq, head_dim));
|
||||
let weight = cx.tensor((out_dim, hidden));
|
||||
let merged = attn.transpose(0, 1).merge_dims(1, 2);
|
||||
let output = merged.matmul(weight.t()).output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let stream = get_cuda_stream().expect("CUDA device required for GenericMatmul execution test");
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
|
||||
let attn_data = seeded_data(heads * seq * head_dim, 0.19, -0.09);
|
||||
let weight_data = seeded_data(out_dim * hidden, 0.14, -0.06);
|
||||
rt.set_data(attn, attn_data.as_slice());
|
||||
rt.set_data(weight, weight_data.as_slice());
|
||||
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
|
||||
assert!(
|
||||
rt.kernel_names().contains(&"GenericMatmul"),
|
||||
"expected GenericMatmul to be selected, kernels: {:?}",
|
||||
rt.kernel_names()
|
||||
);
|
||||
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(output.id);
|
||||
|
||||
let mut expected = vec![0.0; seq * out_dim];
|
||||
for token in 0..seq {
|
||||
for out_col in 0..out_dim {
|
||||
let mut sum = 0.0;
|
||||
for inner in 0..hidden {
|
||||
let head = inner / head_dim;
|
||||
let dim = inner % head_dim;
|
||||
let attn_idx = head * seq * head_dim + token * head_dim + dim;
|
||||
sum += attn_data[attn_idx] * weight_data[out_col * hidden + inner];
|
||||
}
|
||||
expected[token * out_dim + out_col] = sum;
|
||||
}
|
||||
}
|
||||
|
||||
assert_close(&result, &expected, 1e-5, 1e-5);
|
||||
}
|
||||
|
||||
fn seeded_data(len: usize, scale: f32, bias: f32) -> Vec<f32> {
|
||||
(0..len)
|
||||
.map(|i| {
|
||||
let x = ((i * 37 + 11) % 97) as f32 / 97.0;
|
||||
x * scale + bias
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn extract_forced_kernel_llir(cx: &mut Graph, kernel_name: &str) -> LLIRGraph {
|
||||
let egraph = cx.egraph().expect("search space should have an e-graph");
|
||||
let ops = cx
|
||||
.egglog_ops()
|
||||
.expect("search space should have registered egglog ops");
|
||||
let kernel_nodes = op_ir_nodes(egraph, kernel_name);
|
||||
assert!(
|
||||
!kernel_nodes.is_empty(),
|
||||
"expected at least one {kernel_name} candidate"
|
||||
);
|
||||
|
||||
for (idx, kernel_node) in kernel_nodes.iter().enumerate() {
|
||||
let mut rng = StdRng::seed_from_u64(0x9EEE_0000 + idx as u64);
|
||||
let mut choices = random_initial_choice(egraph, &mut rng);
|
||||
let kernel_class = &egraph.node_to_class[*kernel_node];
|
||||
choices.insert(kernel_class, kernel_node);
|
||||
|
||||
if validate_choice_set(egraph, &choices, ops).is_err() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut list_cache = FxHashMap::default();
|
||||
let mut expr_cache = FxHashMap::default();
|
||||
let llir = egglog_to_llir(
|
||||
egraph,
|
||||
choices,
|
||||
ops,
|
||||
&cx.custom_ops,
|
||||
&mut list_cache,
|
||||
&mut expr_cache,
|
||||
None,
|
||||
);
|
||||
if llir_kernel_names(&llir).contains(&kernel_name) {
|
||||
return llir;
|
||||
}
|
||||
}
|
||||
|
||||
panic!("could not extract a valid {kernel_name} candidate");
|
||||
}
|
||||
|
||||
fn llir_kernel_names(llir: &LLIRGraph) -> Vec<&'static str> {
|
||||
llir.node_indices()
|
||||
.filter_map(|node| {
|
||||
llir[node]
|
||||
.to_dialect::<dyn KernelOp>()
|
||||
.map(|kernel| kernel.kernel_name())
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn op_ir_nodes<'a>(egraph: &'a SerializedEGraph, kind_label: &str) -> Vec<&'a NodeId> {
|
||||
let op_kind_classes = egraph
|
||||
.enodes
|
||||
.iter()
|
||||
.filter(|(_, (label, _))| label == kind_label)
|
||||
.map(|(node, _)| egraph.node_to_class[node].clone())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
egraph
|
||||
.enodes
|
||||
.iter()
|
||||
.filter_map(|(node, (label, children))| {
|
||||
(label == "Op"
|
||||
&& children
|
||||
.first()
|
||||
.is_some_and(|kind| op_kind_classes.contains(kind)))
|
||||
.then_some(node)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
@@ -5,12 +5,16 @@ mod bucket_tests;
|
||||
#[cfg(test)]
|
||||
mod consumed_buffer_tests;
|
||||
#[cfg(test)]
|
||||
mod conv2d_rewrite;
|
||||
#[cfg(test)]
|
||||
mod cublaslt_rewrite_tests;
|
||||
#[cfg(test)]
|
||||
mod flashinfer;
|
||||
#[cfg(test)]
|
||||
mod fusion;
|
||||
#[cfg(test)]
|
||||
mod generic_matmul_rewrite;
|
||||
#[cfg(test)]
|
||||
mod model_fuzz;
|
||||
#[cfg(test)]
|
||||
mod op_functional_tests;
|
||||
|
||||
@@ -83,7 +83,7 @@ fn fuzz_mlp(seq: usize, hidden: usize, intermediate: usize, seed: u64) {
|
||||
let w_down = cx.tensor((hidden, intermediate));
|
||||
let out = swiglu_mlp(input, w_gate, w_up, w_down).output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let mut rt = CudaRuntime::initialize(stream.clone());
|
||||
|
||||
let input_data = random_f32_vec(seq * hidden, seed, -0.5, 0.5);
|
||||
@@ -95,7 +95,7 @@ fn fuzz_mlp(seq: usize, hidden: usize, intermediate: usize, seed: u64) {
|
||||
rt.set_data(w_gate, gate_data.clone());
|
||||
rt.set_data(w_up, up_data.clone());
|
||||
rt.set_data(w_down, down_data.clone());
|
||||
rt = cx.search(rt, 5);
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(out);
|
||||
|
||||
@@ -143,7 +143,7 @@ fn fuzz_norm_proj(seq: usize, hidden: usize, proj_dim: usize, eps: f32, seed: u6
|
||||
let proj_w = cx.tensor((proj_dim, hidden));
|
||||
let out = rms_norm(input, norm_w, eps).matmul(proj_w.t()).output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let mut rt = CudaRuntime::initialize(stream.clone());
|
||||
|
||||
let input_data = random_f32_vec(seq * hidden, seed, -0.5, 0.5);
|
||||
@@ -156,7 +156,7 @@ fn fuzz_norm_proj(seq: usize, hidden: usize, proj_dim: usize, eps: f32, seed: u6
|
||||
rt.set_data(input, input_data.clone());
|
||||
rt.set_data(norm_w, norm_data.clone());
|
||||
rt.set_data(proj_w, proj_data.clone());
|
||||
rt = cx.search(rt, 5);
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(out);
|
||||
|
||||
@@ -219,7 +219,7 @@ fn fuzz_layer_no_attn(
|
||||
let mlp_out = swiglu_mlp(mlp_normed, w_gate, w_up, w_down);
|
||||
let out = (x + mlp_out).output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let mut rt = CudaRuntime::initialize(stream.clone());
|
||||
|
||||
let input_data = random_f32_vec(seq * hidden, seed, -0.5, 0.5);
|
||||
@@ -245,7 +245,7 @@ fn fuzz_layer_no_attn(
|
||||
rt.set_data(w_gate, gate_data.clone());
|
||||
rt.set_data(w_up, up_data.clone());
|
||||
rt.set_data(w_down, down_data.clone());
|
||||
rt = cx.search(rt, 5);
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(out);
|
||||
|
||||
@@ -318,7 +318,7 @@ fn fuzz_mlp_hlir_only(seq: usize, hidden: usize, intermediate: usize, seed: u64)
|
||||
let w_down = cx.tensor((hidden, intermediate));
|
||||
let out = swiglu_mlp(input, w_gate, w_up, w_down).output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let mut rt = CudaRuntime::initialize(stream.clone());
|
||||
|
||||
let input_data = random_f32_vec(seq * hidden, seed, -0.5, 0.5);
|
||||
@@ -330,7 +330,7 @@ fn fuzz_mlp_hlir_only(seq: usize, hidden: usize, intermediate: usize, seed: u64)
|
||||
rt.set_data(w_gate, gate_data.clone());
|
||||
rt.set_data(w_up, up_data.clone());
|
||||
rt.set_data(w_down, down_data.clone());
|
||||
rt = cx.search(rt, 5);
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(out);
|
||||
|
||||
@@ -481,7 +481,7 @@ mod gemma {
|
||||
let mlp_normed = rms_norm(mlp_out, post_ff_norm_w, EPS);
|
||||
let out = (x + mlp_normed).output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let mut rt = CudaRuntime::initialize(stream.clone());
|
||||
|
||||
let seed = 800u64;
|
||||
@@ -518,7 +518,7 @@ mod gemma {
|
||||
rt.set_data(w_gate, gate_data.clone());
|
||||
rt.set_data(w_up, up_data.clone());
|
||||
rt.set_data(w_down, down_data.clone());
|
||||
rt = cx.search(rt, 5);
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(out);
|
||||
|
||||
@@ -641,7 +641,7 @@ mod qwen {
|
||||
let embedding = cx.tensor((VOCAB, HIDDEN));
|
||||
let out = rms_norm(input, norm_w, EPS).matmul(embedding.t()).output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let mut rt = CudaRuntime::initialize(stream.clone());
|
||||
|
||||
let seed = 1300u64;
|
||||
@@ -655,7 +655,7 @@ mod qwen {
|
||||
rt.set_data(input, input_data.clone());
|
||||
rt.set_data(norm_w, norm_data.clone());
|
||||
rt.set_data(embedding, emb_data.clone());
|
||||
rt = cx.search(rt, 5);
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(out);
|
||||
|
||||
|
||||
@@ -256,10 +256,10 @@ fn run_argsort_test(rows: usize, cols: usize, seed: u64) {
|
||||
let ctx = CudaContext::new(0).unwrap();
|
||||
ctx.bind_to_thread().unwrap();
|
||||
let stream = ctx.default_stream();
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.set_data(input, data);
|
||||
rt = cx.search(rt, 10);
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(10));
|
||||
rt.execute(&cx.dyn_map);
|
||||
let out_dim0 = rt.get_i32(sorted_dim0.id);
|
||||
let out_dim1 = rt.get_i32(sorted_dim1.id);
|
||||
@@ -424,7 +424,7 @@ fn fuzz_test_cuda_genomes_impl(seed: u64) {
|
||||
let e = (d + c).relu();
|
||||
let out = e.output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let egraph = cx.egraph().unwrap();
|
||||
let ops = cx.egglog_ops().unwrap();
|
||||
|
||||
@@ -592,7 +592,7 @@ fn run_embed_test(vocab_size: usize, embed_dim: usize, seq_len: usize, seed: u64
|
||||
)
|
||||
.output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let mut rt = CudaRuntime::initialize(stream.clone());
|
||||
|
||||
let token_data: Vec<i32> = random_i32_vec(seq_len, seed, 0, vocab_size as i32 - 1);
|
||||
@@ -600,7 +600,7 @@ fn run_embed_test(vocab_size: usize, embed_dim: usize, seq_len: usize, seed: u64
|
||||
|
||||
rt.set_data(token_ids, token_data.clone());
|
||||
rt.set_data(embed_table, embed_data.clone());
|
||||
rt = cx.search(rt, 5);
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
let result = rt.get_f32(output);
|
||||
|
||||
@@ -27,11 +27,11 @@ pub fn kernel_add_bandwidth_test() {
|
||||
ctx.bind_to_thread().unwrap();
|
||||
let stream = ctx.default_stream();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let mut rt = CudaRuntime::initialize(stream.clone());
|
||||
rt.set_data(a, data_a.clone());
|
||||
rt.set_data(b, data_b.clone());
|
||||
rt = cx.search(rt, 5);
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
|
||||
// Warm up
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -2,10 +2,7 @@ use half::bf16;
|
||||
use luminal::{dtype::DType, prelude::*, shape::Expression};
|
||||
|
||||
use super::utilities::{assert_close, get_cuda_stream, random_f32_vec};
|
||||
use crate::{
|
||||
host::moe::{GLUMoE, GLUMoEMode},
|
||||
runtime::CudaRuntime,
|
||||
};
|
||||
use crate::{host::moe::GLUMoE, runtime::CudaRuntime};
|
||||
|
||||
const SEQ: usize = 2;
|
||||
const HIDDEN: usize = 32;
|
||||
@@ -173,30 +170,51 @@ fn gemma_gelu(x: GraphTensor) -> GraphTensor {
|
||||
x * scaled.sigmoid()
|
||||
}
|
||||
|
||||
fn glumoe_modes(rt: &CudaRuntime) -> Vec<GLUMoEMode> {
|
||||
rt.host_ops()
|
||||
.into_iter()
|
||||
.filter_map(|op| {
|
||||
op.as_any()
|
||||
.downcast_ref::<GLUMoE>()
|
||||
.map(|glumoe| glumoe.mode)
|
||||
})
|
||||
.collect()
|
||||
fn search_space_contains(cx: &Graph, op_name: &str) -> bool {
|
||||
let egraph = cx.egraph().expect("test should build an e-graph");
|
||||
|
||||
for (label, children) in egraph.enodes.values() {
|
||||
if label != "Op" {
|
||||
continue;
|
||||
}
|
||||
let Some(kind_eclass) = children.first() else {
|
||||
continue;
|
||||
};
|
||||
let Some((_, kind_enodes)) = egraph.eclasses.get(kind_eclass) else {
|
||||
continue;
|
||||
};
|
||||
if kind_enodes
|
||||
.iter()
|
||||
.any(|kind_node| egraph.enodes[kind_node].0 == op_name)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
fn run_qwen_moe(use_glumoe: bool) -> (Vec<f32>, Vec<GLUMoEMode>) {
|
||||
fn assert_glumoe_in_search_space(cx: &Graph) {
|
||||
assert!(
|
||||
search_space_contains(cx, "GLUMoE"),
|
||||
"GLUMoE was not in the e-graph search space"
|
||||
);
|
||||
}
|
||||
|
||||
fn run_qwen_moe(include_glumoe: bool) -> Vec<f32> {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return (vec![], vec![]);
|
||||
return vec![];
|
||||
};
|
||||
|
||||
let mut model = build_qwen_moe_graph();
|
||||
model.graph.set_dim('s', SEQ);
|
||||
if use_glumoe {
|
||||
model.graph.build_search_space::<CudaRuntime>();
|
||||
if include_glumoe {
|
||||
model
|
||||
.graph
|
||||
.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
} else {
|
||||
model
|
||||
.graph
|
||||
.build_search_space_exclude_ops::<CudaRuntime, GLUMoE>();
|
||||
.build_search_space_exclude_ops::<CudaRuntime, GLUMoE>(CompileOptions::default());
|
||||
}
|
||||
|
||||
let x_data = random_f32_vec(SEQ * HIDDEN, 11, -0.15, 0.15);
|
||||
@@ -215,25 +233,29 @@ fn run_qwen_moe(use_glumoe: bool) -> (Vec<f32>, Vec<GLUMoEMode>) {
|
||||
rt.set_data(model.router, router_data);
|
||||
rt.set_data(model.gate_up_weights, gate_up_data);
|
||||
rt.set_data(model.down_weights, down_data);
|
||||
rt = model.graph.search(rt, 10);
|
||||
rt = model
|
||||
.graph
|
||||
.search(rt, CompileOptions::default().search_graph_limit(10));
|
||||
rt.execute(&model.graph.dyn_map);
|
||||
|
||||
(rt.get_f32(model.output.id), glumoe_modes(&rt))
|
||||
rt.get_f32(model.output.id)
|
||||
}
|
||||
|
||||
fn run_gemma_moe(use_glumoe: bool) -> (Vec<f32>, Vec<GLUMoEMode>) {
|
||||
fn run_gemma_moe(include_glumoe: bool) -> Vec<f32> {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return (vec![], vec![]);
|
||||
return vec![];
|
||||
};
|
||||
|
||||
let mut model = build_gemma_moe_graph();
|
||||
model.graph.set_dim('s', SEQ);
|
||||
if use_glumoe {
|
||||
model.graph.build_search_space::<CudaRuntime>();
|
||||
if include_glumoe {
|
||||
model
|
||||
.graph
|
||||
.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
} else {
|
||||
model
|
||||
.graph
|
||||
.build_search_space_exclude_ops::<CudaRuntime, GLUMoE>();
|
||||
.build_search_space_exclude_ops::<CudaRuntime, GLUMoE>(CompileOptions::default());
|
||||
}
|
||||
|
||||
let router_input_data = random_f32_vec(SEQ * HIDDEN, 21, -0.15, 0.15);
|
||||
@@ -258,54 +280,60 @@ fn run_gemma_moe(use_glumoe: bool) -> (Vec<f32>, Vec<GLUMoEMode>) {
|
||||
rt.set_data(model.per_expert_scale, per_expert_scale_data);
|
||||
rt.set_data(model.gate_up_weights, gate_up_data);
|
||||
rt.set_data(model.down_weights, down_data);
|
||||
rt = model.graph.search(rt, 10);
|
||||
rt = model
|
||||
.graph
|
||||
.search(rt, CompileOptions::default().search_graph_limit(10));
|
||||
rt.execute(&model.graph.dyn_map);
|
||||
|
||||
(rt.get_f32(model.output.id), glumoe_modes(&rt))
|
||||
rt.get_f32(model.output.id)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_glumoe_matches_qwen_swiglu_pattern() {
|
||||
let (_result, modes) = run_qwen_moe(true);
|
||||
if modes.is_empty() {
|
||||
if get_cuda_stream().is_none() {
|
||||
return;
|
||||
}
|
||||
|
||||
assert_eq!(modes, vec![GLUMoEMode::SwiGLUNormalized]);
|
||||
let mut model = build_qwen_moe_graph();
|
||||
model.graph.set_dim('s', SEQ);
|
||||
model
|
||||
.graph
|
||||
.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
assert_glumoe_in_search_space(&model.graph);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_glumoe_matches_gemma_gelu_pattern() {
|
||||
let (_result, modes) = run_gemma_moe(true);
|
||||
if modes.is_empty() {
|
||||
if get_cuda_stream().is_none() {
|
||||
return;
|
||||
}
|
||||
|
||||
assert_eq!(modes, vec![GLUMoEMode::GemmaGELU]);
|
||||
let mut model = build_gemma_moe_graph();
|
||||
model.graph.set_dim('s', SEQ);
|
||||
model
|
||||
.graph
|
||||
.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
assert_glumoe_in_search_space(&model.graph);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_glumoe_swiglu_matches_unfused_output() {
|
||||
let (expected, baseline_modes) = run_qwen_moe(false);
|
||||
let expected = run_qwen_moe(false);
|
||||
if expected.is_empty() {
|
||||
return;
|
||||
}
|
||||
assert!(baseline_modes.is_empty());
|
||||
|
||||
let (actual, fused_modes) = run_qwen_moe(true);
|
||||
assert_eq!(fused_modes, vec![GLUMoEMode::SwiGLUNormalized]);
|
||||
let actual = run_qwen_moe(true);
|
||||
assert_close(&actual, &expected, 3e-2, 3e-2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_glumoe_gemma_gelu_matches_unfused_output() {
|
||||
let (expected, baseline_modes) = run_gemma_moe(false);
|
||||
let expected = run_gemma_moe(false);
|
||||
if expected.is_empty() {
|
||||
return;
|
||||
}
|
||||
assert!(baseline_modes.is_empty());
|
||||
|
||||
let (actual, fused_modes) = run_gemma_moe(true);
|
||||
assert_eq!(fused_modes, vec![GLUMoEMode::GemmaGELU]);
|
||||
let actual = run_gemma_moe(true);
|
||||
assert_close(&actual, &expected, 3e-2, 3e-2);
|
||||
}
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
use cudarc::driver::CudaContext;
|
||||
use luminal::{graph::Graph, op::Runtime};
|
||||
use luminal::{
|
||||
graph::{CompileOptions, Graph},
|
||||
op::Runtime,
|
||||
};
|
||||
|
||||
use crate::{kernel::apply_rope, runtime::CudaRuntime};
|
||||
|
||||
@@ -42,12 +45,12 @@ fn rope_matches_cpu_reference() {
|
||||
let ctx = CudaContext::new(0).unwrap();
|
||||
ctx.bind_to_thread().unwrap();
|
||||
let stream = ctx.default_stream();
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.set_data(x, x_data.clone());
|
||||
rt.set_data(cos, cos_data.clone());
|
||||
rt.set_data(sin, sin_data.clone());
|
||||
rt = cx.search(rt, 1);
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
|
||||
rt.execute(&cx.dyn_map);
|
||||
let got = rt.get_f32(y.id);
|
||||
|
||||
@@ -90,12 +93,12 @@ fn rope_flux2_shape() {
|
||||
let ctx = CudaContext::new(0).unwrap();
|
||||
ctx.bind_to_thread().unwrap();
|
||||
let stream = ctx.default_stream();
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.set_data(x, x_data.clone());
|
||||
rt.set_data(cos, cos_data.clone());
|
||||
rt.set_data(sin, sin_data.clone());
|
||||
rt = cx.search(rt, 1);
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
|
||||
rt.execute(&cx.dyn_map);
|
||||
let got = rt.get_f32(y.id);
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
//!
|
||||
//! These tests do not compare against a hand-written reference. They assert the
|
||||
//! stronger search invariant: every selectable LLIR graph from the same e-graph
|
||||
//! must produce the same outputs for the same runtime inputs.
|
||||
//! must produce finite, numerically close outputs for the same runtime inputs.
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[path = "../../../../examples/llama/src/model.rs"]
|
||||
@@ -92,8 +92,8 @@ fn llama_architecture_search_space_equivalence_fuzz() {
|
||||
.samples(SEARCH_EQUIV_SAMPLES)
|
||||
.generation_size(8)
|
||||
.mutations(3)
|
||||
.build_options(BuildSearchSpaceOptions::new().max_memory_mib(512))
|
||||
.output_f32(logits.id, "logits", 3e-3, 3e-3);
|
||||
.build_options(CompileOptions::default().max_memory_mib(512))
|
||||
.output_f32(logits.id, "logits", 5e-2, 5e-2);
|
||||
for (layer, (k_out, v_out)) in cache_outputs.into_iter().enumerate() {
|
||||
let k_out = k_out.output();
|
||||
let v_out = v_out.output();
|
||||
@@ -168,7 +168,7 @@ fn gemma_architecture_search_space_equivalence_fuzz() {
|
||||
.samples(SEARCH_EQUIV_SAMPLES)
|
||||
.generation_size(8)
|
||||
.mutations(3)
|
||||
.build_options(BuildSearchSpaceOptions::new().max_memory_mib(512))
|
||||
.build_options(CompileOptions::default().max_memory_mib(512))
|
||||
.input_f32(input.id, random_f32_vec(SEQ * HIDDEN, 101, -0.15, 0.15))
|
||||
.input_f32(attn_norm_w.id, random_f32_vec(HIDDEN, 102, 0.7, 1.3))
|
||||
.input_f32(post_attn_norm_w.id, random_f32_vec(HIDDEN, 103, 0.7, 1.3))
|
||||
@@ -263,7 +263,7 @@ fn moe_architecture_search_space_equivalence_fuzz() {
|
||||
.samples(SEARCH_EQUIV_SAMPLES)
|
||||
.generation_size(8)
|
||||
.mutations(3)
|
||||
.build_options(BuildSearchSpaceOptions::new().max_memory_mib(512))
|
||||
.build_options(CompileOptions::default().max_memory_mib(512))
|
||||
.input_f32(
|
||||
router_input.id,
|
||||
random_f32_vec(SEQ * HIDDEN, 201, -0.15, 0.15),
|
||||
@@ -353,7 +353,7 @@ fn moe_architecture_native_reference_fuzz() {
|
||||
.samples(SEARCH_EQUIV_SAMPLES)
|
||||
.generation_size(8)
|
||||
.mutations(3)
|
||||
.build_options(BuildSearchSpaceOptions::new().max_memory_mib(512))
|
||||
.build_options(CompileOptions::default().max_memory_mib(512))
|
||||
.native_reference()
|
||||
.input_f32(input.id, random_f32_vec(SEQ * HIDDEN, 301, -0.15, 0.15))
|
||||
.input_f32(
|
||||
|
||||
@@ -267,7 +267,7 @@ fn test_mini_transformer_layer() {
|
||||
let layer = MiniTransformerLayer::init(&mut cx);
|
||||
let out = layer.forward(input).output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
|
||||
let input_data = random_f32_vec(SEQ * HIDDEN, 42, -0.5, 0.5);
|
||||
@@ -280,7 +280,7 @@ fn test_mini_transformer_layer() {
|
||||
|
||||
// Use minimal search iterations to avoid excessive graph rewriting
|
||||
// which can cause float drift through softmax/RMSNorm reordering
|
||||
rt = cx.search(rt, 1);
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(out);
|
||||
|
||||
@@ -303,7 +303,7 @@ fn test_mini_transformer_two_layers() {
|
||||
let x = layer1.forward(input);
|
||||
let out = layer2.forward(x).output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
|
||||
let input_data = random_f32_vec(SEQ * HIDDEN, 42, -0.5, 0.5);
|
||||
@@ -316,7 +316,7 @@ fn test_mini_transformer_two_layers() {
|
||||
rt.set_data(*tensor, data.clone());
|
||||
}
|
||||
|
||||
rt = cx.search(rt, 1);
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(out);
|
||||
|
||||
@@ -361,7 +361,7 @@ fn test_transformer_multi_seed() {
|
||||
let layer = MiniTransformerLayer::init(&mut cx);
|
||||
let out = layer.forward(input).output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let mut rt = CudaRuntime::initialize(stream.clone());
|
||||
|
||||
let input_data = random_f32_vec(SEQ * HIDDEN, seed, -0.5, 0.5);
|
||||
@@ -372,7 +372,7 @@ fn test_transformer_multi_seed() {
|
||||
rt.set_data(*tensor, data.clone());
|
||||
}
|
||||
|
||||
rt = cx.search(rt, 1);
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(out);
|
||||
|
||||
@@ -394,7 +394,7 @@ fn test_rms_norm_cuda() {
|
||||
let weight = cx.tensor(HIDDEN);
|
||||
let out = rms_norm(input, weight, 1e-5).output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
|
||||
let input_data = random_f32_vec(SEQ * HIDDEN, 1, -0.5, 0.5);
|
||||
@@ -404,7 +404,7 @@ fn test_rms_norm_cuda() {
|
||||
.collect();
|
||||
rt.set_data(input, input_data.clone());
|
||||
rt.set_data(weight, weight_data.clone());
|
||||
rt = cx.search(rt, 5);
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(out);
|
||||
|
||||
@@ -433,7 +433,7 @@ fn test_self_attention_cuda() {
|
||||
let wo = cx.tensor((HIDDEN, HIDDEN));
|
||||
let out = self_attention(input, wq, wk, wv, wo).output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
|
||||
let input_data = random_f32_vec(SEQ * HIDDEN, 10, -0.5, 0.5);
|
||||
@@ -447,7 +447,7 @@ fn test_self_attention_cuda() {
|
||||
rt.set_data(wk, wk_data.clone());
|
||||
rt.set_data(wv, wv_data.clone());
|
||||
rt.set_data(wo, wo_data.clone());
|
||||
rt = cx.search(rt, 5);
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(out);
|
||||
|
||||
@@ -479,7 +479,7 @@ fn test_swiglu_mlp_cuda() {
|
||||
let w_down = cx.tensor((HIDDEN, INTERMEDIATE));
|
||||
let out = swiglu_mlp(input, w_gate, w_up, w_down).output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
|
||||
let input_data = random_f32_vec(SEQ * HIDDEN, 20, -0.5, 0.5);
|
||||
@@ -491,7 +491,7 @@ fn test_swiglu_mlp_cuda() {
|
||||
rt.set_data(w_gate, gate_data.clone());
|
||||
rt.set_data(w_up, up_data.clone());
|
||||
rt.set_data(w_down, down_data.clone());
|
||||
rt = cx.search(rt, 5);
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(out);
|
||||
|
||||
@@ -526,11 +526,11 @@ fn test_rolled_chained_scalar_muls() {
|
||||
let chained = ((x * 2.0_f32) * 3.0_f32) * 5.0_f32;
|
||||
let out = (chained + x).output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
let x_data = random_f32_vec(4 * 32, 101, -0.5, 0.5);
|
||||
rt.set_data(x, x_data.clone());
|
||||
rt = cx.search(rt, 3);
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(3));
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
let result = rt.get_f32(out);
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
use candle_core::{Device, Tensor, WithDType};
|
||||
use cudarc::driver::CudaContext;
|
||||
use half::{bf16, f16};
|
||||
use itertools::Itertools;
|
||||
use luminal::egglog_utils::{
|
||||
EGraphChoiceSet, egglog_to_llir, extract_generation, hash_choice_set, random_initial_choice,
|
||||
validate_choice_set,
|
||||
};
|
||||
use luminal::prelude::*;
|
||||
use luminal::prelude::{
|
||||
petgraph::{Direction, algo::toposort, visit::EdgeRef},
|
||||
*,
|
||||
};
|
||||
use num_traits::{Num, Signed};
|
||||
use rand::{Rng, SeedableRng, rngs::StdRng};
|
||||
use std::sync::Arc;
|
||||
@@ -180,7 +184,7 @@ pub struct SearchEquivalenceFuzzConfig {
|
||||
pub generation_size: usize,
|
||||
pub mutations: usize,
|
||||
pub max_attempts: usize,
|
||||
pub build_options: BuildSearchSpaceOptions,
|
||||
pub build_options: CompileOptions,
|
||||
pub reference: SearchEquivalenceReference,
|
||||
}
|
||||
|
||||
@@ -198,7 +202,7 @@ impl Default for SearchEquivalenceFuzzConfig {
|
||||
generation_size: 16,
|
||||
mutations: 2,
|
||||
max_attempts: 1_000,
|
||||
build_options: BuildSearchSpaceOptions::default(),
|
||||
build_options: CompileOptions::default(),
|
||||
reference: SearchEquivalenceReference::FirstCudaExtraction,
|
||||
}
|
||||
}
|
||||
@@ -210,6 +214,11 @@ pub struct SearchEquivalenceFuzzReport {
|
||||
pub skipped_invalid: usize,
|
||||
}
|
||||
|
||||
struct ChoiceRun {
|
||||
outputs: Vec<Vec<f32>>,
|
||||
llir_summary: String,
|
||||
}
|
||||
|
||||
pub struct CudaSearchEquivalenceFuzzer<'a> {
|
||||
cx: &'a mut Graph,
|
||||
stream: &'a Arc<cudarc::driver::CudaStream>,
|
||||
@@ -249,7 +258,7 @@ impl<'a> CudaSearchEquivalenceFuzzer<'a> {
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build_options(mut self, build_options: BuildSearchSpaceOptions) -> Self {
|
||||
pub fn build_options(mut self, build_options: CompileOptions) -> Self {
|
||||
self.config.build_options = build_options;
|
||||
self
|
||||
}
|
||||
@@ -302,7 +311,8 @@ impl<'a> CudaSearchEquivalenceFuzzer<'a> {
|
||||
/// LLIR graphs, runs each with identical inputs, and verifies every requested
|
||||
/// f32 output matches the first valid extraction. The reference is intentionally
|
||||
/// another selected LLIR graph, not a hand-written CPU implementation: this
|
||||
/// catches cases where supposedly equivalent e-graph choices diverge.
|
||||
/// catches cases where supposedly equivalent e-graph choices diverge, including
|
||||
/// candidates that produce non-finite outputs.
|
||||
pub fn fuzz_cuda_search_space_equivalence(
|
||||
cx: &mut Graph,
|
||||
stream: &Arc<cudarc::driver::CudaStream>,
|
||||
@@ -317,11 +327,11 @@ pub fn fuzz_cuda_search_space_equivalence(
|
||||
|
||||
let native_reference_outputs = if config.reference == SearchEquivalenceReference::NativeRuntime
|
||||
{
|
||||
cx.build_search_space::<NativeRuntime>();
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut native_rng = StdRng::seed_from_u64(config.seed);
|
||||
let mut native_rt = cx.search_options(
|
||||
let mut native_rt = cx.search_with_rng(
|
||||
NativeRuntime::default(),
|
||||
SearchOptions::new(1),
|
||||
CompileOptions::default().search_graph_limit(1),
|
||||
&mut native_rng,
|
||||
);
|
||||
for input in inputs {
|
||||
@@ -338,7 +348,7 @@ pub fn fuzz_cuda_search_space_equivalence(
|
||||
None
|
||||
};
|
||||
|
||||
cx.build_search_space_with_options::<CudaRuntime>(config.build_options);
|
||||
cx.build_search_space::<CudaRuntime>(config.build_options);
|
||||
|
||||
let egraph = cx.egraph().expect("search space should be built");
|
||||
let ops = cx.egglog_ops().expect("search ops should be built");
|
||||
@@ -354,12 +364,12 @@ pub fn fuzz_cuda_search_space_equivalence(
|
||||
|
||||
let mut skipped_invalid = 0usize;
|
||||
let reference_is_cuda = native_reference_outputs.is_none();
|
||||
let (reference_hash, reference_outputs, mut tested) =
|
||||
let (reference_hash, reference_outputs, reference_llir_summary, mut tested) =
|
||||
if let Some(reference_outputs) = native_reference_outputs {
|
||||
(0, reference_outputs, 0usize)
|
||||
(0, reference_outputs, None, 0usize)
|
||||
} else {
|
||||
let mut attempts = 0usize;
|
||||
let (reference_hash, reference_outputs) = loop {
|
||||
let (reference_hash, reference_run) = loop {
|
||||
attempts += 1;
|
||||
if attempts > config.max_attempts {
|
||||
panic!(
|
||||
@@ -372,17 +382,19 @@ pub fn fuzz_cuda_search_space_equivalence(
|
||||
} else {
|
||||
let hash = hash_choice_set(&base);
|
||||
match run_choice_outputs(cx, stream, inputs, outputs, &base) {
|
||||
Ok(values) => break (hash, values),
|
||||
Err(err) => {
|
||||
skipped_invalid += 1;
|
||||
eprintln!("skipping invalid reference candidate hash={hash}: {err}");
|
||||
}
|
||||
Ok(run) => break (hash, run),
|
||||
Err(err) => panic!("reference candidate hash={hash} failed: {err}"),
|
||||
}
|
||||
}
|
||||
base = random_initial_choice(egraph, &mut rng);
|
||||
prev_selected.insert(hash_choice_set(&base));
|
||||
};
|
||||
(reference_hash, reference_outputs, 1usize)
|
||||
(
|
||||
reference_hash,
|
||||
reference_run.outputs,
|
||||
Some(reference_run.llir_summary),
|
||||
1usize,
|
||||
)
|
||||
};
|
||||
|
||||
let mut attempts = 0usize;
|
||||
@@ -415,12 +427,14 @@ pub fn fuzz_cuda_search_space_equivalence(
|
||||
continue;
|
||||
}
|
||||
|
||||
let candidate_outputs = run_choice_outputs(cx, stream, inputs, outputs, &candidate)
|
||||
let candidate_run = run_choice_outputs(cx, stream, inputs, outputs, &candidate)
|
||||
.unwrap_or_else(|err| panic!("candidate hash={candidate_hash} failed: {err}"));
|
||||
assert_fuzz_outputs_close(
|
||||
outputs,
|
||||
&reference_outputs,
|
||||
&candidate_outputs,
|
||||
&candidate_run.outputs,
|
||||
&candidate_run.llir_summary,
|
||||
reference_llir_summary.as_deref(),
|
||||
reference_hash,
|
||||
candidate_hash,
|
||||
);
|
||||
@@ -446,7 +460,7 @@ fn run_choice_outputs<'a>(
|
||||
inputs: &[CudaFuzzInput],
|
||||
outputs: &[F32OutputCheck],
|
||||
choices: &EGraphChoiceSet<'a>,
|
||||
) -> Result<Vec<Vec<f32>>, String> {
|
||||
) -> Result<ChoiceRun, String> {
|
||||
let egraph = cx.egraph().ok_or("search space was not built")?;
|
||||
let ops = cx.egglog_ops().ok_or("search ops were not built")?;
|
||||
let mut list_cache = FxHashMap::default();
|
||||
@@ -461,21 +475,86 @@ fn run_choice_outputs<'a>(
|
||||
None,
|
||||
);
|
||||
unroll_loops_in_llir(&mut llir_graph);
|
||||
let llir_summary = summarize_llir(&llir_graph);
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream.clone());
|
||||
rt.load_llir(&llir_graph);
|
||||
rt.preserve_intermediate_buffers_for_debug();
|
||||
for input in inputs {
|
||||
input.apply(&mut rt);
|
||||
}
|
||||
if std::env::var_os("LUMINAL_FUZZ_DUMP_LAST_LLIR").is_some() {
|
||||
let _ = std::fs::write("/tmp/luminal_fuzz_last_candidate_llir.txt", &llir_summary);
|
||||
}
|
||||
rt.execute(&cx.dyn_map);
|
||||
let topo_order = toposort(&llir_graph, None).map_err(|cycle| {
|
||||
format!(
|
||||
"extracted LLIR contains cycle at node {:?}",
|
||||
cycle.node_id()
|
||||
)
|
||||
})?;
|
||||
if let Some(report) = rt.first_nonfinite_f32_buffer_in_nodes(topo_order) {
|
||||
let dump_path = "/tmp/luminal_fuzz_bad_candidate_llir.txt";
|
||||
let _ = std::fs::write(dump_path, &llir_summary);
|
||||
let op = llir_graph
|
||||
.node_weight(report.node)
|
||||
.map(|op| format!("{op:?}"))
|
||||
.unwrap_or_else(|| "unknown op".to_string());
|
||||
return Err(format!(
|
||||
"LLIR produced non-finite F32 buffer node={} index={} value={} op={}; llir={dump_path}",
|
||||
report.node.index(),
|
||||
report.index,
|
||||
report.value,
|
||||
op
|
||||
));
|
||||
}
|
||||
|
||||
Ok(outputs.iter().map(|out| rt.get_f32(out.id)).collect())
|
||||
let values = outputs
|
||||
.iter()
|
||||
.map(|out| rt.get_f32(out.id))
|
||||
.collect::<Vec<_>>();
|
||||
for (spec, values) in outputs.iter().zip(&values) {
|
||||
if let Some((idx, value)) = values
|
||||
.iter()
|
||||
.enumerate()
|
||||
.find(|(_, value)| !value.is_finite())
|
||||
{
|
||||
let dump_path = "/tmp/luminal_fuzz_bad_candidate_llir.txt";
|
||||
let _ = std::fs::write(dump_path, &llir_summary);
|
||||
let internal = rt
|
||||
.first_nonfinite_f32_buffer()
|
||||
.map(|report| {
|
||||
let op = llir_graph
|
||||
.node_weight(report.node)
|
||||
.map(|op| format!("{op:?}"))
|
||||
.unwrap_or_else(|| "unknown op".to_string());
|
||||
format!(
|
||||
"; first observed non-finite buffer node={} index={} value={} op={}",
|
||||
report.node.index(),
|
||||
report.index,
|
||||
report.value,
|
||||
op
|
||||
)
|
||||
})
|
||||
.unwrap_or_default();
|
||||
return Err(format!(
|
||||
"output {} produced non-finite value {value} at index {idx}{internal}; llir={dump_path}",
|
||||
spec.name
|
||||
));
|
||||
}
|
||||
}
|
||||
Ok(ChoiceRun {
|
||||
outputs: values,
|
||||
llir_summary,
|
||||
})
|
||||
}
|
||||
|
||||
fn assert_fuzz_outputs_close(
|
||||
outputs: &[F32OutputCheck],
|
||||
expected: &[Vec<f32>],
|
||||
actual: &[Vec<f32>],
|
||||
candidate_llir_summary: &str,
|
||||
reference_llir_summary: Option<&str>,
|
||||
reference_hash: u64,
|
||||
candidate_hash: u64,
|
||||
) {
|
||||
@@ -508,8 +587,16 @@ fn assert_fuzz_outputs_close(
|
||||
worst = i;
|
||||
}
|
||||
if abs > spec.atol + spec.rtol * b.abs() {
|
||||
let dump_path = "/tmp/luminal_fuzz_bad_candidate_llir.txt";
|
||||
let _ = std::fs::write(dump_path, candidate_llir_summary);
|
||||
if let Some(reference_llir_summary) = reference_llir_summary {
|
||||
let _ = std::fs::write(
|
||||
"/tmp/luminal_fuzz_bad_reference_llir.txt",
|
||||
reference_llir_summary,
|
||||
);
|
||||
}
|
||||
panic!(
|
||||
"output {} mismatch candidate hash={candidate_hash} reference hash={reference_hash} index={i} actual={a} expected={b} abs={abs} rel={rel} tolerance={}",
|
||||
"output {} mismatch candidate hash={candidate_hash} reference hash={reference_hash} index={i} actual={a} expected={b} abs={abs} rel={rel} tolerance={} candidate_llir={dump_path}",
|
||||
spec.name,
|
||||
spec.atol + spec.rtol * b.abs()
|
||||
);
|
||||
@@ -522,6 +609,22 @@ fn assert_fuzz_outputs_close(
|
||||
}
|
||||
}
|
||||
|
||||
fn summarize_llir(llir_graph: &LLIRGraph) -> String {
|
||||
llir_graph
|
||||
.node_indices()
|
||||
.map(|idx| {
|
||||
let inputs = llir_graph
|
||||
.edges_directed(idx, Direction::Incoming)
|
||||
.sorted_by_key(|edge| edge.id())
|
||||
.map(|edge| edge.source().index().to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ");
|
||||
format!("{} <- [{}]: {:?}", idx.index(), inputs, &llir_graph[idx])
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n")
|
||||
}
|
||||
|
||||
/// Get the GPU compute capability as (major, minor).
|
||||
pub fn gpu_compute_cap() -> Option<(i32, i32)> {
|
||||
let ctx = CudaContext::new(0).ok()?;
|
||||
@@ -593,12 +696,12 @@ pub fn test_unary_cuda<T: TestDType>(
|
||||
let a = cx.tensor(shape.clone());
|
||||
let b = func(a).output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let mut rt = CudaRuntime::initialize(stream.clone());
|
||||
|
||||
let input_data = generator(n_elements, seed);
|
||||
rt.set_data(a, input_data.clone());
|
||||
rt = cx.search(rt, 5);
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
let result = T::get_from_runtime(&rt, b.id);
|
||||
@@ -666,14 +769,14 @@ pub fn test_binary_cuda<T: TestDType>(
|
||||
let b = cx.tensor(b_shape.clone());
|
||||
let c = func(a, b).output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let mut rt = CudaRuntime::initialize(stream.clone());
|
||||
|
||||
let a_data = a_generator(a_elements, seed);
|
||||
let b_data = b_generator(b_elements, seed.wrapping_add(1));
|
||||
rt.set_data(a, a_data.clone());
|
||||
rt.set_data(b, b_data.clone());
|
||||
rt = cx.search(rt, 5);
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
let result = T::get_from_runtime(&rt, c.id);
|
||||
@@ -733,7 +836,7 @@ pub fn test_mod(
|
||||
let b = cx.tensor(b_shape.clone());
|
||||
let c = func(a, b).output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let mut rt = CudaRuntime::initialize(stream.clone());
|
||||
|
||||
let a_data = random_f32_vec(a_elements, seed, -0.5, 0.5);
|
||||
@@ -741,7 +844,7 @@ pub fn test_mod(
|
||||
let b_data = random_f32_vec(b_elements, seed.wrapping_add(1), 0.1, 0.5);
|
||||
rt.set_data(a, a_data.clone());
|
||||
rt.set_data(b, b_data.clone());
|
||||
rt = cx.search(rt, 5);
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
let result = rt.get_f32(c);
|
||||
|
||||
@@ -19,7 +19,14 @@ bytemuck = "1.24.0"
|
||||
|
||||
[dev-dependencies]
|
||||
candle-core = "0.9.2-alpha.1"
|
||||
hf-hub = { version = "0.4", default-features = false, features = ["rustls-tls", "ureq"] }
|
||||
luminal_nn = { path = "../luminal_nn" }
|
||||
luminal_tracing = { path = "../luminal_tracing" }
|
||||
proptest = "1.9.0"
|
||||
rand = "0.9.2"
|
||||
rustc-hash = "2.1"
|
||||
tokenizers = "0.22.2"
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||
|
||||
[lints.rust]
|
||||
unexpected_cfgs = { level = "warn", check-cfg = ['cfg(feature, values("cargo-clippy"))'] }
|
||||
|
||||
642
crates/luminal_metal/examples/llama_1b.rs
Normal file
642
crates/luminal_metal/examples/llama_1b.rs
Normal file
@@ -0,0 +1,642 @@
|
||||
use hf_hub::api::sync::Api;
|
||||
use luminal::{
|
||||
dtype::DType,
|
||||
graph::{CompileOptions, DimBucket, Graph},
|
||||
prelude::{F32Pow, GraphTensor, Runtime},
|
||||
};
|
||||
use luminal_metal::MetalRuntime;
|
||||
use luminal_nn::{LayerNorm, gather_rows, scatter_rows};
|
||||
use luminal_tracing::luminal_filter;
|
||||
use rustc_hash::FxHashSet;
|
||||
use std::{
|
||||
error::Error,
|
||||
io::Write,
|
||||
path::PathBuf,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use tokenizers::Tokenizer;
|
||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
||||
|
||||
const REPO_ID: &str = "unsloth/Llama-3.2-1B-Instruct";
|
||||
const MAX_SEQ_LEN: usize = 2048;
|
||||
const GEN_TOKENS: usize = 96;
|
||||
const SEARCH_GRAPHS: usize = 100;
|
||||
const SEARCH_MEMORY_MIB: usize = 1536;
|
||||
const PROMPT: &str = "In one short paragraph, explain neural networks using the words layers, neurons, learning, and data.";
|
||||
|
||||
const LAYERS: usize = 16;
|
||||
const HIDDEN: usize = 2048;
|
||||
const INTERMEDIATE: usize = 8192;
|
||||
const HEAD_DIM: usize = 64;
|
||||
const N_HEADS: usize = 32;
|
||||
const N_KV_HEADS: usize = 8;
|
||||
const KV_GROUPS: usize = N_HEADS / N_KV_HEADS;
|
||||
const KV_DIM: usize = N_KV_HEADS * HEAD_DIM;
|
||||
const VOCAB_SIZE: usize = 128256;
|
||||
const RMS_NORM_EPS: f32 = 1e-5;
|
||||
const ROPE_THETA: f32 = 500_000.0;
|
||||
const EOS_TOKEN: u32 = 128009;
|
||||
const STOP_TOKEN: u32 = 128001;
|
||||
|
||||
fn prepare_hf_model() -> Result<PathBuf, Box<dyn Error>> {
|
||||
let repo = Api::new()?.model(REPO_ID.to_string());
|
||||
let tokenizer_path = repo.get("tokenizer.json")?;
|
||||
repo.get("model.safetensors")?;
|
||||
Ok(tokenizer_path.parent().unwrap().to_path_buf())
|
||||
}
|
||||
|
||||
fn llama3_chat_prompt(user_prompt: &str) -> String {
|
||||
format!(
|
||||
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{user_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
}
|
||||
|
||||
#[derive(Default, Clone)]
|
||||
struct StepProfile {
|
||||
total: Duration,
|
||||
execute: Duration,
|
||||
get_logits: Duration,
|
||||
cache_roundtrip: Duration,
|
||||
}
|
||||
|
||||
fn avg_ms(duration: Duration, n: usize) -> f64 {
|
||||
if n == 0 {
|
||||
0.0
|
||||
} else {
|
||||
duration.as_secs_f64() * 1e3 / n as f64
|
||||
}
|
||||
}
|
||||
|
||||
fn sample_greedy(logits_row: &[f32], seen: &FxHashSet<u32>, repetition_penalty: f32) -> u32 {
|
||||
let mut row = logits_row.to_vec();
|
||||
for &tok in seen {
|
||||
let logit = &mut row[tok as usize];
|
||||
if *logit > 0.0 {
|
||||
*logit /= repetition_penalty;
|
||||
} else {
|
||||
*logit *= repetition_penalty;
|
||||
}
|
||||
}
|
||||
row.iter()
|
||||
.enumerate()
|
||||
.max_by(|(_, a), (_, b)| a.total_cmp(b))
|
||||
.unwrap()
|
||||
.0 as u32
|
||||
}
|
||||
|
||||
fn causal_mask(q_pos: &[usize], context_len: usize) -> Vec<f32> {
|
||||
let mut mask = vec![-1e10f32; q_pos.len() * context_len];
|
||||
for (qi, &pos) in q_pos.iter().enumerate() {
|
||||
for ci in 0..context_len {
|
||||
if ci <= pos {
|
||||
mask[qi * context_len + ci] = 0.0;
|
||||
}
|
||||
}
|
||||
}
|
||||
mask
|
||||
}
|
||||
|
||||
struct KVCache {
|
||||
k_caches: Vec<GraphTensor>,
|
||||
v_caches: Vec<GraphTensor>,
|
||||
}
|
||||
|
||||
impl KVCache {
|
||||
fn new(cx: &mut Graph, num_slots: usize) -> Self {
|
||||
let mut k_caches = Vec::with_capacity(LAYERS);
|
||||
let mut v_caches = Vec::with_capacity(LAYERS);
|
||||
for l in 0..LAYERS {
|
||||
k_caches.push(
|
||||
cx.named_tensor(format!("kv_cache.{l}.k"), (num_slots, KV_DIM))
|
||||
.persist(),
|
||||
);
|
||||
v_caches.push(
|
||||
cx.named_tensor(format!("kv_cache.{l}.v"), (num_slots, KV_DIM))
|
||||
.persist(),
|
||||
);
|
||||
}
|
||||
Self { k_caches, v_caches }
|
||||
}
|
||||
}
|
||||
|
||||
struct Llama {
|
||||
embedding: GraphTensor,
|
||||
layers: Vec<LlamaLayer>,
|
||||
lm_norm: LayerNorm,
|
||||
}
|
||||
|
||||
impl Llama {
|
||||
fn init(cx: &mut Graph) -> Self {
|
||||
let mut layers = Vec::with_capacity(LAYERS);
|
||||
for l in 0..LAYERS {
|
||||
layers.push(LlamaLayer {
|
||||
up: cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.mlp.up_proj.weight"),
|
||||
(INTERMEDIATE, HIDDEN),
|
||||
)
|
||||
.persist(),
|
||||
gate: cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.mlp.gate_proj.weight"),
|
||||
(INTERMEDIATE, HIDDEN),
|
||||
)
|
||||
.persist(),
|
||||
down: cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.mlp.down_proj.weight"),
|
||||
(HIDDEN, INTERMEDIATE),
|
||||
)
|
||||
.persist(),
|
||||
q_proj: cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.self_attn.q_proj.weight"),
|
||||
(HIDDEN, HIDDEN),
|
||||
)
|
||||
.persist(),
|
||||
k_proj: cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.self_attn.k_proj.weight"),
|
||||
(KV_DIM, HIDDEN),
|
||||
)
|
||||
.persist(),
|
||||
v_proj: cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.self_attn.v_proj.weight"),
|
||||
(KV_DIM, HIDDEN),
|
||||
)
|
||||
.persist(),
|
||||
o_proj: cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.self_attn.o_proj.weight"),
|
||||
(HIDDEN, HIDDEN),
|
||||
)
|
||||
.persist(),
|
||||
attn_rms: LayerNorm::new(
|
||||
HIDDEN,
|
||||
Some(&format!("model.layers.{l}.input_layernorm.weight")),
|
||||
None,
|
||||
false,
|
||||
RMS_NORM_EPS,
|
||||
cx,
|
||||
),
|
||||
mlp_rms: LayerNorm::new(
|
||||
HIDDEN,
|
||||
Some(&format!("model.layers.{l}.post_attention_layernorm.weight")),
|
||||
None,
|
||||
false,
|
||||
RMS_NORM_EPS,
|
||||
cx,
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
Self {
|
||||
embedding: cx
|
||||
.named_tensor("model.embed_tokens.weight", (VOCAB_SIZE, HIDDEN))
|
||||
.persist(),
|
||||
layers,
|
||||
lm_norm: LayerNorm::new(
|
||||
HIDDEN,
|
||||
Some("model.norm.weight"),
|
||||
None,
|
||||
false,
|
||||
RMS_NORM_EPS,
|
||||
cx,
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&self,
|
||||
input: GraphTensor,
|
||||
q_pos: GraphTensor,
|
||||
scatter_idx: GraphTensor,
|
||||
gather_idx: GraphTensor,
|
||||
attn_mask: GraphTensor,
|
||||
kv_cache: &KVCache,
|
||||
) -> (GraphTensor, Vec<(GraphTensor, GraphTensor)>) {
|
||||
let seq = input.dims1();
|
||||
let mut x = self.embedding.gather(
|
||||
(input * HIDDEN).expand_dim(1, HIDDEN)
|
||||
+ input.graph().arange(HIDDEN).expand_dim(0, seq),
|
||||
);
|
||||
let mut cache_outputs = Vec::with_capacity(LAYERS);
|
||||
for (i, layer) in self.layers.iter().enumerate() {
|
||||
let (x_new, k_out, v_out) = layer.forward(
|
||||
x,
|
||||
q_pos,
|
||||
scatter_idx,
|
||||
gather_idx,
|
||||
attn_mask,
|
||||
kv_cache.k_caches[i],
|
||||
kv_cache.v_caches[i],
|
||||
);
|
||||
x = x_new;
|
||||
cache_outputs.push((k_out, v_out));
|
||||
}
|
||||
|
||||
let logits = self.lm_norm.forward(x).matmul(self.embedding.t());
|
||||
(logits, cache_outputs)
|
||||
}
|
||||
}
|
||||
|
||||
struct LlamaLayer {
|
||||
up: GraphTensor,
|
||||
gate: GraphTensor,
|
||||
down: GraphTensor,
|
||||
q_proj: GraphTensor,
|
||||
k_proj: GraphTensor,
|
||||
v_proj: GraphTensor,
|
||||
o_proj: GraphTensor,
|
||||
attn_rms: LayerNorm,
|
||||
mlp_rms: LayerNorm,
|
||||
}
|
||||
|
||||
fn llama_rotary_embeddings(mut input: GraphTensor, pos_ids: GraphTensor) -> GraphTensor {
|
||||
input = input.split_dims(1, HEAD_DIM).transpose(0, 1);
|
||||
|
||||
let freqs = input
|
||||
.graph()
|
||||
.arange_options(0, HEAD_DIM, 2)
|
||||
.cast(DType::F32)
|
||||
/ HEAD_DIM as f32;
|
||||
let inv_freqs = ROPE_THETA.pow(freqs).reciprocal();
|
||||
let emb = pos_ids
|
||||
.cast(DType::F32)
|
||||
.expand_dim(1, 1)
|
||||
.matmul(inv_freqs.expand_dim(0, 1));
|
||||
|
||||
let x0 = input.slice((.., .., ..HEAD_DIM / 2));
|
||||
let x1 = input.slice((.., .., HEAD_DIM / 2..));
|
||||
|
||||
let cos = emb.cos().expand_dim(0, x0.dims()[0]);
|
||||
let sin = emb.sin().expand_dim(0, x0.dims()[0]);
|
||||
let x0_out = x0 * cos - x1 * sin;
|
||||
let x1_out = x1 * cos + x0 * sin;
|
||||
|
||||
x0_out
|
||||
.concat_along(x1_out, 2)
|
||||
.transpose(0, 1)
|
||||
.merge_dims(1, 2)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn attention(
|
||||
q_rope: GraphTensor,
|
||||
k_rope: GraphTensor,
|
||||
v: GraphTensor,
|
||||
k_cache: GraphTensor,
|
||||
v_cache: GraphTensor,
|
||||
scatter_idx: GraphTensor,
|
||||
gather_idx: GraphTensor,
|
||||
attn_mask: GraphTensor,
|
||||
) -> (GraphTensor, GraphTensor, GraphTensor) {
|
||||
let k_cache_out = scatter_rows(k_rope, scatter_idx, k_cache, KV_DIM);
|
||||
let v_cache_out = scatter_rows(v, scatter_idx, v_cache, KV_DIM);
|
||||
|
||||
let k = gather_rows(k_cache_out, gather_idx, KV_DIM);
|
||||
let v_ctx = gather_rows(v_cache_out, gather_idx, KV_DIM);
|
||||
|
||||
let q = (q_rope * 1.0).split_dims(1, HEAD_DIM).transpose(0, 1);
|
||||
let k = k.split_dims(1, HEAD_DIM).permute((1, 2, 0));
|
||||
let v_ctx = v_ctx.split_dims(1, HEAD_DIM).transpose(0, 1);
|
||||
|
||||
let k = k.expand_dim(1, KV_GROUPS).merge_dims(0, 1) * 1.0;
|
||||
let v_ctx = v_ctx.expand_dim(1, KV_GROUPS).merge_dims(0, 1) * 1.0;
|
||||
|
||||
let scores = q.matmul(k) / (HEAD_DIM as f32).sqrt();
|
||||
let masked_scores = scores + attn_mask.expand_dim(0, N_HEADS);
|
||||
let weights = masked_scores.softmax(2);
|
||||
let out = weights.matmul(v_ctx);
|
||||
let attn_out = out.transpose(0, 1).merge_dims(1, 2);
|
||||
|
||||
(attn_out, k_cache_out, v_cache_out)
|
||||
}
|
||||
|
||||
impl LlamaLayer {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn forward(
|
||||
&self,
|
||||
mut x: GraphTensor,
|
||||
q_pos: GraphTensor,
|
||||
scatter_idx: GraphTensor,
|
||||
gather_idx: GraphTensor,
|
||||
attn_mask: GraphTensor,
|
||||
k_cache: GraphTensor,
|
||||
v_cache: GraphTensor,
|
||||
) -> (GraphTensor, GraphTensor, GraphTensor) {
|
||||
let x_attn = self.attn_rms.forward(x);
|
||||
let q = x_attn.matmul(self.q_proj.t());
|
||||
let k = x_attn.matmul(self.k_proj.t());
|
||||
let v = x_attn.matmul(self.v_proj.t());
|
||||
|
||||
let q_rope = llama_rotary_embeddings(q, q_pos);
|
||||
let k_rope = llama_rotary_embeddings(k, q_pos);
|
||||
let (attn_out, k_cache_out, v_cache_out) = attention(
|
||||
q_rope,
|
||||
k_rope,
|
||||
v,
|
||||
k_cache,
|
||||
v_cache,
|
||||
scatter_idx,
|
||||
gather_idx,
|
||||
attn_mask,
|
||||
);
|
||||
x += attn_out.matmul(self.o_proj.t());
|
||||
|
||||
let x_mlp = self.mlp_rms.forward(x);
|
||||
let mlp_out =
|
||||
(x_mlp.matmul(self.gate.t()).swish() * x_mlp.matmul(self.up.t())).matmul(self.down.t());
|
||||
(x + mlp_out, k_cache_out, v_cache_out)
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn run_model_step(
|
||||
cx: &mut Graph,
|
||||
runtime: &mut MetalRuntime,
|
||||
input: GraphTensor,
|
||||
q_pos_t: GraphTensor,
|
||||
scatter_idx_t: GraphTensor,
|
||||
gather_idx_t: GraphTensor,
|
||||
attn_mask_t: GraphTensor,
|
||||
logits: GraphTensor,
|
||||
kv_cache: &KVCache,
|
||||
cache_outputs: &[(GraphTensor, GraphTensor)],
|
||||
tokens: &[u32],
|
||||
q_pos: &[i32],
|
||||
scatter_idx: &[i32],
|
||||
gather_idx: &[i32],
|
||||
attn_mask: &[f32],
|
||||
) -> (Vec<f32>, StepProfile) {
|
||||
let start = Instant::now();
|
||||
cx.set_dim('s', tokens.len());
|
||||
cx.set_dim('c', gather_idx.len());
|
||||
|
||||
runtime.set_data(input, tokens.iter().map(|t| *t as i32).collect::<Vec<_>>());
|
||||
runtime.set_data(q_pos_t, q_pos.to_vec());
|
||||
runtime.set_data(scatter_idx_t, scatter_idx.to_vec());
|
||||
runtime.set_data(gather_idx_t, gather_idx.to_vec());
|
||||
runtime.set_data(attn_mask_t, attn_mask.to_vec());
|
||||
runtime.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
|
||||
let execute_start = Instant::now();
|
||||
runtime.execute(&cx.dyn_map);
|
||||
let execute = execute_start.elapsed();
|
||||
|
||||
let logits_start = Instant::now();
|
||||
let logits_data = runtime.get_f32(logits);
|
||||
let get_logits = logits_start.elapsed();
|
||||
|
||||
let cache_start = Instant::now();
|
||||
for (layer_idx, (k_out, v_out)) in cache_outputs.iter().enumerate() {
|
||||
let k_buf = runtime.remove_buffer(*k_out);
|
||||
let v_buf = runtime.remove_buffer(*v_out);
|
||||
runtime.set_buffer(kv_cache.k_caches[layer_idx], k_buf);
|
||||
runtime.set_buffer(kv_cache.v_caches[layer_idx], v_buf);
|
||||
}
|
||||
let cache_roundtrip = cache_start.elapsed();
|
||||
|
||||
(
|
||||
logits_data,
|
||||
StepProfile {
|
||||
total: start.elapsed(),
|
||||
execute,
|
||||
get_logits,
|
||||
cache_roundtrip,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
fn main() -> Result<(), Box<dyn Error>> {
|
||||
let _ = tracing_subscriber::registry()
|
||||
.with(tracing_subscriber::fmt::layer())
|
||||
.with(luminal_filter())
|
||||
.try_init();
|
||||
|
||||
let model_dir = prepare_hf_model()?;
|
||||
println!("Using model directory: {}", model_dir.display());
|
||||
|
||||
let tokenizer = Tokenizer::from_file(model_dir.join("tokenizer.json"))
|
||||
.map_err(|err| err as Box<dyn Error>)?;
|
||||
let prompt_tokens = tokenizer
|
||||
.encode(llama3_chat_prompt(PROMPT), false)
|
||||
.map_err(|err| err as Box<dyn Error>)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
|
||||
let mut cx = Graph::default();
|
||||
let input = cx.named_tensor("input", 's').as_dtype(DType::Int);
|
||||
let q_pos_t = cx.named_tensor("q_pos", 's').as_dtype(DType::Int);
|
||||
let scatter_idx_t = cx.named_tensor("scatter_idx", 's').as_dtype(DType::Int);
|
||||
let gather_idx_t = cx.named_tensor("gather_idx", 'c').as_dtype(DType::Int);
|
||||
let attn_mask_t = cx.named_tensor("attn_mask", ('s', 'c'));
|
||||
let kv_cache = KVCache::new(&mut cx, MAX_SEQ_LEN);
|
||||
let (logits, cache_outputs) = Llama::init(&mut cx).forward(
|
||||
input,
|
||||
q_pos_t,
|
||||
scatter_idx_t,
|
||||
gather_idx_t,
|
||||
attn_mask_t,
|
||||
&kv_cache,
|
||||
);
|
||||
let logits = logits.output();
|
||||
for (k_out, v_out) in &cache_outputs {
|
||||
k_out.output();
|
||||
v_out.output();
|
||||
}
|
||||
|
||||
cx.set_dim('s', 1);
|
||||
cx.set_dim('c', 1);
|
||||
let max_prefill = (prompt_tokens.len() + 16)
|
||||
.next_power_of_two()
|
||||
.min(MAX_SEQ_LEN);
|
||||
let max_context = (prompt_tokens.len() + GEN_TOKENS + 1)
|
||||
.next_power_of_two()
|
||||
.min(MAX_SEQ_LEN);
|
||||
let search_s = 16.min(max_prefill).max(2);
|
||||
let search_c = 16.min(max_context).max(2);
|
||||
let build_options = CompileOptions::default()
|
||||
.max_memory_mib(SEARCH_MEMORY_MIB)
|
||||
.dim_buckets(
|
||||
's',
|
||||
&[
|
||||
DimBucket::new(1, 1),
|
||||
DimBucket::new(2, max_prefill).representative(search_s),
|
||||
],
|
||||
)
|
||||
.dim_buckets(
|
||||
'c',
|
||||
&[
|
||||
DimBucket::new(1, 1),
|
||||
DimBucket::new(2, max_context).representative(search_c),
|
||||
],
|
||||
);
|
||||
|
||||
println!("Building E-Graph...");
|
||||
let egraph_start = Instant::now();
|
||||
cx.build_search_space::<MetalRuntime>(build_options);
|
||||
println!(
|
||||
" E-Graph build: {:.2} s",
|
||||
egraph_start.elapsed().as_secs_f64()
|
||||
);
|
||||
|
||||
println!("Loading weights...");
|
||||
let load_start = Instant::now();
|
||||
let mut runtime = MetalRuntime::initialize(());
|
||||
runtime.load_safetensors(&cx, model_dir.join("model.safetensors").to_str().unwrap());
|
||||
println!(" Weight load: {:.2} s", load_start.elapsed().as_secs_f64());
|
||||
|
||||
let cache_bytes = MAX_SEQ_LEN * KV_DIM * std::mem::size_of::<f32>();
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
|
||||
}
|
||||
|
||||
println!("Compiling...");
|
||||
let compile_start = Instant::now();
|
||||
cx.set_dim('s', search_s);
|
||||
cx.set_dim('c', search_c);
|
||||
runtime.set_data(input, vec![1; search_s]);
|
||||
runtime.set_data(q_pos_t, (0..search_s as i32).collect::<Vec<_>>());
|
||||
runtime.set_data(scatter_idx_t, (0..search_s as i32).collect::<Vec<_>>());
|
||||
runtime.set_data(gather_idx_t, (0..search_c as i32).collect::<Vec<_>>());
|
||||
runtime.set_data(attn_mask_t, vec![0.0f32; search_s * search_c]);
|
||||
let search_options = CompileOptions::default().search_graph_limit(SEARCH_GRAPHS);
|
||||
runtime = cx.search(runtime, search_options);
|
||||
println!(
|
||||
" Search/compile: {:.2} s",
|
||||
compile_start.elapsed().as_secs_f64()
|
||||
);
|
||||
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
|
||||
}
|
||||
|
||||
let prompt_len = prompt_tokens.len();
|
||||
let mut context_len = 0usize;
|
||||
let mut profiles = Vec::new();
|
||||
let mut seen_tokens = FxHashSet::default();
|
||||
let repetition_penalty = 1.05;
|
||||
|
||||
println!(
|
||||
"Prompt: {} tokens, generating up to {} tokens",
|
||||
prompt_len, GEN_TOKENS
|
||||
);
|
||||
|
||||
let mut generated = 0usize;
|
||||
let mut next_token = None;
|
||||
if GEN_TOKENS > 0 && prompt_len > 0 {
|
||||
let positions: Vec<usize> = (0..prompt_len).collect();
|
||||
let q_pos: Vec<i32> = positions.iter().map(|&p| p as i32).collect();
|
||||
let mask = causal_mask(&positions, prompt_len);
|
||||
let (logits_data, profile) = run_model_step(
|
||||
&mut cx,
|
||||
&mut runtime,
|
||||
input,
|
||||
q_pos_t,
|
||||
scatter_idx_t,
|
||||
gather_idx_t,
|
||||
attn_mask_t,
|
||||
logits,
|
||||
&kv_cache,
|
||||
&cache_outputs,
|
||||
&prompt_tokens,
|
||||
&q_pos,
|
||||
&q_pos,
|
||||
&q_pos,
|
||||
&mask,
|
||||
);
|
||||
context_len = prompt_len;
|
||||
|
||||
let token = sample_greedy(
|
||||
&logits_data[logits_data.len() - VOCAB_SIZE..],
|
||||
&seen_tokens,
|
||||
repetition_penalty,
|
||||
);
|
||||
seen_tokens.insert(token);
|
||||
next_token = Some(token);
|
||||
generated = 1;
|
||||
profiles.push(profile);
|
||||
|
||||
if token != EOS_TOKEN && token != STOP_TOKEN {
|
||||
print!(
|
||||
"{}",
|
||||
tokenizer
|
||||
.decode(&[token], true)
|
||||
.map_err(|err| err as Box<dyn Error>)?
|
||||
);
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
}
|
||||
|
||||
while generated < GEN_TOKENS {
|
||||
let current_token = match next_token {
|
||||
Some(token) if token != EOS_TOKEN && token != STOP_TOKEN => token,
|
||||
_ => break,
|
||||
};
|
||||
let gather_idx = (0..=context_len as i32).collect::<Vec<_>>();
|
||||
let mask = causal_mask(&[context_len], context_len + 1);
|
||||
let (logits_data, profile) = run_model_step(
|
||||
&mut cx,
|
||||
&mut runtime,
|
||||
input,
|
||||
q_pos_t,
|
||||
scatter_idx_t,
|
||||
gather_idx_t,
|
||||
attn_mask_t,
|
||||
logits,
|
||||
&kv_cache,
|
||||
&cache_outputs,
|
||||
&[current_token],
|
||||
&[context_len as i32],
|
||||
&[context_len as i32],
|
||||
&gather_idx,
|
||||
&mask,
|
||||
);
|
||||
context_len += 1;
|
||||
|
||||
let token = sample_greedy(
|
||||
&logits_data[logits_data.len() - VOCAB_SIZE..],
|
||||
&seen_tokens,
|
||||
repetition_penalty,
|
||||
);
|
||||
seen_tokens.insert(token);
|
||||
next_token = Some(token);
|
||||
generated += 1;
|
||||
profiles.push(profile);
|
||||
|
||||
if token == EOS_TOKEN || token == STOP_TOKEN {
|
||||
break;
|
||||
}
|
||||
print!(
|
||||
"{}",
|
||||
tokenizer
|
||||
.decode(&[token], true)
|
||||
.map_err(|err| err as Box<dyn Error>)?
|
||||
);
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
println!();
|
||||
|
||||
let ttft = profiles.first().map(|p| p.total).unwrap_or_default();
|
||||
let decode_steps = profiles.len().saturating_sub(1);
|
||||
let decode_total: Duration = profiles.iter().skip(1).map(|p| p.total).sum();
|
||||
println!(" TTFT: {:.2} ms", ttft.as_secs_f64() * 1e3);
|
||||
println!(" TPOT: {:.2} ms", avg_ms(decode_total, decode_steps));
|
||||
|
||||
let execute_total: Duration = profiles.iter().map(|p| p.execute).sum();
|
||||
let logits_total: Duration = profiles.iter().map(|p| p.get_logits).sum();
|
||||
let cache_total: Duration = profiles.iter().map(|p| p.cache_roundtrip).sum();
|
||||
println!(
|
||||
" Profile: n={}, exec={:.2} ms, logits={:.2} ms, cache={:.2} ms",
|
||||
profiles.len(),
|
||||
avg_ms(execute_total, profiles.len()),
|
||||
avg_ms(logits_total, profiles.len()),
|
||||
avg_ms(cache_total, profiles.len()),
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -31,10 +31,42 @@ impl DynBackend for MetalDynBackend {
|
||||
}
|
||||
}
|
||||
|
||||
/// Reject dtypes the Metal kernel emitters don't support.
|
||||
///
|
||||
/// Metal codegen has no native 64-bit integer or 64-bit float paths.
|
||||
/// Reaching the kernel emitter with one of these dtypes used to panic deep
|
||||
/// in MSL generation with an unhelpful error; surfacing a clean message
|
||||
/// at translate-time lets the user fall back to CPU or pick a narrower
|
||||
/// dtype before any Metal compilation runs.
|
||||
fn reject_unsupported_dtype(graph: &Graph) -> Result<(), String> {
|
||||
for node_id in graph.graph.node_indices() {
|
||||
if let Some(input) = (*graph.graph[node_id])
|
||||
.as_any()
|
||||
.downcast_ref::<luminal::hlir::Input>()
|
||||
{
|
||||
match input.dtype {
|
||||
DType::I64 | DType::F64 => {
|
||||
return Err(format!(
|
||||
"Metal backend does not support {:?} (input `{}`). \
|
||||
Metal codegen has no native 64-bit kernels; either \
|
||||
narrow the dtype (e.g. `.to(torch.int32)` / \
|
||||
`.to(torch.float32)`) before the boundary or \
|
||||
compile with the CPU / CUDA backend.",
|
||||
input.dtype, input.label
|
||||
));
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn metal_factory(
|
||||
graph: &mut Graph,
|
||||
args: BackendCompileArgs,
|
||||
) -> Result<Box<dyn DynBackend>, String> {
|
||||
reject_unsupported_dtype(graph)?;
|
||||
compile_backend::<MetalRuntime>(
|
||||
graph,
|
||||
args,
|
||||
|
||||
@@ -6,10 +6,127 @@ pub use ops::*;
|
||||
use luminal::dtype::DType;
|
||||
use luminal::op::EgglogOp;
|
||||
use luminal::prelude::*;
|
||||
use metal::{Buffer, CommandBufferRef, ComputeCommandEncoderRef, ComputePipelineState, Device};
|
||||
use metal::{
|
||||
Buffer, CommandBufferRef, ComputeCommandEncoderRef, ComputePipelineState, Device,
|
||||
foreign_types::ForeignTypeRef, mps,
|
||||
};
|
||||
use objc::rc::StrongPtr;
|
||||
use objc::runtime::Object;
|
||||
use objc::{class, msg_send, sel, sel_impl};
|
||||
use std::cell::RefCell;
|
||||
|
||||
pub const DYN_SLOT_COUNT: usize = 26;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
struct MpsMatrixDescriptorKey {
|
||||
rows: usize,
|
||||
cols: usize,
|
||||
row_bytes: u64,
|
||||
data_type: isize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
struct MpsMatmulKey {
|
||||
transpose_lhs: bool,
|
||||
transpose_rhs: bool,
|
||||
m: usize,
|
||||
n: usize,
|
||||
k: usize,
|
||||
alpha: u64,
|
||||
beta: u64,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct MpsKernelCache {
|
||||
matrix_descriptors: FxHashMap<MpsMatrixDescriptorKey, StrongPtr>,
|
||||
matmul_kernels: FxHashMap<MpsMatmulKey, StrongPtr>,
|
||||
}
|
||||
|
||||
impl MpsKernelCache {
|
||||
pub(crate) fn matrix_descriptor(
|
||||
&mut self,
|
||||
rows: usize,
|
||||
cols: usize,
|
||||
row_bytes: u64,
|
||||
dtype: DType,
|
||||
) -> *mut Object {
|
||||
let key = MpsMatrixDescriptorKey {
|
||||
rows,
|
||||
cols,
|
||||
row_bytes,
|
||||
data_type: Self::mps_data_type(dtype),
|
||||
};
|
||||
let descriptor = self
|
||||
.matrix_descriptors
|
||||
.entry(key)
|
||||
.or_insert_with(|| unsafe {
|
||||
let descriptor: *mut Object = msg_send![
|
||||
class!(MPSMatrixDescriptor),
|
||||
matrixDescriptorWithRows: rows
|
||||
columns: cols
|
||||
rowBytes: row_bytes as usize
|
||||
dataType: key.data_type
|
||||
];
|
||||
StrongPtr::retain(descriptor)
|
||||
});
|
||||
**descriptor
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) fn matrix_multiplication(
|
||||
&mut self,
|
||||
command_buffer: &CommandBufferRef,
|
||||
transpose_lhs: bool,
|
||||
transpose_rhs: bool,
|
||||
m: usize,
|
||||
n: usize,
|
||||
k: usize,
|
||||
alpha: f64,
|
||||
beta: f64,
|
||||
) -> *mut Object {
|
||||
let key = MpsMatmulKey {
|
||||
transpose_lhs,
|
||||
transpose_rhs,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
alpha: alpha.to_bits(),
|
||||
beta: beta.to_bits(),
|
||||
};
|
||||
let kernel = self.matmul_kernels.entry(key).or_insert_with(|| unsafe {
|
||||
let device: *mut Object = msg_send![command_buffer.as_ptr(), device];
|
||||
let kernel: *mut Object = msg_send![class!(MPSMatrixMultiplication), alloc];
|
||||
let kernel: *mut Object = msg_send![
|
||||
kernel,
|
||||
initWithDevice: device
|
||||
transposeLeft: transpose_lhs
|
||||
transposeRight: transpose_rhs
|
||||
resultRows: m
|
||||
resultColumns: n
|
||||
interiorColumns: k
|
||||
alpha: alpha
|
||||
beta: beta
|
||||
];
|
||||
StrongPtr::new(kernel)
|
||||
});
|
||||
**kernel
|
||||
}
|
||||
|
||||
fn mps_data_type(dtype: DType) -> isize {
|
||||
match dtype {
|
||||
DType::F32 | DType::TF32 => mps::MPSDataType::Float32 as isize,
|
||||
DType::F16 => mps::MPSDataType::Float16 as isize,
|
||||
unsupported => panic!("MPSMatmul does not support dtype {unsupported:?}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MetalEncodeContext<'a> {
|
||||
pub(crate) command_buffer: &'a CommandBufferRef,
|
||||
pub(crate) dyn_buffer: &'a Buffer,
|
||||
pub(crate) mps_cache: &'a RefCell<MpsKernelCache>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MetalMulInfo {
|
||||
pub shape: Vec<Expression>,
|
||||
@@ -52,19 +169,18 @@ pub trait MetalKernelOp: EgglogOp {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn encode(
|
||||
&self,
|
||||
command_buffer: &CommandBufferRef,
|
||||
context: &mut MetalEncodeContext<'_>,
|
||||
pipeline: Option<&ComputePipelineState>,
|
||||
inputs: &[&Buffer],
|
||||
output: &Buffer,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
dyn_buffer: &Buffer,
|
||||
_input_dtypes: &[DType],
|
||||
_output_dtype: DType,
|
||||
) {
|
||||
let pipeline = pipeline.expect("compute pipeline not compiled");
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
let encoder = context.command_buffer.new_compute_command_encoder();
|
||||
let dyn_idx = inputs.len() as u64 + 1;
|
||||
encoder.set_buffer(dyn_idx, Some(dyn_buffer), 0);
|
||||
encoder.set_buffer(dyn_idx, Some(context.dyn_buffer), 0);
|
||||
self.encode_compute(encoder, pipeline, inputs, output, dyn_map);
|
||||
encoder.end_encoding();
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use super::{MPSMatrixLayout, MetalKernelOp, MetalMulInfo, MetalSumReduceInfo};
|
||||
use super::{MPSMatrixLayout, MetalEncodeContext, MetalKernelOp, MetalMulInfo, MetalSumReduceInfo};
|
||||
use luminal::{
|
||||
egglog_utils::{
|
||||
SerializedEGraph,
|
||||
@@ -19,9 +19,8 @@ use luminal::{
|
||||
shape::flatten_strides,
|
||||
};
|
||||
use metal::{
|
||||
Buffer, CommandBufferRef, ComputeCommandEncoderRef, ComputePipelineState, Device, MTLSize,
|
||||
Buffer, ComputeCommandEncoderRef, ComputePipelineState, Device, MTLLanguageVersion, MTLSize,
|
||||
foreign_types::{ForeignType, ForeignTypeRef},
|
||||
mps,
|
||||
};
|
||||
use objc::runtime::Object;
|
||||
use objc::{class, msg_send, sel, sel_impl};
|
||||
@@ -56,15 +55,21 @@ pub type MetalOps = (
|
||||
);
|
||||
|
||||
fn compile_shader(device: &Device, source: &str, function_name: &str) -> ComputePipelineState {
|
||||
let options = metal::CompileOptions::new();
|
||||
options.set_language_version(MTLLanguageVersion::V2_4);
|
||||
let library = device
|
||||
.new_library_with_source(source, &metal::CompileOptions::new())
|
||||
.expect("Failed to compile Metal shader");
|
||||
.new_library_with_source(source, &options)
|
||||
.unwrap_or_else(|err| {
|
||||
panic!("Failed to compile Metal shader {function_name}: {err:?}\n{source}")
|
||||
});
|
||||
let function = library
|
||||
.get_function(function_name, None)
|
||||
.expect("Failed to get function from library");
|
||||
device
|
||||
.new_compute_pipeline_state_with_function(&function)
|
||||
.expect("Failed to create compute pipeline state")
|
||||
.unwrap_or_else(|err| {
|
||||
panic!("Failed to create Metal compute pipeline state for {function_name}: {err:?}\n{source}")
|
||||
})
|
||||
}
|
||||
|
||||
fn lower_dynamic_consts(mut code: String) -> String {
|
||||
@@ -1039,42 +1044,33 @@ impl MetalKernelOp for MetalSumReduce {
|
||||
constant int *dyn [[buffer({dyn_buffer_index})]],
|
||||
constant uint &n_outputs [[buffer({n_outputs_index})]],
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint tid [[thread_index_in_threadgroup]],
|
||||
uint simd_lane [[thread_index_in_simdgroup]],
|
||||
uint simd_id [[simdgroup_index_in_threadgroup]]
|
||||
uint tid [[thread_index_in_threadgroup]]
|
||||
) {{
|
||||
if (gid >= n_outputs) return;
|
||||
|
||||
threadgroup float warp_sums[THREADS_PER_GROUP / 32];
|
||||
threadgroup float partials[THREADS_PER_GROUP];
|
||||
|
||||
int in_start = {in_idx};
|
||||
int iters = {iters};
|
||||
(void)dyn;
|
||||
|
||||
// Each thread accumulates multiple elements
|
||||
float sum = 0.0f;
|
||||
for (int i = tid; i < iters; i += THREADS_PER_GROUP) {{
|
||||
sum += {in_val};
|
||||
}}
|
||||
|
||||
// Warp-level reduction using simd_sum
|
||||
sum = simd_sum(sum);
|
||||
|
||||
// First lane of each warp writes to shared memory
|
||||
if (simd_lane == 0) {{
|
||||
warp_sums[simd_id] = sum;
|
||||
}}
|
||||
partials[tid] = sum;
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// First warp does final reduction
|
||||
if (simd_id == 0) {{
|
||||
int n_warps = THREADS_PER_GROUP / 32;
|
||||
float block_sum = (tid < uint(n_warps)) ? warp_sums[tid] : 0.0f;
|
||||
block_sum = simd_sum(block_sum);
|
||||
|
||||
if (tid == 0) {{
|
||||
out[{out_idx}] = {out_val};
|
||||
for (uint stride = THREADS_PER_GROUP / 2; stride > 0; stride >>= 1) {{
|
||||
if (tid < stride) {{
|
||||
partials[tid] += partials[tid + stride];
|
||||
}}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}}
|
||||
|
||||
if (tid == 0) {{
|
||||
float block_sum = partials[0];
|
||||
out[{out_idx}] = {out_val};
|
||||
}}
|
||||
}}
|
||||
"#,
|
||||
@@ -1220,42 +1216,33 @@ impl MetalKernelOp for MetalMaxReduce {
|
||||
constant int *dyn [[buffer({dyn_buffer_index})]],
|
||||
constant uint &n_outputs [[buffer({n_outputs_index})]],
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint tid [[thread_index_in_threadgroup]],
|
||||
uint simd_lane [[thread_index_in_simdgroup]],
|
||||
uint simd_id [[simdgroup_index_in_threadgroup]]
|
||||
uint tid [[thread_index_in_threadgroup]]
|
||||
) {{
|
||||
if (gid >= n_outputs) return;
|
||||
|
||||
threadgroup float warp_maxs[THREADS_PER_GROUP / 32];
|
||||
threadgroup float partials[THREADS_PER_GROUP];
|
||||
|
||||
int in_start = {in_idx};
|
||||
int iters = {iters};
|
||||
(void)dyn;
|
||||
|
||||
// Each thread finds max of multiple elements
|
||||
float max_val = NEG_INF_F;
|
||||
for (int i = tid; i < iters; i += THREADS_PER_GROUP) {{
|
||||
max_val = fmax(max_val, {in_val});
|
||||
}}
|
||||
|
||||
// Warp-level reduction using simd_max
|
||||
max_val = simd_max(max_val);
|
||||
|
||||
// First lane of each warp writes to shared memory
|
||||
if (simd_lane == 0) {{
|
||||
warp_maxs[simd_id] = max_val;
|
||||
}}
|
||||
partials[tid] = max_val;
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// First warp does final reduction
|
||||
if (simd_id == 0) {{
|
||||
int n_warps = THREADS_PER_GROUP / 32;
|
||||
float block_max = (tid < uint(n_warps)) ? warp_maxs[tid] : NEG_INF_F;
|
||||
block_max = simd_max(block_max);
|
||||
|
||||
if (tid == 0) {{
|
||||
out[{out_idx}] = {out_val};
|
||||
for (uint stride = THREADS_PER_GROUP / 2; stride > 0; stride >>= 1) {{
|
||||
if (tid < stride) {{
|
||||
partials[tid] = fmax(partials[tid], partials[tid + stride]);
|
||||
}}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}}
|
||||
|
||||
if (tid == 0) {{
|
||||
float block_max = partials[0];
|
||||
out[{out_idx}] = {out_val};
|
||||
}}
|
||||
}}
|
||||
"#,
|
||||
@@ -1427,8 +1414,6 @@ impl EgglogOp for MPSMatmul {
|
||||
let dt = v(format!("?{}_dt", name.replace('-', "_")));
|
||||
|
||||
rule(union(sum_op.clone(), mps_op.clone()))
|
||||
.subsume(sum_op.clone())
|
||||
.subsume(mul_op)
|
||||
.set(dtype(mps_op), dt.clone())
|
||||
.fact(eq(dt, dtype(sum_op)))
|
||||
.ruleset("kernel_lower")
|
||||
@@ -1464,6 +1449,17 @@ impl EgglogOp for MPSMatmul {
|
||||
1,
|
||||
1,
|
||||
),
|
||||
Rule::raw(
|
||||
"(rule
|
||||
((= ?mul (Op (MetalMul ?shape ?as ?bs ?os) ?inputs))
|
||||
(= ?sum (Op (MetalSum ?sshape ?sk ?ssi ?sks ?sso) (ICons ?mul (INil))))
|
||||
(= ?sum (MPSMatmul ?m ?n ?k ?lhs ?lhsrs ?rhs ?rhsrs ?ors ?tl ?tr)))
|
||||
((delete (Op (MetalSum ?sshape ?sk ?ssi ?sks ?sso) (ICons ?mul (INil))))
|
||||
(delete (Op (MetalMul ?shape ?as ?bs ?os) ?inputs)))
|
||||
:ruleset cleanup
|
||||
:name \"delete-broadcast-mul-sum-when-mps-matmul-exists\"
|
||||
)",
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
@@ -1505,14 +1501,6 @@ impl EgglogOp for MPSMatmul {
|
||||
}
|
||||
|
||||
impl MPSMatmul {
|
||||
fn mps_dtype(dtype: DType) -> mps::MPSDataType {
|
||||
match dtype {
|
||||
DType::F32 | DType::TF32 => mps::MPSDataType::Float32,
|
||||
DType::F16 => mps::MPSDataType::Float16,
|
||||
unsupported => panic!("MPSMatmul does not support dtype {unsupported:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
fn row_bytes(row_stride: Expression, dtype: DType, dyn_map: &FxHashMap<char, usize>) -> u64 {
|
||||
let elems = row_stride
|
||||
.substitute('z', Expression::from(1))
|
||||
@@ -1521,19 +1509,6 @@ impl MPSMatmul {
|
||||
(elems * dtype.bits().div_ceil(8)) as u64
|
||||
}
|
||||
|
||||
fn descriptor(rows: usize, cols: usize, row_bytes: u64, dtype: DType) -> *mut Object {
|
||||
let data_type = Self::mps_dtype(dtype) as isize;
|
||||
unsafe {
|
||||
msg_send![
|
||||
class!(MPSMatrixDescriptor),
|
||||
matrixDescriptorWithRows: rows
|
||||
columns: cols
|
||||
rowBytes: row_bytes as usize
|
||||
dataType: data_type
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
fn matrix(buffer: &Buffer, descriptor: *mut Object) -> *mut Object {
|
||||
unsafe {
|
||||
let matrix: *mut Object = msg_send![class!(MPSMatrix), alloc];
|
||||
@@ -1589,12 +1564,11 @@ impl MetalKernelOp for MPSMatmul {
|
||||
|
||||
fn encode(
|
||||
&self,
|
||||
command_buffer: &CommandBufferRef,
|
||||
context: &mut MetalEncodeContext<'_>,
|
||||
_pipeline: Option<&ComputePipelineState>,
|
||||
inputs: &[&Buffer],
|
||||
output: &Buffer,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
_dyn_buffer: &Buffer,
|
||||
input_dtypes: &[DType],
|
||||
output_dtype: DType,
|
||||
) {
|
||||
@@ -1610,46 +1584,48 @@ impl MetalKernelOp for MPSMatmul {
|
||||
let rhs_rows = if self.transpose_rhs { n } else { k };
|
||||
let rhs_cols = if self.transpose_rhs { k } else { n };
|
||||
|
||||
let lhs_desc = Self::descriptor(
|
||||
lhs_rows,
|
||||
lhs_cols,
|
||||
Self::row_bytes(self.lhs_row_stride, lhs_dtype, dyn_map),
|
||||
lhs_dtype,
|
||||
);
|
||||
let rhs_desc = Self::descriptor(
|
||||
rhs_rows,
|
||||
rhs_cols,
|
||||
Self::row_bytes(self.rhs_row_stride, rhs_dtype, dyn_map),
|
||||
rhs_dtype,
|
||||
);
|
||||
let out_desc = Self::descriptor(
|
||||
m,
|
||||
n,
|
||||
Self::row_bytes(self.out_row_stride, output_dtype, dyn_map),
|
||||
output_dtype,
|
||||
);
|
||||
let (lhs_desc, rhs_desc, out_desc, kernel) = {
|
||||
let mut cache = context.mps_cache.borrow_mut();
|
||||
(
|
||||
cache.matrix_descriptor(
|
||||
lhs_rows,
|
||||
lhs_cols,
|
||||
Self::row_bytes(self.lhs_row_stride, lhs_dtype, dyn_map),
|
||||
lhs_dtype,
|
||||
),
|
||||
cache.matrix_descriptor(
|
||||
rhs_rows,
|
||||
rhs_cols,
|
||||
Self::row_bytes(self.rhs_row_stride, rhs_dtype, dyn_map),
|
||||
rhs_dtype,
|
||||
),
|
||||
cache.matrix_descriptor(
|
||||
m,
|
||||
n,
|
||||
Self::row_bytes(self.out_row_stride, output_dtype, dyn_map),
|
||||
output_dtype,
|
||||
),
|
||||
cache.matrix_multiplication(
|
||||
context.command_buffer,
|
||||
self.transpose_lhs,
|
||||
self.transpose_rhs,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
1.0,
|
||||
0.0,
|
||||
),
|
||||
)
|
||||
};
|
||||
|
||||
let lhs = Self::matrix(inputs[0], lhs_desc);
|
||||
let rhs = Self::matrix(inputs[1], rhs_desc);
|
||||
let out = Self::matrix(output, out_desc);
|
||||
|
||||
unsafe {
|
||||
let device: *mut Object = msg_send![command_buffer.as_ptr(), device];
|
||||
let kernel: *mut Object = msg_send![class!(MPSMatrixMultiplication), alloc];
|
||||
let kernel: *mut Object = msg_send![
|
||||
kernel,
|
||||
initWithDevice: device
|
||||
transposeLeft: self.transpose_lhs
|
||||
transposeRight: self.transpose_rhs
|
||||
resultRows: m
|
||||
resultColumns: n
|
||||
interiorColumns: k
|
||||
alpha: 1.0f64
|
||||
beta: 0.0f64
|
||||
];
|
||||
let _: () = msg_send![
|
||||
kernel,
|
||||
encodeToCommandBuffer: command_buffer.as_ptr()
|
||||
encodeToCommandBuffer: context.command_buffer.as_ptr()
|
||||
leftMatrix: lhs
|
||||
rightMatrix: rhs
|
||||
resultMatrix: out
|
||||
@@ -1657,7 +1633,6 @@ impl MetalKernelOp for MPSMatmul {
|
||||
let _: () = msg_send![lhs, release];
|
||||
let _: () = msg_send![rhs, release];
|
||||
let _: () = msg_send![out, release];
|
||||
let _: () = msg_send![kernel, release];
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1839,8 +1814,6 @@ impl EgglogOp for MPSBatchedMatmul {
|
||||
let dt = v(format!("?{}_dt", name.replace('-', "_")));
|
||||
|
||||
rule(union(sum_op.clone(), mps_op.clone()))
|
||||
.subsume(sum_op.clone())
|
||||
.subsume(mul_op)
|
||||
.set(dtype(mps_op), dt.clone())
|
||||
.fact(eq(dt, dtype(sum_op)))
|
||||
.ruleset("kernel_lower")
|
||||
@@ -1878,6 +1851,17 @@ impl EgglogOp for MPSBatchedMatmul {
|
||||
),
|
||||
1,
|
||||
),
|
||||
Rule::raw(
|
||||
"(rule
|
||||
((= ?mul (Op (MetalMul ?shape ?as ?bs ?os) ?inputs))
|
||||
(= ?sum (Op (MetalSum ?sshape ?sk ?ssi ?sks ?sso) (ICons ?mul (INil))))
|
||||
(= ?sum (MPSBatchedMatmul ?b ?m ?n ?k ?lhs ?lhsbs ?lhsrs ?rhs ?rhsbs ?rhsrs ?obs ?ors ?tl ?tr)))
|
||||
((delete (Op (MetalSum ?sshape ?sk ?ssi ?sks ?sso) (ICons ?mul (INil))))
|
||||
(delete (Op (MetalMul ?shape ?as ?bs ?os) ?inputs)))
|
||||
:ruleset cleanup
|
||||
:name \"delete-broadcast-mul-sum-when-mps-batched-matmul-exists\"
|
||||
)",
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
@@ -1953,12 +1937,11 @@ impl MetalKernelOp for MPSBatchedMatmul {
|
||||
|
||||
fn encode(
|
||||
&self,
|
||||
command_buffer: &CommandBufferRef,
|
||||
context: &mut MetalEncodeContext<'_>,
|
||||
_pipeline: Option<&ComputePipelineState>,
|
||||
inputs: &[&Buffer],
|
||||
output: &Buffer,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
_dyn_buffer: &Buffer,
|
||||
input_dtypes: &[DType],
|
||||
output_dtype: DType,
|
||||
) {
|
||||
@@ -1982,25 +1965,26 @@ impl MetalKernelOp for MPSBatchedMatmul {
|
||||
let lhs_row_bytes = MPSMatmul::row_bytes(self.lhs_row_stride, lhs_dtype, dyn_map);
|
||||
let rhs_row_bytes = MPSMatmul::row_bytes(self.rhs_row_stride, rhs_dtype, dyn_map);
|
||||
let out_row_bytes = MPSMatmul::row_bytes(self.out_row_stride, output_dtype, dyn_map);
|
||||
let lhs_desc = MPSMatmul::descriptor(lhs_rows, lhs_cols, lhs_row_bytes, lhs_dtype);
|
||||
let rhs_desc = MPSMatmul::descriptor(rhs_rows, rhs_cols, rhs_row_bytes, rhs_dtype);
|
||||
let out_desc = MPSMatmul::descriptor(m, n, out_row_bytes, output_dtype);
|
||||
let (lhs_desc, rhs_desc, out_desc, kernel) = {
|
||||
let mut cache = context.mps_cache.borrow_mut();
|
||||
(
|
||||
cache.matrix_descriptor(lhs_rows, lhs_cols, lhs_row_bytes, lhs_dtype),
|
||||
cache.matrix_descriptor(rhs_rows, rhs_cols, rhs_row_bytes, rhs_dtype),
|
||||
cache.matrix_descriptor(m, n, out_row_bytes, output_dtype),
|
||||
cache.matrix_multiplication(
|
||||
context.command_buffer,
|
||||
self.transpose_lhs,
|
||||
self.transpose_rhs,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
1.0,
|
||||
0.0,
|
||||
),
|
||||
)
|
||||
};
|
||||
|
||||
unsafe {
|
||||
let device: *mut Object = msg_send![command_buffer.as_ptr(), device];
|
||||
let kernel: *mut Object = msg_send![class!(MPSMatrixMultiplication), alloc];
|
||||
let kernel: *mut Object = msg_send![
|
||||
kernel,
|
||||
initWithDevice: device
|
||||
transposeLeft: self.transpose_lhs
|
||||
transposeRight: self.transpose_rhs
|
||||
resultRows: m
|
||||
resultColumns: n
|
||||
interiorColumns: k
|
||||
alpha: 1.0f64
|
||||
beta: 0.0f64
|
||||
];
|
||||
|
||||
for batch_idx in 0..batch {
|
||||
let batch_expr = Expression::from(batch_idx as i64);
|
||||
let lhs_offset = self
|
||||
@@ -2027,7 +2011,7 @@ impl MetalKernelOp for MPSBatchedMatmul {
|
||||
let out = MPSMatmul::matrix_with_offset(output, out_offset as u64, out_desc);
|
||||
let _: () = msg_send![
|
||||
kernel,
|
||||
encodeToCommandBuffer: command_buffer.as_ptr()
|
||||
encodeToCommandBuffer: context.command_buffer.as_ptr()
|
||||
leftMatrix: lhs
|
||||
rightMatrix: rhs
|
||||
resultMatrix: out
|
||||
@@ -2036,7 +2020,6 @@ impl MetalKernelOp for MPSBatchedMatmul {
|
||||
let _: () = msg_send![rhs, release];
|
||||
let _: () = msg_send![out, release];
|
||||
}
|
||||
let _: () = msg_send![kernel, release];
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2163,24 +2146,6 @@ impl EgglogOp for GenericMatmul {
|
||||
:name \"delete-broadcast-mul-sum-when-generic-matmul-exists\"
|
||||
)",
|
||||
),
|
||||
Rule::raw(
|
||||
"(rule
|
||||
((= ?sum (GenericMatmul ?go ?gm ?gk ?gl ?glas ?gr ?grs ?gsis ?gsit ?gos))
|
||||
(= ?sum (MPSMatmul ?mm ?mn ?mk ?ml ?mls ?mr ?mrs ?mos ?mtl ?mtr)))
|
||||
((delete (GenericMatmul ?go ?gm ?gk ?gl ?glas ?gr ?grs ?gsis ?gsit ?gos)))
|
||||
:ruleset cleanup
|
||||
:name \"prefer-mps-over-generic-matmul\"
|
||||
)",
|
||||
),
|
||||
Rule::raw(
|
||||
"(rule
|
||||
((= ?sum (GenericMatmul ?go ?gm ?gk ?gl ?glas ?gr ?grs ?gsis ?gsit ?gos))
|
||||
(= ?sum (MPSBatchedMatmul ?bb ?bm ?bn ?bk ?bl ?blbs ?blrs ?br ?brbs ?brrs ?bobs ?bors ?btl ?btr)))
|
||||
((delete (GenericMatmul ?go ?gm ?gk ?gl ?glas ?gr ?grs ?gsis ?gsit ?gos)))
|
||||
:ruleset cleanup
|
||||
:name \"prefer-mps-batched-over-generic-matmul\"
|
||||
)",
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
@@ -2265,13 +2230,11 @@ impl MetalKernelOp for GenericMatmul {
|
||||
constant int *dyn [[buffer({dyn_buffer_index})]],
|
||||
constant uint &n_outputs [[buffer({n_outputs_index})]],
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint tid [[thread_index_in_threadgroup]],
|
||||
uint simd_lane [[thread_index_in_simdgroup]],
|
||||
uint simd_id [[simdgroup_index_in_threadgroup]]
|
||||
uint tid [[thread_index_in_threadgroup]]
|
||||
) {{
|
||||
if (gid >= n_outputs) return;
|
||||
|
||||
threadgroup float warp_sums[THREADS_PER_GROUP / 32];
|
||||
threadgroup float partials[THREADS_PER_GROUP];
|
||||
int base_idx = {sum_base_idx};
|
||||
int iters = {iters};
|
||||
(void)dyn;
|
||||
@@ -2282,19 +2245,18 @@ impl MetalKernelOp for GenericMatmul {
|
||||
sum += ({lhs_val}) * ({rhs_val});
|
||||
}}
|
||||
|
||||
sum = simd_sum(sum);
|
||||
if (simd_lane == 0) {{
|
||||
warp_sums[simd_id] = sum;
|
||||
}}
|
||||
partials[tid] = sum;
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (simd_id == 0) {{
|
||||
int n_warps = THREADS_PER_GROUP / 32;
|
||||
float block_sum = (tid < uint(n_warps)) ? warp_sums[tid] : 0.0f;
|
||||
block_sum = simd_sum(block_sum);
|
||||
if (tid == 0) {{
|
||||
out[{out_idx}] = {out_val};
|
||||
for (uint stride = THREADS_PER_GROUP / 2; stride > 0; stride >>= 1) {{
|
||||
if (tid < stride) {{
|
||||
partials[tid] += partials[tid + stride];
|
||||
}}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}}
|
||||
|
||||
if (tid == 0) {{
|
||||
float block_sum = partials[0];
|
||||
out[{out_idx}] = {out_val};
|
||||
}}
|
||||
}}
|
||||
"#,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
pub mod dyn_backend;
|
||||
pub mod kernel;
|
||||
mod memory_analysis;
|
||||
pub mod runtime;
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
1478
crates/luminal_metal/src/memory_analysis.rs
Normal file
1478
crates/luminal_metal/src/memory_analysis.rs
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,8 +1,9 @@
|
||||
use crate::kernel::{DYN_SLOT_COUNT, MetalKernelOp};
|
||||
use crate::kernel::{DYN_SLOT_COUNT, MetalEncodeContext, MetalKernelOp, MpsKernelCache};
|
||||
use half::{bf16, f16};
|
||||
use itertools::Itertools;
|
||||
use luminal::{
|
||||
dtype::DType,
|
||||
egglog_utils::SerializedEGraph,
|
||||
graph::{BucketLLIR, DimBucket, Graph, LLIRGraph},
|
||||
hlir::{Input, NativeData, Output},
|
||||
op::{ExecutionStats, Runtime, RuntimeStats, TimingMethod},
|
||||
@@ -16,15 +17,26 @@ use metal::{Buffer, CommandQueue, ComputePipelineState, Device, MTLResourceOptio
|
||||
use objc::rc::autoreleasepool;
|
||||
use objc::runtime::Object;
|
||||
use safetensors::{Dtype, SafeTensors};
|
||||
use std::{fs::File, time::Duration};
|
||||
use std::{cell::RefCell, fs::File, time::Duration};
|
||||
|
||||
#[derive(Clone)]
|
||||
struct MetalExecutionStep {
|
||||
node: NodeIndex,
|
||||
input_nodes: Vec<NodeIndex>,
|
||||
input_dtypes: Vec<DType>,
|
||||
output_dtype: DType,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct MetalCompiledBucket {
|
||||
bucket_indices: FxHashMap<char, usize>,
|
||||
llir_graph: LLIRGraph,
|
||||
llir_to_hlir: FxHashMap<NodeIndex, NodeIndex>,
|
||||
node_dtypes: FxHashMap<NodeIndex, DType>,
|
||||
pipelines: FxHashMap<NodeIndex, ComputePipelineState>,
|
||||
output_alias_map: FxHashMap<NodeIndex, NodeIndex>,
|
||||
output_data_map: FxHashMap<NodeIndex, NodeIndex>,
|
||||
execution_plan: Vec<MetalExecutionStep>,
|
||||
}
|
||||
|
||||
pub struct MetalRuntime {
|
||||
@@ -36,16 +48,26 @@ pub struct MetalRuntime {
|
||||
pub hlir_buffers: FxHashMap<NodeIndex, Buffer>,
|
||||
/// Buffers for LLIR intermediate/output tensors
|
||||
pub buffers: FxHashMap<NodeIndex, Buffer>,
|
||||
/// Logical byte length for each active LLIR buffer.
|
||||
buffer_lengths: FxHashMap<NodeIndex, u64>,
|
||||
/// Dynamic dimensions table (a-z), shared across all kernels.
|
||||
dyn_buffer: Buffer,
|
||||
/// Retained MPS descriptors/kernels reused across command encodes.
|
||||
mps_cache: RefCell<MpsKernelCache>,
|
||||
/// The current LLIR graph
|
||||
llir_graph: LLIRGraph,
|
||||
/// LLIR input node -> HLIR input node.
|
||||
llir_to_hlir: FxHashMap<NodeIndex, NodeIndex>,
|
||||
/// Inferred runtime dtype for each LLIR node.
|
||||
node_dtypes: FxHashMap<NodeIndex, DType>,
|
||||
/// Compiled pipeline states for each kernel node
|
||||
pipelines: FxHashMap<NodeIndex, ComputePipelineState>,
|
||||
/// LLIR output node -> input node whose buffer contains the output.
|
||||
output_alias_map: FxHashMap<NodeIndex, NodeIndex>,
|
||||
/// HLIR output id -> LLIR node whose data feeds the output.
|
||||
output_data_map: FxHashMap<NodeIndex, NodeIndex>,
|
||||
/// Precomputed executable nodes and input metadata for the active LLIR graph.
|
||||
execution_plan: Vec<MetalExecutionStep>,
|
||||
/// Bucket definitions for dynamic dimensions.
|
||||
dim_buckets: FxHashMap<char, Vec<DimBucket>>,
|
||||
/// Compiled LLIR variants, one per bucket combination.
|
||||
@@ -64,22 +86,10 @@ impl MetalRuntime {
|
||||
}
|
||||
|
||||
fn output_data_node(&self, id: NodeIndex) -> NodeIndex {
|
||||
let output_id = self
|
||||
.llir_graph
|
||||
.node_indices()
|
||||
.find(|n| {
|
||||
if let Some(Output { node }) = self.llir_graph[*n].to_op::<Output>() {
|
||||
*node == id.index()
|
||||
} else {
|
||||
false
|
||||
}
|
||||
})
|
||||
.expect("Cannot find output tensor!");
|
||||
|
||||
self.llir_graph
|
||||
.neighbors_directed(output_id, Direction::Incoming)
|
||||
.next()
|
||||
.unwrap()
|
||||
self.output_data_map
|
||||
.get(&id)
|
||||
.copied()
|
||||
.unwrap_or_else(|| panic!("Cannot find output tensor {id:?}!"))
|
||||
}
|
||||
|
||||
fn follow_aliases(&self, mut node: NodeIndex) -> NodeIndex {
|
||||
@@ -225,6 +235,7 @@ impl MetalRuntime {
|
||||
let data_id = self.follow_aliases(self.output_data_node(id.to_id()));
|
||||
|
||||
if let Some(buffer) = self.buffers.remove(&data_id) {
|
||||
self.buffer_lengths.remove(&data_id);
|
||||
return buffer;
|
||||
}
|
||||
|
||||
@@ -269,12 +280,21 @@ impl MetalRuntime {
|
||||
.map(|inp| inp.dtype)
|
||||
})
|
||||
.unwrap_or(DType::F32);
|
||||
let logical_bytes = self
|
||||
.buffer_lengths
|
||||
.get(&data_id)
|
||||
.copied()
|
||||
.unwrap_or_else(|| buffer.length());
|
||||
assert!(
|
||||
logical_bytes <= buffer.length(),
|
||||
"Logical buffer size exceeds allocated Metal buffer size"
|
||||
);
|
||||
|
||||
unsafe {
|
||||
match dtype {
|
||||
DType::F16 => {
|
||||
let ptr = buffer.contents() as *const f16;
|
||||
let len = buffer.length() as usize / std::mem::size_of::<f16>();
|
||||
let len = logical_bytes as usize / std::mem::size_of::<f16>();
|
||||
std::slice::from_raw_parts(ptr, len)
|
||||
.iter()
|
||||
.map(|v| v.to_f32())
|
||||
@@ -282,7 +302,7 @@ impl MetalRuntime {
|
||||
}
|
||||
DType::Int => {
|
||||
let ptr = buffer.contents() as *const i32;
|
||||
let len = buffer.length() as usize / std::mem::size_of::<i32>();
|
||||
let len = logical_bytes as usize / std::mem::size_of::<i32>();
|
||||
std::slice::from_raw_parts(ptr, len)
|
||||
.iter()
|
||||
.map(|v| *v as f32)
|
||||
@@ -290,7 +310,7 @@ impl MetalRuntime {
|
||||
}
|
||||
_ => {
|
||||
let ptr = buffer.contents() as *const f32;
|
||||
let len = buffer.length() as usize / std::mem::size_of::<f32>();
|
||||
let len = logical_bytes as usize / std::mem::size_of::<f32>();
|
||||
std::slice::from_raw_parts(ptr, len).to_vec()
|
||||
}
|
||||
}
|
||||
@@ -304,6 +324,26 @@ impl Runtime for MetalRuntime {
|
||||
type ExecReturn = ();
|
||||
type ProfileMetric = Duration;
|
||||
|
||||
fn late_egglog_passes(
|
||||
ops: &[std::sync::Arc<Box<dyn luminal::op::EgglogOp>>],
|
||||
options: &luminal::graph::CompileOptions,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> Vec<luminal::egglog_utils::LateEgglogPass> {
|
||||
vec![crate::memory_analysis::metal_memory_analysis_pass(
|
||||
ops,
|
||||
options.max_memory_bytes,
|
||||
dyn_map,
|
||||
)]
|
||||
}
|
||||
|
||||
fn estimate_graph_memory<'a>(
|
||||
egraph: &'a SerializedEGraph,
|
||||
choices: &luminal::egglog_utils::EGraphChoiceSet<'a>,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> Option<usize> {
|
||||
crate::memory_analysis::estimate_graph_memory_bytes(egraph, choices, dyn_map)
|
||||
}
|
||||
|
||||
fn initialize(_: Self::CompileArg) -> Self {
|
||||
let device = Device::system_default().expect("No Metal device found!");
|
||||
let command_queue = device.new_command_queue();
|
||||
@@ -318,11 +358,16 @@ impl Runtime for MetalRuntime {
|
||||
input_data: FxHashMap::default(),
|
||||
hlir_buffers: FxHashMap::default(),
|
||||
buffers: FxHashMap::default(),
|
||||
buffer_lengths: FxHashMap::default(),
|
||||
dyn_buffer,
|
||||
mps_cache: RefCell::new(MpsKernelCache::default()),
|
||||
llir_graph: StableGraph::default(),
|
||||
llir_to_hlir: FxHashMap::default(),
|
||||
node_dtypes: FxHashMap::default(),
|
||||
pipelines: FxHashMap::default(),
|
||||
output_alias_map: FxHashMap::default(),
|
||||
output_data_map: FxHashMap::default(),
|
||||
execution_plan: vec![],
|
||||
dim_buckets: FxHashMap::default(),
|
||||
compiled_buckets: vec![],
|
||||
active_bucket: 0,
|
||||
@@ -336,6 +381,7 @@ impl Runtime for MetalRuntime {
|
||||
#[tracing::instrument(skip_all)]
|
||||
fn load_llir(&mut self, llir_graph: &LLIRGraph) {
|
||||
self.buffers.clear();
|
||||
self.buffer_lengths.clear();
|
||||
self.dim_buckets.clear();
|
||||
self.compiled_buckets = vec![self.compile_bucket(FxHashMap::default(), llir_graph)];
|
||||
self.activate_bucket(0);
|
||||
@@ -347,19 +393,25 @@ impl Runtime for MetalRuntime {
|
||||
llir_graph: &LLIRGraph,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
trials: usize,
|
||||
_timeout: Option<std::time::Duration>,
|
||||
timeout: Option<std::time::Duration>,
|
||||
) -> (Self::ProfileMetric, String) {
|
||||
self.load_llir(llir_graph);
|
||||
self.allocate_intermediate_buffers(dyn_map);
|
||||
|
||||
let trials = trials.max(1);
|
||||
let profile_start = std::time::Instant::now();
|
||||
let mut duration = Duration::default();
|
||||
let mut completed_trials = 0;
|
||||
for _ in 0..trials {
|
||||
let start = std::time::Instant::now();
|
||||
self.execute(dyn_map);
|
||||
duration += start.elapsed();
|
||||
completed_trials += 1;
|
||||
if timeout.is_some_and(|timeout| profile_start.elapsed() >= timeout) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
duration /= trials as u32;
|
||||
duration /= completed_trials as u32;
|
||||
|
||||
(duration, format!("{:.2?}", duration))
|
||||
}
|
||||
@@ -370,74 +422,43 @@ impl Runtime for MetalRuntime {
|
||||
self.select_bucket(dyn_map);
|
||||
self.allocate_active_intermediate_buffers(dyn_map);
|
||||
|
||||
let llir_to_hlir: FxHashMap<NodeIndex, NodeIndex> = self
|
||||
.llir_graph
|
||||
.node_indices()
|
||||
.filter_map(|n| {
|
||||
if let Some(Input { node, .. }) = self.llir_graph[n].to_op::<Input>() {
|
||||
Some((n, NodeIndex::new(*node)))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
let topo_order = toposort(&self.llir_graph, None).expect("Graph has cycles!");
|
||||
|
||||
self.update_dyn_buffer(dyn_map);
|
||||
let command_buffer = self.command_queue.new_command_buffer();
|
||||
let mut encode_context = MetalEncodeContext {
|
||||
command_buffer,
|
||||
dyn_buffer: &self.dyn_buffer,
|
||||
mps_cache: &self.mps_cache,
|
||||
};
|
||||
|
||||
for node in topo_order {
|
||||
if self.llir_graph[node].to_op::<Input>().is_some()
|
||||
|| self.llir_graph[node].to_op::<Output>().is_some()
|
||||
{
|
||||
continue;
|
||||
}
|
||||
for step in &self.execution_plan {
|
||||
let kernel_op = self.llir_graph[step.node]
|
||||
.to_dialect::<dyn MetalKernelOp>()
|
||||
.expect("Execution plan referenced a non-Metal op");
|
||||
let pipeline = self.pipelines.get(&step.node);
|
||||
|
||||
if let Some(kernel_op) = self.llir_graph[node].to_dialect::<dyn MetalKernelOp>() {
|
||||
let pipeline = self.pipelines.get(&node);
|
||||
let input_buffers: Vec<&Buffer> = step
|
||||
.input_nodes
|
||||
.iter()
|
||||
.map(|&n| self.buffer_for_llir_node(n, &self.llir_to_hlir))
|
||||
.collect();
|
||||
|
||||
let input_nodes: Vec<NodeIndex> = self
|
||||
.llir_graph
|
||||
.edges_directed(node, Direction::Incoming)
|
||||
.sorted_by_key(|e| e.id())
|
||||
.map(|e| e.source())
|
||||
.collect();
|
||||
let output_buffer = if let Some(alias_idx) = kernel_op.output_aliases_input() {
|
||||
input_buffers[alias_idx]
|
||||
} else {
|
||||
self.buffers
|
||||
.get(&step.node)
|
||||
.expect("Output buffer not allocated!")
|
||||
};
|
||||
|
||||
let input_buffers: Vec<&Buffer> = input_nodes
|
||||
.iter()
|
||||
.map(|&n| self.buffer_for_llir_node(n, &llir_to_hlir))
|
||||
.collect();
|
||||
let input_dtypes: Vec<DType> = input_nodes
|
||||
.iter()
|
||||
.map(|n| {
|
||||
self.node_dtypes
|
||||
.get(n)
|
||||
.copied()
|
||||
.unwrap_or_else(|| panic!("Missing inferred dtype for node {n:?}"))
|
||||
})
|
||||
.collect();
|
||||
|
||||
let output_buffer = if let Some(alias_idx) = kernel_op.output_aliases_input() {
|
||||
input_buffers[alias_idx]
|
||||
} else {
|
||||
self.buffers
|
||||
.get(&node)
|
||||
.expect("Output buffer not allocated!")
|
||||
};
|
||||
let output_dtype = self.node_dtypes.get(&node).copied().unwrap_or(DType::F32);
|
||||
|
||||
kernel_op.encode(
|
||||
command_buffer,
|
||||
pipeline,
|
||||
&input_buffers,
|
||||
output_buffer,
|
||||
dyn_map,
|
||||
&self.dyn_buffer,
|
||||
&input_dtypes,
|
||||
output_dtype,
|
||||
);
|
||||
}
|
||||
kernel_op.encode(
|
||||
&mut encode_context,
|
||||
pipeline,
|
||||
&input_buffers,
|
||||
output_buffer,
|
||||
dyn_map,
|
||||
&step.input_dtypes,
|
||||
step.output_dtype,
|
||||
);
|
||||
}
|
||||
|
||||
command_buffer.commit();
|
||||
@@ -447,6 +468,22 @@ impl Runtime for MetalRuntime {
|
||||
|
||||
fn clear_intermediate_buffers(&mut self) {
|
||||
self.buffers.clear();
|
||||
self.buffer_lengths.clear();
|
||||
}
|
||||
|
||||
fn intermediate_buffer_bytes(&self) -> usize {
|
||||
self.buffers
|
||||
.values()
|
||||
.map(|buffer| buffer.length() as usize)
|
||||
.sum()
|
||||
}
|
||||
|
||||
fn planned_intermediate_buffer_bytes(&self) -> Option<usize> {
|
||||
Some(self.intermediate_buffer_bytes())
|
||||
}
|
||||
|
||||
fn allocated_intermediate_buffer_bytes(&self) -> Option<usize> {
|
||||
Some(self.intermediate_buffer_bytes())
|
||||
}
|
||||
|
||||
fn load_llir_buckets(
|
||||
@@ -455,6 +492,7 @@ impl Runtime for MetalRuntime {
|
||||
bucket_llirs: &[BucketLLIR],
|
||||
) {
|
||||
self.buffers.clear();
|
||||
self.buffer_lengths.clear();
|
||||
self.dim_buckets = dim_buckets.clone();
|
||||
self.compiled_buckets = bucket_llirs
|
||||
.iter()
|
||||
@@ -497,7 +535,7 @@ impl MetalRuntime {
|
||||
fn create_input_buffer(&self, data: &NativeData, dtype: DType) -> Buffer {
|
||||
match dtype {
|
||||
DType::F32 => {
|
||||
let values: Vec<f32> = (0..data.len()).map(|i| data.f32(i)).collect();
|
||||
let values = data.to_f32_vec();
|
||||
self.device.new_buffer_with_data(
|
||||
values.as_ptr() as *const _,
|
||||
std::mem::size_of_val(values.as_slice()) as u64,
|
||||
@@ -505,7 +543,7 @@ impl MetalRuntime {
|
||||
)
|
||||
}
|
||||
DType::F16 => {
|
||||
let values: Vec<f16> = (0..data.len()).map(|i| data.f16(i)).collect();
|
||||
let values = data.to_f16_vec();
|
||||
self.device.new_buffer_with_data(
|
||||
values.as_ptr() as *const _,
|
||||
std::mem::size_of_val(values.as_slice()) as u64,
|
||||
@@ -513,7 +551,7 @@ impl MetalRuntime {
|
||||
)
|
||||
}
|
||||
DType::Int => {
|
||||
let values: Vec<i32> = (0..data.len()).map(|i| data.i32(i)).collect();
|
||||
let values = data.to_i32_vec();
|
||||
self.device.new_buffer_with_data(
|
||||
values.as_ptr() as *const _,
|
||||
std::mem::size_of_val(values.as_slice()) as u64,
|
||||
@@ -531,6 +569,7 @@ impl MetalRuntime {
|
||||
|
||||
fn allocate_active_intermediate_buffers(&mut self, dyn_map: &FxHashMap<char, usize>) {
|
||||
let mut planned = Vec::new();
|
||||
let capacity_dyn_map = self.active_capacity_dyn_map(dyn_map);
|
||||
|
||||
for node in self.llir_graph.node_indices() {
|
||||
if self.llir_graph[node].to_op::<Input>().is_some() {
|
||||
@@ -541,28 +580,58 @@ impl MetalRuntime {
|
||||
if kernel_op.output_aliases_input().is_some() {
|
||||
continue;
|
||||
}
|
||||
let size = kernel_op.output_size().exec(dyn_map).unwrap();
|
||||
let dtype = self.node_dtypes.get(&node).copied().unwrap_or(DType::F32);
|
||||
let bytes = (size * dtype.bits().div_ceil(8)) as u64;
|
||||
let requested_bytes =
|
||||
Self::output_bytes(kernel_op.as_ref().as_ref(), dtype, dyn_map);
|
||||
let allocation_bytes =
|
||||
Self::output_bytes(kernel_op.as_ref().as_ref(), dtype, &capacity_dyn_map)
|
||||
.max(requested_bytes);
|
||||
let needs_buffer = self
|
||||
.buffers
|
||||
.get(&node)
|
||||
.is_none_or(|buffer| buffer.length() != bytes);
|
||||
.is_none_or(|buffer| requested_bytes > buffer.length());
|
||||
|
||||
planned.push((node, bytes, needs_buffer));
|
||||
planned.push((node, requested_bytes, allocation_bytes, needs_buffer));
|
||||
}
|
||||
}
|
||||
|
||||
for (node, bytes, needs_buffer) in planned {
|
||||
for (node, requested_bytes, allocation_bytes, needs_buffer) in planned {
|
||||
self.buffer_lengths.insert(node, requested_bytes);
|
||||
if needs_buffer {
|
||||
let buffer = self
|
||||
.device
|
||||
.new_buffer(bytes, MTLResourceOptions::StorageModeShared);
|
||||
.new_buffer(allocation_bytes, MTLResourceOptions::StorageModeShared);
|
||||
self.buffers.insert(node, buffer);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn output_bytes(
|
||||
kernel_op: &dyn MetalKernelOp,
|
||||
dtype: DType,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> u64 {
|
||||
let size = kernel_op.output_size().exec(dyn_map).unwrap();
|
||||
(size * dtype.bits().div_ceil(8)) as u64
|
||||
}
|
||||
|
||||
fn active_capacity_dyn_map(&self, dyn_map: &FxHashMap<char, usize>) -> FxHashMap<char, usize> {
|
||||
let mut capacity_dyn_map = dyn_map.clone();
|
||||
let Some(active_bucket) = self.compiled_buckets.get(self.active_bucket) else {
|
||||
return capacity_dyn_map;
|
||||
};
|
||||
|
||||
for (&dim, buckets) in &self.dim_buckets {
|
||||
if let Some(&bucket_index) = active_bucket.bucket_indices.get(&dim)
|
||||
&& let Some(bucket) = buckets.get(bucket_index)
|
||||
{
|
||||
capacity_dyn_map.insert(dim, bucket.max);
|
||||
}
|
||||
}
|
||||
|
||||
capacity_dyn_map
|
||||
}
|
||||
|
||||
fn compile_bucket(
|
||||
&self,
|
||||
bucket_indices: FxHashMap<char, usize>,
|
||||
@@ -571,12 +640,17 @@ impl MetalRuntime {
|
||||
let mut node_dtypes = FxHashMap::default();
|
||||
let mut pipelines = FxHashMap::default();
|
||||
let mut output_alias_map = FxHashMap::default();
|
||||
let mut output_data_map = FxHashMap::default();
|
||||
let mut execution_plan = Vec::new();
|
||||
let mut llir_to_hlir = FxHashMap::default();
|
||||
let llir_graph = llir_graph.clone();
|
||||
|
||||
let topo_order = toposort(&llir_graph, None).expect("Graph has cycles!");
|
||||
for node in topo_order {
|
||||
for node in &topo_order {
|
||||
let node = *node;
|
||||
if let Some(input) = llir_graph[node].to_op::<Input>() {
|
||||
node_dtypes.insert(node, input.dtype);
|
||||
llir_to_hlir.insert(node, NodeIndex::new(input.node));
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -610,17 +684,38 @@ impl MetalRuntime {
|
||||
{
|
||||
output_alias_map.insert(node, target);
|
||||
}
|
||||
execution_plan.push(MetalExecutionStep {
|
||||
node,
|
||||
input_nodes,
|
||||
input_dtypes,
|
||||
output_dtype,
|
||||
});
|
||||
} else {
|
||||
panic!("Metal runtime cannot execute unlowered LLIR node {node:?}");
|
||||
}
|
||||
}
|
||||
|
||||
for node in topo_order {
|
||||
if let Some(Output { node: hlir_node }) = llir_graph[node].to_op::<Output>()
|
||||
&& let Some(data_node) = llir_graph
|
||||
.edges_directed(node, Direction::Incoming)
|
||||
.sorted_by_key(|e| e.id())
|
||||
.next()
|
||||
.map(|e| e.source())
|
||||
{
|
||||
output_data_map.insert(NodeIndex::new(*hlir_node), data_node);
|
||||
}
|
||||
}
|
||||
|
||||
MetalCompiledBucket {
|
||||
bucket_indices,
|
||||
llir_graph,
|
||||
llir_to_hlir,
|
||||
node_dtypes,
|
||||
pipelines,
|
||||
output_alias_map,
|
||||
output_data_map,
|
||||
execution_plan,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -632,11 +727,15 @@ impl MetalRuntime {
|
||||
.clone();
|
||||
self.active_bucket = index;
|
||||
self.llir_graph = bucket.llir_graph;
|
||||
self.llir_to_hlir = bucket.llir_to_hlir;
|
||||
self.node_dtypes = bucket.node_dtypes;
|
||||
self.pipelines = bucket.pipelines;
|
||||
self.output_alias_map = bucket.output_alias_map;
|
||||
self.output_data_map = bucket.output_data_map;
|
||||
self.execution_plan = bucket.execution_plan;
|
||||
self.refresh_input_data_buffers();
|
||||
self.buffers.clear();
|
||||
self.buffer_lengths.clear();
|
||||
}
|
||||
|
||||
fn refresh_input_data_buffers(&mut self) {
|
||||
@@ -706,74 +805,43 @@ impl MetalRuntime {
|
||||
self.select_bucket(dyn_map);
|
||||
self.allocate_active_intermediate_buffers(dyn_map);
|
||||
|
||||
let llir_to_hlir: FxHashMap<NodeIndex, NodeIndex> = self
|
||||
.llir_graph
|
||||
.node_indices()
|
||||
.filter_map(|n| {
|
||||
if let Some(Input { node, .. }) = self.llir_graph[n].to_op::<Input>() {
|
||||
Some((n, NodeIndex::new(*node)))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
let topo_order = toposort(&self.llir_graph, None).expect("Graph has cycles!");
|
||||
|
||||
self.update_dyn_buffer(dyn_map);
|
||||
let command_buffer = self.command_queue.new_command_buffer();
|
||||
let mut encode_context = MetalEncodeContext {
|
||||
command_buffer,
|
||||
dyn_buffer: &self.dyn_buffer,
|
||||
mps_cache: &self.mps_cache,
|
||||
};
|
||||
|
||||
for node in topo_order {
|
||||
if self.llir_graph[node].to_op::<Input>().is_some()
|
||||
|| self.llir_graph[node].to_op::<Output>().is_some()
|
||||
{
|
||||
continue;
|
||||
}
|
||||
for step in &self.execution_plan {
|
||||
let kernel_op = self.llir_graph[step.node]
|
||||
.to_dialect::<dyn MetalKernelOp>()
|
||||
.expect("Execution plan referenced a non-Metal op");
|
||||
let pipeline = self.pipelines.get(&step.node);
|
||||
|
||||
if let Some(kernel_op) = self.llir_graph[node].to_dialect::<dyn MetalKernelOp>() {
|
||||
let pipeline = self.pipelines.get(&node);
|
||||
let input_buffers: Vec<&Buffer> = step
|
||||
.input_nodes
|
||||
.iter()
|
||||
.map(|&n| self.buffer_for_llir_node(n, &self.llir_to_hlir))
|
||||
.collect();
|
||||
|
||||
let input_nodes: Vec<NodeIndex> = self
|
||||
.llir_graph
|
||||
.edges_directed(node, Direction::Incoming)
|
||||
.sorted_by_key(|e| e.id())
|
||||
.map(|e| e.source())
|
||||
.collect();
|
||||
let output_buffer = if let Some(alias_idx) = kernel_op.output_aliases_input() {
|
||||
input_buffers[alias_idx]
|
||||
} else {
|
||||
self.buffers
|
||||
.get(&step.node)
|
||||
.expect("Output buffer not allocated!")
|
||||
};
|
||||
|
||||
let input_buffers: Vec<&Buffer> = input_nodes
|
||||
.iter()
|
||||
.map(|&n| self.buffer_for_llir_node(n, &llir_to_hlir))
|
||||
.collect();
|
||||
let input_dtypes: Vec<DType> = input_nodes
|
||||
.iter()
|
||||
.map(|n| {
|
||||
self.node_dtypes
|
||||
.get(n)
|
||||
.copied()
|
||||
.unwrap_or_else(|| panic!("Missing inferred dtype for node {n:?}"))
|
||||
})
|
||||
.collect();
|
||||
|
||||
let output_buffer = if let Some(alias_idx) = kernel_op.output_aliases_input() {
|
||||
input_buffers[alias_idx]
|
||||
} else {
|
||||
self.buffers
|
||||
.get(&node)
|
||||
.expect("Output buffer not allocated!")
|
||||
};
|
||||
let output_dtype = self.node_dtypes.get(&node).copied().unwrap_or(DType::F32);
|
||||
|
||||
kernel_op.encode(
|
||||
command_buffer,
|
||||
pipeline,
|
||||
&input_buffers,
|
||||
output_buffer,
|
||||
dyn_map,
|
||||
&self.dyn_buffer,
|
||||
&input_dtypes,
|
||||
output_dtype,
|
||||
);
|
||||
}
|
||||
kernel_op.encode(
|
||||
&mut encode_context,
|
||||
pipeline,
|
||||
&input_buffers,
|
||||
output_buffer,
|
||||
dyn_map,
|
||||
&step.input_dtypes,
|
||||
step.output_dtype,
|
||||
);
|
||||
}
|
||||
|
||||
command_buffer.commit();
|
||||
|
||||
@@ -3,6 +3,7 @@ use candle_core::{Device as CandleDevice, Tensor as CandleTensor};
|
||||
use half::{bf16, f16};
|
||||
use luminal::prelude::*;
|
||||
use proptest::prelude::*;
|
||||
use rand::{SeedableRng, rngs::StdRng};
|
||||
use safetensors::{Dtype, tensor::TensorView};
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
@@ -38,6 +39,34 @@ fn bytes_of<T: bytemuck::NoUninit>(values: &[T]) -> Vec<u8> {
|
||||
bytemuck::cast_slice(values).to_vec()
|
||||
}
|
||||
|
||||
fn search_candidates(cx: &mut Graph, rt: MetalRuntime, limit: usize) -> MetalRuntime {
|
||||
let mut rng = StdRng::seed_from_u64(0);
|
||||
cx.search_with_rng(
|
||||
rt,
|
||||
CompileOptions::default().search_graph_limit(limit),
|
||||
&mut rng,
|
||||
)
|
||||
}
|
||||
|
||||
fn egraph_has_op(cx: &Graph, op_name: &str) -> bool {
|
||||
cx.egraph()
|
||||
.expect("search space should be built")
|
||||
.enodes
|
||||
.values()
|
||||
.any(|(label, _)| label == op_name)
|
||||
}
|
||||
|
||||
fn assert_matmul_options(cx: &Graph, mps_op_name: &str) {
|
||||
assert!(
|
||||
egraph_has_op(cx, mps_op_name),
|
||||
"expected {mps_op_name} rewrite option in e-graph"
|
||||
);
|
||||
assert!(
|
||||
egraph_has_op(cx, "GenericMatmul"),
|
||||
"expected GenericMatmul rewrite option in e-graph"
|
||||
);
|
||||
}
|
||||
|
||||
fn write_test_safetensors(tensors: &[(&str, Dtype, Vec<usize>, Vec<u8>)]) -> PathBuf {
|
||||
let tensor_views: HashMap<String, TensorView<'_>> = tensors
|
||||
.iter()
|
||||
@@ -272,11 +301,11 @@ fn dynamic_dim_sum_reduce_runs() {
|
||||
let input = cx.tensor(('a', 2));
|
||||
let output = input.sum(0).output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.set_data(input, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
|
||||
|
||||
rt = cx.search(rt, 1);
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -290,13 +319,14 @@ fn metal_bucketed_dynamic_dim_dispatches_correct_graph() {
|
||||
let input = cx.tensor(('s', 4));
|
||||
let output = (input + input).output();
|
||||
|
||||
cx.set_dim_buckets('s', &[DimBucket::new(1, 1), DimBucket::new(2, 4)]);
|
||||
cx.set_dim('s', 1);
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
cx.build_search_space::<MetalRuntime>(
|
||||
CompileOptions::default().dim_buckets('s', &[DimBucket::new(1, 1), DimBucket::new(2, 4)]),
|
||||
);
|
||||
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.set_data(input, vec![1.0f32; 4]);
|
||||
rt = cx.search(rt, 5);
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
|
||||
cx.set_dim('s', 1);
|
||||
let s1_input = vec![1.0, 2.0, 3.0, 4.0];
|
||||
@@ -321,10 +351,10 @@ fn metal_int_arithmetic_preserves_large_values() {
|
||||
let large_index = (token * 1024) + 123;
|
||||
let mod_output = (large_index % 65_537).output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.set_data(token, &[16_385i32]);
|
||||
rt = cx.search(rt, 1);
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -343,11 +373,11 @@ proptest! {
|
||||
let input = cx.tensor(len);
|
||||
let output = (input + input).output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
let input_values: Vec<f32> = values.into_iter().take(len).collect();
|
||||
rt.set_data(input, &input_values);
|
||||
rt = cx.search(rt, 5);
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -365,11 +395,11 @@ proptest! {
|
||||
let input = cx.tensor(len);
|
||||
let output = (input * input).output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
let input_values: Vec<f32> = values.into_iter().take(len).collect();
|
||||
rt.set_data(input, &input_values);
|
||||
rt = cx.search(rt, 5);
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -387,11 +417,11 @@ proptest! {
|
||||
let input = cx.tensor(len);
|
||||
let output = input.exp2().output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
let input_values: Vec<f32> = values.into_iter().take(len).collect();
|
||||
rt.set_data(input, &input_values);
|
||||
rt = cx.search(rt, 5);
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -401,6 +431,16 @@ proptest! {
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn metal_build_search_space_accepts_memory_budget() {
|
||||
let mut cx = Graph::default();
|
||||
let a = cx.tensor(4);
|
||||
let b = cx.tensor(4);
|
||||
(a * b).output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>(CompileOptions::default().max_memory_mib(1));
|
||||
}
|
||||
|
||||
/// Simple deterministic test for add
|
||||
#[test]
|
||||
fn metal_simple_add() {
|
||||
@@ -409,11 +449,11 @@ fn metal_simple_add() {
|
||||
let b = cx.tensor(4);
|
||||
let output = (a + b).output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.set_data(a, &[1.0, 2.0, 3.0, 4.0]);
|
||||
rt.set_data(b, &[5.0, 6.0, 7.0, 8.0]);
|
||||
rt = cx.search(rt, 5);
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -429,11 +469,11 @@ fn metal_simple_mul() {
|
||||
let b = cx.tensor(4);
|
||||
let output = (a * b).output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.set_data(a, &[1.0, 2.0, 3.0, 4.0]);
|
||||
rt.set_data(b, &[5.0, 6.0, 7.0, 8.0]);
|
||||
rt = cx.search(rt, 5);
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -448,10 +488,10 @@ fn metal_simple_exp2() {
|
||||
let input = cx.tensor(4);
|
||||
let output = input.exp2().output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.set_data(input, &[0.0, 1.0, 2.0, 3.0]);
|
||||
rt = cx.search(rt, 5);
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -465,10 +505,10 @@ fn metal_simple_log2() {
|
||||
let input = cx.tensor(4);
|
||||
let output = input.log2().output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.set_data(input, &[1.0, 2.0, 4.0, 8.0]);
|
||||
rt = cx.search(rt, 5);
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -482,7 +522,7 @@ fn metal_simple_sin() {
|
||||
let input = cx.tensor(4);
|
||||
let output = input.sin().output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.set_data(
|
||||
input,
|
||||
@@ -493,7 +533,7 @@ fn metal_simple_sin() {
|
||||
3.0 * std::f32::consts::FRAC_PI_2,
|
||||
],
|
||||
);
|
||||
rt = cx.search(rt, 5);
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -507,10 +547,10 @@ fn metal_simple_sqrt() {
|
||||
let input = cx.tensor(4);
|
||||
let output = input.sqrt().output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.set_data(input, &[1.0, 4.0, 9.0, 16.0]);
|
||||
rt = cx.search(rt, 5);
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -524,10 +564,10 @@ fn metal_simple_recip() {
|
||||
let input = cx.tensor(4);
|
||||
let output = input.reciprocal().output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.set_data(input, &[1.0, 2.0, 4.0, 5.0]);
|
||||
rt = cx.search(rt, 5);
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -542,11 +582,11 @@ fn metal_simple_mod() {
|
||||
let b = cx.tensor(4);
|
||||
let output = (a % b).output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.set_data(a, &[7.0, 10.0, 15.0, 8.5]);
|
||||
rt.set_data(b, &[3.0, 4.0, 6.0, 2.5]);
|
||||
rt = cx.search(rt, 5);
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -561,11 +601,11 @@ fn metal_simple_less_than() {
|
||||
let b = cx.tensor(4);
|
||||
let output = a.lt(b).output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.set_data(a, &[1.0, 5.0, 3.0, 4.0]);
|
||||
rt.set_data(b, &[2.0, 3.0, 3.0, 5.0]);
|
||||
rt = cx.search(rt, 5);
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -581,11 +621,11 @@ fn metal_simple_sum_reduce() {
|
||||
// sum over axis 1
|
||||
let output = input.sum(1).output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
// [[1,2,3,4], [5,6,7,8]] -> [10, 26]
|
||||
rt.set_data(input, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
|
||||
rt = cx.search(rt, 5);
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -600,11 +640,11 @@ fn metal_simple_max_reduce() {
|
||||
// max over axis 1
|
||||
let output = input.max(1).output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
// [[1,4,2,3], [8,5,7,6]] -> [4, 8]
|
||||
rt.set_data(input, &[1.0, 4.0, 2.0, 3.0, 8.0, 5.0, 7.0, 6.0]);
|
||||
rt = cx.search(rt, 5);
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -618,10 +658,10 @@ fn metal_f16_cast_roundtrip() {
|
||||
let input = cx.tensor(4);
|
||||
let output = input.cast(DType::F16).output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.set_data(input, &[1.0, -2.5, 3.25, 4.75]);
|
||||
rt = cx.search(rt, 3);
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(3));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -638,11 +678,11 @@ fn metal_f16_intermediate_add_roundtrip() {
|
||||
.cast(DType::F32)
|
||||
.output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.set_data(a, &[1.0, 2.0, -3.0, 4.5]);
|
||||
rt.set_data(b, &[0.5, -1.0, 3.0, 0.25]);
|
||||
rt = cx.search(rt, 3);
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(3));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -657,7 +697,7 @@ fn metal_specialized_matmul() {
|
||||
let b = cx.tensor((TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN));
|
||||
let output = a.matmul(b).output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
|
||||
let a_data = seeded_data(TRANSFORMER_SEQ * TRANSFORMER_HIDDEN, 1.0, -0.5);
|
||||
@@ -665,7 +705,7 @@ fn metal_specialized_matmul() {
|
||||
|
||||
rt.set_data(a, &a_data);
|
||||
rt.set_data(b, &b_data);
|
||||
rt = cx.search(rt, 1);
|
||||
rt = search_candidates(&mut cx, rt, 32);
|
||||
assert!(
|
||||
rt.contains_matmul(),
|
||||
"expected Metal runtime to fuse matmul, kernels: {:?}",
|
||||
@@ -697,7 +737,8 @@ fn metal_regular_tiled_matmul_path() {
|
||||
let b = cx.tensor((k, n));
|
||||
let output = a.matmul(b).output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
|
||||
assert_matmul_options(&cx, "MPSMatmul");
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
|
||||
let a_data = seeded_data(m * k, 0.4, -0.2);
|
||||
@@ -705,19 +746,7 @@ fn metal_regular_tiled_matmul_path() {
|
||||
|
||||
rt.set_data(a, &a_data);
|
||||
rt.set_data(b, &b_data);
|
||||
rt = cx.search(rt, 1);
|
||||
|
||||
let kernels = rt.debug_kernel_ops();
|
||||
assert!(
|
||||
kernels.iter().any(|k| k.contains("MPSMatmul")),
|
||||
"expected MPS matmul path, kernels: {:?}",
|
||||
kernels
|
||||
);
|
||||
assert!(
|
||||
!kernels.iter().any(|k| k.contains("GenericMatmul")),
|
||||
"MPS-compatible matmul should not extract the generic fallback, kernels: {:?}",
|
||||
kernels
|
||||
);
|
||||
rt = search_candidates(&mut cx, rt, 32);
|
||||
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
@@ -743,7 +772,8 @@ fn metal_mps_matmul_transposed_rhs_weight_layout() {
|
||||
let weight = cx.tensor((n, k));
|
||||
let output = a.matmul(weight.t()).output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
|
||||
assert_matmul_options(&cx, "MPSMatmul");
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
|
||||
let a_data = seeded_data(m * k, 0.35, -0.17);
|
||||
@@ -751,14 +781,7 @@ fn metal_mps_matmul_transposed_rhs_weight_layout() {
|
||||
|
||||
rt.set_data(a, &a_data);
|
||||
rt.set_data(weight, &weight_data);
|
||||
rt = cx.search(rt, 1);
|
||||
|
||||
let kernels = rt.debug_kernel_ops();
|
||||
assert!(
|
||||
kernels.iter().any(|k| k.contains("transpose_rhs: true")),
|
||||
"expected MPS matmul to cover transposed row-major RHS, kernels: {:?}",
|
||||
kernels
|
||||
);
|
||||
rt = search_candidates(&mut cx, rt, 32);
|
||||
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
@@ -784,7 +807,8 @@ fn metal_mps_matmul_transposed_lhs_layout() {
|
||||
let rhs = cx.tensor((k, n));
|
||||
let output = lhs_storage.t().matmul(rhs).output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
|
||||
assert_matmul_options(&cx, "MPSMatmul");
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
|
||||
let lhs_data = seeded_data(k * m, 0.31, -0.12);
|
||||
@@ -792,14 +816,7 @@ fn metal_mps_matmul_transposed_lhs_layout() {
|
||||
|
||||
rt.set_data(lhs_storage, &lhs_data);
|
||||
rt.set_data(rhs, &rhs_data);
|
||||
rt = cx.search(rt, 1);
|
||||
|
||||
let kernels = rt.debug_kernel_ops();
|
||||
assert!(
|
||||
kernels.iter().any(|k| k.contains("transpose_lhs: true")),
|
||||
"expected MPS matmul to cover transposed row-major LHS, kernels: {:?}",
|
||||
kernels
|
||||
);
|
||||
rt = search_candidates(&mut cx, rt, 32);
|
||||
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
@@ -829,21 +846,15 @@ fn metal_mps_batched_matmul_row_row_layout() {
|
||||
let b = cx.tensor((batch, k, n));
|
||||
let output = a.matmul(b).output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
|
||||
assert_matmul_options(&cx, "MPSBatchedMatmul");
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
|
||||
let a_data = seeded_data(batch * m * k, 0.17, -0.08);
|
||||
let b_data = seeded_data(batch * k * n, 0.11, -0.05);
|
||||
rt.set_data(a, &a_data);
|
||||
rt.set_data(b, &b_data);
|
||||
rt = cx.search(rt, 1);
|
||||
|
||||
let kernels = rt.debug_kernel_ops();
|
||||
assert!(
|
||||
kernels.iter().any(|k| k.contains("MPSBatchedMatmul")),
|
||||
"expected MPS batched matmul path, kernels: {:?}",
|
||||
kernels
|
||||
);
|
||||
rt = search_candidates(&mut cx, rt, 32);
|
||||
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
@@ -879,14 +890,18 @@ fn metal_generic_matmul_covers_noncontiguous_merged_head_projection() {
|
||||
let merged = attn.transpose(0, 1).merge_dims(1, 2);
|
||||
let output = merged.matmul(weight.t()).output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
|
||||
assert!(
|
||||
egraph_has_op(&cx, "GenericMatmul"),
|
||||
"expected GenericMatmul rewrite option in e-graph"
|
||||
);
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
|
||||
let attn_data = seeded_data(heads * seq * head_dim, 0.19, -0.09);
|
||||
let weight_data = seeded_data(out_dim * hidden, 0.14, -0.06);
|
||||
rt.set_data(attn, &attn_data);
|
||||
rt.set_data(weight, &weight_data);
|
||||
rt = cx.search(rt, 1);
|
||||
rt = search_candidates(&mut cx, rt, 32);
|
||||
|
||||
let kernels = rt.debug_kernel_ops();
|
||||
assert!(
|
||||
@@ -934,23 +949,15 @@ fn metal_mps_batched_matmul_transposed_rhs_layout() {
|
||||
let weight = cx.tensor((batch, n, k));
|
||||
let output = a.matmul(weight.permute((0, 2, 1))).output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
|
||||
assert_matmul_options(&cx, "MPSBatchedMatmul");
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
|
||||
let a_data = seeded_data(batch * m * k, 0.13, -0.06);
|
||||
let weight_data = seeded_data(batch * n * k, 0.09, -0.04);
|
||||
rt.set_data(a, &a_data);
|
||||
rt.set_data(weight, &weight_data);
|
||||
rt = cx.search(rt, 1);
|
||||
|
||||
let kernels = rt.debug_kernel_ops();
|
||||
assert!(
|
||||
kernels
|
||||
.iter()
|
||||
.any(|k| k.contains("MPSBatchedMatmul") && k.contains("transpose_rhs: true")),
|
||||
"expected MPS batched matmul transposed RHS path, kernels: {:?}",
|
||||
kernels
|
||||
);
|
||||
rt = search_candidates(&mut cx, rt, 32);
|
||||
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
@@ -983,7 +990,8 @@ fn metal_mps_matmul_f16_transposed_rhs_weight_layout() {
|
||||
let weight = cx.tensor((n, k)).as_dtype(DType::F16);
|
||||
let output = a.matmul(weight.t()).cast(DType::F32).output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
|
||||
assert_matmul_options(&cx, "MPSMatmul");
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
|
||||
let a_data = seeded_data(m * k, 0.22, -0.07);
|
||||
@@ -991,14 +999,7 @@ fn metal_mps_matmul_f16_transposed_rhs_weight_layout() {
|
||||
|
||||
rt.set_data(a, to_f16_vec(&a_data));
|
||||
rt.set_data(weight, to_f16_vec(&weight_data));
|
||||
rt = cx.search(rt, 1);
|
||||
|
||||
let kernels = rt.debug_kernel_ops();
|
||||
assert!(
|
||||
kernels.iter().any(|k| k.contains("transpose_rhs: true")),
|
||||
"expected MPS F16 matmul to cover transposed row-major RHS, kernels: {:?}",
|
||||
kernels
|
||||
);
|
||||
rt = search_candidates(&mut cx, rt, 32);
|
||||
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
@@ -1021,7 +1022,7 @@ fn metal_rms_norm() {
|
||||
let weight = cx.tensor(TRANSFORMER_HIDDEN);
|
||||
let output = rms_norm(input, weight, 1e-5).output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
|
||||
let input_data = seeded_data(TRANSFORMER_SEQ * TRANSFORMER_HIDDEN, 1.0, -0.5);
|
||||
@@ -1029,7 +1030,7 @@ fn metal_rms_norm() {
|
||||
|
||||
rt.set_data(input, &input_data);
|
||||
rt.set_data(weight, &weight_data);
|
||||
rt = cx.search(rt, 1);
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -1055,7 +1056,7 @@ fn metal_self_attention() {
|
||||
let wo = cx.tensor((TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN));
|
||||
let output = self_attention(input, wq, wk, wv, wo).output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
|
||||
let input_data = seeded_data(TRANSFORMER_SEQ * TRANSFORMER_HIDDEN, 1.0, -0.5);
|
||||
@@ -1069,7 +1070,7 @@ fn metal_self_attention() {
|
||||
rt.set_data(wk, &wk_data);
|
||||
rt.set_data(wv, &wv_data);
|
||||
rt.set_data(wo, &wo_data);
|
||||
rt = cx.search(rt, 1);
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -1114,7 +1115,7 @@ fn metal_self_attention_f16_weights() {
|
||||
.cast(DType::F32)
|
||||
.output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
|
||||
let input_data = seeded_data(TRANSFORMER_SEQ * TRANSFORMER_HIDDEN, 1.0, -0.5);
|
||||
@@ -1128,7 +1129,7 @@ fn metal_self_attention_f16_weights() {
|
||||
rt.set_data(wk, to_f16_vec(&wk_data));
|
||||
rt.set_data(wv, to_f16_vec(&wv_data));
|
||||
rt.set_data(wo, to_f16_vec(&wo_data));
|
||||
rt = cx.search(rt, 1);
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -1160,7 +1161,7 @@ fn metal_swiglu_mlp() {
|
||||
let w_down = cx.tensor((TRANSFORMER_HIDDEN, TRANSFORMER_INTERMEDIATE));
|
||||
let output = swiglu_mlp(input, w_gate, w_up, w_down).output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
|
||||
let input_data = seeded_data(TRANSFORMER_SEQ * TRANSFORMER_HIDDEN, 1.0, -0.5);
|
||||
@@ -1172,7 +1173,7 @@ fn metal_swiglu_mlp() {
|
||||
rt.set_data(w_gate, &gate_data);
|
||||
rt.set_data(w_up, &up_data);
|
||||
rt.set_data(w_down, &down_data);
|
||||
rt = cx.search(rt, 1);
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -1212,7 +1213,7 @@ fn metal_mini_transformer_layer() {
|
||||
let layer = MiniTransformerLayer::init(&mut cx);
|
||||
let output = layer.forward(input).output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
|
||||
let input_data = seeded_data(TRANSFORMER_SEQ * TRANSFORMER_HIDDEN, 1.0, -0.5);
|
||||
@@ -1222,7 +1223,7 @@ fn metal_mini_transformer_layer() {
|
||||
for (tensor, data) in &weight_data {
|
||||
rt.set_data(*tensor, data);
|
||||
}
|
||||
rt = cx.search(rt, 1);
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -1278,7 +1279,7 @@ fn metal_mini_transformer_layer_f16_intermediate() {
|
||||
.cast(DType::F32);
|
||||
let output = (x + mlp_out).output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
|
||||
let input_data = seeded_data(TRANSFORMER_SEQ * TRANSFORMER_HIDDEN, 1.0, -0.5);
|
||||
@@ -1288,7 +1289,7 @@ fn metal_mini_transformer_layer_f16_intermediate() {
|
||||
for (tensor, data) in &weight_data {
|
||||
rt.set_data(*tensor, data);
|
||||
}
|
||||
rt = cx.search(rt, 1);
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -1325,12 +1326,12 @@ fn test_scatter_basic() {
|
||||
let dest = cx.tensor(5);
|
||||
let result = src.scatter(indexes, dest).output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.set_data(src, &[10.0, 20.0, 30.0]);
|
||||
rt.set_data(indexes, &[1.0, 3.0, 4.0]);
|
||||
rt.set_data(dest, &[0.0, 0.0, 0.0, 0.0, 0.0]);
|
||||
rt = cx.search(rt, 1);
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -1347,12 +1348,12 @@ fn test_scatter_buffer_roundtrip() {
|
||||
let cache_out = src.scatter(indexes, cache);
|
||||
let read = cache_out.output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.set_data(src, &[0.0]);
|
||||
rt.set_data(indexes, &[0.0]);
|
||||
rt.set_zeros(cache, 4 * std::mem::size_of::<f32>());
|
||||
rt = cx.search(rt, 1);
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
|
||||
|
||||
for (pos, value, expected) in [
|
||||
(0, 10.0, [10.0, 0.0, 0.0, 0.0]),
|
||||
@@ -1381,12 +1382,12 @@ fn test_load_safetensors_f32_survives_search_and_overrides_input_data() {
|
||||
let tensors = [("weights", Dtype::F32, vec![3], bytes_of(&weight_values))];
|
||||
let path = write_test_safetensors(&tensors);
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.set_data(weights, &[99.0, 99.0, 99.0]);
|
||||
rt.set_data(bias, &[0.5, 1.0, -1.5]);
|
||||
rt.load_safetensors(&cx, path.to_str().unwrap());
|
||||
rt = cx.search(rt, 1);
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -1448,10 +1449,10 @@ fn test_load_safetensors_converts_supported_float_dtypes() {
|
||||
];
|
||||
let path = write_test_safetensors(&tensors);
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.load_safetensors(&cx, path.to_str().unwrap());
|
||||
rt = cx.search(rt, 1);
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -1471,14 +1472,14 @@ fn test_gather_noncontiguous_data_uses_data_shape() {
|
||||
let indexes = cx.tensor((2, 2)).as_dtype(DType::Int);
|
||||
let out = data.gather(indexes).output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.set_data(
|
||||
input,
|
||||
&[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0],
|
||||
);
|
||||
rt.set_data(indexes, &[0.0, 3.0, 4.0, 7.0]);
|
||||
rt = cx.search(rt, 1);
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -1493,12 +1494,12 @@ fn test_scatter_into_nonzero_dest() {
|
||||
let dest = cx.tensor(5);
|
||||
let result = src.scatter(indexes, dest).output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.set_data(src, &[99.0]);
|
||||
rt.set_data(indexes, &[2f32]);
|
||||
rt.set_data(dest, &[1.0, 2.0, 3.0, 4.0, 5.0]);
|
||||
rt = cx.search(rt, 1);
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
|
||||
let kernels = rt.debug_kernel_ops();
|
||||
assert!(
|
||||
kernels.iter().any(|k| k.contains("MetalScatterNoCopy")),
|
||||
@@ -1520,12 +1521,12 @@ fn test_scatter_no_copy_remove_buffer_aliases_dest() {
|
||||
let dest = cx.tensor(5);
|
||||
let result = src.scatter(indexes, dest).output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.set_data(src, &[7.0, 8.0]);
|
||||
rt.set_data(indexes, &[1.0, 3.0]);
|
||||
rt.set_data(dest, &[10.0, 20.0, 30.0, 40.0, 50.0]);
|
||||
rt = cx.search(rt, 1);
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -1549,12 +1550,12 @@ fn test_scatter_no_copy_handles_2d_destination() {
|
||||
let dest = cx.tensor((2, 3));
|
||||
let result = src.scatter(indexes, dest).output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.set_data(src, &[9.0, 8.0]);
|
||||
rt.set_data(indexes, &[2.0, 4.0]);
|
||||
rt.set_data(dest, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
|
||||
rt = cx.search(rt, 1);
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
|
||||
let kernels = rt.debug_kernel_ops();
|
||||
assert!(
|
||||
kernels.iter().any(|k| k.contains("MetalScatterNoCopy")),
|
||||
@@ -1576,12 +1577,12 @@ fn test_scatter_no_copy_not_selected_when_dest_has_another_consumer() {
|
||||
let scatter = src.scatter(indexes, dest).output();
|
||||
let dest_plus_one = (dest + 1.0).output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.set_data(src, &[99.0]);
|
||||
rt.set_data(indexes, &[1.0]);
|
||||
rt.set_data(dest, &[10.0, 20.0, 30.0, 40.0]);
|
||||
rt = cx.search(rt, 1);
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
|
||||
let kernels = rt.debug_kernel_ops();
|
||||
assert!(
|
||||
!kernels.iter().any(|k| k.contains("MetalScatterNoCopy")),
|
||||
@@ -1603,12 +1604,12 @@ fn test_scatter_all_positions() {
|
||||
let dest = cx.tensor(4);
|
||||
let result = src.scatter(indexes, dest).output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.set_data(src, &[40.0, 30.0, 20.0, 10.0]);
|
||||
rt.set_data(indexes, &[3.0, 2.0, 1.0, 0.0]);
|
||||
rt.set_data(dest, &[1.0, 2.0, 3.0, 4.0]);
|
||||
rt = cx.search(rt, 1);
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -1623,11 +1624,11 @@ fn test_gather_preserves_data_dtype() {
|
||||
let indexes = cx.tensor(1).as_dtype(DType::Int);
|
||||
let out = data.gather(indexes).output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.set_data(data, &[1.25, 2.5]);
|
||||
rt.set_data(indexes, &[1.0]);
|
||||
rt = cx.search(rt, 1);
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
|
||||
@@ -166,8 +166,11 @@ mod tests {
|
||||
let indices = cx.tensor(3).as_dtype(DType::Int);
|
||||
let result = gather_rows(data, indices, 3).output();
|
||||
|
||||
cx.build_search_space::<NativeRuntime>();
|
||||
let mut rt = cx.search(NativeRuntime::default(), 1);
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(
|
||||
NativeRuntime::default(),
|
||||
CompileOptions::default().search_graph_limit(1),
|
||||
);
|
||||
|
||||
// data = [[1,2,3], [4,5,6], [7,8,9], [10,11,12]]
|
||||
rt.set_data(
|
||||
@@ -192,8 +195,11 @@ mod tests {
|
||||
let dest = cx.tensor((4, 3));
|
||||
let result = scatter_rows(src, indices, dest, 3).output();
|
||||
|
||||
cx.build_search_space::<NativeRuntime>();
|
||||
let mut rt = cx.search(NativeRuntime::default(), 1);
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(
|
||||
NativeRuntime::default(),
|
||||
CompileOptions::default().search_graph_limit(1),
|
||||
);
|
||||
|
||||
rt.set_data(src.id, vec![10., 20., 30., 40., 50., 60.]);
|
||||
rt.set_data(indices.id, vec![1, 3]);
|
||||
@@ -218,8 +224,11 @@ mod tests {
|
||||
let updated_cache = scatter_rows(kv_new, scatter_idx, cache, 4);
|
||||
let gathered = gather_rows(updated_cache, gather_idx, 4).output();
|
||||
|
||||
cx.build_search_space::<NativeRuntime>();
|
||||
let mut rt = cx.search(NativeRuntime::default(), 1);
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(
|
||||
NativeRuntime::default(),
|
||||
CompileOptions::default().search_graph_limit(1),
|
||||
);
|
||||
|
||||
rt.set_data(kv_new.id, vec![1., 2., 3., 4., 5., 6., 7., 8.]);
|
||||
rt.set_data(scatter_idx.id, vec![1, 4]); // Write to slots 1 and 4
|
||||
@@ -271,8 +280,11 @@ mod tests {
|
||||
let k_cache_new = k_cache_new.output();
|
||||
let v_cache_new = v_cache_new.output();
|
||||
|
||||
cx.build_search_space::<NativeRuntime>();
|
||||
let mut rt = cx.search(NativeRuntime::default(), 1);
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(
|
||||
NativeRuntime::default(),
|
||||
CompileOptions::default().search_graph_limit(1),
|
||||
);
|
||||
|
||||
// Q = [1, 0, 1, 0] → head0=[1,0], head1=[1,0]
|
||||
rt.set_data(q.id, vec![1., 0., 1., 0.]);
|
||||
@@ -344,8 +356,11 @@ mod tests {
|
||||
);
|
||||
let attn_out = attn_out.output();
|
||||
|
||||
cx.build_search_space::<NativeRuntime>();
|
||||
let mut rt = cx.search(NativeRuntime::default(), 1);
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(
|
||||
NativeRuntime::default(),
|
||||
CompileOptions::default().search_graph_limit(1),
|
||||
);
|
||||
|
||||
// Setup: 1 cached token at slot 0, 1 new token written to slot 1
|
||||
// K cached at slot 0: [1, 0]
|
||||
@@ -416,8 +431,11 @@ mod tests {
|
||||
);
|
||||
let attn_out = attn_out.output();
|
||||
|
||||
cx.build_search_space::<NativeRuntime>();
|
||||
let mut rt = cx.search(NativeRuntime::default(), 1);
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(
|
||||
NativeRuntime::default(),
|
||||
CompileOptions::default().search_graph_limit(1),
|
||||
);
|
||||
|
||||
// Cache has 1 token at slot 0
|
||||
let mut k_cache_data = vec![0.; num_slots * kv_dim];
|
||||
|
||||
@@ -183,8 +183,11 @@ mod tests {
|
||||
};
|
||||
let output = moe.forward(input).output();
|
||||
|
||||
cx.build_search_space::<NativeRuntime>();
|
||||
let mut rt = cx.search(NativeRuntime::default(), 1);
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(
|
||||
NativeRuntime::default(),
|
||||
CompileOptions::default().search_graph_limit(1),
|
||||
);
|
||||
|
||||
let input_data = vec![1.0, 2.0, 3.0];
|
||||
// Router strongly favors expert 0
|
||||
@@ -238,8 +241,11 @@ mod tests {
|
||||
};
|
||||
let output = moe.forward(input).output();
|
||||
|
||||
cx.build_search_space::<NativeRuntime>();
|
||||
let mut rt = cx.search(NativeRuntime::default(), 1);
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(
|
||||
NativeRuntime::default(),
|
||||
CompileOptions::default().search_graph_limit(1),
|
||||
);
|
||||
|
||||
let input_data = vec![1.0, 1.0];
|
||||
// Nearly-equal routing to all experts (slight differences to avoid argsort ties)
|
||||
@@ -292,8 +298,11 @@ mod tests {
|
||||
};
|
||||
let output = moe.forward(input).output();
|
||||
|
||||
cx.build_search_space::<NativeRuntime>();
|
||||
let mut rt = cx.search(NativeRuntime::default(), 1);
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(
|
||||
NativeRuntime::default(),
|
||||
CompileOptions::default().search_graph_limit(1),
|
||||
);
|
||||
|
||||
let input_data = vec![
|
||||
1.0, 0.0, 0.0, // batch 0: routes to expert via feature 0
|
||||
@@ -349,8 +358,11 @@ mod tests {
|
||||
};
|
||||
let output = moe.forward(input).output();
|
||||
|
||||
cx.build_search_space::<NativeRuntime>();
|
||||
let mut rt = cx.search(NativeRuntime::default(), 1);
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(
|
||||
NativeRuntime::default(),
|
||||
CompileOptions::default().search_graph_limit(1),
|
||||
);
|
||||
|
||||
let input_data = random_vec(in_dim);
|
||||
let router_data = random_vec(in_dim * n_experts);
|
||||
@@ -394,8 +406,11 @@ mod tests {
|
||||
};
|
||||
let output = moe.forward(input).output();
|
||||
|
||||
cx.build_search_space::<NativeRuntime>();
|
||||
let mut rt = cx.search(NativeRuntime::default(), 1);
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(
|
||||
NativeRuntime::default(),
|
||||
CompileOptions::default().search_graph_limit(1),
|
||||
);
|
||||
|
||||
let input_data = random_vec(batch * in_dim);
|
||||
let router_data = random_vec(in_dim * n_experts);
|
||||
|
||||
@@ -8,7 +8,7 @@ echo "=========================================="
|
||||
echo " Luminal Python: Full Test Suite"
|
||||
echo "=========================================="
|
||||
|
||||
NATIVE_TESTS="tests/test_hlir_ops.py tests/test_unary.py"
|
||||
NATIVE_TESTS="tests/test_hlir_ops.py tests/test_unary.py tests/test_dtype_boundary.py tests/test_torch_dtype_parity.py"
|
||||
CUDA_TESTS="tests/"
|
||||
|
||||
# ── Phase 1: Native Backend ─────────────────────────────────
|
||||
|
||||
@@ -16,7 +16,7 @@ uv run maturin develop --manifest-path rust/Cargo.toml
|
||||
echo "Step 3: Running pytest..."
|
||||
# it is best not to add the full model tests, they end up running billion parameter models
|
||||
# on the CPU and it takes far to long
|
||||
uv run pytest tests/test_hlir_ops.py tests/test_unary.py -v
|
||||
uv run pytest tests/test_hlir_ops.py tests/test_unary.py tests/test_dtype_boundary.py tests/test_torch_dtype_parity.py -v
|
||||
|
||||
echo ""
|
||||
echo "=== Tests Complete ==="
|
||||
|
||||
@@ -5,6 +5,7 @@ use luminal::{
|
||||
visualization::ToDot,
|
||||
};
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::PyBytes;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::typed_data::TypedData;
|
||||
@@ -73,22 +74,13 @@ fn solve_single_var_dim(expr: &Expression, dim_val: usize) -> Option<(char, usiz
|
||||
Some((var, candidate))
|
||||
}
|
||||
|
||||
/// Convert luminal DType to PT2 dtype integer code (for python interop)
|
||||
/// Types without a direct Pytorch equivalent map to the closest safe representation
|
||||
/// Convert luminal `DType` to a PT2 dtype code via `TorchDType`. Panics
|
||||
/// for luminal-specific dtypes that have no PyTorch counterpart (`I4`,
|
||||
/// `U4`, the F6 / F4 families, ...).
|
||||
fn luminal_dtype_to_pt2_code(dtype: DType) -> u32 {
|
||||
match dtype {
|
||||
DType::U8 => 1,
|
||||
DType::I8 => 2,
|
||||
DType::I16 => 3,
|
||||
DType::Int => 4, // i32
|
||||
DType::U16 => 4, // u16 -> i32 (Pytorch has no u16 in older versions)
|
||||
DType::F16 => 6,
|
||||
DType::F32 | DType::TF32 => 7,
|
||||
DType::F64 => 8,
|
||||
DType::Bool => 12,
|
||||
DType::Bf16 => 13,
|
||||
_ => panic!("luminal_dtype_to_pt2_code: unsupported dtype {:?}", dtype),
|
||||
}
|
||||
crate::torch_dtype::TorchDType::try_from(dtype)
|
||||
.map(|t| t.code())
|
||||
.unwrap_or_else(|d| panic!("luminal_dtype_to_pt2_code: unsupported dtype {d:?}"))
|
||||
}
|
||||
|
||||
/// Common intermediate result from translating a model graph.
|
||||
@@ -512,6 +504,65 @@ impl CompiledGraph {
|
||||
Ok(self.runtime.get_output_i32(*node_id))
|
||||
}
|
||||
|
||||
/// Read an output as f16 (returned as raw little-endian bytes —
|
||||
/// Python has no native f16, so the caller bit-casts via
|
||||
/// `torch.frombuffer(..., dtype=torch.float16)`). Strict: the
|
||||
/// producer node must already be `DType::F16`; no widening at
|
||||
/// the read boundary.
|
||||
fn get_output_f16<'py>(&self, py: Python<'py>, name: &str) -> PyResult<Bound<'py, PyBytes>> {
|
||||
let node_id = self.tensor_ids.get(name).ok_or_else(|| {
|
||||
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!(
|
||||
"Unknown output tensor: {}",
|
||||
name
|
||||
))
|
||||
})?;
|
||||
let data = self.runtime.get_output_f16(*node_id);
|
||||
let bytes: &[u8] =
|
||||
unsafe { std::slice::from_raw_parts(data.as_ptr() as *const u8, data.len() * 2) };
|
||||
Ok(PyBytes::new(py, bytes))
|
||||
}
|
||||
|
||||
/// Read an output as bf16 (returned as raw little-endian bytes —
|
||||
/// caller bit-casts via `torch.frombuffer(..., dtype=torch.
|
||||
/// bfloat16)`). Strict: the producer node must already be
|
||||
/// `DType::Bf16`; no widening at the read boundary.
|
||||
fn get_output_bf16<'py>(&self, py: Python<'py>, name: &str) -> PyResult<Bound<'py, PyBytes>> {
|
||||
let node_id = self.tensor_ids.get(name).ok_or_else(|| {
|
||||
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!(
|
||||
"Unknown output tensor: {}",
|
||||
name
|
||||
))
|
||||
})?;
|
||||
let data = self.runtime.get_output_bf16(*node_id);
|
||||
let bytes: &[u8] =
|
||||
unsafe { std::slice::from_raw_parts(data.as_ptr() as *const u8, data.len() * 2) };
|
||||
Ok(PyBytes::new(py, bytes))
|
||||
}
|
||||
|
||||
/// Read an output as i64. Strict: the producer node must already
|
||||
/// be `DType::I64`; no widening at the read boundary.
|
||||
fn get_output_i64(&self, name: &str) -> PyResult<Vec<i64>> {
|
||||
let node_id = self.tensor_ids.get(name).ok_or_else(|| {
|
||||
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!(
|
||||
"Unknown output tensor: {}",
|
||||
name
|
||||
))
|
||||
})?;
|
||||
Ok(self.runtime.get_output_i64(*node_id))
|
||||
}
|
||||
|
||||
/// Read an output as f64. Strict: the producer node must already
|
||||
/// be `DType::F64`; no widening at the read boundary.
|
||||
fn get_output_f64(&self, name: &str) -> PyResult<Vec<f64>> {
|
||||
let node_id = self.tensor_ids.get(name).ok_or_else(|| {
|
||||
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!(
|
||||
"Unknown output tensor: {}",
|
||||
name
|
||||
))
|
||||
})?;
|
||||
Ok(self.runtime.get_output_f64(*node_id))
|
||||
}
|
||||
|
||||
/// Get output tensor data by name as bool (copies to host).
|
||||
fn get_output_bool(&self, name: &str) -> PyResult<Vec<bool>> {
|
||||
let node_id = self.tensor_ids.get(name).ok_or_else(|| {
|
||||
|
||||
120
crates/luminal_python/rust/src/dim_arith.rs
Normal file
120
crates/luminal_python/rust/src/dim_arith.rs
Normal file
@@ -0,0 +1,120 @@
|
||||
//! Canonical-form helpers for dimension `Expression` arithmetic — used
|
||||
//! by the translator to keep shape arithmetic syntactically consistent
|
||||
//! across code paths.
|
||||
//!
|
||||
//! `Expression` equality is syntactic; `a * 8` and `8 * a` are distinct
|
||||
//! objects despite being mathematically equal. When two translator code
|
||||
//! paths build the same logical dim via differently-ordered
|
||||
//! multiplications, downstream `assert_eq!(self.dims(), rhs.dims())`
|
||||
//! checks in `GraphTensor::Add` / `Sub` / `Mul` / `Rem` panic. These
|
||||
//! helpers solve that at the construction site: every shape product
|
||||
//! goes through `product_of_dims`, which sorts the operand list before
|
||||
//! folding, so two callers passing the operands in different orders
|
||||
//! produce identical `Expression`s.
|
||||
//!
|
||||
//! Lives in `luminal_python` (rather than upstream `luminal::shape`) so
|
||||
//! the change is contained to the translator. luminal-core callers of
|
||||
//! `gather_elements` / `scatter_elements` / `scatter_nd` historically
|
||||
//! pass concrete dims, so they don't need this; the translator-local
|
||||
//! lowerings in `translator::movement_dynamic` do.
|
||||
//!
|
||||
//! The ordering matches what `pt2_expr.rs::normalize_mul_expr` was
|
||||
//! using locally before being promoted here — see that file for the
|
||||
//! original canonical-sort logic.
|
||||
|
||||
use luminal::prelude::Expression;
|
||||
|
||||
/// Sort key for the canonical commutative ordering. Sorts by RPN-term
|
||||
/// count first so single-term operands (variables, literals) sort
|
||||
/// before compound subexpressions; ties broken by debug repr so two
|
||||
/// single-term operands have a stable alphabetic order.
|
||||
///
|
||||
/// O(n) string alloc per compare — only call on shape products, never
|
||||
/// per-element in a kernel.
|
||||
#[inline]
|
||||
pub(crate) fn commutative_key(expr: &Expression) -> (usize, String) {
|
||||
(expr.len(), format!("{expr:?}"))
|
||||
}
|
||||
|
||||
/// Order `(a, b)` so the canonically-smaller expression is first.
|
||||
#[inline]
|
||||
pub(crate) fn sort_pair(a: Expression, b: Expression) -> (Expression, Expression) {
|
||||
if commutative_key(&a) <= commutative_key(&b) {
|
||||
(a, b)
|
||||
} else {
|
||||
(b, a)
|
||||
}
|
||||
}
|
||||
|
||||
/// Multiply two dim expressions with canonical operand ordering.
|
||||
#[inline]
|
||||
pub(crate) fn mul_dims(a: Expression, b: Expression) -> Expression {
|
||||
let (a, b) = sort_pair(a, b);
|
||||
a * b
|
||||
}
|
||||
|
||||
/// Add two dim expressions with canonical operand ordering.
|
||||
#[inline]
|
||||
pub(crate) fn add_dims(a: Expression, b: Expression) -> Expression {
|
||||
let (a, b) = sort_pair(a, b);
|
||||
a + b
|
||||
}
|
||||
|
||||
/// Product of a sequence of dim expressions. Operands are sorted
|
||||
/// canonically before folding so callers passing the same logical
|
||||
/// dim set in different orders produce identical `Expression`s.
|
||||
/// Empty sequence → `Expression::from(1usize)`.
|
||||
pub(crate) fn product_of_dims<I>(dims: I) -> Expression
|
||||
where
|
||||
I: IntoIterator<Item = Expression>,
|
||||
{
|
||||
let mut v: Vec<Expression> = dims.into_iter().collect();
|
||||
v.sort_by_key(commutative_key);
|
||||
v.into_iter()
|
||||
.fold(Expression::from(1usize), |acc, d| acc * d)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn mul_dims_canonicalises_commutative_order() {
|
||||
let a = Expression::from('a');
|
||||
let n = Expression::from(8i64);
|
||||
assert_eq!(mul_dims(a, n), mul_dims(n, a));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn product_of_dims_independent_of_input_order() {
|
||||
let a = Expression::from('a');
|
||||
let b = Expression::from('b');
|
||||
let n = Expression::from(8i64);
|
||||
let p1 = product_of_dims([a, n, b]);
|
||||
let p2 = product_of_dims([n, b, a]);
|
||||
let p3 = product_of_dims([b, a, n]);
|
||||
assert_eq!(p1, p2);
|
||||
assert_eq!(p1, p3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_product_is_one() {
|
||||
let empty: Vec<Expression> = vec![];
|
||||
assert_eq!(product_of_dims(empty), Expression::from(1usize));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mixed_numeric_types_canonicalise_together() {
|
||||
// `pt2_util` builds with `Expression::from(usize)` while tests /
|
||||
// direct callers reach for `i64`. The two literal paths must
|
||||
// produce identical reprs or `product_of_dims` will sort them
|
||||
// into different positions and we lose the canonical-form
|
||||
// guarantee across call sites.
|
||||
assert_eq!(Expression::from(8usize), Expression::from(8i64));
|
||||
let a = Expression::from('a');
|
||||
assert_eq!(
|
||||
product_of_dims([Expression::from(8usize), a]),
|
||||
product_of_dims([Expression::from(8i64), a]),
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,6 @@
|
||||
mod compiled_graph;
|
||||
mod dim_arith;
|
||||
pub mod torch_dtype;
|
||||
pub mod typed_data;
|
||||
|
||||
// PT2 modules
|
||||
@@ -13,17 +15,32 @@ use compiled_graph::CompiledGraph;
|
||||
use pt2_compiled_model::process_pt2;
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::PyCapsule;
|
||||
use std::collections::HashMap;
|
||||
use torch_dtype::TorchDType;
|
||||
|
||||
#[pymodule]
|
||||
fn luminal(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_function(wrap_pyfunction!(process_pt2, m)?)?;
|
||||
m.add_class::<CompiledGraph>()?;
|
||||
m.add_function(wrap_pyfunction!(_native_factory_capsule, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(_torch_dtype_codes, m)?)?;
|
||||
#[cfg(feature = "cuda")]
|
||||
m.add_function(wrap_pyfunction!(_cuda_lite_factory_capsule, m)?)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// `{variant_name: pt2_code}` for every `TorchDType` variant. The Python
|
||||
/// parity test (`tests/test_torch_dtype_parity.py`) consumes this and
|
||||
/// asserts every entry matches `torch._export.serde.schema.ScalarType.<name>
|
||||
/// .value` — drift fails CI rather than silently miscompiling at runtime.
|
||||
#[pyfunction]
|
||||
fn _torch_dtype_codes() -> HashMap<&'static str, u32> {
|
||||
TorchDType::ALL
|
||||
.iter()
|
||||
.map(|v| (v.name(), v.code()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Factory capsule helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
@@ -7,10 +7,10 @@ use std::collections::HashMap;
|
||||
|
||||
use crate::compiled_graph::{CompiledGraph, DimParamMap, GraphTranslation, WeightData};
|
||||
use crate::pt2_expr::parse_sympy_expr;
|
||||
use crate::pt2_parser;
|
||||
use crate::pt2_schema;
|
||||
use crate::translator;
|
||||
use crate::typed_data::TypedData;
|
||||
use crate::{pt2_parser, pt2_util};
|
||||
|
||||
/// Pre-loaded weight/constant data paired with tensor sizes.
|
||||
type PreloadResult = (Vec<(String, TypedData)>, HashMap<String, usize>);
|
||||
@@ -374,52 +374,10 @@ fn safetensors_dtype_to_pt2(dtype: safetensors::Dtype) -> u32 {
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert raw bytes to TypedData using PT2 dtype numbering.
|
||||
/// Preserves native byte format for types luminal supports directly (f32, f16, bf16, i32, bool, u8, i8).
|
||||
/// Converts i64/f64/i16 to the closest luminal-native representation.
|
||||
/// Convert raw bytes to `TypedData` using PT2 dtype numbering. Thin
|
||||
/// wrapper around `TypedData::from_pytorch_bytes` — the dtype dispatch
|
||||
/// (including the narrow-int panic and unknown-code rejection) lives
|
||||
/// there, so this site stays a one-liner that just clones the slice.
|
||||
fn bytes_to_typed(bytes: &[u8], dtype: u32) -> TypedData {
|
||||
match dtype {
|
||||
// Types that map directly — preserve raw bytes
|
||||
7 => TypedData::from_raw(bytes.to_vec(), DType::F32),
|
||||
6 => TypedData::from_raw(bytes.to_vec(), DType::F16),
|
||||
13 => TypedData::from_raw(bytes.to_vec(), DType::Bf16),
|
||||
4 => TypedData::from_raw(bytes.to_vec(), DType::Int), // i32
|
||||
1 => TypedData::from_raw(bytes.to_vec(), DType::U8),
|
||||
2 => TypedData::from_raw(bytes.to_vec(), DType::I8),
|
||||
12 => TypedData::from_raw(bytes.to_vec(), DType::Bool),
|
||||
|
||||
// i64 → i32 (truncate, matching luminal's Int type)
|
||||
5 => {
|
||||
let i32s: Vec<i32> = bytes
|
||||
.chunks_exact(8)
|
||||
.map(|b| {
|
||||
i64::from_le_bytes([b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]]) as i32
|
||||
})
|
||||
.collect();
|
||||
TypedData::from_i32_vec(i32s)
|
||||
}
|
||||
// f64 → f32 (downcast, luminal has no F64 in practice for most ops)
|
||||
8 => {
|
||||
let f32s: Vec<f32> = bytes
|
||||
.chunks_exact(8)
|
||||
.map(|b| {
|
||||
f64::from_le_bytes([b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]]) as f32
|
||||
})
|
||||
.collect();
|
||||
TypedData::from_f32_vec(f32s)
|
||||
}
|
||||
// i16 → i32 (widen to luminal's Int)
|
||||
3 => {
|
||||
let i32s: Vec<i32> = bytes
|
||||
.chunks_exact(2)
|
||||
.map(|b| i16::from_le_bytes([b[0], b[1]]) as i32)
|
||||
.collect();
|
||||
TypedData::from_i32_vec(i32s)
|
||||
}
|
||||
_ => {
|
||||
let luminal_dtype = pt2_util::torch_dtype_int_to_luminal(dtype);
|
||||
warn!("Unrecognized dtype {dtype}, interpreting as {luminal_dtype:?}");
|
||||
TypedData::from_raw(bytes.to_vec(), luminal_dtype)
|
||||
}
|
||||
}
|
||||
TypedData::from_pytorch_bytes(bytes.to_vec(), dtype)
|
||||
}
|
||||
|
||||
@@ -251,26 +251,12 @@ fn normalize_expr(expr: Expression) -> Expression {
|
||||
}
|
||||
}
|
||||
|
||||
fn commutative_key(expr: Expression) -> (usize, String) {
|
||||
(expr.len(), format!("{expr:?}"))
|
||||
}
|
||||
|
||||
fn sort_commutative(lhs: Expression, rhs: Expression) -> (Expression, Expression) {
|
||||
if commutative_key(lhs) <= commutative_key(rhs) {
|
||||
(lhs, rhs)
|
||||
} else {
|
||||
(rhs, lhs)
|
||||
}
|
||||
}
|
||||
|
||||
fn normalize_add_expr(lhs: Expression, rhs: Expression) -> Expression {
|
||||
let (lhs, rhs) = sort_commutative(lhs, rhs);
|
||||
normalize_expr(lhs + rhs)
|
||||
normalize_expr(crate::dim_arith::add_dims(lhs, rhs))
|
||||
}
|
||||
|
||||
fn normalize_mul_expr(lhs: Expression, rhs: Expression) -> Expression {
|
||||
let (lhs, rhs) = sort_commutative(lhs, rhs);
|
||||
normalize_expr(lhs * rhs)
|
||||
normalize_expr(crate::dim_arith::mul_dims(lhs, rhs))
|
||||
}
|
||||
|
||||
fn checked_add_opt(lhs: Option<i64>, rhs: Option<i64>) -> Option<i64> {
|
||||
|
||||
@@ -114,29 +114,17 @@ pub fn resolve_neg1_dim(target: &[i64], current_dims: &[Expression]) -> Vec<Expr
|
||||
}
|
||||
|
||||
if let Some(idx) = neg1_idx {
|
||||
let mut total = Expression::from(1usize);
|
||||
for d in current_dims {
|
||||
total *= *d;
|
||||
}
|
||||
if let (Some(total_val), Some(_)) = (
|
||||
{
|
||||
let mut t = 1i64;
|
||||
let mut all_concrete = true;
|
||||
for d in current_dims {
|
||||
if let Some(v) = d.to_usize() {
|
||||
t *= v as i64;
|
||||
} else {
|
||||
all_concrete = false;
|
||||
}
|
||||
}
|
||||
if all_concrete { Some(t) } else { None }
|
||||
},
|
||||
Some(known_product),
|
||||
) {
|
||||
result[idx] = Expression::from((total_val / known_product) as usize);
|
||||
} else {
|
||||
result[idx] = total / Expression::from(known_product as usize);
|
||||
}
|
||||
result[idx] = match current_dims
|
||||
.iter()
|
||||
.map(|d| d.to_usize())
|
||||
.collect::<Option<Vec<_>>>()
|
||||
{
|
||||
Some(vs) => Expression::from(vs.iter().product::<usize>() / known_product as usize),
|
||||
None => {
|
||||
crate::dim_arith::product_of_dims(current_dims.iter().copied())
|
||||
/ Expression::from(known_product as usize)
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
result
|
||||
@@ -185,11 +173,12 @@ pub fn resolve_neg1_dim_exprs(
|
||||
if input_symbolic.is_empty() {
|
||||
result[idx] = Expression::from((input_concrete / target_concrete) as usize);
|
||||
} else {
|
||||
let mut expr = Expression::from((input_concrete / target_concrete) as usize);
|
||||
for s in &input_symbolic {
|
||||
expr *= *s;
|
||||
}
|
||||
result[idx] = expr;
|
||||
let mut operands: Vec<Expression> = Vec::with_capacity(input_symbolic.len() + 1);
|
||||
operands.push(Expression::from(
|
||||
(input_concrete / target_concrete) as usize,
|
||||
));
|
||||
operands.extend(input_symbolic.iter().copied());
|
||||
result[idx] = crate::dim_arith::product_of_dims(operands);
|
||||
}
|
||||
|
||||
result
|
||||
@@ -198,16 +187,29 @@ pub fn resolve_neg1_dim_exprs(
|
||||
}
|
||||
}
|
||||
|
||||
/// Map torch dtype integer (PT2 format) to luminal DType.
|
||||
/// PT2 numbering: 1=uint8, 2=int8, 3=int16, 4=int32, 5=int64, 6=float16, 7=float32, 8=float64, 12=bool, 13=bfloat16
|
||||
/// Map a PT2 dtype code to luminal `DType`. Panics for variants the IR
|
||||
/// doesn't model as first-class types (narrow ints `Byte` / `Char` /
|
||||
/// `Short`, the complex family, the float8 family) and for unknown
|
||||
/// codes — better to fail loudly at the translator boundary than to
|
||||
/// silently widen and lie about the user's dtype.
|
||||
pub fn torch_dtype_int_to_luminal(dtype: u32) -> DType {
|
||||
match dtype {
|
||||
6 => DType::F16,
|
||||
7 => DType::F32,
|
||||
8 => DType::F32, // float64 → F32 (no F64 in luminal)
|
||||
13 => DType::Bf16,
|
||||
12 => DType::Bool,
|
||||
1..=5 => DType::Int, // uint8, int8, int16, int32, int64
|
||||
_ => DType::F32,
|
||||
let t = crate::torch_dtype::TorchDType::from_code(dtype)
|
||||
.unwrap_or_else(|c| panic!("torch_dtype_int_to_luminal: unknown PT2 dtype code {c}"));
|
||||
match t {
|
||||
crate::torch_dtype::TorchDType::Byte
|
||||
| crate::torch_dtype::TorchDType::Char
|
||||
| crate::torch_dtype::TorchDType::Short => panic!(
|
||||
"torch_dtype_int_to_luminal: PT2 dtype {} (code {}) isn't a first-class \
|
||||
IR type yet — cast to torch.int32 at the call site, or wait for the \
|
||||
narrower-int IR follow-up.",
|
||||
t.name(),
|
||||
t.code(),
|
||||
),
|
||||
other => DType::try_from(other).unwrap_or_else(|t| {
|
||||
panic!(
|
||||
"torch_dtype_int_to_luminal: {} isn't a first-class luminal IR type",
|
||||
t.name()
|
||||
)
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
235
crates/luminal_python/rust/src/torch_dtype.rs
Normal file
235
crates/luminal_python/rust/src/torch_dtype.rs
Normal file
@@ -0,0 +1,235 @@
|
||||
//! Typed mirror of PyTorch's PT2 export-schema `ScalarType` enum.
|
||||
//!
|
||||
//! The PT2 export pipeline wire-serializes tensor dtypes as `u32` codes drawn
|
||||
//! from `torch._export.serde.schema.ScalarType` (an `IntEnum` on the Python
|
||||
//! side). Three sites in this crate used to carry duplicate raw-`u32` match
|
||||
//! arms with the canonical numbering hand-rolled in each — silent miscompile
|
||||
//! risk when PyTorch renumbers or adds a code. This module collapses those
|
||||
//! sites onto one typed enum and pins the numbering with a parity test that
|
||||
//! asserts every Rust variant matches `torch._export.serde.schema.ScalarType`
|
||||
//! at CI time (see `crates/luminal_python/tests/test_torch_dtype_parity.py`).
|
||||
//!
|
||||
//! Note: PyTorch's C++ `c10::ScalarType` uses a different numbering than the
|
||||
//! PT2 schema (PT2 reserves 0 for `Unknown`); we bind to the **PT2 schema**,
|
||||
//! not the c10 header, because that is what flows over our wire.
|
||||
|
||||
use luminal::prelude::DType;
|
||||
|
||||
/// PT2 export-schema dtype code. Discriminants match
|
||||
/// `torch._export.serde.schema.ScalarType` variant values exactly; drift is
|
||||
/// caught by `tests/test_torch_dtype_parity.py`.
|
||||
#[repr(u32)]
|
||||
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
|
||||
pub enum TorchDType {
|
||||
Unknown = 0,
|
||||
Byte = 1,
|
||||
Char = 2,
|
||||
Short = 3,
|
||||
Int = 4,
|
||||
Long = 5,
|
||||
Half = 6,
|
||||
Float = 7,
|
||||
Double = 8,
|
||||
ComplexHalf = 9,
|
||||
ComplexFloat = 10,
|
||||
ComplexDouble = 11,
|
||||
Bool = 12,
|
||||
BFloat16 = 13,
|
||||
Uint16 = 28,
|
||||
Float8E4m3Fn = 29,
|
||||
Float8E5m2 = 30,
|
||||
Float8E4m3Fnuz = 31,
|
||||
Float8E5m2Fnuz = 32,
|
||||
}
|
||||
|
||||
impl TorchDType {
|
||||
/// All variants, in declaration order. Used by the pyo3-exported parity
|
||||
/// table and by tests; add new variants here when PyTorch adds them.
|
||||
pub const ALL: &'static [TorchDType] = &[
|
||||
TorchDType::Unknown,
|
||||
TorchDType::Byte,
|
||||
TorchDType::Char,
|
||||
TorchDType::Short,
|
||||
TorchDType::Int,
|
||||
TorchDType::Long,
|
||||
TorchDType::Half,
|
||||
TorchDType::Float,
|
||||
TorchDType::Double,
|
||||
TorchDType::ComplexHalf,
|
||||
TorchDType::ComplexFloat,
|
||||
TorchDType::ComplexDouble,
|
||||
TorchDType::Bool,
|
||||
TorchDType::BFloat16,
|
||||
TorchDType::Uint16,
|
||||
TorchDType::Float8E4m3Fn,
|
||||
TorchDType::Float8E5m2,
|
||||
TorchDType::Float8E4m3Fnuz,
|
||||
TorchDType::Float8E5m2Fnuz,
|
||||
];
|
||||
|
||||
/// Canonical wire code (matches `ScalarType.<name>.value` in Python).
|
||||
#[inline]
|
||||
pub fn code(self) -> u32 {
|
||||
self as u32
|
||||
}
|
||||
|
||||
/// PyTorch schema variant name (e.g. `"LONG"`, `"BFLOAT16"`). Used by the
|
||||
/// parity test to align Rust variants with `ScalarType.<name>`.
|
||||
pub fn name(self) -> &'static str {
|
||||
match self {
|
||||
TorchDType::Unknown => "UNKNOWN",
|
||||
TorchDType::Byte => "BYTE",
|
||||
TorchDType::Char => "CHAR",
|
||||
TorchDType::Short => "SHORT",
|
||||
TorchDType::Int => "INT",
|
||||
TorchDType::Long => "LONG",
|
||||
TorchDType::Half => "HALF",
|
||||
TorchDType::Float => "FLOAT",
|
||||
TorchDType::Double => "DOUBLE",
|
||||
TorchDType::ComplexHalf => "COMPLEXHALF",
|
||||
TorchDType::ComplexFloat => "COMPLEXFLOAT",
|
||||
TorchDType::ComplexDouble => "COMPLEXDOUBLE",
|
||||
TorchDType::Bool => "BOOL",
|
||||
TorchDType::BFloat16 => "BFLOAT16",
|
||||
TorchDType::Uint16 => "UINT16",
|
||||
TorchDType::Float8E4m3Fn => "FLOAT8E4M3FN",
|
||||
TorchDType::Float8E5m2 => "FLOAT8E5M2",
|
||||
TorchDType::Float8E4m3Fnuz => "FLOAT8E4M3FNUZ",
|
||||
TorchDType::Float8E5m2Fnuz => "FLOAT8E5M2FNUZ",
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse from a wire code. `Err(code)` if the code isn't a known PyTorch
|
||||
/// variant — the caller decides whether to panic with context or fall
|
||||
/// through to a non-PT2 path.
|
||||
pub fn from_code(code: u32) -> Result<Self, u32> {
|
||||
for v in Self::ALL {
|
||||
if v.code() == code {
|
||||
return Ok(*v);
|
||||
}
|
||||
}
|
||||
Err(code)
|
||||
}
|
||||
}
|
||||
|
||||
/// PyTorch dtype → luminal `DType`. `Err(self)` for variants luminal's IR
|
||||
/// doesn't model as first-class types — the narrow ints (`Byte` / `Char` /
|
||||
/// `Short`), the complex family, and the float8 NUZ variants. `DType::U8`,
|
||||
/// `DType::I8`, `DType::I16` exist on the luminal side but the IR has no
|
||||
/// kernels / codegen for them, so we refuse the conversion here rather
|
||||
/// than silently producing a buffer the kernels can't actually run.
|
||||
/// Boundary code panics with the variant name on `Err`; cf.
|
||||
/// `typed_data::from_pytorch_bytes`, `pt2_util::torch_dtype_int_to_luminal`.
|
||||
impl TryFrom<TorchDType> for DType {
|
||||
type Error = TorchDType;
|
||||
fn try_from(t: TorchDType) -> Result<Self, Self::Error> {
|
||||
Ok(match t {
|
||||
TorchDType::Int => DType::Int,
|
||||
TorchDType::Long => DType::I64,
|
||||
TorchDType::Half => DType::F16,
|
||||
TorchDType::Float => DType::F32,
|
||||
TorchDType::Double => DType::F64,
|
||||
TorchDType::Bool => DType::Bool,
|
||||
TorchDType::BFloat16 => DType::Bf16,
|
||||
TorchDType::Float8E4m3Fn => DType::F8E4M3,
|
||||
TorchDType::Float8E5m2 => DType::F8E5M2,
|
||||
TorchDType::Byte
|
||||
| TorchDType::Char
|
||||
| TorchDType::Short
|
||||
| TorchDType::Uint16
|
||||
| TorchDType::Unknown
|
||||
| TorchDType::ComplexHalf
|
||||
| TorchDType::ComplexFloat
|
||||
| TorchDType::ComplexDouble
|
||||
| TorchDType::Float8E4m3Fnuz
|
||||
| TorchDType::Float8E5m2Fnuz => return Err(t),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// luminal `DType` → PyTorch dtype. `Err(dtype)` for luminal-specific
|
||||
/// variants without a first-class PyTorch counterpart — the narrow ints
|
||||
/// (`U8` / `I8` / `I16` / `U16`), the sub-byte / exotic widths (`I4`,
|
||||
/// `U4`, `F6E2M3`, ...), and `TF32`.
|
||||
///
|
||||
/// `TF32` is a compute-mode hint inside luminal, not a storage dtype on
|
||||
/// the PyTorch side (PyTorch has no `torch.tf32`); silently mapping it to
|
||||
/// `Float` would hand PyTorch an f32 buffer that the caller had been
|
||||
/// tracking as TF32 inside luminal. Refuse instead — a real cast to
|
||||
/// `DType::F32` upstream is the explicit way to bridge.
|
||||
impl TryFrom<DType> for TorchDType {
|
||||
type Error = DType;
|
||||
fn try_from(d: DType) -> Result<Self, Self::Error> {
|
||||
Ok(match d {
|
||||
DType::F32 => TorchDType::Float,
|
||||
DType::F64 => TorchDType::Double,
|
||||
DType::F16 => TorchDType::Half,
|
||||
DType::Bf16 => TorchDType::BFloat16,
|
||||
DType::Int => TorchDType::Int,
|
||||
DType::I64 => TorchDType::Long,
|
||||
DType::Bool => TorchDType::Bool,
|
||||
DType::F8E4M3 => TorchDType::Float8E4m3Fn,
|
||||
DType::F8E5M2 => TorchDType::Float8E5m2,
|
||||
_ => return Err(d),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn roundtrip_codes() {
|
||||
for v in TorchDType::ALL {
|
||||
assert_eq!(TorchDType::from_code(v.code()).unwrap(), *v);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn supported_dtypes_roundtrip() {
|
||||
// Only the variants luminal's IR models as first-class can
|
||||
// roundtrip cleanly. Narrow ints (`U8` / `I8` / `I16` / `U16`)
|
||||
// are intentionally excluded — see the `TryFrom` impls.
|
||||
for d in [
|
||||
DType::F32,
|
||||
DType::F64,
|
||||
DType::F16,
|
||||
DType::Bf16,
|
||||
DType::Int,
|
||||
DType::I64,
|
||||
DType::Bool,
|
||||
] {
|
||||
let t = TorchDType::try_from(d).expect("known DType");
|
||||
let back = DType::try_from(t).expect("known TorchDType");
|
||||
assert_eq!(d, back, "roundtrip mismatch for {d:?}");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn narrow_ints_refuse_conversion() {
|
||||
// Forward (PyTorch → luminal) and reverse (luminal → PyTorch)
|
||||
// both refuse the narrow-int variants; downstream sites translate
|
||||
// the `Err` into a typed panic with the variant name.
|
||||
for t in [TorchDType::Byte, TorchDType::Char, TorchDType::Short] {
|
||||
assert!(DType::try_from(t).is_err(), "expected Err for {t:?}");
|
||||
}
|
||||
for d in [
|
||||
DType::U8,
|
||||
DType::I8,
|
||||
DType::I16,
|
||||
DType::U16,
|
||||
// TF32 is a luminal-internal compute-mode hint, not a PyTorch
|
||||
// storage dtype — refuse to silently alias it as `Float`.
|
||||
DType::TF32,
|
||||
] {
|
||||
assert!(TorchDType::try_from(d).is_err(), "expected Err for {d:?}");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unknown_code_errors() {
|
||||
assert!(TorchDType::from_code(99).is_err());
|
||||
assert!(TorchDType::from_code(14).is_err()); // gap in PT2 numbering
|
||||
}
|
||||
}
|
||||
@@ -175,7 +175,11 @@ impl<'a> Translator<'a> {
|
||||
"torch.ops.aten.pow.Tensor_Scalar" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let exp = self.get_float_arg(node, 1)?;
|
||||
a.pow(exp as f32)
|
||||
if (exp - 2.0).abs() < f64::EPSILON {
|
||||
a * a
|
||||
} else {
|
||||
a.pow(exp as f32)
|
||||
}
|
||||
}
|
||||
"torch.ops.aten.pow.Tensor_Tensor" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
|
||||
@@ -7,6 +7,7 @@ mod binary;
|
||||
mod conv;
|
||||
mod dispatch;
|
||||
mod movement;
|
||||
mod movement_dynamic;
|
||||
mod reduction;
|
||||
mod tensor;
|
||||
mod unary;
|
||||
|
||||
@@ -306,7 +306,11 @@ impl<'a> Translator<'a> {
|
||||
let mut target: Vec<Expression> = src_dims.to_vec();
|
||||
target[first_non_none_dim] = idx_dim_size;
|
||||
expanded.shape.expand(target);
|
||||
return Ok(source.gather_elements(expanded, first_non_none_dim));
|
||||
return Ok(super::movement_dynamic::pt2_gather_elements(
|
||||
source,
|
||||
expanded,
|
||||
first_non_none_dim,
|
||||
));
|
||||
}
|
||||
} else {
|
||||
bail!(
|
||||
@@ -426,7 +430,7 @@ impl<'a> Translator<'a> {
|
||||
let is_negative = indices_int.lt(zero).cast(DType::Int);
|
||||
let normalized = indices_int + is_negative * axis_dim;
|
||||
|
||||
let result = a.gather_elements(normalized, dim);
|
||||
let result = super::movement_dynamic::pt2_gather_elements(a, normalized, dim);
|
||||
Ok(if promoted_rank0 {
|
||||
result.squeeze(0)
|
||||
} else {
|
||||
@@ -440,7 +444,12 @@ impl<'a> Translator<'a> {
|
||||
let dim = normalize_dim(dim, a.shape.len());
|
||||
let indices = self.get_input_tensor(node, 2)?;
|
||||
let src = self.get_input_tensor(node, 3)?;
|
||||
Ok(a.scatter_elements(indices.cast(DType::Int), src, dim))
|
||||
Ok(super::movement_dynamic::pt2_scatter_elements(
|
||||
a,
|
||||
indices.cast(DType::Int),
|
||||
src,
|
||||
dim,
|
||||
))
|
||||
}
|
||||
|
||||
pub(crate) fn translate_scatter_value(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
@@ -463,7 +472,12 @@ impl<'a> Translator<'a> {
|
||||
bail!("scatter.value: unsupported scalar argument {:?}", value_arg);
|
||||
}
|
||||
.expand_rhs(indices.shape);
|
||||
Ok(a.scatter_elements(indices.cast(DType::Int), value, dim))
|
||||
Ok(super::movement_dynamic::pt2_scatter_elements(
|
||||
a,
|
||||
indices.cast(DType::Int),
|
||||
value,
|
||||
dim,
|
||||
))
|
||||
}
|
||||
|
||||
pub(crate) fn translate_index_put(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
@@ -508,7 +522,7 @@ impl<'a> Translator<'a> {
|
||||
let indices = idx_tensor.cast(DType::Int);
|
||||
let new_last = indices.shape.len();
|
||||
let indices = indices.expand_dim(new_last, Expression::from(1usize));
|
||||
Ok(a.scatter_nd(indices, values))
|
||||
Ok(super::movement_dynamic::pt2_scatter_nd(a, indices, values))
|
||||
} else {
|
||||
bail!("index_put with multiple index tensors not yet supported");
|
||||
}
|
||||
|
||||
231
crates/luminal_python/rust/src/translator/movement_dynamic.rs
Normal file
231
crates/luminal_python/rust/src/translator/movement_dynamic.rs
Normal file
@@ -0,0 +1,231 @@
|
||||
//! Symbolic-dim-safe `gather_elements` / `scatter_elements` / `scatter_nd`
|
||||
//! lowerings for the PT2 translator.
|
||||
//!
|
||||
//! The luminal-core versions in `luminal::frontend::movement` require
|
||||
//! concrete shape dims — they call `d.to_usize().expect(...)` on every
|
||||
//! input dim and panic at translate-time when `torch.compile` hands us a
|
||||
//! batch dim, sequence-length dim, or any other dynamic dim. PT2's whole
|
||||
//! point is dynamic shapes, so we re-implement the same three ops here
|
||||
//! using `Expression`-typed shape arithmetic and only call luminal-core
|
||||
//! primitives that already accept `Expression`s (`Graph::constant`,
|
||||
//! `Graph::iota`, `flatten_strides`, `ShapeTracker::new(Vec<Expression>)`,
|
||||
//! `expand_dim`, `expand_rhs`, `flatten`, `slice_along`, `squeeze`,
|
||||
//! `cast`, `scatter`, `gather`).
|
||||
//!
|
||||
//! Every shape product flows through `crate::dim_arith::product_of_dims`
|
||||
//! so the `Expression`s we build are canonical: two callers that produce
|
||||
//! the same logical dim via differently-ordered multiplications end up
|
||||
//! with byte-identical `Expression`s. Without this, downstream dim-equality
|
||||
//! asserts in luminal-core's `Add` / `Sub` (see `src/frontend/binary.rs`)
|
||||
//! panic on `a*8` ≠ `8*a` after these helpers feed into broadcast paths.
|
||||
|
||||
use luminal::prelude::*;
|
||||
|
||||
use crate::dim_arith::product_of_dims;
|
||||
|
||||
/// Row-major strides as `Expression`s. `stride[i] = prod(dims[i+1..])`.
|
||||
fn row_major_strides(dims: &[Expression]) -> Vec<Expression> {
|
||||
let rank = dims.len();
|
||||
(0..rank)
|
||||
.map(|i| product_of_dims(dims[i + 1..].iter().copied()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Build the additive non-axis contribution to a flat index over a
|
||||
/// rank-`rank` output of shape `out_shape`. The axis dim contributes
|
||||
/// 0; every other dim `d` contributes `iota_d * strides[d]`. Materialised
|
||||
/// via one `Graph::iota` call with `flatten_strides(out_shape, axis_exprs)`
|
||||
/// — same pattern luminal core uses, just with `Expression` throughout.
|
||||
fn non_axis_flat(
|
||||
graph: &mut Graph,
|
||||
out_shape: &[Expression],
|
||||
strides: &[Expression],
|
||||
axis: usize,
|
||||
) -> GraphTensor {
|
||||
let rank = out_shape.len();
|
||||
let axis_exprs: Vec<Expression> = (0..rank)
|
||||
.map(|d| {
|
||||
if d == axis {
|
||||
Expression::from(0)
|
||||
} else {
|
||||
Expression::from('z') * strides[d]
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
graph.iota(flatten_strides(out_shape, &axis_exprs), out_shape.to_vec())
|
||||
}
|
||||
|
||||
/// Wrap negative axis indices into `[0, axis_dim)`. Equivalent to
|
||||
/// `if idx < 0 { idx + axis_dim } else { idx }` in tensor form.
|
||||
fn normalize_negative_index(indices: GraphTensor, axis_dim: Expression) -> GraphTensor {
|
||||
let idx_f32 = indices.cast(DType::F32);
|
||||
let zero = idx_f32
|
||||
.graph()
|
||||
.constant_float(0.0)
|
||||
.expand_rhs(idx_f32.shape);
|
||||
let adj = idx_f32
|
||||
.graph()
|
||||
.constant(axis_dim)
|
||||
.cast(DType::F32)
|
||||
.expand_rhs(idx_f32.shape);
|
||||
let is_neg = idx_f32.lt(zero).cast(DType::F32);
|
||||
(idx_f32 + (is_neg * adj)).cast(DType::Int)
|
||||
}
|
||||
|
||||
/// Translator-local `gather_elements` that accepts symbolic shape dims.
|
||||
/// Mirrors `GraphTensor::gather_elements` semantics but uses
|
||||
/// `Expression`-typed shape arithmetic and only calls symbol-safe
|
||||
/// luminal-core primitives.
|
||||
///
|
||||
/// `output[i0,..,ik] = self[i0,..,i_{axis-1}, indices[i0,..,ik], i_{axis+1},..,ik]`
|
||||
pub fn pt2_gather_elements(data: GraphTensor, indexes: GraphTensor, axis: usize) -> GraphTensor {
|
||||
let dims = data.dims();
|
||||
let out_shape: Vec<Expression> = indexes.dims();
|
||||
let strides = row_major_strides(&dims);
|
||||
|
||||
let idx_normalized = normalize_negative_index(indexes, dims[axis]);
|
||||
let non_axis_flat = non_axis_flat(data.graph(), &out_shape, &strides, axis);
|
||||
|
||||
let stride_tensor = data
|
||||
.graph()
|
||||
.constant(strides[axis])
|
||||
.expand_rhs(idx_normalized.shape);
|
||||
let flat_idx = non_axis_flat + idx_normalized * stride_tensor;
|
||||
|
||||
data.gather(flat_idx)
|
||||
}
|
||||
|
||||
/// Translator-local `scatter_elements` that accepts symbolic shape dims.
|
||||
/// Same semantics as `GraphTensor::scatter_elements`.
|
||||
pub fn pt2_scatter_elements(
|
||||
data: GraphTensor,
|
||||
indices: GraphTensor,
|
||||
updates: GraphTensor,
|
||||
axis: usize,
|
||||
) -> GraphTensor {
|
||||
let data_dims = data.dims();
|
||||
let idx_shape: Vec<Expression> = indices.dims();
|
||||
let strides = row_major_strides(&data_dims);
|
||||
|
||||
let idx_normalized = normalize_negative_index(indices, data_dims[axis]);
|
||||
let non_axis_flat = non_axis_flat(data.graph(), &idx_shape, &strides, axis);
|
||||
|
||||
let stride_tensor = data
|
||||
.graph()
|
||||
.constant(strides[axis])
|
||||
.expand_rhs(idx_normalized.shape);
|
||||
let flat_dest = non_axis_flat + idx_normalized * stride_tensor;
|
||||
|
||||
let flat_dest_1d = flat_dest.flatten();
|
||||
let flat_updates = updates.flatten();
|
||||
let flat_data = data.flatten();
|
||||
|
||||
let output_flat = flat_updates.scatter(flat_dest_1d, flat_data);
|
||||
|
||||
// View-only reshape back to data shape; the buffer is already laid
|
||||
// out row-major from the scatter, so swapping the tracker is safe.
|
||||
let mut result = output_flat;
|
||||
result.shape = ShapeTracker::new(data_dims);
|
||||
result
|
||||
}
|
||||
|
||||
/// Translator-local `scatter_nd` that accepts symbolic shape dims.
|
||||
/// Mirrors `GraphTensor::scatter_nd` semantics.
|
||||
pub fn pt2_scatter_nd(
|
||||
data: GraphTensor,
|
||||
indices: GraphTensor,
|
||||
updates: GraphTensor,
|
||||
) -> GraphTensor {
|
||||
let indices = indices.cast(DType::Int);
|
||||
let data_dims = data.dims();
|
||||
let data_rank = data_dims.len();
|
||||
let idx_dims = indices.dims();
|
||||
let idx_rank = idx_dims.len();
|
||||
|
||||
// The last dim of indices is the index width K — it must be
|
||||
// concrete at translate-time because it controls how many
|
||||
// contribution terms we build statically. HuggingFace's MoE
|
||||
// accumulator (the path that brought us here via `index_put`)
|
||||
// always passes a literal; non-HF callers with a SymInt K would
|
||||
// need a different lowering.
|
||||
let k = idx_dims[idx_rank - 1]
|
||||
.to_usize()
|
||||
.expect("scatter_nd: indices innermost dim (K) must be concrete");
|
||||
assert!(k <= data_rank, "scatter_nd: K must be <= data rank");
|
||||
|
||||
// Batch shape = indices shape without last dim.
|
||||
let batch_shape: Vec<Expression> = idx_dims[..idx_rank - 1].to_vec();
|
||||
let batch_numel = product_of_dims(batch_shape.iter().copied());
|
||||
|
||||
// Trailing shape = data_shape[K..]
|
||||
let trailing_shape: Vec<Expression> = data_dims[k..].to_vec();
|
||||
let trailing_numel = product_of_dims(trailing_shape.iter().copied());
|
||||
|
||||
let data_strides = row_major_strides(&data_dims);
|
||||
|
||||
// Flatten batch dims of indices to [batch_numel, K] via view reshape.
|
||||
let mut indices_flat = indices;
|
||||
if idx_rank > 2 {
|
||||
indices_flat.shape = ShapeTracker::new(vec![batch_numel, Expression::from(k)]);
|
||||
}
|
||||
|
||||
let mut flat_base: Option<GraphTensor> = None;
|
||||
for (k_dim, stride) in data_strides.iter().copied().enumerate().take(k) {
|
||||
let idx_k = indices_flat.slice_along(k_dim..k_dim + 1, indices_flat.dims().len() - 1);
|
||||
let idx_k = idx_k.squeeze(idx_k.dims().len() - 1);
|
||||
|
||||
let stride_tensor = data.graph().constant(stride).expand_rhs(idx_k.shape);
|
||||
let contribution = idx_k * stride_tensor;
|
||||
|
||||
flat_base = Some(match flat_base {
|
||||
Some(fb) => fb + contribution,
|
||||
None => contribution,
|
||||
});
|
||||
}
|
||||
let flat_base = flat_base.unwrap();
|
||||
|
||||
// Trailing-numel concreteness drives whether we need the expand-and-fold
|
||||
// path. If trailing_shape is empty OR its numel collapses to 1, the flat
|
||||
// base is already the full destination index.
|
||||
let trailing_is_unit = trailing_shape.is_empty() || trailing_numel.to_usize() == Some(1);
|
||||
let mut full_flat_dest = if trailing_is_unit {
|
||||
flat_base
|
||||
} else {
|
||||
let mut base_expanded = flat_base.expand_dim(1, trailing_numel);
|
||||
|
||||
let trailing_rank = trailing_shape.len();
|
||||
for (ti, d) in (k..data_rank).enumerate() {
|
||||
let ar = data.graph().arange(data_dims[d]);
|
||||
let mut ar_shaped = ar;
|
||||
for _ in ti + 1..trailing_rank {
|
||||
let n = ar_shaped.dims().len();
|
||||
ar_shaped = ar_shaped.expand_dim(n, 1);
|
||||
}
|
||||
for _ in 0..ti {
|
||||
ar_shaped = ar_shaped.expand_dim(0, 1);
|
||||
}
|
||||
ar_shaped.shape.expand(trailing_shape.clone());
|
||||
let mut ar_flat = ar_shaped;
|
||||
ar_flat.shape = ShapeTracker::new(vec![trailing_numel]);
|
||||
ar_flat = ar_flat.expand_dim(0, batch_numel);
|
||||
|
||||
let stride_tensor = data
|
||||
.graph()
|
||||
.constant(data_strides[d])
|
||||
.expand_rhs(ar_flat.shape);
|
||||
base_expanded += ar_flat * stride_tensor;
|
||||
}
|
||||
base_expanded
|
||||
};
|
||||
|
||||
full_flat_dest = full_flat_dest.flatten();
|
||||
|
||||
let flat_updates = updates.flatten();
|
||||
let flat_data = data.flatten();
|
||||
|
||||
let output_flat = flat_updates.scatter(full_flat_dest, flat_data);
|
||||
|
||||
let mut result = output_flat;
|
||||
result.shape = ShapeTracker::new(data_dims);
|
||||
result
|
||||
}
|
||||
@@ -119,10 +119,8 @@ impl<'a> Translator<'a> {
|
||||
/// buffer would be sized for the un-sliced argsort tensor while the
|
||||
/// shape tracker reports a smaller rank.
|
||||
///
|
||||
/// The output dtype is `DType::Int` (luminal's 32-bit int); PT2
|
||||
/// metadata records int64 and the Python wrapper widens at the
|
||||
/// boundary, so the PyTorch contract is preserved end-to-end
|
||||
/// (LUM-486).
|
||||
/// The result is cast to `DType::I64` to match PyTorch's int64
|
||||
/// argmax / argmin indices.
|
||||
pub(crate) fn translate_argextremum(
|
||||
&mut self,
|
||||
node: &Node,
|
||||
@@ -149,7 +147,7 @@ impl<'a> Translator<'a> {
|
||||
None | Some(0) | Some(-1) => {
|
||||
// PyTorch returns scalar index 0 for rank-0 argmax/argmin.
|
||||
// `keepdim=True` does not add a dimension when the input is 0-d.
|
||||
return Ok(self.graph.constant(0i64).cast(DType::Int));
|
||||
return Ok(self.graph.constant(0i64).cast(DType::I64));
|
||||
}
|
||||
Some(dim) => {
|
||||
return Err(anyhow::anyhow!(
|
||||
@@ -188,6 +186,6 @@ impl<'a> Translator<'a> {
|
||||
} else {
|
||||
picked
|
||||
};
|
||||
Ok(result * 1)
|
||||
Ok((result * 1).cast(DType::I64))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -413,15 +413,18 @@ impl<'a> Translator<'a> {
|
||||
|
||||
// Build top-k outputs from a full stable argsort. Slice the indices
|
||||
// before gathering values so the gather shape matches the requested
|
||||
// top-k output rather than the full sort width.
|
||||
// top-k output rather than the full sort width. Cast to I64 so the
|
||||
// emitted indices match PyTorch's `torch.topk` semantics (indices
|
||||
// are int64); `gather_elements` accepts any int dtype on its index
|
||||
// operand, so a single I64 tensor serves both consumers.
|
||||
let full_argsort = a.stable_argsort(dim, true);
|
||||
let topk_indices = full_argsort.slice_along(..k, dim) * 1.0;
|
||||
let topk_indices = (full_argsort.slice_along(..k, dim) * 1.0).cast(DType::I64);
|
||||
|
||||
// Only build the outputs that are consumed.
|
||||
if let Some(val_name) = values_name
|
||||
&& !val_name.is_empty()
|
||||
{
|
||||
let values = a.gather_elements(topk_indices, dim);
|
||||
let values = super::movement_dynamic::pt2_gather_elements(a, topk_indices, dim);
|
||||
self.tensors.insert(val_name, values);
|
||||
}
|
||||
if let Some(idx_name) = indices_name {
|
||||
@@ -465,11 +468,12 @@ impl<'a> Translator<'a> {
|
||||
if let Some(val_name) = values_name
|
||||
&& !val_name.is_empty()
|
||||
{
|
||||
let values = a.gather_elements(full_argsort, dim);
|
||||
let values = super::movement_dynamic::pt2_gather_elements(a, full_argsort, dim);
|
||||
self.tensors.insert(val_name, values);
|
||||
}
|
||||
if let Some(idx_name) = indices_name {
|
||||
let indices = full_argsort * 1.0;
|
||||
// `torch.sort` returns int64 indices; cast at the PT2 boundary.
|
||||
let indices = (full_argsort * 1.0).cast(DType::I64);
|
||||
self.tensors.insert(idx_name, indices);
|
||||
}
|
||||
|
||||
|
||||
@@ -35,7 +35,12 @@ impl<'a> Translator<'a> {
|
||||
false
|
||||
};
|
||||
let dim = crate::pt2_util::normalize_dim(dim, a.shape.len());
|
||||
Ok(a.stable_argsort(dim, descending))
|
||||
// PyTorch's `torch.argsort` returns int64 unconditionally;
|
||||
// luminal's frontend `stable_argsort` returns i32 (storage-
|
||||
// efficient default for native Rust callers). Cast at the
|
||||
// PT2↔luminal boundary so the strict output-read path sees
|
||||
// an I64 buffer.
|
||||
Ok(a.stable_argsort(dim, descending).cast(DType::I64))
|
||||
}
|
||||
|
||||
pub(crate) fn translate_unary_op(
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
//! through the PT2 path without forcing everything to f32.
|
||||
|
||||
use luminal::hlir::NativeData;
|
||||
use luminal::prelude::tracing::warn;
|
||||
use luminal::prelude::*;
|
||||
|
||||
/// A dtype-tagged byte buffer. All weight, constant, and input data flows through this type.
|
||||
@@ -149,62 +148,40 @@ impl TypedData {
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert raw bytes from a PyTorch tensor (identified by PT2 dtype code) to TypedData
|
||||
/// in luminal's native format. Handles widening/narrowing conversions for types where
|
||||
/// PyTorch's byte layout differs from luminal's:
|
||||
/// - i64 → i32, f64 → f32 (luminal has no 64-bit types)
|
||||
/// - i16 → i32, u8 → i32, i8 → i32 (luminal maps all integer types to i32 for PT2)
|
||||
/// Convert raw bytes from a PyTorch tensor (identified by PT2 dtype
|
||||
/// code) to `TypedData`. Supported dtypes preserve their raw bytes —
|
||||
/// no width changes at the FFI boundary. Narrow integer widths
|
||||
/// (`Byte` / `Char` / `Short`) panic: luminal's `NativeData` has no
|
||||
/// narrower-integer variants yet, so the only way they could pass
|
||||
/// through is via implicit widening to `i32`, which the no-implicit-
|
||||
/// cast directive forbids. Cast at the call site
|
||||
/// (`x.to(torch.int32)`) or wait for the narrower-int IR follow-up.
|
||||
pub fn from_pytorch_bytes(bytes: Vec<u8>, dtype_code: u32) -> Self {
|
||||
match dtype_code {
|
||||
// Types that map directly — preserve raw bytes
|
||||
7 => Self::from_raw(bytes, DType::F32),
|
||||
6 => Self::from_raw(bytes, DType::F16),
|
||||
13 => Self::from_raw(bytes, DType::Bf16),
|
||||
4 => Self::from_raw(bytes, DType::Int), // i32
|
||||
12 => Self::from_raw(bytes, DType::Bool),
|
||||
// i64 → i32 (truncate)
|
||||
5 => {
|
||||
let i32s: Vec<i32> = bytes
|
||||
.chunks_exact(8)
|
||||
.map(|b| {
|
||||
i64::from_le_bytes([b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]]) as i32
|
||||
})
|
||||
.collect();
|
||||
Self::from_i32_vec(i32s)
|
||||
}
|
||||
// f64 → f32 (downcast)
|
||||
8 => {
|
||||
let f32s: Vec<f32> = bytes
|
||||
.chunks_exact(8)
|
||||
.map(|b| {
|
||||
f64::from_le_bytes([b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]]) as f32
|
||||
})
|
||||
.collect();
|
||||
Self::from_f32_vec(f32s)
|
||||
}
|
||||
// i16 → i32 (widen)
|
||||
3 => {
|
||||
let i32s: Vec<i32> = bytes
|
||||
.chunks_exact(2)
|
||||
.map(|b| i16::from_le_bytes([b[0], b[1]]) as i32)
|
||||
.collect();
|
||||
Self::from_i32_vec(i32s)
|
||||
}
|
||||
// u8 → i32 (widen)
|
||||
1 => {
|
||||
let i32s: Vec<i32> = bytes.iter().map(|&b| b as i32).collect();
|
||||
Self::from_i32_vec(i32s)
|
||||
}
|
||||
// i8 → i32 (widen, signed)
|
||||
2 => {
|
||||
let i32s: Vec<i32> = bytes.iter().map(|&b| (b as i8) as i32).collect();
|
||||
Self::from_i32_vec(i32s)
|
||||
}
|
||||
// Unknown: best-effort pass-through as f32
|
||||
_ => {
|
||||
warn!("Unrecognized pytorch dtype code {dtype_code}, interpreting as f32");
|
||||
Self::from_raw(bytes, DType::F32)
|
||||
}
|
||||
let t = crate::torch_dtype::TorchDType::from_code(dtype_code)
|
||||
.unwrap_or_else(|c| panic!("from_pytorch_bytes: unknown PT2 dtype code {c}"));
|
||||
match t {
|
||||
crate::torch_dtype::TorchDType::Float => Self::from_raw(bytes, DType::F32),
|
||||
crate::torch_dtype::TorchDType::Half => Self::from_raw(bytes, DType::F16),
|
||||
crate::torch_dtype::TorchDType::BFloat16 => Self::from_raw(bytes, DType::Bf16),
|
||||
crate::torch_dtype::TorchDType::Int => Self::from_raw(bytes, DType::Int),
|
||||
crate::torch_dtype::TorchDType::Bool => Self::from_raw(bytes, DType::Bool),
|
||||
crate::torch_dtype::TorchDType::Long => Self::from_raw(bytes, DType::I64),
|
||||
crate::torch_dtype::TorchDType::Double => Self::from_raw(bytes, DType::F64),
|
||||
crate::torch_dtype::TorchDType::Byte
|
||||
| crate::torch_dtype::TorchDType::Char
|
||||
| crate::torch_dtype::TorchDType::Short => panic!(
|
||||
"from_pytorch_bytes: PT2 dtype {} (code {}) isn't a first-class \
|
||||
IR type yet — cast to torch.int32 at the call site, or wait \
|
||||
for the narrower-int IR follow-up.",
|
||||
t.name(),
|
||||
t.code(),
|
||||
),
|
||||
other => panic!(
|
||||
"from_pytorch_bytes: PT2 dtype {} (code {}) isn't a first-class \
|
||||
IR type — no luminal mapping.",
|
||||
other.name(),
|
||||
other.code(),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -8,6 +8,20 @@ from .dtype_util import code_to_torch_dtype
|
||||
from .dtype_util import torch_dtype_code as _torch_dtype_code
|
||||
|
||||
|
||||
class DTypeBoundaryError(TypeError):
|
||||
"""Raised when the caller passes an input whose dtype does not match the
|
||||
compiled graph's declared input dtype.
|
||||
|
||||
The previous behaviour cast silently at every call, which (a) hid real
|
||||
precision bugs (e.g. f64 → f32 truncation on values outside the f32
|
||||
range) and (b) burnt CPU/GPU on a per-call allocation+copy that the
|
||||
user couldn't see in their profile. The contract is now strict:
|
||||
`model(x)` requires `x.dtype == model.input_dtypes[i]` for every
|
||||
positional input. Convert at the call site with
|
||||
`x.to(model.input_dtypes[i])` if you need a different dtype.
|
||||
"""
|
||||
|
||||
|
||||
class CompiledModel:
|
||||
"""Wrapper around CompiledGraph that handles PyTorch tensor conversion."""
|
||||
|
||||
@@ -35,14 +49,18 @@ class CompiledModel:
|
||||
self._supports_device_ptrs = getattr(
|
||||
graph_result, "supports_device_ptrs", False
|
||||
)
|
||||
# Expected input dtypes from graph (used to convert user inputs)
|
||||
# Expected input dtypes from graph. Every declared input MUST
|
||||
# have a dtype code — refuse to silently default to float32 if
|
||||
# the Rust side returned a shorter list than `input_names`.
|
||||
input_dtype_codes = graph_result.input_dtypes
|
||||
self._input_dtypes = [
|
||||
code_to_torch_dtype(input_dtype_codes[i])
|
||||
if i < len(input_dtype_codes)
|
||||
else torch.float32
|
||||
for i in range(len(self._input_names))
|
||||
]
|
||||
if len(input_dtype_codes) != len(self._input_names):
|
||||
raise RuntimeError(
|
||||
f"CompiledGraph returned {len(input_dtype_codes)} input dtype "
|
||||
f"codes for {len(self._input_names)} declared inputs "
|
||||
f"({self._input_names!r}) — every declared input needs a "
|
||||
f"matching dtype."
|
||||
)
|
||||
self._input_dtypes = [code_to_torch_dtype(c) for c in input_dtype_codes]
|
||||
|
||||
def set_dim(self, param_name: str, value: int) -> None:
|
||||
"""Set a dynamic dimension value by its param name."""
|
||||
@@ -95,13 +113,22 @@ class CompiledModel:
|
||||
for name, tensor, expected_dtype in zip(
|
||||
self._input_names, user_inputs, self._input_dtypes
|
||||
):
|
||||
if tensor.dtype != expected_dtype:
|
||||
raise DTypeBoundaryError(
|
||||
f"Luminal compiled input '{name}' expects "
|
||||
f"{expected_dtype} but got {tensor.dtype}. "
|
||||
"Convert at the call site with "
|
||||
f"`x.to({expected_dtype})` — the boundary used to silently "
|
||||
"cast (and warn) on every call, which masked precision "
|
||||
"bugs and burnt cycles on per-call allocation+copy."
|
||||
)
|
||||
if self._supports_device_ptrs and tensor.is_cuda:
|
||||
t = tensor.detach().contiguous().to(expected_dtype)
|
||||
t = tensor.detach().contiguous()
|
||||
n_bytes = t.numel() * t.element_size()
|
||||
self._graph.set_input_device_ptr(name, t.data_ptr(), n_bytes)
|
||||
_input_refs.append(t)
|
||||
else:
|
||||
t = tensor.detach().cpu().contiguous().to(expected_dtype)
|
||||
t = tensor.detach().cpu().contiguous()
|
||||
n_bytes = t.numel() * t.element_size()
|
||||
dtype_code = _torch_dtype_code(t.dtype)
|
||||
self._graph.set_input_from_ptr(name, t.data_ptr(), n_bytes, dtype_code)
|
||||
@@ -112,100 +139,120 @@ class CompiledModel:
|
||||
else:
|
||||
output_shapes = self._output_shapes
|
||||
|
||||
# Every declared output MUST have a dtype code; refuse to default
|
||||
# to float32 the way we used to if the Rust side returned fewer
|
||||
# codes than declared outputs.
|
||||
output_dtype_codes = self._graph.output_dtypes
|
||||
if len(output_dtype_codes) != len(self._output_names):
|
||||
raise RuntimeError(
|
||||
f"CompiledGraph returned {len(output_dtype_codes)} output "
|
||||
f"dtype codes for {len(self._output_names)} declared outputs "
|
||||
f"({self._output_names!r}) — every declared output needs a "
|
||||
f"matching dtype."
|
||||
)
|
||||
output_torch_dtypes = [code_to_torch_dtype(c) for c in output_dtype_codes]
|
||||
|
||||
# CUDA zero-copy path: pre-allocate output tensors and register their device
|
||||
# pointers so the final kernel writes directly into PyTorch's buffer.
|
||||
# Per-dtype dispatch table mapping `torch_dtype` → the typed
|
||||
# `_graph` getter for that dtype. Every supported dtype has an
|
||||
# explicit native-width getter; anything not listed raises
|
||||
# `NotImplementedError` from `_read_typed_output`. There is no
|
||||
# open-ended fallback — a missing entry means we don't know how
|
||||
# to read that dtype yet, and we'd rather fail loudly than
|
||||
# silently reinterpret bytes.
|
||||
#
|
||||
# `float16` / `bfloat16` getters return `uint16` bit patterns
|
||||
# (Python has no native `f16` / `bf16`); the helper below
|
||||
# bit-casts them back to the declared dtype via
|
||||
# `torch.frombuffer`. That's a reinterpret, not a numeric
|
||||
# cast — no precision change.
|
||||
#
|
||||
# Narrow ints (`int8` / `int16` / `uint8`) are intentionally
|
||||
# absent — luminal's IR refuses them at the FFI boundary (cf.
|
||||
# `pt2_util::torch_dtype_int_to_luminal`,
|
||||
# `typed_data::from_pytorch_bytes`), so a graph can never
|
||||
# declare a narrow-int output that reaches this dispatch.
|
||||
_zero_copy_native_floats = (torch.float32, torch.float16, torch.bfloat16)
|
||||
_output_readers = {
|
||||
torch.float32: ("get_output", torch.float32),
|
||||
torch.float64: ("get_output_f64", torch.float64),
|
||||
torch.float16: ("get_output_f16", torch.float16),
|
||||
torch.bfloat16: ("get_output_bf16", torch.bfloat16),
|
||||
torch.int64: ("get_output_i64", torch.int64),
|
||||
torch.int32: ("get_output_i32", torch.int32),
|
||||
torch.bool: ("get_output_bool", torch.bool),
|
||||
}
|
||||
|
||||
def _read_typed_output(name: str, shape, out_dtype) -> torch.Tensor:
|
||||
"""Pull one output back from the runtime at the right dtype.
|
||||
|
||||
Strict: any `out_dtype` not in `_output_readers` raises
|
||||
`NotImplementedError`. The previous code's open-ended
|
||||
fallback read the buffer as f32 and `.to(out_dtype)`'d
|
||||
back, which silently aliased dtypes we don't really
|
||||
support; refusing surfaces the gap.
|
||||
|
||||
For `float16` / `bfloat16` the typed getter returns
|
||||
`uint16` bit patterns (Python has no native half-precision
|
||||
float type); we bit-cast via `torch.tensor(..., uint16)`
|
||||
and `.view(half)` so the conversion is a reinterpret of the
|
||||
bytes, not a numeric cast.
|
||||
"""
|
||||
entry = _output_readers.get(out_dtype)
|
||||
if entry is None:
|
||||
raise NotImplementedError(
|
||||
f"Output '{name}' declared dtype {out_dtype} isn't "
|
||||
f"supported by the luminal read boundary. Add a typed "
|
||||
f"getter for this dtype (see `_output_readers`) or cast "
|
||||
f"the output to a supported dtype upstream."
|
||||
)
|
||||
getter_name, read_dtype = entry
|
||||
data = getattr(self._graph, getter_name)(name)
|
||||
if out_dtype in (torch.float16, torch.bfloat16):
|
||||
# Getter returned an immutable `bytes` from Rust; wrap in
|
||||
# `bytearray` to make the storage writable (suppresses
|
||||
# the "non-writable buffer" warning), then bit-cast via
|
||||
# `frombuffer` — no numeric conversion.
|
||||
tensor = torch.frombuffer(bytearray(data), dtype=out_dtype).reshape(
|
||||
tuple(shape)
|
||||
)
|
||||
else:
|
||||
tensor = torch.tensor(data, dtype=read_dtype).reshape(tuple(shape))
|
||||
return tensor.to(input_device)
|
||||
|
||||
# Pre-allocation is GPU-only: the CUDA kernel needs the
|
||||
# output's device pointer registered *before* `_graph.run()`
|
||||
# so the final kernel writes directly into PyTorch's buffer.
|
||||
# Only the float dtypes luminal natively writes
|
||||
# (`_zero_copy_native_floats`) take the zero-copy path; other
|
||||
# dtypes (int*, bool, f64) read back via `_read_typed_output`
|
||||
# after `run()` and so don't need a pre-allocated tensor at
|
||||
# this layer. CPU never zero-copies — there's no separate
|
||||
# device buffer to register against.
|
||||
_use_zero_copy = self._supports_device_ptrs
|
||||
output_tensors = []
|
||||
if _use_zero_copy:
|
||||
for i, (name, shape) in enumerate(zip(self._output_names, output_shapes)):
|
||||
out_dtype = (
|
||||
code_to_torch_dtype(output_dtype_codes[i])
|
||||
if i < len(output_dtype_codes)
|
||||
else torch.float32
|
||||
)
|
||||
out_dtype = output_torch_dtypes[i]
|
||||
out = torch.empty(shape, dtype=out_dtype, device=input_device)
|
||||
if out_dtype.is_floating_point:
|
||||
if out_dtype in _zero_copy_native_floats:
|
||||
self._graph.set_output_device_ptr(
|
||||
name, out.data_ptr(), out.numel() * out.element_size()
|
||||
)
|
||||
output_tensors.append(out)
|
||||
|
||||
# Run the graph
|
||||
self._graph.run()
|
||||
|
||||
# Integer dtypes for which we read the buffer as i32 and then cast.
|
||||
# Includes int64 because luminal collapses all integer types to its
|
||||
# 32-bit `Int` internally — we restore the original precision here.
|
||||
_int_dtypes = (torch.int8, torch.int16, torch.int32, torch.int64, torch.uint8)
|
||||
|
||||
# Collect outputs
|
||||
if _use_zero_copy:
|
||||
outputs = []
|
||||
for i, (name, shape) in enumerate(zip(self._output_names, output_shapes)):
|
||||
out_dtype = (
|
||||
code_to_torch_dtype(output_dtype_codes[i])
|
||||
if i < len(output_dtype_codes)
|
||||
else torch.float32
|
||||
)
|
||||
outputs = []
|
||||
for i, (name, shape) in enumerate(zip(self._output_names, output_shapes)):
|
||||
out_dtype = output_torch_dtypes[i]
|
||||
if _use_zero_copy and out_dtype in _zero_copy_native_floats:
|
||||
out = output_tensors[i]
|
||||
if out_dtype.is_floating_point:
|
||||
if not self._graph.output_is_zero_copy(name):
|
||||
self._graph.copy_output_to_device_ptr(
|
||||
name, out.data_ptr(), out.numel() * out.element_size()
|
||||
)
|
||||
elif out_dtype in _int_dtypes:
|
||||
data = self._graph.get_output_i32(name)
|
||||
out = (
|
||||
torch.tensor(data, dtype=torch.int32)
|
||||
.reshape(tuple(shape))
|
||||
.to(out_dtype)
|
||||
.to(input_device)
|
||||
if not self._graph.output_is_zero_copy(name):
|
||||
self._graph.copy_output_to_device_ptr(
|
||||
name, out.data_ptr(), out.numel() * out.element_size()
|
||||
)
|
||||
elif out_dtype == torch.bool:
|
||||
data = self._graph.get_output_bool(name)
|
||||
out = (
|
||||
torch.tensor(data, dtype=torch.bool)
|
||||
.reshape(tuple(shape))
|
||||
.to(input_device)
|
||||
)
|
||||
else:
|
||||
data = self._graph.get_output(name)
|
||||
out = (
|
||||
torch.tensor(data, dtype=torch.float32)
|
||||
.reshape(tuple(shape))
|
||||
.to(out_dtype)
|
||||
.to(input_device)
|
||||
)
|
||||
outputs.append(out)
|
||||
else:
|
||||
# Native path: retrieve as f32, then convert to target dtype if needed.
|
||||
outputs = []
|
||||
for i, (name, shape) in enumerate(zip(self._output_names, output_shapes)):
|
||||
out_dtype = (
|
||||
code_to_torch_dtype(output_dtype_codes[i])
|
||||
if i < len(output_dtype_codes)
|
||||
else torch.float32
|
||||
)
|
||||
if out_dtype in _int_dtypes:
|
||||
data = self._graph.get_output_i32(name)
|
||||
out = (
|
||||
torch.tensor(data, dtype=torch.int32)
|
||||
.reshape(tuple(shape))
|
||||
.to(out_dtype)
|
||||
)
|
||||
elif out_dtype == torch.bool:
|
||||
data = self._graph.get_output_bool(name)
|
||||
out = torch.tensor(data, dtype=torch.bool).reshape(tuple(shape))
|
||||
else:
|
||||
data = self._graph.get_output(name)
|
||||
out = (
|
||||
torch.tensor(data, dtype=torch.float32)
|
||||
.reshape(tuple(shape))
|
||||
.to(out_dtype)
|
||||
)
|
||||
out = out.to(input_device)
|
||||
outputs.append(out)
|
||||
else:
|
||||
out = _read_typed_output(name, shape, out_dtype)
|
||||
outputs.append(out)
|
||||
|
||||
return tuple(outputs)
|
||||
|
||||
@@ -1,28 +1,61 @@
|
||||
"""Shared dtype utility functions for the luminal Python Bridge"""
|
||||
"""Shared dtype utility functions for the luminal Python bridge.
|
||||
|
||||
The PT2 dtype-code numbering is sourced from
|
||||
``torch._export.serde.schema.ScalarType`` at import time — PyTorch is the
|
||||
canonical source of truth on both sides of the FFI boundary. The Rust side
|
||||
mirrors the same enum in ``luminal_python/rust/src/torch_dtype.rs`` and is
|
||||
held in agreement by ``tests/test_torch_dtype_parity.py``.
|
||||
|
||||
``torch._export.serde.schema`` is a quasi-private API (leading underscore),
|
||||
but it is the module PT2 export actually wire-serializes against; binding
|
||||
to it here is the right boundary. If PyTorch reorganizes the module path,
|
||||
the import below will fail loudly at module load.
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch._export.serde.schema import ScalarType
|
||||
|
||||
# Map each `torch.dtype` we care about to the PT2 code PyTorch itself
|
||||
# would emit for it. Looking up `ScalarType.<NAME>.value` keeps the
|
||||
# numbering in lockstep with PyTorch — if PyTorch renumbers, we pick
|
||||
# up the new code automatically (and the Rust parity test catches the
|
||||
# drift from the other side).
|
||||
_TORCH_DTYPE_TO_CODE = {
|
||||
torch.uint8: 1,
|
||||
torch.int8: 2,
|
||||
torch.int16: 3,
|
||||
torch.int32: 4,
|
||||
torch.int64: 5,
|
||||
torch.float16: 6,
|
||||
torch.float32: 7,
|
||||
torch.float64: 8,
|
||||
torch.bool: 12,
|
||||
torch.bfloat16: 13,
|
||||
torch.uint8: ScalarType.BYTE.value,
|
||||
torch.int8: ScalarType.CHAR.value,
|
||||
torch.int16: ScalarType.SHORT.value,
|
||||
torch.int32: ScalarType.INT.value,
|
||||
torch.int64: ScalarType.LONG.value,
|
||||
torch.float16: ScalarType.HALF.value,
|
||||
torch.float32: ScalarType.FLOAT.value,
|
||||
torch.float64: ScalarType.DOUBLE.value,
|
||||
torch.bool: ScalarType.BOOL.value,
|
||||
torch.bfloat16: ScalarType.BFLOAT16.value,
|
||||
}
|
||||
|
||||
_CODE_TO_TORCH_DTYPE = {v: k for k, v in _TORCH_DTYPE_TO_CODE.items()}
|
||||
|
||||
|
||||
def torch_dtype_code(dtype):
|
||||
"""Map torch.dtype to PT2 dtype integer code."""
|
||||
return _TORCH_DTYPE_TO_CODE.get(dtype, 7) # default to f32
|
||||
"""Map torch.dtype to PT2 dtype integer code. Raises `KeyError`
|
||||
on an unsupported dtype rather than silently aliasing to FLOAT."""
|
||||
try:
|
||||
return _TORCH_DTYPE_TO_CODE[dtype]
|
||||
except KeyError:
|
||||
raise KeyError(
|
||||
f"torch_dtype_code: {dtype} isn't a supported PT2 dtype "
|
||||
f"(supported: {sorted(_TORCH_DTYPE_TO_CODE.keys(), key=str)})"
|
||||
) from None
|
||||
|
||||
|
||||
def code_to_torch_dtype(code):
|
||||
"""Map PT2 dtype integer code to torch.dtype."""
|
||||
return _CODE_TO_TORCH_DTYPE.get(code, torch.float32)
|
||||
"""Map PT2 dtype integer code to torch.dtype. Raises `KeyError`
|
||||
on an unknown code rather than silently defaulting to float32."""
|
||||
try:
|
||||
return _CODE_TO_TORCH_DTYPE[code]
|
||||
except KeyError:
|
||||
raise KeyError(
|
||||
f"code_to_torch_dtype: PT2 dtype code {code} isn't mapped "
|
||||
f"to a torch.dtype (known codes: "
|
||||
f"{sorted(_CODE_TO_TORCH_DTYPE.keys())})"
|
||||
) from None
|
||||
|
||||
250
crates/luminal_python/tests/test_dtype_boundary.py
Normal file
250
crates/luminal_python/tests/test_dtype_boundary.py
Normal file
@@ -0,0 +1,250 @@
|
||||
from dataclasses import dataclass
|
||||
import warnings
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from luminal import luminal_backend
|
||||
|
||||
|
||||
class BoundaryNoopModel(torch.nn.Module):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if x.dtype is torch.bool:
|
||||
return x | torch.zeros((), dtype=torch.bool, device=x.device)
|
||||
return x + torch.zeros((), dtype=x.dtype, device=x.device)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DTypeCase:
|
||||
name: str
|
||||
dtype: torch.dtype
|
||||
values: Callable[[], torch.Tensor]
|
||||
xfail_reason: str | None = None
|
||||
|
||||
|
||||
DTYPE_CASES = [
|
||||
DTypeCase(
|
||||
"bool",
|
||||
torch.bool,
|
||||
lambda: torch.tensor([True, False, True], dtype=torch.bool),
|
||||
),
|
||||
DTypeCase(
|
||||
"uint8",
|
||||
torch.uint8,
|
||||
lambda: torch.tensor([0, 127, 255], dtype=torch.uint8),
|
||||
),
|
||||
DTypeCase(
|
||||
"int8",
|
||||
torch.int8,
|
||||
lambda: torch.tensor([-128, -1, 127], dtype=torch.int8),
|
||||
),
|
||||
DTypeCase(
|
||||
"int16",
|
||||
torch.int16,
|
||||
lambda: torch.tensor([-32768, -1, 32767], dtype=torch.int16),
|
||||
),
|
||||
DTypeCase(
|
||||
"int32",
|
||||
torch.int32,
|
||||
lambda: torch.tensor(
|
||||
[-2147483648, -1, 2147483647],
|
||||
dtype=torch.int32,
|
||||
),
|
||||
),
|
||||
DTypeCase(
|
||||
"int64_i32_range",
|
||||
torch.int64,
|
||||
lambda: torch.tensor(
|
||||
[-2147483648, -1, 2147483647],
|
||||
dtype=torch.int64,
|
||||
),
|
||||
),
|
||||
DTypeCase(
|
||||
"float16",
|
||||
torch.float16,
|
||||
lambda: torch.tensor([1.0, 1.5, -2.0], dtype=torch.float16),
|
||||
),
|
||||
DTypeCase(
|
||||
"bfloat16",
|
||||
torch.bfloat16,
|
||||
lambda: torch.tensor([1.0, 1.5, -2.0], dtype=torch.bfloat16),
|
||||
),
|
||||
DTypeCase(
|
||||
"float32",
|
||||
torch.float32,
|
||||
lambda: torch.tensor([1.0, 1.5, -2.0], dtype=torch.float32),
|
||||
),
|
||||
DTypeCase(
|
||||
"float64_f32_exact",
|
||||
torch.float64,
|
||||
lambda: torch.tensor([1.0, 1.5, float(2**40)], dtype=torch.float64),
|
||||
),
|
||||
DTypeCase(
|
||||
"int64_outside_i32_range",
|
||||
torch.int64,
|
||||
lambda: torch.tensor([-(2**40), -1, 2**40], dtype=torch.int64),
|
||||
),
|
||||
DTypeCase(
|
||||
"float64_precision_sensitive",
|
||||
torch.float64,
|
||||
lambda: torch.tensor(
|
||||
[1.0, 1.0000000000000002, float(2**40) + 0.25],
|
||||
dtype=torch.float64,
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def _cuda_skip_reason() -> str | None:
|
||||
if not torch.cuda.is_available():
|
||||
return "CUDA is not available"
|
||||
|
||||
try:
|
||||
from luminal.luminal import _cuda_lite_factory_capsule
|
||||
|
||||
_cuda_lite_factory_capsule()
|
||||
except (ImportError, AttributeError, RuntimeError) as exc:
|
||||
return f"luminal_python was not built with CUDA support: {exc}"
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@pytest.fixture(params=["cpu", "cuda"], ids=["cpu", "cuda"])
|
||||
def boundary_device(request) -> torch.device:
|
||||
device_name = request.param
|
||||
if device_name == "cuda":
|
||||
skip_reason = _cuda_skip_reason()
|
||||
if skip_reason is not None:
|
||||
pytest.skip(skip_reason)
|
||||
return torch.device(device_name)
|
||||
|
||||
|
||||
# Dtypes that round-trip the BoundaryNoopModel without an explicit
|
||||
# `x.to(model.input_dtypes[0])` cast at the call site. Anything not in this
|
||||
# set is a narrow integer (uint8 / int8 / int16) that luminal collapses to
|
||||
# `DType::Int` internally — the hard-reject contract makes the boundary
|
||||
# refuse the mismatched dtype, and the test for those lives in
|
||||
# `test_input_dtype_mismatch_rejects` instead.
|
||||
_FIRST_CLASS_NOOP_DTYPES = {
|
||||
"bool",
|
||||
"int32",
|
||||
"int64_i32_range",
|
||||
"int64_outside_i32_range",
|
||||
"float16",
|
||||
"bfloat16",
|
||||
"float32",
|
||||
"float64_f32_exact",
|
||||
"float64_precision_sensitive",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
[
|
||||
pytest.param(
|
||||
case,
|
||||
marks=pytest.mark.xfail(reason=case.xfail_reason, strict=True)
|
||||
if case.xfail_reason is not None
|
||||
else (),
|
||||
id=case.name,
|
||||
)
|
||||
for case in DTYPE_CASES
|
||||
if case.name in _FIRST_CLASS_NOOP_DTYPES
|
||||
],
|
||||
)
|
||||
def test_boundary_noop_preserves_dtype_and_values(
|
||||
boundary_device: torch.device,
|
||||
case: DTypeCase,
|
||||
) -> None:
|
||||
model = BoundaryNoopModel().to(boundary_device)
|
||||
compiled = torch.compile(model, backend=luminal_backend)
|
||||
|
||||
x = case.values().to(boundary_device)
|
||||
expected = model(x)
|
||||
actual = compiled(x)
|
||||
|
||||
assert isinstance(actual, torch.Tensor)
|
||||
assert actual.dtype == expected.dtype
|
||||
assert torch.equal(actual.cpu(), expected.cpu())
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
[
|
||||
pytest.param(case, id=case.name)
|
||||
for case in DTYPE_CASES
|
||||
# Narrow integer widths (uint8 / int8 / int16) aren't first-class in
|
||||
# luminal's IR — the translator refuses them outright. int64 /
|
||||
# float64 are first-class and round-trip without rejection.
|
||||
if case.name in {"uint8", "int8", "int16"}
|
||||
],
|
||||
)
|
||||
def test_input_dtype_mismatch_rejects(
|
||||
boundary_device: torch.device,
|
||||
case: DTypeCase,
|
||||
) -> None:
|
||||
"""Hard-reject contract: a graph whose declared input dtype is one of
|
||||
the narrow ints (uint8 / int8 / int16) fails at compile time with a
|
||||
clear panic from `torch_dtype_int_to_luminal`. Previously the
|
||||
translator silently widened narrow ints to `Int` (i32), which left
|
||||
the user's actual dtype invisible past the FFI boundary; today the
|
||||
failure points at the missing IR support directly.
|
||||
"""
|
||||
model = BoundaryNoopModel().to(boundary_device)
|
||||
compiled = torch.compile(model, backend=luminal_backend)
|
||||
x = case.values().to(boundary_device)
|
||||
|
||||
# `pyo3_runtime.PanicException` inherits from `BaseException` (not
|
||||
# `Exception`), so `pytest.raises(Exception, ...)` would miss it.
|
||||
# Match on the panic message text — stable across torch versions.
|
||||
with pytest.raises(BaseException, match="isn't a first-class IR type yet"):
|
||||
compiled(x)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
[
|
||||
pytest.param(case, id=case.name)
|
||||
for case in DTYPE_CASES
|
||||
if case.name
|
||||
in {
|
||||
"bool",
|
||||
"int32",
|
||||
"float16",
|
||||
"bfloat16",
|
||||
"float32",
|
||||
# int64 / float64 are first-class in the IR — passing a tensor
|
||||
# of either dtype matches the graph's input dtype directly, no
|
||||
# conversion needed.
|
||||
"int64_i32_range",
|
||||
"int64_outside_i32_range",
|
||||
"float64_f32_exact",
|
||||
"float64_precision_sensitive",
|
||||
}
|
||||
],
|
||||
)
|
||||
def test_matching_dtype_does_not_raise(
|
||||
boundary_device: torch.device,
|
||||
case: DTypeCase,
|
||||
) -> None:
|
||||
"""Round-trip contract: a user input whose dtype matches the graph's
|
||||
declared input dtype runs without raising, with no warnings emitted at
|
||||
the boundary."""
|
||||
model = BoundaryNoopModel().to(boundary_device)
|
||||
compiled = torch.compile(model, backend=luminal_backend)
|
||||
x = case.values().to(boundary_device)
|
||||
|
||||
with warnings.catch_warnings(record=True) as records:
|
||||
warnings.simplefilter("always")
|
||||
compiled(x)
|
||||
|
||||
boundary_warnings = [
|
||||
record
|
||||
for record in records
|
||||
if "boundary" in str(record.message).lower()
|
||||
or "convert" in str(record.message).lower()
|
||||
]
|
||||
assert boundary_warnings == [], (
|
||||
f"unexpected boundary-related warning(s): {boundary_warnings}"
|
||||
)
|
||||
109
crates/luminal_python/tests/test_gather_scatter_dynamic.py
Normal file
109
crates/luminal_python/tests/test_gather_scatter_dynamic.py
Normal file
@@ -0,0 +1,109 @@
|
||||
"""Dynamic-shape regression coverage for the movement ops Qwen3-MoE /
|
||||
Gemma4-MoE exercise via `torch.compile`.
|
||||
|
||||
Three failure modes surfaced while debugging the Qwen3-30B-A3B path:
|
||||
|
||||
1. `gather_elements: index dim must be concrete` — `gather_elements`
|
||||
/ `scatter_elements` collected index dims as `Vec<usize>` via
|
||||
`.to_usize().expect(...)`. First forward worked; the second forward
|
||||
at a different seq_len made Dynamo emit a SymInt dim and tripped
|
||||
the assertion.
|
||||
2. `Dims must match to add tensors. left: [(a*8), 2048] right: [(8*a), 2048]`
|
||||
— different translator paths produced semantically-equal but
|
||||
syntactically-different `Expression` dims.
|
||||
3. `scatter_nd: data dim must be concrete` — same family as (1),
|
||||
reached via `translate_index_put` (HF's MoE accumulator).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
from luminal.main import luminal_backend
|
||||
|
||||
|
||||
def _compile(model):
|
||||
return torch.compile(model, backend=luminal_backend)
|
||||
|
||||
|
||||
def test_gather_elements_dynamic_index_shape(device: torch.device) -> None:
|
||||
"""`torch.gather` with a dynamic batch dim on the index tensor."""
|
||||
|
||||
class GatherModel(torch.nn.Module):
|
||||
def forward(self, table: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
|
||||
expanded = table.unsqueeze(0).expand(indices.shape[0], -1, -1)
|
||||
idx = indices.unsqueeze(-1).unsqueeze(-1).expand(-1, 1, 32)
|
||||
return torch.gather(expanded, 1, idx).squeeze(1)
|
||||
|
||||
model = GatherModel().to(device)
|
||||
compiled = _compile(model)
|
||||
table = torch.randn(8, 32, device=device)
|
||||
|
||||
for batch in [4, 7, 11, 4]:
|
||||
idx = torch.randint(0, 8, (batch,), device=device, dtype=torch.int64)
|
||||
assert torch.allclose(compiled(table, idx), model(table, idx), atol=1e-4)
|
||||
|
||||
|
||||
def test_scatter_elements_dynamic_index_shape(device: torch.device) -> None:
|
||||
"""`torch.scatter` with a dynamic batch dim on the index tensor."""
|
||||
|
||||
class ScatterModel(torch.nn.Module):
|
||||
def forward(self, values: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
|
||||
dest = torch.zeros(
|
||||
values.shape[0], 16, device=values.device, dtype=values.dtype
|
||||
)
|
||||
return dest.scatter(1, indices, values)
|
||||
|
||||
model = ScatterModel().to(device)
|
||||
compiled = _compile(model)
|
||||
|
||||
for batch in [4, 7, 11, 4]:
|
||||
# Distinct indices per row → no-overlap scatter for allclose.
|
||||
idx = torch.stack(
|
||||
[torch.randperm(16, device=device)[:4] for _ in range(batch)]
|
||||
).to(torch.int64)
|
||||
vals = torch.randn(batch, 4, device=device)
|
||||
assert torch.allclose(compiled(vals, idx), model(vals, idx), atol=1e-4)
|
||||
|
||||
|
||||
def test_scatter_nd_dynamic_data_shape(device: torch.device) -> None:
|
||||
"""`tensor[idx] = value` → `translate_index_put` → `scatter_nd`."""
|
||||
|
||||
class ScatterNDModel(torch.nn.Module):
|
||||
def forward(
|
||||
self, base: torch.Tensor, idx: torch.Tensor, vals: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
out = base.clone()
|
||||
out[idx] = vals
|
||||
return out
|
||||
|
||||
model = ScatterNDModel().to(device)
|
||||
compiled = _compile(model)
|
||||
|
||||
for batch in [4, 7, 11, 4]:
|
||||
base = torch.randn(16, 4, device=device)
|
||||
idx = torch.randperm(16, device=device)[:batch].to(torch.int64)
|
||||
vals = torch.randn(batch, 4, device=device)
|
||||
assert torch.allclose(
|
||||
compiled(base, idx, vals), model(base, idx, vals), atol=1e-4
|
||||
)
|
||||
|
||||
|
||||
def test_where_dynamic_shape_no_dim_mismatch_panic(device: torch.device) -> None:
|
||||
"""`torch.where` over inputs whose shape derives from a SymInt:
|
||||
two translator paths can produce `a*8` vs `8*a` for the same dim,
|
||||
which trips the dim-equality assert in luminal-core's `Sub` /
|
||||
`Add` without canonical ordering in `dim_arith`.
|
||||
"""
|
||||
|
||||
class WhereModel(torch.nn.Module):
|
||||
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
return torch.where(x > 0, x, y)
|
||||
|
||||
model = WhereModel().to(device)
|
||||
compiled = _compile(model)
|
||||
|
||||
for batch in [4, 7, 11, 4]:
|
||||
x = torch.randn(batch, 16, device=device)
|
||||
y = torch.randn(batch, 16, device=device)
|
||||
assert torch.allclose(compiled(x, y), model(x, y), atol=1e-4)
|
||||
@@ -230,7 +230,6 @@ def test_hf_llama_decode_loop_static(device: torch.device):
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.xfail(reason="numerical precision — max_diff exceeds atol")
|
||||
def test_hf_llama3_1b_decode_loop_dynamic(device: torch.device):
|
||||
"""Decode loop on real Llama3.2-1B with pretrained weights.
|
||||
|
||||
@@ -286,7 +285,6 @@ def _gpu_mem(label):
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.xfail(reason="numerical precision — max_diff exceeds atol")
|
||||
def test_hf_llama3_full(device: torch.device):
|
||||
"""HuggingFace LlamaForCausalLM — full Llama3.2-1B with real pretrained weights.
|
||||
|
||||
@@ -338,7 +336,6 @@ def test_hf_llama3_full(device: torch.device):
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.xfail(reason="numerical precision — max_diff exceeds atol")
|
||||
def test_hf_llama3_large_full(device: torch.device):
|
||||
"""HuggingFace LlamaForCausalLM — full Llama-3.1-8B-Instruct with real pretrained weights.
|
||||
|
||||
@@ -365,7 +362,7 @@ def test_hf_llama3_large_full(device: torch.device):
|
||||
with torch.no_grad():
|
||||
ref = model(input_ids)
|
||||
out = compiled(input_ids)
|
||||
assert torch.allclose(out.logits, ref.logits, atol=1e-5), (
|
||||
assert torch.allclose(out.logits, ref.logits, atol=1e-4), (
|
||||
f"max_diff={torch.max(torch.abs(out.logits - ref.logits)).item():.2e}"
|
||||
)
|
||||
|
||||
@@ -420,7 +417,6 @@ def test_dynamic_dim_reuse_no_recompile(device: torch.device):
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.xfail(reason="numerical precision — max_diff exceeds atol")
|
||||
def test_hf_llama38b_full(device: torch.device):
|
||||
"""HuggingFace LlamaForCausalLM — full Llama-3.1-8B-Instruct with real pretrained weights.
|
||||
|
||||
@@ -447,7 +443,7 @@ def test_hf_llama38b_full(device: torch.device):
|
||||
with torch.no_grad():
|
||||
ref = model(input_ids)
|
||||
out = compiled(input_ids)
|
||||
assert torch.allclose(out.logits, ref.logits, atol=1e-5), (
|
||||
assert torch.allclose(out.logits, ref.logits, atol=1e-4), (
|
||||
f"max_diff={torch.max(torch.abs(out.logits - ref.logits)).item():.2e}"
|
||||
)
|
||||
|
||||
|
||||
47
crates/luminal_python/tests/test_torch_dtype_parity.py
Normal file
47
crates/luminal_python/tests/test_torch_dtype_parity.py
Normal file
@@ -0,0 +1,47 @@
|
||||
"""Pin luminal's Rust `TorchDType` enum to PyTorch's PT2 schema.
|
||||
|
||||
The PT2 export pipeline wire-serializes dtypes as `u32` codes drawn from
|
||||
`torch._export.serde.schema.ScalarType`. luminal mirrors that enum in
|
||||
`crates/luminal_python/rust/src/torch_dtype.rs` and depends on the
|
||||
discriminants matching exactly. If PyTorch renumbers, adds, or removes a
|
||||
variant, this test fails loudly at CI time — better than a silent
|
||||
miscompile at runtime.
|
||||
"""
|
||||
|
||||
from torch._export.serde.schema import ScalarType
|
||||
|
||||
# `_torch_dtype_codes` is the pyo3-exported map `{variant_name: pt2_code}`.
|
||||
from luminal.luminal import _torch_dtype_codes
|
||||
|
||||
|
||||
def test_rust_variants_match_pytorch():
|
||||
"""Every Rust variant must agree with PyTorch's code for the same name."""
|
||||
rust = _torch_dtype_codes()
|
||||
pt = {v.name: v.value for v in ScalarType}
|
||||
mismatches = []
|
||||
for name, code in rust.items():
|
||||
if name not in pt:
|
||||
mismatches.append(f"{name}: luminal={code}, pytorch=<missing variant>")
|
||||
elif pt[name] != code:
|
||||
mismatches.append(f"{name}: luminal={code}, pytorch={pt[name]}")
|
||||
assert not mismatches, (
|
||||
"torch_dtype.rs and PyTorch's ScalarType have drifted:\n "
|
||||
+ "\n ".join(mismatches)
|
||||
)
|
||||
|
||||
|
||||
def test_no_pytorch_variants_missing_from_rust():
|
||||
"""Surface new PyTorch variants so we know to extend the Rust enum.
|
||||
|
||||
Failure here doesn't necessarily indicate a bug — it just means
|
||||
PyTorch added a dtype (e.g. a new float8 variant) and luminal should
|
||||
decide whether to mirror it. Update `TorchDType::ALL` in
|
||||
`torch_dtype.rs` plus the `TryFrom` impls to resolve.
|
||||
"""
|
||||
rust = _torch_dtype_codes()
|
||||
missing = [v.name for v in ScalarType if v.name not in rust]
|
||||
assert not missing, (
|
||||
"PyTorch ScalarType variants not mirrored in luminal::TorchDType: "
|
||||
f"{missing}. Extend TorchDType::ALL in torch_dtype.rs and decide "
|
||||
"whether each maps to a luminal DType variant."
|
||||
)
|
||||
@@ -37,7 +37,7 @@ use std::fs::File;
|
||||
use std::io::BufWriter;
|
||||
use std::time::Instant;
|
||||
|
||||
use luminal::graph::BuildSearchSpaceOptions;
|
||||
use luminal::graph::CompileOptions;
|
||||
use luminal::prelude::*;
|
||||
use luminal_cuda_lite::{cudarc::driver::CudaContext, runtime::CudaRuntime};
|
||||
use rand::{Rng, SeedableRng, rngs::StdRng};
|
||||
@@ -53,6 +53,10 @@ fn env_usize(name: &str, default: usize) -> usize {
|
||||
.unwrap_or(default)
|
||||
}
|
||||
|
||||
fn search_options() -> CompileOptions {
|
||||
CompileOptions::default().search_graph_limit(env_usize("SEARCH_ITERS", 5))
|
||||
}
|
||||
|
||||
fn env_f32(name: &str, default: f32) -> f32 {
|
||||
std::env::var(name)
|
||||
.ok()
|
||||
@@ -159,11 +163,9 @@ fn run_text_encoder(prompt: &str) -> Result<Vec<f32>, Box<dyn std::error::Error>
|
||||
s.parse::<usize>()
|
||||
.map_err(|_| std::env::VarError::NotPresent)
|
||||
}) {
|
||||
cx.build_search_space_with_options::<CudaRuntime>(
|
||||
BuildSearchSpaceOptions::new().max_memory_gib(g),
|
||||
);
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default().max_memory_gib(g));
|
||||
} else {
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
}
|
||||
|
||||
let ctx = CudaContext::new(0).unwrap();
|
||||
@@ -189,7 +191,7 @@ fn run_text_encoder(prompt: &str) -> Result<Vec<f32>, Box<dyn std::error::Error>
|
||||
|
||||
println!("Compiling text encoder...");
|
||||
let t0 = Instant::now();
|
||||
runtime = cx.search(runtime, env_usize("SEARCH_ITERS", 5));
|
||||
runtime = cx.search(runtime, search_options());
|
||||
println!(" compile done in {:.1}s", t0.elapsed().as_secs_f64());
|
||||
|
||||
println!("Encoding prompt...");
|
||||
@@ -301,11 +303,9 @@ fn run_full_pipeline(
|
||||
s.parse::<usize>()
|
||||
.map_err(|_| std::env::VarError::NotPresent)
|
||||
}) {
|
||||
cx.build_search_space_with_options::<CudaRuntime>(
|
||||
BuildSearchSpaceOptions::new().max_memory_gib(g),
|
||||
);
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default().max_memory_gib(g));
|
||||
} else {
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
}
|
||||
|
||||
let ctx = CudaContext::new(0).unwrap();
|
||||
@@ -349,10 +349,9 @@ fn run_full_pipeline(
|
||||
{
|
||||
use rand::SeedableRng;
|
||||
let mut rng = rand::rngs::SmallRng::seed_from_u64(seed);
|
||||
let opts = luminal::graph::SearchOptions::new(env_usize("SEARCH_ITERS", 5));
|
||||
runtime = cx.search_options(runtime, opts, &mut rng);
|
||||
runtime = cx.search_with_rng(runtime, search_options(), &mut rng);
|
||||
} else {
|
||||
runtime = cx.search(runtime, env_usize("SEARCH_ITERS", 5));
|
||||
runtime = cx.search(runtime, search_options());
|
||||
}
|
||||
println!(" compile done in {:.1}s", t0.elapsed().as_secs_f64());
|
||||
|
||||
@@ -409,11 +408,9 @@ fn run_full_pipeline(
|
||||
s.parse::<usize>()
|
||||
.map_err(|_| std::env::VarError::NotPresent)
|
||||
}) {
|
||||
cx.build_search_space_with_options::<CudaRuntime>(
|
||||
BuildSearchSpaceOptions::new().max_memory_gib(g),
|
||||
);
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default().max_memory_gib(g));
|
||||
} else {
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
}
|
||||
|
||||
let ctx = CudaContext::new(0).unwrap();
|
||||
@@ -421,7 +418,7 @@ fn run_full_pipeline(
|
||||
let mut runtime = CudaRuntime::initialize(stream);
|
||||
runtime.load_safetensors(&cx, vae_path.to_str().unwrap());
|
||||
runtime.set_data(latent_in, vae_input);
|
||||
runtime = cx.search(runtime, env_usize("SEARCH_ITERS", 5));
|
||||
runtime = cx.search(runtime, search_options());
|
||||
runtime.execute(&cx.dyn_map);
|
||||
let img = runtime.get_f32(out);
|
||||
// VaeDecoder output is in roughly [-1, 1] range. Diffusers'
|
||||
|
||||
@@ -68,20 +68,10 @@ pub const WEIGHT_DTYPE: DType = DType::Bf16;
|
||||
// =============================================================================
|
||||
|
||||
fn linear_no_bias(x: GraphTensor, w: GraphTensor) -> GraphTensor {
|
||||
// Direct mixed-precision kernel: F32 A × BF16 B^T → F32 (M, N), with the
|
||||
// BF16 → F32 conversion happening on each load inside the kernel rather
|
||||
// than as a separate cast op. This keeps the BF16 weight in memory as-is
|
||||
// (a 24 GB → 48 GB cast for the full encoder would not fit on the GPU)
|
||||
// and bypasses the egglog matmul lowering, where the cublaslt 2D rule
|
||||
// doesn't reliably fire for these shapes — see kernel::matmul2d's docs.
|
||||
//
|
||||
// Falls back to the standard `x.matmul(w.cast(x.dtype).t())` lowering
|
||||
// for ranks > 2 (e.g. attention's batched (heads, seq, head_dim) form),
|
||||
// since the custom kernel is only 2D.
|
||||
if x.shape.len() == 2 && w.shape.len() == 2 {
|
||||
luminal_cuda_lite::kernel::linear_no_bias_bf16_w(x, w)
|
||||
if x.dtype == w.dtype {
|
||||
x.matmul(w.t())
|
||||
} else {
|
||||
x.matmul(w.cast(x.dtype).t())
|
||||
x.cast(w.dtype).matmul(w.t()).cast(x.dtype)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -131,13 +121,6 @@ fn apply_rope(x: GraphTensor, pos_ids: GraphTensor, n_heads: usize, theta: f32)
|
||||
/// Standard scaled dot-product attention over `(n_heads, seq_q, head_dim)`,
|
||||
/// `(n_heads, seq_k, head_dim)`, `(n_heads, seq_k, head_dim)` with a causal
|
||||
/// mask. Returns `(seq_q, n_heads * head_dim)`.
|
||||
///
|
||||
/// Routes the two batched matmuls through `kernel::matmul_3d_t` /
|
||||
/// `matmul_3d` rather than the egglog matmul lowering. The standard path
|
||||
/// has the same problem the VAE attention had (cublaslt batched rules
|
||||
/// fail to fire reliably; the broadcast Mul + SumReduce fallback creates
|
||||
/// a `(n_heads, M, N, K)` intermediate that scales O(seq²) and OOMs at
|
||||
/// seq_len ≥ ~256 even with BF16 weights elsewhere).
|
||||
fn causal_sdpa(
|
||||
q: GraphTensor,
|
||||
k: GraphTensor,
|
||||
@@ -148,13 +131,16 @@ fn causal_sdpa(
|
||||
let n_heads = q.dims()[0];
|
||||
let seq = q.dims()[1];
|
||||
let scale = (HEAD_DIM as f32).sqrt().recip();
|
||||
// The kernel needs contiguous batches; a `* 1.0` after the upstream
|
||||
// transpose / GQA-expand chain materialises the strided view.
|
||||
// Materialize strided views from the upstream transpose / GQA-expand chain
|
||||
// before expressing attention as HLIR matmuls. Today the generic batched
|
||||
// matmul fallback can handle those arbitrary strides correctly, but the
|
||||
// full model becomes too memory-heavy unless cuBLASLt sees contiguous
|
||||
// per-head matrices.
|
||||
let q = q * 1.0_f32;
|
||||
let k = k * 1.0_f32;
|
||||
let v = v * 1.0_f32;
|
||||
// Q @ K^T: (heads, seq, head_dim) @ (heads, seq, head_dim)^T = (heads, seq, seq).
|
||||
let scores = luminal_cuda_lite::kernel::matmul_3d_t(q, k) * scale;
|
||||
let scores = q.matmul(k.transpose(1, 2)) * scale;
|
||||
// Causal mask: positions where k_pos > q_pos are masked.
|
||||
let q_pos = cx.arange(seq).cast(DType::F32);
|
||||
let k_pos = cx.arange(seq).cast(DType::F32);
|
||||
@@ -177,13 +163,9 @@ fn causal_sdpa(
|
||||
let masked = scores + mask * (-1e10_f32);
|
||||
let weights = masked.softmax(2);
|
||||
// attn = weights @ v: (heads, seq, seq) @ (heads, seq, head_dim) = (heads, seq, head_dim).
|
||||
let attn = luminal_cuda_lite::kernel::matmul_3d(weights, v);
|
||||
// `transpose(0, 1).merge_dims(1, 2)` produces the merge_dims
|
||||
// non-contiguous K stride `(((z/HEAD_DIM)*HEAD_DIM)*SEQ)+(z%HEAD_DIM)`.
|
||||
// The cublaslt 2D rule requires `K stride = MIter` (contiguous), so
|
||||
// without forcing materialization here the downstream o_proj matmul
|
||||
// falls through to a broadcast Mul whose `(SEQ, HIDDEN, KV_DIM)`
|
||||
// intermediate is ~20 GB BF16 and OOMs the GPU during search.
|
||||
let attn = weights.matmul(v);
|
||||
// `transpose(0, 1).merge_dims(1, 2)` produces a non-contiguous K stride;
|
||||
// materialize before the downstream o_proj matmul.
|
||||
attn.transpose(0, 1).merge_dims(1, 2) * 1.0_f32 // (seq_q, n_heads*head_dim)
|
||||
}
|
||||
|
||||
@@ -372,8 +354,26 @@ pub fn format_chat(system_message: &str, user_prompt: &str) -> String {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use luminal::hlir::CustomOpKind;
|
||||
|
||||
use super::*;
|
||||
|
||||
fn assert_no_custom_ops(cx: &Graph) {
|
||||
assert!(
|
||||
cx.custom_ops.is_empty(),
|
||||
"Flux2 text encoder helpers should use pure HLIR, not registered CustomOp wrappers"
|
||||
);
|
||||
let custom_nodes: Vec<_> = cx
|
||||
.graph
|
||||
.node_indices()
|
||||
.filter(|&node| cx.try_get_op::<CustomOpKind>(node).is_some())
|
||||
.collect();
|
||||
assert!(
|
||||
custom_nodes.is_empty(),
|
||||
"Flux2 text encoder graph contains CustomOpKind nodes: {custom_nodes:?}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn chat_template_matches_jinja_output() {
|
||||
// Sanity check: the result is the deterministic concatenation we
|
||||
@@ -395,4 +395,23 @@ mod tests {
|
||||
// hidden_states[30] requires running 30 layers (0..29 inclusive).
|
||||
assert_eq!(NUM_LAYERS_USED, *TAP_LAYERS.iter().max().unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn text_encoder_helpers_use_no_custom_ops() {
|
||||
let mut cx = Graph::default();
|
||||
|
||||
let x = cx.named_tensor("x", (2usize, 3usize));
|
||||
let w = cx
|
||||
.named_tensor("w", (4usize, 3usize))
|
||||
.as_dtype(WEIGHT_DTYPE);
|
||||
let _ = linear_no_bias(x, w).output();
|
||||
|
||||
let q = cx.named_tensor("q", (1usize, 2usize, HEAD_DIM));
|
||||
let k = cx.named_tensor("k", (1usize, 2usize, HEAD_DIM));
|
||||
let v = cx.named_tensor("v", (1usize, 2usize, HEAD_DIM));
|
||||
let mask = cx.named_tensor("attention_mask", 2usize);
|
||||
let _ = causal_sdpa(q, k, v, mask).output();
|
||||
|
||||
assert_no_custom_ops(&cx);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -120,26 +120,10 @@ pub const WEIGHT_DTYPE: DType = DType::Bf16;
|
||||
// =============================================================================
|
||||
|
||||
fn linear_no_bias(x: GraphTensor, w: GraphTensor) -> GraphTensor {
|
||||
// For 2D inputs we go through `kernel::linear_no_bias_bf16_w`, which
|
||||
// is a direct mixed-precision SGEMM (F32 A × BF16 B^T → F32) that
|
||||
// converts BF16 → F32 on each load instead of materializing a
|
||||
// separate F32 cast tensor. Two reasons we don't use the egglog
|
||||
// matmul lowering for these:
|
||||
// 1. The cublaslt 2D rule fails to fire reliably for some matmul
|
||||
// shapes (see kernel::matmul2d's docs); even one bad genome
|
||||
// pick on the broadcast Mul + SumReduce fallback creates an
|
||||
// `(M, N, K)` intermediate that OOMs the GPU.
|
||||
// 2. Explicitly casting all BF16 weights to F32 first would more
|
||||
// than double the transformer's working set (~120 GB) and
|
||||
// wouldn't fit. The kernel keeps weights as BF16 in memory.
|
||||
//
|
||||
// Higher-rank cases (3D batched matmul inside attention) fall
|
||||
// through to the standard matmul lowering — those go through the
|
||||
// separate `matmul_3d` / `matmul_3d_t` helpers in `sdpa` below.
|
||||
if x.shape.len() == 2 && w.shape.len() == 2 {
|
||||
luminal_cuda_lite::kernel::linear_no_bias_bf16_w(x, w)
|
||||
if x.dtype == w.dtype {
|
||||
x.matmul(w.t())
|
||||
} else {
|
||||
x.matmul(w.cast(x.dtype).t())
|
||||
x.cast(w.dtype).matmul(w.t()).cast(x.dtype)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -191,20 +175,20 @@ fn apply_rope(x: GraphTensor, cos: GraphTensor, sin: GraphTensor) -> GraphTensor
|
||||
|
||||
/// Scaled dot-product attention with NO mask, no causal: standard SDPA.
|
||||
/// q, k, v: `(H, S, D)`. Returns `(S, H, D)`.
|
||||
///
|
||||
/// Routes through the direct batched matmul kernels for the same reason
|
||||
/// the text encoder does — see `text_encoder::causal_sdpa` for context.
|
||||
fn sdpa(q: GraphTensor, k: GraphTensor, v: GraphTensor) -> GraphTensor {
|
||||
let head_dim = q.dims()[2].to_usize().expect("head_dim must be static");
|
||||
let scale = (head_dim as f32).sqrt().recip();
|
||||
// The kernel needs contiguous batches; materialize the strided views
|
||||
// produced upstream (transpose / split_dims chains).
|
||||
// Materialize the strided views produced upstream (transpose /
|
||||
// split_dims chains) before expressing attention as HLIR matmuls. cuBLASLt
|
||||
// can represent the leading dimensions, but the current rewrite rules do
|
||||
// not yet match the interleaved per-head layout, so omitting these copies
|
||||
// falls back to a much larger generic plan in the full Flux2 graph.
|
||||
let q = q * 1.0_f32;
|
||||
let k = k * 1.0_f32;
|
||||
let v = v * 1.0_f32;
|
||||
let scores = luminal_cuda_lite::kernel::matmul_3d_t(q, k) * scale; // (H, S, S)
|
||||
let scores = q.matmul(k.transpose(1, 2)) * scale; // (H, S, S)
|
||||
let attn_w = scores.softmax(2);
|
||||
let attn = luminal_cuda_lite::kernel::matmul_3d(attn_w, v); // (H, S, D)
|
||||
let attn = attn_w.matmul(v); // (H, S, D)
|
||||
attn.transpose(0, 1) // (S, H, D)
|
||||
}
|
||||
|
||||
@@ -518,9 +502,8 @@ impl SingleStreamAttn {
|
||||
let q = q.transpose(0, 1);
|
||||
let k = k.transpose(0, 1);
|
||||
let v = v.transpose(0, 1);
|
||||
// `merge_dims(1, 2)` on (S, H, D) produces non-contiguous K
|
||||
// stride; force materialization so cublaslt can match the
|
||||
// downstream `to_out` matmul. See dual-stream block above.
|
||||
// `merge_dims(1, 2)` on (S, H, D) produces non-contiguous K stride;
|
||||
// materialize before the downstream `to_out` matmul.
|
||||
let attn = sdpa(q, k, v).merge_dims(1, 2) * 1.0_f32; // (S, HIDDEN)
|
||||
|
||||
let mlp = swiglu(mlp_in); // (S, MLP_HIDDEN)
|
||||
@@ -915,3 +898,44 @@ impl Flux2Transformer {
|
||||
// =============================================================================
|
||||
// Tests
|
||||
// =============================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use luminal::hlir::CustomOpKind;
|
||||
|
||||
use super::*;
|
||||
|
||||
fn assert_no_custom_ops(cx: &Graph) {
|
||||
assert!(
|
||||
cx.custom_ops.is_empty(),
|
||||
"Flux2 transformer helpers should use pure HLIR, not registered CustomOp wrappers"
|
||||
);
|
||||
let custom_nodes: Vec<_> = cx
|
||||
.graph
|
||||
.node_indices()
|
||||
.filter(|&node| cx.try_get_op::<CustomOpKind>(node).is_some())
|
||||
.collect();
|
||||
assert!(
|
||||
custom_nodes.is_empty(),
|
||||
"Flux2 transformer graph contains CustomOpKind nodes: {custom_nodes:?}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn transformer_helpers_use_no_custom_ops() {
|
||||
let mut cx = Graph::default();
|
||||
|
||||
let x = cx.named_tensor("x", (3usize, 4usize));
|
||||
let w = cx
|
||||
.named_tensor("w", (5usize, 4usize))
|
||||
.as_dtype(WEIGHT_DTYPE);
|
||||
let _ = linear_no_bias(x, w).output();
|
||||
|
||||
let q = cx.named_tensor("q", (2usize, 3usize, 4usize));
|
||||
let k = cx.named_tensor("k", (2usize, 3usize, 4usize));
|
||||
let v = cx.named_tensor("v", (2usize, 3usize, 4usize));
|
||||
let _ = sdpa(q, k, v).output();
|
||||
|
||||
assert_no_custom_ops(&cx);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
//!
|
||||
//! ## Status
|
||||
//!
|
||||
//! - All three primitives (`conv2d_bias`, `group_norm`, `nearest_upsample_2x`)
|
||||
//! - All three building blocks (`conv2d_bias`, `group_norm`, `nearest_upsample_2x`)
|
||||
//! are implemented and **individually validated** against numerical
|
||||
//! references — see the tests at the bottom of this file.
|
||||
//! - Stitching them into the full decoder currently hits a `luminal_cuda_lite`
|
||||
@@ -70,14 +70,9 @@ fn decoder_block_channels(block_idx: usize) -> (usize, usize) {
|
||||
// HLIR primitive helpers
|
||||
// =============================================================================
|
||||
|
||||
/// 2D convolution with bias on a `(C_in, H, W)` input, weights stored as
|
||||
/// `(C_out, C_in, K, K)` flat-loaded, bias as `(C_out,)`. Returns
|
||||
/// 2D convolution with bias on a `(C_in, H, W)` input, weights stored flat as
|
||||
/// `(C_out, C_in * K * K)`, bias as `(C_out,)`. Returns
|
||||
/// `(C_out, H_out, W_out)` where `H_out = (H + 2*padding - kernel) / stride + 1`.
|
||||
///
|
||||
/// Wraps the direct conv kernel from [`luminal_cuda_lite::kernel::conv2d_bias`]
|
||||
/// (one CUDA thread per output element), which avoids materializing the
|
||||
/// `(H_out*W_out, C_in*K*K)` unfold intermediate that earlier HLIR-only
|
||||
/// implementations needed.
|
||||
fn conv2d_bias(
|
||||
x: GraphTensor,
|
||||
weight: GraphTensor,
|
||||
@@ -86,7 +81,58 @@ fn conv2d_bias(
|
||||
stride: usize,
|
||||
padding: usize,
|
||||
) -> GraphTensor {
|
||||
luminal_cuda_lite::kernel::conv2d_bias(x, weight, bias, kernel, stride, padding)
|
||||
let dims = x.dims();
|
||||
assert_eq!(dims.len(), 3, "conv2d_bias expects (C, H, W)");
|
||||
let h = dims[1];
|
||||
let w = dims[2];
|
||||
|
||||
if kernel == 1 && stride == 1 && padding == 0 {
|
||||
let xt = x.permute(&[1, 2, 0]).merge_dims(0, 1); // (H*W, C_in)
|
||||
let out = xt.matmul(weight.t()); // (H*W, C_out)
|
||||
let out = out.split_dims(0, w).permute(&[2, 0, 1]); // (C_out, H, W)
|
||||
return out + bias.expand_dim(1, h).expand_dim(2, w);
|
||||
}
|
||||
|
||||
let zero = Expression::from(0);
|
||||
let pad = Expression::from(padding);
|
||||
let padded = if padding > 0 {
|
||||
x.pad(vec![(zero, zero), (pad, pad), (pad, pad)], 0.0)
|
||||
} else {
|
||||
x
|
||||
};
|
||||
|
||||
let unfolded = padded.unfold(
|
||||
vec![1usize, kernel, kernel],
|
||||
vec![1usize, stride, stride],
|
||||
vec![1usize, 1, 1],
|
||||
);
|
||||
let output_spatial_dims = unfolded.dims()[1..3].to_vec();
|
||||
|
||||
// (C, H_out, W_out, 1, K, K) -> (H_out, W_out, C, K, K)
|
||||
let mut patches = unfolded.squeeze(3).permute(&[1, 2, 0, 3, 4]);
|
||||
while patches.dims().len() > 3 {
|
||||
let last = patches.dims().len();
|
||||
patches = patches.merge_dims(last - 2, last - 1);
|
||||
}
|
||||
let patches = patches.merge_dims(0, 1); // (H_out*W_out, C_in*K*K)
|
||||
|
||||
let out = patches.matmul(weight.t()); // (H_out*W_out, C_out)
|
||||
let out = out
|
||||
.split_dims(0, output_spatial_dims[1])
|
||||
.permute(&[2, 0, 1]); // (C_out, H_out, W_out)
|
||||
let out_dims = out.dims();
|
||||
out + bias.expand_dim(1, out_dims[1]).expand_dim(2, out_dims[2])
|
||||
}
|
||||
|
||||
fn linear_bias(x: GraphTensor, weight: GraphTensor, bias: GraphTensor) -> GraphTensor {
|
||||
let out = x.matmul(weight.cast(x.dtype).t());
|
||||
let out_dims = out.dims();
|
||||
match out_dims.len() {
|
||||
1 => out + bias,
|
||||
2 => out + bias.expand_dim(0, out_dims[0]),
|
||||
3 => out + bias.expand_dim(0, out_dims[0]).expand_dim(1, out_dims[1]),
|
||||
n => panic!("linear_bias: unsupported rank {n}"),
|
||||
}
|
||||
}
|
||||
|
||||
/// PyTorch-style GroupNorm on a (C, H, W) tensor.
|
||||
@@ -148,9 +194,7 @@ fn group_norm(
|
||||
fn nearest_upsample_2x(x: GraphTensor) -> GraphTensor {
|
||||
// (C, H, W) -> (C, H, 2, W) -> (C, 2H, W) -> (C, 2H, W, 2) -> (C, 2H, 2W)
|
||||
let stage1 = x.expand_dim(2, 2_usize).merge_dims(1, 2);
|
||||
let stage2 = stage1.expand_dim(3, 2_usize).merge_dims(2, 3);
|
||||
// Materialize the broadcast view so subsequent ops see contiguous strides.
|
||||
stage2 + 0.0_f32
|
||||
stage1.expand_dim(3, 2_usize).merge_dims(2, 3)
|
||||
}
|
||||
|
||||
/// SiLU = x * sigmoid(x).
|
||||
@@ -300,30 +344,21 @@ impl AttnBlock {
|
||||
NORM_NUM_GROUPS,
|
||||
NORM_EPS,
|
||||
);
|
||||
// (C, H, W) -> (C, H*W) -> (H*W, C). The transpose at the end leaves
|
||||
// a column-major view, which the direct matmul kernels assume away;
|
||||
// `* 1.0` forces a contiguous row-major materialization.
|
||||
let merged = normed.merge_dims(1, 2).transpose(0, 1) * 1.0_f32;
|
||||
// (C, H, W) -> (C, H*W) -> (H*W, C). This is a column-major view
|
||||
// that cuBLASLt can consume directly.
|
||||
let merged = normed.merge_dims(1, 2).transpose(0, 1);
|
||||
|
||||
// Q, K, V projections — direct kernel routes around the cublaslt
|
||||
// 2D rule, which silently fails to fire for some of these matmuls
|
||||
// and lets search occasionally pick the broadcast Mul + SumReduce
|
||||
// fallback. At 1024² the bad path on `q @ kᵀ` allocates a
|
||||
// `(HW, HW, C) = (16384, 16384, 512)` ≈ 524 GiB intermediate.
|
||||
let q = luminal_cuda_lite::kernel::linear_bias(merged, self.to_q_w, self.to_q_b);
|
||||
let k = luminal_cuda_lite::kernel::linear_bias(merged, self.to_k_w, self.to_k_b);
|
||||
let v = luminal_cuda_lite::kernel::linear_bias(merged, self.to_v_w, self.to_v_b);
|
||||
let q = linear_bias(merged, self.to_q_w, self.to_q_b);
|
||||
let k = linear_bias(merged, self.to_k_w, self.to_k_b);
|
||||
let v = linear_bias(merged, self.to_v_w, self.to_v_b);
|
||||
|
||||
// Standard scaled dot-product attention over the spatial axis.
|
||||
// `q @ kᵀ` with k stored row-major as `(HW, C)`: matmul_2d_t handles
|
||||
// the transpose without materialising k as a separate tensor.
|
||||
let scale = (self.channels as f32).sqrt().recip();
|
||||
let scores = luminal_cuda_lite::kernel::matmul_2d_t(q, k) * scale;
|
||||
let scores = q.matmul(k.t()) * scale;
|
||||
let attn_w = scores.softmax(1);
|
||||
// attn_w is (HW, HW) row-major, v is (HW, C) row-major; plain matmul.
|
||||
let attn = luminal_cuda_lite::kernel::matmul_2d(attn_w, v);
|
||||
let attn = attn_w.matmul(v);
|
||||
|
||||
let out = luminal_cuda_lite::kernel::linear_bias(attn, self.to_out_w, self.to_out_b);
|
||||
let out = linear_bias(attn, self.to_out_w, self.to_out_b);
|
||||
// (H*W, C) -> (C, H*W) -> (C, H, W)
|
||||
let out = out.transpose(0, 1).split_dims(1, w);
|
||||
residual + out
|
||||
@@ -500,3 +535,342 @@ impl VaeDecoder {
|
||||
conv2d_bias(x, self.conv_out_w, self.conv_out_b, 3, 1, 1)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use luminal::hlir::CustomOpKind;
|
||||
use luminal_cuda_lite::{cudarc::driver::CudaContext, runtime::CudaRuntime};
|
||||
|
||||
use super::*;
|
||||
|
||||
fn assert_no_custom_ops(cx: &Graph) {
|
||||
assert!(
|
||||
cx.custom_ops.is_empty(),
|
||||
"Flux2 VAE helpers should use pure HLIR, not registered CustomOp wrappers"
|
||||
);
|
||||
let custom_nodes: Vec<_> = cx
|
||||
.graph
|
||||
.node_indices()
|
||||
.filter(|&node| cx.try_get_op::<CustomOpKind>(node).is_some())
|
||||
.collect();
|
||||
assert!(
|
||||
custom_nodes.is_empty(),
|
||||
"Flux2 VAE graph contains CustomOpKind nodes: {custom_nodes:?}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn vae_helpers_use_no_custom_ops() {
|
||||
let mut cx = Graph::default();
|
||||
|
||||
let x = cx.named_tensor("x", (2usize, 3usize, 3usize));
|
||||
let conv_w = cx.named_tensor("conv_w", (4usize, 2usize * 3 * 3));
|
||||
let conv_b = cx.named_tensor("conv_b", 4usize);
|
||||
let _ = conv2d_bias(x, conv_w, conv_b, 3, 1, 1).output();
|
||||
|
||||
let lin_x = cx.named_tensor("lin_x", (2usize, 3usize));
|
||||
let lin_w = cx.named_tensor("lin_w", (4usize, 3usize));
|
||||
let lin_b = cx.named_tensor("lin_b", 4usize);
|
||||
let _ = linear_bias(lin_x, lin_w, lin_b).output();
|
||||
|
||||
assert_no_custom_ops(&cx);
|
||||
}
|
||||
|
||||
struct Conv2dCase {
|
||||
c_in: usize,
|
||||
h: usize,
|
||||
w: usize,
|
||||
c_out: usize,
|
||||
kernel: usize,
|
||||
stride: usize,
|
||||
padding: usize,
|
||||
}
|
||||
|
||||
fn reference_conv2d_bias(
|
||||
input: &[f32],
|
||||
weight: &[f32],
|
||||
bias: &[f32],
|
||||
case: Conv2dCase,
|
||||
) -> Vec<f32> {
|
||||
let Conv2dCase {
|
||||
c_in,
|
||||
h,
|
||||
w,
|
||||
c_out,
|
||||
kernel,
|
||||
stride,
|
||||
padding,
|
||||
} = case;
|
||||
let h_out = (h + 2 * padding - kernel) / stride + 1;
|
||||
let w_out = (w + 2 * padding - kernel) / stride + 1;
|
||||
let mut out = vec![0.0_f32; c_out * h_out * w_out];
|
||||
for co in 0..c_out {
|
||||
for oy in 0..h_out {
|
||||
for ox in 0..w_out {
|
||||
let mut acc = bias[co];
|
||||
for ci in 0..c_in {
|
||||
for ky in 0..kernel {
|
||||
for kx in 0..kernel {
|
||||
let iy_padded = oy * stride + ky;
|
||||
let ix_padded = ox * stride + kx;
|
||||
if iy_padded < padding || ix_padded < padding {
|
||||
continue;
|
||||
}
|
||||
let iy = iy_padded - padding;
|
||||
let ix = ix_padded - padding;
|
||||
if iy >= h || ix >= w {
|
||||
continue;
|
||||
}
|
||||
let input_idx = ci * h * w + iy * w + ix;
|
||||
let weight_idx = co * c_in * kernel * kernel
|
||||
+ ci * kernel * kernel
|
||||
+ ky * kernel
|
||||
+ kx;
|
||||
acc += input[input_idx] * weight[weight_idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
out[co * h_out * w_out + oy * w_out + ox] = acc;
|
||||
}
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
fn assert_close(actual: &[f32], expected: &[f32]) {
|
||||
assert_eq!(actual.len(), expected.len());
|
||||
for (idx, (a, e)) in actual.iter().zip(expected).enumerate() {
|
||||
assert!(
|
||||
(*a - *e).abs() < 1e-4,
|
||||
"value mismatch at {idx}: got {a}, expected {e}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn reference_nearest_upsample_2x(input: &[f32], c: usize, h: usize, w: usize) -> Vec<f32> {
|
||||
let mut out = vec![0.0_f32; c * h * 2 * w * 2];
|
||||
for ci in 0..c {
|
||||
for y in 0..h {
|
||||
for x in 0..w {
|
||||
let value = input[ci * h * w + y * w + x];
|
||||
for dy in 0..2 {
|
||||
for dx in 0..2 {
|
||||
let oy = y * 2 + dy;
|
||||
let ox = x * 2 + dx;
|
||||
out[ci * h * 2 * w * 2 + oy * w * 2 + ox] = value;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
struct GroupNormCase {
|
||||
c: usize,
|
||||
h: usize,
|
||||
w: usize,
|
||||
num_groups: usize,
|
||||
eps: f32,
|
||||
}
|
||||
|
||||
fn reference_group_norm(
|
||||
input: &[f32],
|
||||
weight: &[f32],
|
||||
bias: &[f32],
|
||||
case: GroupNormCase,
|
||||
) -> Vec<f32> {
|
||||
let GroupNormCase {
|
||||
c,
|
||||
h,
|
||||
w,
|
||||
num_groups,
|
||||
eps,
|
||||
} = case;
|
||||
let group_size = c / num_groups;
|
||||
let group_volume = group_size * h * w;
|
||||
let mut out = vec![0.0_f32; input.len()];
|
||||
for group in 0..num_groups {
|
||||
let c_start = group * group_size;
|
||||
let mut mean = 0.0_f32;
|
||||
for ci in c_start..c_start + group_size {
|
||||
for idx in 0..h * w {
|
||||
mean += input[ci * h * w + idx];
|
||||
}
|
||||
}
|
||||
mean /= group_volume as f32;
|
||||
|
||||
let mut variance = 0.0_f32;
|
||||
for ci in c_start..c_start + group_size {
|
||||
for idx in 0..h * w {
|
||||
let centered = input[ci * h * w + idx] - mean;
|
||||
variance += centered * centered;
|
||||
}
|
||||
}
|
||||
variance /= group_volume as f32;
|
||||
let inv_std = (variance + eps).sqrt().recip();
|
||||
|
||||
for ci in c_start..c_start + group_size {
|
||||
for idx in 0..h * w {
|
||||
let flat = ci * h * w + idx;
|
||||
out[flat] = (input[flat] - mean) * inv_std * weight[ci] + bias[ci];
|
||||
}
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
fn one_search() -> CompileOptions {
|
||||
CompileOptions::default().search_graph_limit(1)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn conv2d_bias_matches_reference() {
|
||||
let mut cx = Graph::default();
|
||||
let input_t = cx.named_tensor("input", (2usize, 3usize, 3usize));
|
||||
let weight_t = cx.named_tensor("weight", (2usize, 2usize * 3 * 3));
|
||||
let bias_t = cx.named_tensor("bias", 2usize);
|
||||
let out = conv2d_bias(input_t, weight_t, bias_t, 3, 1, 1).output();
|
||||
|
||||
let input: Vec<f32> = (0..18).map(|i| i as f32 * 0.1 - 0.7).collect();
|
||||
let weight: Vec<f32> = (0..36).map(|i| (i as f32 % 7.0) * 0.05 - 0.15).collect();
|
||||
let bias = vec![0.25_f32, -0.5_f32];
|
||||
let expected = reference_conv2d_bias(
|
||||
&input,
|
||||
&weight,
|
||||
&bias,
|
||||
Conv2dCase {
|
||||
c_in: 2,
|
||||
h: 3,
|
||||
w: 3,
|
||||
c_out: 2,
|
||||
kernel: 3,
|
||||
stride: 1,
|
||||
padding: 1,
|
||||
},
|
||||
);
|
||||
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(NativeRuntime::default(), one_search());
|
||||
rt.set_data(input_t, input);
|
||||
rt.set_data(weight_t, weight);
|
||||
rt.set_data(bias_t, bias);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
assert_close(rt.get_f32(out.id), &expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn nearest_upsample_2x_matches_reference_native() {
|
||||
let mut cx = Graph::default();
|
||||
let input_t = cx.named_tensor("input", (2usize, 3usize, 4usize));
|
||||
let out = nearest_upsample_2x(input_t).output();
|
||||
|
||||
let input: Vec<f32> = (0..2 * 3 * 4).map(|i| i as f32 - 11.0).collect();
|
||||
let expected = reference_nearest_upsample_2x(&input, 2, 3, 4);
|
||||
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(NativeRuntime::default(), one_search());
|
||||
rt.set_data(input_t, input);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
assert_close(rt.get_f32(out.id), &expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn nearest_upsample_2x_matches_reference_cuda() {
|
||||
let Ok(ctx) = CudaContext::new(0) else {
|
||||
return;
|
||||
};
|
||||
ctx.bind_to_thread().unwrap();
|
||||
|
||||
let mut cx = Graph::default();
|
||||
let input_t = cx.named_tensor("input", (2usize, 3usize, 4usize));
|
||||
let out = nearest_upsample_2x(input_t).output();
|
||||
|
||||
let input: Vec<f32> = (0..2 * 3 * 4).map(|i| i as f32 - 11.0).collect();
|
||||
let expected = reference_nearest_upsample_2x(&input, 2, 3, 4);
|
||||
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let mut rt = CudaRuntime::initialize(ctx.default_stream());
|
||||
rt.set_data(input_t, input);
|
||||
rt = cx.search(rt, one_search());
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
assert_close(&rt.get_f32(out.id), &expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn group_norm_matches_reference_native() {
|
||||
let mut cx = Graph::default();
|
||||
let input_t = cx.named_tensor("input", (4usize, 2usize, 3usize));
|
||||
let weight_t = cx.named_tensor("weight", 4usize);
|
||||
let bias_t = cx.named_tensor("bias", 4usize);
|
||||
let out = group_norm(input_t, weight_t, bias_t, 2, 1e-6).output();
|
||||
|
||||
let input: Vec<f32> = (0..4 * 2 * 3).map(|i| i as f32 * 0.2 - 2.0).collect();
|
||||
let weight = vec![0.7_f32, -0.2, 1.3, 0.5];
|
||||
let bias = vec![0.1_f32, -0.3, 0.4, -0.6];
|
||||
let expected = reference_group_norm(
|
||||
&input,
|
||||
&weight,
|
||||
&bias,
|
||||
GroupNormCase {
|
||||
c: 4,
|
||||
h: 2,
|
||||
w: 3,
|
||||
num_groups: 2,
|
||||
eps: 1e-6,
|
||||
},
|
||||
);
|
||||
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(NativeRuntime::default(), one_search());
|
||||
rt.set_data(input_t, input);
|
||||
rt.set_data(weight_t, weight);
|
||||
rt.set_data(bias_t, bias);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
assert_close(rt.get_f32(out.id), &expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn group_norm_matches_reference_cuda() {
|
||||
let Ok(ctx) = CudaContext::new(0) else {
|
||||
return;
|
||||
};
|
||||
ctx.bind_to_thread().unwrap();
|
||||
|
||||
let mut cx = Graph::default();
|
||||
let input_t = cx.named_tensor("input", (4usize, 2usize, 3usize));
|
||||
let weight_t = cx.named_tensor("weight", 4usize);
|
||||
let bias_t = cx.named_tensor("bias", 4usize);
|
||||
let out = group_norm(input_t, weight_t, bias_t, 2, 1e-6).output();
|
||||
|
||||
let input: Vec<f32> = (0..4 * 2 * 3).map(|i| i as f32 * 0.2 - 2.0).collect();
|
||||
let weight = vec![0.7_f32, -0.2, 1.3, 0.5];
|
||||
let bias = vec![0.1_f32, -0.3, 0.4, -0.6];
|
||||
let expected = reference_group_norm(
|
||||
&input,
|
||||
&weight,
|
||||
&bias,
|
||||
GroupNormCase {
|
||||
c: 4,
|
||||
h: 2,
|
||||
w: 3,
|
||||
num_groups: 2,
|
||||
eps: 1e-6,
|
||||
},
|
||||
);
|
||||
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let mut rt = CudaRuntime::initialize(ctx.default_stream());
|
||||
rt.set_data(input_t, input);
|
||||
rt.set_data(weight_t, weight);
|
||||
rt.set_data(bias_t, bias);
|
||||
rt = cx.search(rt, one_search());
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
assert_close(&rt.get_f32(out.id), &expected);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -53,9 +53,20 @@ fn main() {
|
||||
k_out.output();
|
||||
v_out.output();
|
||||
}
|
||||
let max_prefill = (prompt_tokens.len() + 16)
|
||||
.next_power_of_two()
|
||||
.min(max_seq_len);
|
||||
let search_s = 16.min(max_prefill).max(2);
|
||||
let build_options = CompileOptions::default().dim_buckets(
|
||||
's',
|
||||
&[
|
||||
DimBucket::new(1, 1),
|
||||
DimBucket::new(2, max_prefill).representative(search_s),
|
||||
],
|
||||
);
|
||||
|
||||
println!("Building E-Graph...");
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(build_options);
|
||||
|
||||
println!("Loading weights...");
|
||||
let mut runtime = CudaRuntime::initialize(stream);
|
||||
@@ -69,22 +80,12 @@ fn main() {
|
||||
}
|
||||
|
||||
println!("Compiling...");
|
||||
let max_prefill = (prompt_tokens.len() + 16)
|
||||
.next_power_of_two()
|
||||
.min(max_seq_len);
|
||||
let search_s = 16.min(max_prefill).max(2);
|
||||
cx.set_dim_buckets(
|
||||
's',
|
||||
&[
|
||||
DimBucket::new(1, 1),
|
||||
DimBucket::new(2, max_prefill).representative(search_s),
|
||||
],
|
||||
);
|
||||
cx.set_dim('s', search_s);
|
||||
cx.set_dim('p', 0);
|
||||
runtime.set_data(input, vec![1; search_s]);
|
||||
runtime.set_data(token_ids, (0..search_s as i32).collect::<Vec<_>>());
|
||||
runtime = cx.search(runtime, search_graphs);
|
||||
let search_options = CompileOptions::default().search_graph_limit(search_graphs);
|
||||
runtime = cx.search(runtime, search_options);
|
||||
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
|
||||
@@ -36,14 +36,16 @@ impl KVCache {
|
||||
let mut k_caches = Vec::with_capacity(LAYERS);
|
||||
let mut v_caches = Vec::with_capacity(LAYERS);
|
||||
for l in 0..LAYERS {
|
||||
let k = cx
|
||||
.named_tensor(format!("kv_cache.{l}.k"), (N_KV_HEADS, max_seq, HEAD_DIM))
|
||||
.persist();
|
||||
let v = cx
|
||||
.named_tensor(format!("kv_cache.{l}.v"), (N_KV_HEADS, max_seq, HEAD_DIM))
|
||||
.persist();
|
||||
k_caches.push(k);
|
||||
v_caches.push(v);
|
||||
k_caches.push(persist(
|
||||
cx,
|
||||
format!("kv_cache.{l}.k"),
|
||||
(N_KV_HEADS, max_seq, HEAD_DIM),
|
||||
));
|
||||
v_caches.push(persist(
|
||||
cx,
|
||||
format!("kv_cache.{l}.v"),
|
||||
(N_KV_HEADS, max_seq, HEAD_DIM),
|
||||
));
|
||||
}
|
||||
Self {
|
||||
k_caches,
|
||||
@@ -68,114 +70,11 @@ pub struct Gemma {
|
||||
|
||||
impl Gemma {
|
||||
pub fn init(cx: &mut Graph) -> Self {
|
||||
let mut w = vec![];
|
||||
for l in 0..LAYERS {
|
||||
let is_local = (l + 1) % SLIDING_WINDOW_PATTERN != 0;
|
||||
let up = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.mlp.up_proj.weight"),
|
||||
(INTERMEDIATE, HIDDEN),
|
||||
)
|
||||
.persist();
|
||||
let gate = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.mlp.gate_proj.weight"),
|
||||
(INTERMEDIATE, HIDDEN),
|
||||
)
|
||||
.persist();
|
||||
let down = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.mlp.down_proj.weight"),
|
||||
(HIDDEN, INTERMEDIATE),
|
||||
)
|
||||
.persist();
|
||||
let q_proj = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.self_attn.q_proj.weight"),
|
||||
(Q_DIM, HIDDEN),
|
||||
)
|
||||
.persist();
|
||||
let k_proj = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.self_attn.k_proj.weight"),
|
||||
(KV_DIM, HIDDEN),
|
||||
)
|
||||
.persist();
|
||||
let v_proj = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.self_attn.v_proj.weight"),
|
||||
(KV_DIM, HIDDEN),
|
||||
)
|
||||
.persist();
|
||||
let o_proj = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.self_attn.o_proj.weight"),
|
||||
(HIDDEN, Q_DIM),
|
||||
)
|
||||
.persist();
|
||||
let q_norm = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.self_attn.q_norm.weight"),
|
||||
HEAD_DIM,
|
||||
)
|
||||
.persist();
|
||||
let k_norm = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.self_attn.k_norm.weight"),
|
||||
HEAD_DIM,
|
||||
)
|
||||
.persist();
|
||||
w.push(GemmaLayer {
|
||||
up,
|
||||
gate,
|
||||
down,
|
||||
q_proj,
|
||||
k_proj,
|
||||
v_proj,
|
||||
o_proj,
|
||||
q_norm,
|
||||
k_norm,
|
||||
input_layernorm: gemma_norm(
|
||||
HIDDEN,
|
||||
&format!("model.layers.{l}.input_layernorm.weight"),
|
||||
cx,
|
||||
),
|
||||
post_attention_layernorm: gemma_norm(
|
||||
HIDDEN,
|
||||
&format!("model.layers.{l}.post_attention_layernorm.weight"),
|
||||
cx,
|
||||
),
|
||||
pre_feedforward_layernorm: gemma_norm(
|
||||
HIDDEN,
|
||||
&format!("model.layers.{l}.pre_feedforward_layernorm.weight"),
|
||||
cx,
|
||||
),
|
||||
post_feedforward_layernorm: gemma_norm(
|
||||
HIDDEN,
|
||||
&format!("model.layers.{l}.post_feedforward_layernorm.weight"),
|
||||
cx,
|
||||
),
|
||||
is_local,
|
||||
rope_theta: if is_local {
|
||||
ROPE_THETA_LOCAL
|
||||
} else {
|
||||
ROPE_THETA_GLOBAL
|
||||
},
|
||||
rope_scaling_factor: if is_local { 1.0 } else { 8.0 },
|
||||
});
|
||||
}
|
||||
let lm_norm = gemma_norm(HIDDEN, "model.norm.weight", cx);
|
||||
let embedding = cx
|
||||
.named_tensor("model.embed_tokens.weight", (VOCAB_SIZE, HIDDEN))
|
||||
.persist();
|
||||
let lm_head = cx
|
||||
.named_tensor("lm_head.weight", (VOCAB_SIZE, HIDDEN))
|
||||
.persist();
|
||||
Self {
|
||||
embedding,
|
||||
lm_head,
|
||||
layers: w,
|
||||
lm_norm,
|
||||
embedding: persist(cx, "model.embed_tokens.weight", (VOCAB_SIZE, HIDDEN)),
|
||||
lm_head: persist(cx, "lm_head.weight", (VOCAB_SIZE, HIDDEN)),
|
||||
layers: (0..LAYERS).map(|l| GemmaLayer::init(cx, l)).collect(),
|
||||
lm_norm: gemma_norm(HIDDEN, "model.norm.weight", cx),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -185,11 +84,7 @@ impl Gemma {
|
||||
pos_ids: GraphTensor,
|
||||
kv_cache: &KVCache,
|
||||
) -> (GraphTensor, Vec<(GraphTensor, GraphTensor)>) {
|
||||
let seq = token_ids.dims1();
|
||||
let mut x = self.embedding.gather(
|
||||
(token_ids * HIDDEN).expand_dim(1, HIDDEN)
|
||||
+ token_ids.graph().arange(HIDDEN).expand_dim(0, seq),
|
||||
);
|
||||
let mut x = token_embedding(self.embedding, token_ids);
|
||||
let mut cache_outputs = Vec::with_capacity(LAYERS);
|
||||
for (i, layer) in self.layers.iter().enumerate() {
|
||||
let (x_new, k_out, v_out) = layer.forward(
|
||||
@@ -226,6 +121,114 @@ struct GemmaLayer {
|
||||
rope_scaling_factor: f32,
|
||||
}
|
||||
|
||||
impl GemmaLayer {
|
||||
fn init(cx: &mut Graph, l: usize) -> Self {
|
||||
let is_local = !(l + 1).is_multiple_of(SLIDING_WINDOW_PATTERN);
|
||||
Self {
|
||||
up: layer_weight(cx, l, "mlp.up_proj", (INTERMEDIATE, HIDDEN)),
|
||||
gate: layer_weight(cx, l, "mlp.gate_proj", (INTERMEDIATE, HIDDEN)),
|
||||
down: layer_weight(cx, l, "mlp.down_proj", (HIDDEN, INTERMEDIATE)),
|
||||
q_proj: layer_weight(cx, l, "self_attn.q_proj", (Q_DIM, HIDDEN)),
|
||||
k_proj: layer_weight(cx, l, "self_attn.k_proj", (KV_DIM, HIDDEN)),
|
||||
v_proj: layer_weight(cx, l, "self_attn.v_proj", (KV_DIM, HIDDEN)),
|
||||
o_proj: layer_weight(cx, l, "self_attn.o_proj", (HIDDEN, Q_DIM)),
|
||||
q_norm: layer_weight(cx, l, "self_attn.q_norm", HEAD_DIM),
|
||||
k_norm: layer_weight(cx, l, "self_attn.k_norm", HEAD_DIM),
|
||||
input_layernorm: layer_norm(cx, l, "input_layernorm"),
|
||||
post_attention_layernorm: layer_norm(cx, l, "post_attention_layernorm"),
|
||||
pre_feedforward_layernorm: layer_norm(cx, l, "pre_feedforward_layernorm"),
|
||||
post_feedforward_layernorm: layer_norm(cx, l, "post_feedforward_layernorm"),
|
||||
is_local,
|
||||
rope_theta: if is_local {
|
||||
ROPE_THETA_LOCAL
|
||||
} else {
|
||||
ROPE_THETA_GLOBAL
|
||||
},
|
||||
rope_scaling_factor: if is_local { 1.0 } else { 8.0 },
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward(
|
||||
&self,
|
||||
x: GraphTensor,
|
||||
pos_ids: GraphTensor,
|
||||
k_cache_in: GraphTensor,
|
||||
v_cache_in: GraphTensor,
|
||||
max_seq: usize,
|
||||
) -> (GraphTensor, GraphTensor, GraphTensor) {
|
||||
let x_attn = self.input_layernorm.forward(x);
|
||||
let q = x_attn.matmul(self.q_proj.t());
|
||||
let k = x_attn.matmul(self.k_proj.t());
|
||||
let v = x_attn.matmul(self.v_proj.t());
|
||||
|
||||
let q_rope = gemma_rotary_embeddings(
|
||||
qk_norm(q, self.q_norm, N_HEADS),
|
||||
pos_ids,
|
||||
N_HEADS,
|
||||
self.rope_theta,
|
||||
self.rope_scaling_factor,
|
||||
);
|
||||
let k_rope = gemma_rotary_embeddings(
|
||||
qk_norm(k, self.k_norm, N_KV_HEADS),
|
||||
pos_ids,
|
||||
N_KV_HEADS,
|
||||
self.rope_theta,
|
||||
self.rope_scaling_factor,
|
||||
);
|
||||
|
||||
let (attn_out, k_cache_out, v_cache_out) = hlir_attention(
|
||||
q_rope,
|
||||
k_rope,
|
||||
v,
|
||||
k_cache_in,
|
||||
v_cache_in,
|
||||
max_seq,
|
||||
self.is_local,
|
||||
);
|
||||
|
||||
let attn_proj = attn_out.matmul(self.o_proj.t());
|
||||
let x = x + self.post_attention_layernorm.forward(attn_proj);
|
||||
|
||||
let x_ff = self.pre_feedforward_layernorm.forward(x);
|
||||
let mlp_out = (gemma_gelu(x_ff.matmul(self.gate.t())) * x_ff.matmul(self.up.t()))
|
||||
.matmul(self.down.t());
|
||||
(
|
||||
x + self.post_feedforward_layernorm.forward(mlp_out),
|
||||
k_cache_out,
|
||||
v_cache_out,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn persist(
|
||||
cx: &mut Graph,
|
||||
name: impl ToString,
|
||||
shape: impl luminal::prelude::ToShape,
|
||||
) -> GraphTensor {
|
||||
cx.named_tensor(name, shape).persist()
|
||||
}
|
||||
|
||||
fn layer_weight(
|
||||
cx: &mut Graph,
|
||||
layer: usize,
|
||||
suffix: &str,
|
||||
shape: impl luminal::prelude::ToShape,
|
||||
) -> GraphTensor {
|
||||
persist(cx, format!("model.layers.{layer}.{suffix}.weight"), shape)
|
||||
}
|
||||
|
||||
fn layer_norm(cx: &mut Graph, layer: usize, name: &str) -> LayerNorm {
|
||||
gemma_norm(HIDDEN, &format!("model.layers.{layer}.{name}.weight"), cx)
|
||||
}
|
||||
|
||||
fn token_embedding(embedding: GraphTensor, token_ids: GraphTensor) -> GraphTensor {
|
||||
let seq = token_ids.dims1();
|
||||
embedding.gather(
|
||||
(token_ids * HIDDEN).expand_dim(1, HIDDEN)
|
||||
+ token_ids.graph().arange(HIDDEN).expand_dim(0, seq),
|
||||
)
|
||||
}
|
||||
|
||||
/// GELU using the identity: 0.5*x*(1+tanh(a)) = x*sigmoid(2*a)
|
||||
/// This produces far fewer e-graph nodes than the tanh-based expansion.
|
||||
#[allow(clippy::excessive_precision)]
|
||||
@@ -363,59 +366,3 @@ fn hlir_attention(
|
||||
|
||||
(out, k_cache_out, v_cache_out)
|
||||
}
|
||||
|
||||
impl GemmaLayer {
|
||||
pub fn forward(
|
||||
&self,
|
||||
x: GraphTensor,
|
||||
pos_ids: GraphTensor,
|
||||
k_cache_in: GraphTensor,
|
||||
v_cache_in: GraphTensor,
|
||||
max_seq: usize,
|
||||
) -> (GraphTensor, GraphTensor, GraphTensor) {
|
||||
let x_attn = self.input_layernorm.forward(x);
|
||||
let q = x_attn.matmul(self.q_proj.t());
|
||||
let k = x_attn.matmul(self.k_proj.t());
|
||||
let v = x_attn.matmul(self.v_proj.t());
|
||||
|
||||
// QK-norm + RoPE
|
||||
let q_normed = qk_norm(q, self.q_norm, N_HEADS);
|
||||
let k_normed = qk_norm(k, self.k_norm, N_KV_HEADS);
|
||||
let q_rope = gemma_rotary_embeddings(
|
||||
q_normed,
|
||||
pos_ids,
|
||||
N_HEADS,
|
||||
self.rope_theta,
|
||||
self.rope_scaling_factor,
|
||||
);
|
||||
let k_rope = gemma_rotary_embeddings(
|
||||
k_normed,
|
||||
pos_ids,
|
||||
N_KV_HEADS,
|
||||
self.rope_theta,
|
||||
self.rope_scaling_factor,
|
||||
);
|
||||
|
||||
let (attn_out, k_cache_out, v_cache_out) = hlir_attention(
|
||||
q_rope,
|
||||
k_rope,
|
||||
v,
|
||||
k_cache_in,
|
||||
v_cache_in,
|
||||
max_seq,
|
||||
self.is_local,
|
||||
);
|
||||
|
||||
// O projection + post-attention norm + residual
|
||||
let attn_proj = attn_out.matmul(self.o_proj.t());
|
||||
let attn_normed = self.post_attention_layernorm.forward(attn_proj);
|
||||
let x = x + attn_normed;
|
||||
|
||||
// Pre-feedforward norm + MLP + post-feedforward norm + residual
|
||||
let x_ff = self.pre_feedforward_layernorm.forward(x);
|
||||
let mlp_out = (gemma_gelu(x_ff.matmul(self.gate.t())) * x_ff.matmul(self.up.t()))
|
||||
.matmul(self.down.t());
|
||||
let mlp_normed = self.post_feedforward_layernorm.forward(mlp_out);
|
||||
(x + mlp_normed, k_cache_out, v_cache_out)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ luminal_nn = { path = "../../crates/luminal_nn" }
|
||||
luminal_cuda_lite = { path = "../../crates/luminal_cuda_lite" }
|
||||
tokenizers = "0.22.2"
|
||||
rustc-hash = "2"
|
||||
rand = "0.9.2"
|
||||
|
||||
# HuggingFace model download
|
||||
hf-hub = { version = "0.4", default-features = false, features = ["rustls-tls", "ureq"] }
|
||||
|
||||
@@ -5,11 +5,13 @@ use hf::prepare_hf_model;
|
||||
use luminal::prelude::*;
|
||||
use luminal_cuda_lite::{cudarc::driver::CudaContext, runtime::CudaRuntime};
|
||||
use model::*;
|
||||
use rand::{SeedableRng, rngs::SmallRng};
|
||||
use rustc_hash::FxHashSet;
|
||||
use std::{io::Write, time::Duration};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
const REPO_ID: &str = "google/gemma-4-26B-A4B";
|
||||
const SEARCH_SEED: u64 = 0;
|
||||
|
||||
fn env_bool(name: &str) -> bool {
|
||||
std::env::var(name)
|
||||
@@ -47,9 +49,20 @@ fn main() {
|
||||
k_out.output();
|
||||
v_out.output();
|
||||
}
|
||||
let max_prefill = (prompt_tokens.len() + 16)
|
||||
.next_power_of_two()
|
||||
.min(max_seq_len);
|
||||
let search_s = 16.min(max_prefill).max(2);
|
||||
let build_options = CompileOptions::default().dim_buckets(
|
||||
's',
|
||||
&[
|
||||
DimBucket::new(1, 1),
|
||||
DimBucket::new(2, max_prefill).representative(search_s),
|
||||
],
|
||||
);
|
||||
|
||||
println!("Building E-Graph...");
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(build_options);
|
||||
|
||||
println!("Loading weights...");
|
||||
let mut runtime = CudaRuntime::initialize(stream);
|
||||
@@ -63,22 +76,15 @@ fn main() {
|
||||
}
|
||||
|
||||
println!("Compiling...");
|
||||
let max_prefill = (prompt_tokens.len() + 16)
|
||||
.next_power_of_two()
|
||||
.min(max_seq_len);
|
||||
let search_s = 16.min(max_prefill).max(2);
|
||||
cx.set_dim_buckets(
|
||||
's',
|
||||
&[
|
||||
DimBucket::new(1, 1),
|
||||
DimBucket::new(2, max_prefill).representative(search_s),
|
||||
],
|
||||
);
|
||||
cx.set_dim('s', search_s);
|
||||
cx.set_dim('p', 0);
|
||||
runtime.set_data(input, vec![1; search_s]);
|
||||
runtime.set_data(pos_ids, (0..search_s as i32).collect::<Vec<_>>());
|
||||
runtime = cx.search(runtime, search_graphs);
|
||||
let mut rng = SmallRng::seed_from_u64(SEARCH_SEED);
|
||||
let search_options = CompileOptions::default()
|
||||
.search_graph_limit(search_graphs)
|
||||
.profile_timeout(Duration::from_secs(2));
|
||||
runtime = cx.search_with_rng(runtime, search_options, &mut rng);
|
||||
|
||||
for layer in 0..LAYERS {
|
||||
let cache_bytes = cache_bytes_for_layer(layer, max_seq_len);
|
||||
|
||||
@@ -83,20 +83,16 @@ impl KVCache {
|
||||
let mut v_caches = Vec::with_capacity(LAYERS);
|
||||
for layer in 0..LAYERS {
|
||||
let spec = layer_spec(layer);
|
||||
let k = cx
|
||||
.named_tensor(
|
||||
format!("kv_cache.{layer}.k"),
|
||||
(spec.num_kv_heads, max_seq, spec.head_dim),
|
||||
)
|
||||
.persist();
|
||||
let v = cx
|
||||
.named_tensor(
|
||||
format!("kv_cache.{layer}.v"),
|
||||
(spec.num_kv_heads, max_seq, spec.head_dim),
|
||||
)
|
||||
.persist();
|
||||
k_caches.push(k);
|
||||
v_caches.push(v);
|
||||
k_caches.push(persist(
|
||||
cx,
|
||||
format!("kv_cache.{layer}.k"),
|
||||
(spec.num_kv_heads, max_seq, spec.head_dim),
|
||||
));
|
||||
v_caches.push(persist(
|
||||
cx,
|
||||
format!("kv_cache.{layer}.v"),
|
||||
(spec.num_kv_heads, max_seq, spec.head_dim),
|
||||
));
|
||||
}
|
||||
Self {
|
||||
k_caches,
|
||||
@@ -115,169 +111,13 @@ pub struct Gemma4MoE {
|
||||
|
||||
impl Gemma4MoE {
|
||||
pub fn init(cx: &mut Graph) -> Self {
|
||||
let mut layers = Vec::with_capacity(LAYERS);
|
||||
for layer in 0..LAYERS {
|
||||
let spec = layer_spec(layer);
|
||||
let gate = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.mlp.gate_proj.weight"),
|
||||
(INTERMEDIATE, HIDDEN),
|
||||
)
|
||||
.persist();
|
||||
let up = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.mlp.up_proj.weight"),
|
||||
(INTERMEDIATE, HIDDEN),
|
||||
)
|
||||
.persist();
|
||||
let down = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.mlp.down_proj.weight"),
|
||||
(HIDDEN, INTERMEDIATE),
|
||||
)
|
||||
.persist();
|
||||
|
||||
let q_proj = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.self_attn.q_proj.weight"),
|
||||
(spec.q_dim, HIDDEN),
|
||||
)
|
||||
.persist();
|
||||
let k_proj = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.self_attn.k_proj.weight"),
|
||||
(spec.kv_dim, HIDDEN),
|
||||
)
|
||||
.persist();
|
||||
let v_proj = spec.has_v_proj.then(|| {
|
||||
cx.named_tensor(
|
||||
format!("model.layers.{layer}.self_attn.v_proj.weight"),
|
||||
(spec.kv_dim, HIDDEN),
|
||||
)
|
||||
.persist()
|
||||
});
|
||||
let o_proj = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.self_attn.o_proj.weight"),
|
||||
(HIDDEN, spec.q_dim),
|
||||
)
|
||||
.persist();
|
||||
let q_norm = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.self_attn.q_norm.weight"),
|
||||
spec.head_dim,
|
||||
)
|
||||
.persist();
|
||||
let k_norm = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.self_attn.k_norm.weight"),
|
||||
spec.head_dim,
|
||||
)
|
||||
.persist();
|
||||
let layer_scalar = cx
|
||||
.named_tensor(format!("model.layers.{layer}.layer_scalar"), HIDDEN)
|
||||
.persist();
|
||||
|
||||
let router_scale = cx
|
||||
.named_tensor(format!("model.layers.{layer}.router.scale"), HIDDEN)
|
||||
.persist();
|
||||
let router_proj = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.router.proj.weight"),
|
||||
(NUM_EXPERTS, HIDDEN),
|
||||
)
|
||||
.persist();
|
||||
let per_expert_scale = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.router.per_expert_scale"),
|
||||
NUM_EXPERTS,
|
||||
)
|
||||
.persist();
|
||||
let gate_up_weights = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.experts.gate_up_proj"),
|
||||
(NUM_EXPERTS, MOE_INTERMEDIATE * 2, HIDDEN),
|
||||
)
|
||||
.persist()
|
||||
.as_dtype(DType::Bf16);
|
||||
let down_weights = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.experts.down_proj"),
|
||||
(NUM_EXPERTS, HIDDEN, MOE_INTERMEDIATE),
|
||||
)
|
||||
.persist()
|
||||
.as_dtype(DType::Bf16);
|
||||
|
||||
layers.push(Gemma4Layer {
|
||||
spec,
|
||||
gate,
|
||||
up,
|
||||
down,
|
||||
q_proj,
|
||||
k_proj,
|
||||
v_proj,
|
||||
o_proj,
|
||||
q_norm,
|
||||
k_norm,
|
||||
layer_scalar,
|
||||
input_layernorm: gemma4_norm(
|
||||
HIDDEN,
|
||||
&format!("model.layers.{layer}.input_layernorm.weight"),
|
||||
cx,
|
||||
),
|
||||
post_attention_layernorm: gemma4_norm(
|
||||
HIDDEN,
|
||||
&format!("model.layers.{layer}.post_attention_layernorm.weight"),
|
||||
cx,
|
||||
),
|
||||
pre_feedforward_layernorm: gemma4_norm(
|
||||
HIDDEN,
|
||||
&format!("model.layers.{layer}.pre_feedforward_layernorm.weight"),
|
||||
cx,
|
||||
),
|
||||
post_feedforward_layernorm: gemma4_norm(
|
||||
HIDDEN,
|
||||
&format!("model.layers.{layer}.post_feedforward_layernorm.weight"),
|
||||
cx,
|
||||
),
|
||||
post_feedforward_layernorm_1: gemma4_norm(
|
||||
HIDDEN,
|
||||
&format!("model.layers.{layer}.post_feedforward_layernorm_1.weight"),
|
||||
cx,
|
||||
),
|
||||
post_feedforward_layernorm_2: gemma4_norm(
|
||||
HIDDEN,
|
||||
&format!("model.layers.{layer}.post_feedforward_layernorm_2.weight"),
|
||||
cx,
|
||||
),
|
||||
pre_feedforward_layernorm_2: gemma4_norm(
|
||||
HIDDEN,
|
||||
&format!("model.layers.{layer}.pre_feedforward_layernorm_2.weight"),
|
||||
cx,
|
||||
),
|
||||
moe: Gemma4SparseMoE {
|
||||
router_scale,
|
||||
router_proj,
|
||||
per_expert_scale,
|
||||
gate_up_weights,
|
||||
down_weights,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
let embedding = cx
|
||||
.named_tensor("model.embed_tokens.weight", (VOCAB_SIZE, HIDDEN))
|
||||
.persist();
|
||||
let lm_head = cx
|
||||
.named_tensor("lm_head.weight", (VOCAB_SIZE, HIDDEN))
|
||||
.persist();
|
||||
let lm_norm = gemma4_norm(HIDDEN, "model.norm.weight", cx);
|
||||
|
||||
Self {
|
||||
embedding,
|
||||
lm_head,
|
||||
layers,
|
||||
lm_norm,
|
||||
embedding: persist(cx, "model.embed_tokens.weight", (VOCAB_SIZE, HIDDEN)),
|
||||
lm_head: persist(cx, "lm_head.weight", (VOCAB_SIZE, HIDDEN)),
|
||||
layers: (0..LAYERS)
|
||||
.map(|layer| Gemma4Layer::init(cx, layer))
|
||||
.collect(),
|
||||
lm_norm: gemma4_norm(HIDDEN, "model.norm.weight", cx),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -287,11 +127,7 @@ impl Gemma4MoE {
|
||||
pos_ids: GraphTensor,
|
||||
kv_cache: &KVCache,
|
||||
) -> (GraphTensor, Vec<(GraphTensor, GraphTensor)>) {
|
||||
let seq = token_ids.dims1();
|
||||
let mut x = self.embedding.gather(
|
||||
(token_ids * HIDDEN).expand_dim(1, HIDDEN)
|
||||
+ token_ids.graph().arange(HIDDEN).expand_dim(0, seq),
|
||||
);
|
||||
let mut x = token_embedding(self.embedding, token_ids);
|
||||
|
||||
let mut cache_outputs = Vec::with_capacity(LAYERS);
|
||||
for (layer_idx, layer) in self.layers.iter().enumerate() {
|
||||
@@ -342,6 +178,164 @@ struct Gemma4SparseMoE {
|
||||
down_weights: GraphTensor,
|
||||
}
|
||||
|
||||
impl Gemma4Layer {
|
||||
fn init(cx: &mut Graph, layer: usize) -> Self {
|
||||
let spec = layer_spec(layer);
|
||||
Self {
|
||||
spec,
|
||||
gate: layer_weight(cx, layer, "mlp.gate_proj", (INTERMEDIATE, HIDDEN)),
|
||||
up: layer_weight(cx, layer, "mlp.up_proj", (INTERMEDIATE, HIDDEN)),
|
||||
down: layer_weight(cx, layer, "mlp.down_proj", (HIDDEN, INTERMEDIATE)),
|
||||
q_proj: layer_weight(cx, layer, "self_attn.q_proj", (spec.q_dim, HIDDEN)),
|
||||
k_proj: layer_weight(cx, layer, "self_attn.k_proj", (spec.kv_dim, HIDDEN)),
|
||||
v_proj: spec
|
||||
.has_v_proj
|
||||
.then(|| layer_weight(cx, layer, "self_attn.v_proj", (spec.kv_dim, HIDDEN))),
|
||||
o_proj: layer_weight(cx, layer, "self_attn.o_proj", (HIDDEN, spec.q_dim)),
|
||||
q_norm: layer_weight(cx, layer, "self_attn.q_norm", spec.head_dim),
|
||||
k_norm: layer_weight(cx, layer, "self_attn.k_norm", spec.head_dim),
|
||||
layer_scalar: layer_tensor(cx, layer, "layer_scalar", HIDDEN),
|
||||
input_layernorm: layer_norm(cx, layer, "input_layernorm"),
|
||||
post_attention_layernorm: layer_norm(cx, layer, "post_attention_layernorm"),
|
||||
pre_feedforward_layernorm: layer_norm(cx, layer, "pre_feedforward_layernorm"),
|
||||
post_feedforward_layernorm: layer_norm(cx, layer, "post_feedforward_layernorm"),
|
||||
post_feedforward_layernorm_1: layer_norm(cx, layer, "post_feedforward_layernorm_1"),
|
||||
post_feedforward_layernorm_2: layer_norm(cx, layer, "post_feedforward_layernorm_2"),
|
||||
pre_feedforward_layernorm_2: layer_norm(cx, layer, "pre_feedforward_layernorm_2"),
|
||||
moe: Gemma4SparseMoE {
|
||||
router_scale: layer_tensor(cx, layer, "router.scale", HIDDEN),
|
||||
router_proj: layer_weight(cx, layer, "router.proj", (NUM_EXPERTS, HIDDEN)),
|
||||
per_expert_scale: layer_tensor(cx, layer, "router.per_expert_scale", NUM_EXPERTS),
|
||||
gate_up_weights: layer_tensor(
|
||||
cx,
|
||||
layer,
|
||||
"experts.gate_up_proj",
|
||||
(NUM_EXPERTS, MOE_INTERMEDIATE * 2, HIDDEN),
|
||||
)
|
||||
.as_dtype(DType::Bf16),
|
||||
down_weights: layer_tensor(
|
||||
cx,
|
||||
layer,
|
||||
"experts.down_proj",
|
||||
(NUM_EXPERTS, HIDDEN, MOE_INTERMEDIATE),
|
||||
)
|
||||
.as_dtype(DType::Bf16),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward(
|
||||
&self,
|
||||
x: GraphTensor,
|
||||
pos_ids: GraphTensor,
|
||||
k_cache_in: GraphTensor,
|
||||
v_cache_in: GraphTensor,
|
||||
max_seq: usize,
|
||||
) -> (GraphTensor, GraphTensor, GraphTensor) {
|
||||
let residual = x;
|
||||
let x_attn = self.input_layernorm.forward(x);
|
||||
let q = x_attn.matmul(self.q_proj.t());
|
||||
let k_base = x_attn.matmul(self.k_proj.t());
|
||||
let v_base = if let Some(v_proj) = self.v_proj {
|
||||
x_attn.matmul(v_proj.t())
|
||||
} else {
|
||||
k_base
|
||||
};
|
||||
|
||||
let q_normed = qk_norm(q, self.q_norm, N_HEADS, self.spec.head_dim);
|
||||
let k_normed = qk_norm(
|
||||
k_base,
|
||||
self.k_norm,
|
||||
self.spec.num_kv_heads,
|
||||
self.spec.head_dim,
|
||||
);
|
||||
let v_normed = value_norm(v_base, self.spec.head_dim);
|
||||
|
||||
let q_rope = gemma4_rotary_embeddings(
|
||||
q_normed,
|
||||
pos_ids,
|
||||
N_HEADS,
|
||||
self.spec.head_dim,
|
||||
self.spec.rope_theta,
|
||||
self.spec.partial_rotary_factor,
|
||||
);
|
||||
let k_rope = gemma4_rotary_embeddings(
|
||||
k_normed,
|
||||
pos_ids,
|
||||
self.spec.num_kv_heads,
|
||||
self.spec.head_dim,
|
||||
self.spec.rope_theta,
|
||||
self.spec.partial_rotary_factor,
|
||||
);
|
||||
|
||||
let (attn_out, k_cache_out, v_cache_out) = hlir_attention(
|
||||
q_rope, k_rope, v_normed, k_cache_in, v_cache_in, max_seq, self.spec,
|
||||
);
|
||||
|
||||
let attn_proj = attn_out.matmul(self.o_proj.t());
|
||||
let x = residual + self.post_attention_layernorm.forward(attn_proj);
|
||||
|
||||
let dense_ff = dense_ffn(
|
||||
self.pre_feedforward_layernorm.forward(x),
|
||||
self.gate,
|
||||
self.up,
|
||||
self.down,
|
||||
);
|
||||
let dense_ff = self.post_feedforward_layernorm_1.forward(dense_ff);
|
||||
|
||||
let moe_out = self
|
||||
.moe
|
||||
.forward(x, self.pre_feedforward_layernorm_2.forward(x));
|
||||
let moe_out = self.post_feedforward_layernorm_2.forward(moe_out);
|
||||
|
||||
let ff_out = self.post_feedforward_layernorm.forward(dense_ff + moe_out);
|
||||
let x = x + ff_out;
|
||||
let x = x * self
|
||||
.layer_scalar
|
||||
.expand_lhs(&x.dims()[..x.dims().len() - 1]);
|
||||
|
||||
(x, k_cache_out, v_cache_out)
|
||||
}
|
||||
}
|
||||
|
||||
fn persist(
|
||||
cx: &mut Graph,
|
||||
name: impl ToString,
|
||||
shape: impl luminal::prelude::ToShape,
|
||||
) -> GraphTensor {
|
||||
cx.named_tensor(name, shape).persist()
|
||||
}
|
||||
|
||||
fn layer_tensor(
|
||||
cx: &mut Graph,
|
||||
layer: usize,
|
||||
suffix: &str,
|
||||
shape: impl luminal::prelude::ToShape,
|
||||
) -> GraphTensor {
|
||||
persist(cx, format!("model.layers.{layer}.{suffix}"), shape)
|
||||
}
|
||||
|
||||
fn layer_weight(
|
||||
cx: &mut Graph,
|
||||
layer: usize,
|
||||
suffix: &str,
|
||||
shape: impl luminal::prelude::ToShape,
|
||||
) -> GraphTensor {
|
||||
layer_tensor(cx, layer, &format!("{suffix}.weight"), shape)
|
||||
}
|
||||
|
||||
fn layer_norm(cx: &mut Graph, layer: usize, name: &str) -> LayerNorm {
|
||||
gemma4_norm(HIDDEN, &format!("model.layers.{layer}.{name}.weight"), cx)
|
||||
}
|
||||
|
||||
fn token_embedding(embedding: GraphTensor, token_ids: GraphTensor) -> GraphTensor {
|
||||
let seq = token_ids.dims1();
|
||||
embedding.gather(
|
||||
(token_ids * HIDDEN).expand_dim(1, HIDDEN)
|
||||
+ token_ids.graph().arange(HIDDEN).expand_dim(0, seq),
|
||||
)
|
||||
}
|
||||
|
||||
fn gemma4_norm(dim: usize, weight_name: &str, cx: &mut Graph) -> LayerNorm {
|
||||
LayerNorm::new(dim, Some(weight_name), None, false, RMS_NORM_EPS, cx)
|
||||
}
|
||||
@@ -505,81 +499,6 @@ fn hlir_attention(
|
||||
(out, k_cache_out, v_cache_out)
|
||||
}
|
||||
|
||||
impl Gemma4Layer {
|
||||
pub fn forward(
|
||||
&self,
|
||||
x: GraphTensor,
|
||||
pos_ids: GraphTensor,
|
||||
k_cache_in: GraphTensor,
|
||||
v_cache_in: GraphTensor,
|
||||
max_seq: usize,
|
||||
) -> (GraphTensor, GraphTensor, GraphTensor) {
|
||||
let residual = x;
|
||||
let x_attn = self.input_layernorm.forward(x);
|
||||
let q = x_attn.matmul(self.q_proj.t());
|
||||
let k_base = x_attn.matmul(self.k_proj.t());
|
||||
let v_base = if let Some(v_proj) = self.v_proj {
|
||||
x_attn.matmul(v_proj.t())
|
||||
} else {
|
||||
k_base
|
||||
};
|
||||
|
||||
let q_normed = qk_norm(q, self.q_norm, N_HEADS, self.spec.head_dim);
|
||||
let k_normed = qk_norm(
|
||||
k_base,
|
||||
self.k_norm,
|
||||
self.spec.num_kv_heads,
|
||||
self.spec.head_dim,
|
||||
);
|
||||
let v_normed = value_norm(v_base, self.spec.head_dim);
|
||||
|
||||
let q_rope = gemma4_rotary_embeddings(
|
||||
q_normed,
|
||||
pos_ids,
|
||||
N_HEADS,
|
||||
self.spec.head_dim,
|
||||
self.spec.rope_theta,
|
||||
self.spec.partial_rotary_factor,
|
||||
);
|
||||
let k_rope = gemma4_rotary_embeddings(
|
||||
k_normed,
|
||||
pos_ids,
|
||||
self.spec.num_kv_heads,
|
||||
self.spec.head_dim,
|
||||
self.spec.rope_theta,
|
||||
self.spec.partial_rotary_factor,
|
||||
);
|
||||
|
||||
let (attn_out, k_cache_out, v_cache_out) = hlir_attention(
|
||||
q_rope, k_rope, v_normed, k_cache_in, v_cache_in, max_seq, self.spec,
|
||||
);
|
||||
|
||||
let attn_proj = attn_out.matmul(self.o_proj.t());
|
||||
let x = residual + self.post_attention_layernorm.forward(attn_proj);
|
||||
|
||||
let dense_ff = dense_ffn(
|
||||
self.pre_feedforward_layernorm.forward(x),
|
||||
self.gate,
|
||||
self.up,
|
||||
self.down,
|
||||
);
|
||||
let dense_ff = self.post_feedforward_layernorm_1.forward(dense_ff);
|
||||
|
||||
let moe_out = self
|
||||
.moe
|
||||
.forward(x, self.pre_feedforward_layernorm_2.forward(x));
|
||||
let moe_out = self.post_feedforward_layernorm_2.forward(moe_out);
|
||||
|
||||
let ff_out = self.post_feedforward_layernorm.forward(dense_ff + moe_out);
|
||||
let x = x + ff_out;
|
||||
let x = x * self
|
||||
.layer_scalar
|
||||
.expand_lhs(&x.dims()[..x.dims().len() - 1]);
|
||||
|
||||
(x, k_cache_out, v_cache_out)
|
||||
}
|
||||
}
|
||||
|
||||
fn dense_ffn(x: GraphTensor, gate: GraphTensor, up: GraphTensor, down: GraphTensor) -> GraphTensor {
|
||||
(gemma_gelu(x.matmul(gate.t())) * x.matmul(up.t())).matmul(down.t())
|
||||
}
|
||||
|
||||
@@ -17,7 +17,7 @@ const FP8_REPO_ID: &str = "nvidia/Llama-3.1-8B-Instruct-FP8";
|
||||
const MAX_SEQ_LEN: usize = 4096;
|
||||
const GEN_TOKENS: usize = 500;
|
||||
const SEARCH_GRAPHS: usize = 500;
|
||||
const SEARCH_TRIALS: usize = 1;
|
||||
const SEARCH_TRIALS: usize = 10;
|
||||
const SEARCH_KEEP_BEST: usize = 4;
|
||||
const SEARCH_MEMORY_MIB: usize = 2048;
|
||||
const SEARCH_SEED: u64 = 0;
|
||||
@@ -290,12 +290,21 @@ fn main() {
|
||||
|
||||
cx.set_dim('s', 1);
|
||||
cx.set_dim('c', 1);
|
||||
let max_prefill = (prompt_len + 16).next_power_of_two().min(MAX_SEQ_LEN);
|
||||
let search_s = 16.min(max_prefill).max(2);
|
||||
let build_options = CompileOptions::default()
|
||||
.max_memory_mib(SEARCH_MEMORY_MIB)
|
||||
.dim_buckets(
|
||||
's',
|
||||
&[
|
||||
DimBucket::new(1, 1),
|
||||
DimBucket::new(2, max_prefill).representative(search_s),
|
||||
],
|
||||
);
|
||||
|
||||
println!("Building E-Graph...");
|
||||
let egraph_start = std::time::Instant::now();
|
||||
cx.build_search_space_with_options::<CudaRuntime>(
|
||||
BuildSearchSpaceOptions::new().max_memory_mib(SEARCH_MEMORY_MIB),
|
||||
);
|
||||
cx.build_search_space::<CudaRuntime>(build_options);
|
||||
println!(
|
||||
" E-Graph build: {:.2} s",
|
||||
egraph_start.elapsed().as_secs_f64()
|
||||
@@ -318,15 +327,6 @@ fn main() {
|
||||
|
||||
println!("Compiling...");
|
||||
let compile_start = std::time::Instant::now();
|
||||
let max_prefill = (prompt_len + 16).next_power_of_two().min(MAX_SEQ_LEN);
|
||||
let search_s = 16.min(max_prefill).max(2);
|
||||
cx.set_dim_buckets(
|
||||
's',
|
||||
&[
|
||||
DimBucket::new(1, 1),
|
||||
DimBucket::new(2, max_prefill).representative(search_s),
|
||||
],
|
||||
);
|
||||
cx.set_dim('s', search_s);
|
||||
cx.set_dim('c', search_s);
|
||||
runtime.set_data(input, vec![1; search_s]);
|
||||
@@ -338,13 +338,11 @@ fn main() {
|
||||
println!(" Search trials: {SEARCH_TRIALS}");
|
||||
println!(" Search keep-best: {SEARCH_KEEP_BEST}");
|
||||
let mut rng = StdRng::seed_from_u64(SEARCH_SEED);
|
||||
runtime = cx.search_options(
|
||||
runtime,
|
||||
SearchOptions::new(SEARCH_GRAPHS)
|
||||
.trials(SEARCH_TRIALS)
|
||||
.keep_best(SEARCH_KEEP_BEST),
|
||||
&mut rng,
|
||||
);
|
||||
let search_options = CompileOptions::default()
|
||||
.search_graph_limit(SEARCH_GRAPHS)
|
||||
.trials(SEARCH_TRIALS)
|
||||
.keep_best(SEARCH_KEEP_BEST);
|
||||
runtime = cx.search_with_rng(runtime, search_options, &mut rng);
|
||||
println!(
|
||||
" Search/compile: {:.2} s",
|
||||
compile_start.elapsed().as_secs_f64()
|
||||
|
||||
@@ -111,125 +111,18 @@ impl Llama {
|
||||
config: LlamaConfig,
|
||||
fp8_linears: bool,
|
||||
) -> Self {
|
||||
let mut layers = Vec::with_capacity(config.layers);
|
||||
for l in 0..config.layers {
|
||||
layers.push(LlamaLayer {
|
||||
config,
|
||||
up: linear_weight(
|
||||
cx,
|
||||
format!("model.layers.{l}.mlp.up_proj"),
|
||||
(config.intermediate, config.hidden),
|
||||
fp8_linears,
|
||||
),
|
||||
up_scales: fp8_linear_scales(
|
||||
cx,
|
||||
format!("model.layers.{l}.mlp.up_proj"),
|
||||
fp8_linears,
|
||||
),
|
||||
gate: linear_weight(
|
||||
cx,
|
||||
format!("model.layers.{l}.mlp.gate_proj"),
|
||||
(config.intermediate, config.hidden),
|
||||
fp8_linears,
|
||||
),
|
||||
gate_scales: fp8_linear_scales(
|
||||
cx,
|
||||
format!("model.layers.{l}.mlp.gate_proj"),
|
||||
fp8_linears,
|
||||
),
|
||||
down: linear_weight(
|
||||
cx,
|
||||
format!("model.layers.{l}.mlp.down_proj"),
|
||||
(config.hidden, config.intermediate),
|
||||
fp8_linears,
|
||||
),
|
||||
down_scales: fp8_linear_scales(
|
||||
cx,
|
||||
format!("model.layers.{l}.mlp.down_proj"),
|
||||
fp8_linears,
|
||||
),
|
||||
q_proj: linear_weight(
|
||||
cx,
|
||||
format!("model.layers.{l}.self_attn.q_proj"),
|
||||
(config.hidden, config.hidden),
|
||||
fp8_linears,
|
||||
),
|
||||
q_proj_scales: fp8_linear_scales(
|
||||
cx,
|
||||
format!("model.layers.{l}.self_attn.q_proj"),
|
||||
fp8_linears,
|
||||
),
|
||||
k_proj: linear_weight(
|
||||
cx,
|
||||
format!("model.layers.{l}.self_attn.k_proj"),
|
||||
(config.kv_dim(), config.hidden),
|
||||
fp8_linears,
|
||||
),
|
||||
k_proj_scales: fp8_linear_scales(
|
||||
cx,
|
||||
format!("model.layers.{l}.self_attn.k_proj"),
|
||||
fp8_linears,
|
||||
),
|
||||
v_proj: linear_weight(
|
||||
cx,
|
||||
format!("model.layers.{l}.self_attn.v_proj"),
|
||||
(config.kv_dim(), config.hidden),
|
||||
fp8_linears,
|
||||
),
|
||||
v_proj_scales: fp8_linear_scales(
|
||||
cx,
|
||||
format!("model.layers.{l}.self_attn.v_proj"),
|
||||
fp8_linears,
|
||||
),
|
||||
o_proj: linear_weight(
|
||||
cx,
|
||||
format!("model.layers.{l}.self_attn.o_proj"),
|
||||
(config.hidden, config.hidden),
|
||||
fp8_linears,
|
||||
),
|
||||
o_proj_scales: fp8_linear_scales(
|
||||
cx,
|
||||
format!("model.layers.{l}.self_attn.o_proj"),
|
||||
fp8_linears,
|
||||
),
|
||||
attn_rms: LayerNorm::new(
|
||||
config.hidden,
|
||||
Some(&format!("model.layers.{l}.input_layernorm.weight")),
|
||||
None,
|
||||
false,
|
||||
1e-5,
|
||||
cx,
|
||||
),
|
||||
mlp_rms: LayerNorm::new(
|
||||
config.hidden,
|
||||
Some(&format!("model.layers.{l}.post_attention_layernorm.weight")),
|
||||
None,
|
||||
false,
|
||||
1e-5,
|
||||
cx,
|
||||
),
|
||||
});
|
||||
}
|
||||
Self {
|
||||
config,
|
||||
embedding: cx
|
||||
.named_tensor(
|
||||
"model.embed_tokens.weight",
|
||||
(config.vocab_size, config.hidden),
|
||||
)
|
||||
.persist(),
|
||||
layers,
|
||||
lm_head: cx
|
||||
.named_tensor("lm_head.weight", (config.vocab_size, config.hidden))
|
||||
.persist(),
|
||||
lm_norm: LayerNorm::new(
|
||||
config.hidden,
|
||||
Some("model.norm.weight"),
|
||||
None,
|
||||
false,
|
||||
1e-5,
|
||||
embedding: persist(
|
||||
cx,
|
||||
"model.embed_tokens.weight",
|
||||
(config.vocab_size, config.hidden),
|
||||
),
|
||||
layers: (0..config.layers)
|
||||
.map(|l| LlamaLayer::init(cx, l, config, fp8_linears))
|
||||
.collect(),
|
||||
lm_head: persist(cx, "lm_head.weight", (config.vocab_size, config.hidden)),
|
||||
lm_norm: rms_norm(cx, config.hidden, "model.norm.weight"),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -243,12 +136,7 @@ impl Llama {
|
||||
attn_mask: GraphTensor,
|
||||
kv_cache: &KVCache,
|
||||
) -> (GraphTensor, Vec<(GraphTensor, GraphTensor)>) {
|
||||
let seq = input.dims1();
|
||||
let hidden = self.config.hidden;
|
||||
let mut x = self.embedding.gather(
|
||||
(input * hidden).expand_dim(1, hidden)
|
||||
+ input.graph().arange(hidden).expand_dim(0, seq),
|
||||
);
|
||||
let mut x = token_embedding(self.embedding, input, self.config.hidden);
|
||||
let mut cache_outputs = Vec::with_capacity(self.config.layers);
|
||||
for (i, layer) in self.layers.iter().enumerate() {
|
||||
let (x_new, k_out, v_out) = layer.forward(
|
||||
@@ -311,6 +199,170 @@ struct Fp8LinearScales {
|
||||
weight: GraphTensor,
|
||||
}
|
||||
|
||||
impl LlamaLayer {
|
||||
fn init(cx: &mut Graph, l: usize, config: LlamaConfig, fp8: bool) -> Self {
|
||||
Self {
|
||||
config,
|
||||
up: layer_linear_weight(
|
||||
cx,
|
||||
l,
|
||||
"mlp.up_proj",
|
||||
(config.intermediate, config.hidden),
|
||||
fp8,
|
||||
),
|
||||
up_scales: layer_linear_scales(cx, l, "mlp.up_proj", fp8),
|
||||
gate: layer_linear_weight(
|
||||
cx,
|
||||
l,
|
||||
"mlp.gate_proj",
|
||||
(config.intermediate, config.hidden),
|
||||
fp8,
|
||||
),
|
||||
gate_scales: layer_linear_scales(cx, l, "mlp.gate_proj", fp8),
|
||||
down: layer_linear_weight(
|
||||
cx,
|
||||
l,
|
||||
"mlp.down_proj",
|
||||
(config.hidden, config.intermediate),
|
||||
fp8,
|
||||
),
|
||||
down_scales: layer_linear_scales(cx, l, "mlp.down_proj", fp8),
|
||||
q_proj: layer_linear_weight(
|
||||
cx,
|
||||
l,
|
||||
"self_attn.q_proj",
|
||||
(config.hidden, config.hidden),
|
||||
fp8,
|
||||
),
|
||||
q_proj_scales: layer_linear_scales(cx, l, "self_attn.q_proj", fp8),
|
||||
k_proj: layer_linear_weight(
|
||||
cx,
|
||||
l,
|
||||
"self_attn.k_proj",
|
||||
(config.kv_dim(), config.hidden),
|
||||
fp8,
|
||||
),
|
||||
k_proj_scales: layer_linear_scales(cx, l, "self_attn.k_proj", fp8),
|
||||
v_proj: layer_linear_weight(
|
||||
cx,
|
||||
l,
|
||||
"self_attn.v_proj",
|
||||
(config.kv_dim(), config.hidden),
|
||||
fp8,
|
||||
),
|
||||
v_proj_scales: layer_linear_scales(cx, l, "self_attn.v_proj", fp8),
|
||||
o_proj: layer_linear_weight(
|
||||
cx,
|
||||
l,
|
||||
"self_attn.o_proj",
|
||||
(config.hidden, config.hidden),
|
||||
fp8,
|
||||
),
|
||||
o_proj_scales: layer_linear_scales(cx, l, "self_attn.o_proj", fp8),
|
||||
attn_rms: rms_norm(
|
||||
cx,
|
||||
config.hidden,
|
||||
format!("model.layers.{l}.input_layernorm.weight"),
|
||||
),
|
||||
mlp_rms: rms_norm(
|
||||
cx,
|
||||
config.hidden,
|
||||
format!("model.layers.{l}.post_attention_layernorm.weight"),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn forward(
|
||||
&self,
|
||||
mut x: GraphTensor,
|
||||
q_pos: GraphTensor,
|
||||
scatter_idx: GraphTensor,
|
||||
gather_idx: GraphTensor,
|
||||
attn_mask: GraphTensor,
|
||||
k_cache: GraphTensor,
|
||||
v_cache: GraphTensor,
|
||||
) -> (GraphTensor, GraphTensor, GraphTensor) {
|
||||
let x_attn = self.attn_rms.forward(x);
|
||||
let q = linear_matmul(x_attn, self.q_proj, self.q_proj_scales);
|
||||
let k = linear_matmul(x_attn, self.k_proj, self.k_proj_scales);
|
||||
let v = linear_matmul(x_attn, self.v_proj, self.v_proj_scales);
|
||||
|
||||
let q_rope = llama_rotary_embeddings(q, q_pos, self.config);
|
||||
let k_rope = llama_rotary_embeddings(k, q_pos, self.config);
|
||||
|
||||
let (attn_out, k_cache_out, v_cache_out) = attention(
|
||||
AttentionInputs {
|
||||
q_rope,
|
||||
k_rope,
|
||||
v,
|
||||
k_cache,
|
||||
v_cache,
|
||||
scatter_idx,
|
||||
gather_idx,
|
||||
attn_mask,
|
||||
},
|
||||
self.config,
|
||||
);
|
||||
x += linear_matmul(attn_out, self.o_proj, self.o_proj_scales);
|
||||
|
||||
let x_mlp = self.mlp_rms.forward(x);
|
||||
let mlp_out = linear_matmul(x_mlp, self.gate, self.gate_scales).swish()
|
||||
* linear_matmul(x_mlp, self.up, self.up_scales);
|
||||
let mlp_out = linear_matmul(mlp_out, self.down, self.down_scales);
|
||||
(x + mlp_out, k_cache_out, v_cache_out)
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn parameter_tensors(&self) -> Vec<GraphTensor> {
|
||||
let mut tensors = vec![
|
||||
self.up,
|
||||
self.gate,
|
||||
self.down,
|
||||
self.q_proj,
|
||||
self.k_proj,
|
||||
self.v_proj,
|
||||
self.o_proj,
|
||||
];
|
||||
for scales in [
|
||||
self.up_scales,
|
||||
self.gate_scales,
|
||||
self.down_scales,
|
||||
self.q_proj_scales,
|
||||
self.k_proj_scales,
|
||||
self.v_proj_scales,
|
||||
self.o_proj_scales,
|
||||
]
|
||||
.into_iter()
|
||||
.flatten()
|
||||
{
|
||||
tensors.push(scales.input);
|
||||
tensors.push(scales.weight);
|
||||
}
|
||||
if let Some(weight) = self.attn_rms.weight {
|
||||
tensors.push(weight);
|
||||
}
|
||||
if let Some(bias) = self.attn_rms.bias {
|
||||
tensors.push(bias);
|
||||
}
|
||||
if let Some(weight) = self.mlp_rms.weight {
|
||||
tensors.push(weight);
|
||||
}
|
||||
if let Some(bias) = self.mlp_rms.bias {
|
||||
tensors.push(bias);
|
||||
}
|
||||
tensors
|
||||
}
|
||||
}
|
||||
|
||||
fn persist(
|
||||
cx: &mut Graph,
|
||||
name: impl ToString,
|
||||
shape: impl luminal::prelude::ToShape,
|
||||
) -> GraphTensor {
|
||||
cx.named_tensor(name, shape).persist()
|
||||
}
|
||||
|
||||
fn linear_weight(
|
||||
cx: &mut Graph,
|
||||
prefix: impl ToString,
|
||||
@@ -325,6 +377,16 @@ fn linear_weight(
|
||||
}
|
||||
}
|
||||
|
||||
fn layer_linear_weight(
|
||||
cx: &mut Graph,
|
||||
layer: usize,
|
||||
suffix: &str,
|
||||
shape: impl luminal::prelude::ToShape,
|
||||
fp8: bool,
|
||||
) -> GraphTensor {
|
||||
linear_weight(cx, format!("model.layers.{layer}.{suffix}"), shape, fp8)
|
||||
}
|
||||
|
||||
fn fp8_linear_scales(cx: &mut Graph, prefix: impl ToString, fp8: bool) -> Option<Fp8LinearScales> {
|
||||
if !fp8 {
|
||||
return None;
|
||||
@@ -340,6 +402,27 @@ fn fp8_linear_scales(cx: &mut Graph, prefix: impl ToString, fp8: bool) -> Option
|
||||
})
|
||||
}
|
||||
|
||||
fn layer_linear_scales(
|
||||
cx: &mut Graph,
|
||||
layer: usize,
|
||||
suffix: &str,
|
||||
fp8: bool,
|
||||
) -> Option<Fp8LinearScales> {
|
||||
fp8_linear_scales(cx, format!("model.layers.{layer}.{suffix}"), fp8)
|
||||
}
|
||||
|
||||
fn rms_norm(cx: &mut Graph, dim: usize, weight_name: impl ToString) -> LayerNorm {
|
||||
LayerNorm::new(dim, Some(&weight_name.to_string()), None, false, 1e-5, cx)
|
||||
}
|
||||
|
||||
fn token_embedding(embedding: GraphTensor, token_ids: GraphTensor, hidden: usize) -> GraphTensor {
|
||||
let seq = token_ids.dims1();
|
||||
embedding.gather(
|
||||
(token_ids * hidden).expand_dim(1, hidden)
|
||||
+ token_ids.graph().arange(hidden).expand_dim(0, seq),
|
||||
)
|
||||
}
|
||||
|
||||
fn expand_scalar(scale: GraphTensor, like: GraphTensor) -> GraphTensor {
|
||||
scale.expand_rhs(like.dims())
|
||||
}
|
||||
@@ -443,87 +526,3 @@ fn attention(
|
||||
|
||||
(attn_out, k_cache_out, v_cache_out)
|
||||
}
|
||||
|
||||
impl LlamaLayer {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn forward(
|
||||
&self,
|
||||
mut x: GraphTensor,
|
||||
q_pos: GraphTensor,
|
||||
scatter_idx: GraphTensor,
|
||||
gather_idx: GraphTensor,
|
||||
attn_mask: GraphTensor,
|
||||
k_cache: GraphTensor,
|
||||
v_cache: GraphTensor,
|
||||
) -> (GraphTensor, GraphTensor, GraphTensor) {
|
||||
let x_attn = self.attn_rms.forward(x);
|
||||
let q = linear_matmul(x_attn, self.q_proj, self.q_proj_scales);
|
||||
let k = linear_matmul(x_attn, self.k_proj, self.k_proj_scales);
|
||||
let v = linear_matmul(x_attn, self.v_proj, self.v_proj_scales);
|
||||
|
||||
let q_rope = llama_rotary_embeddings(q, q_pos, self.config);
|
||||
let k_rope = llama_rotary_embeddings(k, q_pos, self.config);
|
||||
|
||||
let (attn_out, k_cache_out, v_cache_out) = attention(
|
||||
AttentionInputs {
|
||||
q_rope,
|
||||
k_rope,
|
||||
v,
|
||||
k_cache,
|
||||
v_cache,
|
||||
scatter_idx,
|
||||
gather_idx,
|
||||
attn_mask,
|
||||
},
|
||||
self.config,
|
||||
);
|
||||
x += linear_matmul(attn_out, self.o_proj, self.o_proj_scales);
|
||||
|
||||
let x_mlp = self.mlp_rms.forward(x);
|
||||
let mlp_out = linear_matmul(x_mlp, self.gate, self.gate_scales).swish()
|
||||
* linear_matmul(x_mlp, self.up, self.up_scales);
|
||||
let mlp_out = linear_matmul(mlp_out, self.down, self.down_scales);
|
||||
(x + mlp_out, k_cache_out, v_cache_out)
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn parameter_tensors(&self) -> Vec<GraphTensor> {
|
||||
let mut tensors = vec![
|
||||
self.up,
|
||||
self.gate,
|
||||
self.down,
|
||||
self.q_proj,
|
||||
self.k_proj,
|
||||
self.v_proj,
|
||||
self.o_proj,
|
||||
];
|
||||
for scales in [
|
||||
self.up_scales,
|
||||
self.gate_scales,
|
||||
self.down_scales,
|
||||
self.q_proj_scales,
|
||||
self.k_proj_scales,
|
||||
self.v_proj_scales,
|
||||
self.o_proj_scales,
|
||||
]
|
||||
.into_iter()
|
||||
.flatten()
|
||||
{
|
||||
tensors.push(scales.input);
|
||||
tensors.push(scales.weight);
|
||||
}
|
||||
if let Some(weight) = self.attn_rms.weight {
|
||||
tensors.push(weight);
|
||||
}
|
||||
if let Some(bias) = self.attn_rms.bias {
|
||||
tensors.push(bias);
|
||||
}
|
||||
if let Some(weight) = self.mlp_rms.weight {
|
||||
tensors.push(weight);
|
||||
}
|
||||
if let Some(bias) = self.mlp_rms.bias {
|
||||
tensors.push(bias);
|
||||
}
|
||||
tensors
|
||||
}
|
||||
}
|
||||
|
||||
@@ -204,9 +204,20 @@ fn main() {
|
||||
k_out.output();
|
||||
v_out.output();
|
||||
}
|
||||
// Bucket s=1 (decode) vs s>1 (prefill/mixed). Each bucket gets its own
|
||||
// optimized compilation — decode can select warp-parallel kernels while
|
||||
// prefill can select tiled matmul / cuBLAS.
|
||||
let max_prefill = (tokens_a.len().max(tokens_b.len()) + 16).next_power_of_two();
|
||||
let build_options = CompileOptions::default().dim_buckets(
|
||||
's',
|
||||
&[
|
||||
DimBucket::new(1, 1),
|
||||
DimBucket::new(2, max_prefill).representative(16),
|
||||
],
|
||||
);
|
||||
|
||||
println!("Building E-Graph...");
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
cx.build_search_space::<CudaRuntime>(build_options);
|
||||
|
||||
println!("Loading weights...");
|
||||
let mut runtime = CudaRuntime::initialize(stream);
|
||||
@@ -220,18 +231,6 @@ fn main() {
|
||||
}
|
||||
|
||||
println!("Compiling...");
|
||||
// Bucket s=1 (decode) vs s>1 (prefill/mixed). Each bucket gets its own
|
||||
// optimized compilation — decode can select warp-parallel kernels while
|
||||
// prefill can select tiled matmul / cuBLAS.
|
||||
let max_prefill = (tokens_a.len().max(tokens_b.len()) + 16).next_power_of_two();
|
||||
cx.set_dim_buckets(
|
||||
's',
|
||||
&[
|
||||
DimBucket::new(1, 1),
|
||||
DimBucket::new(2, max_prefill).representative(16),
|
||||
],
|
||||
);
|
||||
|
||||
// Dummy data sized for the largest representative (s=16, c=16)
|
||||
let search_s = 16;
|
||||
let search_c = 16;
|
||||
@@ -242,7 +241,8 @@ fn main() {
|
||||
runtime.set_data(scatter_idx_t, vec![0i32; search_s]);
|
||||
runtime.set_data(gather_idx_t, vec![0i32; search_c]);
|
||||
runtime.set_data(attn_mask_t, vec![0.0f32; search_s * search_c]);
|
||||
runtime = cx.search(runtime, search_graphs);
|
||||
let search_options = CompileOptions::default().search_graph_limit(search_graphs);
|
||||
runtime = cx.search(runtime, search_options);
|
||||
|
||||
// Re-initialize KV cache after search (search consumes buffers)
|
||||
let cache_bytes = num_slots * KV_DIM * std::mem::size_of::<f32>();
|
||||
|
||||
@@ -25,8 +25,8 @@ pub struct PagedKVCache {
|
||||
|
||||
impl PagedKVCache {
|
||||
pub fn new(cx: &mut Graph, num_slots: usize) -> Self {
|
||||
let mut k_caches = vec![];
|
||||
let mut v_caches = vec![];
|
||||
let mut k_caches = Vec::with_capacity(LAYERS);
|
||||
let mut v_caches = Vec::with_capacity(LAYERS);
|
||||
for l in 0..LAYERS {
|
||||
k_caches.push(cx.named_tensor(format!("kv_cache.{l}.k"), (num_slots, KV_DIM)));
|
||||
v_caches.push(cx.named_tensor(format!("kv_cache.{l}.v"), (num_slots, KV_DIM)));
|
||||
@@ -44,78 +44,11 @@ pub struct Llama {
|
||||
|
||||
impl Llama {
|
||||
pub fn init(cx: &mut Graph) -> Self {
|
||||
let mut layers = vec![];
|
||||
for l in 0..LAYERS {
|
||||
layers.push(LlamaLayer {
|
||||
up: cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.mlp.up_proj.weight"),
|
||||
(INTERMEDIATE, HIDDEN),
|
||||
)
|
||||
.persist(),
|
||||
gate: cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.mlp.gate_proj.weight"),
|
||||
(INTERMEDIATE, HIDDEN),
|
||||
)
|
||||
.persist(),
|
||||
down: cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.mlp.down_proj.weight"),
|
||||
(HIDDEN, INTERMEDIATE),
|
||||
)
|
||||
.persist(),
|
||||
q_proj: cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.self_attn.q_proj.weight"),
|
||||
(HIDDEN, HIDDEN),
|
||||
)
|
||||
.persist(),
|
||||
k_proj: cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.self_attn.k_proj.weight"),
|
||||
(KV_DIM, HIDDEN),
|
||||
)
|
||||
.persist(),
|
||||
v_proj: cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.self_attn.v_proj.weight"),
|
||||
(KV_DIM, HIDDEN),
|
||||
)
|
||||
.persist(),
|
||||
o_proj: cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.self_attn.o_proj.weight"),
|
||||
(HIDDEN, HIDDEN),
|
||||
)
|
||||
.persist(),
|
||||
attn_rms: LayerNorm::new(
|
||||
HIDDEN,
|
||||
Some(&format!("model.layers.{l}.input_layernorm.weight")),
|
||||
None,
|
||||
false,
|
||||
1e-5,
|
||||
cx,
|
||||
),
|
||||
mlp_rms: LayerNorm::new(
|
||||
HIDDEN,
|
||||
Some(&format!("model.layers.{l}.post_attention_layernorm.weight")),
|
||||
None,
|
||||
false,
|
||||
1e-5,
|
||||
cx,
|
||||
),
|
||||
});
|
||||
}
|
||||
Self {
|
||||
embedding: cx
|
||||
.named_tensor("model.embed_tokens.weight", (VOCAB_SIZE, HIDDEN))
|
||||
.persist(),
|
||||
layers,
|
||||
lm_head: cx
|
||||
.named_tensor("lm_head.weight", (VOCAB_SIZE, HIDDEN))
|
||||
.persist(),
|
||||
lm_norm: LayerNorm::new(HIDDEN, Some("model.norm.weight"), None, false, 1e-5, cx),
|
||||
embedding: persist(cx, "model.embed_tokens.weight", (VOCAB_SIZE, HIDDEN)),
|
||||
layers: (0..LAYERS).map(|l| LlamaLayer::init(cx, l)).collect(),
|
||||
lm_head: persist(cx, "lm_head.weight", (VOCAB_SIZE, HIDDEN)),
|
||||
lm_norm: rms_norm(cx, "model.norm.weight"),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -141,12 +74,8 @@ impl Llama {
|
||||
attn_mask: GraphTensor,
|
||||
kv_cache: &PagedKVCache,
|
||||
) -> (GraphTensor, Vec<(GraphTensor, GraphTensor)>) {
|
||||
let seq = input.dims1();
|
||||
let mut x = self.embedding.gather(
|
||||
(input * HIDDEN).expand_dim(1, HIDDEN)
|
||||
+ input.graph().arange(HIDDEN).expand_dim(0, seq),
|
||||
);
|
||||
let mut cache_outputs = vec![];
|
||||
let mut x = token_embedding(self.embedding, input);
|
||||
let mut cache_outputs = Vec::with_capacity(LAYERS);
|
||||
for (i, layer) in self.layers.iter().enumerate() {
|
||||
let (x_new, k_out, v_out) = layer.forward(
|
||||
x,
|
||||
@@ -177,6 +106,99 @@ struct LlamaLayer {
|
||||
mlp_rms: LayerNorm,
|
||||
}
|
||||
|
||||
impl LlamaLayer {
|
||||
fn init(cx: &mut Graph, l: usize) -> Self {
|
||||
Self {
|
||||
up: layer_weight(cx, l, "mlp.up_proj", (INTERMEDIATE, HIDDEN)),
|
||||
gate: layer_weight(cx, l, "mlp.gate_proj", (INTERMEDIATE, HIDDEN)),
|
||||
down: layer_weight(cx, l, "mlp.down_proj", (HIDDEN, INTERMEDIATE)),
|
||||
q_proj: layer_weight(cx, l, "self_attn.q_proj", (HIDDEN, HIDDEN)),
|
||||
k_proj: layer_weight(cx, l, "self_attn.k_proj", (KV_DIM, HIDDEN)),
|
||||
v_proj: layer_weight(cx, l, "self_attn.v_proj", (KV_DIM, HIDDEN)),
|
||||
o_proj: layer_weight(cx, l, "self_attn.o_proj", (HIDDEN, HIDDEN)),
|
||||
attn_rms: rms_norm(cx, format!("model.layers.{l}.input_layernorm.weight")),
|
||||
mlp_rms: rms_norm(
|
||||
cx,
|
||||
format!("model.layers.{l}.post_attention_layernorm.weight"),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn forward(
|
||||
&self,
|
||||
mut x: GraphTensor,
|
||||
q_pos: GraphTensor,
|
||||
scatter_idx: GraphTensor,
|
||||
gather_idx: GraphTensor,
|
||||
attn_mask: GraphTensor,
|
||||
k_cache: GraphTensor,
|
||||
v_cache: GraphTensor,
|
||||
) -> (GraphTensor, GraphTensor, GraphTensor) {
|
||||
let x_attn = self.attn_rms.forward(x);
|
||||
let q = x_attn.matmul(self.q_proj.t());
|
||||
let k = x_attn.matmul(self.k_proj.t());
|
||||
let v = x_attn.matmul(self.v_proj.t());
|
||||
|
||||
let q_rope = llama_rotary_embeddings(q, q_pos);
|
||||
let k_rope = llama_rotary_embeddings(k, q_pos);
|
||||
|
||||
let (attn_out, k_cache_out, v_cache_out) = paged_attention(
|
||||
q_rope,
|
||||
k_rope,
|
||||
v,
|
||||
k_cache,
|
||||
v_cache,
|
||||
scatter_idx,
|
||||
gather_idx,
|
||||
attn_mask,
|
||||
);
|
||||
|
||||
x += attn_out.matmul(self.o_proj.t());
|
||||
|
||||
let x_mlp = self.mlp_rms.forward(x);
|
||||
let mlp_out =
|
||||
(x_mlp.matmul(self.gate.t()).swish() * x_mlp.matmul(self.up.t())).matmul(self.down.t());
|
||||
(x + mlp_out, k_cache_out, v_cache_out)
|
||||
}
|
||||
}
|
||||
|
||||
fn persist(
|
||||
cx: &mut Graph,
|
||||
name: impl ToString,
|
||||
shape: impl luminal::prelude::ToShape,
|
||||
) -> GraphTensor {
|
||||
cx.named_tensor(name, shape).persist()
|
||||
}
|
||||
|
||||
fn layer_weight(
|
||||
cx: &mut Graph,
|
||||
layer: usize,
|
||||
suffix: &str,
|
||||
shape: impl luminal::prelude::ToShape,
|
||||
) -> GraphTensor {
|
||||
persist(cx, format!("model.layers.{layer}.{suffix}.weight"), shape)
|
||||
}
|
||||
|
||||
fn rms_norm(cx: &mut Graph, weight_name: impl ToString) -> LayerNorm {
|
||||
LayerNorm::new(
|
||||
HIDDEN,
|
||||
Some(&weight_name.to_string()),
|
||||
None,
|
||||
false,
|
||||
1e-5,
|
||||
cx,
|
||||
)
|
||||
}
|
||||
|
||||
fn token_embedding(embedding: GraphTensor, token_ids: GraphTensor) -> GraphTensor {
|
||||
let seq = token_ids.dims1();
|
||||
embedding.gather(
|
||||
(token_ids * HIDDEN).expand_dim(1, HIDDEN)
|
||||
+ token_ids.graph().arange(HIDDEN).expand_dim(0, seq),
|
||||
)
|
||||
}
|
||||
|
||||
fn llama_rotary_embeddings(mut input: GraphTensor, pos_ids: GraphTensor) -> GraphTensor {
|
||||
input = input.split_dims(1, HEAD_DIM).transpose(0, 1);
|
||||
|
||||
@@ -264,44 +286,3 @@ fn paged_attention(
|
||||
let attn_out = out.transpose(0, 1).merge_dims(1, 2);
|
||||
(attn_out, k_cache_out, v_cache_out)
|
||||
}
|
||||
|
||||
impl LlamaLayer {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn forward(
|
||||
&self,
|
||||
mut x: GraphTensor,
|
||||
q_pos: GraphTensor,
|
||||
scatter_idx: GraphTensor,
|
||||
gather_idx: GraphTensor,
|
||||
attn_mask: GraphTensor,
|
||||
k_cache: GraphTensor,
|
||||
v_cache: GraphTensor,
|
||||
) -> (GraphTensor, GraphTensor, GraphTensor) {
|
||||
let x_attn = self.attn_rms.forward(x);
|
||||
let q = x_attn.matmul(self.q_proj.t());
|
||||
let k = x_attn.matmul(self.k_proj.t());
|
||||
let v = x_attn.matmul(self.v_proj.t());
|
||||
|
||||
// Apply RoPE before scattering into cache
|
||||
let q_rope = llama_rotary_embeddings(q, q_pos);
|
||||
let k_rope = llama_rotary_embeddings(k, q_pos);
|
||||
|
||||
let (attn_out, k_cache_out, v_cache_out) = paged_attention(
|
||||
q_rope,
|
||||
k_rope,
|
||||
v,
|
||||
k_cache,
|
||||
v_cache,
|
||||
scatter_idx,
|
||||
gather_idx,
|
||||
attn_mask,
|
||||
);
|
||||
|
||||
x += attn_out.matmul(self.o_proj.t());
|
||||
|
||||
let x_mlp = self.mlp_rms.forward(x);
|
||||
let mlp_out =
|
||||
(x_mlp.matmul(self.gate.t()).swish() * x_mlp.matmul(self.up.t())).matmul(self.down.t());
|
||||
(x + mlp_out, k_cache_out, v_cache_out)
|
||||
}
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user