mirror of
https://github.com/jafioti/luminal.git
synced 2026-06-01 21:49:47 +09:00
Compare commits
15 Commits
dlrm-fused
...
dlrm-pt2-m
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bfcc41040e | ||
|
|
9d4a3bc555 | ||
|
|
6f8de66e3d | ||
|
|
cb9facfb11 | ||
|
|
2845e605c1 | ||
|
|
ccdb6f1540 | ||
|
|
3f57d94ecb | ||
|
|
3b36880c22 | ||
|
|
f94335b1b8 | ||
|
|
f62e3c50d0 | ||
|
|
eeeabd7c20 | ||
|
|
0f02466f3d | ||
|
|
156fac518e | ||
|
|
a3df68bd43 | ||
|
|
7a95e56a8b |
67
.github/workflows/test-full-cuda.yml
vendored
Normal file
67
.github/workflows/test-full-cuda.yml
vendored
Normal file
@@ -0,0 +1,67 @@
|
||||
name: Test Full CUDA
|
||||
|
||||
on:
|
||||
pull_request_target:
|
||||
branches: ["main"]
|
||||
types: [labeled, synchronize]
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
rust_cuda_ignored_tests:
|
||||
if: >-
|
||||
github.event_name == 'workflow_dispatch'
|
||||
|| (github.event_name == 'pull_request_target'
|
||||
&& contains(github.event.pull_request.labels.*.name, 'full-modal-ready'))
|
||||
name: Rust CUDA Ignored Tests
|
||||
runs-on: ubuntu-latest
|
||||
environment: Modal
|
||||
timeout-minutes: 300
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
with:
|
||||
ref: ${{ github.event.pull_request.head.sha || github.sha }}
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- name: Install Modal
|
||||
run: pip install modal
|
||||
- name: Run ignored CUDA Rust tests on Modal
|
||||
env:
|
||||
MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }}
|
||||
MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }}
|
||||
GPU_TYPE: H100
|
||||
MODAL_TIMEOUT: "14400"
|
||||
CARGO_TEST_ARGS: "--ignored --test-threads=1"
|
||||
run: modal run ci/modal_cargo_test.py
|
||||
|
||||
python_cuda_slow_tests:
|
||||
if: >-
|
||||
github.event_name == 'workflow_dispatch'
|
||||
|| (github.event_name == 'pull_request_target'
|
||||
&& contains(github.event.pull_request.labels.*.name, 'full-modal-ready'))
|
||||
name: Python CUDA Slow Tests
|
||||
runs-on: ubuntu-latest
|
||||
environment: Modal
|
||||
timeout-minutes: 300
|
||||
defaults:
|
||||
run:
|
||||
working-directory: crates/luminal_python
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
with:
|
||||
ref: ${{ github.event.pull_request.head.sha || github.sha }}
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- name: Install Modal
|
||||
run: pip install modal
|
||||
- name: Run slow pytest CUDA tests on Modal
|
||||
env:
|
||||
MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }}
|
||||
MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }}
|
||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||
run: modal run modal_pytest_runner.py --gpu A100 --timeout 14400 tests/ -v -s -m slow
|
||||
17
.github/workflows/test-metal.yml
vendored
17
.github/workflows/test-metal.yml
vendored
@@ -17,3 +17,20 @@ jobs:
|
||||
- uses: actions/checkout@v6
|
||||
- name: Run Metal crate tests
|
||||
run: rustup update; cargo test --release -p luminal_metal --verbose -- --test-threads=1
|
||||
|
||||
llama_1b_metal_example:
|
||||
name: Llama 1B Metal Example
|
||||
runs-on: macos-14-xlarge
|
||||
timeout-minutes: 120
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- name: Print runner hardware
|
||||
run: system_profiler SPHardwareDataType SPDisplaysDataType
|
||||
- name: Cache Hugging Face models
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/huggingface
|
||||
key: llama-1b-metal-hf-${{ runner.os }}-${{ runner.arch }}-v1
|
||||
- name: Run Llama 1B Metal example and validate output
|
||||
run: rustup update; python3 ci/metal_llama_1b_example.py
|
||||
|
||||
48
ci/metal_llama_1b_example.py
Normal file
48
ci/metal_llama_1b_example.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
|
||||
def run_and_capture(command: list[str], *, cwd: str, env: dict[str, str]) -> str:
|
||||
process = subprocess.Popen(
|
||||
command,
|
||||
cwd=cwd,
|
||||
env=env,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
)
|
||||
assert process.stdout is not None
|
||||
|
||||
chunks = []
|
||||
while True:
|
||||
chunk = process.stdout.read1(4096)
|
||||
if not chunk:
|
||||
break
|
||||
sys.stdout.buffer.write(chunk)
|
||||
sys.stdout.buffer.flush()
|
||||
chunks.append(chunk)
|
||||
|
||||
return_code = process.wait()
|
||||
output = b"".join(chunks).decode("utf-8", errors="replace")
|
||||
if return_code:
|
||||
raise subprocess.CalledProcessError(return_code, command, output=output)
|
||||
return output
|
||||
|
||||
|
||||
def main():
|
||||
repo_root = os.environ.get("GITHUB_WORKSPACE", os.getcwd())
|
||||
sys.path.insert(0, os.path.join(repo_root, "ci"))
|
||||
from example_output import validate_output
|
||||
|
||||
output = run_and_capture(
|
||||
["cargo", "run", "--release", "-p", "luminal_metal", "--example", "llama_1b"],
|
||||
cwd=repo_root,
|
||||
env=os.environ.copy(),
|
||||
)
|
||||
if "TTFT:" not in output or "TPOT:" not in output:
|
||||
raise AssertionError("Llama 1B Metal example did not complete generation")
|
||||
validate_output("llama", output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,8 +1,10 @@
|
||||
import modal
|
||||
import subprocess
|
||||
import os
|
||||
import shlex
|
||||
|
||||
gpu_type = os.environ.get("GPU_TYPE", "T4")
|
||||
modal_timeout = int(os.environ.get("MODAL_TIMEOUT", "7200"))
|
||||
CUDARC_CUDA_VERSION = "12080"
|
||||
|
||||
app = modal.App("luminal-ci-cargo-test")
|
||||
@@ -28,7 +30,7 @@ cuda_image = (
|
||||
@app.function(
|
||||
image=cuda_image,
|
||||
gpu=gpu_type,
|
||||
timeout=7200, # 2 hours
|
||||
timeout=modal_timeout,
|
||||
)
|
||||
def run_cargo_test():
|
||||
"""Run cargo test for luminal_cuda_lite on a Modal GPU."""
|
||||
@@ -43,17 +45,20 @@ def run_cargo_test():
|
||||
)
|
||||
compute_cap = result.stdout.strip().replace(".", "")
|
||||
|
||||
test_args = shlex.split(os.environ.get("CARGO_TEST_ARGS", "--test-threads=1"))
|
||||
cmd = [
|
||||
"cargo",
|
||||
"test",
|
||||
"--release",
|
||||
"-p",
|
||||
"luminal_cuda_lite",
|
||||
"--verbose",
|
||||
"--",
|
||||
*test_args,
|
||||
]
|
||||
print("Running:", " ".join(cmd), flush=True)
|
||||
subprocess.run(
|
||||
[
|
||||
"cargo",
|
||||
"test",
|
||||
"--release",
|
||||
"-p",
|
||||
"luminal_cuda_lite",
|
||||
"--verbose",
|
||||
"--",
|
||||
"--test-threads=1",
|
||||
],
|
||||
cmd,
|
||||
cwd=WORKDIR,
|
||||
env={
|
||||
**os.environ,
|
||||
|
||||
450
crates/luminal_cuda_lite/src/kernel/dlrm_megakernel.rs
Normal file
450
crates/luminal_cuda_lite/src/kernel/dlrm_megakernel.rs
Normal file
@@ -0,0 +1,450 @@
|
||||
//! DLRM-shape megakernel — one CUDA kernel does the full forward pass
|
||||
//! (bot MLP → N embedding gathers → dot-product interaction → top MLP)
|
||||
//! per (thread × batch row). All intermediate activations live in
|
||||
//! registers; weights are read straight from global memory and rely on
|
||||
//! the L1 cache (the full weight footprint is a few KB).
|
||||
//!
|
||||
//! Parameterized by the DLRM family shape: dense input width, bot MLP
|
||||
//! widths, number of sparse tables + their vocabs, embedding dim,
|
||||
//! top MLP widths. CUDA source is generated per-shape via `format!`
|
||||
//! and compiled through luminal's nvrtc wrapper with source-string
|
||||
//! caching (same path as [`crate::kernel::matmul2d::Matmul2DKernel`]).
|
||||
//!
|
||||
//! Used by `luminal_python`'s PT2 translator when it detects a DLRM-shape
|
||||
//! input graph — see `crates/luminal_python/rust/src/translator/dlrm_pattern.rs`.
|
||||
//! The standalone `examples/dlrm/src/megakernel.rs` is the proof-of-concept
|
||||
//! this module generalizes from.
|
||||
//!
|
||||
//! ## Input layout
|
||||
//!
|
||||
//! The kernel's input list (passed to `cx.custom_op`) is, in order:
|
||||
//! 1. dense_x F32 (B, n_dense_in)
|
||||
//! 2..2+n_sparse int32 indices per sparse table, each (B,)
|
||||
//! — luminal collapses all integer types to 32-bit Int,
|
||||
//! so the runtime delivers a 4-byte-per-element buffer
|
||||
//! regardless of the original PyTorch dtype.
|
||||
//! 2+n_sparse.. F32 embedding weights, one per table, each (V_k, m_spa)
|
||||
//! then bot Linear weight+bias pairs, in topological order
|
||||
//! then top Linear weight+bias pairs, in topological order
|
||||
//!
|
||||
//! The matcher in luminal_python lines up these inputs from the parsed
|
||||
//! PT2 graph; mismatches there will surface as wrong-output bugs in
|
||||
//! `tests/test_dlrm.py`, not as a crash.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use cudarc::driver::{CudaFunction, CudaModule, CudaSlice, CudaStream};
|
||||
use luminal::{
|
||||
dtype::DType, op::CustomOp, op::LLIROp, prelude::FxHashMap, shape::Expression,
|
||||
};
|
||||
|
||||
use crate::compile_module_image_for_current_device;
|
||||
use crate::kernel::KernelOp;
|
||||
|
||||
/// Static shape description for the DLRM family. Every dim is a `usize`
|
||||
/// resolved at translate time — the kernel bakes them all into the CUDA
|
||||
/// source as compile-time constants.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct DlrmMegaKernel {
|
||||
/// Per-call batch size.
|
||||
pub batch: usize,
|
||||
/// Number of dense features (first element of `ln_bot`).
|
||||
pub n_dense_in: usize,
|
||||
/// Bot MLP layer widths. `ln_bot[0] == n_dense_in`; `ln_bot.last() == m_spa`.
|
||||
/// Must have at least 2 entries (one Linear layer).
|
||||
pub ln_bot: Vec<usize>,
|
||||
/// Number of sparse embedding tables.
|
||||
pub n_sparse: usize,
|
||||
/// Vocab size for each table (length == `n_sparse`).
|
||||
pub vocab_sizes: Vec<usize>,
|
||||
/// Sparse embedding dim (equal across tables, == bot MLP output width).
|
||||
pub m_spa: usize,
|
||||
/// Top MLP layer widths. `ln_top[0] == m_spa + n_pairs`; `ln_top.last() == 1`.
|
||||
pub ln_top: Vec<usize>,
|
||||
}
|
||||
|
||||
impl DlrmMegaKernel {
|
||||
/// `n_feat = 1 + n_sparse` — number of feature vectors fed into the
|
||||
/// dot interaction (1 dense + sparse tables).
|
||||
fn n_feat(&self) -> usize {
|
||||
1 + self.n_sparse
|
||||
}
|
||||
|
||||
/// `n_pairs = n_feat * (n_feat - 1) / 2` — number of strictly-lower-tri
|
||||
/// pairs produced by the dot interaction.
|
||||
fn n_pairs(&self) -> usize {
|
||||
let n = self.n_feat();
|
||||
n * (n - 1) / 2
|
||||
}
|
||||
|
||||
/// Validation: cheap up-front check that the shape is internally
|
||||
/// consistent. The matcher should have caught all of these but a
|
||||
/// debug-assert keeps the kernel compile path well-defined.
|
||||
fn validate(&self) {
|
||||
assert!(self.ln_bot.len() >= 2, "ln_bot must have ≥2 entries");
|
||||
assert!(self.ln_top.len() >= 2, "ln_top must have ≥2 entries");
|
||||
assert_eq!(self.ln_bot[0], self.n_dense_in, "ln_bot[0] must == n_dense_in");
|
||||
assert_eq!(*self.ln_bot.last().unwrap(), self.m_spa, "ln_bot.last() must == m_spa");
|
||||
assert_eq!(self.vocab_sizes.len(), self.n_sparse);
|
||||
assert_eq!(
|
||||
self.ln_top[0],
|
||||
self.m_spa + self.n_pairs(),
|
||||
"ln_top[0] must == m_spa + n_pairs"
|
||||
);
|
||||
assert_eq!(*self.ln_top.last().unwrap(), 1, "ln_top.last() must == 1 (binary classifier)");
|
||||
assert!(self.batch > 0);
|
||||
}
|
||||
|
||||
/// Generate the CUDA source for this kernel shape.
|
||||
fn cuda_source(&self) -> String {
|
||||
let n_feat = self.n_feat();
|
||||
let n_pairs = self.n_pairs();
|
||||
|
||||
// ---- Kernel signature ------------------------------------------
|
||||
// luminal's CustomOp dispatcher calls the kernel as
|
||||
// kernel(output_ptr, input_ptrs...)
|
||||
// — see `host/cublaslt`'s C/D ordering and matmul2d's
|
||||
// `matmul_2d_kernel(float* C, const float* A, ...)`. Match that
|
||||
// by putting `out` first, then the inputs in the same order as
|
||||
// emit_megakernel builds the inputs vec.
|
||||
let mut sig = String::from(
|
||||
" float* __restrict__ out,\n const float* __restrict__ dense_x,\n",
|
||||
);
|
||||
for k in 0..self.n_sparse {
|
||||
// 32-bit signed — see module docstring re: luminal's Int collapse.
|
||||
sig.push_str(&format!(" const int* __restrict__ idx_{k},\n"));
|
||||
}
|
||||
for k in 0..self.n_sparse {
|
||||
sig.push_str(&format!(" const float* __restrict__ emb_{k}_w,\n"));
|
||||
}
|
||||
// Bot MLP: one Linear per (ln_bot[i] → ln_bot[i+1]). Stored
|
||||
// PyTorch-style as (out, in), bias (out,).
|
||||
for i in 0..self.ln_bot.len() - 1 {
|
||||
sig.push_str(&format!(" const float* __restrict__ bot_l{i}_w,\n"));
|
||||
sig.push_str(&format!(" const float* __restrict__ bot_l{i}_b,\n"));
|
||||
}
|
||||
for i in 0..self.ln_top.len() - 1 {
|
||||
let trail = if i == self.ln_top.len() - 2 { "" } else { "," };
|
||||
sig.push_str(&format!(" const float* __restrict__ top_l{i}_w,\n"));
|
||||
sig.push_str(&format!(" const float* __restrict__ top_l{i}_b{trail}\n"));
|
||||
}
|
||||
|
||||
// ---- Body --------------------------------------------------------
|
||||
let mut body = String::new();
|
||||
|
||||
// 1. Load dense row into registers.
|
||||
body.push_str(&format!(
|
||||
" // Bot MLP layer 0 input: dense row\n \
|
||||
float layer_in[{}];\n \
|
||||
#pragma unroll\n \
|
||||
for (int i = 0; i < {n_dense_in}; ++i) layer_in[i] = dense_x[bi * {n_dense_in} + i];\n\n",
|
||||
self.ln_bot[0],
|
||||
n_dense_in = self.n_dense_in,
|
||||
));
|
||||
|
||||
// 2. Bot MLP — sequence of Linear+ReLU. Output of last layer
|
||||
// becomes `x[m_spa]` for the interaction.
|
||||
for i in 0..self.ln_bot.len() - 1 {
|
||||
let in_w = self.ln_bot[i];
|
||||
let out_w = self.ln_bot[i + 1];
|
||||
body.push_str(&format!(
|
||||
" // Bot Linear {i}: ({in_w} → {out_w}) + ReLU\n \
|
||||
float bot_l{i}_out[{out_w}];\n \
|
||||
#pragma unroll\n \
|
||||
for (int j = 0; j < {out_w}; ++j) {{\n \
|
||||
float a = bot_l{i}_b[j];\n \
|
||||
#pragma unroll\n \
|
||||
for (int i_ = 0; i_ < {in_w}; ++i_) a += layer_in[i_] * bot_l{i}_w[j*{in_w} + i_];\n \
|
||||
bot_l{i}_out[j] = fmaxf(a, 0.0f);\n \
|
||||
}}\n \
|
||||
// shuffle output into `layer_in` for the next iteration / interaction\n \
|
||||
#pragma unroll\n \
|
||||
for (int i = 0; i < {out_w}; ++i) layer_in[i] = bot_l{i}_out[i];\n\n",
|
||||
));
|
||||
}
|
||||
// After the loop, `layer_in[..m_spa]` holds dense_out ("x").
|
||||
body.push_str(&format!(
|
||||
" float x[{m_spa}];\n \
|
||||
#pragma unroll\n \
|
||||
for (int i = 0; i < {m_spa}; ++i) x[i] = layer_in[i];\n\n",
|
||||
m_spa = self.m_spa,
|
||||
));
|
||||
|
||||
// 3. Sparse embedding gathers (one row per table, bag size 1).
|
||||
for k in 0..self.n_sparse {
|
||||
body.push_str(&format!(
|
||||
" // Embedding lookup {k}\n \
|
||||
float ly_{k}[{m_spa}];\n \
|
||||
{{\n \
|
||||
int i_{k} = idx_{k}[bi];\n \
|
||||
#pragma unroll\n \
|
||||
for (int j = 0; j < {m_spa}; ++j) ly_{k}[j] = emb_{k}_w[i_{k}*{m_spa} + j];\n \
|
||||
}}\n\n",
|
||||
m_spa = self.m_spa,
|
||||
));
|
||||
}
|
||||
|
||||
// 4. Dot interaction: compute n_pairs strictly-lower-tri dot products
|
||||
// over the n_feat = 1 + n_sparse vectors (x, ly_0, ly_1, ...).
|
||||
// Order matches MiniDLRM._interact: for i in 0..n_feat for j in 0..i.
|
||||
// Vec[0] = x, Vec[k+1] = ly_k.
|
||||
body.push_str(&format!(" float zflat[{n_pairs}];\n"));
|
||||
let vec_name = |idx: usize| -> String {
|
||||
if idx == 0 {
|
||||
"x".to_string()
|
||||
} else {
|
||||
format!("ly_{}", idx - 1)
|
||||
}
|
||||
};
|
||||
let mut pair_idx = 0usize;
|
||||
for i in 0..n_feat {
|
||||
for j in 0..i {
|
||||
let a = vec_name(i);
|
||||
let b = vec_name(j);
|
||||
let mut terms = Vec::with_capacity(self.m_spa);
|
||||
for d in 0..self.m_spa {
|
||||
terms.push(format!("{a}[{d}]*{b}[{d}]"));
|
||||
}
|
||||
body.push_str(&format!(
|
||||
" zflat[{pair_idx}] = {};\n",
|
||||
terms.join(" + ")
|
||||
));
|
||||
pair_idx += 1;
|
||||
}
|
||||
}
|
||||
body.push('\n');
|
||||
|
||||
// 5. R = cat([x, zflat]) → top MLP input.
|
||||
let r_len = self.m_spa + n_pairs;
|
||||
body.push_str(&format!(
|
||||
" float r[{r_len}];\n \
|
||||
#pragma unroll\n \
|
||||
for (int i = 0; i < {m_spa}; ++i) r[i] = x[i];\n \
|
||||
#pragma unroll\n \
|
||||
for (int i = 0; i < {n_pairs}; ++i) r[{m_spa} + i] = zflat[i];\n\n",
|
||||
m_spa = self.m_spa,
|
||||
));
|
||||
|
||||
// 6. Top MLP: Linear+ReLU chain, ending with Linear+Sigmoid.
|
||||
// We treat `r` as the first layer input and reuse a single
|
||||
// register array `top_in[]` for subsequent layers.
|
||||
let max_top = *self.ln_top.iter().max().unwrap();
|
||||
body.push_str(&format!(
|
||||
" float top_in[{max_top}];\n \
|
||||
#pragma unroll\n \
|
||||
for (int i = 0; i < {r_len}; ++i) top_in[i] = r[i];\n\n",
|
||||
));
|
||||
let n_top_layers = self.ln_top.len() - 1;
|
||||
for i in 0..n_top_layers {
|
||||
let in_w = self.ln_top[i];
|
||||
let out_w = self.ln_top[i + 1];
|
||||
let is_last = i == n_top_layers - 1;
|
||||
body.push_str(&format!(
|
||||
" // Top Linear {i}: ({in_w} → {out_w})\n \
|
||||
float top_l{i}_out[{out_w}];\n \
|
||||
#pragma unroll\n \
|
||||
for (int j = 0; j < {out_w}; ++j) {{\n \
|
||||
float a = top_l{i}_b[j];\n \
|
||||
#pragma unroll\n \
|
||||
for (int i_ = 0; i_ < {in_w}; ++i_) a += top_in[i_] * top_l{i}_w[j*{in_w} + i_];\n \
|
||||
top_l{i}_out[j] = {activation};\n \
|
||||
}}\n",
|
||||
activation = if is_last {
|
||||
"1.0f / (1.0f + __expf(-a))"
|
||||
} else {
|
||||
"fmaxf(a, 0.0f)"
|
||||
},
|
||||
));
|
||||
if !is_last {
|
||||
body.push_str(&format!(
|
||||
" #pragma unroll\n \
|
||||
for (int i = 0; i < {out_w}; ++i) top_in[i] = top_l{i}_out[i];\n\n",
|
||||
));
|
||||
} else {
|
||||
// Final layer: write to global output. ln_top.last() == 1
|
||||
// so this is just a single value.
|
||||
body.push_str(&format!(
|
||||
" out[bi] = top_l{i}_out[0];\n",
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
// Assemble the full source.
|
||||
format!(
|
||||
"extern \"C\" __global__ void dlrm_mega(\n{sig}) {{\n \
|
||||
int bi = blockIdx.x * blockDim.x + threadIdx.x;\n \
|
||||
if (bi >= {batch}) return;\n\n\
|
||||
{body}\
|
||||
}}\n",
|
||||
batch = self.batch,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl KernelOp for DlrmMegaKernel {
|
||||
fn compile(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> (
|
||||
CudaFunction,
|
||||
Arc<CudaModule>,
|
||||
String,
|
||||
(Expression, Expression, Expression),
|
||||
(Expression, Expression, Expression),
|
||||
Expression,
|
||||
FxHashMap<char, CudaSlice<u8>>,
|
||||
) {
|
||||
self.validate();
|
||||
let kernel = self.cuda_source();
|
||||
|
||||
let (module, func) = if let Some((m, f)) = compile_cache.get(&kernel) {
|
||||
(m.clone(), f.clone())
|
||||
} else {
|
||||
if std::env::var("LUMINAL_DLRM_MEGAKERNEL_DEBUG").is_ok() {
|
||||
let path = "/tmp/dlrm_megakernel_generated.cu";
|
||||
let _ = std::fs::write(path, &kernel);
|
||||
eprintln!("[DlrmMegaKernel] wrote generated source to {path}");
|
||||
}
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), &kernel)
|
||||
.expect("nvrtc compile failed for DLRM megakernel");
|
||||
let module = stream.context().load_module(ptx).expect("load_module");
|
||||
let func = module
|
||||
.load_function("dlrm_mega")
|
||||
.expect("load_function dlrm_mega");
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
(module, func)
|
||||
};
|
||||
|
||||
const BLOCK: usize = 128;
|
||||
let grid_x = self.batch.div_ceil(BLOCK);
|
||||
(
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(
|
||||
Expression::from(grid_x),
|
||||
Expression::from(1usize),
|
||||
Expression::from(1usize),
|
||||
),
|
||||
(
|
||||
Expression::from(BLOCK),
|
||||
Expression::from(1usize),
|
||||
Expression::from(1usize),
|
||||
),
|
||||
Expression::from(0usize),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
Expression::from(self.batch)
|
||||
}
|
||||
|
||||
fn output_bytes(&self) -> Expression {
|
||||
self.output_size() * 4
|
||||
}
|
||||
|
||||
fn output_dtype(&self) -> DType {
|
||||
DType::F32
|
||||
}
|
||||
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
// Per batch row:
|
||||
// dense: n_dense_in × f32
|
||||
// indices: n_sparse × i64
|
||||
// embs: n_sparse × m_spa × f32 (single row each)
|
||||
// bot Ws: sum(in*out) for each layer, × f32 (shared across batch — costed once)
|
||||
// bot bs: sum(out) × f32
|
||||
// top Ws/bs same shape
|
||||
let bot_w: usize = (0..self.ln_bot.len() - 1)
|
||||
.map(|i| self.ln_bot[i] * self.ln_bot[i + 1])
|
||||
.sum();
|
||||
let bot_b: usize = self.ln_bot.iter().skip(1).sum();
|
||||
let top_w: usize = (0..self.ln_top.len() - 1)
|
||||
.map(|i| self.ln_top[i] * self.ln_top[i + 1])
|
||||
.sum();
|
||||
let top_b: usize = self.ln_top.iter().skip(1).sum();
|
||||
let per_row =
|
||||
self.n_dense_in * 4 + self.n_sparse * 8 + self.n_sparse * self.m_spa * 4;
|
||||
let weights = (bot_w + bot_b + top_w + top_b) * 4;
|
||||
Expression::from(self.batch * per_row + weights)
|
||||
}
|
||||
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
// batch × 1 × f32
|
||||
Expression::from(self.batch * 4)
|
||||
}
|
||||
|
||||
fn flops(&self) -> Expression {
|
||||
// Per row:
|
||||
// bot Linears: 2*in*out + out (FMAs + bias)
|
||||
// embedding gathers: 0 FMAs (loads)
|
||||
// dot interaction: n_pairs × m_spa MACs
|
||||
// top Linears: 2*in*out + out + (relu/sigmoid cost ~5)
|
||||
let bot: usize = (0..self.ln_bot.len() - 1)
|
||||
.map(|i| 2 * self.ln_bot[i] * self.ln_bot[i + 1] + self.ln_bot[i + 1])
|
||||
.sum();
|
||||
let dot = self.n_pairs() * self.m_spa * 2;
|
||||
let top: usize = (0..self.ln_top.len() - 1)
|
||||
.map(|i| 2 * self.ln_top[i] * self.ln_top[i + 1] + self.ln_top[i + 1])
|
||||
.sum();
|
||||
Expression::from(self.batch * (bot + dot + top + 5))
|
||||
}
|
||||
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
"DlrmMega"
|
||||
}
|
||||
}
|
||||
|
||||
/// `CustomOp` wrapper for [`DlrmMegaKernel`]. Same pattern as
|
||||
/// [`crate::kernel::matmul2d::Matmul2DCustom`].
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DlrmMegaCustom(pub DlrmMegaKernel);
|
||||
|
||||
impl CustomOp for DlrmMegaCustom {
|
||||
fn to_llir_op(&self) -> LLIROp {
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(self.0.clone()) as Box<dyn KernelOp>)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn mini_dlrm() -> DlrmMegaKernel {
|
||||
DlrmMegaKernel {
|
||||
batch: 2048,
|
||||
n_dense_in: 13,
|
||||
ln_bot: vec![13, 8, 4],
|
||||
n_sparse: 3,
|
||||
vocab_sizes: vec![10, 20, 30],
|
||||
m_spa: 4,
|
||||
ln_top: vec![10, 8, 1],
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn shape_invariants() {
|
||||
let k = mini_dlrm();
|
||||
assert_eq!(k.n_feat(), 4);
|
||||
assert_eq!(k.n_pairs(), 6);
|
||||
assert_eq!(k.ln_top[0], k.m_spa + k.n_pairs());
|
||||
k.validate();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cuda_source_compiles_in_format() {
|
||||
let src = mini_dlrm().cuda_source();
|
||||
// Sanity checks on the generated source — no nvrtc invocation here,
|
||||
// just verify the structural pieces exist.
|
||||
assert!(src.contains("extern \"C\" __global__ void dlrm_mega"));
|
||||
assert!(src.contains("if (bi >= 2048)"));
|
||||
// 3 embedding lookups
|
||||
assert!(src.contains("ly_0[") && src.contains("ly_1[") && src.contains("ly_2["));
|
||||
// 6 dot products
|
||||
assert!(src.contains("zflat[5]"));
|
||||
// Sigmoid epilogue
|
||||
assert!(src.contains("1.0f / (1.0f + __expf(-a))"));
|
||||
}
|
||||
}
|
||||
@@ -11,6 +11,7 @@ use uuid::Uuid;
|
||||
|
||||
pub mod conv2d;
|
||||
pub mod cuda_graph;
|
||||
pub mod dlrm_megakernel;
|
||||
pub mod fusion;
|
||||
pub mod hlir;
|
||||
pub mod matmul2d;
|
||||
@@ -19,6 +20,7 @@ pub mod rope;
|
||||
|
||||
pub use conv2d::KernelConv2D;
|
||||
pub use cuda_graph::*;
|
||||
pub use dlrm_megakernel::{DlrmMegaCustom, DlrmMegaKernel};
|
||||
pub use matmul2d::{
|
||||
Matmul2DCustom, Matmul2DKernel, linear_bias, linear_no_bias_bf16_w, matmul_2d, matmul_2d_t,
|
||||
matmul_3d, matmul_3d_t,
|
||||
|
||||
@@ -1338,9 +1338,21 @@ impl KernelOp for KernelSoftmax {
|
||||
#define FULL_MASK 0xffffffff
|
||||
#define NEG_INF_F __int_as_float(0xff800000)
|
||||
{dyn_defines}
|
||||
#define LOG2E 1.4426950408889634f
|
||||
|
||||
extern \"C\" {{
|
||||
// Online normalizer calculation for softmax (Milakov & Gimelshein 2018).
|
||||
|
||||
// Merge two partial (max, sum) pairs using the online softmax rule.
|
||||
__device__ __forceinline__ void merge_md(float *m, float *d, float m2, float d2) {{
|
||||
float new_m = fmaxf(*m, m2);
|
||||
*d = *d * exp2f((*m - new_m) * LOG2E) + d2 * exp2f((m2 - new_m) * LOG2E);
|
||||
*m = new_m;
|
||||
}}
|
||||
|
||||
__global__ void fused_softmax(float *out, const float *inp{dyn_dims_param}) {{
|
||||
__shared__ float shared[THREADS_PER_BLOCK / WARP_SIZE];
|
||||
__shared__ float sh_m[THREADS_PER_BLOCK / WARP_SIZE];
|
||||
__shared__ float sh_d[THREADS_PER_BLOCK / WARP_SIZE];
|
||||
long long const_z = blockIdx.x;
|
||||
int tid = threadIdx.x;
|
||||
int lane_id = tid % WARP_SIZE;
|
||||
@@ -1352,55 +1364,36 @@ extern \"C\" {{
|
||||
long long in_stride = {in_reduce_stride};
|
||||
long long out_stride = {out_reduce_stride};
|
||||
|
||||
// Pass 1: find max
|
||||
float max_val = NEG_INF_F;
|
||||
// Pass 1: one read of inp produces (global_max, global_sum).
|
||||
float m = NEG_INF_F, d = 0.0f;
|
||||
for (long long i = tid; i < N; i += THREADS_PER_BLOCK) {{
|
||||
max_val = fmaxf(max_val, inp[in_base + i * in_stride]);
|
||||
merge_md(&m, &d, inp[in_base + i * in_stride], 1.0f);
|
||||
}}
|
||||
// Warp reduce: collapse 32 threads within each warp down to lane 0.
|
||||
#pragma unroll
|
||||
for (int s = WARP_SIZE / 2; s > 0; s /= 2) {{
|
||||
max_val = fmaxf(max_val, __shfl_down_sync(FULL_MASK, max_val, s));
|
||||
merge_md(&m, &d, __shfl_down_sync(FULL_MASK, m, s), __shfl_down_sync(FULL_MASK, d, s));
|
||||
}}
|
||||
if (lane_id == 0) shared[warp_id] = max_val;
|
||||
if (lane_id == 0) {{ sh_m[warp_id] = m; sh_d[warp_id] = d; }}
|
||||
__syncthreads();
|
||||
// Block reduce: warp 0 collapses the 8 warp results down to one.
|
||||
if (warp_id == 0) {{
|
||||
max_val = tid < (THREADS_PER_BLOCK / WARP_SIZE) ? shared[tid] : NEG_INF_F;
|
||||
m = tid < (THREADS_PER_BLOCK / WARP_SIZE) ? sh_m[tid] : NEG_INF_F;
|
||||
d = tid < (THREADS_PER_BLOCK / WARP_SIZE) ? sh_d[tid] : 0.0f;
|
||||
#pragma unroll
|
||||
for (int s = (THREADS_PER_BLOCK / WARP_SIZE) / 2; s > 0; s /= 2) {{
|
||||
max_val = fmaxf(max_val, __shfl_down_sync(FULL_MASK, max_val, s));
|
||||
merge_md(&m, &d, __shfl_down_sync(FULL_MASK, m, s), __shfl_down_sync(FULL_MASK, d, s));
|
||||
}}
|
||||
shared[0] = max_val;
|
||||
sh_m[0] = m;
|
||||
sh_d[0] = d;
|
||||
}}
|
||||
__syncthreads();
|
||||
max_val = shared[0];
|
||||
float global_max = sh_m[0];
|
||||
float inv_sum = 1.0f / sh_d[0];
|
||||
|
||||
// Pass 2: compute exp2 and sum
|
||||
float sum_val = 0.0f;
|
||||
// Pass 2: write final softmax values.
|
||||
for (long long i = tid; i < N; i += THREADS_PER_BLOCK) {{
|
||||
float v = exp2f((inp[in_base + i * in_stride] - max_val) * 1.4426950408889634f);
|
||||
out[out_base + i * out_stride] = v; // store exp temporarily
|
||||
sum_val += v;
|
||||
}}
|
||||
#pragma unroll
|
||||
for (int s = WARP_SIZE / 2; s > 0; s /= 2) {{
|
||||
sum_val += __shfl_down_sync(FULL_MASK, sum_val, s);
|
||||
}}
|
||||
if (lane_id == 0) shared[warp_id] = sum_val;
|
||||
__syncthreads();
|
||||
if (warp_id == 0) {{
|
||||
sum_val = tid < (THREADS_PER_BLOCK / WARP_SIZE) ? shared[tid] : 0.0f;
|
||||
#pragma unroll
|
||||
for (int s = (THREADS_PER_BLOCK / WARP_SIZE) / 2; s > 0; s /= 2) {{
|
||||
sum_val += __shfl_down_sync(FULL_MASK, sum_val, s);
|
||||
}}
|
||||
shared[0] = sum_val;
|
||||
}}
|
||||
__syncthreads();
|
||||
float inv_sum = 1.0f / shared[0];
|
||||
|
||||
// Pass 3: normalize
|
||||
for (long long i = tid; i < N; i += THREADS_PER_BLOCK) {{
|
||||
out[out_base + i * out_stride] *= inv_sum;
|
||||
out[out_base + i * out_stride] = exp2f((inp[in_base + i * in_stride] - global_max) * LOG2E) * inv_sum;
|
||||
}}
|
||||
}}
|
||||
}}"
|
||||
|
||||
@@ -106,6 +106,12 @@ pub(crate) struct CompiledBucket {
|
||||
pub(crate) bucket_indices: FxHashMap<char, usize>,
|
||||
/// Whether HLIR pointers have been synced into this bucket's cached_buffer_ptrs
|
||||
pub(crate) hlir_synced: bool,
|
||||
/// Cached topological order of exec_graph nodes. Lazily populated on
|
||||
/// first execute() and invalidated only when the exec_graph itself
|
||||
/// changes (compilation, bucket rebuild). Avoids the per-call
|
||||
/// `petgraph::algo::toposort` Vec allocation + traversal — small but
|
||||
/// real in hot inference loops.
|
||||
pub(crate) exec_topo_order: Vec<NodeIndex>,
|
||||
}
|
||||
|
||||
impl CompiledBucket {
|
||||
@@ -130,6 +136,7 @@ impl CompiledBucket {
|
||||
intermediate_buffer_dims: FxHashSet::default(),
|
||||
bucket_indices: FxHashMap::default(),
|
||||
hlir_synced: false,
|
||||
exec_topo_order: Vec::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -225,7 +232,6 @@ impl CudaRuntime {
|
||||
result::memcpy_dtod_async(dst_ptr, src.ptr(), src.len(), stream.cu_stream())
|
||||
.expect("cuMemcpyDtoDAsync failed");
|
||||
}
|
||||
stream.synchronize().unwrap();
|
||||
dst
|
||||
}
|
||||
|
||||
@@ -328,6 +334,24 @@ impl CudaRuntime {
|
||||
pub unsafe fn set_device_ptr(&mut self, id: impl ToId, device_ptr: u64, n_bytes: usize) {
|
||||
debug_assert!(device_ptr != 0, "set_device_ptr called with null pointer");
|
||||
let id = id.to_id();
|
||||
// Fast path: if the same pointer is already registered, this is a no-op.
|
||||
// PyTorch's caching allocator routinely hands back the same device
|
||||
// pointer for the same logical tensor on each forward; bench loops in
|
||||
// particular hammer this. Skipping the cudarc upgrade_device_ptr +
|
||||
// ManuallyDrop reallocation + the changed_hlir insert + the per-bucket
|
||||
// ptr re-cache that fires on the next execute saves ~2µs per input.
|
||||
if let Some(CudaInput::Ptr(prev)) = self.hlir_buffers.get(&id) {
|
||||
if *prev == device_ptr {
|
||||
// Refresh the external_buffers view in case n_bytes shrank to
|
||||
// exactly cover the live region; cheap and keeps the slice
|
||||
// length correct without rebuilding the registration.
|
||||
if let Some(ext) = self.external_buffers.get(&id) {
|
||||
if ext.len() == n_bytes {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Create CudaSlice view via cudarc's upgrade_device_ptr.
|
||||
// ManuallyDrop prevents cuMemFree on drop (external allocator owns this memory).
|
||||
let slice = unsafe {
|
||||
@@ -1466,9 +1490,19 @@ impl Runtime for CudaRuntime {
|
||||
self.apply_output_ptr_registrations();
|
||||
|
||||
let total_start = std::time::Instant::now();
|
||||
// Populate the topo-order cache lazily — only on first execute for
|
||||
// this bucket. Walking exec_graph + allocating a Vec every iter
|
||||
// measurably shows up at small batches where the kernel work itself
|
||||
// is sub-microsecond and the per-call overhead dominates.
|
||||
{
|
||||
let bucket = &mut self.compiled_buckets[self.active_bucket];
|
||||
if bucket.exec_topo_order.is_empty() && bucket.exec_graph.node_count() > 0 {
|
||||
bucket.exec_topo_order = toposort(&bucket.exec_graph, None).unwrap();
|
||||
}
|
||||
}
|
||||
let bucket = &self.compiled_buckets[self.active_bucket];
|
||||
|
||||
for exec_node in toposort(&bucket.exec_graph, None).unwrap() {
|
||||
for &exec_node in &bucket.exec_topo_order {
|
||||
let exec_op = &bucket.exec_graph[exec_node];
|
||||
trace!("Executing: {:?}", exec_op);
|
||||
|
||||
@@ -1540,21 +1574,26 @@ impl Runtime for CudaRuntime {
|
||||
self.cuda_stream.synchronize().unwrap();
|
||||
self.last_total_time_us = total_start.elapsed().as_secs_f64() * 1_000_000.0;
|
||||
|
||||
// Populate last_kernel_stats from HostOps that report stats
|
||||
self.last_kernel_stats.clear();
|
||||
let bucket = &self.compiled_buckets[self.active_bucket];
|
||||
for exec_node in bucket.exec_graph.node_indices() {
|
||||
let exec_op = &bucket.exec_graph[exec_node];
|
||||
if let Some(name) = exec_op.internal.stats_name() {
|
||||
self.last_kernel_stats.push(KernelStats {
|
||||
name,
|
||||
execution_time_us: 0.0,
|
||||
bytes_loaded: 0,
|
||||
bytes_stored: 0,
|
||||
flops: 0,
|
||||
bandwidth_gbps: 0.0,
|
||||
tflops: 0.0,
|
||||
});
|
||||
// last_kernel_stats is only read by print_execution_stats() — a
|
||||
// diagnostic API. Populating the Vec on every execute() (looping all
|
||||
// exec nodes and calling stats_name() on each) is wasteful in
|
||||
// production inference loops. Gate it on the profiling flag.
|
||||
if self.profiling {
|
||||
self.last_kernel_stats.clear();
|
||||
let bucket = &self.compiled_buckets[self.active_bucket];
|
||||
for exec_node in bucket.exec_graph.node_indices() {
|
||||
let exec_op = &bucket.exec_graph[exec_node];
|
||||
if let Some(name) = exec_op.internal.stats_name() {
|
||||
self.last_kernel_stats.push(KernelStats {
|
||||
name,
|
||||
execution_time_us: 0.0,
|
||||
bytes_loaded: 0,
|
||||
bytes_stored: 0,
|
||||
flops: 0,
|
||||
bandwidth_gbps: 0.0,
|
||||
tflops: 0.0,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1576,11 +1615,22 @@ impl Runtime for CudaRuntime {
|
||||
}
|
||||
}
|
||||
|
||||
// Free owned input buffers after a step so they're not held until the
|
||||
// next set_data overwrites them. External-pointer inputs (registered
|
||||
// via set_device_ptr) are caller-owned and the runtime doesn't free
|
||||
// their memory either way — consuming them only invalidates the
|
||||
// registration and forces the caller to re-register on the next
|
||||
// execute. That's pure waste in tight inference loops (e.g.
|
||||
// luminal_python's torch.compile backend, which re-invokes execute()
|
||||
// for every forward), so leave external-pointer entries in place.
|
||||
let to_consume: Vec<NodeIndex> = self
|
||||
.hlir_buffers
|
||||
.keys()
|
||||
.filter(|hlir_node| !inputs_with_outputs.contains(hlir_node))
|
||||
.copied()
|
||||
.iter()
|
||||
.filter(|(hlir_node, input)| {
|
||||
!inputs_with_outputs.contains(hlir_node)
|
||||
&& !matches!(input, CudaInput::Ptr(_))
|
||||
})
|
||||
.map(|(n, _)| *n)
|
||||
.collect();
|
||||
|
||||
for hlir_node in to_consume {
|
||||
|
||||
@@ -19,7 +19,14 @@ bytemuck = "1.24.0"
|
||||
|
||||
[dev-dependencies]
|
||||
candle-core = "0.9.2-alpha.1"
|
||||
hf-hub = { version = "0.4", default-features = false, features = ["rustls-tls", "ureq"] }
|
||||
luminal_nn = { path = "../luminal_nn" }
|
||||
luminal_tracing = { path = "../luminal_tracing" }
|
||||
proptest = "1.9.0"
|
||||
rand = "0.9.2"
|
||||
rustc-hash = "2.1"
|
||||
tokenizers = "0.22.2"
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||
|
||||
[lints.rust]
|
||||
unexpected_cfgs = { level = "warn", check-cfg = ['cfg(feature, values("cargo-clippy"))'] }
|
||||
|
||||
641
crates/luminal_metal/examples/llama_1b.rs
Normal file
641
crates/luminal_metal/examples/llama_1b.rs
Normal file
@@ -0,0 +1,641 @@
|
||||
use hf_hub::api::sync::Api;
|
||||
use luminal::{
|
||||
dtype::DType,
|
||||
graph::{BuildSearchSpaceOptions, DimBucket, Graph},
|
||||
prelude::{F32Pow, GraphTensor, Runtime},
|
||||
};
|
||||
use luminal_metal::MetalRuntime;
|
||||
use luminal_nn::{LayerNorm, gather_rows, scatter_rows};
|
||||
use luminal_tracing::luminal_filter;
|
||||
use rustc_hash::FxHashSet;
|
||||
use std::{
|
||||
error::Error,
|
||||
io::Write,
|
||||
path::PathBuf,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use tokenizers::Tokenizer;
|
||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
||||
|
||||
const REPO_ID: &str = "unsloth/Llama-3.2-1B-Instruct";
|
||||
const MAX_SEQ_LEN: usize = 2048;
|
||||
const GEN_TOKENS: usize = 96;
|
||||
const SEARCH_GRAPHS: usize = 100;
|
||||
const SEARCH_MEMORY_MIB: usize = 1536;
|
||||
const PROMPT: &str = "In one short paragraph, explain neural networks using the words layers, neurons, learning, and data.";
|
||||
|
||||
const LAYERS: usize = 16;
|
||||
const HIDDEN: usize = 2048;
|
||||
const INTERMEDIATE: usize = 8192;
|
||||
const HEAD_DIM: usize = 64;
|
||||
const N_HEADS: usize = 32;
|
||||
const N_KV_HEADS: usize = 8;
|
||||
const KV_GROUPS: usize = N_HEADS / N_KV_HEADS;
|
||||
const KV_DIM: usize = N_KV_HEADS * HEAD_DIM;
|
||||
const VOCAB_SIZE: usize = 128256;
|
||||
const RMS_NORM_EPS: f32 = 1e-5;
|
||||
const ROPE_THETA: f32 = 500_000.0;
|
||||
const EOS_TOKEN: u32 = 128009;
|
||||
const STOP_TOKEN: u32 = 128001;
|
||||
|
||||
fn prepare_hf_model() -> Result<PathBuf, Box<dyn Error>> {
|
||||
let repo = Api::new()?.model(REPO_ID.to_string());
|
||||
let tokenizer_path = repo.get("tokenizer.json")?;
|
||||
repo.get("model.safetensors")?;
|
||||
Ok(tokenizer_path.parent().unwrap().to_path_buf())
|
||||
}
|
||||
|
||||
fn llama3_chat_prompt(user_prompt: &str) -> String {
|
||||
format!(
|
||||
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{user_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
}
|
||||
|
||||
#[derive(Default, Clone)]
|
||||
struct StepProfile {
|
||||
total: Duration,
|
||||
execute: Duration,
|
||||
get_logits: Duration,
|
||||
cache_roundtrip: Duration,
|
||||
}
|
||||
|
||||
fn avg_ms(duration: Duration, n: usize) -> f64 {
|
||||
if n == 0 {
|
||||
0.0
|
||||
} else {
|
||||
duration.as_secs_f64() * 1e3 / n as f64
|
||||
}
|
||||
}
|
||||
|
||||
fn sample_greedy(logits_row: &[f32], seen: &FxHashSet<u32>, repetition_penalty: f32) -> u32 {
|
||||
let mut row = logits_row.to_vec();
|
||||
for &tok in seen {
|
||||
let logit = &mut row[tok as usize];
|
||||
if *logit > 0.0 {
|
||||
*logit /= repetition_penalty;
|
||||
} else {
|
||||
*logit *= repetition_penalty;
|
||||
}
|
||||
}
|
||||
row.iter()
|
||||
.enumerate()
|
||||
.max_by(|(_, a), (_, b)| a.total_cmp(b))
|
||||
.unwrap()
|
||||
.0 as u32
|
||||
}
|
||||
|
||||
fn causal_mask(q_pos: &[usize], context_len: usize) -> Vec<f32> {
|
||||
let mut mask = vec![-1e10f32; q_pos.len() * context_len];
|
||||
for (qi, &pos) in q_pos.iter().enumerate() {
|
||||
for ci in 0..context_len {
|
||||
if ci <= pos {
|
||||
mask[qi * context_len + ci] = 0.0;
|
||||
}
|
||||
}
|
||||
}
|
||||
mask
|
||||
}
|
||||
|
||||
struct KVCache {
|
||||
k_caches: Vec<GraphTensor>,
|
||||
v_caches: Vec<GraphTensor>,
|
||||
}
|
||||
|
||||
impl KVCache {
|
||||
fn new(cx: &mut Graph, num_slots: usize) -> Self {
|
||||
let mut k_caches = Vec::with_capacity(LAYERS);
|
||||
let mut v_caches = Vec::with_capacity(LAYERS);
|
||||
for l in 0..LAYERS {
|
||||
k_caches.push(
|
||||
cx.named_tensor(format!("kv_cache.{l}.k"), (num_slots, KV_DIM))
|
||||
.persist(),
|
||||
);
|
||||
v_caches.push(
|
||||
cx.named_tensor(format!("kv_cache.{l}.v"), (num_slots, KV_DIM))
|
||||
.persist(),
|
||||
);
|
||||
}
|
||||
Self { k_caches, v_caches }
|
||||
}
|
||||
}
|
||||
|
||||
struct Llama {
|
||||
embedding: GraphTensor,
|
||||
layers: Vec<LlamaLayer>,
|
||||
lm_norm: LayerNorm,
|
||||
}
|
||||
|
||||
impl Llama {
|
||||
fn init(cx: &mut Graph) -> Self {
|
||||
let mut layers = Vec::with_capacity(LAYERS);
|
||||
for l in 0..LAYERS {
|
||||
layers.push(LlamaLayer {
|
||||
up: cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.mlp.up_proj.weight"),
|
||||
(INTERMEDIATE, HIDDEN),
|
||||
)
|
||||
.persist(),
|
||||
gate: cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.mlp.gate_proj.weight"),
|
||||
(INTERMEDIATE, HIDDEN),
|
||||
)
|
||||
.persist(),
|
||||
down: cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.mlp.down_proj.weight"),
|
||||
(HIDDEN, INTERMEDIATE),
|
||||
)
|
||||
.persist(),
|
||||
q_proj: cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.self_attn.q_proj.weight"),
|
||||
(HIDDEN, HIDDEN),
|
||||
)
|
||||
.persist(),
|
||||
k_proj: cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.self_attn.k_proj.weight"),
|
||||
(KV_DIM, HIDDEN),
|
||||
)
|
||||
.persist(),
|
||||
v_proj: cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.self_attn.v_proj.weight"),
|
||||
(KV_DIM, HIDDEN),
|
||||
)
|
||||
.persist(),
|
||||
o_proj: cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.self_attn.o_proj.weight"),
|
||||
(HIDDEN, HIDDEN),
|
||||
)
|
||||
.persist(),
|
||||
attn_rms: LayerNorm::new(
|
||||
HIDDEN,
|
||||
Some(&format!("model.layers.{l}.input_layernorm.weight")),
|
||||
None,
|
||||
false,
|
||||
RMS_NORM_EPS,
|
||||
cx,
|
||||
),
|
||||
mlp_rms: LayerNorm::new(
|
||||
HIDDEN,
|
||||
Some(&format!("model.layers.{l}.post_attention_layernorm.weight")),
|
||||
None,
|
||||
false,
|
||||
RMS_NORM_EPS,
|
||||
cx,
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
Self {
|
||||
embedding: cx
|
||||
.named_tensor("model.embed_tokens.weight", (VOCAB_SIZE, HIDDEN))
|
||||
.persist(),
|
||||
layers,
|
||||
lm_norm: LayerNorm::new(
|
||||
HIDDEN,
|
||||
Some("model.norm.weight"),
|
||||
None,
|
||||
false,
|
||||
RMS_NORM_EPS,
|
||||
cx,
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&self,
|
||||
input: GraphTensor,
|
||||
q_pos: GraphTensor,
|
||||
scatter_idx: GraphTensor,
|
||||
gather_idx: GraphTensor,
|
||||
attn_mask: GraphTensor,
|
||||
kv_cache: &KVCache,
|
||||
) -> (GraphTensor, Vec<(GraphTensor, GraphTensor)>) {
|
||||
let seq = input.dims1();
|
||||
let mut x = self.embedding.gather(
|
||||
(input * HIDDEN).expand_dim(1, HIDDEN)
|
||||
+ input.graph().arange(HIDDEN).expand_dim(0, seq),
|
||||
);
|
||||
let mut cache_outputs = Vec::with_capacity(LAYERS);
|
||||
for (i, layer) in self.layers.iter().enumerate() {
|
||||
let (x_new, k_out, v_out) = layer.forward(
|
||||
x,
|
||||
q_pos,
|
||||
scatter_idx,
|
||||
gather_idx,
|
||||
attn_mask,
|
||||
kv_cache.k_caches[i],
|
||||
kv_cache.v_caches[i],
|
||||
);
|
||||
x = x_new;
|
||||
cache_outputs.push((k_out, v_out));
|
||||
}
|
||||
|
||||
let logits = self.lm_norm.forward(x).matmul(self.embedding.t());
|
||||
(logits, cache_outputs)
|
||||
}
|
||||
}
|
||||
|
||||
struct LlamaLayer {
|
||||
up: GraphTensor,
|
||||
gate: GraphTensor,
|
||||
down: GraphTensor,
|
||||
q_proj: GraphTensor,
|
||||
k_proj: GraphTensor,
|
||||
v_proj: GraphTensor,
|
||||
o_proj: GraphTensor,
|
||||
attn_rms: LayerNorm,
|
||||
mlp_rms: LayerNorm,
|
||||
}
|
||||
|
||||
fn llama_rotary_embeddings(mut input: GraphTensor, pos_ids: GraphTensor) -> GraphTensor {
|
||||
input = input.split_dims(1, HEAD_DIM).transpose(0, 1);
|
||||
|
||||
let freqs = input
|
||||
.graph()
|
||||
.arange_options(0, HEAD_DIM, 2)
|
||||
.cast(DType::F32)
|
||||
/ HEAD_DIM as f32;
|
||||
let inv_freqs = ROPE_THETA.pow(freqs).reciprocal();
|
||||
let emb = pos_ids
|
||||
.cast(DType::F32)
|
||||
.expand_dim(1, 1)
|
||||
.matmul(inv_freqs.expand_dim(0, 1));
|
||||
|
||||
let x0 = input.slice((.., .., ..HEAD_DIM / 2));
|
||||
let x1 = input.slice((.., .., HEAD_DIM / 2..));
|
||||
|
||||
let cos = emb.cos().expand_dim(0, x0.dims()[0]);
|
||||
let sin = emb.sin().expand_dim(0, x0.dims()[0]);
|
||||
let x0_out = x0 * cos - x1 * sin;
|
||||
let x1_out = x1 * cos + x0 * sin;
|
||||
|
||||
x0_out
|
||||
.concat_along(x1_out, 2)
|
||||
.transpose(0, 1)
|
||||
.merge_dims(1, 2)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn attention(
|
||||
q_rope: GraphTensor,
|
||||
k_rope: GraphTensor,
|
||||
v: GraphTensor,
|
||||
k_cache: GraphTensor,
|
||||
v_cache: GraphTensor,
|
||||
scatter_idx: GraphTensor,
|
||||
gather_idx: GraphTensor,
|
||||
attn_mask: GraphTensor,
|
||||
) -> (GraphTensor, GraphTensor, GraphTensor) {
|
||||
let k_cache_out = scatter_rows(k_rope, scatter_idx, k_cache, KV_DIM);
|
||||
let v_cache_out = scatter_rows(v, scatter_idx, v_cache, KV_DIM);
|
||||
|
||||
let k = gather_rows(k_cache_out, gather_idx, KV_DIM);
|
||||
let v_ctx = gather_rows(v_cache_out, gather_idx, KV_DIM);
|
||||
|
||||
let q = (q_rope * 1.0).split_dims(1, HEAD_DIM).transpose(0, 1);
|
||||
let k = k.split_dims(1, HEAD_DIM).permute((1, 2, 0));
|
||||
let v_ctx = v_ctx.split_dims(1, HEAD_DIM).transpose(0, 1);
|
||||
|
||||
let k = k.expand_dim(1, KV_GROUPS).merge_dims(0, 1) * 1.0;
|
||||
let v_ctx = v_ctx.expand_dim(1, KV_GROUPS).merge_dims(0, 1) * 1.0;
|
||||
|
||||
let scores = q.matmul(k) / (HEAD_DIM as f32).sqrt();
|
||||
let masked_scores = scores + attn_mask.expand_dim(0, N_HEADS);
|
||||
let weights = masked_scores.softmax(2);
|
||||
let out = weights.matmul(v_ctx);
|
||||
let attn_out = out.transpose(0, 1).merge_dims(1, 2);
|
||||
|
||||
(attn_out, k_cache_out, v_cache_out)
|
||||
}
|
||||
|
||||
impl LlamaLayer {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn forward(
|
||||
&self,
|
||||
mut x: GraphTensor,
|
||||
q_pos: GraphTensor,
|
||||
scatter_idx: GraphTensor,
|
||||
gather_idx: GraphTensor,
|
||||
attn_mask: GraphTensor,
|
||||
k_cache: GraphTensor,
|
||||
v_cache: GraphTensor,
|
||||
) -> (GraphTensor, GraphTensor, GraphTensor) {
|
||||
let x_attn = self.attn_rms.forward(x);
|
||||
let q = x_attn.matmul(self.q_proj.t());
|
||||
let k = x_attn.matmul(self.k_proj.t());
|
||||
let v = x_attn.matmul(self.v_proj.t());
|
||||
|
||||
let q_rope = llama_rotary_embeddings(q, q_pos);
|
||||
let k_rope = llama_rotary_embeddings(k, q_pos);
|
||||
let (attn_out, k_cache_out, v_cache_out) = attention(
|
||||
q_rope,
|
||||
k_rope,
|
||||
v,
|
||||
k_cache,
|
||||
v_cache,
|
||||
scatter_idx,
|
||||
gather_idx,
|
||||
attn_mask,
|
||||
);
|
||||
x += attn_out.matmul(self.o_proj.t());
|
||||
|
||||
let x_mlp = self.mlp_rms.forward(x);
|
||||
let mlp_out =
|
||||
(x_mlp.matmul(self.gate.t()).swish() * x_mlp.matmul(self.up.t())).matmul(self.down.t());
|
||||
(x + mlp_out, k_cache_out, v_cache_out)
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn run_model_step(
|
||||
cx: &mut Graph,
|
||||
runtime: &mut MetalRuntime,
|
||||
input: GraphTensor,
|
||||
q_pos_t: GraphTensor,
|
||||
scatter_idx_t: GraphTensor,
|
||||
gather_idx_t: GraphTensor,
|
||||
attn_mask_t: GraphTensor,
|
||||
logits: GraphTensor,
|
||||
kv_cache: &KVCache,
|
||||
cache_outputs: &[(GraphTensor, GraphTensor)],
|
||||
tokens: &[u32],
|
||||
q_pos: &[i32],
|
||||
scatter_idx: &[i32],
|
||||
gather_idx: &[i32],
|
||||
attn_mask: &[f32],
|
||||
) -> (Vec<f32>, StepProfile) {
|
||||
let start = Instant::now();
|
||||
cx.set_dim('s', tokens.len());
|
||||
cx.set_dim('c', gather_idx.len());
|
||||
|
||||
runtime.set_data(input, tokens.iter().map(|t| *t as i32).collect::<Vec<_>>());
|
||||
runtime.set_data(q_pos_t, q_pos.to_vec());
|
||||
runtime.set_data(scatter_idx_t, scatter_idx.to_vec());
|
||||
runtime.set_data(gather_idx_t, gather_idx.to_vec());
|
||||
runtime.set_data(attn_mask_t, attn_mask.to_vec());
|
||||
runtime.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
|
||||
let execute_start = Instant::now();
|
||||
runtime.execute(&cx.dyn_map);
|
||||
let execute = execute_start.elapsed();
|
||||
|
||||
let logits_start = Instant::now();
|
||||
let logits_data = runtime.get_f32(logits);
|
||||
let get_logits = logits_start.elapsed();
|
||||
|
||||
let cache_start = Instant::now();
|
||||
for (layer_idx, (k_out, v_out)) in cache_outputs.iter().enumerate() {
|
||||
let k_buf = runtime.remove_buffer(*k_out);
|
||||
let v_buf = runtime.remove_buffer(*v_out);
|
||||
runtime.set_buffer(kv_cache.k_caches[layer_idx], k_buf);
|
||||
runtime.set_buffer(kv_cache.v_caches[layer_idx], v_buf);
|
||||
}
|
||||
let cache_roundtrip = cache_start.elapsed();
|
||||
|
||||
(
|
||||
logits_data,
|
||||
StepProfile {
|
||||
total: start.elapsed(),
|
||||
execute,
|
||||
get_logits,
|
||||
cache_roundtrip,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
fn main() -> Result<(), Box<dyn Error>> {
|
||||
let _ = tracing_subscriber::registry()
|
||||
.with(tracing_subscriber::fmt::layer())
|
||||
.with(luminal_filter())
|
||||
.try_init();
|
||||
|
||||
let model_dir = prepare_hf_model()?;
|
||||
println!("Using model directory: {}", model_dir.display());
|
||||
|
||||
let tokenizer = Tokenizer::from_file(model_dir.join("tokenizer.json"))
|
||||
.map_err(|err| err as Box<dyn Error>)?;
|
||||
let prompt_tokens = tokenizer
|
||||
.encode(llama3_chat_prompt(PROMPT), false)
|
||||
.map_err(|err| err as Box<dyn Error>)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
|
||||
let mut cx = Graph::default();
|
||||
let input = cx.named_tensor("input", 's').as_dtype(DType::Int);
|
||||
let q_pos_t = cx.named_tensor("q_pos", 's').as_dtype(DType::Int);
|
||||
let scatter_idx_t = cx.named_tensor("scatter_idx", 's').as_dtype(DType::Int);
|
||||
let gather_idx_t = cx.named_tensor("gather_idx", 'c').as_dtype(DType::Int);
|
||||
let attn_mask_t = cx.named_tensor("attn_mask", ('s', 'c'));
|
||||
let kv_cache = KVCache::new(&mut cx, MAX_SEQ_LEN);
|
||||
let (logits, cache_outputs) = Llama::init(&mut cx).forward(
|
||||
input,
|
||||
q_pos_t,
|
||||
scatter_idx_t,
|
||||
gather_idx_t,
|
||||
attn_mask_t,
|
||||
&kv_cache,
|
||||
);
|
||||
let logits = logits.output();
|
||||
for (k_out, v_out) in &cache_outputs {
|
||||
k_out.output();
|
||||
v_out.output();
|
||||
}
|
||||
|
||||
cx.set_dim('s', 1);
|
||||
cx.set_dim('c', 1);
|
||||
|
||||
println!("Building E-Graph...");
|
||||
let egraph_start = Instant::now();
|
||||
cx.build_search_space_with_options::<MetalRuntime>(
|
||||
BuildSearchSpaceOptions::new().max_memory_mib(SEARCH_MEMORY_MIB),
|
||||
);
|
||||
println!(
|
||||
" E-Graph build: {:.2} s",
|
||||
egraph_start.elapsed().as_secs_f64()
|
||||
);
|
||||
|
||||
println!("Loading weights...");
|
||||
let load_start = Instant::now();
|
||||
let mut runtime = MetalRuntime::initialize(());
|
||||
runtime.load_safetensors(&cx, model_dir.join("model.safetensors").to_str().unwrap());
|
||||
println!(" Weight load: {:.2} s", load_start.elapsed().as_secs_f64());
|
||||
|
||||
let cache_bytes = MAX_SEQ_LEN * KV_DIM * std::mem::size_of::<f32>();
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
|
||||
}
|
||||
|
||||
println!("Compiling...");
|
||||
let compile_start = Instant::now();
|
||||
let max_prefill = (prompt_tokens.len() + 16)
|
||||
.next_power_of_two()
|
||||
.min(MAX_SEQ_LEN);
|
||||
let max_context = (prompt_tokens.len() + GEN_TOKENS + 1)
|
||||
.next_power_of_two()
|
||||
.min(MAX_SEQ_LEN);
|
||||
let search_s = 16.min(max_prefill).max(2);
|
||||
let search_c = 16.min(max_context).max(2);
|
||||
cx.set_dim_buckets(
|
||||
's',
|
||||
&[
|
||||
DimBucket::new(1, 1),
|
||||
DimBucket::new(2, max_prefill).representative(search_s),
|
||||
],
|
||||
);
|
||||
cx.set_dim_buckets(
|
||||
'c',
|
||||
&[
|
||||
DimBucket::new(1, 1),
|
||||
DimBucket::new(2, max_context).representative(search_c),
|
||||
],
|
||||
);
|
||||
cx.set_dim('s', search_s);
|
||||
cx.set_dim('c', search_c);
|
||||
runtime.set_data(input, vec![1; search_s]);
|
||||
runtime.set_data(q_pos_t, (0..search_s as i32).collect::<Vec<_>>());
|
||||
runtime.set_data(scatter_idx_t, (0..search_s as i32).collect::<Vec<_>>());
|
||||
runtime.set_data(gather_idx_t, (0..search_c as i32).collect::<Vec<_>>());
|
||||
runtime.set_data(attn_mask_t, vec![0.0f32; search_s * search_c]);
|
||||
runtime = cx.search(runtime, SEARCH_GRAPHS);
|
||||
println!(
|
||||
" Search/compile: {:.2} s",
|
||||
compile_start.elapsed().as_secs_f64()
|
||||
);
|
||||
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
|
||||
}
|
||||
|
||||
let prompt_len = prompt_tokens.len();
|
||||
let mut context_len = 0usize;
|
||||
let mut profiles = Vec::new();
|
||||
let mut seen_tokens = FxHashSet::default();
|
||||
let repetition_penalty = 1.05;
|
||||
|
||||
println!(
|
||||
"Prompt: {} tokens, generating up to {} tokens",
|
||||
prompt_len, GEN_TOKENS
|
||||
);
|
||||
|
||||
let mut generated = 0usize;
|
||||
let mut next_token = None;
|
||||
if GEN_TOKENS > 0 && prompt_len > 0 {
|
||||
let positions: Vec<usize> = (0..prompt_len).collect();
|
||||
let q_pos: Vec<i32> = positions.iter().map(|&p| p as i32).collect();
|
||||
let mask = causal_mask(&positions, prompt_len);
|
||||
let (logits_data, profile) = run_model_step(
|
||||
&mut cx,
|
||||
&mut runtime,
|
||||
input,
|
||||
q_pos_t,
|
||||
scatter_idx_t,
|
||||
gather_idx_t,
|
||||
attn_mask_t,
|
||||
logits,
|
||||
&kv_cache,
|
||||
&cache_outputs,
|
||||
&prompt_tokens,
|
||||
&q_pos,
|
||||
&q_pos,
|
||||
&q_pos,
|
||||
&mask,
|
||||
);
|
||||
context_len = prompt_len;
|
||||
|
||||
let token = sample_greedy(
|
||||
&logits_data[logits_data.len() - VOCAB_SIZE..],
|
||||
&seen_tokens,
|
||||
repetition_penalty,
|
||||
);
|
||||
seen_tokens.insert(token);
|
||||
next_token = Some(token);
|
||||
generated = 1;
|
||||
profiles.push(profile);
|
||||
|
||||
if token != EOS_TOKEN && token != STOP_TOKEN {
|
||||
print!(
|
||||
"{}",
|
||||
tokenizer
|
||||
.decode(&[token], true)
|
||||
.map_err(|err| err as Box<dyn Error>)?
|
||||
);
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
}
|
||||
|
||||
while generated < GEN_TOKENS {
|
||||
let current_token = match next_token {
|
||||
Some(token) if token != EOS_TOKEN && token != STOP_TOKEN => token,
|
||||
_ => break,
|
||||
};
|
||||
let gather_idx = (0..=context_len as i32).collect::<Vec<_>>();
|
||||
let mask = causal_mask(&[context_len], context_len + 1);
|
||||
let (logits_data, profile) = run_model_step(
|
||||
&mut cx,
|
||||
&mut runtime,
|
||||
input,
|
||||
q_pos_t,
|
||||
scatter_idx_t,
|
||||
gather_idx_t,
|
||||
attn_mask_t,
|
||||
logits,
|
||||
&kv_cache,
|
||||
&cache_outputs,
|
||||
&[current_token],
|
||||
&[context_len as i32],
|
||||
&[context_len as i32],
|
||||
&gather_idx,
|
||||
&mask,
|
||||
);
|
||||
context_len += 1;
|
||||
|
||||
let token = sample_greedy(
|
||||
&logits_data[logits_data.len() - VOCAB_SIZE..],
|
||||
&seen_tokens,
|
||||
repetition_penalty,
|
||||
);
|
||||
seen_tokens.insert(token);
|
||||
next_token = Some(token);
|
||||
generated += 1;
|
||||
profiles.push(profile);
|
||||
|
||||
if token == EOS_TOKEN || token == STOP_TOKEN {
|
||||
break;
|
||||
}
|
||||
print!(
|
||||
"{}",
|
||||
tokenizer
|
||||
.decode(&[token], true)
|
||||
.map_err(|err| err as Box<dyn Error>)?
|
||||
);
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
println!();
|
||||
|
||||
let ttft = profiles.first().map(|p| p.total).unwrap_or_default();
|
||||
let decode_steps = profiles.len().saturating_sub(1);
|
||||
let decode_total: Duration = profiles.iter().skip(1).map(|p| p.total).sum();
|
||||
println!(" TTFT: {:.2} ms", ttft.as_secs_f64() * 1e3);
|
||||
println!(" TPOT: {:.2} ms", avg_ms(decode_total, decode_steps));
|
||||
|
||||
let execute_total: Duration = profiles.iter().map(|p| p.execute).sum();
|
||||
let logits_total: Duration = profiles.iter().map(|p| p.get_logits).sum();
|
||||
let cache_total: Duration = profiles.iter().map(|p| p.cache_roundtrip).sum();
|
||||
println!(
|
||||
" Profile: n={}, exec={:.2} ms, logits={:.2} ms, cache={:.2} ms",
|
||||
profiles.len(),
|
||||
avg_ms(execute_total, profiles.len()),
|
||||
avg_ms(logits_total, profiles.len()),
|
||||
avg_ms(cache_total, profiles.len()),
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -6,10 +6,127 @@ pub use ops::*;
|
||||
use luminal::dtype::DType;
|
||||
use luminal::op::EgglogOp;
|
||||
use luminal::prelude::*;
|
||||
use metal::{Buffer, CommandBufferRef, ComputeCommandEncoderRef, ComputePipelineState, Device};
|
||||
use metal::{
|
||||
Buffer, CommandBufferRef, ComputeCommandEncoderRef, ComputePipelineState, Device,
|
||||
foreign_types::ForeignTypeRef, mps,
|
||||
};
|
||||
use objc::rc::StrongPtr;
|
||||
use objc::runtime::Object;
|
||||
use objc::{class, msg_send, sel, sel_impl};
|
||||
use std::cell::RefCell;
|
||||
|
||||
pub const DYN_SLOT_COUNT: usize = 26;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
struct MpsMatrixDescriptorKey {
|
||||
rows: usize,
|
||||
cols: usize,
|
||||
row_bytes: u64,
|
||||
data_type: isize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
struct MpsMatmulKey {
|
||||
transpose_lhs: bool,
|
||||
transpose_rhs: bool,
|
||||
m: usize,
|
||||
n: usize,
|
||||
k: usize,
|
||||
alpha: u64,
|
||||
beta: u64,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct MpsKernelCache {
|
||||
matrix_descriptors: FxHashMap<MpsMatrixDescriptorKey, StrongPtr>,
|
||||
matmul_kernels: FxHashMap<MpsMatmulKey, StrongPtr>,
|
||||
}
|
||||
|
||||
impl MpsKernelCache {
|
||||
pub(crate) fn matrix_descriptor(
|
||||
&mut self,
|
||||
rows: usize,
|
||||
cols: usize,
|
||||
row_bytes: u64,
|
||||
dtype: DType,
|
||||
) -> *mut Object {
|
||||
let key = MpsMatrixDescriptorKey {
|
||||
rows,
|
||||
cols,
|
||||
row_bytes,
|
||||
data_type: Self::mps_data_type(dtype),
|
||||
};
|
||||
let descriptor = self
|
||||
.matrix_descriptors
|
||||
.entry(key)
|
||||
.or_insert_with(|| unsafe {
|
||||
let descriptor: *mut Object = msg_send![
|
||||
class!(MPSMatrixDescriptor),
|
||||
matrixDescriptorWithRows: rows
|
||||
columns: cols
|
||||
rowBytes: row_bytes as usize
|
||||
dataType: key.data_type
|
||||
];
|
||||
StrongPtr::retain(descriptor)
|
||||
});
|
||||
**descriptor
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) fn matrix_multiplication(
|
||||
&mut self,
|
||||
command_buffer: &CommandBufferRef,
|
||||
transpose_lhs: bool,
|
||||
transpose_rhs: bool,
|
||||
m: usize,
|
||||
n: usize,
|
||||
k: usize,
|
||||
alpha: f64,
|
||||
beta: f64,
|
||||
) -> *mut Object {
|
||||
let key = MpsMatmulKey {
|
||||
transpose_lhs,
|
||||
transpose_rhs,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
alpha: alpha.to_bits(),
|
||||
beta: beta.to_bits(),
|
||||
};
|
||||
let kernel = self.matmul_kernels.entry(key).or_insert_with(|| unsafe {
|
||||
let device: *mut Object = msg_send![command_buffer.as_ptr(), device];
|
||||
let kernel: *mut Object = msg_send![class!(MPSMatrixMultiplication), alloc];
|
||||
let kernel: *mut Object = msg_send![
|
||||
kernel,
|
||||
initWithDevice: device
|
||||
transposeLeft: transpose_lhs
|
||||
transposeRight: transpose_rhs
|
||||
resultRows: m
|
||||
resultColumns: n
|
||||
interiorColumns: k
|
||||
alpha: alpha
|
||||
beta: beta
|
||||
];
|
||||
StrongPtr::new(kernel)
|
||||
});
|
||||
**kernel
|
||||
}
|
||||
|
||||
fn mps_data_type(dtype: DType) -> isize {
|
||||
match dtype {
|
||||
DType::F32 | DType::TF32 => mps::MPSDataType::Float32 as isize,
|
||||
DType::F16 => mps::MPSDataType::Float16 as isize,
|
||||
unsupported => panic!("MPSMatmul does not support dtype {unsupported:?}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MetalEncodeContext<'a> {
|
||||
pub(crate) command_buffer: &'a CommandBufferRef,
|
||||
pub(crate) dyn_buffer: &'a Buffer,
|
||||
pub(crate) mps_cache: &'a RefCell<MpsKernelCache>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MetalMulInfo {
|
||||
pub shape: Vec<Expression>,
|
||||
@@ -52,19 +169,18 @@ pub trait MetalKernelOp: EgglogOp {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn encode(
|
||||
&self,
|
||||
command_buffer: &CommandBufferRef,
|
||||
context: &mut MetalEncodeContext<'_>,
|
||||
pipeline: Option<&ComputePipelineState>,
|
||||
inputs: &[&Buffer],
|
||||
output: &Buffer,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
dyn_buffer: &Buffer,
|
||||
_input_dtypes: &[DType],
|
||||
_output_dtype: DType,
|
||||
) {
|
||||
let pipeline = pipeline.expect("compute pipeline not compiled");
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
let encoder = context.command_buffer.new_compute_command_encoder();
|
||||
let dyn_idx = inputs.len() as u64 + 1;
|
||||
encoder.set_buffer(dyn_idx, Some(dyn_buffer), 0);
|
||||
encoder.set_buffer(dyn_idx, Some(context.dyn_buffer), 0);
|
||||
self.encode_compute(encoder, pipeline, inputs, output, dyn_map);
|
||||
encoder.end_encoding();
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use super::{MPSMatrixLayout, MetalKernelOp, MetalMulInfo, MetalSumReduceInfo};
|
||||
use super::{MPSMatrixLayout, MetalEncodeContext, MetalKernelOp, MetalMulInfo, MetalSumReduceInfo};
|
||||
use luminal::{
|
||||
egglog_utils::{
|
||||
SerializedEGraph,
|
||||
@@ -19,9 +19,8 @@ use luminal::{
|
||||
shape::flatten_strides,
|
||||
};
|
||||
use metal::{
|
||||
Buffer, CommandBufferRef, ComputeCommandEncoderRef, ComputePipelineState, Device, MTLSize,
|
||||
Buffer, ComputeCommandEncoderRef, ComputePipelineState, Device, MTLLanguageVersion, MTLSize,
|
||||
foreign_types::{ForeignType, ForeignTypeRef},
|
||||
mps,
|
||||
};
|
||||
use objc::runtime::Object;
|
||||
use objc::{class, msg_send, sel, sel_impl};
|
||||
@@ -56,15 +55,21 @@ pub type MetalOps = (
|
||||
);
|
||||
|
||||
fn compile_shader(device: &Device, source: &str, function_name: &str) -> ComputePipelineState {
|
||||
let options = metal::CompileOptions::new();
|
||||
options.set_language_version(MTLLanguageVersion::V2_4);
|
||||
let library = device
|
||||
.new_library_with_source(source, &metal::CompileOptions::new())
|
||||
.expect("Failed to compile Metal shader");
|
||||
.new_library_with_source(source, &options)
|
||||
.unwrap_or_else(|err| {
|
||||
panic!("Failed to compile Metal shader {function_name}: {err:?}\n{source}")
|
||||
});
|
||||
let function = library
|
||||
.get_function(function_name, None)
|
||||
.expect("Failed to get function from library");
|
||||
device
|
||||
.new_compute_pipeline_state_with_function(&function)
|
||||
.expect("Failed to create compute pipeline state")
|
||||
.unwrap_or_else(|err| {
|
||||
panic!("Failed to create Metal compute pipeline state for {function_name}: {err:?}\n{source}")
|
||||
})
|
||||
}
|
||||
|
||||
fn lower_dynamic_consts(mut code: String) -> String {
|
||||
@@ -1039,42 +1044,33 @@ impl MetalKernelOp for MetalSumReduce {
|
||||
constant int *dyn [[buffer({dyn_buffer_index})]],
|
||||
constant uint &n_outputs [[buffer({n_outputs_index})]],
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint tid [[thread_index_in_threadgroup]],
|
||||
uint simd_lane [[thread_index_in_simdgroup]],
|
||||
uint simd_id [[simdgroup_index_in_threadgroup]]
|
||||
uint tid [[thread_index_in_threadgroup]]
|
||||
) {{
|
||||
if (gid >= n_outputs) return;
|
||||
|
||||
threadgroup float warp_sums[THREADS_PER_GROUP / 32];
|
||||
threadgroup float partials[THREADS_PER_GROUP];
|
||||
|
||||
int in_start = {in_idx};
|
||||
int iters = {iters};
|
||||
(void)dyn;
|
||||
|
||||
// Each thread accumulates multiple elements
|
||||
float sum = 0.0f;
|
||||
for (int i = tid; i < iters; i += THREADS_PER_GROUP) {{
|
||||
sum += {in_val};
|
||||
}}
|
||||
|
||||
// Warp-level reduction using simd_sum
|
||||
sum = simd_sum(sum);
|
||||
|
||||
// First lane of each warp writes to shared memory
|
||||
if (simd_lane == 0) {{
|
||||
warp_sums[simd_id] = sum;
|
||||
}}
|
||||
partials[tid] = sum;
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// First warp does final reduction
|
||||
if (simd_id == 0) {{
|
||||
int n_warps = THREADS_PER_GROUP / 32;
|
||||
float block_sum = (tid < uint(n_warps)) ? warp_sums[tid] : 0.0f;
|
||||
block_sum = simd_sum(block_sum);
|
||||
|
||||
if (tid == 0) {{
|
||||
out[{out_idx}] = {out_val};
|
||||
for (uint stride = THREADS_PER_GROUP / 2; stride > 0; stride >>= 1) {{
|
||||
if (tid < stride) {{
|
||||
partials[tid] += partials[tid + stride];
|
||||
}}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}}
|
||||
|
||||
if (tid == 0) {{
|
||||
float block_sum = partials[0];
|
||||
out[{out_idx}] = {out_val};
|
||||
}}
|
||||
}}
|
||||
"#,
|
||||
@@ -1220,42 +1216,33 @@ impl MetalKernelOp for MetalMaxReduce {
|
||||
constant int *dyn [[buffer({dyn_buffer_index})]],
|
||||
constant uint &n_outputs [[buffer({n_outputs_index})]],
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint tid [[thread_index_in_threadgroup]],
|
||||
uint simd_lane [[thread_index_in_simdgroup]],
|
||||
uint simd_id [[simdgroup_index_in_threadgroup]]
|
||||
uint tid [[thread_index_in_threadgroup]]
|
||||
) {{
|
||||
if (gid >= n_outputs) return;
|
||||
|
||||
threadgroup float warp_maxs[THREADS_PER_GROUP / 32];
|
||||
threadgroup float partials[THREADS_PER_GROUP];
|
||||
|
||||
int in_start = {in_idx};
|
||||
int iters = {iters};
|
||||
(void)dyn;
|
||||
|
||||
// Each thread finds max of multiple elements
|
||||
float max_val = NEG_INF_F;
|
||||
for (int i = tid; i < iters; i += THREADS_PER_GROUP) {{
|
||||
max_val = fmax(max_val, {in_val});
|
||||
}}
|
||||
|
||||
// Warp-level reduction using simd_max
|
||||
max_val = simd_max(max_val);
|
||||
|
||||
// First lane of each warp writes to shared memory
|
||||
if (simd_lane == 0) {{
|
||||
warp_maxs[simd_id] = max_val;
|
||||
}}
|
||||
partials[tid] = max_val;
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// First warp does final reduction
|
||||
if (simd_id == 0) {{
|
||||
int n_warps = THREADS_PER_GROUP / 32;
|
||||
float block_max = (tid < uint(n_warps)) ? warp_maxs[tid] : NEG_INF_F;
|
||||
block_max = simd_max(block_max);
|
||||
|
||||
if (tid == 0) {{
|
||||
out[{out_idx}] = {out_val};
|
||||
for (uint stride = THREADS_PER_GROUP / 2; stride > 0; stride >>= 1) {{
|
||||
if (tid < stride) {{
|
||||
partials[tid] = fmax(partials[tid], partials[tid + stride]);
|
||||
}}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}}
|
||||
|
||||
if (tid == 0) {{
|
||||
float block_max = partials[0];
|
||||
out[{out_idx}] = {out_val};
|
||||
}}
|
||||
}}
|
||||
"#,
|
||||
@@ -1427,8 +1414,6 @@ impl EgglogOp for MPSMatmul {
|
||||
let dt = v(format!("?{}_dt", name.replace('-', "_")));
|
||||
|
||||
rule(union(sum_op.clone(), mps_op.clone()))
|
||||
.subsume(sum_op.clone())
|
||||
.subsume(mul_op)
|
||||
.set(dtype(mps_op), dt.clone())
|
||||
.fact(eq(dt, dtype(sum_op)))
|
||||
.ruleset("kernel_lower")
|
||||
@@ -1464,6 +1449,17 @@ impl EgglogOp for MPSMatmul {
|
||||
1,
|
||||
1,
|
||||
),
|
||||
Rule::raw(
|
||||
"(rule
|
||||
((= ?mul (Op (MetalMul ?shape ?as ?bs ?os) ?inputs))
|
||||
(= ?sum (Op (MetalSum ?sshape ?sk ?ssi ?sks ?sso) (ICons ?mul (INil))))
|
||||
(= ?sum (MPSMatmul ?m ?n ?k ?lhs ?lhsrs ?rhs ?rhsrs ?ors ?tl ?tr)))
|
||||
((delete (Op (MetalSum ?sshape ?sk ?ssi ?sks ?sso) (ICons ?mul (INil))))
|
||||
(delete (Op (MetalMul ?shape ?as ?bs ?os) ?inputs)))
|
||||
:ruleset cleanup
|
||||
:name \"delete-broadcast-mul-sum-when-mps-matmul-exists\"
|
||||
)",
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
@@ -1505,14 +1501,6 @@ impl EgglogOp for MPSMatmul {
|
||||
}
|
||||
|
||||
impl MPSMatmul {
|
||||
fn mps_dtype(dtype: DType) -> mps::MPSDataType {
|
||||
match dtype {
|
||||
DType::F32 | DType::TF32 => mps::MPSDataType::Float32,
|
||||
DType::F16 => mps::MPSDataType::Float16,
|
||||
unsupported => panic!("MPSMatmul does not support dtype {unsupported:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
fn row_bytes(row_stride: Expression, dtype: DType, dyn_map: &FxHashMap<char, usize>) -> u64 {
|
||||
let elems = row_stride
|
||||
.substitute('z', Expression::from(1))
|
||||
@@ -1521,19 +1509,6 @@ impl MPSMatmul {
|
||||
(elems * dtype.bits().div_ceil(8)) as u64
|
||||
}
|
||||
|
||||
fn descriptor(rows: usize, cols: usize, row_bytes: u64, dtype: DType) -> *mut Object {
|
||||
let data_type = Self::mps_dtype(dtype) as isize;
|
||||
unsafe {
|
||||
msg_send![
|
||||
class!(MPSMatrixDescriptor),
|
||||
matrixDescriptorWithRows: rows
|
||||
columns: cols
|
||||
rowBytes: row_bytes as usize
|
||||
dataType: data_type
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
fn matrix(buffer: &Buffer, descriptor: *mut Object) -> *mut Object {
|
||||
unsafe {
|
||||
let matrix: *mut Object = msg_send![class!(MPSMatrix), alloc];
|
||||
@@ -1589,12 +1564,11 @@ impl MetalKernelOp for MPSMatmul {
|
||||
|
||||
fn encode(
|
||||
&self,
|
||||
command_buffer: &CommandBufferRef,
|
||||
context: &mut MetalEncodeContext<'_>,
|
||||
_pipeline: Option<&ComputePipelineState>,
|
||||
inputs: &[&Buffer],
|
||||
output: &Buffer,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
_dyn_buffer: &Buffer,
|
||||
input_dtypes: &[DType],
|
||||
output_dtype: DType,
|
||||
) {
|
||||
@@ -1610,46 +1584,48 @@ impl MetalKernelOp for MPSMatmul {
|
||||
let rhs_rows = if self.transpose_rhs { n } else { k };
|
||||
let rhs_cols = if self.transpose_rhs { k } else { n };
|
||||
|
||||
let lhs_desc = Self::descriptor(
|
||||
lhs_rows,
|
||||
lhs_cols,
|
||||
Self::row_bytes(self.lhs_row_stride, lhs_dtype, dyn_map),
|
||||
lhs_dtype,
|
||||
);
|
||||
let rhs_desc = Self::descriptor(
|
||||
rhs_rows,
|
||||
rhs_cols,
|
||||
Self::row_bytes(self.rhs_row_stride, rhs_dtype, dyn_map),
|
||||
rhs_dtype,
|
||||
);
|
||||
let out_desc = Self::descriptor(
|
||||
m,
|
||||
n,
|
||||
Self::row_bytes(self.out_row_stride, output_dtype, dyn_map),
|
||||
output_dtype,
|
||||
);
|
||||
let (lhs_desc, rhs_desc, out_desc, kernel) = {
|
||||
let mut cache = context.mps_cache.borrow_mut();
|
||||
(
|
||||
cache.matrix_descriptor(
|
||||
lhs_rows,
|
||||
lhs_cols,
|
||||
Self::row_bytes(self.lhs_row_stride, lhs_dtype, dyn_map),
|
||||
lhs_dtype,
|
||||
),
|
||||
cache.matrix_descriptor(
|
||||
rhs_rows,
|
||||
rhs_cols,
|
||||
Self::row_bytes(self.rhs_row_stride, rhs_dtype, dyn_map),
|
||||
rhs_dtype,
|
||||
),
|
||||
cache.matrix_descriptor(
|
||||
m,
|
||||
n,
|
||||
Self::row_bytes(self.out_row_stride, output_dtype, dyn_map),
|
||||
output_dtype,
|
||||
),
|
||||
cache.matrix_multiplication(
|
||||
context.command_buffer,
|
||||
self.transpose_lhs,
|
||||
self.transpose_rhs,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
1.0,
|
||||
0.0,
|
||||
),
|
||||
)
|
||||
};
|
||||
|
||||
let lhs = Self::matrix(inputs[0], lhs_desc);
|
||||
let rhs = Self::matrix(inputs[1], rhs_desc);
|
||||
let out = Self::matrix(output, out_desc);
|
||||
|
||||
unsafe {
|
||||
let device: *mut Object = msg_send![command_buffer.as_ptr(), device];
|
||||
let kernel: *mut Object = msg_send![class!(MPSMatrixMultiplication), alloc];
|
||||
let kernel: *mut Object = msg_send![
|
||||
kernel,
|
||||
initWithDevice: device
|
||||
transposeLeft: self.transpose_lhs
|
||||
transposeRight: self.transpose_rhs
|
||||
resultRows: m
|
||||
resultColumns: n
|
||||
interiorColumns: k
|
||||
alpha: 1.0f64
|
||||
beta: 0.0f64
|
||||
];
|
||||
let _: () = msg_send![
|
||||
kernel,
|
||||
encodeToCommandBuffer: command_buffer.as_ptr()
|
||||
encodeToCommandBuffer: context.command_buffer.as_ptr()
|
||||
leftMatrix: lhs
|
||||
rightMatrix: rhs
|
||||
resultMatrix: out
|
||||
@@ -1657,7 +1633,6 @@ impl MetalKernelOp for MPSMatmul {
|
||||
let _: () = msg_send![lhs, release];
|
||||
let _: () = msg_send![rhs, release];
|
||||
let _: () = msg_send![out, release];
|
||||
let _: () = msg_send![kernel, release];
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1839,8 +1814,6 @@ impl EgglogOp for MPSBatchedMatmul {
|
||||
let dt = v(format!("?{}_dt", name.replace('-', "_")));
|
||||
|
||||
rule(union(sum_op.clone(), mps_op.clone()))
|
||||
.subsume(sum_op.clone())
|
||||
.subsume(mul_op)
|
||||
.set(dtype(mps_op), dt.clone())
|
||||
.fact(eq(dt, dtype(sum_op)))
|
||||
.ruleset("kernel_lower")
|
||||
@@ -1878,6 +1851,17 @@ impl EgglogOp for MPSBatchedMatmul {
|
||||
),
|
||||
1,
|
||||
),
|
||||
Rule::raw(
|
||||
"(rule
|
||||
((= ?mul (Op (MetalMul ?shape ?as ?bs ?os) ?inputs))
|
||||
(= ?sum (Op (MetalSum ?sshape ?sk ?ssi ?sks ?sso) (ICons ?mul (INil))))
|
||||
(= ?sum (MPSBatchedMatmul ?b ?m ?n ?k ?lhs ?lhsbs ?lhsrs ?rhs ?rhsbs ?rhsrs ?obs ?ors ?tl ?tr)))
|
||||
((delete (Op (MetalSum ?sshape ?sk ?ssi ?sks ?sso) (ICons ?mul (INil))))
|
||||
(delete (Op (MetalMul ?shape ?as ?bs ?os) ?inputs)))
|
||||
:ruleset cleanup
|
||||
:name \"delete-broadcast-mul-sum-when-mps-batched-matmul-exists\"
|
||||
)",
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
@@ -1953,12 +1937,11 @@ impl MetalKernelOp for MPSBatchedMatmul {
|
||||
|
||||
fn encode(
|
||||
&self,
|
||||
command_buffer: &CommandBufferRef,
|
||||
context: &mut MetalEncodeContext<'_>,
|
||||
_pipeline: Option<&ComputePipelineState>,
|
||||
inputs: &[&Buffer],
|
||||
output: &Buffer,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
_dyn_buffer: &Buffer,
|
||||
input_dtypes: &[DType],
|
||||
output_dtype: DType,
|
||||
) {
|
||||
@@ -1982,25 +1965,26 @@ impl MetalKernelOp for MPSBatchedMatmul {
|
||||
let lhs_row_bytes = MPSMatmul::row_bytes(self.lhs_row_stride, lhs_dtype, dyn_map);
|
||||
let rhs_row_bytes = MPSMatmul::row_bytes(self.rhs_row_stride, rhs_dtype, dyn_map);
|
||||
let out_row_bytes = MPSMatmul::row_bytes(self.out_row_stride, output_dtype, dyn_map);
|
||||
let lhs_desc = MPSMatmul::descriptor(lhs_rows, lhs_cols, lhs_row_bytes, lhs_dtype);
|
||||
let rhs_desc = MPSMatmul::descriptor(rhs_rows, rhs_cols, rhs_row_bytes, rhs_dtype);
|
||||
let out_desc = MPSMatmul::descriptor(m, n, out_row_bytes, output_dtype);
|
||||
let (lhs_desc, rhs_desc, out_desc, kernel) = {
|
||||
let mut cache = context.mps_cache.borrow_mut();
|
||||
(
|
||||
cache.matrix_descriptor(lhs_rows, lhs_cols, lhs_row_bytes, lhs_dtype),
|
||||
cache.matrix_descriptor(rhs_rows, rhs_cols, rhs_row_bytes, rhs_dtype),
|
||||
cache.matrix_descriptor(m, n, out_row_bytes, output_dtype),
|
||||
cache.matrix_multiplication(
|
||||
context.command_buffer,
|
||||
self.transpose_lhs,
|
||||
self.transpose_rhs,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
1.0,
|
||||
0.0,
|
||||
),
|
||||
)
|
||||
};
|
||||
|
||||
unsafe {
|
||||
let device: *mut Object = msg_send![command_buffer.as_ptr(), device];
|
||||
let kernel: *mut Object = msg_send![class!(MPSMatrixMultiplication), alloc];
|
||||
let kernel: *mut Object = msg_send![
|
||||
kernel,
|
||||
initWithDevice: device
|
||||
transposeLeft: self.transpose_lhs
|
||||
transposeRight: self.transpose_rhs
|
||||
resultRows: m
|
||||
resultColumns: n
|
||||
interiorColumns: k
|
||||
alpha: 1.0f64
|
||||
beta: 0.0f64
|
||||
];
|
||||
|
||||
for batch_idx in 0..batch {
|
||||
let batch_expr = Expression::from(batch_idx as i64);
|
||||
let lhs_offset = self
|
||||
@@ -2027,7 +2011,7 @@ impl MetalKernelOp for MPSBatchedMatmul {
|
||||
let out = MPSMatmul::matrix_with_offset(output, out_offset as u64, out_desc);
|
||||
let _: () = msg_send![
|
||||
kernel,
|
||||
encodeToCommandBuffer: command_buffer.as_ptr()
|
||||
encodeToCommandBuffer: context.command_buffer.as_ptr()
|
||||
leftMatrix: lhs
|
||||
rightMatrix: rhs
|
||||
resultMatrix: out
|
||||
@@ -2036,7 +2020,6 @@ impl MetalKernelOp for MPSBatchedMatmul {
|
||||
let _: () = msg_send![rhs, release];
|
||||
let _: () = msg_send![out, release];
|
||||
}
|
||||
let _: () = msg_send![kernel, release];
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2163,24 +2146,6 @@ impl EgglogOp for GenericMatmul {
|
||||
:name \"delete-broadcast-mul-sum-when-generic-matmul-exists\"
|
||||
)",
|
||||
),
|
||||
Rule::raw(
|
||||
"(rule
|
||||
((= ?sum (GenericMatmul ?go ?gm ?gk ?gl ?glas ?gr ?grs ?gsis ?gsit ?gos))
|
||||
(= ?sum (MPSMatmul ?mm ?mn ?mk ?ml ?mls ?mr ?mrs ?mos ?mtl ?mtr)))
|
||||
((delete (GenericMatmul ?go ?gm ?gk ?gl ?glas ?gr ?grs ?gsis ?gsit ?gos)))
|
||||
:ruleset cleanup
|
||||
:name \"prefer-mps-over-generic-matmul\"
|
||||
)",
|
||||
),
|
||||
Rule::raw(
|
||||
"(rule
|
||||
((= ?sum (GenericMatmul ?go ?gm ?gk ?gl ?glas ?gr ?grs ?gsis ?gsit ?gos))
|
||||
(= ?sum (MPSBatchedMatmul ?bb ?bm ?bn ?bk ?bl ?blbs ?blrs ?br ?brbs ?brrs ?bobs ?bors ?btl ?btr)))
|
||||
((delete (GenericMatmul ?go ?gm ?gk ?gl ?glas ?gr ?grs ?gsis ?gsit ?gos)))
|
||||
:ruleset cleanup
|
||||
:name \"prefer-mps-batched-over-generic-matmul\"
|
||||
)",
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
@@ -2265,13 +2230,11 @@ impl MetalKernelOp for GenericMatmul {
|
||||
constant int *dyn [[buffer({dyn_buffer_index})]],
|
||||
constant uint &n_outputs [[buffer({n_outputs_index})]],
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint tid [[thread_index_in_threadgroup]],
|
||||
uint simd_lane [[thread_index_in_simdgroup]],
|
||||
uint simd_id [[simdgroup_index_in_threadgroup]]
|
||||
uint tid [[thread_index_in_threadgroup]]
|
||||
) {{
|
||||
if (gid >= n_outputs) return;
|
||||
|
||||
threadgroup float warp_sums[THREADS_PER_GROUP / 32];
|
||||
threadgroup float partials[THREADS_PER_GROUP];
|
||||
int base_idx = {sum_base_idx};
|
||||
int iters = {iters};
|
||||
(void)dyn;
|
||||
@@ -2282,19 +2245,18 @@ impl MetalKernelOp for GenericMatmul {
|
||||
sum += ({lhs_val}) * ({rhs_val});
|
||||
}}
|
||||
|
||||
sum = simd_sum(sum);
|
||||
if (simd_lane == 0) {{
|
||||
warp_sums[simd_id] = sum;
|
||||
}}
|
||||
partials[tid] = sum;
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (simd_id == 0) {{
|
||||
int n_warps = THREADS_PER_GROUP / 32;
|
||||
float block_sum = (tid < uint(n_warps)) ? warp_sums[tid] : 0.0f;
|
||||
block_sum = simd_sum(block_sum);
|
||||
if (tid == 0) {{
|
||||
out[{out_idx}] = {out_val};
|
||||
for (uint stride = THREADS_PER_GROUP / 2; stride > 0; stride >>= 1) {{
|
||||
if (tid < stride) {{
|
||||
partials[tid] += partials[tid + stride];
|
||||
}}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}}
|
||||
|
||||
if (tid == 0) {{
|
||||
float block_sum = partials[0];
|
||||
out[{out_idx}] = {out_val};
|
||||
}}
|
||||
}}
|
||||
"#,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
pub mod dyn_backend;
|
||||
pub mod kernel;
|
||||
mod memory_analysis;
|
||||
pub mod runtime;
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
1478
crates/luminal_metal/src/memory_analysis.rs
Normal file
1478
crates/luminal_metal/src/memory_analysis.rs
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,8 +1,9 @@
|
||||
use crate::kernel::{DYN_SLOT_COUNT, MetalKernelOp};
|
||||
use crate::kernel::{DYN_SLOT_COUNT, MetalEncodeContext, MetalKernelOp, MpsKernelCache};
|
||||
use half::{bf16, f16};
|
||||
use itertools::Itertools;
|
||||
use luminal::{
|
||||
dtype::DType,
|
||||
egglog_utils::SerializedEGraph,
|
||||
graph::{BucketLLIR, DimBucket, Graph, LLIRGraph},
|
||||
hlir::{Input, NativeData, Output},
|
||||
op::{ExecutionStats, Runtime, RuntimeStats, TimingMethod},
|
||||
@@ -16,15 +17,26 @@ use metal::{Buffer, CommandQueue, ComputePipelineState, Device, MTLResourceOptio
|
||||
use objc::rc::autoreleasepool;
|
||||
use objc::runtime::Object;
|
||||
use safetensors::{Dtype, SafeTensors};
|
||||
use std::{fs::File, time::Duration};
|
||||
use std::{cell::RefCell, fs::File, time::Duration};
|
||||
|
||||
#[derive(Clone)]
|
||||
struct MetalExecutionStep {
|
||||
node: NodeIndex,
|
||||
input_nodes: Vec<NodeIndex>,
|
||||
input_dtypes: Vec<DType>,
|
||||
output_dtype: DType,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct MetalCompiledBucket {
|
||||
bucket_indices: FxHashMap<char, usize>,
|
||||
llir_graph: LLIRGraph,
|
||||
llir_to_hlir: FxHashMap<NodeIndex, NodeIndex>,
|
||||
node_dtypes: FxHashMap<NodeIndex, DType>,
|
||||
pipelines: FxHashMap<NodeIndex, ComputePipelineState>,
|
||||
output_alias_map: FxHashMap<NodeIndex, NodeIndex>,
|
||||
output_data_map: FxHashMap<NodeIndex, NodeIndex>,
|
||||
execution_plan: Vec<MetalExecutionStep>,
|
||||
}
|
||||
|
||||
pub struct MetalRuntime {
|
||||
@@ -36,16 +48,26 @@ pub struct MetalRuntime {
|
||||
pub hlir_buffers: FxHashMap<NodeIndex, Buffer>,
|
||||
/// Buffers for LLIR intermediate/output tensors
|
||||
pub buffers: FxHashMap<NodeIndex, Buffer>,
|
||||
/// Logical byte length for each active LLIR buffer.
|
||||
buffer_lengths: FxHashMap<NodeIndex, u64>,
|
||||
/// Dynamic dimensions table (a-z), shared across all kernels.
|
||||
dyn_buffer: Buffer,
|
||||
/// Retained MPS descriptors/kernels reused across command encodes.
|
||||
mps_cache: RefCell<MpsKernelCache>,
|
||||
/// The current LLIR graph
|
||||
llir_graph: LLIRGraph,
|
||||
/// LLIR input node -> HLIR input node.
|
||||
llir_to_hlir: FxHashMap<NodeIndex, NodeIndex>,
|
||||
/// Inferred runtime dtype for each LLIR node.
|
||||
node_dtypes: FxHashMap<NodeIndex, DType>,
|
||||
/// Compiled pipeline states for each kernel node
|
||||
pipelines: FxHashMap<NodeIndex, ComputePipelineState>,
|
||||
/// LLIR output node -> input node whose buffer contains the output.
|
||||
output_alias_map: FxHashMap<NodeIndex, NodeIndex>,
|
||||
/// HLIR output id -> LLIR node whose data feeds the output.
|
||||
output_data_map: FxHashMap<NodeIndex, NodeIndex>,
|
||||
/// Precomputed executable nodes and input metadata for the active LLIR graph.
|
||||
execution_plan: Vec<MetalExecutionStep>,
|
||||
/// Bucket definitions for dynamic dimensions.
|
||||
dim_buckets: FxHashMap<char, Vec<DimBucket>>,
|
||||
/// Compiled LLIR variants, one per bucket combination.
|
||||
@@ -64,22 +86,10 @@ impl MetalRuntime {
|
||||
}
|
||||
|
||||
fn output_data_node(&self, id: NodeIndex) -> NodeIndex {
|
||||
let output_id = self
|
||||
.llir_graph
|
||||
.node_indices()
|
||||
.find(|n| {
|
||||
if let Some(Output { node }) = self.llir_graph[*n].to_op::<Output>() {
|
||||
*node == id.index()
|
||||
} else {
|
||||
false
|
||||
}
|
||||
})
|
||||
.expect("Cannot find output tensor!");
|
||||
|
||||
self.llir_graph
|
||||
.neighbors_directed(output_id, Direction::Incoming)
|
||||
.next()
|
||||
.unwrap()
|
||||
self.output_data_map
|
||||
.get(&id)
|
||||
.copied()
|
||||
.unwrap_or_else(|| panic!("Cannot find output tensor {id:?}!"))
|
||||
}
|
||||
|
||||
fn follow_aliases(&self, mut node: NodeIndex) -> NodeIndex {
|
||||
@@ -225,6 +235,7 @@ impl MetalRuntime {
|
||||
let data_id = self.follow_aliases(self.output_data_node(id.to_id()));
|
||||
|
||||
if let Some(buffer) = self.buffers.remove(&data_id) {
|
||||
self.buffer_lengths.remove(&data_id);
|
||||
return buffer;
|
||||
}
|
||||
|
||||
@@ -269,12 +280,21 @@ impl MetalRuntime {
|
||||
.map(|inp| inp.dtype)
|
||||
})
|
||||
.unwrap_or(DType::F32);
|
||||
let logical_bytes = self
|
||||
.buffer_lengths
|
||||
.get(&data_id)
|
||||
.copied()
|
||||
.unwrap_or_else(|| buffer.length());
|
||||
assert!(
|
||||
logical_bytes <= buffer.length(),
|
||||
"Logical buffer size exceeds allocated Metal buffer size"
|
||||
);
|
||||
|
||||
unsafe {
|
||||
match dtype {
|
||||
DType::F16 => {
|
||||
let ptr = buffer.contents() as *const f16;
|
||||
let len = buffer.length() as usize / std::mem::size_of::<f16>();
|
||||
let len = logical_bytes as usize / std::mem::size_of::<f16>();
|
||||
std::slice::from_raw_parts(ptr, len)
|
||||
.iter()
|
||||
.map(|v| v.to_f32())
|
||||
@@ -282,7 +302,7 @@ impl MetalRuntime {
|
||||
}
|
||||
DType::Int => {
|
||||
let ptr = buffer.contents() as *const i32;
|
||||
let len = buffer.length() as usize / std::mem::size_of::<i32>();
|
||||
let len = logical_bytes as usize / std::mem::size_of::<i32>();
|
||||
std::slice::from_raw_parts(ptr, len)
|
||||
.iter()
|
||||
.map(|v| *v as f32)
|
||||
@@ -290,7 +310,7 @@ impl MetalRuntime {
|
||||
}
|
||||
_ => {
|
||||
let ptr = buffer.contents() as *const f32;
|
||||
let len = buffer.length() as usize / std::mem::size_of::<f32>();
|
||||
let len = logical_bytes as usize / std::mem::size_of::<f32>();
|
||||
std::slice::from_raw_parts(ptr, len).to_vec()
|
||||
}
|
||||
}
|
||||
@@ -304,6 +324,26 @@ impl Runtime for MetalRuntime {
|
||||
type ExecReturn = ();
|
||||
type ProfileMetric = Duration;
|
||||
|
||||
fn late_egglog_passes(
|
||||
ops: &[std::sync::Arc<Box<dyn luminal::op::EgglogOp>>],
|
||||
options: &luminal::graph::BuildSearchSpaceOptions,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> Vec<luminal::egglog_utils::LateEgglogPass> {
|
||||
vec![crate::memory_analysis::metal_memory_analysis_pass(
|
||||
ops,
|
||||
options.max_memory_bytes,
|
||||
dyn_map,
|
||||
)]
|
||||
}
|
||||
|
||||
fn estimate_graph_memory<'a>(
|
||||
egraph: &'a SerializedEGraph,
|
||||
choices: &luminal::egglog_utils::EGraphChoiceSet<'a>,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> Option<usize> {
|
||||
crate::memory_analysis::estimate_graph_memory_bytes(egraph, choices, dyn_map)
|
||||
}
|
||||
|
||||
fn initialize(_: Self::CompileArg) -> Self {
|
||||
let device = Device::system_default().expect("No Metal device found!");
|
||||
let command_queue = device.new_command_queue();
|
||||
@@ -318,11 +358,16 @@ impl Runtime for MetalRuntime {
|
||||
input_data: FxHashMap::default(),
|
||||
hlir_buffers: FxHashMap::default(),
|
||||
buffers: FxHashMap::default(),
|
||||
buffer_lengths: FxHashMap::default(),
|
||||
dyn_buffer,
|
||||
mps_cache: RefCell::new(MpsKernelCache::default()),
|
||||
llir_graph: StableGraph::default(),
|
||||
llir_to_hlir: FxHashMap::default(),
|
||||
node_dtypes: FxHashMap::default(),
|
||||
pipelines: FxHashMap::default(),
|
||||
output_alias_map: FxHashMap::default(),
|
||||
output_data_map: FxHashMap::default(),
|
||||
execution_plan: vec![],
|
||||
dim_buckets: FxHashMap::default(),
|
||||
compiled_buckets: vec![],
|
||||
active_bucket: 0,
|
||||
@@ -336,6 +381,7 @@ impl Runtime for MetalRuntime {
|
||||
#[tracing::instrument(skip_all)]
|
||||
fn load_llir(&mut self, llir_graph: &LLIRGraph) {
|
||||
self.buffers.clear();
|
||||
self.buffer_lengths.clear();
|
||||
self.dim_buckets.clear();
|
||||
self.compiled_buckets = vec![self.compile_bucket(FxHashMap::default(), llir_graph)];
|
||||
self.activate_bucket(0);
|
||||
@@ -347,19 +393,25 @@ impl Runtime for MetalRuntime {
|
||||
llir_graph: &LLIRGraph,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
trials: usize,
|
||||
_timeout: Option<std::time::Duration>,
|
||||
timeout: Option<std::time::Duration>,
|
||||
) -> (Self::ProfileMetric, String) {
|
||||
self.load_llir(llir_graph);
|
||||
self.allocate_intermediate_buffers(dyn_map);
|
||||
|
||||
let trials = trials.max(1);
|
||||
let profile_start = std::time::Instant::now();
|
||||
let mut duration = Duration::default();
|
||||
let mut completed_trials = 0;
|
||||
for _ in 0..trials {
|
||||
let start = std::time::Instant::now();
|
||||
self.execute(dyn_map);
|
||||
duration += start.elapsed();
|
||||
completed_trials += 1;
|
||||
if timeout.is_some_and(|timeout| profile_start.elapsed() >= timeout) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
duration /= trials as u32;
|
||||
duration /= completed_trials as u32;
|
||||
|
||||
(duration, format!("{:.2?}", duration))
|
||||
}
|
||||
@@ -370,74 +422,43 @@ impl Runtime for MetalRuntime {
|
||||
self.select_bucket(dyn_map);
|
||||
self.allocate_active_intermediate_buffers(dyn_map);
|
||||
|
||||
let llir_to_hlir: FxHashMap<NodeIndex, NodeIndex> = self
|
||||
.llir_graph
|
||||
.node_indices()
|
||||
.filter_map(|n| {
|
||||
if let Some(Input { node, .. }) = self.llir_graph[n].to_op::<Input>() {
|
||||
Some((n, NodeIndex::new(*node)))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
let topo_order = toposort(&self.llir_graph, None).expect("Graph has cycles!");
|
||||
|
||||
self.update_dyn_buffer(dyn_map);
|
||||
let command_buffer = self.command_queue.new_command_buffer();
|
||||
let mut encode_context = MetalEncodeContext {
|
||||
command_buffer,
|
||||
dyn_buffer: &self.dyn_buffer,
|
||||
mps_cache: &self.mps_cache,
|
||||
};
|
||||
|
||||
for node in topo_order {
|
||||
if self.llir_graph[node].to_op::<Input>().is_some()
|
||||
|| self.llir_graph[node].to_op::<Output>().is_some()
|
||||
{
|
||||
continue;
|
||||
}
|
||||
for step in &self.execution_plan {
|
||||
let kernel_op = self.llir_graph[step.node]
|
||||
.to_dialect::<dyn MetalKernelOp>()
|
||||
.expect("Execution plan referenced a non-Metal op");
|
||||
let pipeline = self.pipelines.get(&step.node);
|
||||
|
||||
if let Some(kernel_op) = self.llir_graph[node].to_dialect::<dyn MetalKernelOp>() {
|
||||
let pipeline = self.pipelines.get(&node);
|
||||
let input_buffers: Vec<&Buffer> = step
|
||||
.input_nodes
|
||||
.iter()
|
||||
.map(|&n| self.buffer_for_llir_node(n, &self.llir_to_hlir))
|
||||
.collect();
|
||||
|
||||
let input_nodes: Vec<NodeIndex> = self
|
||||
.llir_graph
|
||||
.edges_directed(node, Direction::Incoming)
|
||||
.sorted_by_key(|e| e.id())
|
||||
.map(|e| e.source())
|
||||
.collect();
|
||||
let output_buffer = if let Some(alias_idx) = kernel_op.output_aliases_input() {
|
||||
input_buffers[alias_idx]
|
||||
} else {
|
||||
self.buffers
|
||||
.get(&step.node)
|
||||
.expect("Output buffer not allocated!")
|
||||
};
|
||||
|
||||
let input_buffers: Vec<&Buffer> = input_nodes
|
||||
.iter()
|
||||
.map(|&n| self.buffer_for_llir_node(n, &llir_to_hlir))
|
||||
.collect();
|
||||
let input_dtypes: Vec<DType> = input_nodes
|
||||
.iter()
|
||||
.map(|n| {
|
||||
self.node_dtypes
|
||||
.get(n)
|
||||
.copied()
|
||||
.unwrap_or_else(|| panic!("Missing inferred dtype for node {n:?}"))
|
||||
})
|
||||
.collect();
|
||||
|
||||
let output_buffer = if let Some(alias_idx) = kernel_op.output_aliases_input() {
|
||||
input_buffers[alias_idx]
|
||||
} else {
|
||||
self.buffers
|
||||
.get(&node)
|
||||
.expect("Output buffer not allocated!")
|
||||
};
|
||||
let output_dtype = self.node_dtypes.get(&node).copied().unwrap_or(DType::F32);
|
||||
|
||||
kernel_op.encode(
|
||||
command_buffer,
|
||||
pipeline,
|
||||
&input_buffers,
|
||||
output_buffer,
|
||||
dyn_map,
|
||||
&self.dyn_buffer,
|
||||
&input_dtypes,
|
||||
output_dtype,
|
||||
);
|
||||
}
|
||||
kernel_op.encode(
|
||||
&mut encode_context,
|
||||
pipeline,
|
||||
&input_buffers,
|
||||
output_buffer,
|
||||
dyn_map,
|
||||
&step.input_dtypes,
|
||||
step.output_dtype,
|
||||
);
|
||||
}
|
||||
|
||||
command_buffer.commit();
|
||||
@@ -447,6 +468,22 @@ impl Runtime for MetalRuntime {
|
||||
|
||||
fn clear_intermediate_buffers(&mut self) {
|
||||
self.buffers.clear();
|
||||
self.buffer_lengths.clear();
|
||||
}
|
||||
|
||||
fn intermediate_buffer_bytes(&self) -> usize {
|
||||
self.buffers
|
||||
.values()
|
||||
.map(|buffer| buffer.length() as usize)
|
||||
.sum()
|
||||
}
|
||||
|
||||
fn planned_intermediate_buffer_bytes(&self) -> Option<usize> {
|
||||
Some(self.intermediate_buffer_bytes())
|
||||
}
|
||||
|
||||
fn allocated_intermediate_buffer_bytes(&self) -> Option<usize> {
|
||||
Some(self.intermediate_buffer_bytes())
|
||||
}
|
||||
|
||||
fn load_llir_buckets(
|
||||
@@ -455,6 +492,7 @@ impl Runtime for MetalRuntime {
|
||||
bucket_llirs: &[BucketLLIR],
|
||||
) {
|
||||
self.buffers.clear();
|
||||
self.buffer_lengths.clear();
|
||||
self.dim_buckets = dim_buckets.clone();
|
||||
self.compiled_buckets = bucket_llirs
|
||||
.iter()
|
||||
@@ -497,7 +535,7 @@ impl MetalRuntime {
|
||||
fn create_input_buffer(&self, data: &NativeData, dtype: DType) -> Buffer {
|
||||
match dtype {
|
||||
DType::F32 => {
|
||||
let values: Vec<f32> = (0..data.len()).map(|i| data.f32(i)).collect();
|
||||
let values = data.to_f32_vec();
|
||||
self.device.new_buffer_with_data(
|
||||
values.as_ptr() as *const _,
|
||||
std::mem::size_of_val(values.as_slice()) as u64,
|
||||
@@ -505,7 +543,7 @@ impl MetalRuntime {
|
||||
)
|
||||
}
|
||||
DType::F16 => {
|
||||
let values: Vec<f16> = (0..data.len()).map(|i| data.f16(i)).collect();
|
||||
let values = data.to_f16_vec();
|
||||
self.device.new_buffer_with_data(
|
||||
values.as_ptr() as *const _,
|
||||
std::mem::size_of_val(values.as_slice()) as u64,
|
||||
@@ -513,7 +551,7 @@ impl MetalRuntime {
|
||||
)
|
||||
}
|
||||
DType::Int => {
|
||||
let values: Vec<i32> = (0..data.len()).map(|i| data.i32(i)).collect();
|
||||
let values = data.to_i32_vec();
|
||||
self.device.new_buffer_with_data(
|
||||
values.as_ptr() as *const _,
|
||||
std::mem::size_of_val(values.as_slice()) as u64,
|
||||
@@ -531,6 +569,7 @@ impl MetalRuntime {
|
||||
|
||||
fn allocate_active_intermediate_buffers(&mut self, dyn_map: &FxHashMap<char, usize>) {
|
||||
let mut planned = Vec::new();
|
||||
let capacity_dyn_map = self.active_capacity_dyn_map(dyn_map);
|
||||
|
||||
for node in self.llir_graph.node_indices() {
|
||||
if self.llir_graph[node].to_op::<Input>().is_some() {
|
||||
@@ -541,28 +580,58 @@ impl MetalRuntime {
|
||||
if kernel_op.output_aliases_input().is_some() {
|
||||
continue;
|
||||
}
|
||||
let size = kernel_op.output_size().exec(dyn_map).unwrap();
|
||||
let dtype = self.node_dtypes.get(&node).copied().unwrap_or(DType::F32);
|
||||
let bytes = (size * dtype.bits().div_ceil(8)) as u64;
|
||||
let requested_bytes =
|
||||
Self::output_bytes(kernel_op.as_ref().as_ref(), dtype, dyn_map);
|
||||
let allocation_bytes =
|
||||
Self::output_bytes(kernel_op.as_ref().as_ref(), dtype, &capacity_dyn_map)
|
||||
.max(requested_bytes);
|
||||
let needs_buffer = self
|
||||
.buffers
|
||||
.get(&node)
|
||||
.is_none_or(|buffer| buffer.length() != bytes);
|
||||
.is_none_or(|buffer| requested_bytes > buffer.length());
|
||||
|
||||
planned.push((node, bytes, needs_buffer));
|
||||
planned.push((node, requested_bytes, allocation_bytes, needs_buffer));
|
||||
}
|
||||
}
|
||||
|
||||
for (node, bytes, needs_buffer) in planned {
|
||||
for (node, requested_bytes, allocation_bytes, needs_buffer) in planned {
|
||||
self.buffer_lengths.insert(node, requested_bytes);
|
||||
if needs_buffer {
|
||||
let buffer = self
|
||||
.device
|
||||
.new_buffer(bytes, MTLResourceOptions::StorageModeShared);
|
||||
.new_buffer(allocation_bytes, MTLResourceOptions::StorageModeShared);
|
||||
self.buffers.insert(node, buffer);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn output_bytes(
|
||||
kernel_op: &dyn MetalKernelOp,
|
||||
dtype: DType,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> u64 {
|
||||
let size = kernel_op.output_size().exec(dyn_map).unwrap();
|
||||
(size * dtype.bits().div_ceil(8)) as u64
|
||||
}
|
||||
|
||||
fn active_capacity_dyn_map(&self, dyn_map: &FxHashMap<char, usize>) -> FxHashMap<char, usize> {
|
||||
let mut capacity_dyn_map = dyn_map.clone();
|
||||
let Some(active_bucket) = self.compiled_buckets.get(self.active_bucket) else {
|
||||
return capacity_dyn_map;
|
||||
};
|
||||
|
||||
for (&dim, buckets) in &self.dim_buckets {
|
||||
if let Some(&bucket_index) = active_bucket.bucket_indices.get(&dim)
|
||||
&& let Some(bucket) = buckets.get(bucket_index)
|
||||
{
|
||||
capacity_dyn_map.insert(dim, bucket.max);
|
||||
}
|
||||
}
|
||||
|
||||
capacity_dyn_map
|
||||
}
|
||||
|
||||
fn compile_bucket(
|
||||
&self,
|
||||
bucket_indices: FxHashMap<char, usize>,
|
||||
@@ -571,12 +640,17 @@ impl MetalRuntime {
|
||||
let mut node_dtypes = FxHashMap::default();
|
||||
let mut pipelines = FxHashMap::default();
|
||||
let mut output_alias_map = FxHashMap::default();
|
||||
let mut output_data_map = FxHashMap::default();
|
||||
let mut execution_plan = Vec::new();
|
||||
let mut llir_to_hlir = FxHashMap::default();
|
||||
let llir_graph = llir_graph.clone();
|
||||
|
||||
let topo_order = toposort(&llir_graph, None).expect("Graph has cycles!");
|
||||
for node in topo_order {
|
||||
for node in &topo_order {
|
||||
let node = *node;
|
||||
if let Some(input) = llir_graph[node].to_op::<Input>() {
|
||||
node_dtypes.insert(node, input.dtype);
|
||||
llir_to_hlir.insert(node, NodeIndex::new(input.node));
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -610,17 +684,38 @@ impl MetalRuntime {
|
||||
{
|
||||
output_alias_map.insert(node, target);
|
||||
}
|
||||
execution_plan.push(MetalExecutionStep {
|
||||
node,
|
||||
input_nodes,
|
||||
input_dtypes,
|
||||
output_dtype,
|
||||
});
|
||||
} else {
|
||||
panic!("Metal runtime cannot execute unlowered LLIR node {node:?}");
|
||||
}
|
||||
}
|
||||
|
||||
for node in topo_order {
|
||||
if let Some(Output { node: hlir_node }) = llir_graph[node].to_op::<Output>()
|
||||
&& let Some(data_node) = llir_graph
|
||||
.edges_directed(node, Direction::Incoming)
|
||||
.sorted_by_key(|e| e.id())
|
||||
.next()
|
||||
.map(|e| e.source())
|
||||
{
|
||||
output_data_map.insert(NodeIndex::new(*hlir_node), data_node);
|
||||
}
|
||||
}
|
||||
|
||||
MetalCompiledBucket {
|
||||
bucket_indices,
|
||||
llir_graph,
|
||||
llir_to_hlir,
|
||||
node_dtypes,
|
||||
pipelines,
|
||||
output_alias_map,
|
||||
output_data_map,
|
||||
execution_plan,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -632,11 +727,15 @@ impl MetalRuntime {
|
||||
.clone();
|
||||
self.active_bucket = index;
|
||||
self.llir_graph = bucket.llir_graph;
|
||||
self.llir_to_hlir = bucket.llir_to_hlir;
|
||||
self.node_dtypes = bucket.node_dtypes;
|
||||
self.pipelines = bucket.pipelines;
|
||||
self.output_alias_map = bucket.output_alias_map;
|
||||
self.output_data_map = bucket.output_data_map;
|
||||
self.execution_plan = bucket.execution_plan;
|
||||
self.refresh_input_data_buffers();
|
||||
self.buffers.clear();
|
||||
self.buffer_lengths.clear();
|
||||
}
|
||||
|
||||
fn refresh_input_data_buffers(&mut self) {
|
||||
@@ -706,74 +805,43 @@ impl MetalRuntime {
|
||||
self.select_bucket(dyn_map);
|
||||
self.allocate_active_intermediate_buffers(dyn_map);
|
||||
|
||||
let llir_to_hlir: FxHashMap<NodeIndex, NodeIndex> = self
|
||||
.llir_graph
|
||||
.node_indices()
|
||||
.filter_map(|n| {
|
||||
if let Some(Input { node, .. }) = self.llir_graph[n].to_op::<Input>() {
|
||||
Some((n, NodeIndex::new(*node)))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
let topo_order = toposort(&self.llir_graph, None).expect("Graph has cycles!");
|
||||
|
||||
self.update_dyn_buffer(dyn_map);
|
||||
let command_buffer = self.command_queue.new_command_buffer();
|
||||
let mut encode_context = MetalEncodeContext {
|
||||
command_buffer,
|
||||
dyn_buffer: &self.dyn_buffer,
|
||||
mps_cache: &self.mps_cache,
|
||||
};
|
||||
|
||||
for node in topo_order {
|
||||
if self.llir_graph[node].to_op::<Input>().is_some()
|
||||
|| self.llir_graph[node].to_op::<Output>().is_some()
|
||||
{
|
||||
continue;
|
||||
}
|
||||
for step in &self.execution_plan {
|
||||
let kernel_op = self.llir_graph[step.node]
|
||||
.to_dialect::<dyn MetalKernelOp>()
|
||||
.expect("Execution plan referenced a non-Metal op");
|
||||
let pipeline = self.pipelines.get(&step.node);
|
||||
|
||||
if let Some(kernel_op) = self.llir_graph[node].to_dialect::<dyn MetalKernelOp>() {
|
||||
let pipeline = self.pipelines.get(&node);
|
||||
let input_buffers: Vec<&Buffer> = step
|
||||
.input_nodes
|
||||
.iter()
|
||||
.map(|&n| self.buffer_for_llir_node(n, &self.llir_to_hlir))
|
||||
.collect();
|
||||
|
||||
let input_nodes: Vec<NodeIndex> = self
|
||||
.llir_graph
|
||||
.edges_directed(node, Direction::Incoming)
|
||||
.sorted_by_key(|e| e.id())
|
||||
.map(|e| e.source())
|
||||
.collect();
|
||||
let output_buffer = if let Some(alias_idx) = kernel_op.output_aliases_input() {
|
||||
input_buffers[alias_idx]
|
||||
} else {
|
||||
self.buffers
|
||||
.get(&step.node)
|
||||
.expect("Output buffer not allocated!")
|
||||
};
|
||||
|
||||
let input_buffers: Vec<&Buffer> = input_nodes
|
||||
.iter()
|
||||
.map(|&n| self.buffer_for_llir_node(n, &llir_to_hlir))
|
||||
.collect();
|
||||
let input_dtypes: Vec<DType> = input_nodes
|
||||
.iter()
|
||||
.map(|n| {
|
||||
self.node_dtypes
|
||||
.get(n)
|
||||
.copied()
|
||||
.unwrap_or_else(|| panic!("Missing inferred dtype for node {n:?}"))
|
||||
})
|
||||
.collect();
|
||||
|
||||
let output_buffer = if let Some(alias_idx) = kernel_op.output_aliases_input() {
|
||||
input_buffers[alias_idx]
|
||||
} else {
|
||||
self.buffers
|
||||
.get(&node)
|
||||
.expect("Output buffer not allocated!")
|
||||
};
|
||||
let output_dtype = self.node_dtypes.get(&node).copied().unwrap_or(DType::F32);
|
||||
|
||||
kernel_op.encode(
|
||||
command_buffer,
|
||||
pipeline,
|
||||
&input_buffers,
|
||||
output_buffer,
|
||||
dyn_map,
|
||||
&self.dyn_buffer,
|
||||
&input_dtypes,
|
||||
output_dtype,
|
||||
);
|
||||
}
|
||||
kernel_op.encode(
|
||||
&mut encode_context,
|
||||
pipeline,
|
||||
&input_buffers,
|
||||
output_buffer,
|
||||
dyn_map,
|
||||
&step.input_dtypes,
|
||||
step.output_dtype,
|
||||
);
|
||||
}
|
||||
|
||||
command_buffer.commit();
|
||||
|
||||
@@ -3,6 +3,7 @@ use candle_core::{Device as CandleDevice, Tensor as CandleTensor};
|
||||
use half::{bf16, f16};
|
||||
use luminal::prelude::*;
|
||||
use proptest::prelude::*;
|
||||
use rand::{SeedableRng, rngs::StdRng};
|
||||
use safetensors::{Dtype, tensor::TensorView};
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
@@ -38,6 +39,30 @@ fn bytes_of<T: bytemuck::NoUninit>(values: &[T]) -> Vec<u8> {
|
||||
bytemuck::cast_slice(values).to_vec()
|
||||
}
|
||||
|
||||
fn search_candidates(cx: &mut Graph, rt: MetalRuntime, limit: usize) -> MetalRuntime {
|
||||
let mut rng = StdRng::seed_from_u64(0);
|
||||
cx.search_options(rt, SearchOptions::new(limit), &mut rng)
|
||||
}
|
||||
|
||||
fn egraph_has_op(cx: &Graph, op_name: &str) -> bool {
|
||||
cx.egraph()
|
||||
.expect("search space should be built")
|
||||
.enodes
|
||||
.values()
|
||||
.any(|(label, _)| label == op_name)
|
||||
}
|
||||
|
||||
fn assert_matmul_options(cx: &Graph, mps_op_name: &str) {
|
||||
assert!(
|
||||
egraph_has_op(cx, mps_op_name),
|
||||
"expected {mps_op_name} rewrite option in e-graph"
|
||||
);
|
||||
assert!(
|
||||
egraph_has_op(cx, "GenericMatmul"),
|
||||
"expected GenericMatmul rewrite option in e-graph"
|
||||
);
|
||||
}
|
||||
|
||||
fn write_test_safetensors(tensors: &[(&str, Dtype, Vec<usize>, Vec<u8>)]) -> PathBuf {
|
||||
let tensor_views: HashMap<String, TensorView<'_>> = tensors
|
||||
.iter()
|
||||
@@ -401,6 +426,18 @@ proptest! {
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn metal_build_search_space_accepts_memory_budget() {
|
||||
let mut cx = Graph::default();
|
||||
let a = cx.tensor(4);
|
||||
let b = cx.tensor(4);
|
||||
(a * b).output();
|
||||
|
||||
cx.build_search_space_with_options::<MetalRuntime>(
|
||||
BuildSearchSpaceOptions::new().max_memory_mib(1),
|
||||
);
|
||||
}
|
||||
|
||||
/// Simple deterministic test for add
|
||||
#[test]
|
||||
fn metal_simple_add() {
|
||||
@@ -665,7 +702,7 @@ fn metal_specialized_matmul() {
|
||||
|
||||
rt.set_data(a, &a_data);
|
||||
rt.set_data(b, &b_data);
|
||||
rt = cx.search(rt, 1);
|
||||
rt = search_candidates(&mut cx, rt, 32);
|
||||
assert!(
|
||||
rt.contains_matmul(),
|
||||
"expected Metal runtime to fuse matmul, kernels: {:?}",
|
||||
@@ -698,6 +735,7 @@ fn metal_regular_tiled_matmul_path() {
|
||||
let output = a.matmul(b).output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
assert_matmul_options(&cx, "MPSMatmul");
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
|
||||
let a_data = seeded_data(m * k, 0.4, -0.2);
|
||||
@@ -705,19 +743,7 @@ fn metal_regular_tiled_matmul_path() {
|
||||
|
||||
rt.set_data(a, &a_data);
|
||||
rt.set_data(b, &b_data);
|
||||
rt = cx.search(rt, 1);
|
||||
|
||||
let kernels = rt.debug_kernel_ops();
|
||||
assert!(
|
||||
kernels.iter().any(|k| k.contains("MPSMatmul")),
|
||||
"expected MPS matmul path, kernels: {:?}",
|
||||
kernels
|
||||
);
|
||||
assert!(
|
||||
!kernels.iter().any(|k| k.contains("GenericMatmul")),
|
||||
"MPS-compatible matmul should not extract the generic fallback, kernels: {:?}",
|
||||
kernels
|
||||
);
|
||||
rt = search_candidates(&mut cx, rt, 32);
|
||||
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
@@ -744,6 +770,7 @@ fn metal_mps_matmul_transposed_rhs_weight_layout() {
|
||||
let output = a.matmul(weight.t()).output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
assert_matmul_options(&cx, "MPSMatmul");
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
|
||||
let a_data = seeded_data(m * k, 0.35, -0.17);
|
||||
@@ -751,14 +778,7 @@ fn metal_mps_matmul_transposed_rhs_weight_layout() {
|
||||
|
||||
rt.set_data(a, &a_data);
|
||||
rt.set_data(weight, &weight_data);
|
||||
rt = cx.search(rt, 1);
|
||||
|
||||
let kernels = rt.debug_kernel_ops();
|
||||
assert!(
|
||||
kernels.iter().any(|k| k.contains("transpose_rhs: true")),
|
||||
"expected MPS matmul to cover transposed row-major RHS, kernels: {:?}",
|
||||
kernels
|
||||
);
|
||||
rt = search_candidates(&mut cx, rt, 32);
|
||||
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
@@ -785,6 +805,7 @@ fn metal_mps_matmul_transposed_lhs_layout() {
|
||||
let output = lhs_storage.t().matmul(rhs).output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
assert_matmul_options(&cx, "MPSMatmul");
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
|
||||
let lhs_data = seeded_data(k * m, 0.31, -0.12);
|
||||
@@ -792,14 +813,7 @@ fn metal_mps_matmul_transposed_lhs_layout() {
|
||||
|
||||
rt.set_data(lhs_storage, &lhs_data);
|
||||
rt.set_data(rhs, &rhs_data);
|
||||
rt = cx.search(rt, 1);
|
||||
|
||||
let kernels = rt.debug_kernel_ops();
|
||||
assert!(
|
||||
kernels.iter().any(|k| k.contains("transpose_lhs: true")),
|
||||
"expected MPS matmul to cover transposed row-major LHS, kernels: {:?}",
|
||||
kernels
|
||||
);
|
||||
rt = search_candidates(&mut cx, rt, 32);
|
||||
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
@@ -830,20 +844,14 @@ fn metal_mps_batched_matmul_row_row_layout() {
|
||||
let output = a.matmul(b).output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
assert_matmul_options(&cx, "MPSBatchedMatmul");
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
|
||||
let a_data = seeded_data(batch * m * k, 0.17, -0.08);
|
||||
let b_data = seeded_data(batch * k * n, 0.11, -0.05);
|
||||
rt.set_data(a, &a_data);
|
||||
rt.set_data(b, &b_data);
|
||||
rt = cx.search(rt, 1);
|
||||
|
||||
let kernels = rt.debug_kernel_ops();
|
||||
assert!(
|
||||
kernels.iter().any(|k| k.contains("MPSBatchedMatmul")),
|
||||
"expected MPS batched matmul path, kernels: {:?}",
|
||||
kernels
|
||||
);
|
||||
rt = search_candidates(&mut cx, rt, 32);
|
||||
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
@@ -880,13 +888,17 @@ fn metal_generic_matmul_covers_noncontiguous_merged_head_projection() {
|
||||
let output = merged.matmul(weight.t()).output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
assert!(
|
||||
egraph_has_op(&cx, "GenericMatmul"),
|
||||
"expected GenericMatmul rewrite option in e-graph"
|
||||
);
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
|
||||
let attn_data = seeded_data(heads * seq * head_dim, 0.19, -0.09);
|
||||
let weight_data = seeded_data(out_dim * hidden, 0.14, -0.06);
|
||||
rt.set_data(attn, &attn_data);
|
||||
rt.set_data(weight, &weight_data);
|
||||
rt = cx.search(rt, 1);
|
||||
rt = search_candidates(&mut cx, rt, 32);
|
||||
|
||||
let kernels = rt.debug_kernel_ops();
|
||||
assert!(
|
||||
@@ -935,22 +947,14 @@ fn metal_mps_batched_matmul_transposed_rhs_layout() {
|
||||
let output = a.matmul(weight.permute((0, 2, 1))).output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
assert_matmul_options(&cx, "MPSBatchedMatmul");
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
|
||||
let a_data = seeded_data(batch * m * k, 0.13, -0.06);
|
||||
let weight_data = seeded_data(batch * n * k, 0.09, -0.04);
|
||||
rt.set_data(a, &a_data);
|
||||
rt.set_data(weight, &weight_data);
|
||||
rt = cx.search(rt, 1);
|
||||
|
||||
let kernels = rt.debug_kernel_ops();
|
||||
assert!(
|
||||
kernels
|
||||
.iter()
|
||||
.any(|k| k.contains("MPSBatchedMatmul") && k.contains("transpose_rhs: true")),
|
||||
"expected MPS batched matmul transposed RHS path, kernels: {:?}",
|
||||
kernels
|
||||
);
|
||||
rt = search_candidates(&mut cx, rt, 32);
|
||||
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
@@ -984,6 +988,7 @@ fn metal_mps_matmul_f16_transposed_rhs_weight_layout() {
|
||||
let output = a.matmul(weight.t()).cast(DType::F32).output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
assert_matmul_options(&cx, "MPSMatmul");
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
|
||||
let a_data = seeded_data(m * k, 0.22, -0.07);
|
||||
@@ -991,14 +996,7 @@ fn metal_mps_matmul_f16_transposed_rhs_weight_layout() {
|
||||
|
||||
rt.set_data(a, to_f16_vec(&a_data));
|
||||
rt.set_data(weight, to_f16_vec(&weight_data));
|
||||
rt = cx.search(rt, 1);
|
||||
|
||||
let kernels = rt.debug_kernel_ops();
|
||||
assert!(
|
||||
kernels.iter().any(|k| k.contains("transpose_rhs: true")),
|
||||
"expected MPS F16 matmul to cover transposed row-major RHS, kernels: {:?}",
|
||||
kernels
|
||||
);
|
||||
rt = search_candidates(&mut cx, rt, 32);
|
||||
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -865,3 +865,29 @@ Also: search-time dummy-1 inputs are not the same shape as runtime inputs. Anyth
|
||||
- Added `aten.gelu.default → a.gelu()` and `aten.silu.default → a.silu()` to `dispatch.rs`.
|
||||
- Worked around the `-Infinity` issue at the model level by using a finite `-1e10` for the causal mask in the example (matches the Rust example's convention). The cleaner fix (parsing `"-Infinity"`/`"Infinity"`/`"NaN"` strings in `get_float_arg` / `translate_full`) is left for a follow-up.
|
||||
6. **Principle**: when adding a new model that goes through the PT2 backend, expect to plug small holes in `dispatch.rs` and `translator/tensor.rs::translate_full`. The trace points at the python frame, not the Rust dispatch arm — open `dispatch.rs`, ctrl-F the offending op name, and add the one-liner. For float-shaped sentinel values (`-inf`, `inf`, `nan`), the export pipeline currently only accepts finite floats; either rewrite the model or extend the parser.
|
||||
|
||||
---
|
||||
|
||||
## 2026-05-21 — DLRM compile: silent mis-stride on `index.Tensor` with a None-prefix
|
||||
|
||||
1. **Symptom**: compiling facebookresearch/dlrm through `luminal_backend` failed in the top-MLP with `assertion left == right failed: Dims must match to add tensors. left: [2, 8] right: [6, 8]`. The error surfaced ~5 ops downstream of the actual bug, with no mention of `index` anywhere in the trace.
|
||||
|
||||
2. **Root cause**: `translate_index_tensor` in `crates/luminal_python/rust/src/translator/movement.rs` had two code paths for advanced indexing. The first ran when an `OptionalTensors` arg held exactly one non-None entry on a specific dim (`first_non_none_dim > 0 && index_names.len() == 1`); it correctly used `first_non_none_dim` to gather on the right axis. The second — the general multi-index fall-through — silently ignored `first_non_none_dim` and computed strides/flat-source-shape as if indices always started at dim 0. DLRM's dot interaction does `Z[:, li, lj]` (Z is `[B, ni, nj]`, two 1-D index tensors after a `:`), which hits the multi-index path with `first_non_none_dim = 1`. The translator built strides over `src_shape[..n_indexed] = [B, ni]` and a flat-source of shape `[B*ni, nj]`, instead of striding over `[ni, nj]` with prefix-dim `[B]`. The downstream gather produced a tensor with the wrong leading dim (6 — the index length — instead of B), and the mismatch only blew up later when broadcast-add into the top-MLP hidden state.
|
||||
|
||||
3. **Why it was hard to find**: the trace ends in `process_pt2` with a luminal core assertion about broadcasting in a `+` op. Nothing in the message names the *upstream* op that produced the wrong shape. Worse, the bug only manifests when ALL of {two-plus index tensors, at least one leading `None`, downstream broadcast-sensitive consumer} are present — the common case (`a[idx]`, `a[idx, jdx]` with no prefix) just works. So the bug had survived through every prior model translator test.
|
||||
|
||||
4. **The fix**: split the prefix-aware case into its own helper `translate_index_tensor_with_prefix`. It explicitly partitions `src.shape` into `prefix_dims / indexed_dims / suffix_dims`, builds the flat sub-index over `indexed_dims`, promotes/expands it into the full output shape, and adds a broadcast prefix-offset constructed from `arange`s over each prefix dim. Result is fully-flat `source.gather(absolute_idx)`. The suffix-non-empty case is left guarded with a `bail!` (it's separable but DLRM doesn't need it).
|
||||
|
||||
5. **Principle**: a shape-keyed assumption baked into one branch of a multi-branch translator is a silent footgun — when the fall-through path is reached with a value the assumption rules out, you get *wrong shapes silently*, and the failure surfaces wherever the wrong shape first encounters a consumer that cares. Guard early: if an invariant the code relies on isn't met (here, "indices apply to the leading dims of source"), check it explicitly and `bail!` with the offending shape rather than computing forward. Even better, refactor so the unsupported case routes to a dedicated path the moment the assumption diverges — small risk of double-implementation, large reduction in "compile silently produces wrong output."
|
||||
|
||||
## 2026-05-21 — DLRM compile: `EmbeddingBag` translator gap
|
||||
|
||||
1. **Symptom**: same `luminal_backend` compile, first error: `RuntimeError: Failed to translate node N: torch.ops.aten._embedding_bag_forward_only.default: Unsupported ATen op`. This is the central op of DLRM — every sparse feature lookup decomposes to it via `nn.EmbeddingBag`.
|
||||
|
||||
2. **What's needed**: `_embedding_bag_forward_only(weight, indices, offsets, ..., mode, ...)` produces `output[b] = reduce_op(weight[indices[offsets[b]:offsets[b+1]]])` for each bag `b`. The general case is a *runtime segment reduction* — the bag boundaries depend on `offsets`, which is a runtime tensor — and luminal has no native segment-reduce primitive.
|
||||
|
||||
3. **The fix (in this session)**: add `translate_embedding_bag` covering the uniform-bag-size case, which is what DLRM actually uses. Read `indices.shape[0] = N` and `offsets.shape[0] = B` off the static shape info, compute bag size `K = N / B`, bail if they don't divide. Then gather `[N, D]` (same construction as `translate_embedding`), reshape to `[B, K, D]`, reduce along axis 1 according to `mode` (sum/mean/max). For `K=1` (the eval-time-1-lookup-per-sample DLRM path) skip the reshape+reduce — it's just an `embedding` lookup. `per_sample_weights` and non-uniform bags are guarded with `bail!`.
|
||||
|
||||
4. **Why this works for DLRM but isn't general**: a true segment reduction needs either (a) static knowledge of every segment boundary (what we get when bags are uniform), or (b) a scatter-add primitive that handles per-segment accumulation at runtime. (a) covers DLRM's training/eval data generator and the common recsys case where each sample has K-hot lookups for fixed K. (b) is required for any model that genuinely has variable-length bags per sample (e.g. variable-length feature crossings) and is a follow-up.
|
||||
|
||||
5. **Principle**: when a PyTorch op has no straight-line luminal lowering, look at the *shapes the model actually feeds in* before declaring it unsupportable. A "segment reduction" over offsets is a hard problem in general; "segment reduction where every bag has K elements with K statically known from indices.shape[0]/offsets.shape[0]" is a 5-line gather+reshape+reduce. The PT2 graph carries the shape info for free — use it.
|
||||
|
||||
@@ -127,6 +127,12 @@ impl<'a> Translator<'a> {
|
||||
}
|
||||
|
||||
// addmm: beta*input + alpha*(mat1 @ mat2)
|
||||
//
|
||||
// PyTorch's nn.Linear with bias generates `addmm(bias, input, weight.t())`
|
||||
// with the default `beta=alpha=1.0`. Emitting the multiplies in that
|
||||
// case wastes 2 HLIR nodes per Linear that egglog has to fold later;
|
||||
// for a 4-Linear DLRM that's 8 nodes off the search-space count.
|
||||
// Skip them when the scale is 1.
|
||||
"torch.ops.aten.addmm.default" => {
|
||||
let input = self.get_input_tensor(node, 0)?;
|
||||
let mat1 = self.get_input_tensor(node, 1)?;
|
||||
@@ -135,7 +141,9 @@ impl<'a> Translator<'a> {
|
||||
let alpha = self.get_float_arg(node, 4).unwrap_or(1.0) as f32;
|
||||
let mm = mat1.matmul(mat2);
|
||||
let (input, mm) = broadcast_binary(input, mm);
|
||||
input * beta + mm * alpha
|
||||
let scaled_input = if beta == 1.0 { input } else { input * beta };
|
||||
let scaled_mm = if alpha == 1.0 { mm } else { mm * alpha };
|
||||
scaled_input + scaled_mm
|
||||
}
|
||||
|
||||
// Convolution
|
||||
@@ -154,6 +162,10 @@ impl<'a> Translator<'a> {
|
||||
|
||||
// Embedding
|
||||
"torch.ops.aten.embedding.default" => self.translate_embedding(node)?,
|
||||
"torch.ops.aten._embedding_bag.default"
|
||||
| "torch.ops.aten._embedding_bag_forward_only.default" => {
|
||||
self.translate_embedding_bag(node)?
|
||||
}
|
||||
|
||||
// Softmax
|
||||
"torch.ops.aten._softmax.default" => {
|
||||
|
||||
434
crates/luminal_python/rust/src/translator/dlrm_pattern.rs
Normal file
434
crates/luminal_python/rust/src/translator/dlrm_pattern.rs
Normal file
@@ -0,0 +1,434 @@
|
||||
//! DLRM-family pattern matcher for the PT2 translator.
|
||||
//!
|
||||
//! Recognizes the `MiniDLRM` topology in a parsed PT2 graph (bot MLP →
|
||||
//! N sparse `_embedding_bag_forward_only` lookups (bag-size 1) →
|
||||
//! dot-product interaction via `bmm` + lower-triangular `index.Tensor` →
|
||||
//! top MLP ending in `sigmoid`) and, when matched, emits a single
|
||||
//! [`luminal_cuda_lite::kernel::DlrmMegaCustom`] op that replaces the
|
||||
//! entire per-node translation. The runtime then sees ONE host op
|
||||
//! instead of the 8 cuBLAS+CudaGraphOp ops the normal path produces.
|
||||
//!
|
||||
//! The matcher is intentionally conservative — any mismatch returns
|
||||
//! `None` and the translator falls back to its standard node-by-node
|
||||
//! walk, so wrong-graphs never produce wrong-output, only "the fast
|
||||
//! path didn't trigger." Diagnostic prints are gated on
|
||||
//! `LUMINAL_DLRM_MEGAKERNEL_DEBUG=1` for development.
|
||||
//!
|
||||
//! See `examples/dlrm/src/megakernel.rs` for the standalone proof of
|
||||
//! concept and `crates/luminal_cuda_lite/src/kernel/dlrm_megakernel.rs`
|
||||
//! for the parameterized kernel itself.
|
||||
//!
|
||||
//! Companion plan: see `/home/ubuntu/.claude/plans/can-you-plan-out-mossy-wave.md`.
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use luminal::prelude::*;
|
||||
use luminal_cuda_lite::kernel::{DlrmMegaCustom, DlrmMegaKernel};
|
||||
|
||||
use crate::pt2_parser::ParsedPT2;
|
||||
use crate::pt2_schema::Node;
|
||||
|
||||
use super::Translator;
|
||||
|
||||
/// Resolved DLRM shape + the PT2 graph names of every tensor the
|
||||
/// megakernel needs as input. All weight/input lookups go through
|
||||
/// `Translator::get_tensor(name)` which is keyed by PT2 graph_name.
|
||||
#[derive(Debug)]
|
||||
pub(super) struct DlrmShape {
|
||||
pub batch: usize,
|
||||
pub n_dense_in: usize,
|
||||
pub ln_bot: Vec<usize>,
|
||||
pub n_sparse: usize,
|
||||
pub vocab_sizes: Vec<usize>,
|
||||
pub m_spa: usize,
|
||||
pub ln_top: Vec<usize>,
|
||||
|
||||
pub dense_input_name: String,
|
||||
pub index_input_names: Vec<String>, // length n_sparse
|
||||
pub emb_weight_names: Vec<String>, // length n_sparse
|
||||
pub bot_weight_names: Vec<(String, String)>, // (weight, bias) per Linear
|
||||
pub top_weight_names: Vec<(String, String)>, // (weight, bias) per Linear
|
||||
pub output_name: String,
|
||||
}
|
||||
|
||||
fn debug_enabled() -> bool {
|
||||
std::env::var("LUMINAL_DLRM_MEGAKERNEL_DEBUG").map(|v| v == "1").unwrap_or(false)
|
||||
}
|
||||
|
||||
macro_rules! dbgln {
|
||||
($($arg:tt)*) => {
|
||||
if debug_enabled() {
|
||||
eprintln!("[dlrm_pattern] {}", format!($($arg)*));
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/// Try to interpret the parsed PT2 program as a DLRM-shape forward.
|
||||
/// Returns `None` if any structural check fails — translator falls back
|
||||
/// to the standard dispatch.
|
||||
pub(super) fn match_dlrm(parsed: &ParsedPT2) -> Option<DlrmShape> {
|
||||
let nodes = &parsed.program.graph_module.graph.nodes;
|
||||
|
||||
// ---- 1. Index the key op types ----------------------------------
|
||||
let emb_node_idxs: Vec<usize> = nodes
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|(_, n)| n.target == "torch.ops.aten._embedding_bag_forward_only.default"
|
||||
|| n.target == "torch.ops.aten._embedding_bag.default")
|
||||
.map(|(i, _)| i)
|
||||
.collect();
|
||||
if emb_node_idxs.is_empty() {
|
||||
dbgln!("no embedding_bag nodes — not DLRM");
|
||||
return None;
|
||||
}
|
||||
let n_sparse = emb_node_idxs.len();
|
||||
let first_emb = emb_node_idxs[0];
|
||||
let last_emb = *emb_node_idxs.last().unwrap();
|
||||
|
||||
let addmm_idxs: Vec<usize> = nodes
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|(_, n)| n.target == "torch.ops.aten.addmm.default")
|
||||
.map(|(i, _)| i)
|
||||
.collect();
|
||||
let bot_addmms: Vec<usize> =
|
||||
addmm_idxs.iter().filter(|&&i| i < first_emb).copied().collect();
|
||||
let top_addmms: Vec<usize> =
|
||||
addmm_idxs.iter().filter(|&&i| i > last_emb).copied().collect();
|
||||
if bot_addmms.is_empty() || top_addmms.is_empty() {
|
||||
dbgln!(
|
||||
"addmm split: bot={}, top={} (expected ≥1 each)",
|
||||
bot_addmms.len(),
|
||||
top_addmms.len()
|
||||
);
|
||||
return None;
|
||||
}
|
||||
|
||||
let sigmoid_idx = nodes
|
||||
.iter()
|
||||
.enumerate()
|
||||
.find(|(_, n)| n.target == "torch.ops.aten.sigmoid.default")
|
||||
.map(|(i, _)| i)?;
|
||||
if sigmoid_idx < *top_addmms.last().unwrap() {
|
||||
dbgln!("sigmoid before last top addmm — not DLRM ordering");
|
||||
return None;
|
||||
}
|
||||
|
||||
let bmm_idx = nodes
|
||||
.iter()
|
||||
.enumerate()
|
||||
.find(|(_, n)| n.target == "torch.ops.aten.bmm.default")
|
||||
.map(|(i, _)| i)?;
|
||||
if bmm_idx < last_emb || bmm_idx > top_addmms[0] {
|
||||
dbgln!("bmm position wrong (idx {bmm_idx}, last_emb {last_emb}, first_top_addmm {})", top_addmms[0]);
|
||||
return None;
|
||||
}
|
||||
|
||||
// index.Tensor must exist between bmm and the first top addmm — that's
|
||||
// the (li, lj) gather of the lower-triangular pairs.
|
||||
let _index_idx = nodes
|
||||
.iter()
|
||||
.enumerate()
|
||||
.find(|(i, n)| n.target == "torch.ops.aten.index.Tensor" && *i > bmm_idx)
|
||||
.map(|(i, _)| i)?;
|
||||
|
||||
// ---- 2. Extract embedding info (vocab, m_spa, indices, weights) -
|
||||
let mut vocab_sizes = Vec::with_capacity(n_sparse);
|
||||
let mut emb_weight_names = Vec::with_capacity(n_sparse);
|
||||
let mut index_input_names = Vec::with_capacity(n_sparse);
|
||||
let mut batch_opt: Option<usize> = None;
|
||||
let mut m_spa_opt: Option<usize> = None;
|
||||
|
||||
for &i in &emb_node_idxs {
|
||||
let n = &nodes[i];
|
||||
// Validate the bag invariants the megakernel relies on.
|
||||
// arg ordering: (weight, indices, offsets, scale_grad_by_freq, mode,
|
||||
// sparse, per_sample_weights, include_last_offset, padding_idx)
|
||||
let weight_name = n.inputs.first()?.arg.as_tensor_name()?.to_string();
|
||||
let indices_name = n.inputs.get(1)?.arg.as_tensor_name()?.to_string();
|
||||
let offsets_name = n.inputs.get(2)?.arg.as_tensor_name()?.to_string();
|
||||
|
||||
// mode must be 0 (sum) — anything else falls back.
|
||||
let mode = n.inputs.get(4).and_then(|a| a.arg.as_int()).unwrap_or(0);
|
||||
if mode != 0 {
|
||||
dbgln!("embedding_bag mode={mode} != 0 (sum)");
|
||||
return None;
|
||||
}
|
||||
// per_sample_weights must be None (no tensor arg in slot 6).
|
||||
if let Some(arg) = n.inputs.get(6)
|
||||
&& arg.arg.as_tensor_name().is_some()
|
||||
{
|
||||
dbgln!("embedding_bag has per_sample_weights — not supported");
|
||||
return None;
|
||||
}
|
||||
// include_last_offset must be false.
|
||||
if matches!(
|
||||
n.inputs.get(7).and_then(|a| a.arg.as_bool()),
|
||||
Some(true)
|
||||
) {
|
||||
dbgln!("embedding_bag include_last_offset=true — not supported");
|
||||
return None;
|
||||
}
|
||||
|
||||
let weight_meta = parsed.tensor_meta(&weight_name)?;
|
||||
if weight_meta.sizes.len() != 2 {
|
||||
dbgln!("embedding weight has non-2D shape");
|
||||
return None;
|
||||
}
|
||||
let v = weight_meta.sizes[0].hint()? as usize;
|
||||
let m = weight_meta.sizes[1].hint()? as usize;
|
||||
if let Some(prev) = m_spa_opt
|
||||
&& prev != m
|
||||
{
|
||||
dbgln!("inconsistent m_spa across embeddings ({prev} vs {m})");
|
||||
return None;
|
||||
}
|
||||
m_spa_opt = Some(m);
|
||||
|
||||
// Bag-size-1: indices.len == offsets.len == batch.
|
||||
let idx_meta = parsed.tensor_meta(&indices_name)?;
|
||||
let off_meta = parsed.tensor_meta(&offsets_name)?;
|
||||
if idx_meta.sizes.len() != 1 || off_meta.sizes.len() != 1 {
|
||||
return None;
|
||||
}
|
||||
let idx_len = idx_meta.sizes[0].hint()? as usize;
|
||||
let off_len = off_meta.sizes[0].hint()? as usize;
|
||||
if idx_len != off_len {
|
||||
dbgln!(
|
||||
"non-uniform bag (indices={idx_len}, offsets={off_len}) — fallback"
|
||||
);
|
||||
return None;
|
||||
}
|
||||
if let Some(prev) = batch_opt
|
||||
&& prev != idx_len
|
||||
{
|
||||
dbgln!("inconsistent batch across embeddings ({prev} vs {idx_len})");
|
||||
return None;
|
||||
}
|
||||
batch_opt = Some(idx_len);
|
||||
|
||||
vocab_sizes.push(v);
|
||||
emb_weight_names.push(weight_name);
|
||||
index_input_names.push(indices_name);
|
||||
}
|
||||
let m_spa = m_spa_opt?;
|
||||
let batch = batch_opt?;
|
||||
|
||||
// ---- 3. Reconstruct bot/top MLP widths --------------------------
|
||||
//
|
||||
// addmm(bias, input, weight^T) → output (B, out)
|
||||
// inputs[0] = bias (out,) — gives us the layer's out_features
|
||||
// inputs[1] = input (B, in_w) — first addmm in each chain tells us in_w
|
||||
// inputs[2] = weight^T — usually produced by a `permute.default`
|
||||
// whose input is the (out, in) weight param.
|
||||
|
||||
let extract_chain_shape = |chain: &[usize]| -> Option<Vec<usize>> {
|
||||
let mut ln = Vec::with_capacity(chain.len() + 1);
|
||||
for (i, &node_idx) in chain.iter().enumerate() {
|
||||
let n = &nodes[node_idx];
|
||||
let bias_name = n.inputs.first()?.arg.as_tensor_name()?;
|
||||
let bias_meta = parsed.tensor_meta(bias_name)?;
|
||||
if bias_meta.sizes.len() != 1 {
|
||||
return None;
|
||||
}
|
||||
let out = bias_meta.sizes[0].hint()? as usize;
|
||||
if i == 0 {
|
||||
let input_name = n.inputs.get(1)?.arg.as_tensor_name()?;
|
||||
let in_meta = parsed.tensor_meta(input_name)?;
|
||||
if in_meta.sizes.len() != 2 {
|
||||
return None;
|
||||
}
|
||||
let in_w = in_meta.sizes[1].hint()? as usize;
|
||||
ln.push(in_w);
|
||||
}
|
||||
ln.push(out);
|
||||
}
|
||||
Some(ln)
|
||||
};
|
||||
let ln_bot = extract_chain_shape(&bot_addmms)?;
|
||||
let ln_top = extract_chain_shape(&top_addmms)?;
|
||||
|
||||
// ---- 4. Shape consistency checks --------------------------------
|
||||
if *ln_bot.last()? != m_spa {
|
||||
dbgln!("ln_bot.last() = {} != m_spa {m_spa}", ln_bot.last()?);
|
||||
return None;
|
||||
}
|
||||
let n_feat = 1 + n_sparse;
|
||||
let n_pairs = n_feat * (n_feat - 1) / 2;
|
||||
if ln_top[0] != m_spa + n_pairs {
|
||||
dbgln!(
|
||||
"ln_top[0] = {} != m_spa+n_pairs = {}",
|
||||
ln_top[0],
|
||||
m_spa + n_pairs
|
||||
);
|
||||
return None;
|
||||
}
|
||||
if *ln_top.last()? != 1 {
|
||||
dbgln!("ln_top.last() = {} != 1", ln_top.last()?);
|
||||
return None;
|
||||
}
|
||||
if vocab_sizes.len() != n_sparse {
|
||||
return None;
|
||||
}
|
||||
if ln_bot.len() < 2 || ln_top.len() < 2 {
|
||||
return None;
|
||||
}
|
||||
|
||||
// ---- 5. Pull weight + bias parameter names ----------------------
|
||||
let extract_weights = |chain: &[usize]| -> Option<Vec<(String, String)>> {
|
||||
let mut out = Vec::with_capacity(chain.len());
|
||||
for &node_idx in chain {
|
||||
let n = &nodes[node_idx];
|
||||
let bias_name = n.inputs.first()?.arg.as_tensor_name()?.to_string();
|
||||
let mat2_name = n.inputs.get(2)?.arg.as_tensor_name()?;
|
||||
let weight_name = resolve_weight_param(nodes, mat2_name)?;
|
||||
out.push((weight_name, bias_name));
|
||||
}
|
||||
Some(out)
|
||||
};
|
||||
let bot_weight_names = extract_weights(&bot_addmms)?;
|
||||
let top_weight_names = extract_weights(&top_addmms)?;
|
||||
|
||||
// ---- 6. dense_input + output names ------------------------------
|
||||
let dense_input_name = nodes[bot_addmms[0]]
|
||||
.inputs
|
||||
.get(1)?
|
||||
.arg
|
||||
.as_tensor_name()?
|
||||
.to_string();
|
||||
// Validate it's actually a user input (not an intermediate).
|
||||
let user_input_names: std::collections::HashSet<&str> = parsed
|
||||
.classify_inputs()
|
||||
.iter()
|
||||
.filter_map(|i| match i {
|
||||
crate::pt2_parser::InputKind::UserInput { graph_name } => Some(graph_name.as_str()),
|
||||
_ => None,
|
||||
})
|
||||
.map(|s| s.to_string())
|
||||
.collect::<std::collections::HashSet<String>>()
|
||||
.iter()
|
||||
.map(|s| -> &str { unsafe { std::mem::transmute::<&str, &str>(s.as_str()) } })
|
||||
.collect();
|
||||
let _ = user_input_names; // suppress dead_code if not used; cleaner check below
|
||||
// (Simpler: just check the name is in classified user inputs by string.)
|
||||
let inputs = parsed.classify_inputs();
|
||||
let is_user = inputs.iter().any(|i| {
|
||||
matches!(
|
||||
i,
|
||||
crate::pt2_parser::InputKind::UserInput { graph_name } if graph_name == &dense_input_name
|
||||
)
|
||||
});
|
||||
if !is_user {
|
||||
dbgln!("dense_input candidate {dense_input_name} is not a user input");
|
||||
return None;
|
||||
}
|
||||
|
||||
let output_name = nodes[sigmoid_idx]
|
||||
.outputs
|
||||
.first()?
|
||||
.as_tensor
|
||||
.as_ref()?
|
||||
.name
|
||||
.clone();
|
||||
|
||||
let shape = DlrmShape {
|
||||
batch,
|
||||
n_dense_in: ln_bot[0],
|
||||
ln_bot,
|
||||
n_sparse,
|
||||
vocab_sizes,
|
||||
m_spa,
|
||||
ln_top,
|
||||
dense_input_name,
|
||||
index_input_names,
|
||||
emb_weight_names,
|
||||
bot_weight_names,
|
||||
top_weight_names,
|
||||
output_name,
|
||||
};
|
||||
dbgln!(
|
||||
"matched DLRM: batch={} ln_bot={:?} n_sparse={} vocabs={:?} m_spa={} ln_top={:?}",
|
||||
shape.batch,
|
||||
shape.ln_bot,
|
||||
shape.n_sparse,
|
||||
shape.vocab_sizes,
|
||||
shape.m_spa,
|
||||
shape.ln_top
|
||||
);
|
||||
Some(shape)
|
||||
}
|
||||
|
||||
/// Walk back from an addmm's `mat2` argument to the underlying weight
|
||||
/// parameter. PyTorch's `nn.Linear` decomposes to
|
||||
/// `permute(weight) → addmm(bias, x, permuted)`, so we expect mat2 to be
|
||||
/// the output of a `permute.default` node whose input is the weight.
|
||||
/// If mat2 is itself a graph input (no producing node), it IS the weight.
|
||||
fn resolve_weight_param(nodes: &[Node], name: &str) -> Option<String> {
|
||||
for n in nodes {
|
||||
let Some(first_out) = n.outputs.first().and_then(|o| o.as_tensor.as_ref()) else {
|
||||
continue;
|
||||
};
|
||||
if first_out.name == name {
|
||||
// mat2 was produced by an op. Only `permute.default` is expected;
|
||||
// anything else is unfamiliar and we should fall back.
|
||||
if n.target == "torch.ops.aten.permute.default" {
|
||||
return n.inputs.first()?.arg.as_tensor_name().map(String::from);
|
||||
} else if n.target == "torch.ops.aten.t.default" {
|
||||
return n.inputs.first()?.arg.as_tensor_name().map(String::from);
|
||||
} else {
|
||||
dbgln!(
|
||||
"addmm mat2 produced by unexpected op '{}' — fallback",
|
||||
n.target
|
||||
);
|
||||
return None;
|
||||
}
|
||||
}
|
||||
}
|
||||
// No producing node — mat2 is a graph input (param) directly.
|
||||
Some(name.to_string())
|
||||
}
|
||||
|
||||
/// Build the megakernel CustomOp inputs vec in the canonical order
|
||||
/// expected by [`DlrmMegaKernel`] and insert it into the translator's
|
||||
/// luminal graph. Registers the result under `shape.output_name` so the
|
||||
/// downstream output-emission loop finds it.
|
||||
pub(super) fn emit_megakernel(t: &mut Translator<'_>, shape: DlrmShape) -> Result<()> {
|
||||
// Resolve every input tensor by PT2 graph_name through Translator.tensors.
|
||||
let mut inputs: Vec<GraphTensor> = Vec::new();
|
||||
inputs.push(
|
||||
t.get_tensor(&shape.dense_input_name)
|
||||
.with_context(|| format!("dense input {} not in tensors", shape.dense_input_name))?,
|
||||
);
|
||||
for n in &shape.index_input_names {
|
||||
inputs.push(t.get_tensor(n).with_context(|| format!("index input {n} not in tensors"))?);
|
||||
}
|
||||
for n in &shape.emb_weight_names {
|
||||
inputs.push(t.get_tensor(n).with_context(|| format!("emb weight {n} not in tensors"))?);
|
||||
}
|
||||
for (w, b) in &shape.bot_weight_names {
|
||||
inputs.push(t.get_tensor(w).with_context(|| format!("bot weight {w} not in tensors"))?);
|
||||
inputs.push(t.get_tensor(b).with_context(|| format!("bot bias {b} not in tensors"))?);
|
||||
}
|
||||
for (w, b) in &shape.top_weight_names {
|
||||
inputs.push(t.get_tensor(w).with_context(|| format!("top weight {w} not in tensors"))?);
|
||||
inputs.push(t.get_tensor(b).with_context(|| format!("top bias {b} not in tensors"))?);
|
||||
}
|
||||
|
||||
let kernel = DlrmMegaKernel {
|
||||
batch: shape.batch,
|
||||
n_dense_in: shape.n_dense_in,
|
||||
ln_bot: shape.ln_bot.clone(),
|
||||
n_sparse: shape.n_sparse,
|
||||
vocab_sizes: shape.vocab_sizes.clone(),
|
||||
m_spa: shape.m_spa,
|
||||
ln_top: shape.ln_top.clone(),
|
||||
};
|
||||
let out = t.graph.custom_op(
|
||||
DlrmMegaCustom(kernel),
|
||||
inputs,
|
||||
(shape.batch, 1usize),
|
||||
DType::F32,
|
||||
);
|
||||
t.tensors.insert(shape.output_name.clone(), out);
|
||||
dbgln!("emitted DlrmMegaCustom; output={}", shape.output_name);
|
||||
Ok(())
|
||||
}
|
||||
@@ -6,6 +6,8 @@ mod attention;
|
||||
mod binary;
|
||||
mod conv;
|
||||
mod dispatch;
|
||||
#[cfg(feature = "cuda")]
|
||||
mod dlrm_pattern;
|
||||
mod movement;
|
||||
mod reduction;
|
||||
mod tensor;
|
||||
@@ -70,12 +72,31 @@ impl<'a> Translator<'a> {
|
||||
fn translate_graph(&mut self) -> Result<()> {
|
||||
self.create_inputs()?;
|
||||
|
||||
// Fast path: if the entire forward matches the DLRM family shape,
|
||||
// emit one DlrmMegaCustom op instead of walking nodes. On any
|
||||
// mismatch the matcher returns None and we fall through to the
|
||||
// standard dispatch — no semantic difference, just slower (~503µs
|
||||
// vs ~30µs at bs=2048 for MiniDLRM). CUDA-only: the megakernel
|
||||
// is a CUDA CustomOp.
|
||||
#[cfg(feature = "cuda")]
|
||||
if let Some(shape) = dlrm_pattern::match_dlrm(self.parsed) {
|
||||
dlrm_pattern::emit_megakernel(self, shape)?;
|
||||
return self.emit_outputs();
|
||||
}
|
||||
|
||||
let nodes = &self.parsed.program.graph_module.graph.nodes;
|
||||
for (i, node) in nodes.iter().enumerate() {
|
||||
self.translate_node(node)
|
||||
.with_context(|| format!("Failed to translate node {i}: {}", node.target))?;
|
||||
}
|
||||
self.emit_outputs()
|
||||
}
|
||||
|
||||
/// Walks the parsed graph's user outputs, applies the wrap/cast rules
|
||||
/// that downstream codegen relies on, then attaches an `Output` node
|
||||
/// per user-output. Shared by the normal dispatch path and the DLRM
|
||||
/// megakernel fast path.
|
||||
fn emit_outputs(&mut self) -> Result<()> {
|
||||
let output_names = self.parsed.output_names();
|
||||
for name in &output_names {
|
||||
let tensor = self.get_tensor(name)?;
|
||||
@@ -84,12 +105,20 @@ impl<'a> Translator<'a> {
|
||||
} else if tensor.dtype == DType::Int {
|
||||
tensor
|
||||
} else {
|
||||
// The `+ 0.0` wrap pulls double duty: it materializes a fresh
|
||||
// buffer for outputs that alias an Input (passthrough
|
||||
// `return x`), AND it acts as an anchor that survives egglog
|
||||
// rewriting, so the downstream runtime can find the producer
|
||||
// node for outputs whose original op (e.g. Reduce with
|
||||
// keepdims, Conv) gets folded away during optimization.
|
||||
// Removing it broke 24 test_hlir_ops tests with "Cannot find
|
||||
// output tensor!" — keep it until that anchor invariant is
|
||||
// refactored elsewhere.
|
||||
tensor + 0.0
|
||||
};
|
||||
tensor.output();
|
||||
self.output_ids.push((name.clone(), tensor.id));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
@@ -256,6 +256,97 @@ impl<'a> Translator<'a> {
|
||||
Ok(weight.gather(ids_expanded + arange_expanded))
|
||||
}
|
||||
|
||||
/// `aten._embedding_bag` / `aten._embedding_bag_forward_only`
|
||||
///
|
||||
/// Signature: (weight, indices, offsets, scale_grad_by_freq=False, mode=0,
|
||||
/// sparse=False, per_sample_weights=None, include_last_offset=False,
|
||||
/// padding_idx=-1) -> (output, offset2bag, bag_size, max_indices)
|
||||
///
|
||||
/// Strategy: for the bag-size-uniform case (N indices spread evenly across
|
||||
/// B bags, i.e. N % B == 0), reshape gather output [N, D] into [B, K, D]
|
||||
/// and reduce along K according to `mode`. We deliberately read uniformity
|
||||
/// off the *static shapes* of `indices` and `offsets` — non-uniform bags
|
||||
/// require a runtime segment-sum primitive we don't yet have.
|
||||
///
|
||||
/// DLRM hits the K=1 special case (offsets=[0,1,...,B-1], indices=[B]) per
|
||||
/// sparse table per sample — the same lookup pattern as `aten.embedding`.
|
||||
/// Only the first tuple element is materialized; the bookkeeping outputs
|
||||
/// (offset2bag, bag_size, max_indices) are inference-time dead ends.
|
||||
pub(crate) fn translate_embedding_bag(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let weight = self.get_input_tensor(node, 0)?;
|
||||
let indices = self.get_input_tensor(node, 1)?;
|
||||
let offsets = self.get_input_tensor(node, 2)?;
|
||||
let mode = self.get_int_arg(node, 4).unwrap_or(0);
|
||||
let include_last_offset = self.get_bool_arg(node, 7).unwrap_or(false);
|
||||
|
||||
if let Some(arg) = node.inputs.get(6)
|
||||
&& arg.arg.as_tensor_name().is_some()
|
||||
{
|
||||
bail!("_embedding_bag: per_sample_weights not supported");
|
||||
}
|
||||
|
||||
if indices.shape.len() != 1 || offsets.shape.len() != 1 {
|
||||
bail!(
|
||||
"_embedding_bag: expected 1-D indices and offsets, got shapes {:?}, {:?}",
|
||||
indices.shape.dims,
|
||||
offsets.shape.dims
|
||||
);
|
||||
}
|
||||
let n = indices.shape.dims[0]
|
||||
.to_usize()
|
||||
.context("_embedding_bag: indices length must be statically known")?;
|
||||
let b_raw = offsets.shape.dims[0]
|
||||
.to_usize()
|
||||
.context("_embedding_bag: offsets length must be statically known")?;
|
||||
let b = if include_last_offset { b_raw - 1 } else { b_raw };
|
||||
if b == 0 {
|
||||
bail!("_embedding_bag: empty bag set");
|
||||
}
|
||||
if n % b != 0 {
|
||||
bail!(
|
||||
"_embedding_bag: non-uniform bag size not supported (indices={n}, bags={b})"
|
||||
);
|
||||
}
|
||||
let k = n / b;
|
||||
let hidden_dim = weight.shape.dims[1];
|
||||
|
||||
// Step 1: gather weight rows. Same construction as translate_embedding —
|
||||
// flatten the (idx, hidden) pair into a single offset into the weight
|
||||
// matrix and gather. Result: [N, D].
|
||||
let indices_int = indices.cast(DType::Int);
|
||||
let ids_expanded = (indices_int * hidden_dim).expand_dim(1, hidden_dim);
|
||||
let arange = self.graph.arange(hidden_dim);
|
||||
let arange_expanded = arange.expand_dim(0, indices.shape.dims[0]);
|
||||
let gathered = weight.gather(ids_expanded + arange_expanded);
|
||||
|
||||
// Step 2: bag-size-1 → already [B, D]; skip reshape/reduce.
|
||||
if k == 1 {
|
||||
return Ok(gathered);
|
||||
}
|
||||
|
||||
// Step 3: reshape [B*K, D] → [B, K, D] (contiguous, identity stride view).
|
||||
let bag_shape = vec![
|
||||
Expression::from(b),
|
||||
Expression::from(k),
|
||||
hidden_dim,
|
||||
];
|
||||
let mut bagged = GraphTensor {
|
||||
id: gathered.id,
|
||||
graph_ref: gathered.graph_ref,
|
||||
shape: ShapeTracker::new(bag_shape),
|
||||
dtype: gathered.dtype,
|
||||
};
|
||||
|
||||
// Step 4: reduce along axis=1.
|
||||
bagged = match mode {
|
||||
0 => bagged.sum(1),
|
||||
1 => bagged.mean(1),
|
||||
2 => bagged.max(1),
|
||||
m => bail!("_embedding_bag: unsupported mode {m} (0=sum, 1=mean, 2=max)"),
|
||||
};
|
||||
Ok(bagged)
|
||||
}
|
||||
|
||||
pub(crate) fn translate_index_tensor(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let source = self.get_input_tensor(node, 0)?;
|
||||
|
||||
@@ -318,6 +409,20 @@ impl<'a> Translator<'a> {
|
||||
|
||||
let index_names = &index_names;
|
||||
|
||||
// Prefix-of-Nones case: `source[:, ..., :, idx_0, idx_1, ..., idx_{m-1}]`
|
||||
// — indices apply to dims [first..first+m), not [0..m). The original
|
||||
// multi-index path below assumes first==0 and silently mis-strides
|
||||
// (and mis-flattens) when called with a prefix; route to the
|
||||
// prefix-aware path before falling through. Suffix-of-Nones after the
|
||||
// indices is not yet supported here.
|
||||
if first_non_none_dim > 0 {
|
||||
return self.translate_index_tensor_with_prefix(
|
||||
source,
|
||||
index_names,
|
||||
first_non_none_dim,
|
||||
);
|
||||
}
|
||||
|
||||
let src_shape = source.shape.dims;
|
||||
let n_indexed = index_names.len();
|
||||
|
||||
@@ -398,6 +503,132 @@ impl<'a> Translator<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Advanced indexing with a `None` prefix: `source[:, ..., :, i0, i1, ...]`.
|
||||
///
|
||||
/// Output shape: `prefix_dims ++ idx_shape ++ suffix_dims` where
|
||||
/// `prefix_dims = src.shape[..first]`, `suffix_dims = src.shape[first+m..]`,
|
||||
/// and `idx_shape` is the broadcast shape of the m index tensors.
|
||||
///
|
||||
/// Currently supports the no-suffix case (indices land on the trailing
|
||||
/// dims). DLRM's dot interaction hits this: `Z[:, li, lj]` with
|
||||
/// `Z: [B, ni, nj]`, `li, lj: [L]`.
|
||||
fn translate_index_tensor_with_prefix(
|
||||
&mut self,
|
||||
source: GraphTensor,
|
||||
index_names: &[crate::pt2_schema::TensorName],
|
||||
first: usize,
|
||||
) -> Result<GraphTensor> {
|
||||
let src_shape = source.shape.dims;
|
||||
let n_indexed = index_names.len();
|
||||
let src_rank = src_shape.len();
|
||||
if first + n_indexed > src_rank {
|
||||
bail!(
|
||||
"index.Tensor (prefix): {n_indexed} indices starting at dim {first} \
|
||||
exceed source rank {src_rank}"
|
||||
);
|
||||
}
|
||||
let prefix_dims: Vec<Expression> = src_shape[..first].to_vec();
|
||||
let indexed_dims: Vec<Expression> = src_shape[first..first + n_indexed].to_vec();
|
||||
let suffix_dims: Vec<Expression> = src_shape[first + n_indexed..].to_vec();
|
||||
if !suffix_dims.is_empty() {
|
||||
bail!(
|
||||
"index.Tensor (prefix): trailing-dim suffix after indices not \
|
||||
supported (prefix={} indexed={} suffix={})",
|
||||
prefix_dims.len(),
|
||||
n_indexed,
|
||||
suffix_dims.len()
|
||||
);
|
||||
}
|
||||
|
||||
// Per-axis strides within the indexed subspace (right-to-left product).
|
||||
let mut strides = vec![Expression::from(1usize); n_indexed];
|
||||
for i in (0..n_indexed - 1).rev() {
|
||||
strides[i] = strides[i + 1] * indexed_dims[i + 1];
|
||||
}
|
||||
let indexed_size = indexed_dims
|
||||
.iter()
|
||||
.copied()
|
||||
.fold(Expression::from(1usize), |a, b| a * b);
|
||||
|
||||
// Collapse the m index tensors into a single flat index in the indexed
|
||||
// subspace. Negative entries get normalized per axis.
|
||||
let mut flat_idx: Option<GraphTensor> = None;
|
||||
for (i, idx_name) in index_names.iter().enumerate() {
|
||||
let idx_t = self.get_tensor(&idx_name.name)?.cast(DType::Int);
|
||||
let axis_size = indexed_dims[i];
|
||||
let zero = self.graph.constant(0).expand_rhs(idx_t.shape);
|
||||
let is_neg = idx_t.lt(zero).cast(DType::Int);
|
||||
let idx_norm = idx_t + is_neg * axis_size;
|
||||
let stride = strides[i];
|
||||
let weighted = if stride.to_usize() == Some(1) {
|
||||
idx_norm
|
||||
} else {
|
||||
idx_norm * stride
|
||||
};
|
||||
flat_idx = Some(match flat_idx {
|
||||
Some(acc) => {
|
||||
let (a, w) = broadcast_binary(acc, weighted);
|
||||
a + w
|
||||
}
|
||||
None => weighted,
|
||||
});
|
||||
}
|
||||
let flat_idx = flat_idx.context("index.Tensor (prefix): no indices")?;
|
||||
let idx_shape: Vec<Expression> = flat_idx.shape.dims.to_vec();
|
||||
|
||||
// Build the absolute flat index over `source` viewed as 1D, shape
|
||||
// `prefix_dims ++ idx_shape`:
|
||||
// abs[p..., k...] = flat_prefix(p...) * indexed_size + flat_idx[k...]
|
||||
// Construct by promoting `flat_idx` to the full rank then adding a
|
||||
// broadcast prefix-offset tensor.
|
||||
let mut full_shape: Vec<Expression> = prefix_dims.clone();
|
||||
full_shape.extend_from_slice(&idx_shape);
|
||||
|
||||
// Promote flat_idx: insert prefix_dims leading axes, then expand.
|
||||
let mut idx_promoted = flat_idx;
|
||||
for _ in 0..prefix_dims.len() {
|
||||
idx_promoted = idx_promoted.expand_dim(0, Expression::from(1usize));
|
||||
}
|
||||
idx_promoted.shape.expand(full_shape.clone());
|
||||
|
||||
// Prefix offset: for each prefix dim pi (right-to-left), accumulate
|
||||
// arange(prefix_dims[pi]) * (product_of_more_inner_prefix_dims * indexed_size).
|
||||
let mut prefix_offset: Option<GraphTensor> = None;
|
||||
let mut cum_stride = indexed_size;
|
||||
for (pi, pd) in prefix_dims.iter().enumerate().rev() {
|
||||
let ar = self.graph.arange(*pd) * cum_stride;
|
||||
// arange is shape [pd]; lift it into full_shape at position pi.
|
||||
let mut ar_promoted = ar;
|
||||
for _ in 0..pi {
|
||||
ar_promoted = ar_promoted.expand_dim(0, Expression::from(1usize));
|
||||
}
|
||||
let trailing = full_shape.len() - pi - 1;
|
||||
for _ in 0..trailing {
|
||||
let r = ar_promoted.shape.len();
|
||||
ar_promoted = ar_promoted.expand_dim(r, Expression::from(1usize));
|
||||
}
|
||||
ar_promoted.shape.expand(full_shape.clone());
|
||||
prefix_offset = Some(match prefix_offset {
|
||||
Some(acc) => acc + ar_promoted,
|
||||
None => ar_promoted,
|
||||
});
|
||||
cum_stride = cum_stride * *pd;
|
||||
}
|
||||
|
||||
let final_idx = match prefix_offset {
|
||||
Some(po) => idx_promoted + po,
|
||||
None => idx_promoted,
|
||||
};
|
||||
|
||||
// Flatten source to 1D and gather with the absolute index.
|
||||
let total: Expression = src_shape
|
||||
.iter()
|
||||
.copied()
|
||||
.fold(Expression::from(1usize), |a, b| a * b);
|
||||
let fully_flat = reshape_tensor(source, vec![total]);
|
||||
Ok(fully_flat.gather(final_idx))
|
||||
}
|
||||
|
||||
pub(crate) fn translate_gather(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let dim = self.get_int_arg(node, 1)?;
|
||||
|
||||
@@ -43,6 +43,19 @@ class CompiledModel:
|
||||
else torch.float32
|
||||
for i in range(len(self._input_names))
|
||||
]
|
||||
# Pre-zip + caches for the hot path. The CudaRuntime now preserves
|
||||
# external-pointer registrations across execute() calls and treats
|
||||
# set_device_ptr as a no-op when the pointer is unchanged — caching
|
||||
# the (name, ptr) here avoids the pyo3 round-trip entirely in tight
|
||||
# loops where PyTorch's caching allocator keeps re-handing back the
|
||||
# same tensor (e.g. inference loops with reused activation buffers).
|
||||
self._input_specs = list(zip(self._input_names, self._input_dtypes))
|
||||
self._last_input_ptrs: dict[str, int] = {}
|
||||
# Output dtype/zero-copy decisions are properties of the compiled
|
||||
# graph and never change; computing them lazily and caching avoids
|
||||
# ~10µs of pyo3 calls per iter.
|
||||
self._output_torch_dtypes_cache = None
|
||||
self._output_zero_copy_cache = None
|
||||
|
||||
def set_dim(self, param_name: str, value: int) -> None:
|
||||
"""Set a dynamic dimension value by its param name."""
|
||||
@@ -89,22 +102,41 @@ class CompiledModel:
|
||||
|
||||
# Set user input data via pointer.
|
||||
# Convert to the graph's expected dtype so bytes match the Input node's dtype tag.
|
||||
# For CUDA inputs already in the expected dtype + contiguous, we
|
||||
# skip the detach/contiguous/to chain (those allocate new Tensor
|
||||
# objects even when they're no-ops) and short-circuit set_input_device_ptr
|
||||
# when the pointer hasn't moved since the last call. The runtime
|
||||
# treats same-ptr re-registration as a no-op too, but skipping the
|
||||
# pyo3 round-trip here saves another ~5µs per input.
|
||||
# For CUDA inputs, keep references alive so the caching allocator doesn't
|
||||
# recycle GPU memory before run() reads the pointers.
|
||||
_input_refs = []
|
||||
for name, tensor, expected_dtype in zip(
|
||||
self._input_names, user_inputs, self._input_dtypes
|
||||
):
|
||||
if self._supports_device_ptrs and tensor.is_cuda:
|
||||
t = tensor.detach().contiguous().to(expected_dtype)
|
||||
n_bytes = t.numel() * t.element_size()
|
||||
self._graph.set_input_device_ptr(name, t.data_ptr(), n_bytes)
|
||||
graph = self._graph
|
||||
last_input_ptrs = self._last_input_ptrs
|
||||
if self._supports_device_ptrs:
|
||||
for (name, expected_dtype), tensor in zip(self._input_specs, user_inputs):
|
||||
if (
|
||||
tensor.is_cuda
|
||||
and tensor.dtype is expected_dtype
|
||||
and tensor.is_contiguous()
|
||||
):
|
||||
t = tensor
|
||||
else:
|
||||
t = tensor.detach().contiguous().to(expected_dtype)
|
||||
ptr = t.data_ptr()
|
||||
if last_input_ptrs.get(name) != ptr:
|
||||
graph.set_input_device_ptr(name, ptr, t.numel() * t.element_size())
|
||||
last_input_ptrs[name] = ptr
|
||||
_input_refs.append(t)
|
||||
else:
|
||||
else:
|
||||
for (name, expected_dtype), tensor in zip(self._input_specs, user_inputs):
|
||||
t = tensor.detach().cpu().contiguous().to(expected_dtype)
|
||||
n_bytes = t.numel() * t.element_size()
|
||||
dtype_code = _torch_dtype_code(t.dtype)
|
||||
self._graph.set_input_from_ptr(name, t.data_ptr(), n_bytes, dtype_code)
|
||||
graph.set_input_from_ptr(
|
||||
name,
|
||||
t.data_ptr(),
|
||||
t.numel() * t.element_size(),
|
||||
_torch_dtype_code(t.dtype),
|
||||
)
|
||||
|
||||
# Resolve output shapes before run() (needed for pre-allocation).
|
||||
if self._has_dynamic_dims:
|
||||
|
||||
209
crates/luminal_python/tests/test_dlrm.py
Normal file
209
crates/luminal_python/tests/test_dlrm.py
Normal file
@@ -0,0 +1,209 @@
|
||||
"""End-to-end compile tests for a faithful DLRM-style recommender.
|
||||
|
||||
`MiniDLRM` below mirrors `DLRM_Net` from facebookresearch/dlrm:
|
||||
bottom-MLP on dense features, an `EmbeddingBag` per sparse table, dot-product
|
||||
interaction over the (1 + n_sparse) feature vectors, and a top-MLP. The
|
||||
forward signature `(dense_x, lS_o, lS_i)` matches DLRM exactly.
|
||||
|
||||
This is the smallest model that exercises the three translator paths added for
|
||||
DLRM:
|
||||
- `aten._embedding_bag_forward_only.default` (uniform-bag-size lowering)
|
||||
- `aten.index.Tensor` with a `None` prefix (`Z[:, li, lj]`)
|
||||
- the existing `aten.bmm` / `aten.cat` paths under the above feeders
|
||||
"""
|
||||
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
import torch._dynamo
|
||||
import torch.nn as nn
|
||||
|
||||
from luminal import luminal_backend
|
||||
|
||||
|
||||
class MiniDLRM(nn.Module):
|
||||
"""Minimal faithful DLRM (dot interaction, mode='sum' EmbeddingBag)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
m_spa: int,
|
||||
ln_emb: list[int],
|
||||
ln_bot: list[int],
|
||||
ln_top: list[int],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert ln_bot[-1] == m_spa, "bottom MLP must end at m_spa"
|
||||
n_feat = 1 + len(ln_emb)
|
||||
n_pairs = n_feat * (n_feat - 1) // 2
|
||||
assert ln_top[0] == n_pairs + m_spa, (
|
||||
f"top MLP entry width must equal n_pairs ({n_pairs}) + m_spa ({m_spa}) "
|
||||
f"= {n_pairs + m_spa}, got {ln_top[0]}"
|
||||
)
|
||||
self.m_spa = m_spa
|
||||
self.emb_l = nn.ModuleList(
|
||||
[nn.EmbeddingBag(int(n), m_spa, mode="sum") for n in ln_emb]
|
||||
)
|
||||
self.bot_l = self._build_mlp(ln_bot, sigmoid_last=False)
|
||||
self.top_l = self._build_mlp(ln_top, sigmoid_last=True)
|
||||
|
||||
@staticmethod
|
||||
def _build_mlp(sizes: list[int], sigmoid_last: bool) -> nn.Sequential:
|
||||
layers: list[nn.Module] = []
|
||||
for i in range(len(sizes) - 1):
|
||||
layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=True))
|
||||
if i == len(sizes) - 2 and sigmoid_last:
|
||||
layers.append(nn.Sigmoid())
|
||||
else:
|
||||
layers.append(nn.ReLU())
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def _apply_emb(
|
||||
self, lS_o: list[torch.Tensor], lS_i: list[torch.Tensor]
|
||||
) -> list[torch.Tensor]:
|
||||
return [self.emb_l[k](lS_i[k], lS_o[k]) for k in range(len(self.emb_l))]
|
||||
|
||||
def _interact(self, x: torch.Tensor, ly: list[torch.Tensor]) -> torch.Tensor:
|
||||
batch_size, d = x.shape
|
||||
T = torch.cat([x] + ly, dim=1).view((batch_size, -1, d))
|
||||
Z = torch.bmm(T, torch.transpose(T, 1, 2))
|
||||
_, ni, nj = Z.shape
|
||||
li = torch.tensor(
|
||||
[i for i in range(ni) for _ in range(i)], device=x.device
|
||||
)
|
||||
lj = torch.tensor(
|
||||
[j for i in range(nj) for j in range(i)], device=x.device
|
||||
)
|
||||
Zflat = Z[:, li, lj]
|
||||
return torch.cat([x, Zflat], dim=1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
dense_x: torch.Tensor,
|
||||
lS_o: list[torch.Tensor],
|
||||
lS_i: list[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
x = self.bot_l(dense_x)
|
||||
ly = self._apply_emb(lS_o, lS_i)
|
||||
z = self._interact(x, ly)
|
||||
return self.top_l(z)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_inputs(
|
||||
batch_size: int,
|
||||
dense_dim: int,
|
||||
ln_emb: list[int],
|
||||
bag_size: int,
|
||||
device: torch.device,
|
||||
) -> tuple[torch.Tensor, list[torch.Tensor], list[torch.Tensor]]:
|
||||
dense_x = torch.rand(batch_size, dense_dim, device=device)
|
||||
if bag_size == 1:
|
||||
offsets = [
|
||||
torch.arange(batch_size, dtype=torch.long, device=device)
|
||||
for _ in ln_emb
|
||||
]
|
||||
indices = [
|
||||
torch.randint(0, int(n), (batch_size,), dtype=torch.long, device=device)
|
||||
for n in ln_emb
|
||||
]
|
||||
else:
|
||||
offsets = [
|
||||
torch.arange(
|
||||
0,
|
||||
batch_size * bag_size,
|
||||
bag_size,
|
||||
dtype=torch.long,
|
||||
device=device,
|
||||
)
|
||||
for _ in ln_emb
|
||||
]
|
||||
indices = [
|
||||
torch.randint(
|
||||
0, int(n), (batch_size * bag_size,), dtype=torch.long, device=device
|
||||
)
|
||||
for n in ln_emb
|
||||
]
|
||||
return dense_x, offsets, indices
|
||||
|
||||
|
||||
def _build_model(
|
||||
m_spa: int,
|
||||
ln_emb: list[int],
|
||||
ln_bot: list[int],
|
||||
device: torch.device,
|
||||
) -> MiniDLRM:
|
||||
torch.manual_seed(0)
|
||||
n_feat = 1 + len(ln_emb)
|
||||
n_pairs = n_feat * (n_feat - 1) // 2
|
||||
ln_top = [n_pairs + m_spa, 8, 1]
|
||||
model = MiniDLRM(m_spa, ln_emb, ln_bot, ln_top).to(device).eval()
|
||||
return model
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_dlrm_dot_bag1_smallbatch(device: torch.device) -> None:
|
||||
"""The canonical DLRM eval path: 1 lookup per sample per sparse table."""
|
||||
m_spa = 4
|
||||
ln_emb = [10, 20, 30]
|
||||
ln_bot = [13, 8, m_spa]
|
||||
model = _build_model(m_spa, ln_emb, ln_bot, device)
|
||||
inputs = _make_inputs(batch_size=2, dense_dim=13, ln_emb=ln_emb, bag_size=1, device=device)
|
||||
|
||||
compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
with torch.no_grad():
|
||||
eager = model(*inputs)
|
||||
out = compiled(*inputs)
|
||||
assert torch.allclose(out, eager, atol=1e-5)
|
||||
|
||||
|
||||
def test_dlrm_dot_bag1_largerbatch(device: torch.device) -> None:
|
||||
"""Larger batch (64) — sanity-check that the bs-1 specialization isn't load-bearing."""
|
||||
m_spa = 4
|
||||
ln_emb = [10, 20, 30]
|
||||
ln_bot = [13, 8, m_spa]
|
||||
model = _build_model(m_spa, ln_emb, ln_bot, device)
|
||||
inputs = _make_inputs(batch_size=64, dense_dim=13, ln_emb=ln_emb, bag_size=1, device=device)
|
||||
|
||||
compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
with torch.no_grad():
|
||||
eager = model(*inputs)
|
||||
out = compiled(*inputs)
|
||||
assert torch.allclose(out, eager, atol=1e-4)
|
||||
|
||||
|
||||
def test_dlrm_dot_multihot(device: torch.device) -> None:
|
||||
"""Uniform multi-hot bags (bag_size=3) — exercises the reshape+sum path."""
|
||||
m_spa = 4
|
||||
ln_emb = [10, 20, 30]
|
||||
ln_bot = [13, 8, m_spa]
|
||||
model = _build_model(m_spa, ln_emb, ln_bot, device)
|
||||
inputs = _make_inputs(batch_size=4, dense_dim=13, ln_emb=ln_emb, bag_size=3, device=device)
|
||||
|
||||
compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
with torch.no_grad():
|
||||
eager = model(*inputs)
|
||||
out = compiled(*inputs)
|
||||
assert torch.allclose(out, eager, atol=1e-5)
|
||||
|
||||
|
||||
def test_dlrm_dot_larger_tables(device: torch.device) -> None:
|
||||
"""Verifies bigger embedding tables don't change the path."""
|
||||
m_spa = 4
|
||||
ln_emb = [50, 100, 200]
|
||||
ln_bot = [13, 8, m_spa]
|
||||
model = _build_model(m_spa, ln_emb, ln_bot, device)
|
||||
inputs = _make_inputs(batch_size=4, dense_dim=13, ln_emb=ln_emb, bag_size=1, device=device)
|
||||
|
||||
compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
with torch.no_grad():
|
||||
eager = model(*inputs)
|
||||
out = compiled(*inputs)
|
||||
assert torch.allclose(out, eager, atol=1e-5)
|
||||
17
examples/dlrm/Cargo.toml
Normal file
17
examples/dlrm/Cargo.toml
Normal file
@@ -0,0 +1,17 @@
|
||||
[package]
|
||||
name = "dlrm"
|
||||
version = "0.1.0"
|
||||
edition = "2024"
|
||||
|
||||
[[bin]]
|
||||
name = "dlrm"
|
||||
path = "src/main.rs"
|
||||
|
||||
[dependencies]
|
||||
luminal = { path = "../.." }
|
||||
luminal_nn = { path = "../../crates/luminal_nn" }
|
||||
luminal_cuda_lite = { path = "../../crates/luminal_cuda_lite" }
|
||||
safetensors = "0.7.0"
|
||||
memmap2 = "0.9.9"
|
||||
bytemuck = "1.24.0"
|
||||
rand = "0.9.2"
|
||||
306
examples/dlrm/bench.py
Normal file
306
examples/dlrm/bench.py
Normal file
@@ -0,0 +1,306 @@
|
||||
"""DLRM inference latency benchmark across PyTorch backends + luminal.
|
||||
|
||||
Backends measured:
|
||||
1. PyTorch eager
|
||||
2. torch.compile (default backend = inductor, mode="reduce-overhead")
|
||||
3. AOTInductor (export → aoti_compile_and_package → load → run)
|
||||
4. CUDA graphs (capture-replay around the eager model)
|
||||
5. PyTorch + luminal_backend (torch.compile with our PT2 → luminal backend)
|
||||
|
||||
The rust luminal path is measured separately by the dlrm binary's --bench
|
||||
flag and the two results are combined in the rank table later.
|
||||
|
||||
Methodology:
|
||||
- Same MiniDLRM at the small config, batch_size=2 (matches export.py and
|
||||
the rust binary so the comparison is apples-to-apples).
|
||||
- 50 warmup iters per backend, 500 measured iters.
|
||||
- Per-iteration latency via paired cudaEvent_record + elapsed_time.
|
||||
- Report mean / p50 / p99 in microseconds; also dump every measurement
|
||||
to /tmp/dlrm_bench_<backend>.txt so other consumers can re-aggregate.
|
||||
|
||||
Run:
|
||||
/lambda/nfs/tucker-fs/second/luminal/crates/luminal_python/.venv/bin/python \
|
||||
examples/dlrm/bench.py
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import statistics
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
|
||||
# MiniDLRM lives in tests.
|
||||
TESTS_DIR = (
|
||||
Path(__file__).resolve().parents[2] / "crates" / "luminal_python" / "tests"
|
||||
)
|
||||
sys.path.insert(0, str(TESTS_DIR))
|
||||
from test_dlrm import MiniDLRM # noqa: E402
|
||||
|
||||
from luminal import luminal_backend # noqa: E402
|
||||
|
||||
DEVICE = torch.device("cuda")
|
||||
WARMUP = 50
|
||||
ITERS = 500
|
||||
|
||||
M_SPA = 4
|
||||
LN_EMB = [10, 20, 30]
|
||||
LN_BOT = [13, 8, M_SPA]
|
||||
LN_TOP = [10, 8, 1]
|
||||
# Real-workload DLRM batch — kernel work dominates per-launch overhead.
|
||||
BATCH = 2048
|
||||
|
||||
|
||||
def make_model() -> torch.nn.Module:
|
||||
torch.manual_seed(0)
|
||||
return MiniDLRM(M_SPA, LN_EMB, LN_BOT, LN_TOP).to(DEVICE).eval()
|
||||
|
||||
|
||||
def make_inputs() -> tuple[torch.Tensor, list[torch.Tensor], list[torch.Tensor]]:
|
||||
torch.manual_seed(42)
|
||||
dense_x = torch.rand(BATCH, LN_BOT[0], device=DEVICE)
|
||||
indices = [
|
||||
torch.randint(0, n, (BATCH,), dtype=torch.long, device=DEVICE) for n in LN_EMB
|
||||
]
|
||||
offsets = [torch.arange(BATCH, dtype=torch.long, device=DEVICE) for _ in LN_EMB]
|
||||
return dense_x, offsets, indices
|
||||
|
||||
|
||||
def time_callable(fn: Callable[[], torch.Tensor], iters: int) -> list[float]:
|
||||
"""Time `fn` over `iters` iterations using CUDA events. Returns per-iter
|
||||
microseconds."""
|
||||
start_evts = [torch.cuda.Event(enable_timing=True) for _ in range(iters)]
|
||||
end_evts = [torch.cuda.Event(enable_timing=True) for _ in range(iters)]
|
||||
torch.cuda.synchronize()
|
||||
for i in range(iters):
|
||||
start_evts[i].record()
|
||||
_ = fn()
|
||||
end_evts[i].record()
|
||||
torch.cuda.synchronize()
|
||||
return [start_evts[i].elapsed_time(end_evts[i]) * 1000.0 for i in range(iters)]
|
||||
|
||||
|
||||
def report(name: str, samples_us: list[float]) -> dict[str, float]:
|
||||
samples_us = sorted(samples_us)
|
||||
n = len(samples_us)
|
||||
mean = sum(samples_us) / n
|
||||
p50 = samples_us[n // 2]
|
||||
p99 = samples_us[int(n * 0.99)]
|
||||
print(f" {name:<32s} mean={mean:8.2f}µs p50={p50:8.2f}µs p99={p99:8.2f}µs")
|
||||
# Dump every sample for downstream aggregation.
|
||||
out_path = f"/tmp/dlrm_bench_{name.replace(' ', '_').replace('(', '').replace(')', '')}.txt"
|
||||
Path(out_path).write_text("\n".join(f"{s:.4f}" for s in samples_us))
|
||||
return {"name": name, "mean": mean, "p50": p50, "p99": p99, "n": n}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Backends
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def bench_eager() -> dict[str, float]:
|
||||
model = make_model()
|
||||
inputs = make_inputs()
|
||||
|
||||
@torch.no_grad()
|
||||
def fn() -> torch.Tensor:
|
||||
return model(*inputs)
|
||||
|
||||
for _ in range(WARMUP):
|
||||
fn()
|
||||
return report("eager", time_callable(fn, ITERS))
|
||||
|
||||
|
||||
def bench_torch_compile() -> dict[str, float]:
|
||||
torch._dynamo.reset()
|
||||
model = make_model()
|
||||
inputs = make_inputs()
|
||||
compiled = torch.compile(model, mode="reduce-overhead")
|
||||
|
||||
@torch.no_grad()
|
||||
def fn() -> torch.Tensor:
|
||||
return compiled(*inputs)
|
||||
|
||||
for _ in range(WARMUP):
|
||||
fn()
|
||||
return report("torch.compile (inductor)", time_callable(fn, ITERS))
|
||||
|
||||
|
||||
def bench_aoti() -> dict[str, float]:
|
||||
"""AOTInductor: export → compile-and-package → load → run.
|
||||
|
||||
Note: torch.export currently treats list[Tensor] inputs as pytree-flattened,
|
||||
so the runtime callable takes positional tensors. We unpack manually.
|
||||
"""
|
||||
torch._dynamo.reset()
|
||||
model = make_model()
|
||||
dense_x, offsets, indices = make_inputs()
|
||||
|
||||
# Wrap to surface tensor inputs at the top-level positional signature.
|
||||
class FlatWrapper(torch.nn.Module):
|
||||
def __init__(self, m: torch.nn.Module) -> None:
|
||||
super().__init__()
|
||||
self.m = m
|
||||
|
||||
def forward(
|
||||
self,
|
||||
dense_x: torch.Tensor,
|
||||
o0: torch.Tensor,
|
||||
o1: torch.Tensor,
|
||||
o2: torch.Tensor,
|
||||
i0: torch.Tensor,
|
||||
i1: torch.Tensor,
|
||||
i2: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return self.m(dense_x, [o0, o1, o2], [i0, i1, i2])
|
||||
|
||||
flat_model = FlatWrapper(model).to(DEVICE).eval()
|
||||
flat_inputs = (dense_x, *offsets, *indices)
|
||||
|
||||
with torch.no_grad():
|
||||
ep = torch.export.export(flat_model, flat_inputs)
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
pkg_path = os.path.join(tmp, "dlrm.pt2")
|
||||
torch._inductor.aoti_compile_and_package(ep, package_path=pkg_path)
|
||||
loaded = torch._inductor.aoti_load_package(pkg_path)
|
||||
|
||||
@torch.no_grad()
|
||||
def fn() -> torch.Tensor:
|
||||
return loaded(*flat_inputs)
|
||||
|
||||
for _ in range(WARMUP):
|
||||
fn()
|
||||
return report("AOTInductor", time_callable(fn, ITERS))
|
||||
|
||||
|
||||
def bench_cuda_graphs() -> dict[str, float]:
|
||||
"""Capture the eager forward as a CUDA graph, then replay.
|
||||
|
||||
MiniDLRM builds the `li`/`lj` lower-triangular index tensors via
|
||||
`torch.tensor([...], device=...)` inside `_interact`, which triggers a
|
||||
fresh host→device copy each call — and CUDA-graph capture can't observe
|
||||
non-pinned host→device copies. Wrap the model to pre-bake those indices
|
||||
as cuda buffers on the wrapper, then patch the bound method.
|
||||
"""
|
||||
model = make_model()
|
||||
n_feat = 1 + len(LN_EMB)
|
||||
li_const = torch.tensor(
|
||||
[i for i in range(n_feat) for _ in range(i)], device=DEVICE
|
||||
)
|
||||
lj_const = torch.tensor(
|
||||
[j for i in range(n_feat) for j in range(i)], device=DEVICE
|
||||
)
|
||||
|
||||
def _interact_static(self, x: torch.Tensor, ly: list[torch.Tensor]) -> torch.Tensor:
|
||||
bs, d = x.shape
|
||||
T = torch.cat([x] + ly, dim=1).view((bs, -1, d))
|
||||
Z = torch.bmm(T, torch.transpose(T, 1, 2))
|
||||
Zflat = Z[:, li_const, lj_const]
|
||||
return torch.cat([x, Zflat], dim=1)
|
||||
|
||||
# Bind the static version so `self` resolves correctly.
|
||||
import types
|
||||
|
||||
model._interact = types.MethodType(_interact_static, model)
|
||||
|
||||
dense_x, offsets, indices = make_inputs()
|
||||
static_dense = dense_x.clone()
|
||||
static_offsets = [o.clone() for o in offsets]
|
||||
static_indices = [i.clone() for i in indices]
|
||||
|
||||
@torch.no_grad()
|
||||
def fwd() -> torch.Tensor:
|
||||
return model(static_dense, static_offsets, static_indices)
|
||||
|
||||
# CUDA-graph prep: a stream warmup, then capture.
|
||||
s = torch.cuda.Stream()
|
||||
s.wait_stream(torch.cuda.current_stream())
|
||||
with torch.cuda.stream(s):
|
||||
for _ in range(3):
|
||||
_ = fwd()
|
||||
torch.cuda.current_stream().wait_stream(s)
|
||||
|
||||
g = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(g):
|
||||
static_out = fwd()
|
||||
|
||||
@torch.no_grad()
|
||||
def fn() -> torch.Tensor:
|
||||
# Real workloads would copy fresh inputs into static_* here. For pure
|
||||
# replay-latency measurement the inputs are constant.
|
||||
g.replay()
|
||||
return static_out
|
||||
|
||||
for _ in range(WARMUP):
|
||||
fn()
|
||||
return report("CUDA graphs (eager capture)", time_callable(fn, ITERS))
|
||||
|
||||
|
||||
def bench_luminal_backend() -> dict[str, float]:
|
||||
torch._dynamo.reset()
|
||||
model = make_model()
|
||||
inputs = make_inputs()
|
||||
compiled = torch.compile(model, backend=luminal_backend)
|
||||
|
||||
@torch.no_grad()
|
||||
def fn() -> torch.Tensor:
|
||||
return compiled(*inputs)
|
||||
|
||||
for _ in range(WARMUP):
|
||||
fn()
|
||||
return report("luminal_backend (PT2)", time_callable(fn, ITERS))
|
||||
|
||||
|
||||
def main() -> None:
|
||||
torch.cuda.synchronize()
|
||||
print(f"Device: {torch.cuda.get_device_name(0)}")
|
||||
print(f"PyTorch: {torch.__version__}")
|
||||
print(f"Config: m_spa={M_SPA} ln_emb={LN_EMB} batch={BATCH} iters={ITERS}\n")
|
||||
|
||||
rows = []
|
||||
for fn in (bench_eager, bench_torch_compile, bench_aoti, bench_cuda_graphs, bench_luminal_backend):
|
||||
t0 = time.perf_counter()
|
||||
try:
|
||||
rows.append(fn())
|
||||
except Exception as e:
|
||||
print(f" FAILED {fn.__name__}: {type(e).__name__}: {e}")
|
||||
print(f" (setup+bench took {time.perf_counter() - t0:.1f}s)\n")
|
||||
|
||||
# Pull in any externally-produced rust samples (rust luminal binary
|
||||
# writes both —bench and —mega samples to /tmp).
|
||||
for label, path_str in [
|
||||
("rust luminal", "/tmp/dlrm_bench_rust_luminal.txt"),
|
||||
("DLRM megakernel", "/tmp/dlrm_bench_megakernel.txt"),
|
||||
]:
|
||||
p = Path(path_str)
|
||||
if not p.exists():
|
||||
continue
|
||||
samples_us = sorted(float(s) for s in p.read_text().splitlines() if s)
|
||||
n = len(samples_us)
|
||||
rows.append({
|
||||
"name": label,
|
||||
"mean": sum(samples_us) / n,
|
||||
"p50": samples_us[n // 2],
|
||||
"p99": samples_us[int(n * 0.99)],
|
||||
"n": n,
|
||||
})
|
||||
print(f" {label:<32s} mean={rows[-1]['mean']:8.2f}µs "
|
||||
f"p50={rows[-1]['p50']:8.2f}µs p99={rows[-1]['p99']:8.2f}µs "
|
||||
f"(from {path_str})")
|
||||
|
||||
# Rank by mean latency.
|
||||
rows.sort(key=lambda r: r["mean"])
|
||||
print("=" * 60)
|
||||
print("Ranking (mean latency, lower is better):\n")
|
||||
fastest = rows[0]["mean"]
|
||||
print(f" {'#':<3}{'backend':<32s}{'mean µs':>10s}{'vs fastest':>14s}")
|
||||
for i, r in enumerate(rows):
|
||||
ratio = r["mean"] / fastest
|
||||
print(f" {i + 1:<3}{r['name']:<32s}{r['mean']:>10.2f}{ratio:>13.2f}x")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
328
examples/dlrm/bench_sweep.py
Normal file
328
examples/dlrm/bench_sweep.py
Normal file
@@ -0,0 +1,328 @@
|
||||
"""DLRM latency sweep: batch_size × n_sparse_tables × backend.
|
||||
|
||||
Reuses the per-backend timing primitives from `bench.py` but parameterises
|
||||
the model config so we can see how each backend scales along both DLRM's
|
||||
key axes: batch size (parallelism / kernel utilisation) and number of
|
||||
sparse tables (kernel launch count, host-side dispatch cost).
|
||||
|
||||
For each (batch, n_sparse) cell, runs:
|
||||
- PyTorch eager
|
||||
- torch.compile (mode='reduce-overhead')
|
||||
- AOTInductor
|
||||
- CUDA graphs (eager capture)
|
||||
- luminal_backend (PT2)
|
||||
|
||||
The rust luminal path can't be invoked from python; we skip it here. The
|
||||
single-config bench.py remains the cross-check that includes rust.
|
||||
|
||||
Output is one table per backend with rows = batch, cols = n_sparse, plus a
|
||||
final per-cell "winner" matrix.
|
||||
"""
|
||||
|
||||
import gc
|
||||
import sys
|
||||
import time
|
||||
import os
|
||||
import tempfile
|
||||
import types
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
TESTS_DIR = (
|
||||
Path(__file__).resolve().parents[2] / "crates" / "luminal_python" / "tests"
|
||||
)
|
||||
sys.path.insert(0, str(TESTS_DIR))
|
||||
from test_dlrm import MiniDLRM # noqa: E402
|
||||
|
||||
from luminal import luminal_backend # noqa: E402
|
||||
|
||||
DEVICE = torch.device("cuda")
|
||||
WARMUP = 25
|
||||
ITERS = 200 # halved vs bench.py to keep sweep wall-clock reasonable
|
||||
M_SPA = 4
|
||||
|
||||
# Sweep grid — real-workload DLRM batches where matmul efficiency is what's
|
||||
# actually being compared (sub-100 batches were launch-overhead dominated and
|
||||
# said more about wrapper cost than backend quality).
|
||||
BATCH_SIZES = [256, 1024, 2048, 4096]
|
||||
N_SPARSE_LIST = [3, 8, 16]
|
||||
|
||||
|
||||
def make_model(n_sparse: int):
|
||||
torch.manual_seed(0)
|
||||
# Embedding table vocab sizes: alternate small/medium so the lookups
|
||||
# exercise different table widths without making setup time explode.
|
||||
base_vocabs = [10, 20, 30, 40, 60, 80, 100, 120, 160, 200, 240, 320, 400, 500, 640, 800]
|
||||
ln_emb = base_vocabs[:n_sparse]
|
||||
ln_bot = [13, 8, M_SPA]
|
||||
n_feat = 1 + n_sparse
|
||||
n_pairs = n_feat * (n_feat - 1) // 2
|
||||
ln_top = [n_pairs + M_SPA, 8, 1]
|
||||
return MiniDLRM(M_SPA, ln_emb, ln_bot, ln_top).to(DEVICE).eval(), ln_emb
|
||||
|
||||
|
||||
def make_inputs(batch: int, ln_emb: list[int]):
|
||||
torch.manual_seed(42)
|
||||
dense_x = torch.rand(batch, 13, device=DEVICE)
|
||||
indices = [
|
||||
torch.randint(0, n, (batch,), dtype=torch.long, device=DEVICE) for n in ln_emb
|
||||
]
|
||||
offsets = [torch.arange(batch, dtype=torch.long, device=DEVICE) for _ in ln_emb]
|
||||
return dense_x, offsets, indices
|
||||
|
||||
|
||||
def time_callable(fn, iters: int) -> list[float]:
|
||||
start_evts = [torch.cuda.Event(enable_timing=True) for _ in range(iters)]
|
||||
end_evts = [torch.cuda.Event(enable_timing=True) for _ in range(iters)]
|
||||
torch.cuda.synchronize()
|
||||
for i in range(iters):
|
||||
start_evts[i].record()
|
||||
fn()
|
||||
end_evts[i].record()
|
||||
torch.cuda.synchronize()
|
||||
return [start_evts[i].elapsed_time(end_evts[i]) * 1000.0 for i in range(iters)]
|
||||
|
||||
|
||||
def mean_us(samples: list[float]) -> float:
|
||||
return sum(samples) / len(samples)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Backends
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def bench_eager(model, inputs):
|
||||
@torch.no_grad()
|
||||
def fn():
|
||||
return model(*inputs)
|
||||
|
||||
for _ in range(WARMUP):
|
||||
fn()
|
||||
return mean_us(time_callable(fn, ITERS))
|
||||
|
||||
|
||||
def bench_torch_compile(model, inputs):
|
||||
torch._dynamo.reset()
|
||||
compiled = torch.compile(model, mode="reduce-overhead")
|
||||
|
||||
@torch.no_grad()
|
||||
def fn():
|
||||
return compiled(*inputs)
|
||||
|
||||
for _ in range(WARMUP):
|
||||
fn()
|
||||
return mean_us(time_callable(fn, ITERS))
|
||||
|
||||
|
||||
def bench_aoti(model, inputs):
|
||||
torch._dynamo.reset()
|
||||
dense_x, offsets, indices = inputs
|
||||
n_sparse = len(offsets)
|
||||
|
||||
# Flat-signature wrapper so torch.export sees positional tensors.
|
||||
class FlatWrapper(torch.nn.Module):
|
||||
def __init__(self, m, n_sparse: int):
|
||||
super().__init__()
|
||||
self.m = m
|
||||
self.n_sparse = n_sparse
|
||||
|
||||
def forward(self, *args):
|
||||
n = self.n_sparse
|
||||
dense_x = args[0]
|
||||
offsets = list(args[1 : 1 + n])
|
||||
indices = list(args[1 + n : 1 + 2 * n])
|
||||
return self.m(dense_x, offsets, indices)
|
||||
|
||||
flat_model = FlatWrapper(model, n_sparse).to(DEVICE).eval()
|
||||
flat_inputs = (dense_x, *offsets, *indices)
|
||||
|
||||
with torch.no_grad():
|
||||
ep = torch.export.export(flat_model, flat_inputs)
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
pkg = os.path.join(tmp, "dlrm.pt2")
|
||||
torch._inductor.aoti_compile_and_package(ep, package_path=pkg)
|
||||
loaded = torch._inductor.aoti_load_package(pkg)
|
||||
|
||||
@torch.no_grad()
|
||||
def fn():
|
||||
return loaded(*flat_inputs)
|
||||
|
||||
for _ in range(WARMUP):
|
||||
fn()
|
||||
return mean_us(time_callable(fn, ITERS))
|
||||
|
||||
|
||||
def bench_cuda_graphs(model, inputs):
|
||||
"""Capture eager forward as a CUDA graph and replay. Patches the
|
||||
interaction's li/lj construction to be static buffers so capture works
|
||||
(same trick the single-config bench uses)."""
|
||||
dense_x, offsets, indices = inputs
|
||||
n_sparse = len(offsets)
|
||||
n_feat = 1 + n_sparse
|
||||
li = torch.tensor([i for i in range(n_feat) for _ in range(i)], device=DEVICE)
|
||||
lj = torch.tensor([j for i in range(n_feat) for j in range(i)], device=DEVICE)
|
||||
|
||||
def _interact_static(self, x, ly):
|
||||
bs, d = x.shape
|
||||
T = torch.cat([x] + ly, dim=1).view((bs, -1, d))
|
||||
Z = torch.bmm(T, torch.transpose(T, 1, 2))
|
||||
Zflat = Z[:, li, lj]
|
||||
return torch.cat([x, Zflat], dim=1)
|
||||
|
||||
model._interact = types.MethodType(_interact_static, model)
|
||||
|
||||
static_dense = dense_x.clone()
|
||||
static_offsets = [o.clone() for o in offsets]
|
||||
static_indices = [i.clone() for i in indices]
|
||||
|
||||
@torch.no_grad()
|
||||
def fwd():
|
||||
return model(static_dense, static_offsets, static_indices)
|
||||
|
||||
s = torch.cuda.Stream()
|
||||
s.wait_stream(torch.cuda.current_stream())
|
||||
with torch.cuda.stream(s):
|
||||
for _ in range(3):
|
||||
fwd()
|
||||
torch.cuda.current_stream().wait_stream(s)
|
||||
|
||||
g = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(g):
|
||||
_ = fwd()
|
||||
|
||||
@torch.no_grad()
|
||||
def fn():
|
||||
g.replay()
|
||||
|
||||
for _ in range(WARMUP):
|
||||
fn()
|
||||
return mean_us(time_callable(fn, ITERS))
|
||||
|
||||
|
||||
def bench_luminal_backend(model, inputs):
|
||||
torch._dynamo.reset()
|
||||
compiled = torch.compile(model, backend=luminal_backend)
|
||||
|
||||
@torch.no_grad()
|
||||
def fn():
|
||||
return compiled(*inputs)
|
||||
|
||||
for _ in range(WARMUP):
|
||||
fn()
|
||||
return mean_us(time_callable(fn, ITERS))
|
||||
|
||||
|
||||
BACKENDS = [
|
||||
("eager", bench_eager),
|
||||
("torch.compile", bench_torch_compile),
|
||||
("AOTInductor", bench_aoti),
|
||||
("CUDA graphs", bench_cuda_graphs),
|
||||
("luminal_backend", bench_luminal_backend),
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Driver
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def fmt(v: float) -> str:
|
||||
if v != v: # NaN
|
||||
return " - "
|
||||
return f"{v:7.1f}"
|
||||
|
||||
|
||||
def main() -> None:
|
||||
print(f"Device: {torch.cuda.get_device_name(0)}")
|
||||
print(f"PyTorch: {torch.__version__}")
|
||||
print(
|
||||
f"Sweep: batch ∈ {BATCH_SIZES}, n_sparse ∈ {N_SPARSE_LIST}, "
|
||||
f"backends ∈ {[b[0] for b in BACKENDS]}, iters={ITERS}\n"
|
||||
)
|
||||
|
||||
# results[backend_name][batch][n_sparse] = mean µs
|
||||
results: dict[str, dict[tuple[int, int], float]] = {
|
||||
name: {} for name, _ in BACKENDS
|
||||
}
|
||||
|
||||
total_cells = len(BATCH_SIZES) * len(N_SPARSE_LIST) * len(BACKENDS)
|
||||
cell = 0
|
||||
for n_sparse in N_SPARSE_LIST:
|
||||
for batch in BATCH_SIZES:
|
||||
model, ln_emb = make_model(n_sparse)
|
||||
inputs = make_inputs(batch, ln_emb)
|
||||
for name, fn in BACKENDS:
|
||||
cell += 1
|
||||
t0 = time.perf_counter()
|
||||
try:
|
||||
mu = fn(model, inputs)
|
||||
except Exception as e:
|
||||
mu = float("nan")
|
||||
print(
|
||||
f" [{cell:>3}/{total_cells}] "
|
||||
f"bs={batch:>4} n_sparse={n_sparse:>2} {name:<18s} "
|
||||
f"FAILED: {type(e).__name__}: {str(e).splitlines()[-1][:80]}"
|
||||
)
|
||||
continue
|
||||
results[name][(batch, n_sparse)] = mu
|
||||
print(
|
||||
f" [{cell:>3}/{total_cells}] "
|
||||
f"bs={batch:>4} n_sparse={n_sparse:>2} {name:<18s} "
|
||||
f"mean={mu:>7.1f}µs (took {time.perf_counter() - t0:.1f}s)"
|
||||
)
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
torch._dynamo.reset()
|
||||
|
||||
# ---- Print one table per backend -----------------------------------
|
||||
print("\n" + "=" * 78)
|
||||
print("Latency in µs by backend (rows = batch, cols = n_sparse)")
|
||||
for name, _ in BACKENDS:
|
||||
print(f"\n {name}:")
|
||||
header = " " + "".join(f" n_sp={ns:<4}" for ns in N_SPARSE_LIST)
|
||||
print(header)
|
||||
for bs in BATCH_SIZES:
|
||||
row = f" bs={bs:<4} "
|
||||
for ns in N_SPARSE_LIST:
|
||||
v = results[name].get((bs, ns), float("nan"))
|
||||
row += f" {fmt(v)} "
|
||||
print(row)
|
||||
|
||||
# ---- Print "fastest backend per cell" matrix -----------------------
|
||||
print("\n" + "=" * 78)
|
||||
print("Winner per cell (lowest mean µs):")
|
||||
print("\n " + "".join(f" n_sp={ns:<14}" for ns in N_SPARSE_LIST))
|
||||
for bs in BATCH_SIZES:
|
||||
row = f" bs={bs:<4} "
|
||||
for ns in N_SPARSE_LIST:
|
||||
options = [
|
||||
(name, results[name].get((bs, ns), float("inf"))) for name, _ in BACKENDS
|
||||
]
|
||||
options = [(n, v) for n, v in options if v == v and v != float("inf")]
|
||||
if not options:
|
||||
row += " - "
|
||||
continue
|
||||
winner = min(options, key=lambda x: x[1])
|
||||
row += f" {winner[0]:<13s} {winner[1]:>6.1f}"
|
||||
print(row)
|
||||
|
||||
# ---- luminal_backend vs eager: scaling story -----------------------
|
||||
print("\n" + "=" * 78)
|
||||
print("luminal_backend / eager (lower than 1.0 = luminal wins this cell):")
|
||||
print("\n " + "".join(f" n_sp={ns:<4}" for ns in N_SPARSE_LIST))
|
||||
for bs in BATCH_SIZES:
|
||||
row = f" bs={bs:<4} "
|
||||
for ns in N_SPARSE_LIST:
|
||||
le = results.get("luminal_backend", {}).get((bs, ns), float("nan"))
|
||||
eg = results.get("eager", {}).get((bs, ns), float("nan"))
|
||||
if eg != eg or le != le or eg == 0:
|
||||
row += " - "
|
||||
else:
|
||||
row += f" {le / eg:>5.2f}x"
|
||||
print(row)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
89
examples/dlrm/export.py
Normal file
89
examples/dlrm/export.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""Three-way DLRM equivalence harness.
|
||||
|
||||
Builds the MiniDLRM at the fixed config used by examples/dlrm/src/main.rs,
|
||||
serializes weights + sample inputs + the PyTorch eager output to safetensors
|
||||
files that the rust binary loads. Also runs the PyTorch + luminal_backend
|
||||
path so the comparison happens in one place.
|
||||
|
||||
Saves:
|
||||
/tmp/dlrm_weights.safetensors — state_dict with PyTorch names
|
||||
/tmp/dlrm_inputs.safetensors — dense_x, indices_{0..2}, and `expected`
|
||||
(the PyTorch eager output, fp32)
|
||||
|
||||
Then run:
|
||||
cargo run --release --manifest-path examples/dlrm/Cargo.toml
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from safetensors.torch import save_file
|
||||
|
||||
# Import MiniDLRM from the test file we already authored.
|
||||
TESTS_DIR = Path(__file__).resolve().parents[2] / "crates" / "luminal_python" / "tests"
|
||||
sys.path.insert(0, str(TESTS_DIR))
|
||||
from test_dlrm import MiniDLRM # noqa: E402
|
||||
|
||||
# Backend (and venv) shared with the test runner.
|
||||
from luminal import luminal_backend # noqa: E402
|
||||
|
||||
M_SPA = 4
|
||||
LN_EMB = [10, 20, 30]
|
||||
LN_BOT = [13, 8, M_SPA]
|
||||
LN_TOP = [10, 8, 1]
|
||||
# Match the rust binary's BATCH_SIZE — real-workload DLRM batch where
|
||||
# compute-bound matmul efficiency is what's being measured.
|
||||
BATCH = 2048
|
||||
DEVICE = torch.device("cuda")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
torch.manual_seed(0)
|
||||
|
||||
model = MiniDLRM(M_SPA, LN_EMB, LN_BOT, LN_TOP).to(DEVICE).eval()
|
||||
|
||||
dense_x = torch.rand(BATCH, LN_BOT[0], device=DEVICE)
|
||||
indices = [
|
||||
torch.randint(0, n, (BATCH,), dtype=torch.long, device=DEVICE)
|
||||
for n in LN_EMB
|
||||
]
|
||||
offsets = [
|
||||
torch.arange(BATCH, dtype=torch.long, device=DEVICE) for _ in LN_EMB
|
||||
]
|
||||
|
||||
with torch.no_grad():
|
||||
eager_out = model(dense_x, offsets, indices)
|
||||
|
||||
compiled = torch.compile(model, backend=luminal_backend)
|
||||
with torch.no_grad():
|
||||
luminal_out = compiled(dense_x, offsets, indices)
|
||||
|
||||
max_diff_lum = (luminal_out - eager_out).abs().max().item()
|
||||
print(f"PyTorch eager output : {eager_out.flatten().tolist()}")
|
||||
print(f"PyTorch + luminal : {luminal_out.flatten().tolist()}")
|
||||
print(f" max |diff| eager vs luminal_backend : {max_diff_lum:.3e}")
|
||||
assert max_diff_lum < 1e-5, "PT eager and luminal_backend disagree"
|
||||
|
||||
# Save weights — state_dict names already match what rust uses.
|
||||
weights = {k: v.detach().cpu() for k, v in model.state_dict().items()}
|
||||
save_file(weights, "/tmp/dlrm_weights.safetensors")
|
||||
print(f" wrote /tmp/dlrm_weights.safetensors ({len(weights)} tensors)")
|
||||
|
||||
inputs = {
|
||||
"dense_x": dense_x.detach().cpu().contiguous(),
|
||||
"expected": eager_out.detach().cpu().contiguous(),
|
||||
}
|
||||
for k, ix in enumerate(indices):
|
||||
# Rust reads i32 indices.
|
||||
inputs[f"indices_{k}"] = ix.detach().cpu().to(torch.int32).contiguous()
|
||||
save_file(inputs, "/tmp/dlrm_inputs.safetensors")
|
||||
print(f" wrote /tmp/dlrm_inputs.safetensors ({len(inputs)} tensors)")
|
||||
|
||||
print(
|
||||
"\nNext: cargo run --release --manifest-path examples/dlrm/Cargo.toml --bin dlrm"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
637
examples/dlrm/src/main.rs
Normal file
637
examples/dlrm/src/main.rs
Normal file
@@ -0,0 +1,637 @@
|
||||
//! Pure-rust DLRM mirroring `MiniDLRM` from
|
||||
//! `crates/luminal_python/tests/test_dlrm.py`.
|
||||
//!
|
||||
//! Loads weights + sample inputs + expected output produced by
|
||||
//! `examples/dlrm/export.py`, runs the same compute graph through luminal's
|
||||
//! CUDA runtime, and prints max-abs diff vs the saved PyTorch eager output.
|
||||
//!
|
||||
//! Topology (fixed for now — same as MiniDLRM at the small-config we test):
|
||||
//! m_spa = 4
|
||||
//! ln_emb = [10, 20, 30] (3 sparse tables)
|
||||
//! ln_bot = [13, 8, 4] (Linear-ReLU-Linear-ReLU)
|
||||
//! ln_top = [10, 8, 1] (Linear-ReLU-Linear-Sigmoid)
|
||||
//! batch_size = 2, bag_size = 1
|
||||
//!
|
||||
//! Weight name convention matches the PyTorch state_dict (so
|
||||
//! `runtime.load_safetensors` matches by name with no remapping):
|
||||
//! emb_l.{k}.weight (V_k, m_spa)
|
||||
//! bot_l.{0,2}.{weight,bias} Linear in_features → out_features
|
||||
//! top_l.{0,2}.{weight,bias} same
|
||||
//! PyTorch stores Linear weight as (out, in); we permute when matmul'ing.
|
||||
|
||||
use std::path::Path;
|
||||
|
||||
use luminal::prelude::*;
|
||||
use luminal_cuda_lite::kernel::{DlrmMegaCustom, DlrmMegaKernel};
|
||||
use luminal_cuda_lite::{cudarc::driver::CudaContext, runtime::CudaRuntime};
|
||||
use luminal_nn::gather_rows;
|
||||
use memmap2::MmapOptions;
|
||||
use safetensors::SafeTensors;
|
||||
|
||||
const M_SPA: usize = 4;
|
||||
const LN_EMB: [usize; 3] = [10, 20, 30];
|
||||
const LN_BOT: [usize; 3] = [13, 8, M_SPA];
|
||||
const LN_TOP: [usize; 3] = [10, 8, 1];
|
||||
// Real-workload DLRM batch — large enough that kernel work dominates the
|
||||
// per-launch overhead and the compute-bound performance is what's measured.
|
||||
const BATCH_SIZE: usize = 2048;
|
||||
|
||||
/// Linear with bias whose weight matches PyTorch's `nn.Linear` storage:
|
||||
/// shape `(out, in)`. Forward computes `input @ weight.T + bias`.
|
||||
struct Linear {
|
||||
weight: GraphTensor, // (out_features, in_features)
|
||||
bias: GraphTensor, // (out_features,)
|
||||
}
|
||||
|
||||
impl Linear {
|
||||
fn new(cx: &mut Graph, prefix: &str, in_features: usize, out_features: usize) -> Self {
|
||||
Self {
|
||||
weight: cx
|
||||
.named_tensor(format!("{prefix}.weight"), (out_features, in_features))
|
||||
.persist(),
|
||||
bias: cx
|
||||
.named_tensor(format!("{prefix}.bias"), out_features)
|
||||
.persist(),
|
||||
}
|
||||
}
|
||||
|
||||
fn forward(&self, input: GraphTensor) -> GraphTensor {
|
||||
let out_features = self.weight.shape.dims[0];
|
||||
let mm = input.matmul(self.weight.permute((1, 0)));
|
||||
// Broadcast bias (out,) → output shape (..., out).
|
||||
let bias_b = self.bias.expand_dim(0, mm.shape.dims[0]);
|
||||
// bias_b shape: (B, out) — matches `mm` for 2-D input.
|
||||
let _ = out_features;
|
||||
mm + bias_b
|
||||
}
|
||||
}
|
||||
|
||||
// We use luminal's primitive .relu() / .sigmoid() rather than hand-rolling
|
||||
// them out of maximum / exp / reciprocal so the HLIR generated here matches
|
||||
// what the PT2 translator emits for `aten.relu.default` / `aten.sigmoid.default`
|
||||
// op-for-op. See dispatch.rs: both route through these same primitives.
|
||||
|
||||
fn bot_forward(layers: &[Linear; 2], x: GraphTensor) -> GraphTensor {
|
||||
layers[1].forward(layers[0].forward(x).relu()).relu()
|
||||
}
|
||||
|
||||
fn top_forward(layers: &[Linear; 2], x: GraphTensor) -> GraphTensor {
|
||||
layers[1].forward(layers[0].forward(x).relu()).sigmoid()
|
||||
}
|
||||
|
||||
/// Dot interaction: cat(dense, sparse...) → reshape → bmm → flat-tri-upper indexing.
|
||||
/// Matches MiniDLRM._interact in the python test.
|
||||
fn interact_features(
|
||||
cx: &mut Graph,
|
||||
dense: GraphTensor,
|
||||
sparse: &[GraphTensor],
|
||||
) -> GraphTensor {
|
||||
let batch = dense.shape.dims[0];
|
||||
let d = dense.shape.dims[1];
|
||||
let n_feat = 1 + sparse.len();
|
||||
|
||||
// T = cat([dense, *sparse], dim=1).view(B, n_feat, d)
|
||||
let mut t = dense;
|
||||
for s in sparse {
|
||||
t = t.concat_along(*s, 1);
|
||||
}
|
||||
// Reshape (B, n_feat * d) → (B, n_feat, d). concat_along leaves a contiguous
|
||||
// tensor so a fresh ShapeTracker is safe.
|
||||
let bagged = GraphTensor {
|
||||
id: t.id,
|
||||
graph_ref: t.graph_ref,
|
||||
shape: ShapeTracker::new((batch, Expression::from(n_feat), d)),
|
||||
dtype: t.dtype,
|
||||
};
|
||||
|
||||
// Z = bmm(T, T.transpose(1, 2)) → (B, n_feat, n_feat)
|
||||
let z = bagged.matmul(bagged.permute((0, 2, 1)));
|
||||
|
||||
// Strictly-lower-triangular indices into (n_feat, n_feat). For n_feat=4
|
||||
// these are 6 (i,j) pairs: (1,0),(2,0),(2,1),(3,0),(3,1),(3,2).
|
||||
let mut li = Vec::new();
|
||||
let mut lj = Vec::new();
|
||||
for i in 0..n_feat {
|
||||
for j in 0..i {
|
||||
li.push(i as i32);
|
||||
lj.push(j as i32);
|
||||
}
|
||||
}
|
||||
let n_pairs = li.len();
|
||||
|
||||
// Build flat_idx_per_pair[k] = li[k] * n_feat + lj[k] (constant across batch).
|
||||
let mut flat_idx_per_pair = Vec::with_capacity(n_pairs);
|
||||
for k in 0..n_pairs {
|
||||
flat_idx_per_pair.push(li[k] * n_feat as i32 + lj[k]);
|
||||
}
|
||||
|
||||
// Absolute flat index into Z viewed as 1D for each (b, k):
|
||||
// abs[b, k] = b * (n_feat*n_feat) + flat_idx_per_pair[k]
|
||||
let row_stride = n_feat * n_feat; // entries per batch in Z
|
||||
let arange_b = cx.arange(batch); // (B,) ints, values 0..B
|
||||
let abs_idx = arange_b.expand_dim(1, Expression::from(n_pairs))
|
||||
* Expression::from(row_stride);
|
||||
// pair_idx_const: (n_pairs,) ints, captured as a graph input we set once.
|
||||
let pair_idx = cx
|
||||
.named_tensor("__dot_pair_idx", n_pairs)
|
||||
.as_dtype(DType::Int)
|
||||
.persist();
|
||||
let abs_idx = abs_idx + pair_idx.expand_dim(0, batch);
|
||||
|
||||
// Gather Z as 1D.
|
||||
let z_flat = GraphTensor {
|
||||
id: z.id,
|
||||
graph_ref: z.graph_ref,
|
||||
shape: ShapeTracker::new(batch * row_stride),
|
||||
dtype: z.dtype,
|
||||
};
|
||||
let zflat_indexed = z_flat.gather(abs_idx); // (B, n_pairs)
|
||||
|
||||
// R = cat(dense, zflat_indexed, dim=1) → (B, d + n_pairs)
|
||||
dense.concat_along(zflat_indexed, 1)
|
||||
}
|
||||
|
||||
fn main() {
|
||||
// Parse args: optional --bench / --stats / --mega, then positional paths.
|
||||
let mut bench_mode = false;
|
||||
let mut stats_mode = false;
|
||||
let mut mega_mode = false;
|
||||
let mut positional: Vec<String> = Vec::new();
|
||||
for arg in std::env::args().skip(1) {
|
||||
if arg == "--bench" {
|
||||
bench_mode = true;
|
||||
} else if arg == "--stats" {
|
||||
stats_mode = true;
|
||||
} else if arg == "--mega" {
|
||||
mega_mode = true;
|
||||
} else {
|
||||
positional.push(arg);
|
||||
}
|
||||
}
|
||||
let weights_path = positional
|
||||
.first()
|
||||
.cloned()
|
||||
.unwrap_or_else(|| "/tmp/dlrm_weights.safetensors".to_string());
|
||||
let inputs_path = positional
|
||||
.get(1)
|
||||
.cloned()
|
||||
.unwrap_or_else(|| "/tmp/dlrm_inputs.safetensors".to_string());
|
||||
|
||||
if mega_mode {
|
||||
run_megakernel(&weights_path, &inputs_path, bench_mode);
|
||||
return;
|
||||
}
|
||||
|
||||
assert!(
|
||||
Path::new(&weights_path).exists(),
|
||||
"weights not found: {weights_path}. Run examples/dlrm/export.py first."
|
||||
);
|
||||
assert!(
|
||||
Path::new(&inputs_path).exists(),
|
||||
"inputs not found: {inputs_path}. Run examples/dlrm/export.py first."
|
||||
);
|
||||
|
||||
// ---- Build graph -----------------------------------------------------
|
||||
let mut cx = Graph::default();
|
||||
|
||||
let dense_in = cx
|
||||
.named_tensor("dense_x", (BATCH_SIZE, LN_BOT[0]));
|
||||
let idx_tensors: Vec<GraphTensor> = (0..LN_EMB.len())
|
||||
.map(|k| {
|
||||
cx.named_tensor(format!("indices_{k}"), BATCH_SIZE)
|
||||
.as_dtype(DType::Int)
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Embedding tables (bag_size=1 → just row gather).
|
||||
let emb_weights: Vec<GraphTensor> = (0..LN_EMB.len())
|
||||
.map(|k| {
|
||||
cx.named_tensor(format!("emb_l.{k}.weight"), (LN_EMB[k], M_SPA))
|
||||
.persist()
|
||||
})
|
||||
.collect();
|
||||
let sparse_feats: Vec<GraphTensor> = (0..LN_EMB.len())
|
||||
.map(|k| gather_rows(emb_weights[k], idx_tensors[k], M_SPA))
|
||||
.collect();
|
||||
|
||||
// Bottom MLP: Linear 13→8, ReLU, Linear 8→4, ReLU.
|
||||
let bot = [
|
||||
Linear::new(&mut cx, "bot_l.0", LN_BOT[0], LN_BOT[1]),
|
||||
Linear::new(&mut cx, "bot_l.2", LN_BOT[1], LN_BOT[2]),
|
||||
];
|
||||
let dense_out = bot_forward(&bot, dense_in);
|
||||
|
||||
// Dot interaction → (B, n_pairs + m_spa) = (B, 10) for our config.
|
||||
let interacted = interact_features(&mut cx, dense_out, &sparse_feats);
|
||||
|
||||
// Top MLP: Linear 10→8, ReLU, Linear 8→1, Sigmoid.
|
||||
let top = [
|
||||
Linear::new(&mut cx, "top_l.0", LN_TOP[0], LN_TOP[1]),
|
||||
Linear::new(&mut cx, "top_l.2", LN_TOP[1], LN_TOP[2]),
|
||||
];
|
||||
let out = top_forward(&top, interacted).output();
|
||||
|
||||
// ---- Compile + load weights ------------------------------------------
|
||||
let ctx = CudaContext::new(0).expect("Failed to open CUDA device 0");
|
||||
let stream = ctx.default_stream();
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
let mut runtime = CudaRuntime::initialize(stream);
|
||||
runtime.load_safetensors(&cx, &weights_path);
|
||||
|
||||
// Set the strictly-lower-triangular pair index constant.
|
||||
let n_feat = 1 + LN_EMB.len();
|
||||
let mut pair_idx_vals = Vec::new();
|
||||
for i in 0..n_feat {
|
||||
for j in 0..i {
|
||||
pair_idx_vals.push((i * n_feat + j) as i32);
|
||||
}
|
||||
}
|
||||
// Find the named input by walking the graph.
|
||||
let pair_idx_id = find_named_input(&cx, "__dot_pair_idx")
|
||||
.expect("pair_idx tensor not found in graph");
|
||||
runtime.set_data(pair_idx_id, pair_idx_vals);
|
||||
|
||||
// Load inputs + expected output from safetensors.
|
||||
let inputs_mmap = unsafe {
|
||||
MmapOptions::new()
|
||||
.map(&std::fs::File::open(&inputs_path).unwrap())
|
||||
.unwrap()
|
||||
};
|
||||
let inputs_st = SafeTensors::deserialize(&inputs_mmap).unwrap();
|
||||
|
||||
let dense_x: Vec<f32> = bytemuck::cast_slice(inputs_st.tensor("dense_x").unwrap().data()).to_vec();
|
||||
runtime.set_data(dense_in, dense_x);
|
||||
for (k, idx_t) in idx_tensors.iter().enumerate() {
|
||||
let ix: Vec<i32> = bytemuck::cast_slice(
|
||||
inputs_st.tensor(&format!("indices_{k}")).unwrap().data(),
|
||||
)
|
||||
.to_vec();
|
||||
runtime.set_data(*idx_t, ix);
|
||||
}
|
||||
|
||||
// ---- Search (small budget — graph is tiny) ---------------------------
|
||||
use rand::SeedableRng;
|
||||
let mut rng = rand::rngs::StdRng::seed_from_u64(0);
|
||||
runtime = cx.search_options(runtime, SearchOptions::new(8).trials(1), &mut rng);
|
||||
|
||||
// ---- Execute and compare ---------------------------------------------
|
||||
runtime.execute(&cx.dyn_map);
|
||||
let result = runtime.get_f32(out);
|
||||
|
||||
let expected_bytes = inputs_st.tensor("expected").unwrap().data();
|
||||
let expected: &[f32] = bytemuck::cast_slice(expected_bytes);
|
||||
|
||||
println!("rust output : {result:?}");
|
||||
println!("expected : {expected:?}");
|
||||
|
||||
let max_diff = result
|
||||
.iter()
|
||||
.zip(expected.iter())
|
||||
.map(|(a, b)| (a - b).abs())
|
||||
.fold(0.0_f32, f32::max);
|
||||
println!("max |diff| : {max_diff:.3e}");
|
||||
|
||||
assert!(
|
||||
max_diff < 1e-4,
|
||||
"rust output diverges from PyTorch eager (max_diff={max_diff})"
|
||||
);
|
||||
println!("OK — rust luminal matches PyTorch eager within 1e-4.");
|
||||
|
||||
if stats_mode {
|
||||
let host_ops = runtime.host_ops();
|
||||
println!("\n=== Active bucket host-op inventory ({} ops) ===", host_ops.len());
|
||||
let mut by_type: std::collections::BTreeMap<String, usize> =
|
||||
std::collections::BTreeMap::new();
|
||||
for op in &host_ops {
|
||||
let s = format!("{op:?}");
|
||||
let head = s.split_whitespace().next().unwrap_or(&s).to_string();
|
||||
*by_type.entry(head).or_insert(0) += 1;
|
||||
}
|
||||
for (k, v) in &by_type {
|
||||
println!(" {v:>3} {k}");
|
||||
}
|
||||
// Per-op detail: extract the cuBLASLt epilogue + shape signature so
|
||||
// we can see at a glance whether bias/relu fusion fired (the egglog
|
||||
// rewrites map matmul+add+maximum_f32(0) -> EPILOGUE_RELU_BIAS).
|
||||
println!("\n=== cuBLASLt op detail ===");
|
||||
for op in &host_ops {
|
||||
let s = format!("{op:?}");
|
||||
if !s.starts_with("CuBlasLt") {
|
||||
continue;
|
||||
}
|
||||
let epilogue = extract_field(&s, "epilogue:");
|
||||
let shape = (extract_field(&s, "m:"), extract_field(&s, "n:"), extract_field(&s, "k:"));
|
||||
println!(
|
||||
" m={:<8} n={:<8} k={:<8} epilogue={}",
|
||||
shape.0, shape.1, shape.2, epilogue
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if bench_mode {
|
||||
// Cache input vectors so the bench loop can re-set_data each iter (the
|
||||
// PyTorch backends do an equivalent staging step under the hood).
|
||||
let dense_vec: Vec<f32> =
|
||||
bytemuck::cast_slice(inputs_st.tensor("dense_x").unwrap().data()).to_vec();
|
||||
let idx_vecs: Vec<Vec<i32>> = (0..idx_tensors.len())
|
||||
.map(|k| {
|
||||
bytemuck::cast_slice(
|
||||
inputs_st.tensor(&format!("indices_{k}")).unwrap().data(),
|
||||
)
|
||||
.to_vec()
|
||||
})
|
||||
.collect();
|
||||
bench_rust(
|
||||
&mut cx,
|
||||
&mut runtime,
|
||||
out,
|
||||
dense_in,
|
||||
&idx_tensors,
|
||||
dense_vec,
|
||||
idx_vecs,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Time `runtime.execute` directly. Inputs are already loaded once before
|
||||
/// `--bench` and not re-uploaded between calls, mirroring CUDA-graph replay
|
||||
/// semantics. Synchronizes the stream once at the end and divides total
|
||||
/// elapsed by `iters` for a steady-state mean; also prints per-iter samples
|
||||
/// to /tmp/dlrm_bench_rust_luminal.txt for the python aggregator.
|
||||
fn bench_rust(
|
||||
cx: &mut Graph,
|
||||
runtime: &mut CudaRuntime,
|
||||
out: GraphTensor,
|
||||
dense_in: GraphTensor,
|
||||
idx_tensors: &[GraphTensor],
|
||||
dense_vec: Vec<f32>,
|
||||
idx_vecs: Vec<Vec<i32>>,
|
||||
) {
|
||||
bench_through_luminal(
|
||||
cx,
|
||||
runtime,
|
||||
out,
|
||||
dense_in,
|
||||
idx_tensors,
|
||||
dense_vec,
|
||||
idx_vecs,
|
||||
"/tmp/dlrm_bench_rust_luminal.txt",
|
||||
"[bench] rust luminal",
|
||||
);
|
||||
}
|
||||
|
||||
/// Shared steady-state bench for any luminal graph + runtime. Re-sets
|
||||
/// inputs every iter, calls `execute`, then `get_f32` to force a stream
|
||||
/// sync. Dumps per-iter µs samples to `samples_path` for
|
||||
/// `examples/dlrm/bench.py` to merge into its ranking.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn bench_through_luminal(
|
||||
cx: &mut Graph,
|
||||
runtime: &mut CudaRuntime,
|
||||
out: GraphTensor,
|
||||
dense_in: GraphTensor,
|
||||
idx_tensors: &[GraphTensor],
|
||||
dense_vec: Vec<f32>,
|
||||
idx_vecs: Vec<Vec<i32>>,
|
||||
samples_path: &str,
|
||||
label: &str,
|
||||
) {
|
||||
const WARMUP: usize = 50;
|
||||
const ITERS: usize = 500;
|
||||
use std::time::Instant;
|
||||
|
||||
let bench_once = |runtime: &mut CudaRuntime| {
|
||||
runtime.set_data(dense_in, dense_vec.clone());
|
||||
for (k, t) in idx_tensors.iter().enumerate() {
|
||||
runtime.set_data(*t, idx_vecs[k].clone());
|
||||
}
|
||||
runtime.execute(&cx.dyn_map);
|
||||
let _ = runtime.get_f32(out);
|
||||
};
|
||||
|
||||
for _ in 0..WARMUP {
|
||||
bench_once(runtime);
|
||||
}
|
||||
|
||||
let mut samples = Vec::with_capacity(ITERS);
|
||||
for _ in 0..ITERS {
|
||||
let t0 = Instant::now();
|
||||
bench_once(runtime);
|
||||
samples.push(t0.elapsed().as_secs_f64() * 1e6);
|
||||
}
|
||||
samples.sort_by(|a, b| a.partial_cmp(b).unwrap());
|
||||
let mean = samples.iter().sum::<f64>() / ITERS as f64;
|
||||
let p50 = samples[ITERS / 2];
|
||||
let p99 = samples[(ITERS as f64 * 0.99) as usize];
|
||||
println!(
|
||||
"\n{label}: mean={mean:.2}µs p50={p50:.2}µs p99={p99:.2}µs (n={ITERS})"
|
||||
);
|
||||
|
||||
let body = samples
|
||||
.iter()
|
||||
.map(|s| format!("{s:.4}"))
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
std::fs::write(samples_path, body).expect("write bench samples");
|
||||
println!(" per-iter samples -> {samples_path}");
|
||||
}
|
||||
|
||||
/// `--mega`: build a luminal Graph whose entire forward is a single
|
||||
/// [`DlrmMegaCustom`] op, then run it through the standard
|
||||
/// `CudaRuntime` flow (load_safetensors → search → execute → get_f32).
|
||||
/// Verifies bitwise vs the saved PyTorch eager output, optionally
|
||||
/// benches steady-state per-call latency through the same `bench_rust`
|
||||
/// path the non-mega rust binary uses.
|
||||
///
|
||||
/// The point: same kernel as the PT2-backend fast path (the parameterized
|
||||
/// `DlrmMegaKernel` in `luminal_cuda_lite::kernel::dlrm_megakernel`),
|
||||
/// just constructed by hand instead of via the translator's pattern
|
||||
/// matcher. Everything past the `cx.custom_op` call — buffer
|
||||
/// management, weight loading, input registration, kernel dispatch,
|
||||
/// output retrieval — is luminal's runtime.
|
||||
fn run_megakernel(weights_path: &str, inputs_path: &str, bench: bool) {
|
||||
assert!(
|
||||
Path::new(weights_path).exists(),
|
||||
"weights not found: {weights_path}. Run examples/dlrm/export.py first."
|
||||
);
|
||||
assert!(
|
||||
Path::new(inputs_path).exists(),
|
||||
"inputs not found: {inputs_path}. Run examples/dlrm/export.py first."
|
||||
);
|
||||
|
||||
let mut cx = Graph::default();
|
||||
|
||||
// ---- User inputs -----------------------------------------------------
|
||||
let dense_in = cx.named_tensor("dense_x", (BATCH_SIZE, LN_BOT[0]));
|
||||
let idx_tensors: Vec<GraphTensor> = (0..LN_EMB.len())
|
||||
.map(|k| {
|
||||
cx.named_tensor(format!("indices_{k}"), BATCH_SIZE)
|
||||
.as_dtype(DType::Int)
|
||||
})
|
||||
.collect();
|
||||
|
||||
// ---- Weights — names must match safetensors keys so the runtime's
|
||||
// load_safetensors matches by Input label.
|
||||
let emb_weights: Vec<GraphTensor> = (0..LN_EMB.len())
|
||||
.map(|k| {
|
||||
cx.named_tensor(format!("emb_l.{k}.weight"), (LN_EMB[k], M_SPA))
|
||||
.persist()
|
||||
})
|
||||
.collect();
|
||||
// PyTorch's nn.Linear stores weight as (out_features, in_features).
|
||||
let bot_l0_w = cx
|
||||
.named_tensor("bot_l.0.weight", (LN_BOT[1], LN_BOT[0]))
|
||||
.persist();
|
||||
let bot_l0_b = cx.named_tensor("bot_l.0.bias", LN_BOT[1]).persist();
|
||||
let bot_l1_w = cx
|
||||
.named_tensor("bot_l.2.weight", (LN_BOT[2], LN_BOT[1]))
|
||||
.persist();
|
||||
let bot_l1_b = cx.named_tensor("bot_l.2.bias", LN_BOT[2]).persist();
|
||||
let top_l0_w = cx
|
||||
.named_tensor("top_l.0.weight", (LN_TOP[1], LN_TOP[0]))
|
||||
.persist();
|
||||
let top_l0_b = cx.named_tensor("top_l.0.bias", LN_TOP[1]).persist();
|
||||
let top_l1_w = cx
|
||||
.named_tensor("top_l.2.weight", (LN_TOP[2], LN_TOP[1]))
|
||||
.persist();
|
||||
let top_l1_b = cx.named_tensor("top_l.2.bias", LN_TOP[2]).persist();
|
||||
|
||||
// ---- One CustomOp does the whole forward ----------------------------
|
||||
// Input order MUST match what DlrmMegaKernel's CUDA source expects:
|
||||
// dense, indices..., emb_weights..., bot Linears (w then b each),
|
||||
// top Linears (w then b each). See `kernel::dlrm_megakernel`.
|
||||
let mut inputs: Vec<GraphTensor> = vec![dense_in];
|
||||
inputs.extend(idx_tensors.iter().copied());
|
||||
inputs.extend(emb_weights.iter().copied());
|
||||
inputs.extend([
|
||||
bot_l0_w, bot_l0_b, bot_l1_w, bot_l1_b, top_l0_w, top_l0_b, top_l1_w, top_l1_b,
|
||||
]);
|
||||
|
||||
let kernel = DlrmMegaKernel {
|
||||
batch: BATCH_SIZE,
|
||||
n_dense_in: LN_BOT[0],
|
||||
ln_bot: LN_BOT.to_vec(),
|
||||
n_sparse: LN_EMB.len(),
|
||||
vocab_sizes: LN_EMB.to_vec(),
|
||||
m_spa: M_SPA,
|
||||
ln_top: LN_TOP.to_vec(),
|
||||
};
|
||||
let out = cx
|
||||
.custom_op(
|
||||
DlrmMegaCustom(kernel),
|
||||
inputs,
|
||||
(BATCH_SIZE, 1usize),
|
||||
DType::F32,
|
||||
)
|
||||
.output();
|
||||
|
||||
// ---- Compile + load weights — same path as the non-mega flow -------
|
||||
let ctx = CudaContext::new(0).expect("Failed to open CUDA device 0");
|
||||
let stream = ctx.default_stream();
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
let mut runtime = CudaRuntime::initialize(stream);
|
||||
runtime.load_safetensors(&cx, weights_path);
|
||||
|
||||
// ---- Inputs ---------------------------------------------------------
|
||||
let inputs_mmap = unsafe {
|
||||
MmapOptions::new()
|
||||
.map(&std::fs::File::open(inputs_path).unwrap())
|
||||
.unwrap()
|
||||
};
|
||||
let inputs_st = SafeTensors::deserialize(&inputs_mmap).unwrap();
|
||||
let dense_vec: Vec<f32> =
|
||||
bytemuck::cast_slice(inputs_st.tensor("dense_x").unwrap().data()).to_vec();
|
||||
runtime.set_data(dense_in, dense_vec.clone());
|
||||
let idx_vecs: Vec<Vec<i32>> = (0..idx_tensors.len())
|
||||
.map(|k| {
|
||||
bytemuck::cast_slice(
|
||||
inputs_st.tensor(&format!("indices_{k}")).unwrap().data(),
|
||||
)
|
||||
.to_vec()
|
||||
})
|
||||
.collect();
|
||||
for (k, idx_t) in idx_tensors.iter().enumerate() {
|
||||
runtime.set_data(*idx_t, idx_vecs[k].clone());
|
||||
}
|
||||
|
||||
// ---- Search ---------------------------------------------------------
|
||||
// Single-CustomOp graph: nothing to search over. One trial.
|
||||
use rand::SeedableRng;
|
||||
let mut rng = rand::rngs::StdRng::seed_from_u64(0);
|
||||
runtime = cx.search_options(runtime, SearchOptions::new(1).trials(1), &mut rng);
|
||||
|
||||
// ---- Execute + verify -----------------------------------------------
|
||||
runtime.execute(&cx.dyn_map);
|
||||
let result = runtime.get_f32(out);
|
||||
|
||||
let expected: &[f32] =
|
||||
bytemuck::cast_slice(inputs_st.tensor("expected").unwrap().data());
|
||||
let max_diff = result
|
||||
.iter()
|
||||
.zip(expected.iter())
|
||||
.map(|(a, b)| (a - b).abs())
|
||||
.fold(0.0_f32, f32::max);
|
||||
println!(
|
||||
"[mega] output[..4]={:?} expected[..4]={:?} max|diff|={:.3e}",
|
||||
&result[..result.len().min(4)],
|
||||
&expected[..result.len().min(4)],
|
||||
max_diff
|
||||
);
|
||||
assert!(
|
||||
max_diff < 1e-4,
|
||||
"megakernel output diverges from PyTorch eager (max_diff={max_diff})"
|
||||
);
|
||||
println!("[mega] OK — luminal megakernel matches PyTorch eager within 1e-4");
|
||||
|
||||
// Inventory the host ops — should be exactly 1 (the DlrmMegaCustom).
|
||||
let host_ops = runtime.host_ops();
|
||||
println!("[mega] active bucket host-op count: {}", host_ops.len());
|
||||
|
||||
if bench {
|
||||
// Reuse the shared bench loop. Writes per-iter µs samples to
|
||||
// /tmp/dlrm_bench_megakernel.txt so examples/dlrm/bench.py picks
|
||||
// them up under the "DLRM megakernel" row.
|
||||
bench_through_luminal(
|
||||
&mut cx,
|
||||
&mut runtime,
|
||||
out,
|
||||
dense_in,
|
||||
&idx_tensors,
|
||||
dense_vec,
|
||||
idx_vecs,
|
||||
"/tmp/dlrm_bench_megakernel.txt",
|
||||
"[mega]",
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Pull a `Field: value,` from a Debug-formatted struct dump. Returns the
|
||||
/// substring between `field` and the next `,` or `}`, trimmed.
|
||||
fn extract_field(s: &str, field: &str) -> String {
|
||||
let Some(idx) = s.find(field) else {
|
||||
return "?".to_string();
|
||||
};
|
||||
let start = idx + field.len();
|
||||
let tail = &s[start..];
|
||||
let end = tail
|
||||
.find(|c: char| c == ',' || c == '}')
|
||||
.unwrap_or(tail.len());
|
||||
tail[..end].trim().to_string()
|
||||
}
|
||||
|
||||
/// Walk the graph looking for an [`Input`] op with the given label. Used to
|
||||
/// recover a `NodeIndex` we can `set_data` against when the original
|
||||
/// `GraphTensor` handle isn't in scope.
|
||||
fn find_named_input(cx: &Graph, label: &str) -> Option<NodeIndex> {
|
||||
use luminal::hlir::Input;
|
||||
for n in cx.graph.node_indices() {
|
||||
if let Some(Input { label: l, .. }) =
|
||||
(*cx.graph[n]).as_any().downcast_ref::<Input>()
|
||||
{
|
||||
if l == label {
|
||||
return Some(n);
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
@@ -11,6 +11,7 @@ luminal_nn = { path = "../../crates/luminal_nn" }
|
||||
luminal_cuda_lite = { path = "../../crates/luminal_cuda_lite" }
|
||||
tokenizers = "0.22.2"
|
||||
rustc-hash = "2"
|
||||
rand = "0.9.2"
|
||||
|
||||
# HuggingFace model download
|
||||
hf-hub = { version = "0.4", default-features = false, features = ["rustls-tls", "ureq"] }
|
||||
|
||||
@@ -5,11 +5,13 @@ use hf::prepare_hf_model;
|
||||
use luminal::prelude::*;
|
||||
use luminal_cuda_lite::{cudarc::driver::CudaContext, runtime::CudaRuntime};
|
||||
use model::*;
|
||||
use rand::{SeedableRng, rngs::SmallRng};
|
||||
use rustc_hash::FxHashSet;
|
||||
use std::{io::Write, time::Duration};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
const REPO_ID: &str = "google/gemma-4-26B-A4B";
|
||||
const SEARCH_SEED: u64 = 0;
|
||||
|
||||
fn env_bool(name: &str) -> bool {
|
||||
std::env::var(name)
|
||||
@@ -78,7 +80,12 @@ fn main() {
|
||||
cx.set_dim('p', 0);
|
||||
runtime.set_data(input, vec![1; search_s]);
|
||||
runtime.set_data(pos_ids, (0..search_s as i32).collect::<Vec<_>>());
|
||||
runtime = cx.search(runtime, search_graphs);
|
||||
let mut rng = SmallRng::seed_from_u64(SEARCH_SEED);
|
||||
runtime = cx.search_options(
|
||||
runtime,
|
||||
SearchOptions::new(search_graphs).profile_timeout(Duration::from_secs(2)),
|
||||
&mut rng,
|
||||
);
|
||||
|
||||
for layer in 0..LAYERS {
|
||||
let cache_bytes = cache_bytes_for_layer(layer, max_seq_len);
|
||||
|
||||
@@ -113,10 +113,6 @@ impl QwenRuntime for luminal_metal::MetalRuntime {
|
||||
fn get_f32(&self, id: NodeIndex) -> Vec<f32> {
|
||||
luminal_metal::MetalRuntime::get_f32(self, id)
|
||||
}
|
||||
|
||||
fn prepare_execute(&mut self, dyn_map: &FxHashMap<char, usize>) {
|
||||
luminal_metal::MetalRuntime::allocate_intermediate_buffers(self, dyn_map);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn run_qwen<R>(mut runtime: R, config: QwenRunConfig) -> Result<(), Box<dyn Error>>
|
||||
@@ -177,6 +173,17 @@ where
|
||||
DimBucket::new(2, max_prefill).representative(search_s),
|
||||
],
|
||||
);
|
||||
let max_decode_p = config.max_seq_len.saturating_sub(1);
|
||||
let decode_p_representative = prompt_tokens.len().min(max_decode_p).max(1);
|
||||
let p_buckets = if max_decode_p == 0 {
|
||||
vec![DimBucket::new(0, 0)]
|
||||
} else {
|
||||
vec![
|
||||
DimBucket::new(0, 0),
|
||||
DimBucket::new(1, max_decode_p).representative(decode_p_representative),
|
||||
]
|
||||
};
|
||||
cx.set_dim_buckets('p', &p_buckets);
|
||||
cx.set_dim('s', search_s);
|
||||
cx.set_dim('p', 0);
|
||||
runtime.set_i32_data(input.id, vec![1; search_s]);
|
||||
|
||||
@@ -279,17 +279,17 @@ impl DynBackend for NativeDynBackend {
|
||||
|
||||
fn get_output_f32(&self, node: NodeIndex) -> Vec<f32> {
|
||||
let data = self.output_buffer(node);
|
||||
(0..data.len()).map(|i| data.f32(i)).collect()
|
||||
data.to_f32_vec()
|
||||
}
|
||||
|
||||
fn get_output_i32(&self, node: NodeIndex) -> Vec<i32> {
|
||||
let data = self.output_buffer(node);
|
||||
(0..data.len()).map(|i| data.i32(i)).collect()
|
||||
data.to_i32_vec()
|
||||
}
|
||||
|
||||
fn get_output_bool(&self, node: NodeIndex) -> Vec<bool> {
|
||||
let data = self.output_buffer(node);
|
||||
(0..data.len()).map(|i| data.bool(i)).collect()
|
||||
data.to_bool_vec()
|
||||
}
|
||||
|
||||
fn execute(&mut self, dyn_map: &FxHashMap<char, usize>) {
|
||||
|
||||
@@ -1620,7 +1620,56 @@ pub fn extract_expr<'a>(
|
||||
|
||||
pub type EGraphChoiceSet<'a> = FxHashMap<&'a ClassId, &'a NodeId>;
|
||||
|
||||
/// Count the total number of possible IR/IList choice sets, capped at `limit`.
|
||||
fn is_search_choice_eclass(label: &str) -> bool {
|
||||
label.contains("IR") || label.contains("IList") || label.contains("OpKind")
|
||||
}
|
||||
|
||||
fn extractor_list_len(egraph: &SerializedEGraph, eclass_id: &ClassId) -> Option<usize> {
|
||||
let mut len = 0usize;
|
||||
let mut cur_eclass: ClassId = eclass_id.clone();
|
||||
let mut visited: FxHashSet<ClassId> = FxHashSet::default();
|
||||
loop {
|
||||
if !visited.insert(cur_eclass.clone()) {
|
||||
return None;
|
||||
}
|
||||
let (label, enodes) = egraph.eclasses.get(&cur_eclass)?;
|
||||
if !label.contains("List") {
|
||||
return Some(len);
|
||||
}
|
||||
let head_enode = enodes.first()?;
|
||||
let head_label = &egraph.enodes[head_enode].0;
|
||||
if head_label == "ENil" || head_label == "INil" {
|
||||
return Some(len);
|
||||
}
|
||||
if head_label != "ECons" && head_label != "ICons" {
|
||||
return Some(len);
|
||||
}
|
||||
len += 1;
|
||||
let children = &egraph.enodes[head_enode].1;
|
||||
if children.len() < 2 {
|
||||
return Some(len);
|
||||
}
|
||||
cur_eclass = children[1].clone();
|
||||
}
|
||||
}
|
||||
|
||||
fn opkind_metadata_consistent(egraph: &SerializedEGraph, node: &NodeId) -> bool {
|
||||
let lens: Vec<usize> = egraph.enodes[node]
|
||||
.1
|
||||
.iter()
|
||||
.filter_map(|c| {
|
||||
let lbl = &egraph.eclasses[c].0;
|
||||
if lbl.contains("List") {
|
||||
extractor_list_len(egraph, c)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
lens.is_empty() || lens.iter().all(|l| *l == lens[0])
|
||||
}
|
||||
|
||||
/// Count the total number of possible searchable choice sets, capped at `limit`.
|
||||
///
|
||||
/// Search deduplicates candidates by `EGraphChoiceSet`, so this gives the exact
|
||||
/// number of candidates when it is below `limit` without risking overflow on
|
||||
@@ -1632,7 +1681,7 @@ pub fn count_choice_sets_up_to(egraph: &SerializedEGraph, limit: usize) -> usize
|
||||
|
||||
let mut count = 1usize;
|
||||
for (label, enodes) in egraph.eclasses.values() {
|
||||
if !label.contains("IR") && !label.contains("IList") {
|
||||
if !is_search_choice_eclass(label) {
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -1650,10 +1699,10 @@ pub fn random_initial_choice<'a>(
|
||||
) -> EGraphChoiceSet<'a> {
|
||||
let mut choices = FxHashMap::default();
|
||||
for (eclass, (label, enodes)) in &egraph.eclasses {
|
||||
if !label.contains("IR") && !label.contains("IList") {
|
||||
if !is_search_choice_eclass(label) {
|
||||
continue;
|
||||
}
|
||||
// Prefer synth-injected enodes when available — they point at
|
||||
// Use synth-injected enodes when available — they point at
|
||||
// deterministic single-variant kind eclasses produced by the
|
||||
// deep-clone fallback in `inject_kernel_alternatives`, so the
|
||||
// extractor's first-enode walk is guaranteed length-consistent.
|
||||
@@ -1667,7 +1716,18 @@ pub fn random_initial_choice<'a>(
|
||||
.enumerate()
|
||||
.filter_map(|(i, n)| n.as_ref().starts_with("synth_").then_some(i))
|
||||
.collect();
|
||||
let pick_idx = if !synth_indices.is_empty() {
|
||||
let consistent_opkind_indices: Vec<usize> = if label == "OpKind" {
|
||||
enodes
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter_map(|(i, n)| opkind_metadata_consistent(egraph, n).then_some(i))
|
||||
.collect()
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
let pick_idx = if !consistent_opkind_indices.is_empty() {
|
||||
consistent_opkind_indices[rng.random_range(0..consistent_opkind_indices.len())]
|
||||
} else if !synth_indices.is_empty() {
|
||||
synth_indices[rng.random_range(0..synth_indices.len())]
|
||||
} else {
|
||||
rng.random_range(0..enodes.len())
|
||||
@@ -1684,9 +1744,9 @@ pub fn validate_choice_set<'a>(
|
||||
choices: &EGraphChoiceSet<'a>,
|
||||
ops: &[Arc<Box<dyn EgglogOp>>],
|
||||
) -> Result<(), String> {
|
||||
// Check all IR/IList eclasses have a choice
|
||||
// Check all searchable eclasses have a choice.
|
||||
for (eclass, (label, enodes)) in &egraph.eclasses {
|
||||
if !label.contains("IR") && !label.contains("IList") {
|
||||
if !is_search_choice_eclass(label) {
|
||||
continue;
|
||||
}
|
||||
let Some(chosen) = choices.get(eclass) else {
|
||||
@@ -1719,7 +1779,7 @@ pub fn validate_choice_set<'a>(
|
||||
.eclasses
|
||||
.get(ch)
|
||||
.ok_or_else(|| format!("Eclass {} not found", ch.as_ref()))?;
|
||||
if label.contains("IR") || label.contains("IList") {
|
||||
if is_search_choice_eclass(label) {
|
||||
let n = choices
|
||||
.get(ch)
|
||||
.ok_or_else(|| format!("No choice for reachable eclass {}", ch.as_ref()))?;
|
||||
@@ -1745,14 +1805,12 @@ pub fn validate_choice_set<'a>(
|
||||
if op_name == "Op" {
|
||||
// Normalized op — check OpKind child
|
||||
if let Some(kind_eclass) = children.first() {
|
||||
if let Some((_, kind_enodes)) = egraph.eclasses.get(kind_eclass) {
|
||||
if let Some(kn) = kind_enodes.first() {
|
||||
let kind_name = &egraph.enodes[kn].0;
|
||||
if kind_name != "CustomOpKind"
|
||||
&& !ops.iter().any(|op| op.sort().name == *kind_name)
|
||||
{
|
||||
return Err(format!("No extractor for OpKind {kind_name}"));
|
||||
}
|
||||
if let Some(kn) = choices.get(kind_eclass) {
|
||||
let kind_name = &egraph.enodes[kn].0;
|
||||
if kind_name != "CustomOpKind"
|
||||
&& !ops.iter().any(|op| op.sort().name == *kind_name)
|
||||
{
|
||||
return Err(format!("No extractor for OpKind {kind_name}"));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1813,9 +1871,7 @@ pub fn extract_generation<'a>(
|
||||
let mutable_classes: Vec<&ClassId> = egraph
|
||||
.eclasses
|
||||
.iter()
|
||||
.filter(|(_, (label, enodes))| {
|
||||
(label.contains("IR") || label.contains("IList")) && enodes.len() > 1
|
||||
})
|
||||
.filter(|(_, (label, enodes))| is_search_choice_eclass(label) && enodes.len() > 1)
|
||||
.map(|(class_id, _)| class_id)
|
||||
.collect();
|
||||
|
||||
@@ -1848,9 +1904,21 @@ pub fn extract_generation<'a>(
|
||||
for _ in 0..rng.random_range(1..=mutations_per_generation) {
|
||||
// Pick a random mutable eclass
|
||||
let class_id = mutable_classes[rng.random_range(0..mutable_classes.len())];
|
||||
let (_, enodes) = &egraph.eclasses[class_id];
|
||||
let (label, enodes) = &egraph.eclasses[class_id];
|
||||
// Pick a random enode for this class
|
||||
let new_node = &enodes[rng.random_range(0..enodes.len())];
|
||||
let consistent_opkind_nodes: Vec<&NodeId> = if label == "OpKind" {
|
||||
enodes
|
||||
.iter()
|
||||
.filter(|n| opkind_metadata_consistent(egraph, n))
|
||||
.collect()
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
let new_node = if !consistent_opkind_nodes.is_empty() {
|
||||
consistent_opkind_nodes[rng.random_range(0..consistent_opkind_nodes.len())]
|
||||
} else {
|
||||
&enodes[rng.random_range(0..enodes.len())]
|
||||
};
|
||||
// Insert returns the previous binding (if any); fold the diff
|
||||
// into the running hash. If the new pick equals the old one,
|
||||
// the two XORs cancel and `child_hash` is unchanged — exactly
|
||||
@@ -1932,7 +2000,7 @@ pub fn egglog_to_llir_from_root<'a>(
|
||||
let mut reachability_stack = vec![choices[root_class]];
|
||||
while let Some(r) = reachability_stack.pop() {
|
||||
for ch in &egraph.enodes[r].1 {
|
||||
if egraph.eclasses[ch].0.contains("IR") || egraph.eclasses[ch].0.contains("IList") {
|
||||
if is_search_choice_eclass(&egraph.eclasses[ch].0) {
|
||||
let n = choices[ch];
|
||||
if !reachable.contains(n) {
|
||||
reachability_stack.push(n);
|
||||
@@ -1968,69 +2036,19 @@ pub fn egglog_to_llir_from_root<'a>(
|
||||
// structurally-equivalent kind enodes whose ELIST children
|
||||
// were unioned but resolve (under the extractor's first-enode
|
||||
// walk) to inconsistent lengths — picking such an enode causes
|
||||
// a downstream `flatten_strides` length mismatch. Prefer the
|
||||
// first kind enode whose ELIST children all walk to the same
|
||||
// length; fall back to the original first enode if no
|
||||
// consistent candidate exists (rare; only happens for ops
|
||||
// outside the runnable subgraph).
|
||||
// a downstream `flatten_strides` length mismatch. Candidate
|
||||
// generation filters these out where possible; this fallback is
|
||||
// structural only and does not rank backend implementations.
|
||||
let kind_enodes = &egraph.eclasses[kind_eclass].1;
|
||||
let extractor_length = |eclass_id: &ClassId| -> Option<usize> {
|
||||
let mut len = 0usize;
|
||||
let mut cur_eclass: ClassId = eclass_id.clone();
|
||||
let mut visited: FxHashSet<ClassId> = FxHashSet::default();
|
||||
loop {
|
||||
if !visited.insert(cur_eclass.clone()) {
|
||||
return None;
|
||||
}
|
||||
let (label, enodes) = egraph.eclasses.get(&cur_eclass)?;
|
||||
if !label.contains("List") {
|
||||
return Some(len);
|
||||
}
|
||||
let head_enode = enodes.first()?;
|
||||
let head_label = &egraph.enodes[head_enode].0;
|
||||
if head_label == "ENil" || head_label == "INil" {
|
||||
return Some(len);
|
||||
}
|
||||
if head_label != "ECons" && head_label != "ICons" {
|
||||
return Some(len);
|
||||
}
|
||||
len += 1;
|
||||
let children = &egraph.enodes[head_enode].1;
|
||||
if children.len() < 2 {
|
||||
return Some(len);
|
||||
}
|
||||
cur_eclass = children[1].clone();
|
||||
}
|
||||
};
|
||||
let elist_lens_for = |n: &NodeId| -> Vec<usize> {
|
||||
egraph.enodes[n]
|
||||
.1
|
||||
.iter()
|
||||
.filter_map(|c| {
|
||||
let lbl = &egraph.eclasses[c].0;
|
||||
if lbl.contains("List") {
|
||||
extractor_length(c)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
};
|
||||
let is_consistent = |n: &NodeId| -> bool {
|
||||
let lens = elist_lens_for(n);
|
||||
lens.is_empty() || lens.iter().all(|l| *l == lens[0])
|
||||
};
|
||||
let is_kernel = |n: &NodeId| -> bool {
|
||||
let l = &egraph.enodes[n].0;
|
||||
l.starts_with("Kernel") || l.starts_with("Fused")
|
||||
};
|
||||
// Prefer a consistent kernel kind; then any consistent;
|
||||
// then any kernel; then fall back to first.
|
||||
let kind_enode = kind_enodes
|
||||
.iter()
|
||||
.find(|n| is_kernel(n) && is_consistent(n))
|
||||
.or_else(|| kind_enodes.iter().find(|n| is_consistent(n)))
|
||||
.or_else(|| kind_enodes.iter().find(|n| is_kernel(n)))
|
||||
let kind_enode = choices
|
||||
.get(kind_eclass)
|
||||
.copied()
|
||||
.filter(|n| opkind_metadata_consistent(egraph, n))
|
||||
.or_else(|| {
|
||||
kind_enodes
|
||||
.iter()
|
||||
.find(|n| opkind_metadata_consistent(egraph, n))
|
||||
})
|
||||
.unwrap_or(&kind_enodes[0]);
|
||||
let kind_label = &egraph.enodes[kind_enode].0;
|
||||
|
||||
@@ -2039,8 +2057,7 @@ pub fn egglog_to_llir_from_root<'a>(
|
||||
.1
|
||||
.iter()
|
||||
.map(|c| {
|
||||
if egraph.eclasses[c].0.contains("IR") || egraph.eclasses[c].0.contains("IList")
|
||||
{
|
||||
if is_search_choice_eclass(&egraph.eclasses[c].0) {
|
||||
choices[c]
|
||||
} else {
|
||||
&egraph.eclasses[c].1[0]
|
||||
@@ -2085,8 +2102,7 @@ pub fn egglog_to_llir_from_root<'a>(
|
||||
.1
|
||||
.iter()
|
||||
.map(|c| {
|
||||
if egraph.eclasses[c].0.contains("IR") || egraph.eclasses[c].0.contains("IList")
|
||||
{
|
||||
if is_search_choice_eclass(&egraph.eclasses[c].0) {
|
||||
choices[c]
|
||||
} else {
|
||||
&egraph.eclasses[c].1[0]
|
||||
@@ -2165,10 +2181,11 @@ mod tests {
|
||||
let egraph = egraph(vec![
|
||||
eclass("a", "IR", 2),
|
||||
eclass("b", "IList", 3),
|
||||
eclass("op", "OpKind", 5),
|
||||
eclass("c", "Shape", 99),
|
||||
]);
|
||||
|
||||
assert_eq!(count_choice_sets_up_to(&egraph, 100), 6);
|
||||
assert_eq!(count_choice_sets_up_to(&egraph, 100), 30);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
110
src/graph.rs
110
src/graph.rs
@@ -173,15 +173,16 @@ impl BuildSearchSpaceOptions {
|
||||
pub struct SearchOptions {
|
||||
/// Maximum number of graphs to evaluate
|
||||
pub limit: usize,
|
||||
/// Number of offspring per generation (default: 30)
|
||||
/// Number of offspring per generation (default: 10)
|
||||
pub generation_size: usize,
|
||||
/// Number of mutations applied to each offspring (default: 30)
|
||||
/// Number of mutations applied to each offspring (default: 10)
|
||||
pub mutations: usize,
|
||||
/// Number of profiling trials per candidate (default: 3)
|
||||
pub trials: usize,
|
||||
/// Number of best genomes to keep as parents per generation (default: 1)
|
||||
pub keep_best: usize,
|
||||
/// Optional per-candidate profiling timeout.
|
||||
/// Per-candidate profiling timeout. If a profile call reaches this budget,
|
||||
/// that candidate is discarded and search continues.
|
||||
pub profile_timeout: Option<std::time::Duration>,
|
||||
/// Optional per-group search timeout.
|
||||
pub group_timeout: Option<std::time::Duration>,
|
||||
@@ -194,11 +195,11 @@ impl SearchOptions {
|
||||
pub fn new(limit: usize) -> Self {
|
||||
Self {
|
||||
limit,
|
||||
generation_size: 30,
|
||||
mutations: 30,
|
||||
generation_size: 10,
|
||||
mutations: 10,
|
||||
trials: 3,
|
||||
keep_best: 1,
|
||||
profile_timeout: None,
|
||||
profile_timeout: Some(std::time::Duration::from_secs(1)),
|
||||
group_timeout: None,
|
||||
profile_dims: FxHashMap::default(),
|
||||
}
|
||||
@@ -315,6 +316,27 @@ fn maybe_dump_selected_llir(label: &str, dyn_map: &FxHashMap<char, usize>, llir:
|
||||
}
|
||||
}
|
||||
|
||||
fn random_choice_generation<'a, G: rand::Rng>(
|
||||
egraph: &'a SerializedEGraph,
|
||||
generation_size: usize,
|
||||
prev_selected: &mut FxHashSet<u64>,
|
||||
rng: &mut G,
|
||||
) -> Vec<crate::egglog_utils::EGraphChoiceSet<'a>> {
|
||||
let mut generation = Vec::with_capacity(generation_size);
|
||||
let max_attempts = generation_size.saturating_mul(100);
|
||||
let mut attempts = 0;
|
||||
|
||||
while generation.len() < generation_size && attempts < max_attempts {
|
||||
attempts += 1;
|
||||
let genome = random_initial_choice(egraph, rng);
|
||||
if prev_selected.insert(hash_choice_set(&genome)) {
|
||||
generation.push(genome);
|
||||
}
|
||||
}
|
||||
|
||||
generation
|
||||
}
|
||||
|
||||
/// A Luminal compute graph.
|
||||
///
|
||||
/// All computation is represented as a directed acyclic graph.
|
||||
@@ -1347,6 +1369,11 @@ impl Graph {
|
||||
let mut list_cache = FxHashMap::default();
|
||||
let mut expr_cache = FxHashMap::default();
|
||||
runtime.clear_intermediate_buffers();
|
||||
let profile_timed_out = |elapsed: std::time::Duration| {
|
||||
options
|
||||
.profile_timeout
|
||||
.is_some_and(|timeout| elapsed >= timeout)
|
||||
};
|
||||
|
||||
// Find a viable initial genome (may need multiple attempts if some panic)
|
||||
let (mut best_genome, mut best_metric, display, mut n_graphs);
|
||||
@@ -1385,13 +1412,15 @@ impl Graph {
|
||||
// unrolled graph size.
|
||||
collapse_loops_to_first_iter(&mut graph);
|
||||
runtime.clear_intermediate_buffers();
|
||||
let profile_start = std::time::Instant::now();
|
||||
let (rep_metric, rep_display) = runtime.profile(
|
||||
&graph,
|
||||
&profile_dyn_map,
|
||||
options.trials,
|
||||
options.profile_timeout,
|
||||
);
|
||||
let has_nan = runtime.has_nan_outputs(&graph, &profile_dyn_map);
|
||||
let timed_out = profile_timed_out(profile_start.elapsed());
|
||||
let has_nan = !timed_out && runtime.has_nan_outputs(&graph, &profile_dyn_map);
|
||||
(
|
||||
rep_metric,
|
||||
append_memory_display(
|
||||
@@ -1401,11 +1430,12 @@ impl Graph {
|
||||
runtime.allocated_intermediate_buffer_bytes(),
|
||||
),
|
||||
has_nan,
|
||||
timed_out,
|
||||
)
|
||||
}));
|
||||
|
||||
match result {
|
||||
Ok((metric, disp, false)) => {
|
||||
Ok((metric, disp, false, false)) => {
|
||||
best_genome = genome;
|
||||
best_metric = R::aggregate_profile_metrics(&[metric]);
|
||||
display = disp;
|
||||
@@ -1435,6 +1465,7 @@ impl Graph {
|
||||
// Track top-N parents for offspring generation
|
||||
let mut parents: Vec<(R::ProfileMetric, crate::egglog_utils::EGraphChoiceSet<'_>)> =
|
||||
vec![(best_metric.clone(), best_genome.clone())];
|
||||
let mut resample_generation = false;
|
||||
|
||||
while n_graphs < search_limit {
|
||||
if options
|
||||
@@ -1446,26 +1477,33 @@ impl Graph {
|
||||
|
||||
// Generate offspring from all parents, dividing budget evenly
|
||||
let budget = (search_limit - n_graphs).min(options.generation_size);
|
||||
let per_parent = budget.div_ceil(parents.len());
|
||||
let mut all_offspring = Vec::new();
|
||||
for (_, parent_genome) in &parents {
|
||||
let remaining = budget.saturating_sub(all_offspring.len());
|
||||
if remaining == 0 {
|
||||
break;
|
||||
let all_offspring = if resample_generation {
|
||||
random_choice_generation(egraph, budget, &mut prev_selected, rng)
|
||||
} else {
|
||||
let per_parent = budget.div_ceil(parents.len());
|
||||
let mut offspring = Vec::new();
|
||||
for (_, parent_genome) in &parents {
|
||||
let remaining = budget.saturating_sub(offspring.len());
|
||||
if remaining == 0 {
|
||||
break;
|
||||
}
|
||||
offspring.extend(extract_generation(
|
||||
egraph,
|
||||
parent_genome,
|
||||
per_parent.min(remaining),
|
||||
options.mutations,
|
||||
&mut prev_selected,
|
||||
rng,
|
||||
));
|
||||
}
|
||||
all_offspring.extend(extract_generation(
|
||||
egraph,
|
||||
parent_genome,
|
||||
per_parent.min(remaining),
|
||||
options.mutations,
|
||||
&mut prev_selected,
|
||||
rng,
|
||||
));
|
||||
}
|
||||
offspring
|
||||
};
|
||||
if all_offspring.is_empty() {
|
||||
break;
|
||||
}
|
||||
|
||||
let mut generation_found_non_timeout = false;
|
||||
|
||||
for genome in all_offspring {
|
||||
if options
|
||||
.group_timeout
|
||||
@@ -1502,13 +1540,16 @@ impl Graph {
|
||||
// before profiling — see initial-genome path.
|
||||
collapse_loops_to_first_iter(&mut llir_graph);
|
||||
runtime.clear_intermediate_buffers();
|
||||
let profile_start = std::time::Instant::now();
|
||||
let (rep_metric, rep_display) = runtime.profile(
|
||||
&llir_graph,
|
||||
&profile_dyn_map,
|
||||
options.trials,
|
||||
options.profile_timeout,
|
||||
);
|
||||
let has_nan = runtime.has_nan_outputs(&llir_graph, &profile_dyn_map);
|
||||
let timed_out = profile_timed_out(profile_start.elapsed());
|
||||
let has_nan =
|
||||
!timed_out && runtime.has_nan_outputs(&llir_graph, &profile_dyn_map);
|
||||
(
|
||||
rep_metric,
|
||||
append_memory_display(
|
||||
@@ -1518,15 +1559,28 @@ impl Graph {
|
||||
runtime.allocated_intermediate_buffer_bytes(),
|
||||
),
|
||||
has_nan,
|
||||
timed_out,
|
||||
)
|
||||
}));
|
||||
|
||||
let (new_metric, display_metric) = match profile_result {
|
||||
Ok((metric, display, false)) => {
|
||||
Ok((metric, display, false, false)) => {
|
||||
generation_found_non_timeout = true;
|
||||
(R::aggregate_profile_metrics(&[metric]), display)
|
||||
}
|
||||
Ok((_, _, true)) | Err(_) => {
|
||||
// NaN or panic — redraw bars and skip
|
||||
Ok((_, _, _, true)) | Err(_) => {
|
||||
// Timed out or panicked — redraw bars and skip.
|
||||
for _ in 1..n_bar_lines {
|
||||
print!("\x1b[1A");
|
||||
}
|
||||
print!("\r\x1b[2K");
|
||||
render_bars(n_graphs, search_limit, bucket_progress);
|
||||
std::io::stdout().flush().unwrap();
|
||||
continue;
|
||||
}
|
||||
Ok((_, _, true, false)) => {
|
||||
generation_found_non_timeout = true;
|
||||
// Completed profiling but produced NaNs — redraw bars and skip.
|
||||
for _ in 1..n_bar_lines {
|
||||
print!("\x1b[1A");
|
||||
}
|
||||
@@ -1577,6 +1631,8 @@ impl Graph {
|
||||
render_bars(n_graphs, search_limit, bucket_progress);
|
||||
std::io::stdout().flush().unwrap();
|
||||
}
|
||||
|
||||
resample_generation = !generation_found_non_timeout;
|
||||
}
|
||||
|
||||
// Clear progress bars
|
||||
|
||||
206
src/hlir.rs
206
src/hlir.rs
@@ -1613,8 +1613,7 @@ fn bin_fn<A: Copy>(
|
||||
a_ind: StridedIterator,
|
||||
a: &[A],
|
||||
b_ind: StridedIterator,
|
||||
b: &NativeData,
|
||||
b_get: impl Fn(&NativeData, usize) -> A,
|
||||
b: &[A],
|
||||
op: impl Fn(A, A) -> A,
|
||||
) -> Vec<A> {
|
||||
let a_shape = a_ind.shape.clone();
|
||||
@@ -1634,7 +1633,36 @@ fn bin_fn<A: Copy>(
|
||||
"bin_fn: b index {j} out of bounds (b.len={}), shape={b_shape:?}, strides={b_strides:?}",
|
||||
b.len(),
|
||||
);
|
||||
op(a[i], b_get(b, j))
|
||||
op(a[i], b[j])
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn bin_cmp_fn<A: Copy>(
|
||||
a_ind: StridedIterator,
|
||||
a: &[A],
|
||||
b_ind: StridedIterator,
|
||||
b: &[A],
|
||||
op: impl Fn(A, A) -> bool,
|
||||
) -> Vec<bool> {
|
||||
let a_shape = a_ind.shape.clone();
|
||||
let a_strides = a_ind.strides.clone();
|
||||
let b_shape = b_ind.shape.clone();
|
||||
let b_strides = b_ind.strides.clone();
|
||||
a_ind
|
||||
.zip(b_ind)
|
||||
.map(|(i, j)| {
|
||||
assert!(
|
||||
i < a.len(),
|
||||
"bin_cmp_fn: a index {i} out of bounds (a.len={}), shape={a_shape:?}, strides={a_strides:?}",
|
||||
a.len(),
|
||||
);
|
||||
assert!(
|
||||
j < b.len(),
|
||||
"bin_cmp_fn: b index {j} out of bounds (b.len={}), shape={b_shape:?}, strides={b_strides:?}",
|
||||
b.len(),
|
||||
);
|
||||
op(a[i], b[j])
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
@@ -1708,20 +1736,23 @@ impl NativeOp for Add {
|
||||
StridedIterator::new(&self.shape, &self.a_strides, dyn_map),
|
||||
StridedIterator::new(&self.shape, &self.b_strides, dyn_map),
|
||||
);
|
||||
match a {
|
||||
NativeData::F32(a) => {
|
||||
NativeData::F32(bin_fn(a_ind, a, b_ind, b, NativeData::f32, |x, y| x + y))
|
||||
match (a, b) {
|
||||
(NativeData::F32(a), NativeData::F32(b)) => {
|
||||
NativeData::F32(bin_fn(a_ind, a, b_ind, b, |x, y| x + y))
|
||||
}
|
||||
NativeData::F16(a) => {
|
||||
NativeData::F16(bin_fn(a_ind, a, b_ind, b, NativeData::f16, |x, y| x + y))
|
||||
(NativeData::F16(a), NativeData::F16(b)) => {
|
||||
NativeData::F16(bin_fn(a_ind, a, b_ind, b, |x, y| x + y))
|
||||
}
|
||||
NativeData::Bf16(a) => {
|
||||
NativeData::Bf16(bin_fn(a_ind, a, b_ind, b, NativeData::bf16, |x, y| x + y))
|
||||
(NativeData::Bf16(a), NativeData::Bf16(b)) => {
|
||||
NativeData::Bf16(bin_fn(a_ind, a, b_ind, b, |x, y| x + y))
|
||||
}
|
||||
NativeData::Int(a) => {
|
||||
NativeData::Int(bin_fn(a_ind, a, b_ind, b, NativeData::i32, |x, y| x + y))
|
||||
(NativeData::Int(a), NativeData::Int(b)) => {
|
||||
NativeData::Int(bin_fn(a_ind, a, b_ind, b, |x, y| x + y))
|
||||
}
|
||||
NativeData::Bool(_) => panic!("Cannot add Bool tensors, cast to F32 first"),
|
||||
(NativeData::Bool(_), NativeData::Bool(_)) => {
|
||||
panic!("Cannot add Bool tensors, cast to F32 first")
|
||||
}
|
||||
_ => panic!("Add inputs must have the same dtype"),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1795,20 +1826,23 @@ impl NativeOp for Mul {
|
||||
StridedIterator::new(&self.shape, &self.a_strides, dyn_map),
|
||||
StridedIterator::new(&self.shape, &self.b_strides, dyn_map),
|
||||
);
|
||||
match a {
|
||||
NativeData::F32(a) => {
|
||||
NativeData::F32(bin_fn(a_ind, a, b_ind, b, NativeData::f32, |x, y| x * y))
|
||||
match (a, b) {
|
||||
(NativeData::F32(a), NativeData::F32(b)) => {
|
||||
NativeData::F32(bin_fn(a_ind, a, b_ind, b, |x, y| x * y))
|
||||
}
|
||||
NativeData::F16(a) => {
|
||||
NativeData::F16(bin_fn(a_ind, a, b_ind, b, NativeData::f16, |x, y| x * y))
|
||||
(NativeData::F16(a), NativeData::F16(b)) => {
|
||||
NativeData::F16(bin_fn(a_ind, a, b_ind, b, |x, y| x * y))
|
||||
}
|
||||
NativeData::Bf16(a) => {
|
||||
NativeData::Bf16(bin_fn(a_ind, a, b_ind, b, NativeData::bf16, |x, y| x * y))
|
||||
(NativeData::Bf16(a), NativeData::Bf16(b)) => {
|
||||
NativeData::Bf16(bin_fn(a_ind, a, b_ind, b, |x, y| x * y))
|
||||
}
|
||||
NativeData::Int(a) => {
|
||||
NativeData::Int(bin_fn(a_ind, a, b_ind, b, NativeData::i32, |x, y| x * y))
|
||||
(NativeData::Int(a), NativeData::Int(b)) => {
|
||||
NativeData::Int(bin_fn(a_ind, a, b_ind, b, |x, y| x * y))
|
||||
}
|
||||
NativeData::Bool(_) => panic!("Cannot multiply Bool tensors, cast to F32 first"),
|
||||
(NativeData::Bool(_), NativeData::Bool(_)) => {
|
||||
panic!("Cannot multiply Bool tensors, cast to F32 first")
|
||||
}
|
||||
_ => panic!("Mul inputs must have the same dtype"),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1882,20 +1916,21 @@ impl NativeOp for Mod {
|
||||
StridedIterator::new(&self.shape, &self.a_strides, dyn_map),
|
||||
StridedIterator::new(&self.shape, &self.b_strides, dyn_map),
|
||||
);
|
||||
match a {
|
||||
NativeData::F32(a) => {
|
||||
NativeData::F32(bin_fn(a_ind, a, b_ind, b, NativeData::f32, |x, y| x % y))
|
||||
match (a, b) {
|
||||
(NativeData::F32(a), NativeData::F32(b)) => {
|
||||
NativeData::F32(bin_fn(a_ind, a, b_ind, b, |x, y| x % y))
|
||||
}
|
||||
NativeData::F16(a) => {
|
||||
NativeData::F16(bin_fn(a_ind, a, b_ind, b, NativeData::f16, |x, y| x % y))
|
||||
(NativeData::F16(a), NativeData::F16(b)) => {
|
||||
NativeData::F16(bin_fn(a_ind, a, b_ind, b, |x, y| x % y))
|
||||
}
|
||||
NativeData::Bf16(a) => {
|
||||
NativeData::Bf16(bin_fn(a_ind, a, b_ind, b, NativeData::bf16, |x, y| x % y))
|
||||
(NativeData::Bf16(a), NativeData::Bf16(b)) => {
|
||||
NativeData::Bf16(bin_fn(a_ind, a, b_ind, b, |x, y| x % y))
|
||||
}
|
||||
NativeData::Int(a) => {
|
||||
NativeData::Int(bin_fn(a_ind, a, b_ind, b, NativeData::i32, |x, y| x % y))
|
||||
(NativeData::Int(a), NativeData::Int(b)) => {
|
||||
NativeData::Int(bin_fn(a_ind, a, b_ind, b, |x, y| x % y))
|
||||
}
|
||||
NativeData::Bool(_) => panic!("Cannot mod Bool tensors"),
|
||||
(NativeData::Bool(_), NativeData::Bool(_)) => panic!("Cannot mod Bool tensors"),
|
||||
_ => panic!("Mod inputs must have the same dtype"),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1970,13 +2005,24 @@ impl NativeOp for LessThan {
|
||||
StridedIterator::new(&self.shape, &self.a_strides, dyn_map),
|
||||
StridedIterator::new(&self.shape, &self.b_strides, dyn_map),
|
||||
);
|
||||
// Comparison always returns Bool
|
||||
NativeData::Bool(
|
||||
a_ind
|
||||
.zip(b_ind)
|
||||
.map(|(i, j)| NativeData::f32(a, i) < NativeData::f32(b, j))
|
||||
.collect(),
|
||||
)
|
||||
match (a, b) {
|
||||
(NativeData::F32(a), NativeData::F32(b)) => {
|
||||
NativeData::Bool(bin_cmp_fn(a_ind, a, b_ind, b, |x, y| x < y))
|
||||
}
|
||||
(NativeData::F16(a), NativeData::F16(b)) => {
|
||||
NativeData::Bool(bin_cmp_fn(a_ind, a, b_ind, b, |x, y| x < y))
|
||||
}
|
||||
(NativeData::Bf16(a), NativeData::Bf16(b)) => {
|
||||
NativeData::Bool(bin_cmp_fn(a_ind, a, b_ind, b, |x, y| x < y))
|
||||
}
|
||||
(NativeData::Int(a), NativeData::Int(b)) => {
|
||||
NativeData::Bool(bin_cmp_fn(a_ind, a, b_ind, b, |x, y| x < y))
|
||||
}
|
||||
(NativeData::Bool(a), NativeData::Bool(b)) => {
|
||||
NativeData::Bool(bin_cmp_fn(a_ind, a, b_ind, b, |x, y| !x & y))
|
||||
}
|
||||
_ => panic!("LessThan inputs must have the same dtype"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2708,16 +2754,7 @@ impl NativeData {
|
||||
pub fn f32(&self, i: usize) -> f32 {
|
||||
match self {
|
||||
NativeData::F32(v) => v[i],
|
||||
NativeData::F16(v) => v[i].to_f32(),
|
||||
NativeData::Bf16(v) => v[i].to_f32(),
|
||||
NativeData::Int(v) => v[i] as f32,
|
||||
NativeData::Bool(v) => {
|
||||
if v[i] {
|
||||
1.0
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
}
|
||||
_ => panic!("NativeData::f32 called on non-F32 data"),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2725,10 +2762,7 @@ impl NativeData {
|
||||
pub fn f16(&self, i: usize) -> f16 {
|
||||
match self {
|
||||
NativeData::F16(v) => v[i],
|
||||
NativeData::F32(v) => f16::from_f32(v[i]),
|
||||
NativeData::Bf16(v) => f16::from_f32(v[i].to_f32()),
|
||||
NativeData::Int(v) => f16::from_f32(v[i] as f32),
|
||||
NativeData::Bool(v) => f16::from_f32(if v[i] { 1.0 } else { 0.0 }),
|
||||
_ => panic!("NativeData::f16 called on non-F16 data"),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2736,10 +2770,7 @@ impl NativeData {
|
||||
pub fn bf16(&self, i: usize) -> bf16 {
|
||||
match self {
|
||||
NativeData::Bf16(v) => v[i],
|
||||
NativeData::F32(v) => bf16::from_f32(v[i]),
|
||||
NativeData::F16(v) => bf16::from_f32(v[i].to_f32()),
|
||||
NativeData::Int(v) => bf16::from_f32(v[i] as f32),
|
||||
NativeData::Bool(v) => bf16::from_f32(if v[i] { 1.0 } else { 0.0 }),
|
||||
_ => panic!("NativeData::bf16 called on non-Bf16 data"),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2747,16 +2778,7 @@ impl NativeData {
|
||||
pub fn i32(&self, i: usize) -> i32 {
|
||||
match self {
|
||||
NativeData::Int(v) => v[i],
|
||||
NativeData::F32(v) => v[i] as i32,
|
||||
NativeData::F16(v) => v[i].to_f32() as i32,
|
||||
NativeData::Bf16(v) => v[i].to_f32() as i32,
|
||||
NativeData::Bool(v) => {
|
||||
if v[i] {
|
||||
1
|
||||
} else {
|
||||
0
|
||||
}
|
||||
}
|
||||
_ => panic!("NativeData::i32 called on non-Int data"),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2764,10 +2786,50 @@ impl NativeData {
|
||||
pub fn bool(&self, i: usize) -> bool {
|
||||
match self {
|
||||
NativeData::Bool(v) => v[i],
|
||||
NativeData::F32(v) => v[i] != 0.0,
|
||||
NativeData::F16(v) => v[i].to_f32() != 0.0,
|
||||
NativeData::Bf16(v) => v[i].to_f32() != 0.0,
|
||||
NativeData::Int(v) => v[i] != 0,
|
||||
_ => panic!("NativeData::bool called on non-Bool data"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_f32_vec(&self) -> Vec<f32> {
|
||||
match self {
|
||||
NativeData::F32(v) => v.clone(),
|
||||
NativeData::F16(v) => v.iter().map(|v| v.to_f32()).collect(),
|
||||
NativeData::Bf16(v) => v.iter().map(|v| v.to_f32()).collect(),
|
||||
NativeData::Int(v) => v.iter().map(|v| *v as f32).collect(),
|
||||
NativeData::Bool(v) => v.iter().map(|v| if *v { 1.0 } else { 0.0 }).collect(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_f16_vec(&self) -> Vec<f16> {
|
||||
match self {
|
||||
NativeData::F32(v) => v.iter().copied().map(f16::from_f32).collect(),
|
||||
NativeData::F16(v) => v.clone(),
|
||||
NativeData::Bf16(v) => v.iter().map(|v| f16::from_f32(v.to_f32())).collect(),
|
||||
NativeData::Int(v) => v.iter().map(|v| f16::from_f32(*v as f32)).collect(),
|
||||
NativeData::Bool(v) => v
|
||||
.iter()
|
||||
.map(|v| f16::from_f32(if *v { 1.0 } else { 0.0 }))
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_i32_vec(&self) -> Vec<i32> {
|
||||
match self {
|
||||
NativeData::F32(v) => v.iter().map(|v| *v as i32).collect(),
|
||||
NativeData::F16(v) => v.iter().map(|v| v.to_f32() as i32).collect(),
|
||||
NativeData::Bf16(v) => v.iter().map(|v| v.to_f32() as i32).collect(),
|
||||
NativeData::Int(v) => v.clone(),
|
||||
NativeData::Bool(v) => v.iter().map(|v| if *v { 1 } else { 0 }).collect(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_bool_vec(&self) -> Vec<bool> {
|
||||
match self {
|
||||
NativeData::F32(v) => v.iter().map(|v| *v != 0.0).collect(),
|
||||
NativeData::F16(v) => v.iter().map(|v| v.to_f32() != 0.0).collect(),
|
||||
NativeData::Bf16(v) => v.iter().map(|v| v.to_f32() != 0.0).collect(),
|
||||
NativeData::Int(v) => v.iter().map(|v| *v != 0).collect(),
|
||||
NativeData::Bool(v) => v.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@ use std::fmt::Debug;
|
||||
use crate::egglog_utils::{
|
||||
extract_generation, hash_choice_set, random_initial_choice, validate_choice_set,
|
||||
};
|
||||
use crate::hlir::Output;
|
||||
use crate::hlir::{Add as NativeAdd, LessThan as NativeLessThan, NativeData, NativeOp, Output};
|
||||
use crate::prelude::*;
|
||||
use candle_core::{Device, Tensor};
|
||||
use proptest::prelude::*;
|
||||
@@ -430,6 +430,34 @@ fn fuzz_test_genome_execution() {
|
||||
|
||||
// --- Consumed-input semantics tests ---
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "Add inputs must have the same dtype")]
|
||||
fn native_add_rejects_mixed_dtypes() {
|
||||
let op = NativeAdd {
|
||||
shape: vec![2.into()],
|
||||
a_strides: vec![1.into()],
|
||||
b_strides: vec![1.into()],
|
||||
input_shapes: vec![],
|
||||
};
|
||||
let a = NativeData::F32(vec![1.0, 2.0]);
|
||||
let b = NativeData::Int(vec![1, 2]);
|
||||
op.execute(vec![&a, &b], &FxHashMap::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "LessThan inputs must have the same dtype")]
|
||||
fn native_less_than_rejects_mixed_dtypes() {
|
||||
let op = NativeLessThan {
|
||||
shape: vec![2.into()],
|
||||
a_strides: vec![1.into()],
|
||||
b_strides: vec![1.into()],
|
||||
input_shapes: vec![],
|
||||
};
|
||||
let a = NativeData::F32(vec![1.0, 2.0]);
|
||||
let b = NativeData::Int(vec![1, 2]);
|
||||
op.execute(vec![&a, &b], &FxHashMap::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_inputs_consumed_after_execute() {
|
||||
|
||||
Reference in New Issue
Block a user