mirror of
https://github.com/jafioti/luminal.git
synced 2026-06-01 21:49:47 +09:00
Compare commits
1 Commits
rust-examp
...
strided-in
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a4bda06d64 |
12
AGENTS.md
12
AGENTS.md
@@ -8,14 +8,4 @@ All other functionality is split into crates in the `crates/` directory. For ins
|
||||
## Testing Instructions
|
||||
- Find the CI plan in the .github/workflows folder.
|
||||
- Currently running `cargo test` in luminal_metal and luminal_cuda_lite require access to an Apple and Nvidia GPU respectively.
|
||||
- PRs must have no clippy errors and `cargo fmt` must be ran before a PR is submitted.
|
||||
|
||||
## Debugging and Correctness
|
||||
- Treat model examples as specifications of the intended architecture. Do not change model code, prompt templates, weights, or example logic to hide compiler/runtime/search bugs unless the model code is demonstrably semantically wrong.
|
||||
- When outputs are incorrect, first root-cause the failing compiler/runtime path. Prefer isolating the bad LLIR/HLIR graph, rewrite, op lowering, shape/stride assumption, layout contract, or runtime implementation that caused the mismatch.
|
||||
- Avoid narrow special-case fixes. A fix should state and enforce the general invariant it relies on, or explicitly document why the affected operation is only valid for a restricted layout/shape and ensure rewrites enforce that restriction.
|
||||
- For e-graph/search issues, assume all selectable LLIR graphs are intended to be semantically equivalent. If two selectable graphs disagree, debug the equivalence violation rather than selecting around the bad graph.
|
||||
- Add regression tests at the level where the bug occurred. Prefer tests that compare against a semantic reference such as `NativeRuntime` or a small independent reference, and use fixed seeds for any randomized search/fuzz test so failures are reproducible.
|
||||
|
||||
## Compiler Rewrite Boundary
|
||||
- All graph pattern matching and op selection must be expressed in egglog rewrites. Do not add Rust-side LLIR graph post-passes that search for op patterns, fuse kernels, select backend ops, or otherwise rewrite extracted graphs after egglog. If a backend needs a fused/specialized op, add the match and rewrite in egglog and let extraction produce that op directly.
|
||||
- PRs must have no clippy errors and `cargo fmt` must be ran before a PR is submitted.
|
||||
50
README.md
50
README.md
@@ -55,27 +55,23 @@ Luminal can run Q8 Llama 3 8B at ~80% of theoretical max performance on an H100.
|
||||
|
||||
The core of Luminal is and always will be minimal. It should be possible to understand the entire core library in an afternoon.
|
||||
|
||||
### PyTorch-native
|
||||
|
||||
Luminal directly integrates with PyTorch as a compiler backend. Simply do `torch.compile(model, backend=luminal_cuda)` to compile your PyTorch models. We also have an excellent tensor API in Rust.
|
||||
|
||||
### RISC-style architecture
|
||||
|
||||
Everything in Luminal boils down to 15 primitive ops:
|
||||
Everything in Luminal boils down to 14 primitive ops:
|
||||
|
||||
- Unary - `Log2, Exp2, Sin, Sqrt, Recip`
|
||||
- Binary - `Add, Mul, Mod, LessThan`
|
||||
- Other - `SumReduce, MaxReduce, Iota, Gather, Scatter, Cast`
|
||||
- Other - `SumReduce, MaxReduce, Iota, Gather, Cast`
|
||||
|
||||
These ops are enough to support transformers, convnets, and nearly every popular model in the world.
|
||||
These ops are enough to support transformers, convnets, and nearly every popular model.
|
||||
|
||||
### Search
|
||||
|
||||
The best heuristic is no heuristic. Luminal tries to search every possible decision to give the compiler the flexibility to discover complex optimizations. This allows us to automatically discover Flash Attention and other similarly complex optimizations without relying on hand-written operations or heuristics. It also allows us to stay extremely small and simple long into the future and beat the performance of far larger frameworks.
|
||||
The best heuristic is no heuristic. We try to search every possible decision to give the compiler the most flexibility to discover complex optimizations. This allows us to automatically derive Flash Attention and other similarly complex rewrites. It also allows us to stay extremely small long into the future and beat the performance of far larger frameworks with tons of handwritten kernels.
|
||||
|
||||
### Native
|
||||
|
||||
The current ML ecosystem is too fragmented, and the solution isn't another layer of abstraction. Luminal is written in rust, and interacts directly with the accelerator APIs (CUDA, Metal, etc.). No indirections or abstractions, compatability layers, docker containers, or virtual environments. Just a statically-linked rust crate.
|
||||
The current ML ecosystem is too fragmented, and the solution isn't another layer of abstraction. Luminal is written in rust, and interacts directly with the CUDA / Metal APIs. No indirections or abstractions, docker containers, or virtual environments. Just a statically-linked rust crate.
|
||||
|
||||
### Validated against Pytorch
|
||||
|
||||
@@ -89,45 +85,39 @@ Most deep learning libraries are eager-first, meaning each op call directly oper
|
||||
|
||||
However, this isn't great for performance. What makes sense for a developer doesn't work well for the machine, in the same way that no one writes assembly by hand. Most libraries try to fix this problem by tacking on operator fusion or JIT compilation to try to change the compilation flow to something better for the machine. Turns out this is [super](https://docs.pytorch.org/docs/stable/torch.compiler_dynamo_overview.html) [difficult](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) [even](https://pytorch.org/docs/stable/jit.html) [for](https://pytorch.org/docs/stable/fx.html#torch.fx.symbolic_trace) Pytorch!
|
||||
|
||||
### What about XLA?
|
||||
|
||||
XLA, torch.compile, TVM, and other traditional compiler stacks suffer from complexity explosion. They are made up of a very large set of destructive (one-direction) rewrite rules that lower and optimize a graph from a high-level representation to low-level machine code. But since these rules are destructive, they are required to only fire when it's certian that there's a performance benefit. This leads to the rules becoming very complex, special-cased, and numerous. Once additional hardware backends, model architectures, and new dtypes get thrown in, they suffer from the weight of their complexity and often produce very suboptimal code, requiring DSLs like Pallas or Triton to regain performance.
|
||||
|
||||
### Compile everything
|
||||
|
||||
A core tenet of Luminal is ahead-of-time compilation. Whenever possible, push everything to compile time and leave nothing to run time. Luminal takes an approach more similar to [XLA](https://www.tensorflow.org/xla), and [tinygrad](https://github.com/tinygrad/tinygrad). Everything's static here. When you write out an expression like `x + y`, no actual computation happens. The operation is recorded to a directed acyclic computation graph for execution later. Only once `graph.execute()` is ran does the computation happen. _But isn't that just lazy execution?_ Yes it is! But in luminal **everything is done this way**. All neural networks are built up as a static computation graphs, compiled, and executed later.
|
||||
|
||||
### First-class dynamism
|
||||
|
||||
A fully-static world would be nice, but we live in a world of nessecary dynamism. So we model dynamic shapes natively, as symbolic dimensions. Luminal supports arbitrary symbolic dimensions, including complex expressions, to give us shapes like `(s, 4096)`, `(b, h, w + 3)`, etc. This rich representation gives the compiler full visibility into shapes and lets it still do aggressive specialization.
|
||||
A core tenet of Luminal is ahead-of-time compilation. Whenever possible, push everything to compile time and leave nothing to run time. Luminal takes an approach more similar to [XLA](https://www.tensorflow.org/xla), and [tinygrad](https://github.com/tinygrad/tinygrad). Everything's static here. When you write out an expression like `x + y`, no actual computation happens. The operation is recorded to a directed acyclic computation graph for execution later. Only once `graph.execute()` is ran does the computation happen. _But isn't that just lazy execution?_ Yes it is! But in luminal **everything is done this way**. All neural networks are built up as one or a few static computation graphs, compiled, and executed later.
|
||||
|
||||
**But why?**
|
||||
|
||||
A consequence of this is that the actual computation that gets ran can be radically different than the code that was written. Since we have an entire neural network fully represented in a compute graph, Luminal has global knowledge. This means we can push most ML complexity to the compiler. For instance, devices, datatypes, and even autograd is modeled ahead of time and optimized by the compiler!
|
||||
A consequence of this is that the actual computation that gets ran can be radically different than the code that was written. Since we have an entire neural network fully represented in a compute graph, our compilers have global knowledge. This means we can push most ML complexity to the compilers. For instance, devices, datatypes, and execution schedules are all handled by compliers. Even autograd is handled by a compiler!
|
||||
|
||||
Now we can do:
|
||||
|
||||
- Aggressive kernel fusion
|
||||
- Shape-specific kernels compiled at runtime
|
||||
- Low-precision dtypes (mxfp4, nvfp4, fp8, etc.)
|
||||
- Complex mutli-device parallelism topologies, searched ahead-of-time
|
||||
- Networks can be written in generic code, but compiled and ran fast on hyper-specific architectures
|
||||
- Devices and Dtypes are handled through compilers (just run the CUDA compiler to convert the graph to use CUDA kernels, then the fp16 compiler to convert to half-precision kernels)
|
||||
- Networks can be written in generic code, but compiled and ran fast on hyper-specific architectures (try writing a PyTorch network that works with both TF32 dtypes and TPUs; get ready for if statement hell...)
|
||||
|
||||
## Where are we?
|
||||
|
||||
- Native PyTorch support
|
||||
- Many kernel libraries supported in the search space (FlashInfer, cuBLASLt, etc.)
|
||||
- Many models implemented in our Rust tensor API in `examples/`.
|
||||
- Search is partially merged. We are between 1.0 and 2.0 (search), which will be completed within the next month or so.
|
||||
- Metal and Cuda are supported for running models on Macs and Nvidia GPUs respectively, in both full and half precision.
|
||||
- Full training support with graph-based autograd.
|
||||
- Llama 3, Phi 3, Whisper and Yolo v8 are implemented in `examples/`. See instructions above for running.
|
||||
- We have a small library of NN modules in `luminal_nn`, including transformers.
|
||||
- A significant amount of high-level ops are implemented in `hl_ops`. We are aiming to match the most used ~80% of the pytorch api.
|
||||
|
||||
Some things on the roadmap:
|
||||
|
||||
- More fine-grained dialects supporting thread- and warp-level intrinsics like TMA and tcgen.05
|
||||
- ROCm backend
|
||||
- More public infernce accelerator backends (coming very soon...)
|
||||
- Public benchmarking suite
|
||||
- Automatically searched model parallelism (TP, PP, EPS, EPR, SP, etc.)
|
||||
- Expand the search space to utilize Tensor Cores more flexibly
|
||||
- Bring cuda to parity with Metal
|
||||
- Add Blackwell intrinsics, such as TMEM and TMA
|
||||
- Build a ROCm backend
|
||||
- Build benchmarking suite to test against other libs
|
||||
- Distributed data, pipeline and tensor parallel.
|
||||
- Beat PT 2.0 perf on LLM inference _and_ training
|
||||
- Write compiler for quantum photonic retro encabulator
|
||||
- Build dyson swarm
|
||||
|
||||
|
||||
@@ -1,85 +0,0 @@
|
||||
import re
|
||||
|
||||
ANSI_ESCAPE = re.compile(r"\x1b\[[0-?]*[ -/]*[@-~]")
|
||||
|
||||
EXPECTED_OUTPUT = {
|
||||
"gemma4_moe": [
|
||||
"city of romance, art and culture",
|
||||
],
|
||||
"whisper": [
|
||||
"ask not what your country can do for you",
|
||||
],
|
||||
}
|
||||
|
||||
EXPECTED_CONCEPTS = {
|
||||
"llama": [
|
||||
["layers"],
|
||||
["neurons", "nodes"],
|
||||
["learn", "learning", "adapt"],
|
||||
["data", "patterns", "features"],
|
||||
],
|
||||
"gemma": [
|
||||
["neural network", "neural networks"],
|
||||
["nodes", "neurons"],
|
||||
["layers"],
|
||||
["weights"],
|
||||
["training", "learn", "learns"],
|
||||
],
|
||||
"qwen": [
|
||||
["neural network", "neural networks"],
|
||||
["computational model", "computational system"],
|
||||
["brain"],
|
||||
["layers"],
|
||||
["neurons", "nodes"],
|
||||
["learn", "learning", "training"],
|
||||
],
|
||||
"qwen3_moe": [
|
||||
["capital"],
|
||||
["france"],
|
||||
["paris"],
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def normalize_output(output: str) -> str:
|
||||
output = ANSI_ESCAPE.sub("", output)
|
||||
output = output.replace("\r", "\n")
|
||||
return re.sub(r"\s+", " ", output).casefold()
|
||||
|
||||
|
||||
def validate_output(example: str, output: str):
|
||||
normalized_output = normalize_output(output)
|
||||
|
||||
expected_concepts = EXPECTED_CONCEPTS.get(example)
|
||||
if expected_concepts is not None:
|
||||
missing = [
|
||||
concept_group
|
||||
for concept_group in expected_concepts
|
||||
if not any(normalize_output(term) in normalized_output for term in concept_group)
|
||||
]
|
||||
if missing:
|
||||
expected = "\n - ".join(" / ".join(group) for group in expected_concepts)
|
||||
missing_terms = "\n - ".join(" / ".join(group) for group in missing)
|
||||
raise AssertionError(
|
||||
f"Output check failed for {example!r}.\n"
|
||||
f"Expected concept groups:\n - {expected}\n"
|
||||
f"Missing concept groups:\n - {missing_terms}"
|
||||
)
|
||||
|
||||
expected = ", ".join(" / ".join(group) for group in expected_concepts)
|
||||
print(f"\nOutput check passed for {example!r}: found concepts {expected}")
|
||||
return
|
||||
|
||||
expected_phrases = EXPECTED_OUTPUT.get(example)
|
||||
if expected_phrases is None:
|
||||
raise ValueError(f"No expected output phrases configured for example {example!r}")
|
||||
|
||||
for phrase in expected_phrases:
|
||||
if normalize_output(phrase) in normalized_output:
|
||||
print(f"\nOutput check passed for {example!r}: found {phrase!r}")
|
||||
return
|
||||
|
||||
expected = "\n - ".join(expected_phrases)
|
||||
raise AssertionError(
|
||||
f"Output check failed for {example!r}. Expected one of:\n - {expected}"
|
||||
)
|
||||
@@ -1,46 +0,0 @@
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
from example_output import validate_output
|
||||
|
||||
def run_and_capture(command: list[str], *, cwd: str, env: dict[str, str]) -> str:
|
||||
process = subprocess.Popen(
|
||||
command,
|
||||
cwd=cwd,
|
||||
env=env,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
)
|
||||
assert process.stdout is not None
|
||||
|
||||
chunks = []
|
||||
while True:
|
||||
chunk = process.stdout.read1(4096)
|
||||
if not chunk:
|
||||
break
|
||||
sys.stdout.buffer.write(chunk)
|
||||
sys.stdout.buffer.flush()
|
||||
chunks.append(chunk)
|
||||
|
||||
return_code = process.wait()
|
||||
output = b"".join(chunks).decode("utf-8", errors="replace")
|
||||
if return_code:
|
||||
raise subprocess.CalledProcessError(return_code, command, output=output)
|
||||
return output
|
||||
|
||||
|
||||
def main():
|
||||
repo_root = os.environ.get("GITHUB_WORKSPACE", os.getcwd())
|
||||
output = run_and_capture(
|
||||
["cargo", "run", "--release", "-p", "qwen", "--features", "metal"],
|
||||
cwd=repo_root,
|
||||
env=os.environ.copy(),
|
||||
)
|
||||
if "TTFT:" not in output or "TPOT:" not in output:
|
||||
raise AssertionError("qwen Metal example did not complete generation")
|
||||
validate_output("qwen", output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
@@ -20,8 +21,28 @@ hf_cache = modal.Volume.from_name(
|
||||
|
||||
WORKDIR = "/workspace/luminal"
|
||||
|
||||
EXAMPLE_CARGO_ARGS = {
|
||||
"qwen": ["--features", "cuda"],
|
||||
ANSI_ESCAPE = re.compile(r"\x1b\[[0-?]*[ -/]*[@-~]")
|
||||
|
||||
EXPECTED_OUTPUT = {
|
||||
"llama": [
|
||||
"complex system modeled after the structure and function of the human brain",
|
||||
],
|
||||
"gemma": [
|
||||
"recognize pictures of cats",
|
||||
"little detectives looking for specific features",
|
||||
],
|
||||
"qwen": [
|
||||
"computational model inspired by the structure and function of the human brain",
|
||||
],
|
||||
"qwen3_moe": [
|
||||
"The capital of France is Paris",
|
||||
],
|
||||
"gemma4_moe": [
|
||||
"city of romance, art and culture",
|
||||
],
|
||||
"whisper": [
|
||||
"ask not what your country can do for you",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@@ -51,6 +72,28 @@ def run_and_capture(command: list[str], *, cwd: str, env: dict[str, str]) -> str
|
||||
return output
|
||||
|
||||
|
||||
def normalize_output(output: str) -> str:
|
||||
output = ANSI_ESCAPE.sub("", output)
|
||||
output = output.replace("\r", "\n")
|
||||
return re.sub(r"\s+", " ", output).casefold()
|
||||
|
||||
|
||||
def validate_output(example: str, output: str):
|
||||
expected_phrases = EXPECTED_OUTPUT.get(example)
|
||||
if expected_phrases is None:
|
||||
raise ValueError(f"No expected output phrases configured for example {example!r}")
|
||||
|
||||
normalized_output = normalize_output(output)
|
||||
for phrase in expected_phrases:
|
||||
if normalize_output(phrase) in normalized_output:
|
||||
print(f"\nOutput check passed for {example!r}: found {phrase!r}")
|
||||
return
|
||||
|
||||
expected = "\n - ".join(expected_phrases)
|
||||
raise AssertionError(
|
||||
f"Output check failed for {example!r}. Expected one of:\n - {expected}"
|
||||
)
|
||||
|
||||
cuda_image = (
|
||||
modal.Image.from_registry(
|
||||
"nvcr.io/nvidia/pytorch:25.03-py3"
|
||||
@@ -80,8 +123,6 @@ cuda_image = (
|
||||
def run_example(example: str):
|
||||
"""Build and run a luminal example on a Modal GPU."""
|
||||
subprocess.run(["nvidia-smi"], check=True)
|
||||
sys.path.insert(0, f"{WORKDIR}/ci")
|
||||
from example_output import validate_output
|
||||
|
||||
run_env = {
|
||||
**os.environ,
|
||||
@@ -89,7 +130,7 @@ def run_example(example: str):
|
||||
"HF_HOME": HF_CACHE_PATH,
|
||||
}
|
||||
output = run_and_capture(
|
||||
["cargo", "run", "--release", *EXAMPLE_CARGO_ARGS.get(example, [])],
|
||||
["cargo", "run", "--release"],
|
||||
cwd=f"{WORKDIR}/examples/{example}",
|
||||
env=run_env,
|
||||
)
|
||||
|
||||
@@ -29,7 +29,6 @@ colorize = "*"
|
||||
|
||||
[dev-dependencies]
|
||||
candle-core = { version = "0.9.2", features = ["cuda"] }
|
||||
luminal_nn = { path = "../luminal_nn" }
|
||||
proptest = "1.9.0"
|
||||
rand = "0.9.2"
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||
|
||||
198
crates/luminal_cuda_lite/src/host/compute_attn_mask.rs
Normal file
198
crates/luminal_cuda_lite/src/host/compute_attn_mask.rs
Normal file
@@ -0,0 +1,198 @@
|
||||
//! ComputeAttnMask — fused op that computes the paged attention mask from indptrs.
|
||||
//!
|
||||
//! This op exists so the indptr tensors (qo_indptr, kv_indptr) are visible in the
|
||||
//! same e-graph chunk as the attention pattern, letting the FlashInfer egglog rule
|
||||
//! capture them directly.
|
||||
//!
|
||||
//! Inputs (3): q_pos (s,) Int, qo_indptr (r,) Int, kv_indptr (r,) Int.
|
||||
//! Output: mask (s, c) F32 where mask[i, j] = 0.0 (attend) or -1e10 (block).
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use luminal::{
|
||||
egglog_utils::{
|
||||
api::{Rule, SortDef, sort},
|
||||
base::{EXPRESSION, OP_KIND},
|
||||
extract_expr,
|
||||
},
|
||||
op::{EgglogOp, HLIROp, LLIROp},
|
||||
prelude::*,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
cudarc::driver::{CudaStream, result},
|
||||
host::{DeviceBuffer, HostOp},
|
||||
};
|
||||
|
||||
/// Computes the paged attention mask from indptr arrays.
|
||||
///
|
||||
/// The mask encodes both request-membership and causality:
|
||||
/// `mask[i, j] = 0.0` if query `i` and context `j` belong to the same request AND
|
||||
/// context `j`'s local position is `<= q_pos[i]`; `-1e10` otherwise.
|
||||
#[derive(Debug, Default)]
|
||||
pub struct ComputeAttnMask {
|
||||
pub s_dim: Expression,
|
||||
pub c_dim: Expression,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ComputeAttnMask {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "ComputeAttnMask(s={}, c={})", self.s_dim, self.c_dim)
|
||||
}
|
||||
}
|
||||
|
||||
impl HLIROp for ComputeAttnMask {
|
||||
fn to_egglog(&self, inputs: &[(NodeIndex, String)]) -> String {
|
||||
format!(
|
||||
"(Op (ComputeAttnMask {} {}) (ICons {} (ICons {} (ICons {} (INil)))))",
|
||||
self.s_dim.to_egglog(),
|
||||
self.c_dim.to_egglog(),
|
||||
inputs[0].1, // q_pos
|
||||
inputs[1].1, // qo_indptr
|
||||
inputs[2].1, // kv_indptr
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl EgglogOp for ComputeAttnMask {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
"ComputeAttnMask",
|
||||
&[("s_dim", EXPRESSION), ("c_dim", EXPRESSION)],
|
||||
)
|
||||
}
|
||||
|
||||
fn n_inputs(&self) -> usize {
|
||||
3
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
// No rewrites — inserted directly by model code.
|
||||
vec![]
|
||||
}
|
||||
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a luminal::egglog_utils::SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
_list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
let s_dim = extract_expr(egraph, kind_children[0], expr_cache).unwrap();
|
||||
let c_dim = extract_expr(egraph, kind_children[1], expr_cache).unwrap();
|
||||
let op = Self { s_dim, c_dim };
|
||||
let llir_op = LLIROp::new::<dyn HostOp>(Box::new(op) as Box<dyn HostOp>);
|
||||
(llir_op, input_enodes)
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
impl HostOp for ComputeAttnMask {
|
||||
fn execute(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
self_node: NodeIndex,
|
||||
inputs: &[NodeIndex],
|
||||
buffers: &FxHashMap<NodeIndex, DeviceBuffer>,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> anyhow::Result<()> {
|
||||
if inputs.len() < 3 {
|
||||
anyhow::bail!(
|
||||
"ComputeAttnMask expects 3 inputs (q_pos, qo_indptr, kv_indptr), got {}",
|
||||
inputs.len()
|
||||
);
|
||||
}
|
||||
|
||||
let s = self
|
||||
.s_dim
|
||||
.exec(dyn_map)
|
||||
.ok_or_else(|| anyhow::anyhow!("ComputeAttnMask s_dim unresolved"))?;
|
||||
let c = self
|
||||
.c_dim
|
||||
.exec(dyn_map)
|
||||
.ok_or_else(|| anyhow::anyhow!("ComputeAttnMask c_dim unresolved"))?;
|
||||
let r = *dyn_map
|
||||
.get(&'r')
|
||||
.ok_or_else(|| anyhow::anyhow!("ComputeAttnMask requires dynamic dim 'r'"))?;
|
||||
|
||||
let get_buf = |name: &str, node: NodeIndex| -> anyhow::Result<DeviceBuffer> {
|
||||
buffers.get(&node).copied().ok_or_else(|| {
|
||||
anyhow::anyhow!("ComputeAttnMask missing {name} buffer for {node:?}")
|
||||
})
|
||||
};
|
||||
|
||||
let q_pos_buf = get_buf("q_pos", inputs[0])?;
|
||||
let qo_indptr_buf = get_buf("qo_indptr", inputs[1])?;
|
||||
let kv_indptr_buf = get_buf("kv_indptr", inputs[2])?;
|
||||
let out_buf = get_buf("output", self_node)?;
|
||||
|
||||
let q_pos = dtoh_i32(stream, q_pos_buf.ptr(), s)?;
|
||||
let qo_indptr = dtoh_i32(stream, qo_indptr_buf.ptr(), r)?;
|
||||
let kv_indptr = dtoh_i32(stream, kv_indptr_buf.ptr(), r)?;
|
||||
|
||||
let mut mask = vec![-1e10f32; s * c];
|
||||
for i in 0..s {
|
||||
let q_req = indptr_to_request(&qo_indptr, i as i32);
|
||||
for j in 0..c {
|
||||
let c_req = indptr_to_request(&kv_indptr, j as i32);
|
||||
if q_req == c_req && q_req >= 0 {
|
||||
let c_local = j as i32 - kv_indptr[c_req as usize];
|
||||
if c_local <= q_pos[i] {
|
||||
mask[i * c + j] = 0.0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mask_bytes =
|
||||
unsafe { std::slice::from_raw_parts(mask.as_ptr() as *const u8, mask.len() * 4) };
|
||||
unsafe {
|
||||
let res = cudarc::driver::sys::cuMemcpyHtoD_v2(
|
||||
out_buf.ptr(),
|
||||
mask_bytes.as_ptr() as *const std::ffi::c_void,
|
||||
mask_bytes.len(),
|
||||
);
|
||||
if res != cudarc::driver::sys::CUresult::CUDA_SUCCESS {
|
||||
anyhow::bail!("ComputeAttnMask cuMemcpyHtoD failed: {res:?}");
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
self.s_dim * self.c_dim
|
||||
}
|
||||
|
||||
fn output_bytes(&self) -> Expression {
|
||||
self.output_size() * 4
|
||||
}
|
||||
|
||||
fn stats_name(&self) -> Option<&'static str> {
|
||||
Some("ComputeAttnMask")
|
||||
}
|
||||
}
|
||||
|
||||
fn dtoh_i32(stream: &Arc<CudaStream>, dev_ptr: u64, len: usize) -> anyhow::Result<Vec<i32>> {
|
||||
let mut host = vec![0u8; len * std::mem::size_of::<i32>()];
|
||||
unsafe {
|
||||
result::memcpy_dtoh_async(&mut host, dev_ptr, stream.cu_stream())?;
|
||||
}
|
||||
stream.synchronize()?;
|
||||
let v = unsafe {
|
||||
let mut bytes = std::mem::ManuallyDrop::new(host);
|
||||
Vec::from_raw_parts(bytes.as_mut_ptr() as *mut i32, len, len)
|
||||
};
|
||||
Ok(v)
|
||||
}
|
||||
|
||||
/// Given an indptr array `[0, a, b, ...]`, find which segment `idx` belongs to.
|
||||
/// Returns `count(indptr[i] <= idx) - 1`.
|
||||
fn indptr_to_request(indptr: &[i32], idx: i32) -> i32 {
|
||||
indptr.iter().filter(|&&v| v <= idx).count() as i32 - 1
|
||||
}
|
||||
@@ -8,436 +8,6 @@
|
||||
; describes as descriptor A = logical B, descriptor B = logical A, transa=T,
|
||||
; transb=N.
|
||||
|
||||
(rule
|
||||
(
|
||||
; Match the scaled FP8 linear form directly before the unscaled FP8
|
||||
; matmul rewrite can hide the quantize/dequant scale structure.
|
||||
(= ?scaled_activation (Op (Mul
|
||||
?activation_shape
|
||||
?raw_activation_strides
|
||||
?recip_activation_strides
|
||||
?activation_out_strides)
|
||||
(ICons ?raw_activation (ICons ?recip_input_scale (INil)))))
|
||||
(= ?recip_input_scale (Op (Recip
|
||||
?activation_shape
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ENil)))
|
||||
?recip_out_strides)
|
||||
(ICons ?input_scale (INil))))
|
||||
(= ?a (Op (Cast ?a_size ?a_dtype) (ICons ?scaled_activation (INil))))
|
||||
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?cast (Op (Cast ?size (F32)) (ICons ?sum (INil))))
|
||||
(= ?scale_product (Op (Mul (ENil) (ENil) (ENil) (ENil))
|
||||
(ICons ?input_scale (ICons ?weight_scale (INil)))))
|
||||
(= ?scaled (Op (Mul
|
||||
?out_shape
|
||||
?cast_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ENil)))
|
||||
?scaled_out_strides)
|
||||
(ICons ?cast (ICons ?scale_product (INil)))))
|
||||
(= ?cast_strides ?scaled_out_strides)
|
||||
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
|
||||
(= ?a_stride (ECons ?a_m_stride (ECons ?a_n_stride (ECons ?a_k_stride (ENil)))))
|
||||
(= ?b_stride (ECons ?b_m_stride (ECons ?b_n_stride (ECons ?b_k_stride (ENil)))))
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
(= ?a_m_stride (MMul (MIter) ?k))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MIter))
|
||||
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MMul (MIter) ?k))
|
||||
(= ?b_k_stride (MIter))
|
||||
|
||||
(= ?b_dtype (dtype ?b))
|
||||
(cublaslt_fp8_f32_output_pair ?a_dtype ?b_dtype)
|
||||
)
|
||||
(
|
||||
(let ?sgemm (Op (cublaslt_scaled
|
||||
?n ?m ?k
|
||||
"T" "N"
|
||||
"COL" "COL" "COL" "COL"
|
||||
?b_n_stride
|
||||
?a_m_stride
|
||||
?n
|
||||
?n
|
||||
(MNum 1)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
?b_dtype ?a_dtype (F32) (F32) "32F" "F32" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?b (ICons ?a (ICons ?weight_scale (ICons ?input_scale (INil)))))))
|
||||
(union ?scaled ?sgemm)
|
||||
(set (dtype ?sgemm) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt scaled fp8 row-major x column-major f32 output"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?scaled_activation (Op (Mul
|
||||
?activation_shape
|
||||
?raw_activation_strides
|
||||
?recip_activation_strides
|
||||
?activation_out_strides)
|
||||
(ICons ?raw_activation (ICons ?recip_input_scale (INil)))))
|
||||
(= ?recip_input_scale (Op (Recip
|
||||
?activation_shape
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ENil)))
|
||||
?recip_out_strides)
|
||||
(ICons ?input_scale (INil))))
|
||||
(= ?a (Op (Cast ?a_size ?a_dtype) (ICons ?scaled_activation (INil))))
|
||||
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?cast (Op (Cast ?size (F32)) (ICons ?sum (INil))))
|
||||
(= ?scale_product (Op (Mul (ENil) (ENil) (ENil) (ENil))
|
||||
(ICons ?input_scale (ICons ?weight_scale (INil)))))
|
||||
(= ?scaled (Op (Mul
|
||||
?out_shape
|
||||
?cast_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ENil)))
|
||||
?scaled_out_strides)
|
||||
(ICons ?cast (ICons ?scale_product (INil)))))
|
||||
(= ?cast_strides ?scaled_out_strides)
|
||||
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
|
||||
(= ?a_stride (ECons ?a_m_stride (ECons ?a_n_stride (ECons ?a_k_stride (ENil)))))
|
||||
(= ?b_stride (ECons ?b_m_stride (ECons ?b_n_stride (ECons ?b_k_stride (ENil)))))
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
(= ?a_m_stride (MMul (MIter) ?k))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MIter))
|
||||
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MMul (MIter) ?k))
|
||||
(= ?b_k_stride (MIter))
|
||||
|
||||
(= ?b_dtype (dtype ?b))
|
||||
(cublaslt_fp8_f32_output_pair ?a_dtype ?b_dtype)
|
||||
(= ?scaled (Op (cublaslt_scaled
|
||||
?n ?m ?k
|
||||
"T" "N"
|
||||
"COL" "COL" "COL" "COL"
|
||||
?b_n_stride
|
||||
?a_m_stride
|
||||
?n
|
||||
?n
|
||||
(MNum 1)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
?b_dtype ?a_dtype (F32) (F32) "32F" "F32" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?b (ICons ?a (ICons ?weight_scale (ICons ?input_scale (INil)))))))
|
||||
(= ?cast (Op (cublaslt
|
||||
?n ?m ?k
|
||||
"T" "N"
|
||||
"COL" "COL" "COL" "COL"
|
||||
?b_n_stride
|
||||
?a_m_stride
|
||||
?n
|
||||
?n
|
||||
(MNum 1)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
?b_dtype ?a_dtype (F32) (F32) "32F" "F32" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
)
|
||||
(
|
||||
(delete (Op (Mul
|
||||
?out_shape
|
||||
?cast_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ENil)))
|
||||
?scaled_out_strides)
|
||||
(ICons ?cast (ICons ?scale_product (INil)))))
|
||||
(delete (Op (cublaslt
|
||||
?n ?m ?k
|
||||
"T" "N"
|
||||
"COL" "COL" "COL" "COL"
|
||||
?b_n_stride
|
||||
?a_m_stride
|
||||
?n
|
||||
?n
|
||||
(MNum 1)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
(MNum 0)
|
||||
?b_dtype ?a_dtype (F32) (F32) "32F" "F32" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
)
|
||||
:ruleset cleanup
|
||||
:name "delete raw fp8 path when scaled cublaslt covers direct output scale"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
; Fusion growth can make the live path consume a raw FP8 cuBLASLt
|
||||
; candidate through an internal CudaBinaryElementwise scale multiply,
|
||||
; instead of the original HLIR output-scale Mul. The scalar scale
|
||||
; product is tensor-wide, so the two scalar factors can be passed as
|
||||
; cuBLASLt A/B scale inputs and the internal multiply can be bypassed.
|
||||
(= ?raw_gemm (Op (cublaslt
|
||||
?cm ?cn ?ck
|
||||
?cta ?ctb
|
||||
?cao ?cbo ?cco ?cdo
|
||||
?clda ?cldb ?cldc ?cldd
|
||||
?cbc ?csa ?csb ?csc ?csd
|
||||
?cadt ?cbdt ?ccdt ?cddt ?ccompute ?cscale ?calpha ?cbeta ?cepilogue)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(cublaslt_fp8_f32_output_pair ?cadt ?cbdt)
|
||||
(= ?ccdt (F32))
|
||||
(= ?cddt (F32))
|
||||
(= ?cbeta 0.0)
|
||||
(= ?cepilogue "DEFAULT")
|
||||
|
||||
(= ?fs_cast (Op (FusionStart
|
||||
?out_shape
|
||||
?cast_strides
|
||||
(F32))
|
||||
(ICons ?raw_gemm (INil))))
|
||||
|
||||
(= ?out_shape (ECons ?out_m (ECons ?out_n (ENil))))
|
||||
(= ?scale_strides (ECons (MNum 0) (ECons (MNum 0) (ENil))))
|
||||
|
||||
(= ?fs_a_scale (Op (FusionStart (ENil) (ENil) (F32))
|
||||
(ICons ?a_scale (INil))))
|
||||
(= ?fs_b_scale (Op (FusionStart (ENil) (ENil) (F32))
|
||||
(ICons ?b_scale (INil))))
|
||||
(= ?scale_product_inner (Op (CudaBinaryElementwise
|
||||
"Mul"
|
||||
(ENil)
|
||||
(ENil)
|
||||
(ENil)
|
||||
(ENil)
|
||||
(F32))
|
||||
(ICons ?fs_a_scale (ICons ?fs_b_scale (INil)))))
|
||||
(= ?scale_product (Op (FusionEnd (ENil) (ENil) (F32))
|
||||
(ICons ?scale_product_inner (INil))))
|
||||
(= ?fs_scale (Op (FusionStart
|
||||
?out_shape
|
||||
?scale_strides
|
||||
(F32))
|
||||
(ICons ?scale_product (INil))))
|
||||
(= ?fused_scale (Op (CudaBinaryElementwise
|
||||
"Mul"
|
||||
?out_shape
|
||||
?cast_strides
|
||||
?scale_strides
|
||||
?scaled_out_strides
|
||||
(F32))
|
||||
(ICons ?fs_cast (ICons ?fs_scale (INil)))))
|
||||
(= ?cast_strides ?scaled_out_strides)
|
||||
)
|
||||
(
|
||||
(let ?sgemm (Op (cublaslt_scaled
|
||||
?cm ?cn ?ck
|
||||
?cta ?ctb
|
||||
?cao ?cbo ?cco ?cdo
|
||||
?clda ?cldb ?cldc ?cldd
|
||||
?cbc ?csa ?csb ?csc ?csd
|
||||
?cadt ?cbdt ?ccdt ?cddt ?ccompute ?cscale ?calpha ?cbeta ?cepilogue)
|
||||
(ICons ?a (ICons ?b (ICons ?a_scale (ICons ?b_scale (INil)))))))
|
||||
(let ?fs_sgemm (Op (FusionStart ?out_shape ?scaled_out_strides (F32))
|
||||
(ICons ?sgemm (INil))))
|
||||
(union ?fused_scale ?fs_sgemm)
|
||||
(set (dtype ?sgemm) (F32))
|
||||
(set (dtype ?fs_sgemm) (F32))
|
||||
)
|
||||
:ruleset fusion_grow
|
||||
:name "cublaslt scaled fp8 fused output-scale f32 output"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?raw_gemm (Op (cublaslt
|
||||
?cm ?cn ?ck
|
||||
?cta ?ctb
|
||||
?cao ?cbo ?cco ?cdo
|
||||
?clda ?cldb ?cldc ?cldd
|
||||
?cbc ?csa ?csb ?csc ?csd
|
||||
?cadt ?cbdt ?ccdt ?cddt ?ccompute ?cscale ?calpha ?cbeta ?cepilogue)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(cublaslt_fp8_f32_output_pair ?cadt ?cbdt)
|
||||
(= ?ccdt (F32))
|
||||
(= ?cddt (F32))
|
||||
(= ?cbeta 0.0)
|
||||
(= ?cepilogue "DEFAULT")
|
||||
|
||||
(= ?fs_cast (Op (FusionStart
|
||||
?out_shape
|
||||
?cast_strides
|
||||
(F32))
|
||||
(ICons ?raw_gemm (INil))))
|
||||
|
||||
(= ?out_shape (ECons ?out_m (ECons ?out_n (ENil))))
|
||||
(= ?scale_strides (ECons (MNum 0) (ECons (MNum 0) (ENil))))
|
||||
|
||||
(= ?fs_a_scale (Op (FusionStart (ENil) (ENil) (F32))
|
||||
(ICons ?a_scale (INil))))
|
||||
(= ?fs_b_scale (Op (FusionStart (ENil) (ENil) (F32))
|
||||
(ICons ?b_scale (INil))))
|
||||
(= ?scale_product_inner (Op (CudaBinaryElementwise
|
||||
"Mul"
|
||||
(ENil)
|
||||
(ENil)
|
||||
(ENil)
|
||||
(ENil)
|
||||
(F32))
|
||||
(ICons ?fs_a_scale (ICons ?fs_b_scale (INil)))))
|
||||
(= ?scale_product (Op (FusionEnd (ENil) (ENil) (F32))
|
||||
(ICons ?scale_product_inner (INil))))
|
||||
(= ?fs_scale (Op (FusionStart
|
||||
?out_shape
|
||||
?scale_strides
|
||||
(F32))
|
||||
(ICons ?scale_product (INil))))
|
||||
(= ?fused_scale (Op (CudaBinaryElementwise
|
||||
"Mul"
|
||||
?out_shape
|
||||
?cast_strides
|
||||
?scale_strides
|
||||
?scaled_out_strides
|
||||
(F32))
|
||||
(ICons ?fs_cast (ICons ?fs_scale (INil)))))
|
||||
(= ?cast_strides ?scaled_out_strides)
|
||||
|
||||
(= ?sgemm (Op (cublaslt_scaled
|
||||
?cm ?cn ?ck
|
||||
?cta ?ctb
|
||||
?cao ?cbo ?cco ?cdo
|
||||
?clda ?cldb ?cldc ?cldd
|
||||
?cbc ?csa ?csb ?csc ?csd
|
||||
?cadt ?cbdt ?ccdt ?cddt ?ccompute ?cscale ?calpha ?cbeta ?cepilogue)
|
||||
(ICons ?a (ICons ?b (ICons ?a_scale (ICons ?b_scale (INil)))))))
|
||||
(= ?fused_scale (Op (FusionStart ?out_shape ?scaled_out_strides (F32))
|
||||
(ICons ?sgemm (INil))))
|
||||
)
|
||||
(
|
||||
(delete (Op (cublaslt
|
||||
?cm ?cn ?ck
|
||||
?cta ?ctb
|
||||
?cao ?cbo ?cco ?cdo
|
||||
?clda ?cldb ?cldc ?cldd
|
||||
?cbc ?csa ?csb ?csc ?csd
|
||||
?cadt ?cbdt ?ccdt ?cddt ?ccompute ?cscale ?calpha ?cbeta ?cepilogue)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(delete (Op (CudaBinaryElementwise
|
||||
"Mul"
|
||||
?out_shape
|
||||
?cast_strides
|
||||
?scale_strides
|
||||
?scaled_out_strides
|
||||
(F32))
|
||||
(ICons ?fs_cast (ICons ?fs_scale (INil)))))
|
||||
)
|
||||
:ruleset cleanup
|
||||
:name "delete raw fp8 path when scaled cublaslt covers fused output scale"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
; Batched form of the scaled FP8 linear rewrite. The scale operands are
|
||||
; scalar tensors expanded across the last three output/activation axes.
|
||||
(= ?scaled_activation (Op (Mul
|
||||
?activation_shape
|
||||
?raw_activation_strides
|
||||
?recip_activation_strides
|
||||
?activation_out_strides)
|
||||
(ICons ?raw_activation (ICons ?recip_input_scale (INil)))))
|
||||
(= ?recip_input_scale (Op (Recip
|
||||
?activation_shape
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ECons (MNum 0) (ENil))))
|
||||
?recip_out_strides)
|
||||
(ICons ?input_scale (INil))))
|
||||
(= ?a (Op (Cast ?a_size ?a_dtype) (ICons ?scaled_activation (INil))))
|
||||
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?cast (Op (Cast ?size (F32)) (ICons ?sum (INil))))
|
||||
(= ?scale_product (Op (Mul (ENil) (ENil) (ENil) (ENil))
|
||||
(ICons ?input_scale (ICons ?weight_scale (INil)))))
|
||||
(= ?scaled (Op (Mul
|
||||
?out_shape
|
||||
?cast_strides
|
||||
(ECons (MNum 0) (ECons (MNum 0) (ECons (MNum 0) (ENil))))
|
||||
?scaled_out_strides)
|
||||
(ICons ?cast (ICons ?scale_product (INil)))))
|
||||
(= ?cast_strides ?scaled_out_strides)
|
||||
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
(!= ?batch (MNum 0))
|
||||
|
||||
(= ?a_batch_stride (nth_from_end ?a_stride 3))
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
(= ?b_batch_stride (nth_from_end ?b_stride 3))
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
(= ?a_k_stride (MIter))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_m_stride (MMul (MIter) ?k))
|
||||
|
||||
(= ?b_k_stride (MIter))
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MMul (MIter) ?k))
|
||||
|
||||
(= ?a_batch_stride (MMul ?m ?a_m_stride))
|
||||
(= ?b_batch_stride (MMul ?n ?b_n_stride))
|
||||
|
||||
(= ?b_dtype (dtype ?b))
|
||||
(cublaslt_fp8_f32_output_pair ?a_dtype ?b_dtype)
|
||||
)
|
||||
(
|
||||
(let ?sgemm (Op (cublaslt_scaled
|
||||
?n ?m ?k
|
||||
"T" "N"
|
||||
"COL" "COL" "COL" "COL"
|
||||
?b_n_stride
|
||||
?a_m_stride
|
||||
?n
|
||||
?n
|
||||
?batch
|
||||
?b_batch_stride
|
||||
?a_batch_stride
|
||||
(MMul ?m ?n)
|
||||
(MMul ?m ?n)
|
||||
?b_dtype ?a_dtype (F32) (F32) "32F" "F32" 1.0 0.0 "DEFAULT")
|
||||
(ICons ?b (ICons ?a (ICons ?weight_scale (ICons ?input_scale (INil)))))))
|
||||
(union ?scaled ?sgemm)
|
||||
(set (dtype ?sgemm) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublaslt scaled fp8 batched row-major x column-major f32 output"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
@@ -69,8 +69,6 @@ pub struct CuBlasLt {
|
||||
alpha: f64,
|
||||
beta: f64,
|
||||
epilogue: cublasLtEpilogue_t,
|
||||
a_scale_input: bool,
|
||||
b_scale_input: bool,
|
||||
cublaslt: OnceLock<Arc<CudaBlasLT>>,
|
||||
}
|
||||
|
||||
@@ -105,62 +103,52 @@ impl Default for CuBlasLt {
|
||||
alpha: 1.0,
|
||||
beta: 0.0,
|
||||
epilogue: cublasLtEpilogue_t::CUBLASLT_EPILOGUE_DEFAULT,
|
||||
a_scale_input: false,
|
||||
b_scale_input: false,
|
||||
cublaslt: OnceLock::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct CuBlasLtScaled;
|
||||
|
||||
fn cublaslt_sort(name: &'static str) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
name,
|
||||
&[
|
||||
("m", EXPRESSION),
|
||||
("n", EXPRESSION),
|
||||
("k", EXPRESSION),
|
||||
("a_layout", STRING),
|
||||
("b_layout", STRING),
|
||||
("a_order", STRING),
|
||||
("b_order", STRING),
|
||||
("c_order", STRING),
|
||||
("d_order", STRING),
|
||||
("lda", EXPRESSION),
|
||||
("ldb", EXPRESSION),
|
||||
("ldc", EXPRESSION),
|
||||
("ldd", EXPRESSION),
|
||||
("batch_count", EXPRESSION),
|
||||
("stride_a", EXPRESSION),
|
||||
("stride_b", EXPRESSION),
|
||||
("stride_c", EXPRESSION),
|
||||
("stride_d", EXPRESSION),
|
||||
("a_dtype", DTYPE),
|
||||
("b_dtype", DTYPE),
|
||||
("c_dtype", DTYPE),
|
||||
("d_dtype", DTYPE),
|
||||
("compute_type", STRING),
|
||||
("scale_dtype", STRING),
|
||||
("alpha", F64),
|
||||
("beta", F64),
|
||||
("epilogue", STRING),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
impl EgglogOp for CuBlasLt {
|
||||
fn sort(&self) -> SortDef {
|
||||
cublaslt_sort("cublaslt")
|
||||
sort(
|
||||
OP_KIND,
|
||||
"cublaslt",
|
||||
&[
|
||||
("m", EXPRESSION),
|
||||
("n", EXPRESSION),
|
||||
("k", EXPRESSION),
|
||||
("a_layout", STRING),
|
||||
("b_layout", STRING),
|
||||
("a_order", STRING),
|
||||
("b_order", STRING),
|
||||
("c_order", STRING),
|
||||
("d_order", STRING),
|
||||
("lda", EXPRESSION),
|
||||
("ldb", EXPRESSION),
|
||||
("ldc", EXPRESSION),
|
||||
("ldd", EXPRESSION),
|
||||
("batch_count", EXPRESSION),
|
||||
("stride_a", EXPRESSION),
|
||||
("stride_b", EXPRESSION),
|
||||
("stride_c", EXPRESSION),
|
||||
("stride_d", EXPRESSION),
|
||||
("a_dtype", DTYPE),
|
||||
("b_dtype", DTYPE),
|
||||
("c_dtype", DTYPE),
|
||||
("d_dtype", DTYPE),
|
||||
("compute_type", STRING),
|
||||
("scale_dtype", STRING),
|
||||
("alpha", F64),
|
||||
("beta", F64),
|
||||
("epilogue", STRING),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
fn n_inputs(&self) -> usize {
|
||||
let c_input = usize::from(self.beta != 0.0);
|
||||
let bias_input = usize::from(epilogue_uses_bias(self.epilogue));
|
||||
let scale_inputs = usize::from(self.a_scale_input) + usize::from(self.b_scale_input);
|
||||
2 + c_input + bias_input + scale_inputs
|
||||
2 + c_input + bias_input
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
@@ -170,54 +158,39 @@ impl EgglogOp for CuBlasLt {
|
||||
(cublaslt_base_dtype (F32))
|
||||
(cublaslt_base_dtype (F16))
|
||||
(cublaslt_base_dtype (Bf16))
|
||||
(cublaslt_base_dtype (TF32))
|
||||
(relation cublaslt_fp8_dtype (DType))
|
||||
(cublaslt_fp8_dtype (F8E4M3))
|
||||
(cublaslt_fp8_dtype (F8E5M2))
|
||||
(relation cublaslt_fp8_f32_output_pair (DType DType))
|
||||
(cublaslt_fp8_f32_output_pair (F8E4M3) (F8E4M3))
|
||||
(cublaslt_fp8_f32_output_pair (F8E4M3) (F8E5M2))
|
||||
(cublaslt_fp8_f32_output_pair (F8E5M2) (F8E4M3))",
|
||||
(cublaslt_base_dtype (TF32))",
|
||||
),
|
||||
Rule::raw(include_str!["cublaslt_RmRm_rewrite.egg"]), // row row
|
||||
Rule::raw(include_str!["cublaslt_RmCm_rewrite.egg"]), // row col
|
||||
Rule::raw(include_str!["cublaslt_CmRm_rewrite.egg"]), // col row
|
||||
Rule::raw(include_str!["cublaslt_CmCm_rewrite.egg"]), // col col
|
||||
Rule::raw(include_str!["cublaslt_fp8_rewrite.egg"]),
|
||||
Rule::raw(include_str!["cublaslt_row_order_rewrite.egg"]),
|
||||
Rule::raw(include_str!["cublaslt_mixed_dtype_rewrite.egg"]),
|
||||
Rule::raw(include_str!["cublaslt_scale_rewrite.egg"]),
|
||||
Rule::raw(include_str!["cublaslt_beta_rewrite.egg"]),
|
||||
Rule::raw(include_str!["cublaslt_epilogue_rewrite.egg"]),
|
||||
Rule::raw(include_str!["cublaslt_row_order_rewrite.egg"]),
|
||||
// Delete the matmul-broadcast Mul eclass when the consuming Sum
|
||||
// eclass has a `cublaslt` or `KernelBatchMatMul` alternative. The
|
||||
// cuBLASLt / batched-matmul rewrite rules only union those enodes
|
||||
// into the Sum eclass after the broadcast pattern check passes,
|
||||
// so their presence is the matmul-broadcast signal — no further
|
||||
// stride-form check needed.
|
||||
//
|
||||
// Delete the HLIR `Mul` and its generic fusion-region alternative
|
||||
// from the Mul eclass. Emptying that eclass lets the empty-eclass
|
||||
// cascade prune the downstream Sum / KernelSum fallback. cuBLAS,
|
||||
// TileMatmulFullSplit, KernelBatchMatVec, and KernelBatchMatMul all
|
||||
// take original (a, b) inputs rather than the Mul eclass, so they
|
||||
// survive the cascade and remain as the matmul output alternative.
|
||||
// Delete KernelMul matmul broadcast intermediates when the Sum eclass
|
||||
// has a cublaslt or KernelBatchMatMul alternative. This prevents OOM
|
||||
// from O(m*k*n) intermediates at large seq_len. cuBLAS, TileMatmulFullSplit,
|
||||
// KernelBatchMatVec, and KernelBatchMatMul all take original inputs
|
||||
// (not the Mul eclass), so they survive the cascade.
|
||||
Rule::raw("(rule
|
||||
((= ?mul (Op (Mul ?shape ?as ?bs ?os) ?inputs))
|
||||
(= ?mul (Op (FusionEnd ?fshape ?fos ?fdt) ?finputs))
|
||||
((= ?mul (Op (KernelMul ?shape ?as ?bs ?os ?dt) ?inputs))
|
||||
(= (MNum 0) (nth_from_end ?as 1))
|
||||
(= (MNum 0) (nth_from_end ?bs 2))
|
||||
(= ?sum (Op (Sum ?sshape ?sk ?ssi ?sks ?sso) (ICons ?mul (INil))))
|
||||
(= ?sum (Op (cublaslt ?cm ?cn ?ck ?cta ?ctb ?cao ?cbo ?cco ?cdo ?clda ?cldb ?cldc ?cldd ?cbc ?csa ?csb ?csc ?csd ?cadt ?cbdt ?ccdt ?cddt ?ccompute ?cscale ?calpha ?cbeta ?cepilogue) ?ci)))
|
||||
((delete (Op (Mul ?shape ?as ?bs ?os) ?inputs))
|
||||
(delete (Op (FusionEnd ?fshape ?fos ?fdt) ?finputs)))
|
||||
((delete (Op (KernelMul ?shape ?as ?bs ?os ?dt) ?inputs)))
|
||||
:ruleset cleanup
|
||||
)"),
|
||||
Rule::raw("(rule
|
||||
((= ?mul (Op (Mul ?shape ?as ?bs ?os) ?inputs))
|
||||
(= ?mul (Op (FusionEnd ?fshape ?fos ?fdt) ?finputs))
|
||||
((= ?mul (Op (KernelMul ?shape ?as ?bs ?os ?dt) ?inputs))
|
||||
(= (MNum 0) (nth_from_end ?as 1))
|
||||
(= (MNum 0) (nth_from_end ?bs 2))
|
||||
(= ?sum (Op (Sum ?sshape ?sk ?ssi ?sks ?sso) (ICons ?mul (INil))))
|
||||
(= ?sum (Op (KernelBatchMatMul ?bos ?bk ?bas ?baks ?bbs ?bbks ?bouts ?bdt) ?bi)))
|
||||
((delete (Op (Mul ?shape ?as ?bs ?os) ?inputs))
|
||||
(delete (Op (FusionEnd ?fshape ?fos ?fdt) ?finputs)))
|
||||
((delete (Op (KernelMul ?shape ?as ?bs ?os ?dt) ?inputs)))
|
||||
:ruleset cleanup
|
||||
)"),
|
||||
]
|
||||
@@ -304,104 +277,6 @@ impl EgglogOp for CuBlasLt {
|
||||
alpha,
|
||||
beta,
|
||||
epilogue,
|
||||
a_scale_input: false,
|
||||
b_scale_input: false,
|
||||
cublaslt: OnceLock::new(),
|
||||
};
|
||||
trace!(?extracted_state);
|
||||
|
||||
let extracted = LLIROp::new::<dyn HostOp>(Box::new(extracted_state) as Box<dyn HostOp>);
|
||||
|
||||
(extracted, input_enodes)
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
impl EgglogOp for CuBlasLtScaled {
|
||||
fn sort(&self) -> SortDef {
|
||||
cublaslt_sort("cublaslt_scaled")
|
||||
}
|
||||
|
||||
fn n_inputs(&self) -> usize {
|
||||
4
|
||||
}
|
||||
|
||||
#[allow(unused_variables)]
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a luminal::egglog_utils::SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
let m = extract_expr(egraph, kind_children[0], expr_cache).unwrap();
|
||||
let n = extract_expr(egraph, kind_children[1], expr_cache).unwrap();
|
||||
let k = extract_expr(egraph, kind_children[2], expr_cache).unwrap();
|
||||
|
||||
let a_layout = parse_cublas_op(&egraph.enodes[kind_children[3]].0);
|
||||
let b_layout = parse_cublas_op(&egraph.enodes[kind_children[4]].0);
|
||||
let a_order = parse_cublaslt_order(&egraph.enodes[kind_children[5]].0);
|
||||
let b_order = parse_cublaslt_order(&egraph.enodes[kind_children[6]].0);
|
||||
let c_order = parse_cublaslt_order(&egraph.enodes[kind_children[7]].0);
|
||||
let d_order = parse_cublaslt_order(&egraph.enodes[kind_children[8]].0);
|
||||
|
||||
let lda = extract_expr(egraph, kind_children[9], expr_cache).unwrap();
|
||||
let ldb = extract_expr(egraph, kind_children[10], expr_cache).unwrap();
|
||||
let ldc = extract_expr(egraph, kind_children[11], expr_cache).unwrap();
|
||||
let ldd = extract_expr(egraph, kind_children[12], expr_cache).unwrap();
|
||||
|
||||
let batch_count = extract_expr(egraph, kind_children[13], expr_cache).unwrap();
|
||||
let stride_a = extract_expr(egraph, kind_children[14], expr_cache).unwrap();
|
||||
let stride_b = extract_expr(egraph, kind_children[15], expr_cache).unwrap();
|
||||
let stride_c = extract_expr(egraph, kind_children[16], expr_cache).unwrap();
|
||||
let stride_d = extract_expr(egraph, kind_children[17], expr_cache).unwrap();
|
||||
|
||||
let a_dtype = extract_dtype(egraph, kind_children[18]);
|
||||
let b_dtype = extract_dtype(egraph, kind_children[19]);
|
||||
let c_dtype = extract_dtype(egraph, kind_children[20]);
|
||||
let d_dtype = extract_dtype(egraph, kind_children[21]);
|
||||
let compute_type_str = &egraph.enodes[kind_children[22]].0;
|
||||
let scale_dtype_str = &egraph.enodes[kind_children[23]].0;
|
||||
let compute_type = parse_cublaslt_compute_type(compute_type_str, a_dtype);
|
||||
let scale_dtype = parse_cublaslt_scale_dtype(scale_dtype_str, a_dtype);
|
||||
let alpha = parse_cublaslt_scalar(&egraph.enodes[kind_children[24]].0);
|
||||
let beta = parse_cublaslt_scalar(&egraph.enodes[kind_children[25]].0);
|
||||
let epilogue = parse_cublaslt_epilogue(&egraph.enodes[kind_children[26]].0);
|
||||
|
||||
let extracted_state = CuBlasLt {
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
a_layout,
|
||||
b_layout,
|
||||
a_order,
|
||||
b_order,
|
||||
c_order,
|
||||
d_order,
|
||||
lda,
|
||||
ldb,
|
||||
ldc,
|
||||
ldd,
|
||||
batch_count,
|
||||
stride_a,
|
||||
stride_b,
|
||||
stride_c,
|
||||
stride_d,
|
||||
a_dtype,
|
||||
b_dtype,
|
||||
c_dtype,
|
||||
d_dtype,
|
||||
compute_type,
|
||||
scale_dtype,
|
||||
alpha,
|
||||
beta,
|
||||
epilogue,
|
||||
a_scale_input: true,
|
||||
b_scale_input: true,
|
||||
cublaslt: OnceLock::new(),
|
||||
};
|
||||
trace!(?extracted_state);
|
||||
@@ -645,8 +520,6 @@ struct LtMatmulPointers {
|
||||
c: u64,
|
||||
d: u64,
|
||||
bias: Option<u64>,
|
||||
a_scale: Option<u64>,
|
||||
b_scale: Option<u64>,
|
||||
}
|
||||
|
||||
struct LtRawDescriptors {
|
||||
@@ -794,12 +667,12 @@ fn run_cublaslt_matmul(
|
||||
let workspace = unsafe { stream.alloc::<u8>(spec.workspace_size)? };
|
||||
let (workspace_ptr, _workspace_guard) = workspace.device_ptr(stream);
|
||||
|
||||
let a_scale = if cuda_dtype_needs_tensorwide_scale(spec.a.dtype) && ptrs.a_scale.is_none() {
|
||||
let a_scale = if cuda_dtype_needs_tensorwide_scale(spec.a.dtype) {
|
||||
Some(stream.clone_htod(&[1.0f32])?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let b_scale = if cuda_dtype_needs_tensorwide_scale(spec.b.dtype) && ptrs.b_scale.is_none() {
|
||||
let b_scale = if cuda_dtype_needs_tensorwide_scale(spec.b.dtype) {
|
||||
Some(stream.clone_htod(&[1.0f32])?)
|
||||
} else {
|
||||
None
|
||||
@@ -855,17 +728,13 @@ fn run_cublaslt_matmul(
|
||||
}
|
||||
}
|
||||
|
||||
let (a_scale_ptr, _a_scale_guard) = if let Some(ptr) = ptrs.a_scale {
|
||||
(Some(ptr), None)
|
||||
} else if let Some(scale) = &a_scale {
|
||||
let (a_scale_ptr, _a_scale_guard) = if let Some(scale) = &a_scale {
|
||||
let (ptr, guard) = scale.device_ptr(stream);
|
||||
(Some(ptr), Some(guard))
|
||||
} else {
|
||||
(None, None)
|
||||
};
|
||||
let (b_scale_ptr, _b_scale_guard) = if let Some(ptr) = ptrs.b_scale {
|
||||
(Some(ptr), None)
|
||||
} else if let Some(scale) = &b_scale {
|
||||
let (b_scale_ptr, _b_scale_guard) = if let Some(scale) = &b_scale {
|
||||
let (ptr, guard) = scale.device_ptr(stream);
|
||||
(Some(ptr), Some(guard))
|
||||
} else {
|
||||
@@ -988,8 +857,6 @@ fn resolve_cublaslt_pointers(
|
||||
buffers: &FxHashMap<NodeIndex, DeviceBuffer>,
|
||||
beta: f64,
|
||||
epilogue: cublasLtEpilogue_t,
|
||||
a_scale_input: bool,
|
||||
b_scale_input: bool,
|
||||
) -> anyhow::Result<LtMatmulPointers> {
|
||||
if inputs.len() < 2 {
|
||||
return Err(anyhow::anyhow!(
|
||||
@@ -1010,25 +877,24 @@ fn resolve_cublaslt_pointers(
|
||||
.get(&self_node)
|
||||
.ok_or_else(|| anyhow::anyhow!("missing cuBLASLt output buffer"))?
|
||||
.ptr();
|
||||
let mut next_input = 2;
|
||||
let c = if beta == 0.0 {
|
||||
d
|
||||
} else {
|
||||
let c_input = inputs.get(next_input).ok_or_else(|| {
|
||||
anyhow::anyhow!("cuBLASLt matmul with beta={beta} requires a third C input")
|
||||
})?;
|
||||
next_input += 1;
|
||||
} else if let Some(c_input) = inputs.get(2) {
|
||||
buffers
|
||||
.get(c_input)
|
||||
.ok_or_else(|| anyhow::anyhow!("missing cuBLASLt C input buffer"))?
|
||||
.ptr()
|
||||
} else {
|
||||
return Err(anyhow::anyhow!(
|
||||
"cuBLASLt matmul with beta={beta} requires a third C input"
|
||||
));
|
||||
};
|
||||
|
||||
let bias_input_index = if beta == 0.0 { 2 } else { 3 };
|
||||
let bias = if epilogue_uses_bias(epilogue) {
|
||||
let bias_input = inputs.get(next_input).ok_or_else(|| {
|
||||
let bias_input = inputs.get(bias_input_index).ok_or_else(|| {
|
||||
anyhow::anyhow!("cuBLASLt matmul with {epilogue:?} epilogue requires a bias input")
|
||||
})?;
|
||||
next_input += 1;
|
||||
Some(
|
||||
buffers
|
||||
.get(bias_input)
|
||||
@@ -1039,44 +905,7 @@ fn resolve_cublaslt_pointers(
|
||||
None
|
||||
};
|
||||
|
||||
let a_scale = if a_scale_input {
|
||||
let scale_input = inputs
|
||||
.get(next_input)
|
||||
.ok_or_else(|| anyhow::anyhow!("cuBLASLt matmul requires an A scale input pointer"))?;
|
||||
next_input += 1;
|
||||
Some(
|
||||
buffers
|
||||
.get(scale_input)
|
||||
.ok_or_else(|| anyhow::anyhow!("missing cuBLASLt A scale input buffer"))?
|
||||
.ptr(),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let b_scale = if b_scale_input {
|
||||
let scale_input = inputs
|
||||
.get(next_input)
|
||||
.ok_or_else(|| anyhow::anyhow!("cuBLASLt matmul requires a B scale input pointer"))?;
|
||||
Some(
|
||||
buffers
|
||||
.get(scale_input)
|
||||
.ok_or_else(|| anyhow::anyhow!("missing cuBLASLt B scale input buffer"))?
|
||||
.ptr(),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Ok(LtMatmulPointers {
|
||||
a,
|
||||
b,
|
||||
c,
|
||||
d,
|
||||
bias,
|
||||
a_scale,
|
||||
b_scale,
|
||||
})
|
||||
Ok(LtMatmulPointers { a, b, c, d, bias })
|
||||
}
|
||||
|
||||
fn epilogue_uses_bias(epilogue: cublasLtEpilogue_t) -> bool {
|
||||
@@ -1149,11 +978,6 @@ impl CuBlasLt {
|
||||
&& normalize(self.stride_c) == normalize(self.stride_d)
|
||||
&& self.c_order == self.d_order
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn tensor_scale_inputs(&self) -> (bool, bool) {
|
||||
(self.a_scale_input, self.b_scale_input)
|
||||
}
|
||||
}
|
||||
|
||||
impl HostOp for CuBlasLt {
|
||||
@@ -1198,15 +1022,7 @@ impl HostOp for CuBlasLt {
|
||||
let alpha = LtScalar::from_f64(self.scale_dtype, self.alpha)?;
|
||||
let beta = LtScalar::from_f64(self.scale_dtype, self.beta)?;
|
||||
|
||||
let ptrs = resolve_cublaslt_pointers(
|
||||
self_node,
|
||||
inputs,
|
||||
buffers,
|
||||
self.beta,
|
||||
self.epilogue,
|
||||
self.a_scale_input,
|
||||
self.b_scale_input,
|
||||
)?;
|
||||
let ptrs = resolve_cublaslt_pointers(self_node, inputs, buffers, self.beta, self.epilogue)?;
|
||||
|
||||
let (a_rows, a_cols) = if a_layout == cublasOperation_t::CUBLAS_OP_N {
|
||||
(m, k)
|
||||
@@ -1381,8 +1197,6 @@ mod tests {
|
||||
&buffers,
|
||||
0.0,
|
||||
cublasLtEpilogue_t::CUBLASLT_EPILOGUE_DEFAULT,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
@@ -1407,8 +1221,6 @@ mod tests {
|
||||
&buffers,
|
||||
0.0,
|
||||
cublasLtEpilogue_t::CUBLASLT_EPILOGUE_DEFAULT,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
@@ -1433,8 +1245,6 @@ mod tests {
|
||||
&buffers,
|
||||
1.0,
|
||||
cublasLtEpilogue_t::CUBLASLT_EPILOGUE_DEFAULT,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
@@ -1459,8 +1269,6 @@ mod tests {
|
||||
&buffers,
|
||||
0.0,
|
||||
cublasLtEpilogue_t::CUBLASLT_EPILOGUE_BIAS,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
@@ -1471,41 +1279,6 @@ mod tests {
|
||||
assert_eq!(ptrs.bias, Some(0xB1A5));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cublaslt_pointers_use_tensor_scale_inputs_after_base_inputs() {
|
||||
let output = NodeIndex::new(0);
|
||||
let a = NodeIndex::new(1);
|
||||
let b = NodeIndex::new(2);
|
||||
let a_scale = NodeIndex::new(3);
|
||||
let b_scale = NodeIndex::new(4);
|
||||
let buffers = buffers_for(&[
|
||||
(output, 0xD000),
|
||||
(a, 0xA000),
|
||||
(b, 0xB000),
|
||||
(a_scale, 0xA5A5),
|
||||
(b_scale, 0xB5B5),
|
||||
]);
|
||||
|
||||
let ptrs = resolve_cublaslt_pointers(
|
||||
output,
|
||||
&[a, b, a_scale, b_scale],
|
||||
&buffers,
|
||||
0.0,
|
||||
cublasLtEpilogue_t::CUBLASLT_EPILOGUE_DEFAULT,
|
||||
true,
|
||||
true,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(ptrs.a, 0xA000);
|
||||
assert_eq!(ptrs.b, 0xB000);
|
||||
assert_eq!(ptrs.c, 0xD000);
|
||||
assert_eq!(ptrs.d, 0xD000);
|
||||
assert_eq!(ptrs.bias, None);
|
||||
assert_eq!(ptrs.a_scale, Some(0xA5A5));
|
||||
assert_eq!(ptrs.b_scale, Some(0xB5B5));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cublaslt_pointers_reject_two_input_nonzero_beta() {
|
||||
let output = NodeIndex::new(0);
|
||||
@@ -1519,8 +1292,6 @@ mod tests {
|
||||
&buffers,
|
||||
1.0,
|
||||
cublasLtEpilogue_t::CUBLASLT_EPILOGUE_DEFAULT,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.unwrap_err();
|
||||
|
||||
@@ -1543,8 +1314,6 @@ mod tests {
|
||||
&buffers,
|
||||
0.0,
|
||||
cublasLtEpilogue_t::CUBLASLT_EPILOGUE_BIAS,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.unwrap_err();
|
||||
|
||||
|
||||
@@ -27,16 +27,19 @@ pub fn find_indptr_inputs<'a>(
|
||||
mask_node: &'a NodeId,
|
||||
) -> IndptrNodes<'a> {
|
||||
// Step 1: Validate mask = Add(scaled_allowed, neg_constant)
|
||||
let mask_inputs = logical_binary_inputs(egraph, mask_node, "Add").unwrap_or_else(|| {
|
||||
let (mask_label, mask_children) = &egraph.enodes[mask_node];
|
||||
assert!(
|
||||
mask_label == "Op",
|
||||
"find_indptr_inputs: mask node is not an Op (label={mask_label})"
|
||||
);
|
||||
let mask_kind = resolve_first_node(egraph, &mask_children[0]);
|
||||
let mask_kind_label = &egraph.enodes[mask_kind].0;
|
||||
panic!("find_indptr_inputs: mask is not an Add (kind={mask_kind_label})");
|
||||
});
|
||||
let (mask_label, mask_children) = &egraph.enodes[mask_node];
|
||||
assert!(
|
||||
mask_label == "Op",
|
||||
"find_indptr_inputs: mask node is not an Op (label={mask_label})"
|
||||
);
|
||||
let mask_kind = resolve_first_node(egraph, &mask_children[0]);
|
||||
let mask_kind_label = &egraph.enodes[mask_kind].0;
|
||||
assert!(
|
||||
mask_kind_label.contains("Add"),
|
||||
"find_indptr_inputs: mask is not an Add (kind={mask_kind_label})"
|
||||
);
|
||||
|
||||
let mask_inputs = walk_ilist_simple(egraph, &mask_children[1]);
|
||||
assert_eq!(
|
||||
mask_inputs.len(),
|
||||
2,
|
||||
@@ -95,9 +98,15 @@ fn find_1e10_mul<'a>(
|
||||
mask_add_inputs: &[&'a NodeId],
|
||||
) -> (&'a NodeId, &'a NodeId) {
|
||||
for &input_node in mask_add_inputs {
|
||||
let Some(mul_inputs) = logical_binary_inputs(egraph, input_node, "Mul") else {
|
||||
let (label, children) = &egraph.enodes[input_node];
|
||||
if label != "Op" {
|
||||
continue;
|
||||
};
|
||||
}
|
||||
let kind = resolve_first_node(egraph, &children[0]);
|
||||
if !egraph.enodes[kind].0.contains("Mul") {
|
||||
continue;
|
||||
}
|
||||
let mul_inputs = walk_ilist_simple(egraph, &children[1]);
|
||||
if mul_inputs.len() != 2 {
|
||||
continue;
|
||||
}
|
||||
@@ -143,7 +152,6 @@ fn find_1e10_mul<'a>(
|
||||
}
|
||||
|
||||
fn is_constant(egraph: &SerializedEGraph, node: &NodeId, expected: f32) -> bool {
|
||||
let node = resolve_op_with_kind(egraph, node, "Constant").unwrap_or(node);
|
||||
let (label, children) = &egraph.enodes[node];
|
||||
if label != "Op" {
|
||||
return false;
|
||||
@@ -238,91 +246,3 @@ fn resolve_first_ir_node<'a>(egraph: &'a SerializedEGraph, eclass: &ClassId) ->
|
||||
}
|
||||
&nodes[0]
|
||||
}
|
||||
|
||||
fn resolve_op_with_kind<'a>(
|
||||
egraph: &'a SerializedEGraph,
|
||||
node: &'a NodeId,
|
||||
kind_substr: &str,
|
||||
) -> Option<&'a NodeId> {
|
||||
let class = egraph.node_to_class.get(node)?;
|
||||
for candidate in &egraph.eclasses[class].1 {
|
||||
let (label, children) = &egraph.enodes[candidate];
|
||||
if label != "Op" || children.is_empty() {
|
||||
continue;
|
||||
}
|
||||
let kind = resolve_first_node(egraph, &children[0]);
|
||||
if egraph.enodes[kind].0.contains(kind_substr) {
|
||||
return Some(candidate);
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
fn logical_binary_inputs<'a>(
|
||||
egraph: &'a SerializedEGraph,
|
||||
node: &'a NodeId,
|
||||
op_name: &str,
|
||||
) -> Option<Vec<&'a NodeId>> {
|
||||
if let Some(op_node) = resolve_op_with_kind(egraph, node, op_name) {
|
||||
let (_, children) = &egraph.enodes[op_node];
|
||||
return Some(walk_ilist_simple(egraph, &children[1]));
|
||||
}
|
||||
|
||||
let (label, children) = &egraph.enodes[node];
|
||||
if label != "Op" || children.len() < 2 {
|
||||
return None;
|
||||
}
|
||||
let kind = resolve_first_node(egraph, &children[0]);
|
||||
if egraph.enodes[kind].0.contains("CudaBinaryElementwise") {
|
||||
let opcode_class = egraph.enodes[kind].1.first()?;
|
||||
let opcode_node = resolve_first_node(egraph, opcode_class);
|
||||
if egraph.enodes[opcode_node].0.trim_matches('"') != op_name {
|
||||
return None;
|
||||
}
|
||||
return Some(
|
||||
walk_ilist_simple(egraph, &children[1])
|
||||
.into_iter()
|
||||
.map(|input| unwrap_fusion_start(egraph, input))
|
||||
.collect(),
|
||||
);
|
||||
}
|
||||
if !egraph.enodes[kind].0.contains("FusionEnd") {
|
||||
return None;
|
||||
}
|
||||
let fe_inputs = walk_ilist_simple(egraph, &children[1]);
|
||||
let elem = *fe_inputs.first()?;
|
||||
let (elem_label, elem_children) = &egraph.enodes[elem];
|
||||
if elem_label != "Op" || elem_children.len() < 2 {
|
||||
return None;
|
||||
}
|
||||
let elem_kind = resolve_first_node(egraph, &elem_children[0]);
|
||||
if !egraph.enodes[elem_kind].0.contains("CudaBinaryElementwise") {
|
||||
return None;
|
||||
}
|
||||
let opcode_class = egraph.enodes[elem_kind].1.first()?;
|
||||
let opcode_node = resolve_first_node(egraph, opcode_class);
|
||||
if egraph.enodes[opcode_node].0.trim_matches('"') != op_name {
|
||||
return None;
|
||||
}
|
||||
Some(
|
||||
walk_ilist_simple(egraph, &elem_children[1])
|
||||
.into_iter()
|
||||
.map(|input| unwrap_fusion_start(egraph, input))
|
||||
.collect(),
|
||||
)
|
||||
}
|
||||
|
||||
fn unwrap_fusion_start<'a>(egraph: &'a SerializedEGraph, node: &'a NodeId) -> &'a NodeId {
|
||||
let (label, children) = &egraph.enodes[node];
|
||||
if label != "Op" || children.len() < 2 {
|
||||
return node;
|
||||
}
|
||||
let kind = resolve_first_node(egraph, &children[0]);
|
||||
if !egraph.enodes[kind].0.contains("FusionStart") {
|
||||
return node;
|
||||
}
|
||||
walk_ilist_simple(egraph, &children[1])
|
||||
.first()
|
||||
.copied()
|
||||
.unwrap_or(node)
|
||||
}
|
||||
|
||||
@@ -89,16 +89,6 @@
|
||||
?mask_add_out_strides)
|
||||
(ICons ?scaled_qk (ICons ?mask (INil)))))
|
||||
|
||||
; FlashInfer needs qo_indptr/kv_indptr to be recoverable from the mask
|
||||
; expression. Do not match examples that pass a precomputed mask Input.
|
||||
(= ?mask (Op (Add ?inner_mask_shape ?inner_mask_a_strides ?inner_mask_b_strides ?inner_mask_out_strides)
|
||||
(ICons ?mask_scaled_allowed (ICons ?mask_offset (INil)))))
|
||||
(= ?mask_scaled_allowed (Op (Mul ?allowed_shape ?allowed_strides ?scale_const_strides ?scaled_allowed_strides)
|
||||
(ICons ?mask_allowed (ICons ?mask_scale_const (INil)))))
|
||||
(= ?mask_scale_const (Op (Constant ?mask_scale_val) (INil)))
|
||||
(> ?mask_scale_val 9999999999.0)
|
||||
(< ?mask_scale_val 10000000001.0)
|
||||
|
||||
; ── K GQA broadcast: Mul(K_gathered, 1.0) with zero-stride constant ──
|
||||
; Shape: (nheads, hdim, c) — 3D
|
||||
(= ?k_gqa_const (Op (Constant 1.000000) (INil)))
|
||||
|
||||
@@ -2,16 +2,19 @@ use std::{fmt::Debug, sync::Arc};
|
||||
|
||||
use crate::cudarc::driver::{CudaStream, DriverError, result};
|
||||
use luminal::{op::EgglogOp, prelude::*};
|
||||
pub mod compute_attn_mask;
|
||||
mod cublas;
|
||||
mod cublaslt;
|
||||
pub mod flashinfer;
|
||||
pub mod moe;
|
||||
|
||||
pub use compute_attn_mask::ComputeAttnMask;
|
||||
|
||||
pub type Ops = (
|
||||
// cublas::CuBlasSgemmV2,
|
||||
cublaslt::CuBlasLt,
|
||||
cublaslt::CuBlasLtScaled,
|
||||
moe::GLUMoE,
|
||||
compute_attn_mask::ComputeAttnMask,
|
||||
flashinfer::FlashInferAttention,
|
||||
);
|
||||
|
||||
@@ -76,16 +79,6 @@ pub(crate) fn cublaslt_c_d_layouts_match(op: &dyn HostOp) -> Option<bool> {
|
||||
.map(cublaslt::CuBlasLt::c_d_layouts_match)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) type CublasLtTensorScaleInputs = (bool, bool);
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn cublaslt_tensor_scale_inputs(op: &dyn HostOp) -> Option<CublasLtTensorScaleInputs> {
|
||||
op.as_any()
|
||||
.downcast_ref::<cublaslt::CuBlasLt>()
|
||||
.map(cublaslt::CuBlasLt::tensor_scale_inputs)
|
||||
}
|
||||
|
||||
/// Non-owning device buffer handle used by host operations.
|
||||
///
|
||||
/// Runtime-owned intermediates may be a whole `CudaSlice`, a subregion inside
|
||||
|
||||
@@ -195,10 +195,6 @@
|
||||
(= ?swiglu_state (MkGLUMoESwiGLUState ?gate_up_state))
|
||||
(= ?gate_up_state (MkGLUMoEGateUpState ?gu_io ?gu_matmul_k ?gu_within_range ?x ?topk_idx ?gate_up_w))
|
||||
|
||||
(= ?topk_row_offsets (Op (Iota ?topk_row_offsets_expr ?topk_row_offsets_range) (INil)))
|
||||
(= ?topk_flat_idx (Op (Add ?topk_flat_idx_shape ?topk_flat_idx_a_stride ?topk_flat_idx_b_stride ?topk_flat_idx_out_stride) (ICons ?topk_row_offsets (ICons ?topk_idx (INil)))))
|
||||
(= ?topk_vals (Op (Gather ?topk_vals_gather_idx_shape ?topk_vals_gather_idx_stride ?topk_vals_gather_data_shape ?topk_vals_gather_data_stride) (ICons ?topk_flat_idx (ICons ?routing_weights (INil)))))
|
||||
|
||||
(= ?weighted (Op (Mul ?weighted_shape ?weighted_a_stride ?weighted_b_stride ?weighted_out_stride) (ICons ?dn_matmul (ICons ?topk_vals (INil)))))
|
||||
(= ?output (Op (Sum ?output_shape ?output_k ?output_in_stride ?output_k_stride ?output_out_stride) (ICons ?weighted (INil))))
|
||||
)
|
||||
@@ -215,37 +211,6 @@
|
||||
:name "GLUMoE fused expert computation (swiglu)"
|
||||
)
|
||||
|
||||
; ===== Final fusion: mode 2 (SwiGLU with row-normalized top-k weights) =====
|
||||
(rule
|
||||
(
|
||||
(= ?down_state (glumoe_swiglu_down ?dn_matmul))
|
||||
(= ?down_state (MkGLUMoESwiGLUDownState ?dn_io ?dn_matmul_k ?dn_within_range ?swiglu_state ?topk_idx ?down_w))
|
||||
(= ?swiglu_state (MkGLUMoESwiGLUState ?gate_up_state))
|
||||
(= ?gate_up_state (MkGLUMoEGateUpState ?gu_io ?gu_matmul_k ?gu_within_range ?x ?topk_idx ?gate_up_w))
|
||||
|
||||
(= ?topk_row_offsets (Op (Iota ?topk_row_offsets_expr ?topk_row_offsets_range) (INil)))
|
||||
(= ?topk_flat_idx (Op (Add ?topk_flat_idx_shape ?topk_flat_idx_a_stride ?topk_flat_idx_b_stride ?topk_flat_idx_out_stride) (ICons ?topk_row_offsets (ICons ?topk_idx (INil)))))
|
||||
(= ?topk_vals (Op (Gather ?topk_vals_gather_idx_shape ?topk_vals_gather_idx_stride ?topk_vals_gather_data_shape ?topk_vals_gather_data_stride) (ICons ?topk_flat_idx (ICons ?routing_weights (INil)))))
|
||||
(= ?topk_norm (Op (Sum ?topk_norm_shape ?output_k ?topk_norm_in_stride ?topk_norm_k_stride ?topk_norm_out_stride) (ICons ?topk_vals (INil))))
|
||||
(= ?topk_norm_factor (Op (Recip ?topk_norm_recip_shape ?topk_norm_recip_in_stride ?topk_norm_recip_out_stride) (ICons ?topk_norm (INil))))
|
||||
(= ?normed_topk (Op (Mul ?normed_topk_shape ?normed_topk_a_stride ?normed_topk_b_stride ?normed_topk_out_stride) (ICons ?topk_vals (ICons ?topk_norm_factor (INil)))))
|
||||
|
||||
(= ?weighted (Op (Mul ?weighted_shape ?weighted_a_stride ?weighted_b_stride ?weighted_out_stride) (ICons ?dn_matmul (ICons ?normed_topk (INil)))))
|
||||
(= ?output (Op (Sum ?output_shape ?output_k ?output_in_stride ?output_k_stride ?output_out_stride) (ICons ?weighted (INil))))
|
||||
)
|
||||
(
|
||||
(let ?glumoe (Op (GLUMoE
|
||||
?gu_io ?dn_io ?gu_matmul_k ?dn_matmul_k ?output_k
|
||||
?gu_within_range ?dn_within_range (MNum 2))
|
||||
(ICons ?x (ICons ?topk_idx (ICons ?topk_vals (ICons ?gate_up_w (ICons ?down_w (ICons ?topk_vals (INil)))))))))
|
||||
(union ?output ?glumoe)
|
||||
(subsume (Op (Sum ?output_shape ?output_k ?output_in_stride ?output_k_stride ?output_out_stride) (ICons ?weighted (INil))))
|
||||
(subsume (Op (KernelSum ?output_shape ?output_k ?output_in_stride ?output_k_stride ?output_out_stride (F32)) (ICons ?weighted (INil))))
|
||||
)
|
||||
:ruleset glumoe
|
||||
:name "GLUMoE fused expert computation (normalized swiglu)"
|
||||
)
|
||||
|
||||
; ===== Final fusion: mode 1 (Gemma GELU) =====
|
||||
(rule
|
||||
(
|
||||
|
||||
@@ -50,7 +50,7 @@ const WORKSPACE_SIZE: usize = 32 * 1024 * 1024; // 32 MiB
|
||||
/// 3: gate_up_w [E, gate_up_dim, hidden] BF16
|
||||
/// 4: down_w [E, hidden, intermediate] BF16
|
||||
/// 5: mode_aux
|
||||
/// - SwiGLU/SwiGLUNormalized: ignored (rewriter wires `topk_values` again)
|
||||
/// - SwiGLU: ignored (rewriter wires `topk_values` again)
|
||||
/// - GemmaGELU: per_expert_scale [E] F32
|
||||
///
|
||||
/// Output: [seq, hidden] F32
|
||||
@@ -78,7 +78,6 @@ pub struct GLUMoE {
|
||||
pub(crate) enum GLUMoEMode {
|
||||
SwiGLU,
|
||||
GemmaGELU,
|
||||
SwiGLUNormalized,
|
||||
}
|
||||
|
||||
impl GLUMoEMode {
|
||||
@@ -86,7 +85,6 @@ impl GLUMoEMode {
|
||||
match mode_id {
|
||||
0 => Self::SwiGLU,
|
||||
1 => Self::GemmaGELU,
|
||||
2 => Self::SwiGLUNormalized,
|
||||
other => {
|
||||
panic!("Unknown GLUMoE mode id: {other}");
|
||||
}
|
||||
@@ -95,7 +93,7 @@ impl GLUMoEMode {
|
||||
|
||||
fn activation_kernel_mode(self) -> i32 {
|
||||
match self {
|
||||
Self::SwiGLU | Self::SwiGLUNormalized => 0,
|
||||
Self::SwiGLU => 0,
|
||||
Self::GemmaGELU => 1,
|
||||
}
|
||||
}
|
||||
@@ -385,22 +383,22 @@ impl HostOp for GLUMoE {
|
||||
let mode_aux_buf = get_buffer("mode aux", inputs[5])?;
|
||||
let output_buf = get_buffer("output", self_node)?; // [seq, hidden] F32
|
||||
|
||||
let min_topk_bytes = seq * top_k * 4;
|
||||
let topk_bytes = seq * top_k * 4;
|
||||
if x_buf.len() < output_bytes {
|
||||
anyhow::bail!(
|
||||
"GLUMoE x buffer too small: have {} bytes, need {output_bytes}",
|
||||
x_buf.len()
|
||||
);
|
||||
}
|
||||
if topk_idx_buf.len() < min_topk_bytes {
|
||||
if topk_idx_buf.len() < topk_bytes {
|
||||
anyhow::bail!(
|
||||
"GLUMoE topk index buffer too small: have {} bytes, need {min_topk_bytes}",
|
||||
"GLUMoE topk index buffer too small: have {} bytes, need {topk_bytes}",
|
||||
topk_idx_buf.len()
|
||||
);
|
||||
}
|
||||
if topk_vals_buf.len() < min_topk_bytes {
|
||||
if topk_vals_buf.len() < topk_bytes {
|
||||
anyhow::bail!(
|
||||
"GLUMoE topk value buffer too small: have {} bytes, need {min_topk_bytes}",
|
||||
"GLUMoE topk value buffer too small: have {} bytes, need {topk_bytes}",
|
||||
topk_vals_buf.len()
|
||||
);
|
||||
}
|
||||
@@ -442,83 +440,24 @@ impl HostOp for GLUMoE {
|
||||
|
||||
// Read top-k routing values from GPU
|
||||
let topk_idx_host: Vec<u8> = topk_idx_buf.clone_dtoh(stream)?;
|
||||
let topk_idx_i32: &[i32] = bytemuck::cast_slice(&topk_idx_host);
|
||||
let topk_idx_i32: &[i32] = bytemuck::cast_slice(&topk_idx_host[..topk_bytes]);
|
||||
let topk_vals_host: Vec<u8> = topk_vals_buf.clone_dtoh(stream)?;
|
||||
let topk_vals_f32: &[f32] = bytemuck::cast_slice(&topk_vals_host);
|
||||
let topk_vals_f32: &[f32] = bytemuck::cast_slice(&topk_vals_host[..topk_bytes]);
|
||||
|
||||
if !topk_idx_i32.len().is_multiple_of(seq) {
|
||||
anyhow::bail!(
|
||||
"GLUMoE topk index element count {} is not divisible by seq {seq}",
|
||||
topk_idx_i32.len()
|
||||
);
|
||||
}
|
||||
if !topk_vals_f32.len().is_multiple_of(seq) {
|
||||
anyhow::bail!(
|
||||
"GLUMoE topk value element count {} is not divisible by seq {seq}",
|
||||
topk_vals_f32.len()
|
||||
);
|
||||
}
|
||||
let topk_idx_row_stride = topk_idx_i32.len() / seq;
|
||||
let topk_vals_row_stride = topk_vals_f32.len() / seq;
|
||||
if topk_idx_row_stride < top_k {
|
||||
anyhow::bail!(
|
||||
"GLUMoE topk index row stride {topk_idx_row_stride} is smaller than top_k {top_k}"
|
||||
);
|
||||
}
|
||||
if topk_vals_row_stride < top_k {
|
||||
anyhow::bail!(
|
||||
"GLUMoE topk value row stride {topk_vals_row_stride} is smaller than top_k {top_k}"
|
||||
);
|
||||
}
|
||||
|
||||
let topk_idx_at = |token: usize, expert: usize| -> i32 {
|
||||
topk_idx_i32[token * topk_idx_row_stride + expert]
|
||||
};
|
||||
let topk_val_at = |token: usize, expert: usize| -> f32 {
|
||||
topk_vals_f32[token * topk_vals_row_stride + expert]
|
||||
};
|
||||
|
||||
for t in 0..seq {
|
||||
for i in 0..top_k {
|
||||
let expert_idx = topk_idx_at(t, i);
|
||||
if expert_idx < 0 || expert_idx as usize >= num_experts {
|
||||
anyhow::bail!(
|
||||
"GLUMoE expert index {expert_idx} at token {t} top-k position {i} out of bounds for {num_experts} experts"
|
||||
);
|
||||
}
|
||||
for (pos, &expert_idx) in topk_idx_i32.iter().enumerate() {
|
||||
if expert_idx < 0 || expert_idx as usize >= num_experts {
|
||||
anyhow::bail!(
|
||||
"GLUMoE expert index {expert_idx} at routing position {pos} out of bounds for {num_experts} experts"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Mode-dependent expert weights used for the final reduction:
|
||||
// - SwiGLU: direct topk values
|
||||
// - SwiGLUNormalized: normalize topk values row-wise
|
||||
// - GemmaGELU: normalize topk values and scale by per-expert factors
|
||||
let mut expert_weights_storage: Vec<f32> = Vec::new();
|
||||
let expert_weights_f32: &[f32] = match self.mode {
|
||||
GLUMoEMode::SwiGLU => {
|
||||
if topk_vals_row_stride == top_k {
|
||||
topk_vals_f32
|
||||
} else {
|
||||
expert_weights_storage.resize(seq * top_k, 0.0);
|
||||
for t in 0..seq {
|
||||
for i in 0..top_k {
|
||||
expert_weights_storage[t * top_k + i] = topk_val_at(t, i);
|
||||
}
|
||||
}
|
||||
&expert_weights_storage
|
||||
}
|
||||
}
|
||||
GLUMoEMode::SwiGLUNormalized => {
|
||||
expert_weights_storage.resize(seq * top_k, 0.0);
|
||||
for t in 0..seq {
|
||||
let norm = (0..top_k).map(|i| topk_val_at(t, i)).sum::<f32>();
|
||||
let inv_norm = if norm != 0.0 { norm.recip() } else { 0.0 };
|
||||
for i in 0..top_k {
|
||||
expert_weights_storage[t * top_k + i] = topk_val_at(t, i) * inv_norm;
|
||||
}
|
||||
}
|
||||
&expert_weights_storage
|
||||
}
|
||||
GLUMoEMode::SwiGLU => topk_vals_f32,
|
||||
GLUMoEMode::GemmaGELU => {
|
||||
let per_expert_scale_host: Vec<u8> = mode_aux_buf.clone_dtoh(stream)?;
|
||||
let per_expert_scale_bytes = num_experts * 4;
|
||||
@@ -532,10 +471,12 @@ impl HostOp for GLUMoE {
|
||||
bytemuck::cast_slice(&per_expert_scale_host[..per_expert_scale_bytes]);
|
||||
expert_weights_storage.resize(seq * top_k, 0.0);
|
||||
for t in 0..seq {
|
||||
let norm = (0..top_k).map(|i| topk_val_at(t, i)).sum::<f32>();
|
||||
let base = t * top_k;
|
||||
let vals = &topk_vals_f32[base..base + top_k];
|
||||
let norm = vals.iter().copied().sum::<f32>();
|
||||
let inv_norm = if norm != 0.0 { norm.recip() } else { 0.0 };
|
||||
for i in 0..top_k {
|
||||
let expert_idx = topk_idx_at(t, i) as usize;
|
||||
let expert_idx = topk_idx_i32[base + i] as usize;
|
||||
if expert_idx >= per_expert_scale_f32.len() {
|
||||
anyhow::bail!(
|
||||
"GLUMoE Gemma mode expert index {} out of bounds {}",
|
||||
@@ -544,8 +485,7 @@ impl HostOp for GLUMoE {
|
||||
);
|
||||
}
|
||||
let scale = per_expert_scale_f32[expert_idx];
|
||||
expert_weights_storage[t * top_k + i] =
|
||||
topk_val_at(t, i) * inv_norm * scale;
|
||||
expert_weights_storage[base + i] = vals[i] * inv_norm * scale;
|
||||
}
|
||||
}
|
||||
&expert_weights_storage
|
||||
@@ -585,10 +525,12 @@ impl HostOp for GLUMoE {
|
||||
|
||||
for t in 0..seq {
|
||||
let x_t_ptr = xbf16_ptr + (t * hidden * 2) as u64; // BF16
|
||||
let expert_indices = &topk_idx_i32[t * top_k..(t + 1) * top_k];
|
||||
let weights = &expert_weights_f32[t * top_k..(t + 1) * top_k];
|
||||
|
||||
for (i, &weight) in weights.iter().enumerate() {
|
||||
let expert_idx = topk_idx_at(t, i) as usize;
|
||||
for (i, (&expert_idx, &weight)) in expert_indices.iter().zip(weights.iter()).enumerate()
|
||||
{
|
||||
let expert_idx = expert_idx as usize;
|
||||
|
||||
// a. Gate+Up matmul (BF16 in, BF16 out)
|
||||
let expert_gu_ptr = gate_up_ptr + expert_idx as u64 * gu_stride;
|
||||
|
||||
@@ -1,289 +0,0 @@
|
||||
//! Direct conv2d_bias kernel — fuses unfold + matmul + bias into one
|
||||
//! CUDA kernel with no `(H_out*W_out, C_in*K*K)` intermediate matrix.
|
||||
//!
|
||||
//! This is exposed as a luminal `CustomOp`, not a standard egglog-rewritten
|
||||
//! `KernelOp`, because the conv has no useful fusion opportunities with
|
||||
//! surrounding ops in the graphs it's used in (the VAE's resnet blocks),
|
||||
//! and pattern-matching the unfold+permute+merge_dims+matmul+bias chain
|
||||
//! reliably from egglog is significantly more work than just bypassing
|
||||
//! the egglog rewrite path entirely.
|
||||
//!
|
||||
//! The kernel is one-thread-per-output: each thread computes
|
||||
//! `out[co, ho, wo] = bias[co] + sum_{ci,ki,kj} input[ci, ho*S+ki-P, wo*S+kj-P] * weight[co, ci, ki, kj]`
|
||||
//! with bounds checks on the spatial dims for padding. This is far from
|
||||
//! peak FLOPs (no shared-memory tiling, no warp-level reduction over K)
|
||||
//! but it's correct and the memory footprint is just the input + weight +
|
||||
//! bias + output buffers — no `(M, K)` or `(M, N, K)` intermediate, so it
|
||||
//! scales linearly with the actual conv FLOPs rather than blowing up at
|
||||
//! large H/W like the unfold-based formulation.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use cudarc::driver::{CudaFunction, CudaModule, CudaSlice, CudaStream};
|
||||
use luminal::prelude::FxHashMap;
|
||||
use luminal::{
|
||||
dtype::DType, graph::Graph, op::CustomOp, op::LLIROp, prelude::GraphTensor, shape::Expression,
|
||||
};
|
||||
|
||||
use crate::compile_module_image_for_current_device;
|
||||
use crate::kernel::KernelOp;
|
||||
|
||||
/// Direct conv2d-with-bias kernel. All shape/kernel params are static
|
||||
/// (baked into the CUDA source via #defines), so each conv shape gets
|
||||
/// its own compiled kernel. Inputs (in order): input `(C_in, H_in, W_in)`,
|
||||
/// weight `(C_out, C_in*K*K)` (i.e. flattened `(C_out, C_in, K, K)`), bias
|
||||
/// `(C_out,)`. Output: `(C_out, H_out, W_out)`.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Conv2DKernel {
|
||||
pub c_in: usize,
|
||||
pub h_in: usize,
|
||||
pub w_in: usize,
|
||||
pub c_out: usize,
|
||||
pub kernel: usize,
|
||||
pub stride: usize,
|
||||
pub padding: usize,
|
||||
pub h_out: usize,
|
||||
pub w_out: usize,
|
||||
}
|
||||
|
||||
impl Conv2DKernel {
|
||||
fn output_elements(&self) -> usize {
|
||||
self.c_out * self.h_out * self.w_out
|
||||
}
|
||||
}
|
||||
|
||||
const THREADS_PER_BLOCK: usize = 256;
|
||||
|
||||
impl KernelOp for Conv2DKernel {
|
||||
fn compile(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> (
|
||||
CudaFunction,
|
||||
Arc<CudaModule>,
|
||||
String,
|
||||
(Expression, Expression, Expression),
|
||||
(Expression, Expression, Expression),
|
||||
Expression,
|
||||
FxHashMap<char, CudaSlice<u8>>,
|
||||
) {
|
||||
let total = self.output_elements();
|
||||
let grid = total.div_ceil(THREADS_PER_BLOCK);
|
||||
|
||||
let kernel = format!(
|
||||
"
|
||||
extern \"C\" __global__ void conv2d_bias_kernel(
|
||||
float* __restrict__ out,
|
||||
const float* __restrict__ input,
|
||||
const float* __restrict__ weight,
|
||||
const float* __restrict__ bias
|
||||
) {{
|
||||
const int TOTAL = {total};
|
||||
const int CIN = {c_in};
|
||||
const int H = {h_in};
|
||||
const int W = {w_in};
|
||||
const int HOUT = {h_out};
|
||||
const int WOUT = {w_out};
|
||||
const int K = {k};
|
||||
const int S = {s};
|
||||
const int P = {p};
|
||||
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx >= TOTAL) return;
|
||||
int hw = HOUT * WOUT;
|
||||
int co = idx / hw;
|
||||
int rem = idx - co * hw;
|
||||
int ho = rem / WOUT;
|
||||
int wo = rem - ho * WOUT;
|
||||
|
||||
float acc = bias[co];
|
||||
int weight_co_base = co * (CIN * K * K);
|
||||
for (int ci = 0; ci < CIN; ci++) {{
|
||||
int input_ci_base = ci * (H * W);
|
||||
int weight_ci_base = weight_co_base + ci * (K * K);
|
||||
#pragma unroll
|
||||
for (int ki = 0; ki < K; ki++) {{
|
||||
int hi = ho * S + ki - P;
|
||||
if (hi < 0 || hi >= H) continue;
|
||||
int input_row_base = input_ci_base + hi * W;
|
||||
int weight_row_base = weight_ci_base + ki * K;
|
||||
#pragma unroll
|
||||
for (int kj = 0; kj < K; kj++) {{
|
||||
int wj = wo * S + kj - P;
|
||||
if (wj < 0 || wj >= W) continue;
|
||||
acc += input[input_row_base + wj] * weight[weight_row_base + kj];
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
out[idx] = acc;
|
||||
}}
|
||||
",
|
||||
total = total,
|
||||
c_in = self.c_in,
|
||||
h_in = self.h_in,
|
||||
w_in = self.w_in,
|
||||
h_out = self.h_out,
|
||||
w_out = self.w_out,
|
||||
k = self.kernel,
|
||||
s = self.stride,
|
||||
p = self.padding,
|
||||
);
|
||||
|
||||
let (module, func) = if let Some((m, f)) = compile_cache.get(&kernel) {
|
||||
(m.clone(), f.clone())
|
||||
} else {
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let func = module.load_function("conv2d_bias_kernel").unwrap();
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
(module, func)
|
||||
};
|
||||
|
||||
(
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(
|
||||
Expression::from(grid),
|
||||
Expression::from(1usize),
|
||||
Expression::from(1usize),
|
||||
),
|
||||
(
|
||||
Expression::from(THREADS_PER_BLOCK),
|
||||
Expression::from(1usize),
|
||||
Expression::from(1usize),
|
||||
),
|
||||
Expression::from(0usize),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
Expression::from(self.output_elements())
|
||||
}
|
||||
|
||||
fn output_bytes(&self) -> Expression {
|
||||
self.output_size() * 4
|
||||
}
|
||||
|
||||
fn output_dtype(&self) -> DType {
|
||||
DType::F32
|
||||
}
|
||||
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
// Per output: C_in * K * K input loads + same many weight loads + 1 bias load.
|
||||
let per_out = self.c_in * self.kernel * self.kernel * 2 + 1;
|
||||
Expression::from(self.output_elements() * per_out * 4)
|
||||
}
|
||||
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
self.output_size() * 4
|
||||
}
|
||||
|
||||
fn flops(&self) -> Expression {
|
||||
// 2 * C_in * K * K mul-adds per output, plus the bias add = +1.
|
||||
let per_out = self.c_in * self.kernel * self.kernel * 2 + 1;
|
||||
Expression::from(self.output_elements() * per_out)
|
||||
}
|
||||
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
"Conv2DBias"
|
||||
}
|
||||
}
|
||||
|
||||
/// luminal `CustomOp` that wraps `Conv2DKernel`. Lets us drop the kernel
|
||||
/// straight into an HLIR graph via `cx.custom_op(...)` without going
|
||||
/// through egglog rewrites.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Conv2DCustom(pub Conv2DKernel);
|
||||
|
||||
impl CustomOp for Conv2DCustom {
|
||||
fn to_llir_op(&self) -> LLIROp {
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(self.0.clone()) as Box<dyn KernelOp>)
|
||||
}
|
||||
}
|
||||
|
||||
/// 2D conv-with-bias on a `(C_in, H, W)` F32 input tensor, with weights
|
||||
/// stored as `(C_out, C_in*K*K)` and bias as `(C_out,)`. Stride/padding/kernel
|
||||
/// are static. Output: `(C_out, H_out, W_out)`.
|
||||
///
|
||||
/// This is a thin wrapper over [`Conv2DKernel`] that hides the
|
||||
/// `cx.custom_op` plumbing. All inputs MUST be `DType::F32` and contiguous
|
||||
/// row-major; pass `tensor * 1.0_f32` first if you have a strided view.
|
||||
pub fn conv2d_bias(
|
||||
input: GraphTensor,
|
||||
weight: GraphTensor,
|
||||
bias: GraphTensor,
|
||||
kernel: usize,
|
||||
stride: usize,
|
||||
padding: usize,
|
||||
) -> GraphTensor {
|
||||
assert_eq!(input.dtype, DType::F32, "conv2d_bias requires F32 input");
|
||||
assert_eq!(weight.dtype, DType::F32, "conv2d_bias requires F32 weight");
|
||||
assert_eq!(bias.dtype, DType::F32, "conv2d_bias requires F32 bias");
|
||||
|
||||
let dims = input.dims();
|
||||
assert_eq!(dims.len(), 3, "conv2d_bias expects (C_in, H, W) input");
|
||||
let c_in = dims[0].to_usize().expect("C_in must be a static dim");
|
||||
let h_in = dims[1].to_usize().expect("H must be a static dim");
|
||||
let w_in = dims[2].to_usize().expect("W must be a static dim");
|
||||
|
||||
let w_dims = weight.dims();
|
||||
assert_eq!(
|
||||
w_dims.len(),
|
||||
2,
|
||||
"conv2d_bias expects weight (C_out, C_in*K*K)"
|
||||
);
|
||||
let c_out = w_dims[0].to_usize().expect("C_out must be a static dim");
|
||||
let w_kk = w_dims[1]
|
||||
.to_usize()
|
||||
.expect("weight inner dim must be static");
|
||||
assert_eq!(
|
||||
w_kk,
|
||||
c_in * kernel * kernel,
|
||||
"weight inner dim {w_kk} != C_in*K*K = {}",
|
||||
c_in * kernel * kernel,
|
||||
);
|
||||
|
||||
let b_dims = bias.dims();
|
||||
assert_eq!(b_dims.len(), 1, "conv2d_bias expects bias (C_out,)");
|
||||
assert_eq!(
|
||||
b_dims[0].to_usize().expect("bias dim must be static"),
|
||||
c_out
|
||||
);
|
||||
|
||||
assert!(
|
||||
h_in + 2 * padding >= kernel,
|
||||
"padded H_in ({}) is smaller than kernel ({})",
|
||||
h_in + 2 * padding,
|
||||
kernel,
|
||||
);
|
||||
assert!(
|
||||
w_in + 2 * padding >= kernel,
|
||||
"padded W_in ({}) is smaller than kernel ({})",
|
||||
w_in + 2 * padding,
|
||||
kernel,
|
||||
);
|
||||
let h_out = (h_in + 2 * padding - kernel) / stride + 1;
|
||||
let w_out = (w_in + 2 * padding - kernel) / stride + 1;
|
||||
|
||||
let kern = Conv2DKernel {
|
||||
c_in,
|
||||
h_in,
|
||||
w_in,
|
||||
c_out,
|
||||
kernel,
|
||||
stride,
|
||||
padding,
|
||||
h_out,
|
||||
w_out,
|
||||
};
|
||||
let cx: &mut Graph = unsafe { &mut *input.graph_ref };
|
||||
cx.custom_op(
|
||||
Conv2DCustom(kern),
|
||||
vec![input, weight, bias],
|
||||
(c_out, h_out, w_out),
|
||||
DType::F32,
|
||||
)
|
||||
}
|
||||
@@ -1,378 +0,0 @@
|
||||
// =========================================================================
|
||||
// Generic CUDA elementwise ops used inside FusionStart/FusionEnd regions.
|
||||
//
|
||||
// CUDA elementwise execution is represented as a FusionEnd-rooted region even
|
||||
// for a single op. These ops are therefore region-internal only; standalone
|
||||
// compilation is intentionally unsupported.
|
||||
// =========================================================================
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use cudarc::driver::{CudaFunction, CudaModule, CudaSlice, CudaStream};
|
||||
use luminal::{
|
||||
egglog_utils::{
|
||||
api::{Rule, SortDef, sort},
|
||||
base::{DTYPE, ELIST, OP_KIND, STRING},
|
||||
extract_dtype, extract_expr_list,
|
||||
},
|
||||
op::*,
|
||||
prelude::*,
|
||||
};
|
||||
|
||||
use crate::kernel::KernelOp;
|
||||
|
||||
pub type Ops = (CudaUnaryElementwise, CudaBinaryElementwise);
|
||||
|
||||
type CompileOut = (
|
||||
CudaFunction,
|
||||
Arc<CudaModule>,
|
||||
String,
|
||||
(Expression, Expression, Expression),
|
||||
(Expression, Expression, Expression),
|
||||
Expression,
|
||||
FxHashMap<char, CudaSlice<u8>>,
|
||||
);
|
||||
|
||||
fn extract_string_label(egraph: &SerializedEGraph, node: &ENodeId) -> String {
|
||||
egraph.enodes[node].0.trim_matches('"').to_string()
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct CudaUnaryElementwise {
|
||||
pub(crate) op: String,
|
||||
pub(crate) shape: Vec<Expression>,
|
||||
pub(crate) in_strides: Vec<Expression>,
|
||||
pub(crate) out_strides: Vec<Expression>,
|
||||
pub(crate) dtype: DType,
|
||||
}
|
||||
|
||||
impl EgglogOp for CudaUnaryElementwise {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
"CudaUnaryElementwise",
|
||||
&[
|
||||
("op", STRING),
|
||||
("shape", ELIST),
|
||||
("strides", ELIST),
|
||||
("out_strides", ELIST),
|
||||
("dtype", DTYPE),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
fn n_inputs(&self) -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
let mut rules = Vec::new();
|
||||
for (hlir, opcode) in [
|
||||
("Sin", "Sin"),
|
||||
("Sqrt", "Sqrt"),
|
||||
("Exp2", "Exp2"),
|
||||
("Log2", "Log2"),
|
||||
("Recip", "Recip"),
|
||||
] {
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule (
|
||||
(= ?u (Op ({hlir} ?shape ?s ?out_s) (ICons ?x (INil))))
|
||||
(= ?dt (dtype ?u))
|
||||
) (
|
||||
(let ?fs (Op (FusionStart ?shape ?s ?dt) (ICons ?x (INil))))
|
||||
(let ?elem (Op (CudaUnaryElementwise \"{opcode}\" ?shape ?s ?out_s ?dt)
|
||||
(ICons ?fs (INil))))
|
||||
(let ?fe (Op (FusionEnd ?shape ?out_s ?dt) (ICons ?elem (INil))))
|
||||
(union ?u ?fe)
|
||||
(set (dtype ?fe) ?dt)
|
||||
) :ruleset kernel_lower :name \"cuda-elem-singleton-{hlir}\")"
|
||||
)));
|
||||
}
|
||||
|
||||
rules.push(Rule::raw(
|
||||
"(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?shape ?x_stride ?const_stride ?inter_stride) (ICons ?x (ICons ?exp_const (INil)))))
|
||||
(= ?exp2 (Op (Exp2 ?shape ?inter_stride ?out_stride) (ICons ?mul (INil))))
|
||||
(= ?dt (dtype ?x))
|
||||
(= ?cv (Op (Constant ?val) (INil)))
|
||||
(= ?exp_const ?cv)
|
||||
(> ?val 1.44)
|
||||
(< ?val 1.45)
|
||||
)
|
||||
(
|
||||
(let ?fs (Op (FusionStart ?shape ?x_stride ?dt) (ICons ?x (INil))))
|
||||
(let ?elem (Op (CudaUnaryElementwise \"Exp\" ?shape ?x_stride ?out_stride ?dt)
|
||||
(ICons ?fs (INil))))
|
||||
(let ?fe (Op (FusionEnd ?shape ?out_stride ?dt) (ICons ?elem (INil))))
|
||||
(union ?exp2 ?fe)
|
||||
(set (dtype ?fe) ?dt)
|
||||
)
|
||||
:ruleset direct_kernel
|
||||
:name \"direct-exp-region\"
|
||||
)",
|
||||
));
|
||||
|
||||
rules.push(Rule::raw(
|
||||
"(datatype*
|
||||
(CudaSigmoidScaledState
|
||||
(MkCudaSigmoidScaledState IR EList EList DType)
|
||||
)
|
||||
)
|
||||
(function cuda_sigmoid_scaled (IR) CudaSigmoidScaledState :merge new)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?neg1 (Op (Constant ?nv) (INil)))
|
||||
(< ?nv -0.99)
|
||||
(> ?nv -1.01)
|
||||
(= ?neg_x (Op (Mul ?shape ?x_stride ?neg_stride ?neg_out_stride) (ICons ?x (ICons ?neg1 (INil)))))
|
||||
(= ?log2e (Op (Constant ?lv) (INil)))
|
||||
(> ?lv 1.44)
|
||||
(< ?lv 1.45)
|
||||
(= ?scaled (Op (Mul ?shape ?neg_out_stride ?log2e_stride ?scaled_stride) (ICons ?neg_x (ICons ?log2e (INil)))))
|
||||
(= ?dt (dtype ?x))
|
||||
)
|
||||
(
|
||||
(set (cuda_sigmoid_scaled ?scaled)
|
||||
(MkCudaSigmoidScaledState ?x ?shape ?x_stride ?dt))
|
||||
)
|
||||
:ruleset direct_kernel
|
||||
:name \"direct-sigmoid-scaled-region-marker\"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?scaled_state (cuda_sigmoid_scaled ?scaled))
|
||||
(= ?scaled_state (MkCudaSigmoidScaledState ?x ?shape ?x_stride ?dt))
|
||||
(= ?exp2 (Op (Exp2 ?shape ?scaled_stride ?exp_stride) (ICons ?scaled (INil))))
|
||||
(= ?one (Op (Constant ?ov) (INil)))
|
||||
(> ?ov 0.99)
|
||||
(< ?ov 1.01)
|
||||
(= ?plus_one (Op (Add ?shape ?exp_stride ?one_stride ?add_stride) (ICons ?exp2 (ICons ?one (INil)))))
|
||||
(= ?sig_out (Op (Recip ?shape ?add_stride ?out_stride) (ICons ?plus_one (INil))))
|
||||
)
|
||||
(
|
||||
(let ?fs (Op (FusionStart ?shape ?x_stride ?dt) (ICons ?x (INil))))
|
||||
(let ?elem (Op (CudaUnaryElementwise \"Sigmoid\" ?shape ?x_stride ?out_stride ?dt)
|
||||
(ICons ?fs (INil))))
|
||||
(let ?fe (Op (FusionEnd ?shape ?out_stride ?dt) (ICons ?elem (INil))))
|
||||
(union ?sig_out ?fe)
|
||||
(set (dtype ?fe) ?dt)
|
||||
)
|
||||
:ruleset direct_kernel
|
||||
:name \"direct-sigmoid-region\"
|
||||
)",
|
||||
));
|
||||
|
||||
rules
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
(
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(Self {
|
||||
op: extract_string_label(egraph, kind_children[0]),
|
||||
shape: extract_expr_list(egraph, kind_children[1], list_cache, expr_cache).unwrap(),
|
||||
in_strides: extract_expr_list(egraph, kind_children[2], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
out_strides: extract_expr_list(egraph, kind_children[3], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
dtype: extract_dtype(egraph, kind_children[4]),
|
||||
})),
|
||||
input_enodes,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl KernelOp for CudaUnaryElementwise {
|
||||
fn compile(
|
||||
&self,
|
||||
_stream: &Arc<CudaStream>,
|
||||
_compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> CompileOut {
|
||||
unreachable!("CudaUnaryElementwise must be compiled through fusion region codegen")
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
self.shape.iter().copied().product()
|
||||
}
|
||||
|
||||
fn output_bytes(&self) -> Expression {
|
||||
(self.output_size() * self.dtype.bits()).ceil_div(8)
|
||||
}
|
||||
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
self.output_bytes()
|
||||
}
|
||||
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
self.output_bytes()
|
||||
}
|
||||
|
||||
fn flops(&self) -> Expression {
|
||||
self.output_size()
|
||||
}
|
||||
|
||||
fn output_dtype(&self) -> DType {
|
||||
self.dtype
|
||||
}
|
||||
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
"CudaUnaryElementwise"
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct CudaBinaryElementwise {
|
||||
pub(crate) op: String,
|
||||
pub(crate) out_shape: Vec<Expression>,
|
||||
pub(crate) a_stride: Vec<Expression>,
|
||||
pub(crate) b_stride: Vec<Expression>,
|
||||
pub(crate) out_stride: Vec<Expression>,
|
||||
pub(crate) dtype: DType,
|
||||
}
|
||||
|
||||
impl EgglogOp for CudaBinaryElementwise {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
"CudaBinaryElementwise",
|
||||
&[
|
||||
("op", STRING),
|
||||
("shape", ELIST),
|
||||
("a_strides", ELIST),
|
||||
("b_strides", ELIST),
|
||||
("out_strides", ELIST),
|
||||
("dtype", DTYPE),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
fn n_inputs(&self) -> usize {
|
||||
2
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![
|
||||
Rule::raw(
|
||||
"(rule (
|
||||
(= ?bin (Op (Add ?shape ?a_s ?b_s ?out_s) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?dt (dtype ?bin))
|
||||
) (
|
||||
(let ?fs_a (Op (FusionStart ?shape ?a_s ?dt) (ICons ?a (INil))))
|
||||
(let ?fs_b (Op (FusionStart ?shape ?b_s ?dt) (ICons ?b (INil))))
|
||||
(let ?elem (Op (CudaBinaryElementwise \"Add\" ?shape ?a_s ?b_s ?out_s ?dt)
|
||||
(ICons ?fs_a (ICons ?fs_b (INil)))))
|
||||
(let ?fe (Op (FusionEnd ?shape ?out_s ?dt) (ICons ?elem (INil))))
|
||||
(union ?bin ?fe)
|
||||
(set (dtype ?fe) ?dt)
|
||||
) :ruleset kernel_lower :name \"cuda-elem-singleton-Add\")",
|
||||
),
|
||||
Rule::raw(
|
||||
"(rule (
|
||||
(= ?bin (Op (Mul ?shape ?a_s ?b_s ?out_s) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?dt (dtype ?a))
|
||||
) (
|
||||
(let ?fs_a (Op (FusionStart ?shape ?a_s ?dt) (ICons ?a (INil))))
|
||||
(let ?fs_b (Op (FusionStart ?shape ?b_s ?dt) (ICons ?b (INil))))
|
||||
(let ?elem (Op (CudaBinaryElementwise \"Mul\" ?shape ?a_s ?b_s ?out_s ?dt)
|
||||
(ICons ?fs_a (ICons ?fs_b (INil)))))
|
||||
(let ?fe (Op (FusionEnd ?shape ?out_s ?dt) (ICons ?elem (INil))))
|
||||
(union ?bin ?fe)
|
||||
(set (dtype ?fe) ?dt)
|
||||
) :ruleset kernel_lower :name \"cuda-elem-singleton-Mul\")",
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
let mut out_shape =
|
||||
extract_expr_list(egraph, kind_children[1], list_cache, expr_cache).unwrap();
|
||||
let mut a_stride =
|
||||
extract_expr_list(egraph, kind_children[2], list_cache, expr_cache).unwrap();
|
||||
let mut b_stride =
|
||||
extract_expr_list(egraph, kind_children[3], list_cache, expr_cache).unwrap();
|
||||
let mut out_stride =
|
||||
extract_expr_list(egraph, kind_children[4], list_cache, expr_cache).unwrap();
|
||||
let n = out_shape
|
||||
.len()
|
||||
.min(a_stride.len())
|
||||
.min(b_stride.len())
|
||||
.min(out_stride.len());
|
||||
out_shape.truncate(n);
|
||||
a_stride.truncate(n);
|
||||
b_stride.truncate(n);
|
||||
out_stride.truncate(n);
|
||||
(
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(Self {
|
||||
op: extract_string_label(egraph, kind_children[0]),
|
||||
out_shape,
|
||||
a_stride,
|
||||
b_stride,
|
||||
out_stride,
|
||||
dtype: extract_dtype(egraph, kind_children[5]),
|
||||
})),
|
||||
input_enodes,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl KernelOp for CudaBinaryElementwise {
|
||||
fn compile(
|
||||
&self,
|
||||
_stream: &Arc<CudaStream>,
|
||||
_compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> CompileOut {
|
||||
unreachable!("CudaBinaryElementwise must be compiled through fusion region codegen")
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
self.out_shape.iter().copied().product()
|
||||
}
|
||||
|
||||
fn output_bytes(&self) -> Expression {
|
||||
(self.output_size() * self.dtype.bits()).ceil_div(8)
|
||||
}
|
||||
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
self.output_bytes() * 2
|
||||
}
|
||||
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
self.output_bytes()
|
||||
}
|
||||
|
||||
fn flops(&self) -> Expression {
|
||||
self.output_size()
|
||||
}
|
||||
|
||||
fn output_dtype(&self) -> DType {
|
||||
self.dtype
|
||||
}
|
||||
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
"CudaBinaryElementwise"
|
||||
}
|
||||
}
|
||||
301
crates/luminal_cuda_lite/src/kernel/fusion/fused_ops.rs
Normal file
301
crates/luminal_cuda_lite/src/kernel/fusion/fused_ops.rs
Normal file
@@ -0,0 +1,301 @@
|
||||
// =========================================================================
|
||||
// Fused elementwise op variants used inside FusionStart/FusionEnd regions.
|
||||
//
|
||||
// Each `FusedX` struct mirrors its un-fused `KernelX` sibling field-for-field
|
||||
// and serves a single purpose: give the egglog rules a distinct sort to
|
||||
// rewrite into so a pair-fuse rule's RHS can never re-match its own LHS
|
||||
// pattern. Cascade prevention by typing.
|
||||
//
|
||||
// Each FusedX must be absorbed into a FusionEnd-rooted region and compiled by
|
||||
// `region_codegen`; standalone compilation is intentionally unsupported.
|
||||
// =========================================================================
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use cudarc::driver::{CudaFunction, CudaModule, CudaSlice, CudaStream};
|
||||
use luminal::{
|
||||
egglog_utils::{
|
||||
api::{Rule, SortDef, sort},
|
||||
base::{DTYPE, ELIST, OP_KIND},
|
||||
extract_dtype, extract_expr_list,
|
||||
},
|
||||
op::*,
|
||||
prelude::*,
|
||||
};
|
||||
|
||||
use crate::kernel::KernelOp;
|
||||
|
||||
pub type Ops = (
|
||||
FusedSin,
|
||||
FusedSqrt,
|
||||
FusedExp,
|
||||
FusedExp2,
|
||||
FusedLog2,
|
||||
FusedRecip,
|
||||
FusedAdd,
|
||||
FusedMul,
|
||||
);
|
||||
|
||||
// Standard `compile()` return tuple (matches the trait signature).
|
||||
type CompileOut = (
|
||||
CudaFunction,
|
||||
Arc<CudaModule>,
|
||||
String,
|
||||
(Expression, Expression, Expression),
|
||||
(Expression, Expression, Expression),
|
||||
Expression,
|
||||
FxHashMap<char, CudaSlice<u8>>,
|
||||
);
|
||||
|
||||
/// Generate `pub struct $Name { … unary fields … }` plus its `EgglogOp` and
|
||||
/// `KernelOp` impls. `$kernel_name` names the CUDA function (and the cache
|
||||
/// key); `$body` is the per-op CUDA expression, e.g. `"sinf(in[{in_idx}])"`.
|
||||
macro_rules! impl_fused_unary {
|
||||
($Name:ident, $sort:literal, $kernel_name:literal, $body:literal) => {
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct $Name {
|
||||
pub(crate) shape: Vec<Expression>,
|
||||
pub(crate) in_strides: Vec<Expression>,
|
||||
pub(crate) out_strides: Vec<Expression>,
|
||||
pub(crate) dtype: DType,
|
||||
}
|
||||
|
||||
impl EgglogOp for $Name {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
$sort,
|
||||
&[
|
||||
("shape", ELIST),
|
||||
("strides", ELIST),
|
||||
("out_strides", ELIST),
|
||||
("dtype", DTYPE),
|
||||
],
|
||||
)
|
||||
}
|
||||
fn n_inputs(&self) -> usize {
|
||||
1
|
||||
}
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
Vec::new()
|
||||
}
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
(
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(Self {
|
||||
shape: extract_expr_list(egraph, kind_children[0], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
in_strides: extract_expr_list(
|
||||
egraph,
|
||||
kind_children[1],
|
||||
list_cache,
|
||||
expr_cache,
|
||||
)
|
||||
.unwrap(),
|
||||
out_strides: extract_expr_list(
|
||||
egraph,
|
||||
kind_children[2],
|
||||
list_cache,
|
||||
expr_cache,
|
||||
)
|
||||
.unwrap(),
|
||||
dtype: extract_dtype(egraph, kind_children[3]),
|
||||
})),
|
||||
input_enodes,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl KernelOp for $Name {
|
||||
fn compile(
|
||||
&self,
|
||||
_stream: &Arc<CudaStream>,
|
||||
_compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> CompileOut {
|
||||
unreachable!(concat!(
|
||||
$sort,
|
||||
" must be compiled through fusion region codegen"
|
||||
))
|
||||
}
|
||||
fn output_size(&self) -> Expression {
|
||||
self.shape.iter().copied().product()
|
||||
}
|
||||
fn output_bytes(&self) -> Expression {
|
||||
(self.output_size() * self.dtype.bits()).ceil_div(8)
|
||||
}
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
self.output_bytes()
|
||||
}
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
self.output_bytes()
|
||||
}
|
||||
fn flops(&self) -> Expression {
|
||||
self.shape.iter().copied().product()
|
||||
}
|
||||
fn output_dtype(&self) -> DType {
|
||||
self.dtype
|
||||
}
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
$sort
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/// As `impl_fused_unary!` but for binary ops: 5-field sort signature
|
||||
/// (shape + per-input strides + out_stride + dtype), n_inputs = 2.
|
||||
/// `$op_str` is the CUDA infix operator, e.g. `"+"`, `"*"`.
|
||||
macro_rules! impl_fused_binary {
|
||||
($Name:ident, $sort:literal, $kernel_name:literal, $op_str:literal) => {
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct $Name {
|
||||
pub(crate) out_shape: Vec<Expression>,
|
||||
pub(crate) a_stride: Vec<Expression>,
|
||||
pub(crate) b_stride: Vec<Expression>,
|
||||
pub(crate) out_stride: Vec<Expression>,
|
||||
pub(crate) dtype: DType,
|
||||
}
|
||||
|
||||
impl EgglogOp for $Name {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
$sort,
|
||||
&[
|
||||
("shape", ELIST),
|
||||
("a_strides", ELIST),
|
||||
("b_strides", ELIST),
|
||||
("out_strides", ELIST),
|
||||
("dtype", DTYPE),
|
||||
],
|
||||
)
|
||||
}
|
||||
fn n_inputs(&self) -> usize {
|
||||
2
|
||||
}
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
Vec::new()
|
||||
}
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
(
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(Self {
|
||||
out_shape: extract_expr_list(
|
||||
egraph,
|
||||
kind_children[0],
|
||||
list_cache,
|
||||
expr_cache,
|
||||
)
|
||||
.unwrap(),
|
||||
a_stride: extract_expr_list(
|
||||
egraph,
|
||||
kind_children[1],
|
||||
list_cache,
|
||||
expr_cache,
|
||||
)
|
||||
.unwrap(),
|
||||
b_stride: extract_expr_list(
|
||||
egraph,
|
||||
kind_children[2],
|
||||
list_cache,
|
||||
expr_cache,
|
||||
)
|
||||
.unwrap(),
|
||||
out_stride: extract_expr_list(
|
||||
egraph,
|
||||
kind_children[3],
|
||||
list_cache,
|
||||
expr_cache,
|
||||
)
|
||||
.unwrap(),
|
||||
dtype: extract_dtype(egraph, kind_children[4]),
|
||||
})),
|
||||
input_enodes,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl KernelOp for $Name {
|
||||
fn compile(
|
||||
&self,
|
||||
_stream: &Arc<CudaStream>,
|
||||
_compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> CompileOut {
|
||||
unreachable!(concat!(
|
||||
$sort,
|
||||
" must be compiled through fusion region codegen"
|
||||
))
|
||||
}
|
||||
fn output_size(&self) -> Expression {
|
||||
self.out_shape.iter().copied().product()
|
||||
}
|
||||
fn output_bytes(&self) -> Expression {
|
||||
(self.output_size() * self.dtype.bits()).ceil_div(8)
|
||||
}
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
let bytes = (self.output_size() * self.dtype.bits()).ceil_div(8);
|
||||
bytes + bytes
|
||||
}
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
self.output_bytes()
|
||||
}
|
||||
fn flops(&self) -> Expression {
|
||||
self.out_shape.iter().copied().product()
|
||||
}
|
||||
fn output_dtype(&self) -> DType {
|
||||
self.dtype
|
||||
}
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
$sort
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl_fused_unary!(FusedSin, "FusedSin", "fused_sin_k", "sinf(in[{in_idx}])");
|
||||
impl_fused_unary!(
|
||||
FusedSqrt,
|
||||
"FusedSqrt",
|
||||
"fused_sqrt_k",
|
||||
"sqrtf(in[{in_idx}])"
|
||||
);
|
||||
impl_fused_unary!(FusedExp, "FusedExp", "fused_exp_k", "expf(in[{in_idx}])");
|
||||
impl_fused_unary!(
|
||||
FusedExp2,
|
||||
"FusedExp2",
|
||||
"fused_exp2_k",
|
||||
"exp2f(in[{in_idx}])"
|
||||
);
|
||||
impl_fused_unary!(
|
||||
FusedLog2,
|
||||
"FusedLog2",
|
||||
"fused_log2_k",
|
||||
"log2f(in[{in_idx}])"
|
||||
);
|
||||
impl_fused_unary!(
|
||||
FusedRecip,
|
||||
"FusedRecip",
|
||||
"fused_recip_k",
|
||||
"1.0f / in[{in_idx}]"
|
||||
);
|
||||
|
||||
impl_fused_binary!(FusedAdd, "FusedAdd", "fused_add_k", "+");
|
||||
impl_fused_binary!(FusedMul, "FusedMul", "fused_mul_k", "*");
|
||||
@@ -9,8 +9,8 @@
|
||||
//
|
||||
// `FusionEnd::rewrites()` carries the seven rule families that build and
|
||||
// extend regions (pair-fuse / grow / merge); the actual single-kernel
|
||||
// codegen lives in `region_codegen`. Both markers' `compile()` is
|
||||
// `unreachable!()` — region codegen folds them away
|
||||
// codegen lives in `region_codegen`. Like FusedX, both markers'
|
||||
// `compile()` is `unreachable!()` — region codegen folds them away
|
||||
// before kernel_to_host's compile loop reaches an interior node.
|
||||
// =========================================================================
|
||||
|
||||
@@ -142,164 +142,218 @@ impl EgglogOp for FusionEnd {
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
// Generic region growth works directly from HLIR elementwise ops into
|
||||
// `Cuda*Elementwise` region nodes. The concrete HLIR op still appears in
|
||||
// the egraph, so fusion remains a normal nondestructive alternative, but
|
||||
// the region-internal representation is arity based instead of one
|
||||
// dedicated fused sort per operation.
|
||||
// Seven rule families build and extend FE-bracketed regions. Each
|
||||
// pair-fuse rule's LHS pattern matches *un-fused* `KernelX` ops; the
|
||||
// RHS produces `FusedX` variants in a different egglog sort, so the
|
||||
// rule's own output cannot re-match its LHS — cascade is prevented
|
||||
// by typing rather than by a discriminator field.
|
||||
//
|
||||
// Stride compatibility is expressed by reusing variable names: a
|
||||
// unary inside a region matches `(KernelU ?shape ?s ?s ?dt)` (in =
|
||||
// out, no transpose); a binary feeding a downstream op binds the
|
||||
// binary's out-stride to the downstream op's in-stride along the
|
||||
// connecting side.
|
||||
let mut rules = Vec::new();
|
||||
|
||||
// (KernelX kind, FusedX kind)
|
||||
let unaries: &[(&str, &str)] = &[
|
||||
("Sin", "Sin"),
|
||||
("Sqrt", "Sqrt"),
|
||||
("Exp2", "Exp2"),
|
||||
("Log2", "Log2"),
|
||||
("Recip", "Recip"),
|
||||
("KernelSin", "FusedSin"),
|
||||
("KernelSqrt", "FusedSqrt"),
|
||||
("KernelExp", "FusedExp"),
|
||||
("KernelExp2", "FusedExp2"),
|
||||
("KernelLog2", "FusedLog2"),
|
||||
("KernelRecip", "FusedRecip"),
|
||||
];
|
||||
// (KernelX kind, FusedX kind, rule-name label)
|
||||
let binaries: &[(&str, &str, &str)] = &[
|
||||
("KernelAdd", "FusedAdd", "Add"),
|
||||
("KernelMul", "FusedMul", "Mul"),
|
||||
];
|
||||
let binaries: &[(&str, &str)] = &[("Add", "Add"), ("Mul", "Mul")];
|
||||
|
||||
// Grow FE → unary consumer: U(FE(inner)) → FE(CudaUnary(inner)).
|
||||
for (hlir, opcode) in unaries {
|
||||
// 1. Pair-fuse U → U: U2(U1(x)) → FE(FU2(FU1(FS(x)))).
|
||||
for (ki1, fi1) in unaries {
|
||||
for (ko2, fo2) in unaries {
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule (
|
||||
(= ?u1 (Op ({ki1} ?shape ?s ?s ?dt) (ICons ?x (INil))))
|
||||
(= ?u2 (Op ({ko2} ?shape ?s ?s ?dt) (ICons ?u1 (INil))))
|
||||
) (
|
||||
(let ?fs (Op (FusionStart ?shape ?s ?dt) (ICons ?x (INil))))
|
||||
(let ?fu1 (Op ({fi1} ?shape ?s ?s ?dt) (ICons ?fs (INil))))
|
||||
(let ?fu2 (Op ({fo2} ?shape ?s ?s ?dt) (ICons ?fu1 (INil))))
|
||||
(let ?fe (Op (FusionEnd ?shape ?s ?dt) (ICons ?fu2 (INil))))
|
||||
(union ?u2 ?fe)
|
||||
) :ruleset fusion_pair :name \"pair-fuse-U-U-{ki1}-{ko2}\")"
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Pair-fuse B → U: U(B(a, b)) → FE(FU(FB(FS(a), FS(b)))).
|
||||
for (kb, fb, lb) in binaries {
|
||||
for (ku, fu) in unaries {
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule (
|
||||
(= ?bin (Op ({kb} ?shape ?a_s ?b_s ?o_s ?dt)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(= ?u (Op ({ku} ?shape ?o_s ?o_s ?dt) (ICons ?bin (INil))))
|
||||
) (
|
||||
(let ?fs_a (Op (FusionStart ?shape ?a_s ?dt) (ICons ?a (INil))))
|
||||
(let ?fs_b (Op (FusionStart ?shape ?b_s ?dt) (ICons ?b (INil))))
|
||||
(let ?fbin (Op ({fb} ?shape ?a_s ?b_s ?o_s ?dt)
|
||||
(ICons ?fs_a (ICons ?fs_b (INil)))))
|
||||
(let ?fu (Op ({fu} ?shape ?o_s ?o_s ?dt) (ICons ?fbin (INil))))
|
||||
(let ?fe (Op (FusionEnd ?shape ?o_s ?dt) (ICons ?fu (INil))))
|
||||
(union ?u ?fe)
|
||||
) :ruleset fusion_pair :name \"pair-fuse-B-U-{lb}-{ku}\")"
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Pair-fuse U → B (lhs / rhs): unary feeds binary's A or B input.
|
||||
// LHS: B(U(a), b) → FE(FB(FU(FS(a)), FS(b))).
|
||||
// RHS: B(a, U(b)) → FE(FB(FS(a), FU(FS(b)))).
|
||||
for (ku, fu) in unaries {
|
||||
for (kb, fb, lb) in binaries {
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule (
|
||||
(= ?u (Op ({ku} ?shape ?u_s ?u_s ?dt) (ICons ?a (INil))))
|
||||
(= ?bin (Op ({kb} ?shape ?u_s ?b_s ?o_s ?dt)
|
||||
(ICons ?u (ICons ?b (INil)))))
|
||||
) (
|
||||
(let ?fs_a (Op (FusionStart ?shape ?u_s ?dt) (ICons ?a (INil))))
|
||||
(let ?fs_b (Op (FusionStart ?shape ?b_s ?dt) (ICons ?b (INil))))
|
||||
(let ?fu (Op ({fu} ?shape ?u_s ?u_s ?dt) (ICons ?fs_a (INil))))
|
||||
(let ?fbin (Op ({fb} ?shape ?u_s ?b_s ?o_s ?dt)
|
||||
(ICons ?fu (ICons ?fs_b (INil)))))
|
||||
(let ?fe (Op (FusionEnd ?shape ?o_s ?dt) (ICons ?fbin (INil))))
|
||||
(union ?bin ?fe)
|
||||
) :ruleset fusion_pair :name \"pair-fuse-U-B-lhs-{ku}-{lb}\")"
|
||||
)));
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule (
|
||||
(= ?u (Op ({ku} ?shape ?u_s ?u_s ?dt) (ICons ?b (INil))))
|
||||
(= ?bin (Op ({kb} ?shape ?a_s ?u_s ?o_s ?dt)
|
||||
(ICons ?a (ICons ?u (INil)))))
|
||||
) (
|
||||
(let ?fs_a (Op (FusionStart ?shape ?a_s ?dt) (ICons ?a (INil))))
|
||||
(let ?fs_b (Op (FusionStart ?shape ?u_s ?dt) (ICons ?b (INil))))
|
||||
(let ?fu (Op ({fu} ?shape ?u_s ?u_s ?dt) (ICons ?fs_b (INil))))
|
||||
(let ?fbin (Op ({fb} ?shape ?a_s ?u_s ?o_s ?dt)
|
||||
(ICons ?fs_a (ICons ?fu (INil)))))
|
||||
(let ?fe (Op (FusionEnd ?shape ?o_s ?dt) (ICons ?fbin (INil))))
|
||||
(union ?bin ?fe)
|
||||
) :ruleset fusion_pair :name \"pair-fuse-U-B-rhs-{ku}-{lb}\")"
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
// 4. Pair-fuse B → B (lhs / rhs): inner binary feeds outer's A or B.
|
||||
for (kbi, fbi, lbi) in binaries {
|
||||
for (kbo, fbo, lbo) in binaries {
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule (
|
||||
(= ?bi (Op ({kbi} ?shape ?ai_s ?bi_s ?oi_s ?dt)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(= ?bo (Op ({kbo} ?shape ?oi_s ?co_s ?oo_s ?dt)
|
||||
(ICons ?bi (ICons ?c (INil)))))
|
||||
) (
|
||||
(let ?fs_a (Op (FusionStart ?shape ?ai_s ?dt) (ICons ?a (INil))))
|
||||
(let ?fs_b (Op (FusionStart ?shape ?bi_s ?dt) (ICons ?b (INil))))
|
||||
(let ?fs_c (Op (FusionStart ?shape ?co_s ?dt) (ICons ?c (INil))))
|
||||
(let ?fbi (Op ({fbi} ?shape ?ai_s ?bi_s ?oi_s ?dt)
|
||||
(ICons ?fs_a (ICons ?fs_b (INil)))))
|
||||
(let ?fbo (Op ({fbo} ?shape ?oi_s ?co_s ?oo_s ?dt)
|
||||
(ICons ?fbi (ICons ?fs_c (INil)))))
|
||||
(let ?fe (Op (FusionEnd ?shape ?oo_s ?dt) (ICons ?fbo (INil))))
|
||||
(union ?bo ?fe)
|
||||
) :ruleset fusion_pair :name \"pair-fuse-B-B-lhs-{lbi}-{lbo}\")"
|
||||
)));
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule (
|
||||
(= ?bi (Op ({kbi} ?shape ?ai_s ?bi_s ?oi_s ?dt)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(= ?bo (Op ({kbo} ?shape ?co_s ?oi_s ?oo_s ?dt)
|
||||
(ICons ?c (ICons ?bi (INil)))))
|
||||
) (
|
||||
(let ?fs_a (Op (FusionStart ?shape ?ai_s ?dt) (ICons ?a (INil))))
|
||||
(let ?fs_b (Op (FusionStart ?shape ?bi_s ?dt) (ICons ?b (INil))))
|
||||
(let ?fs_c (Op (FusionStart ?shape ?co_s ?dt) (ICons ?c (INil))))
|
||||
(let ?fbi (Op ({fbi} ?shape ?ai_s ?bi_s ?oi_s ?dt)
|
||||
(ICons ?fs_a (ICons ?fs_b (INil)))))
|
||||
(let ?fbo (Op ({fbo} ?shape ?co_s ?oi_s ?oo_s ?dt)
|
||||
(ICons ?fs_c (ICons ?fbi (INil)))))
|
||||
(let ?fe (Op (FusionEnd ?shape ?oo_s ?dt) (ICons ?fbo (INil))))
|
||||
(union ?bo ?fe)
|
||||
) :ruleset fusion_pair :name \"pair-fuse-B-B-rhs-{lbi}-{lbo}\")"
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
// 5. Grow FE → U: U(FE(inner)) → FE(FU(inner)). No new FS.
|
||||
for (ku, fu) in unaries {
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule (
|
||||
(= ?fe (Op (FusionEnd ?shape ?s ?dt) (ICons ?inner (INil))))
|
||||
(= ?u (Op ({hlir} ?shape ?s ?s) (ICons ?fe (INil))))
|
||||
(= ?u (Op ({ku} ?shape ?s ?s ?dt) (ICons ?fe (INil))))
|
||||
) (
|
||||
(let ?elem (Op (CudaUnaryElementwise \"{opcode}\" ?shape ?s ?s ?dt)
|
||||
(ICons ?inner (INil))))
|
||||
(let ?new_fe (Op (FusionEnd ?shape ?s ?dt) (ICons ?elem (INil))))
|
||||
(let ?fu (Op ({fu} ?shape ?s ?s ?dt) (ICons ?inner (INil))))
|
||||
(let ?new_fe (Op (FusionEnd ?shape ?s ?dt) (ICons ?fu (INil))))
|
||||
(union ?u ?new_fe)
|
||||
(set (dtype ?new_fe) ?dt)
|
||||
) :ruleset fusion_grow :name \"grow-FE-U-{hlir}\")"
|
||||
) :ruleset fusion_grow :name \"grow-FE-U-{ku}\")"
|
||||
)));
|
||||
}
|
||||
|
||||
// Grow FE → binary consumer, left and right orientations.
|
||||
for (hlir, opcode) in binaries {
|
||||
// 6. Grow FE → B (lhs / rhs): one input is the FE, the other external.
|
||||
for (kb, fb, lb) in binaries {
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule (
|
||||
(= ?fe (Op (FusionEnd ?shape ?a_s ?dt) (ICons ?inner_a (INil))))
|
||||
(= ?bin (Op ({hlir} ?shape ?a_s ?b_s ?out_s)
|
||||
(= ?bin (Op ({kb} ?shape ?a_s ?b_s ?o_s ?dt)
|
||||
(ICons ?fe (ICons ?b (INil)))))
|
||||
) (
|
||||
(let ?fs_b (Op (FusionStart ?shape ?b_s ?dt) (ICons ?b (INil))))
|
||||
(let ?elem (Op (CudaBinaryElementwise \"{opcode}\" ?shape ?a_s ?b_s ?out_s ?dt)
|
||||
(let ?fbin (Op ({fb} ?shape ?a_s ?b_s ?o_s ?dt)
|
||||
(ICons ?inner_a (ICons ?fs_b (INil)))))
|
||||
(let ?new_fe (Op (FusionEnd ?shape ?out_s ?dt) (ICons ?elem (INil))))
|
||||
(let ?new_fe (Op (FusionEnd ?shape ?o_s ?dt) (ICons ?fbin (INil))))
|
||||
(union ?bin ?new_fe)
|
||||
(set (dtype ?new_fe) ?dt)
|
||||
) :ruleset fusion_grow :name \"grow-FE-B-lhs-{hlir}\")"
|
||||
) :ruleset fusion_grow :name \"grow-FE-B-lhs-{lb}\")"
|
||||
)));
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule (
|
||||
(= ?fe (Op (FusionEnd ?shape ?b_s ?dt) (ICons ?inner_b (INil))))
|
||||
(= ?bin (Op ({hlir} ?shape ?a_s ?b_s ?out_s)
|
||||
(= ?bin (Op ({kb} ?shape ?a_s ?b_s ?o_s ?dt)
|
||||
(ICons ?a (ICons ?fe (INil)))))
|
||||
) (
|
||||
(let ?fs_a (Op (FusionStart ?shape ?a_s ?dt) (ICons ?a (INil))))
|
||||
(let ?elem (Op (CudaBinaryElementwise \"{opcode}\" ?shape ?a_s ?b_s ?out_s ?dt)
|
||||
(let ?fbin (Op ({fb} ?shape ?a_s ?b_s ?o_s ?dt)
|
||||
(ICons ?fs_a (ICons ?inner_b (INil)))))
|
||||
(let ?new_fe (Op (FusionEnd ?shape ?out_s ?dt) (ICons ?elem (INil))))
|
||||
(let ?new_fe (Op (FusionEnd ?shape ?o_s ?dt) (ICons ?fbin (INil))))
|
||||
(union ?bin ?new_fe)
|
||||
(set (dtype ?new_fe) ?dt)
|
||||
) :ruleset fusion_grow :name \"grow-FE-B-rhs-{hlir}\")"
|
||||
) :ruleset fusion_grow :name \"grow-FE-B-rhs-{lb}\")"
|
||||
)));
|
||||
}
|
||||
|
||||
// Absorb an elementwise producer through a FusionStart boundary. This
|
||||
// makes a region that initially treats `producer(...)` as an external
|
||||
// input able to pull that producer inside later.
|
||||
for (hlir, opcode) in unaries {
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule (
|
||||
(= ?u (Op ({hlir} ?shape ?s ?s) (ICons ?x (INil))))
|
||||
(= ?fs_u (Op (FusionStart ?shape ?s ?dt) (ICons ?u (INil))))
|
||||
) (
|
||||
(let ?fs_x (Op (FusionStart ?shape ?s ?dt) (ICons ?x (INil))))
|
||||
(let ?elem (Op (CudaUnaryElementwise \"{opcode}\" ?shape ?s ?s ?dt)
|
||||
(ICons ?fs_x (INil))))
|
||||
(union ?fs_u ?elem)
|
||||
) :ruleset fusion_grow :name \"grow-U-FS-{hlir}\")"
|
||||
)));
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule (
|
||||
(= ?inner_fe (Op (FusionEnd ?shape ?s ?dt) (ICons ?inner (INil))))
|
||||
(= ?bad_fs (Op (FusionStart ?shape ?s ?dt) (ICons ?inner_fe (INil))))
|
||||
(= ?bad_elem (Op (CudaUnaryElementwise \"{opcode}\" ?shape ?s ?s ?dt)
|
||||
(ICons ?bad_fs (INil))))
|
||||
(= ?bad_fe (Op (FusionEnd ?shape ?s ?dt) (ICons ?bad_elem (INil))))
|
||||
(= ?good_elem (Op (CudaUnaryElementwise \"{opcode}\" ?shape ?s ?s ?dt)
|
||||
(ICons ?inner (INil))))
|
||||
(= ?good_fe (Op (FusionEnd ?shape ?s ?dt) (ICons ?good_elem (INil))))
|
||||
(= ?bad_fe ?good_fe)
|
||||
) (
|
||||
(delete (Op (FusionStart ?shape ?s ?dt) (ICons ?inner_fe (INil))))
|
||||
) :ruleset cleanup :name \"cleanup-nested-FS-FE-unary-{hlir}\")"
|
||||
)));
|
||||
}
|
||||
for (hlir, opcode) in binaries {
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule (
|
||||
(= ?bin (Op ({hlir} ?shape ?a_s ?b_s ?out_s)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(= ?fs_bin (Op (FusionStart ?shape ?out_s ?dt) (ICons ?bin (INil))))
|
||||
) (
|
||||
(let ?fs_a (Op (FusionStart ?shape ?a_s ?dt) (ICons ?a (INil))))
|
||||
(let ?fs_b (Op (FusionStart ?shape ?b_s ?dt) (ICons ?b (INil))))
|
||||
(let ?elem (Op (CudaBinaryElementwise \"{opcode}\" ?shape ?a_s ?b_s ?out_s ?dt)
|
||||
(ICons ?fs_a (ICons ?fs_b (INil)))))
|
||||
(union ?fs_bin ?elem)
|
||||
) :ruleset fusion_grow :name \"grow-B-FS-{hlir}\")"
|
||||
)));
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule (
|
||||
(= ?inner_fe (Op (FusionEnd ?shape ?a_s ?dt) (ICons ?inner_a (INil))))
|
||||
(= ?bad_fs (Op (FusionStart ?shape ?a_s ?dt) (ICons ?inner_fe (INil))))
|
||||
(= ?fs_b (Op (FusionStart ?shape ?b_s ?dt) (ICons ?b (INil))))
|
||||
(= ?bad_elem (Op (CudaBinaryElementwise \"{opcode}\" ?shape ?a_s ?b_s ?out_s ?dt)
|
||||
(ICons ?bad_fs (ICons ?fs_b (INil)))))
|
||||
(= ?bad_fe (Op (FusionEnd ?shape ?out_s ?dt) (ICons ?bad_elem (INil))))
|
||||
(= ?good_elem (Op (CudaBinaryElementwise \"{opcode}\" ?shape ?a_s ?b_s ?out_s ?dt)
|
||||
(ICons ?inner_a (ICons ?fs_b (INil)))))
|
||||
(= ?good_fe (Op (FusionEnd ?shape ?out_s ?dt) (ICons ?good_elem (INil))))
|
||||
(= ?bad_fe ?good_fe)
|
||||
) (
|
||||
(delete (Op (FusionStart ?shape ?a_s ?dt) (ICons ?inner_fe (INil))))
|
||||
) :ruleset cleanup :name \"cleanup-nested-FS-FE-binary-lhs-{hlir}\")"
|
||||
)));
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule (
|
||||
(= ?inner_fe (Op (FusionEnd ?shape ?b_s ?dt) (ICons ?inner_b (INil))))
|
||||
(= ?bad_fs (Op (FusionStart ?shape ?b_s ?dt) (ICons ?inner_fe (INil))))
|
||||
(= ?fs_a (Op (FusionStart ?shape ?a_s ?dt) (ICons ?a (INil))))
|
||||
(= ?bad_elem (Op (CudaBinaryElementwise \"{opcode}\" ?shape ?a_s ?b_s ?out_s ?dt)
|
||||
(ICons ?fs_a (ICons ?bad_fs (INil)))))
|
||||
(= ?bad_fe (Op (FusionEnd ?shape ?out_s ?dt) (ICons ?bad_elem (INil))))
|
||||
(= ?good_elem (Op (CudaBinaryElementwise \"{opcode}\" ?shape ?a_s ?b_s ?out_s ?dt)
|
||||
(ICons ?fs_a (ICons ?inner_b (INil)))))
|
||||
(= ?good_fe (Op (FusionEnd ?shape ?out_s ?dt) (ICons ?good_elem (INil))))
|
||||
(= ?bad_fe ?good_fe)
|
||||
) (
|
||||
(delete (Op (FusionStart ?shape ?b_s ?dt) (ICons ?inner_fe (INil))))
|
||||
) :ruleset cleanup :name \"cleanup-nested-FS-FE-binary-rhs-{hlir}\")"
|
||||
)));
|
||||
}
|
||||
|
||||
// Merge two FEs at a binary: B(FE(ia), FE(ib)) → FE(CudaBinary(ia, ib)).
|
||||
for (hlir, opcode) in binaries {
|
||||
// 7. Merge two FEs at a binary: B(FE(ia), FE(ib)) → FE(FB(ia, ib)).
|
||||
//
|
||||
// This is destructive: after creating the larger region, subsume the
|
||||
// two smaller FusionEnd rows. Without that, independently-grown left
|
||||
// and right regions form a Cartesian product, then those alternatives
|
||||
// can merge again higher in the graph.
|
||||
for (kb, fb, lb) in binaries {
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule (
|
||||
(= ?fe_a (Op (FusionEnd ?shape ?a_s ?dt) (ICons ?inner_a (INil))))
|
||||
(= ?fe_b (Op (FusionEnd ?shape ?b_s ?dt) (ICons ?inner_b (INil))))
|
||||
(= ?bin (Op ({hlir} ?shape ?a_s ?b_s ?out_s)
|
||||
(= ?bin (Op ({kb} ?shape ?a_s ?b_s ?o_s ?dt)
|
||||
(ICons ?fe_a (ICons ?fe_b (INil)))))
|
||||
) (
|
||||
(let ?elem (Op (CudaBinaryElementwise \"{opcode}\" ?shape ?a_s ?b_s ?out_s ?dt)
|
||||
(let ?fbin (Op ({fb} ?shape ?a_s ?b_s ?o_s ?dt)
|
||||
(ICons ?inner_a (ICons ?inner_b (INil)))))
|
||||
(let ?new_fe (Op (FusionEnd ?shape ?out_s ?dt) (ICons ?elem (INil))))
|
||||
(let ?new_fe (Op (FusionEnd ?shape ?o_s ?dt) (ICons ?fbin (INil))))
|
||||
(union ?bin ?new_fe)
|
||||
(set (dtype ?new_fe) ?dt)
|
||||
) :ruleset fusion_merge :name \"merge-FE-FE-{hlir}\")"
|
||||
(subsume (Op (FusionEnd ?shape ?a_s ?dt) (ICons ?inner_a (INil))))
|
||||
(subsume (Op (FusionEnd ?shape ?b_s ?dt) (ICons ?inner_b (INil))))
|
||||
) :ruleset fusion_merge :name \"merge-FE-FE-{lb}\")"
|
||||
)));
|
||||
}
|
||||
|
||||
|
||||
@@ -2,21 +2,25 @@
|
||||
//!
|
||||
//! - `markers` — `FusionStart` / `FusionEnd` ops + the seven egglog rule
|
||||
//! families that build and extend FE-bracketed regions.
|
||||
//! - `elementwise` — generic region-internal CUDA elementwise op variants.
|
||||
//! - `fused_ops` — eight `FusedX` op variants (interior to a region) so
|
||||
//! pair-fuse rules' RHS sit in a different egglog sort than their LHS,
|
||||
//! blocking cascade by typing.
|
||||
//! - `region_codegen` — `kernel_to_host` calls into here to collapse each
|
||||
//! FE-rooted region into a single CUDA kernel at compile time.
|
||||
//!
|
||||
//! The LLIR keeps `FusionStart` / generic elementwise / `FusionEnd` nodes after
|
||||
//! The LLIR keeps `FusionStart` / `FusedX` / `FusionEnd` nodes after
|
||||
//! extraction; `region_codegen` is the only place that walks them.
|
||||
|
||||
pub mod elementwise;
|
||||
pub mod fused_ops;
|
||||
pub mod markers;
|
||||
pub mod region_codegen;
|
||||
|
||||
pub use elementwise::{CudaBinaryElementwise, CudaUnaryElementwise};
|
||||
pub use fused_ops::{
|
||||
FusedAdd, FusedExp, FusedExp2, FusedLog2, FusedMul, FusedRecip, FusedSin, FusedSqrt,
|
||||
};
|
||||
pub use markers::{FusionEnd, FusionStart};
|
||||
|
||||
/// All fusion-related op types that the egglog runtime needs to know about
|
||||
/// (markers + interior generic elementwise variants). Combined into a flat
|
||||
/// tuple for the `Ops` registry in `kernel::mod`.
|
||||
pub type Ops = (markers::Ops, elementwise::Ops);
|
||||
/// (markers + interior FusedX variants). Combined into a flat tuple for the
|
||||
/// `Ops` registry in `kernel::mod`.
|
||||
pub type Ops = (markers::Ops, fused_ops::Ops);
|
||||
|
||||
@@ -1,26 +1,26 @@
|
||||
// =========================================================================
|
||||
// Region codegen for FusionStart / FusionEnd-bracketed fused regions.
|
||||
//
|
||||
// Older fusion lowering left elementwise / FusionStart / FusionEnd nodes in the post-extraction
|
||||
// PR1 left FusedX / FusionStart / FusionEnd nodes in the post-extraction
|
||||
// LLIR, each compiling to its own standalone CUDA kernel. PR2 collapses
|
||||
// every FusionEnd-rooted region into ONE fused CUDA kernel at codegen
|
||||
// time — without rewriting the LLIR.
|
||||
//
|
||||
// Pipeline:
|
||||
// `kernel_to_host` builds a Vec<CompileUnit> from the topo order:
|
||||
// - CompileUnit::Single(node) — unfused non-region kernels, compiled as before.
|
||||
// - CompileUnit::Region(rgn) — one FE + its interior elementwise DAG +
|
||||
// - CompileUnit::Single(node) — un-fused KernelX, compiled as before.
|
||||
// - CompileUnit::Region(rgn) — one FE + its interior FusedX DAG +
|
||||
// its FS leaves. Compiled here as a
|
||||
// single CUDA kernel that reads from
|
||||
// the region's external inputs once,
|
||||
// chains all elementwise bodies through
|
||||
// chains all FusedX bodies through
|
||||
// register-resident locals, and writes
|
||||
// the FE's output.
|
||||
//
|
||||
// The CompiledKernel for a Region is keyed on the FE node and stores
|
||||
// `inputs = external producer NodeIndices` (one per interior FusionStart),
|
||||
// so the existing buffer-pointer wiring in to_host.rs picks up the right
|
||||
// device pointers at execute time. Interior Cuda*Elementwise / FusionStart nodes
|
||||
// device pointers at execute time. Interior FusedX / FusionStart nodes
|
||||
// never enter the kernels Vec — they have no buffers, no launches.
|
||||
// =========================================================================
|
||||
|
||||
@@ -40,7 +40,6 @@ use as_any::Downcast;
|
||||
use crate::{
|
||||
compile_module_image_for_current_device, cuda_dtype,
|
||||
kernel::KernelOp,
|
||||
kernel::fusion::elementwise::{CudaBinaryElementwise, CudaUnaryElementwise},
|
||||
kernel::fusion::markers::{FusionEnd, FusionStart},
|
||||
kernel::hlir::{dtype_includes, generate_dyn_dims_defines},
|
||||
};
|
||||
@@ -53,10 +52,10 @@ use crate::{
|
||||
pub(crate) struct RegionUnit {
|
||||
/// The FusionEnd node that anchors this region.
|
||||
pub fe_node: NodeIndex,
|
||||
/// Interior Cuda*Elementwise nodes, in topological order (predecessors before
|
||||
/// Interior FusedX nodes, in topological order (predecessors before
|
||||
/// consumers). Used to emit register-binding statements in dependency
|
||||
/// order in the fused CUDA kernel body.
|
||||
pub elementwise_topo: Vec<NodeIndex>,
|
||||
pub fusedx_topo: Vec<NodeIndex>,
|
||||
/// FusionStart nodes that bound the region's leaves. One per external
|
||||
/// read site — duplicates (different FS LLIR nodes wrapping the same
|
||||
/// upstream tensor) are kept separate so each read uses its own
|
||||
@@ -80,13 +79,13 @@ pub(crate) enum CompileUnit {
|
||||
|
||||
/// Group a sub-DAG's topo order into compile units. Each FusionEnd node
|
||||
/// becomes the root of a `CompileUnit::Region`; the region's interior
|
||||
/// Cuda*Elementwise and FusionStart nodes are absorbed into that region and removed
|
||||
/// FusedX and FusionStart nodes are absorbed into that region and removed
|
||||
/// from the per-node iteration. Anything else is wrapped in
|
||||
/// `CompileUnit::Single`.
|
||||
/// Globally-absorbed FS / FE markers — the set of marker nodes that any
|
||||
/// `FusionEnd` in the LLIR walks back to during region detection. A
|
||||
/// marker is "absorbed" iff some FE in the LLIR can reach it by walking
|
||||
/// incoming edges through `FusionEnd` / Cuda*Elementwise nodes, stopping at
|
||||
/// incoming edges through `FusionEnd` / `FusedX` nodes, stopping at
|
||||
/// `FusionStart` leaves.
|
||||
///
|
||||
/// This is computed once over the full LLIR rather than per-convex-
|
||||
@@ -124,7 +123,7 @@ pub(crate) fn globally_absorbed_markers(llir_graph: &LLIRGraph) -> FxHashSet<Nod
|
||||
absorbed.insert(pred);
|
||||
stack.push(pred);
|
||||
}
|
||||
Some(_) if is_region_elementwise(llir_graph, pred) => {
|
||||
Some(other) if other.starts_with("Fused") => {
|
||||
absorbed.insert(pred);
|
||||
stack.push(pred);
|
||||
}
|
||||
@@ -188,12 +187,12 @@ pub(crate) fn build_compile_units(
|
||||
absorbed.insert(pred);
|
||||
stack.push(pred);
|
||||
}
|
||||
Some(_) if is_region_elementwise(llir_graph, pred) => {
|
||||
Some(other) if other.starts_with("Fused") => {
|
||||
interior.push(pred);
|
||||
stack.push(pred);
|
||||
}
|
||||
_ => {
|
||||
// Non-marker, non-elementwise predecessor inside what
|
||||
// Non-marker, non-FusedX predecessor inside what
|
||||
// we thought was a region. Shouldn't happen with
|
||||
// the current rules; treat conservatively: do
|
||||
// not absorb it. This means the region is
|
||||
@@ -230,56 +229,7 @@ pub(crate) fn build_compile_units(
|
||||
llir_graph
|
||||
.neighbors_directed(fs, Direction::Incoming)
|
||||
.next()
|
||||
.unwrap_or_else(|| {
|
||||
// Dump the malformed structure: which FE
|
||||
// triggered the walk, every node in fs_topo and
|
||||
// interior_topo, and each FS's incoming /
|
||||
// outgoing degree. Helps localize whether the
|
||||
// missing edge came from extraction or a
|
||||
// downstream LLIR transform.
|
||||
if std::env::var("LUMINAL_DEBUG_FUSION_PANIC").is_ok() {
|
||||
eprintln!(
|
||||
"FusionStart panic: fe={} (kernel={:?})",
|
||||
node.index(),
|
||||
llir_graph.node_weight(node).and_then(|op| {
|
||||
op.to_dialect::<dyn KernelOp>().map(|k| k.kernel_name())
|
||||
}),
|
||||
);
|
||||
eprintln!(" fs_topo ({}):", fs_topo.len());
|
||||
for &f in &fs_topo {
|
||||
let in_deg = llir_graph
|
||||
.neighbors_directed(f, Direction::Incoming)
|
||||
.count();
|
||||
let out_deg = llir_graph
|
||||
.neighbors_directed(f, Direction::Outgoing)
|
||||
.count();
|
||||
let kn = llir_graph
|
||||
.node_weight(f)
|
||||
.and_then(|op| {
|
||||
op.to_dialect::<dyn KernelOp>().map(|k| k.kernel_name())
|
||||
})
|
||||
.unwrap_or("?");
|
||||
eprintln!(
|
||||
" fs={} kind={} in_deg={} out_deg={}",
|
||||
f.index(),
|
||||
kn,
|
||||
in_deg,
|
||||
out_deg,
|
||||
);
|
||||
}
|
||||
eprintln!(" interior_topo ({}):", interior_topo.len());
|
||||
for &i in &interior_topo {
|
||||
let kn = llir_graph
|
||||
.node_weight(i)
|
||||
.and_then(|op| {
|
||||
op.to_dialect::<dyn KernelOp>().map(|k| k.kernel_name())
|
||||
})
|
||||
.unwrap_or("?");
|
||||
eprintln!(" interior={} kind={}", i.index(), kn);
|
||||
}
|
||||
}
|
||||
panic!("FusionStart with no predecessor")
|
||||
})
|
||||
.expect("FusionStart with no predecessor")
|
||||
})
|
||||
.collect();
|
||||
|
||||
@@ -290,7 +240,7 @@ pub(crate) fn build_compile_units(
|
||||
node,
|
||||
RegionUnit {
|
||||
fe_node: node,
|
||||
elementwise_topo: interior_topo,
|
||||
fusedx_topo: interior_topo,
|
||||
fs_nodes: fs_topo,
|
||||
external_inputs,
|
||||
},
|
||||
@@ -319,53 +269,24 @@ pub(crate) fn build_compile_units(
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// Per-elementwise body templates.
|
||||
// Per-FusedX body templates.
|
||||
//
|
||||
// Each entry takes the names of the local variables holding the op's
|
||||
// inputs and returns a CUDA expression evaluating to the op's output
|
||||
// (a register-resident value, no buffer involved).
|
||||
// =========================================================================
|
||||
|
||||
fn is_region_elementwise(llir_graph: &LLIRGraph, node: NodeIndex) -> bool {
|
||||
llir_graph
|
||||
.node_weight(node)
|
||||
.and_then(|op| op.to_dialect::<dyn KernelOp>())
|
||||
.is_some_and(|op| {
|
||||
(***op).downcast_ref::<CudaUnaryElementwise>().is_some()
|
||||
|| (***op).downcast_ref::<CudaBinaryElementwise>().is_some()
|
||||
})
|
||||
}
|
||||
|
||||
fn elementwise_value(local: &str, dtype: DType) -> String {
|
||||
if matches!(dtype, DType::F8E4M3 | DType::F8E5M2 | DType::F8UE8M0) {
|
||||
format!("static_cast<float>({local})")
|
||||
} else {
|
||||
local.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
fn elementwise_init_expr(expr: &str, dtype: DType, cuda_ty: &str) -> String {
|
||||
if matches!(dtype, DType::F8E4M3 | DType::F8E5M2 | DType::F8UE8M0) {
|
||||
format!("{cuda_ty}({expr})")
|
||||
} else {
|
||||
expr.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
fn elementwise_body(op: &str, locals: &[&str], dtype: DType) -> String {
|
||||
let a = || elementwise_value(locals[0], dtype);
|
||||
let b = || elementwise_value(locals[1], dtype);
|
||||
match op {
|
||||
"Sin" => format!("sinf({})", a()),
|
||||
"Sqrt" => format!("sqrtf({})", a()),
|
||||
"Exp" => format!("expf({})", a()),
|
||||
"Exp2" => format!("exp2f({})", a()),
|
||||
"Log2" => format!("log2f({})", a()),
|
||||
"Recip" => format!("1.0f / {}", a()),
|
||||
"Sigmoid" => format!("1.0f / (1.0f + expf(-{}))", a()),
|
||||
"Add" => format!("{} + {}", a(), b()),
|
||||
"Mul" => format!("{} * {}", a(), b()),
|
||||
other => panic!("region_codegen: unknown elementwise op {other}"),
|
||||
fn fused_body(name: &str, locals: &[&str]) -> String {
|
||||
match name {
|
||||
"FusedSin" => format!("sinf({})", locals[0]),
|
||||
"FusedSqrt" => format!("sqrtf({})", locals[0]),
|
||||
"FusedExp" => format!("expf({})", locals[0]),
|
||||
"FusedExp2" => format!("exp2f({})", locals[0]),
|
||||
"FusedLog2" => format!("log2f({})", locals[0]),
|
||||
"FusedRecip" => format!("1.0f / {}", locals[0]),
|
||||
"FusedAdd" => format!("{} + {}", locals[0], locals[1]),
|
||||
"FusedMul" => format!("{} * {}", locals[0], locals[1]),
|
||||
other => panic!("region_codegen: unknown FusedX op {other}"),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -403,7 +324,7 @@ pub(crate) fn compile_region(
|
||||
let dtype: DType = fe_struct.dtype;
|
||||
|
||||
// Aggregate all dynamic vars used anywhere in the region (FS strides,
|
||||
// FE strides and elementwise shapes.
|
||||
// FE strides, FusedX shape — all FusedX share `out_shape`, but their
|
||||
// own strides are likewise relevant for any future stride-affine ops).
|
||||
let mut all_vars: FxHashSet<char> = FxHashSet::default();
|
||||
all_vars.extend(out_shape.iter().flat_map(|e| e.dyn_vars()));
|
||||
@@ -413,19 +334,6 @@ pub(crate) fn compile_region(
|
||||
let fs_struct: &FusionStart = (***fs_op).downcast_ref::<FusionStart>().unwrap();
|
||||
all_vars.extend(fs_struct.strides.iter().flat_map(|e| e.dyn_vars()));
|
||||
}
|
||||
for &elem_idx in ®ion.elementwise_topo {
|
||||
let elem_op = llir_graph[elem_idx].to_dialect::<dyn KernelOp>().unwrap();
|
||||
if let Some(elem) = (***elem_op).downcast_ref::<CudaUnaryElementwise>() {
|
||||
all_vars.extend(elem.shape.iter().flat_map(|e| e.dyn_vars()));
|
||||
all_vars.extend(elem.in_strides.iter().flat_map(|e| e.dyn_vars()));
|
||||
all_vars.extend(elem.out_strides.iter().flat_map(|e| e.dyn_vars()));
|
||||
} else if let Some(elem) = (***elem_op).downcast_ref::<CudaBinaryElementwise>() {
|
||||
all_vars.extend(elem.out_shape.iter().flat_map(|e| e.dyn_vars()));
|
||||
all_vars.extend(elem.a_stride.iter().flat_map(|e| e.dyn_vars()));
|
||||
all_vars.extend(elem.b_stride.iter().flat_map(|e| e.dyn_vars()));
|
||||
all_vars.extend(elem.out_stride.iter().flat_map(|e| e.dyn_vars()));
|
||||
}
|
||||
}
|
||||
|
||||
let cuda_ty = cuda_dtype(dtype);
|
||||
let includes = dtype_includes(&[dtype]);
|
||||
@@ -451,19 +359,19 @@ pub(crate) fn compile_region(
|
||||
}
|
||||
let signature = signature_params.join(", ");
|
||||
|
||||
// Body: read FS leaves, then walk elementwise nodes in topo order emitting a
|
||||
// Body: read FS leaves, then walk FusedX in topo order emitting a
|
||||
// local per op, then write FE output. Every node gets a local keyed
|
||||
// by a position-in-region index so the kernel string is invariant
|
||||
// under NodeIndex churn (each `egglog_to_llir` reissues NodeIndexes,
|
||||
// so naming locals by `n.index()` would invalidate the kernel
|
||||
// string cache on every search candidate). Indices: FS leaves get
|
||||
// 0..fs_nodes.len(), elementwise nodes get fs_nodes.len()..(+ elementwise_topo.len()).
|
||||
// 0..fs_nodes.len(), FusedX get fs_nodes.len()..(+ fusedx_topo.len()).
|
||||
let mut local_idx_map: FxHashMap<NodeIndex, usize> = FxHashMap::default();
|
||||
for (i, &fs_idx) in region.fs_nodes.iter().enumerate() {
|
||||
local_idx_map.insert(fs_idx, i);
|
||||
}
|
||||
let fs_count = region.fs_nodes.len();
|
||||
for (i, &op_idx) in region.elementwise_topo.iter().enumerate() {
|
||||
for (i, &op_idx) in region.fusedx_topo.iter().enumerate() {
|
||||
local_idx_map.insert(op_idx, fs_count + i);
|
||||
}
|
||||
let local_name = |n: NodeIndex| format!("v_{}", local_idx_map[&n]);
|
||||
@@ -486,22 +394,12 @@ pub(crate) fn compile_region(
|
||||
));
|
||||
}
|
||||
|
||||
// Elementwise ops in topo order. Each looks up its predecessor locals
|
||||
// FusedX ops in topo order. Each looks up its predecessor locals
|
||||
// (in incoming-edge id order to match the original op's input
|
||||
// arity / position).
|
||||
for &op_idx in ®ion.elementwise_topo {
|
||||
for &op_idx in ®ion.fusedx_topo {
|
||||
let op_ref = llir_graph[op_idx].to_dialect::<dyn KernelOp>().unwrap();
|
||||
let (elem_name, elem_dtype) =
|
||||
if let Some(elem) = (***op_ref).downcast_ref::<CudaUnaryElementwise>() {
|
||||
(elem.op.as_str(), elem.dtype)
|
||||
} else if let Some(elem) = (***op_ref).downcast_ref::<CudaBinaryElementwise>() {
|
||||
(elem.op.as_str(), elem.dtype)
|
||||
} else {
|
||||
panic!(
|
||||
"region_codegen: expected Cuda*Elementwise op, got {}",
|
||||
op_ref.kernel_name()
|
||||
);
|
||||
};
|
||||
let op_name = op_ref.kernel_name();
|
||||
|
||||
let mut input_locals: Vec<String> = llir_graph
|
||||
.edges_directed(op_idx, Direction::Incoming)
|
||||
@@ -520,16 +418,15 @@ pub(crate) fn compile_region(
|
||||
input_locals = edges.into_iter().map(|(_, src)| local_name(src)).collect();
|
||||
let inputs_ref: Vec<&str> = input_locals.iter().map(|s| s.as_str()).collect();
|
||||
|
||||
let expr = elementwise_body(elem_name, &inputs_ref, elem_dtype);
|
||||
let expr = elementwise_init_expr(&expr, elem_dtype, cuda_ty);
|
||||
let expr = fused_body(op_name, &inputs_ref);
|
||||
body.push_str(&format!(
|
||||
" {cuda_ty} {name} = {expr};\n",
|
||||
name = local_name(op_idx),
|
||||
));
|
||||
}
|
||||
|
||||
// FE write: pick the elementwise node feeding FE (its single incoming edge in
|
||||
// the region — an elementwise node or, in degenerate single-FS regions which
|
||||
// FE write: pick the FusedX feeding FE (its single incoming edge in
|
||||
// the region — a FusedX or, in degenerate single-FS regions which
|
||||
// shouldn't arise, an FS).
|
||||
let fe_input: NodeIndex = llir_graph
|
||||
.neighbors_directed(region.fe_node, Direction::Incoming)
|
||||
@@ -577,63 +474,3 @@ pub(crate) fn compile_region(
|
||||
constants: FxHashMap::default(),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::kernel::fusion::elementwise::CudaBinaryElementwise;
|
||||
use luminal::op::LLIROp;
|
||||
use luminal::prelude::petgraph::algo::toposort;
|
||||
|
||||
/// Helper: wrap a `KernelOp` in an `LLIROp` of the kernel dialect.
|
||||
fn llir_of(op: impl KernelOp + 'static) -> LLIROp {
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(op) as Box<dyn KernelOp>)
|
||||
}
|
||||
|
||||
/// Reproducer for the `FusionStart with no predecessor` panic at
|
||||
/// `region_codegen.rs:232`. The egglog rolling pass + iterated mode
|
||||
/// (`LUMINAL_LOOP_ROLL_ITERATE=1`) has been observed to produce LLIR
|
||||
/// graphs where a `FusionStart` marker is reached as a region leaf
|
||||
/// during the FE→FS walk but has no incoming edge — meaning the
|
||||
/// region has nothing to read from. `build_compile_units` then
|
||||
/// panics when constructing `external_inputs` because every FS leaf
|
||||
/// is required to have exactly one external producer.
|
||||
///
|
||||
/// Until that path is fixed, this test pins the failure mode so a
|
||||
/// regression doesn't silently change the panic message or location.
|
||||
/// `should_panic` rather than `ignore` so it stays runnable in CI
|
||||
/// and surfaces if the panic ever moves.
|
||||
#[test]
|
||||
#[should_panic(expected = "FusionStart with no predecessor")]
|
||||
fn fusion_start_with_no_predecessor_panics() {
|
||||
// Minimal reproducer:
|
||||
//
|
||||
// (no input) ──▶ FusionStart ──▶ CudaBinaryElementwise ──▶ FusionEnd
|
||||
//
|
||||
// CudaBinaryElementwise is a binary op (n_inputs = 2) so a real region would
|
||||
// have two FS leaves. For this panic-shape test only the *first*
|
||||
// FS leaf needs a missing predecessor — `build_compile_units`
|
||||
// panics in `expect("FusionStart with no predecessor")` as soon
|
||||
// as any FS in `fs_topo` lacks one. We add only one FS edge so
|
||||
// CudaBinaryElementwise has a dangling second input slot, but that's fine:
|
||||
// we're testing the specific panic path inside `build_compile_units`,
|
||||
// not full kernel codegen.
|
||||
let mut llir: LLIRGraph = LLIRGraph::default();
|
||||
|
||||
let fs_node = llir.add_node(llir_of(FusionStart::default()));
|
||||
let fadd_node = llir.add_node(llir_of(CudaBinaryElementwise::default()));
|
||||
let fe_node = llir.add_node(llir_of(FusionEnd::default()));
|
||||
|
||||
// FusionStart → CudaBinaryElementwise → FusionEnd.
|
||||
llir.add_edge(fs_node, fadd_node, ());
|
||||
llir.add_edge(fadd_node, fe_node, ());
|
||||
|
||||
let topo = toposort(&llir, None).expect("LLIR cycle in test setup");
|
||||
let absorbed = globally_absorbed_markers(&llir);
|
||||
|
||||
// This is the call that panics with `FusionStart with no
|
||||
// predecessor` because `fs_node`'s incoming-edges iterator is
|
||||
// empty.
|
||||
let _ = build_compile_units(&topo, &llir, &absorbed);
|
||||
}
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,427 +0,0 @@
|
||||
//! Direct 2D matmul kernel — bypasses egglog rewrites, used as a custom op
|
||||
//! for matmul shapes where the cublaslt egg rules don't reliably fire.
|
||||
//!
|
||||
//! The cublaslt 2D rules in `host/cublaslt/cublaslt_*Cm_rewrite.egg` /
|
||||
//! `cublaslt_Rm*_rewrite.egg` are *supposed* to match any 2D matmul whose
|
||||
//! Mul + SumReduce broadcast lowering has the expected stride patterns,
|
||||
//! and the conditional matmul cleanup is *supposed* to delete the
|
||||
//! elementwise Mul + KernelSumReduce fallback whenever a cublaslt alternative
|
||||
//! exists. In practice both fail to fire reliably for the VAE's mid-block
|
||||
//! `AttnBlock` matmuls — at 1024² that lets the search occasionally pick
|
||||
//! the broadcast-Mul path for `q @ kᵀ`, generating a `(HW, HW, C) =
|
||||
//! (16384, 16384, 512)` ≈ 524 GiB single intermediate that OOMs the GPU.
|
||||
//!
|
||||
//! Same approach as `kernel::conv2d`: define a `KernelOp`, wrap it in a
|
||||
//! `CustomOp`, expose a tiny `pub fn` so callers don't see the
|
||||
//! `cx.custom_op` plumbing. This is opaque to egglog by design — we
|
||||
//! aren't trying to fuse with surrounding ops, just guarantee a sane
|
||||
//! lowering for the matmuls we know are problematic.
|
||||
//!
|
||||
//! The CUDA implementation is a textbook 2D-blocked SGEMM:
|
||||
//! * 16×16 output tile per block (256 threads)
|
||||
//! * Tiled load of A and B into shared memory in K-size chunks
|
||||
//! * Each thread accumulates one output element across all K-tiles
|
||||
//! * Optional bias broadcast along the M axis at write-out
|
||||
//! * `transpose_b` toggles between row-major B `(K, N)` and row-major
|
||||
//! B `(N, K)` (i.e. the `A @ Bᵀ` pattern that linear/projection
|
||||
//! layers use).
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use cudarc::driver::{CudaFunction, CudaModule, CudaSlice, CudaStream};
|
||||
use luminal::{
|
||||
dtype::DType, op::CustomOp, op::LLIROp, prelude::FxHashMap, prelude::GraphTensor,
|
||||
shape::Expression,
|
||||
};
|
||||
|
||||
use crate::compile_module_image_for_current_device;
|
||||
use crate::kernel::KernelOp;
|
||||
|
||||
/// Direct 2D matmul `(M, K) × {(K, N) | (N, K)} → (M, N)` with optional
|
||||
/// per-output-column bias and an optional batch axis. A and output are
|
||||
/// always F32. B can be F32 or BF16; BF16 is converted to F32 on each
|
||||
/// load, which avoids materializing the cast as a separate intermediate
|
||||
/// tensor (important for the text encoder / transformer where the F32-
|
||||
/// cast weights would not fit in GPU memory). All shape parameters are
|
||||
/// static (baked into the CUDA source via #defines).
|
||||
///
|
||||
/// When `batch > 1` the kernel does `batch` independent 2D matmuls in
|
||||
/// parallel: A is `(batch, M, K)`, B is `(batch, *, *)` with the same
|
||||
/// per-batch shape, output is `(batch, M, N)`. All three are assumed
|
||||
/// contiguous row-major across batches (i.e. `a_batch_stride = M*K`,
|
||||
/// `b_batch_stride = K*N` or `N*K` depending on `transpose_b`,
|
||||
/// `out_batch_stride = M*N`). Bias does NOT have a batch axis — it's
|
||||
/// `(N,)` and broadcast across batches.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Matmul2DKernel {
|
||||
pub m: usize,
|
||||
pub n: usize,
|
||||
pub k: usize,
|
||||
pub batch: usize,
|
||||
/// If `true`, B is interpreted as `(N, K)` row-major and accessed as
|
||||
/// `B[n][k]` (i.e. `A @ Bᵀ`). If `false`, B is `(K, N)` row-major and
|
||||
/// accessed as `B[k][n]` (i.e. `A @ B`).
|
||||
pub transpose_b: bool,
|
||||
pub has_bias: bool,
|
||||
/// Storage dtype of B. Currently F32 or BF16 are supported.
|
||||
pub weight_dtype: DType,
|
||||
}
|
||||
|
||||
const TILE: usize = 16;
|
||||
|
||||
impl KernelOp for Matmul2DKernel {
|
||||
fn compile(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> (
|
||||
CudaFunction,
|
||||
Arc<CudaModule>,
|
||||
String,
|
||||
(Expression, Expression, Expression),
|
||||
(Expression, Expression, Expression),
|
||||
Expression,
|
||||
FxHashMap<char, CudaSlice<u8>>,
|
||||
) {
|
||||
let bias_param = if self.has_bias {
|
||||
", const float* __restrict__ bias"
|
||||
} else {
|
||||
""
|
||||
};
|
||||
let bias_add = if self.has_bias {
|
||||
" acc += bias[n];\n"
|
||||
} else {
|
||||
""
|
||||
};
|
||||
// We want Bs[ty][tx] = B_effective[k0+ty][b_n_base+tx] where:
|
||||
// transpose_b=false: B is (K, N) row-major → B[(k0+ty)*N + (b_n_base+tx)]
|
||||
// transpose_b=true: B is (N, K) row-major → B[(b_n_base+tx)*K + (k0+ty)]
|
||||
// Plus the per-batch offset (`b_batch_off`).
|
||||
let b_index_expr = if self.transpose_b {
|
||||
"b_batch_off + (b_n_base + tx) * K + (k0 + ty)"
|
||||
} else {
|
||||
"b_batch_off + (k0 + ty) * N + (b_n_base + tx)"
|
||||
};
|
||||
// Convert B's element to float on load. For BF16 we declare B as
|
||||
// `__nv_bfloat16*` and use `__bfloat162float`; for F32 it's a no-op.
|
||||
let (b_param_type, b_load_expr, bf16_include) = match self.weight_dtype {
|
||||
DType::F32 => (
|
||||
"const float* __restrict__ B",
|
||||
format!("B[{b_index_expr}]"),
|
||||
"",
|
||||
),
|
||||
DType::Bf16 => (
|
||||
"const __nv_bfloat16* __restrict__ B",
|
||||
format!("__bfloat162float(B[{b_index_expr}])"),
|
||||
"#include <cuda_bf16.h>\n",
|
||||
),
|
||||
other => panic!("Matmul2DKernel: unsupported weight_dtype {other:?}"),
|
||||
};
|
||||
|
||||
let kernel = format!(
|
||||
"
|
||||
{bf16_include}extern \"C\" __global__ void matmul_2d_kernel(
|
||||
float* __restrict__ C,
|
||||
const float* __restrict__ A,
|
||||
{b_param_type}{bias_param}
|
||||
) {{
|
||||
const int M = {m};
|
||||
const int N = {n};
|
||||
const int K = {k};
|
||||
const int TILE = {tile};
|
||||
|
||||
__shared__ float As[{tile}][{tile}];
|
||||
__shared__ float Bs[{tile}][{tile}];
|
||||
|
||||
int bx = blockIdx.x; // tile column (n)
|
||||
int by = blockIdx.y; // tile row (m)
|
||||
int batch = blockIdx.z; // batch index (0..BATCH-1)
|
||||
int tx = threadIdx.x; // 0..TILE-1, output col within tile
|
||||
int ty = threadIdx.y; // 0..TILE-1, output row within tile
|
||||
|
||||
int m_global = by * TILE + ty;
|
||||
int n_global = bx * TILE + tx;
|
||||
|
||||
int a_m_base = by * TILE;
|
||||
int b_n_base = bx * TILE;
|
||||
|
||||
// Per-batch base pointer offsets (contiguous row-major across batches).
|
||||
int a_batch_off = batch * (M * K);
|
||||
int b_batch_off = batch * (K * N);
|
||||
int c_batch_off = batch * (M * N);
|
||||
|
||||
float acc = 0.0f;
|
||||
|
||||
int n_tiles = (K + TILE - 1) / TILE;
|
||||
for (int t = 0; t < n_tiles; ++t) {{
|
||||
int k0 = t * TILE;
|
||||
|
||||
// Load A tile (TILE, TILE) row-major from A[m, k]: A[(by*TILE+ty)*K + (k0+tx)]
|
||||
int a_m = a_m_base + ty;
|
||||
int a_k = k0 + tx;
|
||||
As[ty][tx] = (a_m < M && a_k < K) ? A[a_batch_off + a_m * K + a_k] : 0.0f;
|
||||
|
||||
// Load B tile depending on transpose_b
|
||||
int b_n_or_k = b_n_base + tx; // for transpose_b=true this is N; for =false this is N
|
||||
int b_k_or_k = k0 + ty; // similarly
|
||||
// We compute Bs[ty][tx] such that the inner loop reads Bs[k_local][n_local] = B[k][n].
|
||||
// For transpose_b=true (B is (N,K)): B[k][n] in math = B_storage[n][k] = B[(b_n_base+tx)*K + (k0+ty)]
|
||||
// For transpose_b=false (B is (K,N)): B[k][n] in math = B_storage[k][n] = B[(k0+ty)*N + (b_n_base+tx)]
|
||||
bool b_in_bounds = ({transpose_b} ? (b_n_or_k < N && b_k_or_k < K)
|
||||
: (b_k_or_k < K && b_n_or_k < N));
|
||||
Bs[ty][tx] = b_in_bounds ? ({b_load_expr}) : 0.0f;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
for (int kk = 0; kk < {tile}; ++kk) {{
|
||||
acc += As[ty][kk] * Bs[kk][tx];
|
||||
}}
|
||||
__syncthreads();
|
||||
}}
|
||||
|
||||
if (m_global < M && n_global < N) {{
|
||||
int n = n_global;
|
||||
{bias_add} C[c_batch_off + m_global * N + n_global] = acc;
|
||||
}}
|
||||
}}
|
||||
",
|
||||
m = self.m,
|
||||
n = self.n,
|
||||
k = self.k,
|
||||
tile = TILE,
|
||||
transpose_b = self.transpose_b,
|
||||
b_load_expr = b_load_expr,
|
||||
b_param_type = b_param_type,
|
||||
bias_param = bias_param,
|
||||
bias_add = bias_add,
|
||||
bf16_include = bf16_include,
|
||||
);
|
||||
|
||||
let (module, func) = if let Some((m, f)) = compile_cache.get(&kernel) {
|
||||
(m.clone(), f.clone())
|
||||
} else {
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let func = module.load_function("matmul_2d_kernel").unwrap();
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
(module, func)
|
||||
};
|
||||
|
||||
let grid_x = self.n.div_ceil(TILE);
|
||||
let grid_y = self.m.div_ceil(TILE);
|
||||
(
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(
|
||||
Expression::from(grid_x),
|
||||
Expression::from(grid_y),
|
||||
Expression::from(self.batch),
|
||||
),
|
||||
(
|
||||
Expression::from(TILE),
|
||||
Expression::from(TILE),
|
||||
Expression::from(1usize),
|
||||
),
|
||||
Expression::from(0usize),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
Expression::from(self.batch * self.m * self.n)
|
||||
}
|
||||
|
||||
fn output_bytes(&self) -> Expression {
|
||||
self.output_size() * 4
|
||||
}
|
||||
|
||||
fn output_dtype(&self) -> DType {
|
||||
DType::F32
|
||||
}
|
||||
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
// K elements from A (F32) + K elements from B (F32 or BF16) + maybe bias (F32).
|
||||
let b_bytes = match self.weight_dtype {
|
||||
DType::F32 => 4,
|
||||
DType::Bf16 => 2,
|
||||
_ => 4,
|
||||
};
|
||||
let bias_bytes = if self.has_bias { 4 } else { 0 };
|
||||
Expression::from(
|
||||
self.batch * self.m * self.n * (self.k * 4 + self.k * b_bytes + bias_bytes),
|
||||
)
|
||||
}
|
||||
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
self.output_size() * 4
|
||||
}
|
||||
|
||||
fn flops(&self) -> Expression {
|
||||
let per_out = self.k * 2 + if self.has_bias { 1 } else { 0 };
|
||||
Expression::from(self.batch * self.m * self.n * per_out)
|
||||
}
|
||||
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
"Matmul2D"
|
||||
}
|
||||
}
|
||||
|
||||
/// CustomOp wrapper for [`Matmul2DKernel`].
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Matmul2DCustom(pub Matmul2DKernel);
|
||||
|
||||
impl CustomOp for Matmul2DCustom {
|
||||
fn to_llir_op(&self) -> LLIROp {
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(self.0.clone()) as Box<dyn KernelOp>)
|
||||
}
|
||||
}
|
||||
|
||||
/// `(M, K) @ (K, N) -> (M, N)` for row-major F32 inputs. No bias.
|
||||
pub fn matmul_2d(a: GraphTensor, b: GraphTensor) -> GraphTensor {
|
||||
matmul_inner(a, b, /*transpose_b=*/ false, None)
|
||||
}
|
||||
|
||||
/// `(M, K) @ (N, K)ᵀ -> (M, N)` for row-major F32 inputs. No bias.
|
||||
/// Use this for `A @ Bᵀ` where B is stored row-major as `(N, K)` — the
|
||||
/// pattern produced by linear / projection layers (`x @ w.t()`).
|
||||
pub fn matmul_2d_t(a: GraphTensor, b: GraphTensor) -> GraphTensor {
|
||||
matmul_inner(a, b, /*transpose_b=*/ true, None)
|
||||
}
|
||||
|
||||
/// Linear projection with bias: `(M, K) @ (N, K)ᵀ + bias` where bias is
|
||||
/// `(N,)`, row-major F32 throughout.
|
||||
pub fn linear_bias(a: GraphTensor, b: GraphTensor, bias: GraphTensor) -> GraphTensor {
|
||||
matmul_inner(a, b, /*transpose_b=*/ true, Some(bias))
|
||||
}
|
||||
|
||||
/// Mixed-precision linear (no bias): `A (F32, M, K) @ B (BF16, N, K)ᵀ → (F32, M, N)`.
|
||||
///
|
||||
/// Lowers as plain HLIR — `Cast(A, BF16) @ permute(B_bf16) → Cast(F32)`.
|
||||
/// The activation cast and output cast are tiny (M*K and M*N elements;
|
||||
/// the K=hidden weight stays BF16). The inner BF16 matmul matches the
|
||||
/// existing cublaslt rewrite rules and runs as
|
||||
/// `CUBLAS_COMPUTE_32F_FAST_16BF` — Hopper's native 2× BF16 path.
|
||||
pub fn linear_no_bias_bf16_w(a: GraphTensor, b_bf16: GraphTensor) -> GraphTensor {
|
||||
assert_eq!(a.dtype, DType::F32, "linear_no_bias_bf16_w expects F32 A");
|
||||
assert_eq!(
|
||||
b_bf16.dtype,
|
||||
DType::Bf16,
|
||||
"linear_no_bias_bf16_w expects BF16 B"
|
||||
);
|
||||
let a_dims = a.dims();
|
||||
let b_dims = b_bf16.dims();
|
||||
assert_eq!(a_dims.len(), 2);
|
||||
assert_eq!(b_dims.len(), 2);
|
||||
let a_bf16 = a.cast(DType::Bf16);
|
||||
let b_kn = b_bf16.permute((1, 0));
|
||||
a_bf16.matmul(b_kn).cast(DType::F32)
|
||||
}
|
||||
|
||||
/// Batched matmul: `A (B, M, K) @ B (B, K, N) → (B, M, N)`, all F32 row-major.
|
||||
pub fn matmul_3d(a: GraphTensor, b: GraphTensor) -> GraphTensor {
|
||||
matmul_inner(a, b, /*transpose_b=*/ false, None)
|
||||
}
|
||||
|
||||
/// Batched matmul with B-transpose: `A (B, M, K) @ B (B, N, K)ᵀ → (B, M, N)`.
|
||||
pub fn matmul_3d_t(a: GraphTensor, b: GraphTensor) -> GraphTensor {
|
||||
matmul_inner(a, b, /*transpose_b=*/ true, None)
|
||||
}
|
||||
|
||||
fn matmul_inner(
|
||||
a: GraphTensor,
|
||||
b: GraphTensor,
|
||||
transpose_b: bool,
|
||||
bias: Option<GraphTensor>,
|
||||
) -> GraphTensor {
|
||||
assert_eq!(a.dtype, DType::F32, "matmul requires F32 A");
|
||||
let weight_dtype = b.dtype;
|
||||
assert!(
|
||||
matches!(weight_dtype, DType::F32 | DType::Bf16),
|
||||
"matmul B must be F32 or BF16, got {weight_dtype:?}",
|
||||
);
|
||||
let a_dims = a.dims();
|
||||
let b_dims = b.dims();
|
||||
assert_eq!(
|
||||
a_dims.len(),
|
||||
b_dims.len(),
|
||||
"matmul A/B rank mismatch: {} vs {}",
|
||||
a_dims.len(),
|
||||
b_dims.len(),
|
||||
);
|
||||
assert!(
|
||||
a_dims.len() == 2 || a_dims.len() == 3,
|
||||
"matmul expects rank 2 or 3, got rank {}",
|
||||
a_dims.len(),
|
||||
);
|
||||
|
||||
let (batch, a_off) = if a_dims.len() == 3 {
|
||||
let ba = a_dims[0].to_usize().expect("batch dim must be static");
|
||||
let bb = b_dims[0].to_usize().expect("batch dim must be static");
|
||||
assert_eq!(
|
||||
ba, bb,
|
||||
"matmul batch dim mismatch: A batch={ba}, B batch={bb}"
|
||||
);
|
||||
(ba, 1)
|
||||
} else {
|
||||
(1, 0)
|
||||
};
|
||||
|
||||
let m = a_dims[a_off].to_usize().expect("M must be a static dim");
|
||||
let k_a = a_dims[a_off + 1]
|
||||
.to_usize()
|
||||
.expect("K (A) must be a static dim");
|
||||
let (n, k_b) = if transpose_b {
|
||||
// B per-batch is (N, K)
|
||||
let n = b_dims[a_off].to_usize().expect("N must be a static dim");
|
||||
let k = b_dims[a_off + 1]
|
||||
.to_usize()
|
||||
.expect("K (B) must be a static dim");
|
||||
(n, k)
|
||||
} else {
|
||||
// B per-batch is (K, N)
|
||||
let k = b_dims[a_off]
|
||||
.to_usize()
|
||||
.expect("K (B) must be a static dim");
|
||||
let n = b_dims[a_off + 1]
|
||||
.to_usize()
|
||||
.expect("N must be a static dim");
|
||||
(n, k)
|
||||
};
|
||||
assert_eq!(k_a, k_b, "matmul K mismatch: A K={k_a}, B K={k_b}");
|
||||
let k = k_a;
|
||||
|
||||
let has_bias = bias.is_some();
|
||||
if let Some(bias) = bias {
|
||||
let bdims = bias.dims();
|
||||
assert_eq!(bdims.len(), 1, "matmul bias must be 1D");
|
||||
assert_eq!(
|
||||
bdims[0].to_usize().expect("bias dim must be static"),
|
||||
n,
|
||||
"matmul bias size must equal N"
|
||||
);
|
||||
assert_eq!(bias.dtype, DType::F32, "matmul bias must be F32");
|
||||
}
|
||||
|
||||
let kern = Matmul2DKernel {
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
batch,
|
||||
transpose_b,
|
||||
has_bias,
|
||||
weight_dtype,
|
||||
};
|
||||
let cx = unsafe { &mut *a.graph_ref };
|
||||
let inputs: Vec<GraphTensor> = if let Some(bias) = bias {
|
||||
vec![a, b, bias]
|
||||
} else {
|
||||
vec![a, b]
|
||||
};
|
||||
if batch == 1 {
|
||||
cx.custom_op(Matmul2DCustom(kern), inputs, (m, n), DType::F32)
|
||||
} else {
|
||||
cx.custom_op(Matmul2DCustom(kern), inputs, (batch, m, n), DType::F32)
|
||||
}
|
||||
}
|
||||
@@ -9,21 +9,12 @@ use luminal_tracing::schema::{
|
||||
};
|
||||
use uuid::Uuid;
|
||||
|
||||
pub mod conv2d;
|
||||
pub mod cuda_graph;
|
||||
pub mod fusion;
|
||||
pub mod hlir;
|
||||
pub mod matmul2d;
|
||||
pub mod other_ops;
|
||||
pub mod rope;
|
||||
|
||||
pub use conv2d::{Conv2DCustom, Conv2DKernel, conv2d_bias};
|
||||
pub use cuda_graph::*;
|
||||
pub use matmul2d::{
|
||||
Matmul2DCustom, Matmul2DKernel, linear_bias, linear_no_bias_bf16_w, matmul_2d, matmul_2d_t,
|
||||
matmul_3d, matmul_3d_t,
|
||||
};
|
||||
pub use rope::{RoPECustom, RoPEKernel, apply_rope};
|
||||
|
||||
pub type Ops = (hlir::Ops, other_ops::Ops, fusion::Ops);
|
||||
|
||||
|
||||
@@ -23,6 +23,8 @@ pub type Ops = (
|
||||
KernelBatchMatMul,
|
||||
KernelScatterNoCopy,
|
||||
KernelSoftmax,
|
||||
KernelExp,
|
||||
KernelSigmoid,
|
||||
);
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
@@ -623,7 +625,7 @@ extern \"C\" {{
|
||||
// KernelBatchMatVec: Fused batched matrix-vector product for attention
|
||||
// Matches: Mul(broadcast) + Sum pattern for [B, 1, K] x [B, K, N] -> [B, 1, N]
|
||||
// or [B, M, K] x [B, K, N] -> [B, M, N] with small M
|
||||
// Replaces the broadcast elementwise Mul + single-threaded KernelSumReduce pipeline
|
||||
// Replaces the broadcast KernelMul + single-threaded KernelSumReduce pipeline
|
||||
// =============================================================================
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
@@ -1454,3 +1456,393 @@ extern \"C\" {{
|
||||
"Softmax"
|
||||
}
|
||||
}
|
||||
|
||||
// KernelExp: native exp (uses expf instead of exp2f * constant)
|
||||
// Single-kernel alternative to the 3-kernel Constant+Mul+Exp2 path.
|
||||
// Improves numerical precision by avoiding the truncated log2(e) constant.
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct KernelExp {
|
||||
shape: Vec<Expression>,
|
||||
in_strides: Vec<Expression>,
|
||||
out_strides: Vec<Expression>,
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
impl EgglogOp for KernelExp {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
"KernelExp",
|
||||
&[
|
||||
("shape", ELIST),
|
||||
("strides", ELIST),
|
||||
("out_strides", ELIST),
|
||||
("dtype", DTYPE),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
fn n_inputs(&self) -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![
|
||||
// Match Exp2(Mul(x, log2e_constant)) directly.
|
||||
// This matches the pattern created by frontend exp() = (self * (1/ln(2))).exp2()
|
||||
Rule::raw(
|
||||
"(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?shape ?x_stride ?const_stride ?inter_stride) (ICons ?x (ICons ?exp_const (INil)))))
|
||||
(= ?exp2 (Op (Exp2 ?shape ?inter_stride ?out_stride) (ICons ?mul (INil))))
|
||||
(= ?dt (dtype ?x))
|
||||
(= ?cv (Op (Constant ?val) (INil)))
|
||||
(= ?exp_const ?cv)
|
||||
(> ?val 1.44)
|
||||
(< ?val 1.45)
|
||||
)
|
||||
(
|
||||
(let ?kexp (Op (KernelExp ?shape ?x_stride ?out_stride ?dt) (ICons ?x (INil))))
|
||||
(union ?exp2 ?kexp)
|
||||
(set (dtype ?kexp) ?dt)
|
||||
)
|
||||
:ruleset direct_kernel
|
||||
:name \"direct-exp-fusion\"
|
||||
)",
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
(
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(Self {
|
||||
shape: extract_expr_list(egraph, kind_children[0], list_cache, expr_cache).unwrap(),
|
||||
in_strides: extract_expr_list(egraph, kind_children[1], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
out_strides: extract_expr_list(egraph, kind_children[2], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
dtype: extract_dtype(egraph, kind_children[3]),
|
||||
})),
|
||||
input_enodes,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl KernelOp for KernelExp {
|
||||
fn compile(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> (
|
||||
CudaFunction,
|
||||
Arc<CudaModule>,
|
||||
String,
|
||||
(Expression, Expression, Expression),
|
||||
(Expression, Expression, Expression),
|
||||
Expression,
|
||||
FxHashMap<char, CudaSlice<u8>>,
|
||||
) {
|
||||
let vars = self
|
||||
.shape
|
||||
.iter()
|
||||
.flat_map(|e| e.dyn_vars())
|
||||
.chain(self.in_strides.iter().flat_map(|e| e.dyn_vars()))
|
||||
.chain(self.out_strides.iter().flat_map(|e| e.dyn_vars()))
|
||||
.collect::<FxHashSet<_>>();
|
||||
let dtype = cuda_dtype(self.dtype);
|
||||
let includes = dtype_includes(&[self.dtype]);
|
||||
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
|
||||
let dyn_dims_param = if vars.is_empty() {
|
||||
""
|
||||
} else {
|
||||
", const int* dyn_dims"
|
||||
};
|
||||
let n_elements = self
|
||||
.shape
|
||||
.iter()
|
||||
.copied()
|
||||
.product::<Expression>()
|
||||
.to_kernel();
|
||||
let out_idx = flatten_strides(&self.shape, &self.out_strides).to_kernel();
|
||||
let in_idx = flatten_strides(&self.shape, &self.in_strides).to_kernel();
|
||||
let kernel = format!(
|
||||
"{includes}
|
||||
{dyn_defines}
|
||||
extern \"C\" {{
|
||||
__global__ void exp_k({dtype} *out, const {dtype} *in{dyn_dims_param}) {{
|
||||
long long const_z = (long long)blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (const_z >= {n_elements}) return;
|
||||
out[{out_idx}] = expf(in[{in_idx}]);
|
||||
}}
|
||||
}}"
|
||||
);
|
||||
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
|
||||
(module.clone(), func.clone())
|
||||
} else {
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let func = module.load_function("exp_k").unwrap();
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
(module, func)
|
||||
};
|
||||
let out_size = self.shape.iter().copied().product::<Expression>();
|
||||
(
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(out_size.ceil_div(256), 1.into(), 1.into()),
|
||||
(out_size.min(256), 1.into(), 1.into()),
|
||||
0.into(),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
self.shape.iter().copied().product()
|
||||
}
|
||||
|
||||
fn output_bytes(&self) -> Expression {
|
||||
(self.output_size() * self.dtype.bits()).ceil_div(8)
|
||||
}
|
||||
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
self.output_bytes()
|
||||
}
|
||||
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
self.output_bytes()
|
||||
}
|
||||
|
||||
fn flops(&self) -> Expression {
|
||||
self.shape.iter().copied().product()
|
||||
}
|
||||
|
||||
fn output_dtype(&self) -> DType {
|
||||
self.dtype
|
||||
}
|
||||
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
"Exp"
|
||||
}
|
||||
}
|
||||
|
||||
// KernelSigmoid: fused sigmoid = 1/(1+exp(-x))
|
||||
// Single-kernel alternative to the 5-kernel Neg+Exp+Const+Add+Recip path.
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct KernelSigmoid {
|
||||
shape: Vec<Expression>,
|
||||
in_strides: Vec<Expression>,
|
||||
out_strides: Vec<Expression>,
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
impl EgglogOp for KernelSigmoid {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
"KernelSigmoid",
|
||||
&[
|
||||
("shape", ELIST),
|
||||
("strides", ELIST),
|
||||
("out_strides", ELIST),
|
||||
("dtype", DTYPE),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
fn n_inputs(&self) -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![
|
||||
// Stage the HLIR sigmoid pattern through a small marker so repeated
|
||||
// default passes do not re-run one large join over every Mul/Add/Recip.
|
||||
Rule::raw(
|
||||
"(datatype*
|
||||
(KernelSigmoidScaledState
|
||||
(MkKernelSigmoidScaledState IR EList EList DType)
|
||||
)
|
||||
)
|
||||
(function kernel_sigmoid_scaled (IR) KernelSigmoidScaledState :merge new)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?neg1 (Op (Constant ?nv) (INil)))
|
||||
(< ?nv -0.99)
|
||||
(> ?nv -1.01)
|
||||
(= ?neg_x (Op (Mul ?shape ?x_stride ?neg_stride ?neg_out_stride) (ICons ?x (ICons ?neg1 (INil)))))
|
||||
(= ?log2e (Op (Constant ?lv) (INil)))
|
||||
(> ?lv 1.44)
|
||||
(< ?lv 1.45)
|
||||
(= ?scaled (Op (Mul ?shape ?neg_out_stride ?log2e_stride ?scaled_stride) (ICons ?neg_x (ICons ?log2e (INil)))))
|
||||
(= ?dt (dtype ?x))
|
||||
)
|
||||
(
|
||||
(set (kernel_sigmoid_scaled ?scaled)
|
||||
(MkKernelSigmoidScaledState ?x ?shape ?x_stride ?dt))
|
||||
)
|
||||
:ruleset direct_kernel
|
||||
:name \"direct-sigmoid-scaled-marker\"
|
||||
)
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?scaled_state (kernel_sigmoid_scaled ?scaled))
|
||||
(= ?scaled_state (MkKernelSigmoidScaledState ?x ?shape ?x_stride ?dt))
|
||||
(= ?exp2 (Op (Exp2 ?shape ?scaled_stride ?exp_stride) (ICons ?scaled (INil))))
|
||||
(= ?one (Op (Constant ?ov) (INil)))
|
||||
(> ?ov 0.99)
|
||||
(< ?ov 1.01)
|
||||
(= ?plus_one (Op (Add ?shape ?exp_stride ?one_stride ?add_stride) (ICons ?exp2 (ICons ?one (INil)))))
|
||||
(= ?sig_out (Op (Recip ?shape ?add_stride ?out_stride) (ICons ?plus_one (INil))))
|
||||
)
|
||||
(
|
||||
(let ?ksig (Op (KernelSigmoid ?shape ?x_stride ?out_stride ?dt) (ICons ?x (INil))))
|
||||
(union ?sig_out ?ksig)
|
||||
(set (dtype ?ksig) ?dt)
|
||||
)
|
||||
:ruleset direct_kernel
|
||||
:name \"direct-sigmoid-fusion\"
|
||||
)",
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
(
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(Self {
|
||||
shape: extract_expr_list(egraph, kind_children[0], list_cache, expr_cache).unwrap(),
|
||||
in_strides: extract_expr_list(egraph, kind_children[1], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
out_strides: extract_expr_list(egraph, kind_children[2], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
dtype: extract_dtype(egraph, kind_children[3]),
|
||||
})),
|
||||
input_enodes,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl KernelOp for KernelSigmoid {
|
||||
fn compile(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> (
|
||||
CudaFunction,
|
||||
Arc<CudaModule>,
|
||||
String,
|
||||
(Expression, Expression, Expression),
|
||||
(Expression, Expression, Expression),
|
||||
Expression,
|
||||
FxHashMap<char, CudaSlice<u8>>,
|
||||
) {
|
||||
let vars = self
|
||||
.shape
|
||||
.iter()
|
||||
.flat_map(|e| e.dyn_vars())
|
||||
.chain(self.in_strides.iter().flat_map(|e| e.dyn_vars()))
|
||||
.chain(self.out_strides.iter().flat_map(|e| e.dyn_vars()))
|
||||
.collect::<FxHashSet<_>>();
|
||||
let dtype = cuda_dtype(self.dtype);
|
||||
let includes = dtype_includes(&[self.dtype]);
|
||||
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
|
||||
let dyn_dims_param = if vars.is_empty() {
|
||||
""
|
||||
} else {
|
||||
", const int* dyn_dims"
|
||||
};
|
||||
let n_elements = self
|
||||
.shape
|
||||
.iter()
|
||||
.copied()
|
||||
.product::<Expression>()
|
||||
.to_kernel();
|
||||
let out_idx = flatten_strides(&self.shape, &self.out_strides).to_kernel();
|
||||
let in_idx = flatten_strides(&self.shape, &self.in_strides).to_kernel();
|
||||
let kernel = format!(
|
||||
"{includes}
|
||||
{dyn_defines}
|
||||
extern \"C\" {{
|
||||
__global__ void sigmoid_k({dtype} *out, const {dtype} *in{dyn_dims_param}) {{
|
||||
long long const_z = (long long)blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (const_z >= {n_elements}) return;
|
||||
out[{out_idx}] = 1.0f / (1.0f + expf(-in[{in_idx}]));
|
||||
}}
|
||||
}}"
|
||||
);
|
||||
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
|
||||
(module.clone(), func.clone())
|
||||
} else {
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let func = module.load_function("sigmoid_k").unwrap();
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
(module, func)
|
||||
};
|
||||
let out_size = self.shape.iter().copied().product::<Expression>();
|
||||
(
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(out_size.ceil_div(256), 1.into(), 1.into()),
|
||||
(out_size.min(256), 1.into(), 1.into()),
|
||||
0.into(),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
self.shape.iter().copied().product()
|
||||
}
|
||||
|
||||
fn output_bytes(&self) -> Expression {
|
||||
(self.output_size() * self.dtype.bits()).ceil_div(8)
|
||||
}
|
||||
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
self.output_bytes()
|
||||
}
|
||||
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
self.output_bytes()
|
||||
}
|
||||
|
||||
fn flops(&self) -> Expression {
|
||||
// neg + exp + add + recip = ~4 ops per element
|
||||
self.shape.iter().copied().product::<Expression>() * 4
|
||||
}
|
||||
|
||||
fn output_dtype(&self) -> DType {
|
||||
self.dtype
|
||||
}
|
||||
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
"Sigmoid"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,189 +0,0 @@
|
||||
//! Fused RoPE (rotary position embedding) — interleaved-pair convention.
|
||||
//!
|
||||
//! Replaces flux2's 6-op RoPE chain (split / slice / squeeze / neg / concat /
|
||||
//! merge_dims / 4× cast / mul / add) with a single kernel launch per call.
|
||||
//! ~120 RoPE calls per forward pass at full DiT depth.
|
||||
//!
|
||||
//! Convention: `repeat_interleave_real=True` (Flux 2 / diffusers), so adjacent
|
||||
//! dim pairs rotate together. For an input `[a0, b0, a1, b1, ...]` and per-
|
||||
//! position `(cos, sin)`, the output is
|
||||
//! `out[2j] = x[2j] * cos[2j] - x[2j+1] * sin[2j]`
|
||||
//! `out[2j+1] = x[2j+1] * cos[2j+1] + x[2j] * sin[2j+1]`
|
||||
//!
|
||||
//! Layout: x `(S, H, D)`, cos/sin `(S, D)` (broadcast across H).
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use cudarc::driver::{CudaFunction, CudaModule, CudaSlice, CudaStream};
|
||||
use luminal::{
|
||||
dtype::DType, op::CustomOp, op::LLIROp, prelude::FxHashMap, prelude::GraphTensor,
|
||||
shape::Expression,
|
||||
};
|
||||
|
||||
use crate::compile_module_image_for_current_device;
|
||||
use crate::kernel::KernelOp;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RoPEKernel {
|
||||
pub s: usize,
|
||||
pub h: usize,
|
||||
pub d: usize,
|
||||
}
|
||||
|
||||
const TPB: usize = 64;
|
||||
|
||||
impl KernelOp for RoPEKernel {
|
||||
fn compile(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> (
|
||||
CudaFunction,
|
||||
Arc<CudaModule>,
|
||||
String,
|
||||
(Expression, Expression, Expression),
|
||||
(Expression, Expression, Expression),
|
||||
Expression,
|
||||
FxHashMap<char, CudaSlice<u8>>,
|
||||
) {
|
||||
let s = self.s;
|
||||
let h = self.h;
|
||||
let d = self.d;
|
||||
assert!(d.is_multiple_of(2), "RoPE head_dim must be even");
|
||||
let kernel = format!(
|
||||
r#"
|
||||
extern "C" __global__ void rope_kernel(
|
||||
float* __restrict__ out,
|
||||
const float* __restrict__ x,
|
||||
const float* __restrict__ cos_,
|
||||
const float* __restrict__ sin_
|
||||
) {{
|
||||
const int S = {s};
|
||||
const int H = {h};
|
||||
const int D = {d};
|
||||
int sh = blockIdx.x; // 0..S*H
|
||||
int s_idx = sh / H;
|
||||
int tid = threadIdx.x;
|
||||
|
||||
const float* xr = x + sh * D;
|
||||
const float* cosr = cos_ + s_idx * D;
|
||||
const float* sinr = sin_ + s_idx * D;
|
||||
float* yr = out + sh * D;
|
||||
|
||||
for (int i = tid; i < D; i += {TPB}) {{
|
||||
float xi = xr[i];
|
||||
float xpair;
|
||||
if ((i & 1) == 0) {{
|
||||
// even: paired with i+1, rotated value is -x[i+1]
|
||||
xpair = -xr[i + 1];
|
||||
}} else {{
|
||||
// odd: paired with i-1, rotated value is +x[i-1]
|
||||
xpair = xr[i - 1];
|
||||
}}
|
||||
yr[i] = xi * cosr[i] + xpair * sinr[i];
|
||||
}}
|
||||
}}
|
||||
"#
|
||||
);
|
||||
|
||||
let (module, func) = if let Some((m, f)) = compile_cache.get(&kernel) {
|
||||
(m.clone(), f.clone())
|
||||
} else {
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let func = module.load_function("rope_kernel").unwrap();
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
(module, func)
|
||||
};
|
||||
|
||||
(
|
||||
func,
|
||||
module,
|
||||
"rope_kernel".to_string(),
|
||||
(
|
||||
Expression::from(s * h),
|
||||
Expression::from(1usize),
|
||||
Expression::from(1usize),
|
||||
),
|
||||
(
|
||||
Expression::from(TPB),
|
||||
Expression::from(1usize),
|
||||
Expression::from(1usize),
|
||||
),
|
||||
Expression::from(0usize),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
Expression::from(self.s * self.h * self.d)
|
||||
}
|
||||
|
||||
fn output_bytes(&self) -> Expression {
|
||||
self.output_size() * 4
|
||||
}
|
||||
|
||||
fn output_dtype(&self) -> DType {
|
||||
DType::F32
|
||||
}
|
||||
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
// x: full (S,H,D); cos/sin: (S,D) read H times each but cached.
|
||||
Expression::from(self.s * self.h * self.d * 4 + self.s * self.d * 4 * 2)
|
||||
}
|
||||
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
self.output_size() * 4
|
||||
}
|
||||
|
||||
fn flops(&self) -> Expression {
|
||||
// 4 per output element (mul, neg/load, mul, add).
|
||||
Expression::from(self.s * self.h * self.d * 4)
|
||||
}
|
||||
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
"RoPE"
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RoPECustom(pub RoPEKernel);
|
||||
|
||||
impl CustomOp for RoPECustom {
|
||||
fn to_llir_op(&self) -> LLIROp {
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(self.0.clone()) as Box<dyn KernelOp>)
|
||||
}
|
||||
}
|
||||
|
||||
/// Apply RoPE: `x` shape `(S, H, D)` F32, `cos`/`sin` shape `(S, D)` F32.
|
||||
/// Returns `(S, H, D)` F32.
|
||||
pub fn apply_rope(x: GraphTensor, cos: GraphTensor, sin: GraphTensor) -> GraphTensor {
|
||||
assert_eq!(x.dtype, DType::F32, "RoPE x must be F32");
|
||||
let cos = if cos.dtype == DType::F32 {
|
||||
cos
|
||||
} else {
|
||||
cos.cast(DType::F32)
|
||||
};
|
||||
let sin = if sin.dtype == DType::F32 {
|
||||
sin
|
||||
} else {
|
||||
sin.cast(DType::F32)
|
||||
};
|
||||
let x_dims = x.dims();
|
||||
assert_eq!(x_dims.len(), 3, "RoPE x must be 3-D (S, H, D)");
|
||||
let s = x_dims[0].to_usize().expect("RoPE: S must be static");
|
||||
let h = x_dims[1].to_usize().expect("RoPE: H must be static");
|
||||
let d = x_dims[2].to_usize().expect("RoPE: D must be static");
|
||||
let cos_dims = cos.dims();
|
||||
let sin_dims = sin.dims();
|
||||
assert_eq!(cos_dims.len(), 2, "RoPE cos must be 2-D (S, D)");
|
||||
assert_eq!(sin_dims.len(), 2, "RoPE sin must be 2-D (S, D)");
|
||||
assert_eq!(cos_dims[0].to_usize().unwrap(), s, "RoPE cos S mismatch");
|
||||
assert_eq!(cos_dims[1].to_usize().unwrap(), d, "RoPE cos D mismatch");
|
||||
assert_eq!(sin_dims[0].to_usize().unwrap(), s, "RoPE sin S mismatch");
|
||||
assert_eq!(sin_dims[1].to_usize().unwrap(), d, "RoPE sin D mismatch");
|
||||
|
||||
let kern = RoPEKernel { s, h, d };
|
||||
let cx = unsafe { &mut *x.graph_ref };
|
||||
cx.custom_op(RoPECustom(kern), vec![x, cos, sin], (s, h, d), DType::F32)
|
||||
}
|
||||
@@ -192,32 +192,6 @@ impl CudaGraphOp {
|
||||
state: RefCell::new(state),
|
||||
}
|
||||
}
|
||||
|
||||
/// LLIR node IDs of every kernel in this CudaGraphOp, in the order
|
||||
/// they execute inside the compiled CUDA graph. This is the
|
||||
/// toposort `kernel_to_host` used at compile time, preserved here
|
||||
/// so the runtime can compute live ranges that match real
|
||||
/// execution order: each kernel in `state.kernels` was added to
|
||||
/// the CUDA graph with `prev_graph_node` as its sole dependency,
|
||||
/// which serializes them.
|
||||
pub fn kernel_topo_order(&self) -> Vec<NodeIndex> {
|
||||
self.state.borrow().kernels.iter().map(|k| k.node).collect()
|
||||
}
|
||||
|
||||
/// Direct LLIR-node inputs of one kernel inside this CudaGraphOp.
|
||||
/// Used by the runtime's live-range pass to refine intra-graph
|
||||
/// consumer positions: a kernel's input can stop being live as
|
||||
/// soon as that specific kernel finishes, not when the whole
|
||||
/// CudaGraphOp finishes.
|
||||
pub fn kernel_inputs(&self, kernel_node: NodeIndex) -> Vec<NodeIndex> {
|
||||
self.state
|
||||
.borrow()
|
||||
.kernels
|
||||
.iter()
|
||||
.find(|k| k.node == kernel_node)
|
||||
.map(|k| k.inputs.clone())
|
||||
.unwrap_or_default()
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for CudaGraphOp {
|
||||
@@ -840,7 +814,7 @@ pub fn kernel_to_host(
|
||||
}
|
||||
|
||||
let kernel_subgraphs = partition_marked_convex(llir_graph, &kernel_ops_in_graph).unwrap();
|
||||
// Compute the set of FS / FE / Cuda*Elementwise nodes globally absorbed by some
|
||||
// Compute the set of FS / FE / FusedX nodes globally absorbed by some
|
||||
// FusionEnd in the LLIR. Used by `build_compile_units` to suppress
|
||||
// standalone marker compile units for shared FS leaves whose consumers
|
||||
// live in a different convex subgraph than the FS itself.
|
||||
@@ -1000,7 +974,7 @@ pub fn kernel_to_host(
|
||||
// (so FE provides trait methods like output_size /
|
||||
// build_params) but its `inputs` are the external
|
||||
// producers, not FE's literal LLIR predecessors —
|
||||
// those are interior elementwise nodes that don't exist
|
||||
// those are interior FusedX nodes that don't exist
|
||||
// as buffer-bearing nodes from the host's view.
|
||||
let fe_op_ref = llir_graph[region.fe_node]
|
||||
.to_dialect::<dyn KernelOp>()
|
||||
@@ -1165,7 +1139,7 @@ pub fn kernel_to_host(
|
||||
}
|
||||
|
||||
// Strip fully-absorbed marker nodes (FusionStart, nested FusionEnd,
|
||||
// Cuda*Elementwise) from the LLIR. Region codegen has already folded them into
|
||||
// FusedX) from the LLIR. Region codegen has already folded them into
|
||||
// a single fused CUDA function anchored at each region's root
|
||||
// FusionEnd; the absorbed nodes have no consumers outside the region
|
||||
// and never need their own buffers. Removing them keeps later
|
||||
|
||||
@@ -237,7 +237,6 @@ pub(crate) fn split_egraph_by_memory_limit(
|
||||
let mut split = splitter.split();
|
||||
|
||||
compact_egraph_after_prune(&mut split);
|
||||
validate_unique_loop_markers(&split);
|
||||
let stats = MemorySplitStats {
|
||||
original_enodes,
|
||||
split_enodes: split.enodes.len(),
|
||||
@@ -443,9 +442,6 @@ impl<'a> StateSplitter<'a> {
|
||||
}
|
||||
}
|
||||
"Op" => self.split_op_node(owner_class, node, label, children),
|
||||
label if direct_loop_marker(label) => {
|
||||
self.split_direct_loop_marker_node(owner_class, node, label.to_string(), children)
|
||||
}
|
||||
_ => {
|
||||
let Some((idx, child_class)) =
|
||||
first_child_with_sort_index(self.original, &children, "IR")
|
||||
@@ -483,9 +479,6 @@ impl<'a> StateSplitter<'a> {
|
||||
|
||||
let input_states = self.split_list_class(inputs_class);
|
||||
for kind_node in kind_nodes {
|
||||
let Some((kind_label, _)) = self.original.enodes.get(kind_node) else {
|
||||
continue;
|
||||
};
|
||||
let Some(kind) =
|
||||
kind_memory_for_node(self.original, &self.sort_by_name, kind_node, self.dyn_map)
|
||||
else {
|
||||
@@ -495,33 +488,6 @@ impl<'a> StateSplitter<'a> {
|
||||
continue;
|
||||
}
|
||||
let kind_split_class = self.kind_singleton_class(kind_node);
|
||||
if loop_op_kind(kind_label) {
|
||||
// Loop OpKinds are structural markers. Keep the marker singleton and
|
||||
// pick one feasible state for the data flowing through it.
|
||||
let Some((state, input_split_class)) = input_states
|
||||
.iter()
|
||||
.filter_map(|(input_state, input_split_class)| {
|
||||
let state = op_memory_state(kind, input_state)?;
|
||||
(state.peak <= self.limit).then(|| (state, input_split_class.clone()))
|
||||
})
|
||||
.min_by_key(|(state, _)| (state.peak, state.live))
|
||||
else {
|
||||
continue;
|
||||
};
|
||||
|
||||
let mut split_children = children.clone();
|
||||
split_children[0] = kind_split_class;
|
||||
split_children[1] = input_split_class;
|
||||
self.add_ir_state_node(
|
||||
owner_class,
|
||||
state,
|
||||
label.clone(),
|
||||
split_children,
|
||||
source_node,
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
for (input_state, input_split_class) in &input_states {
|
||||
let Some(state) = op_memory_state(kind, input_state) else {
|
||||
continue;
|
||||
@@ -543,33 +509,6 @@ impl<'a> StateSplitter<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
fn split_direct_loop_marker_node(
|
||||
&mut self,
|
||||
owner_class: &ClassId,
|
||||
source_node: &NodeId,
|
||||
label: String,
|
||||
children: Vec<ClassId>,
|
||||
) {
|
||||
let Some((idx, child_class)) = first_child_with_sort_index(self.original, &children, "IR")
|
||||
else {
|
||||
return;
|
||||
};
|
||||
// LoopStart/LoopEnd identity is part of the loop scaffold, so state
|
||||
// splitting must not clone the marker across child-state variants.
|
||||
let Some((state, state_class)) = self
|
||||
.split_ir_class(&child_class)
|
||||
.into_iter()
|
||||
.filter(|(state, _)| state.peak <= self.limit)
|
||||
.min_by_key(|(state, _)| (state.peak, state.live))
|
||||
else {
|
||||
return;
|
||||
};
|
||||
|
||||
let mut split_children = children;
|
||||
split_children[idx] = state_class;
|
||||
self.add_ir_state_node(owner_class, state, label, split_children, source_node);
|
||||
}
|
||||
|
||||
fn split_list_class(&mut self, class: &ClassId) -> Vec<(ListMemoryState, ClassId)> {
|
||||
if let Some(states) = self.list_memo.get(class) {
|
||||
return states.clone();
|
||||
@@ -1053,10 +992,7 @@ fn choose_kind_node<'a>(egraph: &'a SerializedEGraph, kind_class: &ClassId) -> O
|
||||
};
|
||||
let is_kernel = |node: &&NodeId| -> bool {
|
||||
let label = &egraph.enodes[*node].0;
|
||||
label.starts_with("Kernel")
|
||||
|| label.starts_with("Cuda")
|
||||
|| label == "FusionStart"
|
||||
|| label == "FusionEnd"
|
||||
label.starts_with("Kernel") || label.starts_with("Fused")
|
||||
};
|
||||
|
||||
kind_enodes
|
||||
@@ -1143,94 +1079,12 @@ fn compact_egraph_after_prune(egraph: &mut SerializedEGraph) {
|
||||
}
|
||||
|
||||
fn zero_local_op_kind(kind: &str) -> bool {
|
||||
loop_op_kind(kind)
|
||||
}
|
||||
|
||||
fn loop_op_kind(kind: &str) -> bool {
|
||||
matches!(
|
||||
kind,
|
||||
"LoopInput" | "LoopInputStatic" | "LoopOutput" | "LoopOutputSelect"
|
||||
)
|
||||
}
|
||||
|
||||
fn direct_loop_marker(kind: &str) -> bool {
|
||||
matches!(kind, "LoopStart" | "LoopEnd")
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
struct LoopMarkerKey {
|
||||
label: String,
|
||||
fields: Vec<String>,
|
||||
}
|
||||
|
||||
fn validate_unique_loop_markers(egraph: &SerializedEGraph) {
|
||||
let mut seen = FxHashMap::default();
|
||||
for node in egraph.enodes.keys() {
|
||||
for key in loop_marker_keys_for_node(egraph, node) {
|
||||
if let Some(previous) = seen.insert(key.clone(), node.clone()) {
|
||||
panic!(
|
||||
"CUDA memory splitter duplicated loop marker {key:?}: {previous:?} and {node:?}"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn loop_marker_keys_for_node(egraph: &SerializedEGraph, node: &NodeId) -> Vec<LoopMarkerKey> {
|
||||
let Some((label, children)) = egraph.enodes.get(node) else {
|
||||
return Vec::new();
|
||||
};
|
||||
if direct_loop_marker(label) {
|
||||
return vec![LoopMarkerKey {
|
||||
label: label.clone(),
|
||||
fields: field_signature(egraph, children.iter().skip(1)),
|
||||
}];
|
||||
}
|
||||
if label != "Op" {
|
||||
return Vec::new();
|
||||
}
|
||||
let Some(kind_class) = children.first() else {
|
||||
return Vec::new();
|
||||
};
|
||||
let Some((sort, kind_nodes)) = egraph.eclasses.get(kind_class) else {
|
||||
return Vec::new();
|
||||
};
|
||||
if sort != "OpKind" {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
kind_nodes
|
||||
.iter()
|
||||
.filter_map(|kind_node| {
|
||||
let (kind_label, kind_children) = egraph.enodes.get(kind_node)?;
|
||||
loop_op_kind(kind_label).then(|| LoopMarkerKey {
|
||||
label: kind_label.clone(),
|
||||
fields: field_signature(egraph, kind_children.iter()),
|
||||
})
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn field_signature<'a>(
|
||||
egraph: &SerializedEGraph,
|
||||
fields: impl Iterator<Item = &'a ClassId>,
|
||||
) -> Vec<String> {
|
||||
fields
|
||||
.map(|class| {
|
||||
let node_label = egraph
|
||||
.eclasses
|
||||
.get(class)
|
||||
.and_then(|(_, nodes)| {
|
||||
nodes
|
||||
.iter()
|
||||
.find_map(|node| egraph.enodes.get(node).map(|(label, _)| label.clone()))
|
||||
})
|
||||
.unwrap_or_else(|| "<missing>".to_string());
|
||||
format!("{}:{node_label}", class.as_ref())
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn cuda_sort_map() -> FxHashMap<String, SortDef> {
|
||||
<(crate::kernel::Ops, crate::host::Ops) as luminal::op::IntoEgglogOp>::into_vec()
|
||||
.into_iter()
|
||||
@@ -1250,7 +1104,7 @@ fn local_output_bytes<'a>(
|
||||
) -> Option<Expression> {
|
||||
match sort.name.as_str() {
|
||||
name if zero_local_op_kind(name) => Some(0.into()),
|
||||
name if name.starts_with("Cuda") || name == "FusionStart" => Some(0.into()),
|
||||
name if name.starts_with("Fused") || name == "FusionStart" => Some(0.into()),
|
||||
"KernelConstant" => Some(4.into()),
|
||||
"KernelIota" => Some(expr_field(egraph, sort, kind_children, "range", expr_cache)? * 4),
|
||||
"KernelLessThan" => Some(n_elements_field(
|
||||
@@ -1281,7 +1135,7 @@ fn local_output_bytes<'a>(
|
||||
let dtype = dtype_field(egraph, sort, kind_children, "dtype")?;
|
||||
Some(bytes_for_elements(size, dtype))
|
||||
}
|
||||
"cublaslt" | "cublaslt_scaled" => {
|
||||
"cublaslt" => {
|
||||
let batch = expr_field(egraph, sort, kind_children, "batch_count", expr_cache)?;
|
||||
let m = expr_field(egraph, sort, kind_children, "m", expr_cache)?;
|
||||
let n = expr_field(egraph, sort, kind_children, "n", expr_cache)?;
|
||||
@@ -1359,7 +1213,7 @@ fn n_elements_field<'a>(
|
||||
fn output_bytes_rules(sort: &SortDef) -> Vec<String> {
|
||||
match sort.name.as_str() {
|
||||
name if zero_local_op_kind(name) => vec![output_bytes_rule(sort, "(MNum 0)", "zero")],
|
||||
name if name.starts_with("Cuda") || name == "FusionStart" => {
|
||||
name if name.starts_with("Fused") || name == "FusionStart" => {
|
||||
vec![output_bytes_rule(sort, "(MNum 0)", "zero")]
|
||||
}
|
||||
"KernelConstant" => vec![output_bytes_rule(sort, "(MNum 4)", "f32-scalar")],
|
||||
@@ -1390,7 +1244,7 @@ fn output_bytes_rules(sort: &SortDef) -> Vec<String> {
|
||||
&["(= ?__cuda_elems (n_elements ?batch_shape))"],
|
||||
)],
|
||||
"KernelCast" => dtype_output_bytes_rules(sort, "size", "dtype"),
|
||||
"cublaslt" | "cublaslt_scaled" => {
|
||||
"cublaslt" => {
|
||||
dtype_output_bytes_rules_for_expr(sort, "(MMul (MMul ?batch_count ?m) ?n)", "d_dtype")
|
||||
}
|
||||
"GLUMoE" => vec![output_bytes_rule(
|
||||
@@ -1517,9 +1371,7 @@ fn output_bytes_rule_with_facts(
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{
|
||||
cuda_memory_analysis_pass, estimate_graph_memory_bytes, loop_marker_keys_for_node,
|
||||
};
|
||||
use super::{cuda_memory_analysis_pass, estimate_graph_memory_bytes};
|
||||
use luminal::{
|
||||
egglog_utils::{
|
||||
EGraphChoiceSet, SerializedEGraph, count_choice_sets_up_to, random_initial_choice,
|
||||
@@ -1531,7 +1383,11 @@ mod tests {
|
||||
};
|
||||
|
||||
fn ops() -> Vec<std::sync::Arc<Box<dyn luminal::op::EgglogOp>>> {
|
||||
let mut ops = <(crate::kernel::Ops, crate::host::Ops) as IntoEgglogOp>::into_vec();
|
||||
let mut ops = <(
|
||||
crate::kernel::hlir::Ops,
|
||||
crate::kernel::other_ops::Ops,
|
||||
crate::host::Ops,
|
||||
) as IntoEgglogOp>::into_vec();
|
||||
ops.extend(<HLIROps as IntoEgglogOp>::into_vec());
|
||||
ops
|
||||
}
|
||||
@@ -1543,13 +1399,13 @@ mod tests {
|
||||
.expect("cuda memory pass should parse and run")
|
||||
}
|
||||
|
||||
fn kernel_mod(name: &str, size: &str, a: &str, b: &str) -> String {
|
||||
fn kernel_add(name: &str, size: usize, a: &str, b: &str) -> String {
|
||||
format!(
|
||||
r#"
|
||||
(let {name}
|
||||
(Op
|
||||
(KernelMod
|
||||
(ECons {size} (ENil))
|
||||
(KernelAdd
|
||||
(ECons (MNum {size}) (ENil))
|
||||
(ECons (MIter) (ENil))
|
||||
(ECons (MIter) (ENil))
|
||||
(ECons (MIter) (ENil))
|
||||
@@ -1598,20 +1454,25 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cuda_memory_late_pass_runs_on_kernel_mod() {
|
||||
fn cuda_memory_late_pass_runs_on_kernel_add() {
|
||||
let ops = ops();
|
||||
let late_pass = cuda_memory_analysis_pass(&ops, None, &FxHashMap::default());
|
||||
let program = format!(
|
||||
r#"
|
||||
let program = r#"
|
||||
(let t0 (Input 0 "" (F32)))
|
||||
(let t1 (Input 1 "" (F32)))
|
||||
{}
|
||||
(let t2
|
||||
(Op
|
||||
(KernelAdd
|
||||
(ECons (MNum 4) (ENil))
|
||||
(ECons (MIter) (ENil))
|
||||
(ECons (MIter) (ENil))
|
||||
(ECons (MIter) (ENil))
|
||||
(F32))
|
||||
(ICons t0 (ICons t1 (INil)))))
|
||||
(let t3 (Output t2 2))
|
||||
"#,
|
||||
kernel_mod("t2", "(MNum 4)", "t0", "t1"),
|
||||
);
|
||||
"#;
|
||||
|
||||
run_egglog_with_late_passes(&program, "t3", &ops, false, &[late_pass])
|
||||
run_egglog_with_late_passes(program, "t3", &ops, false, &[late_pass])
|
||||
.expect("cuda memory pass should parse and run");
|
||||
}
|
||||
|
||||
@@ -1638,55 +1499,6 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cuda_memory_state_split_does_not_duplicate_loop_markers() {
|
||||
let program = format!(
|
||||
r#"
|
||||
(let t0 (Input 0 "" (F32)))
|
||||
(let t1 (Input 1 "" (F32)))
|
||||
{}
|
||||
{}
|
||||
(union small big)
|
||||
(let loop_start (LoopStart small 0 0 (MNum 2) (F32)))
|
||||
(let loop_end (LoopEnd small 0 0 (F32)))
|
||||
(let loop_input (Op (LoopInput 0 0 (F32)) (ICons small (ICons t0 (INil)))))
|
||||
(let loop_output (Op (LoopOutput 0 0 (F32)) (ICons small (INil))))
|
||||
(let loop_select (Op (LoopOutputSelect 0 0 0 (F32)) (ICons loop_output (INil))))
|
||||
(let out_start (Output loop_start 2))
|
||||
(let out_end (Output loop_end 3))
|
||||
(let out_input (Output loop_input 4))
|
||||
(let out_select (Output loop_select 5))
|
||||
(let out_a (OutputJoin out_start out_end))
|
||||
(let out_b (OutputJoin out_input out_select))
|
||||
(let out (OutputJoin out_a out_b))
|
||||
"#,
|
||||
kernel_mod("small", "(MNum 4)", "t0", "t1"),
|
||||
kernel_mod("big", "(MNum 8)", "t0", "t1"),
|
||||
);
|
||||
|
||||
let egraph = run_memory_egraph(&program, "out", Some(1024));
|
||||
let mut marker_counts = FxHashMap::<String, usize>::default();
|
||||
for node in egraph.enodes.keys() {
|
||||
for key in loop_marker_keys_for_node(&egraph, node) {
|
||||
*marker_counts.entry(key.label).or_default() += 1;
|
||||
}
|
||||
}
|
||||
|
||||
for marker in [
|
||||
"LoopStart",
|
||||
"LoopEnd",
|
||||
"LoopInput",
|
||||
"LoopOutput",
|
||||
"LoopOutputSelect",
|
||||
] {
|
||||
assert_eq!(
|
||||
marker_counts.get(marker).copied().unwrap_or_default(),
|
||||
1,
|
||||
"{marker} should not be duplicated by memory state splitting"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cuda_memory_estimates_peak_for_two_live_inputs() {
|
||||
let program = format!(
|
||||
@@ -1698,9 +1510,9 @@ mod tests {
|
||||
{}
|
||||
(let out (Output parent 3))
|
||||
"#,
|
||||
kernel_mod("left", "(MNum 4)", "t0", "t1"),
|
||||
kernel_mod("right", "(MNum 4)", "t0", "t1"),
|
||||
kernel_mod("parent", "(MNum 4)", "left", "right"),
|
||||
kernel_add("left", 4, "t0", "t1"),
|
||||
kernel_add("right", 4, "t0", "t1"),
|
||||
kernel_add("parent", 4, "left", "right"),
|
||||
);
|
||||
let egraph = run_memory_egraph(&program, "out", None);
|
||||
let mut rng = rand::rng();
|
||||
@@ -1734,7 +1546,7 @@ mod tests {
|
||||
(ICons dest (ICons indexes (ICons src (INil))))))
|
||||
(let out (Output scatter 4))
|
||||
"#,
|
||||
kernel_mod("dest", "(MNum 4)", "t0", "t1"),
|
||||
kernel_add("dest", 4, "t0", "t1"),
|
||||
);
|
||||
let egraph = run_memory_egraph(&program, "out", None);
|
||||
let mut rng = rand::rng();
|
||||
@@ -1757,8 +1569,8 @@ mod tests {
|
||||
(union small big)
|
||||
(let out (Output small 2))
|
||||
"#,
|
||||
kernel_mod("small", "(MNum 4)", "t0", "t1"),
|
||||
kernel_mod("big", "(MNum 32)", "t0", "t1"),
|
||||
kernel_add("small", 4, "t0", "t1"),
|
||||
kernel_add("big", 32, "t0", "t1"),
|
||||
);
|
||||
|
||||
let egraph = run_memory_egraph(&program, "out", Some(64));
|
||||
@@ -1778,17 +1590,22 @@ mod tests {
|
||||
let mut dyn_map = FxHashMap::default();
|
||||
dyn_map.insert('s', 4);
|
||||
let late_pass = cuda_memory_analysis_pass(&ops, Some(16), &dyn_map);
|
||||
let program = format!(
|
||||
r#"
|
||||
let program = r#"
|
||||
(let t0 (Input 0 "" (F32)))
|
||||
(let t1 (Input 1 "" (F32)))
|
||||
{}
|
||||
(let add
|
||||
(Op
|
||||
(KernelAdd
|
||||
(ECons (MVar "s") (ENil))
|
||||
(ECons (MIter) (ENil))
|
||||
(ECons (MIter) (ENil))
|
||||
(ECons (MIter) (ENil))
|
||||
(F32))
|
||||
(ICons t0 (ICons t1 (INil)))))
|
||||
(let out (Output add 2))
|
||||
"#,
|
||||
kernel_mod("add", "(MVar \"s\")", "t0", "t1"),
|
||||
);
|
||||
"#;
|
||||
|
||||
let egraph = run_egglog_with_late_passes(&program, "out", &ops, false, &[late_pass])
|
||||
let egraph = run_egglog_with_late_passes(program, "out", &ops, false, &[late_pass])
|
||||
.expect("cuda memory pass should parse and run");
|
||||
assert_eq!(count_choice_sets_up_to(&egraph, 10), 1);
|
||||
|
||||
@@ -1811,9 +1628,9 @@ mod tests {
|
||||
{}
|
||||
(let out (Output parent 3))
|
||||
"#,
|
||||
kernel_mod("left", "(MNum 12)", "t0", "t1"),
|
||||
kernel_mod("right", "(MNum 12)", "t0", "t1"),
|
||||
kernel_mod("parent", "(MNum 4)", "left", "right"),
|
||||
kernel_add("left", 12, "t0", "t1"),
|
||||
kernel_add("right", 12, "t0", "t1"),
|
||||
kernel_add("parent", 4, "left", "right"),
|
||||
);
|
||||
|
||||
let egraph = run_memory_egraph(&program, "out", Some(64));
|
||||
@@ -1842,11 +1659,11 @@ mod tests {
|
||||
{}
|
||||
(let out (Output parent 4))
|
||||
"#,
|
||||
kernel_mod("left_small", "(MNum 4)", "t0", "t1"),
|
||||
kernel_mod("left_medium", "(MNum 8)", "t0", "t1"),
|
||||
kernel_mod("left_big", "(MNum 12)", "t0", "t1"),
|
||||
kernel_mod("right_small", "(MNum 4)", "t0", "t1"),
|
||||
kernel_mod("parent", "(MNum 4)", "left_small", "right_small"),
|
||||
kernel_add("left_small", 4, "t0", "t1"),
|
||||
kernel_add("left_medium", 8, "t0", "t1"),
|
||||
kernel_add("left_big", 12, "t0", "t1"),
|
||||
kernel_add("right_small", 4, "t0", "t1"),
|
||||
kernel_add("parent", 4, "left_small", "right_small"),
|
||||
);
|
||||
|
||||
let uncapped_start = std::time::Instant::now();
|
||||
|
||||
@@ -287,12 +287,7 @@ impl CudaRuntime {
|
||||
let dev = f32s.to_cuda_input(&self.cuda_stream);
|
||||
self.hlir_buffers.insert(node, dev);
|
||||
}
|
||||
safetensors::Dtype::U8
|
||||
| safetensors::Dtype::BF16
|
||||
| safetensors::Dtype::F16
|
||||
| safetensors::Dtype::F8_E4M3
|
||||
| safetensors::Dtype::F8_E5M2
|
||||
| safetensors::Dtype::F8_E8M0 => {
|
||||
safetensors::Dtype::U8 | safetensors::Dtype::BF16 | safetensors::Dtype::F16 => {
|
||||
let bytes = tensor.data();
|
||||
let dev = bytes.to_cuda_input(&self.cuda_stream);
|
||||
self.hlir_buffers.insert(node, dev);
|
||||
@@ -1194,7 +1189,7 @@ impl Runtime for CudaRuntime {
|
||||
}
|
||||
|
||||
fn estimate_graph_memory<'a>(
|
||||
egraph: &'a luminal::egglog_utils::SerializedEGraph,
|
||||
egraph: &'a SerializedEGraph,
|
||||
choices: &luminal::egglog_utils::EGraphChoiceSet<'a>,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> Option<usize> {
|
||||
@@ -1348,8 +1343,8 @@ impl Runtime for CudaRuntime {
|
||||
&mut self,
|
||||
llir_graph: &LLIRGraph,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
trials: usize,
|
||||
timeout: Option<std::time::Duration>,
|
||||
_trials: usize,
|
||||
_timeout: Option<std::time::Duration>,
|
||||
) -> (Self::ProfileMetric, String) {
|
||||
// Clear active bucket's arena before loading new LLIR for profiling.
|
||||
if !self.compiled_buckets.is_empty() {
|
||||
@@ -1357,18 +1352,10 @@ impl Runtime for CudaRuntime {
|
||||
}
|
||||
self.load_llir(llir_graph);
|
||||
self.profiling = true;
|
||||
let profile_start = std::time::Instant::now();
|
||||
let mut durations = Vec::with_capacity(trials.max(1));
|
||||
for _ in 0..trials.max(1) {
|
||||
let start = std::time::Instant::now();
|
||||
self.execute(dyn_map);
|
||||
durations.push(start.elapsed());
|
||||
if timeout.is_some_and(|timeout| profile_start.elapsed() >= timeout) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
let start = std::time::Instant::now();
|
||||
self.execute(dyn_map);
|
||||
self.profiling = false;
|
||||
let duration = durations.iter().sum::<std::time::Duration>() / durations.len() as u32;
|
||||
let duration = start.elapsed();
|
||||
|
||||
let total_bytes: usize = self
|
||||
.last_kernel_stats
|
||||
@@ -1670,8 +1657,8 @@ impl CudaRuntime {
|
||||
//
|
||||
// The default assumption is "yes" for ordinary kernel ops
|
||||
// (Conv outputs, matmul outputs, etc). FusionStart and
|
||||
// Cuda*Elementwise are the exceptions — they're synthetic
|
||||
// nodes that the fusion rewrites add inside a region; the
|
||||
// Fused* are the exceptions — they're synthetic markers
|
||||
// that the fusion rewrites add inside a region; the
|
||||
// megakernel computes them in registers and never writes
|
||||
// to memory, so allocating a buffer would just be waste.
|
||||
//
|
||||
@@ -1686,12 +1673,12 @@ impl CudaRuntime {
|
||||
// an unrelated downstream op that lives in another region.
|
||||
//
|
||||
// Safe over-approximation: if the node is a FusionStart /
|
||||
// Cuda*Elementwise and *any* of its consumers is a FusionStart
|
||||
// Fused* and *any* of its consumers is a FusionStart
|
||||
// (which can only happen when that consumer is the leaf
|
||||
// of a different region) or a non-marker op (e.g. an
|
||||
// unfused Add/Mul reading the value directly), allocate a
|
||||
// buffer so cross-region reads have somewhere to land.
|
||||
let is_marker = kernel_name == "FusionStart" || kernel_name.starts_with("Cuda");
|
||||
let is_marker = kernel_name == "FusionStart" || kernel_name.starts_with("Fused");
|
||||
let has_external_consumer = is_marker
|
||||
&& llir_graph
|
||||
.neighbors_directed(node, Direction::Outgoing)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,8 +1,7 @@
|
||||
use luminal::{
|
||||
dtype::DType,
|
||||
egglog_utils::{
|
||||
ClassId, NodeId, SerializedEGraph, egglog_to_llir, random_initial_choice,
|
||||
validate_choice_set,
|
||||
NodeId, SerializedEGraph, egglog_to_llir, random_initial_choice, validate_choice_set,
|
||||
},
|
||||
prelude::*,
|
||||
};
|
||||
@@ -12,8 +11,7 @@ use crate::{
|
||||
host::{
|
||||
CublasLtMatrixOrders, CublasLtScaleValues, CublasLtTransposeOps, CublasLtTypeTuple, HostOp,
|
||||
cublaslt_c_d_layouts_match, cublaslt_epilogue, cublaslt_matrix_orders,
|
||||
cublaslt_scale_values, cublaslt_tensor_scale_inputs, cublaslt_transpose_ops,
|
||||
cublaslt_type_tuple,
|
||||
cublaslt_scale_values, cublaslt_transpose_ops, cublaslt_type_tuple,
|
||||
},
|
||||
runtime::CudaRuntime,
|
||||
};
|
||||
@@ -902,196 +900,6 @@ fn cublaslt_fp8_e4m3_beta_candidate_executes_2d_matmul_plus_f32_c() {
|
||||
assert_close(&rt.get_f32(out.id), &expected, 1e-5, 1e-5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "expensive CUDA FP8 rewrite sweep; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
fn cublaslt_fp8_scaled_candidate_executes_2d_matmul_f32_output() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
if !gpu_supports_cublaslt_fp8_launch(DType::F8E4M3) {
|
||||
return;
|
||||
}
|
||||
|
||||
let (m, n, k) = (16, 16, 16);
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor((m, k));
|
||||
let a_scale = cx.tensor(());
|
||||
let b_scale = cx.tensor(());
|
||||
let b_input = cx.tensor((n, k)).as_dtype(DType::F8E4M3);
|
||||
let b = b_input.t();
|
||||
let scaled_a = (a / a_scale.expand_rhs((m, k))).cast(DType::F8E4M3);
|
||||
let out =
|
||||
(scaled_a.matmul(b).cast(DType::F32) * (a_scale * b_scale).expand_rhs((m, n))).output();
|
||||
let expected_tuple = (
|
||||
DType::F8E4M3,
|
||||
DType::F8E4M3,
|
||||
DType::F32,
|
||||
DType::F32,
|
||||
"32F",
|
||||
DType::F32,
|
||||
);
|
||||
let llir = extract_forced_cublaslt_llir_where(&mut cx, "functional scaled fp8", |llir| {
|
||||
cublaslt_type_tuples(llir).contains(&expected_tuple)
|
||||
&& cublaslt_tensor_scale_input_tuples(llir).contains(&(true, true))
|
||||
&& cublaslt_transpose_op_tuples(llir).contains(&("T", "N"))
|
||||
&& cublaslt_matrix_order_tuples(llir).contains(&("COL", "COL", "COL", "COL"))
|
||||
});
|
||||
|
||||
let input_scale = 0.25f32;
|
||||
let weight_scale = 2.0f32;
|
||||
let (a_fp8_bytes, a_values) = fp8_exact_bytes(DType::F8E4M3, m * k, 7);
|
||||
let a_data = a_values
|
||||
.iter()
|
||||
.map(|value| value * input_scale)
|
||||
.collect::<Vec<_>>();
|
||||
let (b_bytes, b_storage_values) = fp8_exact_bytes(DType::F8E4M3, k * n, 9);
|
||||
let b_values = logical_b_from_column_major_storage(&b_storage_values, n, k);
|
||||
let mut expected = reference_matmul_2d(&a_values, &b_values, m, n, k);
|
||||
for value in &mut expected {
|
||||
*value *= input_scale * weight_scale;
|
||||
}
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.load_llir(&llir);
|
||||
rt.set_data(a, a_data);
|
||||
rt.set_data(a_scale, vec![input_scale]);
|
||||
rt.set_data(b_scale, vec![weight_scale]);
|
||||
rt.set_data(b_input, b_bytes);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
// Keep the raw bytes live in the test construction: a_data was chosen so
|
||||
// the explicit scaled cast quantizes to these exact FP8 values.
|
||||
assert_eq!(a_fp8_bytes.len(), m * k);
|
||||
assert_close(&rt.get_f32(out.id), &expected, 1e-5, 1e-5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cublaslt_fp8_scaled_candidate_reaches_fused_output_scale_consumer() {
|
||||
let (m, n, k) = (16, 16, 16);
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor((m, k));
|
||||
let a_scale = cx.tensor(());
|
||||
let b_scale = cx.tensor(());
|
||||
let b_input = cx.tensor((n, k)).as_dtype(DType::F8E4M3);
|
||||
let b = b_input.t();
|
||||
let side = cx.tensor((m, n));
|
||||
let scaled_a = (a / a_scale.expand_rhs((m, k))).cast(DType::F8E4M3);
|
||||
let scaled_out = scaled_a.matmul(b).cast(DType::F32) * (a_scale * b_scale).expand_rhs((m, n));
|
||||
(scaled_out * side).output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
let egraph = cx.egraph().expect("search space should have an e-graph");
|
||||
|
||||
assert!(
|
||||
dataflow_reachable_cublaslt_scaled_count(egraph) > 0,
|
||||
"scaled cuBLASLt must remain reachable when fusion growth consumes the output-scale multiply internally"
|
||||
);
|
||||
assert_eq!(
|
||||
dataflow_reachable_cublaslt_raw_fp8_count(egraph),
|
||||
0,
|
||||
"raw FP8 cuBLASLt must be deleted when a scaled equivalent covers the fused output-scale consumer"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cublaslt_fp8_scaled_candidates_reach_fused_mlp_consumer() {
|
||||
let (m, n, k) = (16, 32, 16);
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor((m, k));
|
||||
let gate_input_scale = cx.tensor(());
|
||||
let gate_weight_scale = cx.tensor(());
|
||||
let up_input_scale = cx.tensor(());
|
||||
let up_weight_scale = cx.tensor(());
|
||||
let gate_weight = cx.tensor((n, k)).as_dtype(DType::F8E4M3);
|
||||
let up_weight = cx.tensor((n, k)).as_dtype(DType::F8E4M3);
|
||||
|
||||
let scaled_gate_a = (a / gate_input_scale.expand_rhs((m, k))).cast(DType::F8E4M3);
|
||||
let gate = scaled_gate_a.matmul(gate_weight.t()).cast(DType::F32)
|
||||
* (gate_input_scale * gate_weight_scale).expand_rhs((m, n));
|
||||
let scaled_up_a = (a / up_input_scale.expand_rhs((m, k))).cast(DType::F8E4M3);
|
||||
let up = scaled_up_a.matmul(up_weight.t()).cast(DType::F32)
|
||||
* (up_input_scale * up_weight_scale).expand_rhs((m, n));
|
||||
(gate.swish() * up).output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
let egraph = cx.egraph().expect("search space should have an e-graph");
|
||||
|
||||
assert!(
|
||||
dataflow_reachable_cublaslt_scaled_count(egraph) >= 2,
|
||||
"scaled cuBLASLt candidates must remain reachable through fused MLP gate/up consumers"
|
||||
);
|
||||
assert_eq!(
|
||||
dataflow_reachable_cublaslt_raw_fp8_count(egraph),
|
||||
0,
|
||||
"raw FP8 cuBLASLt must be deleted when a scaled equivalent covers the fused MLP consumer"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "expensive CUDA FP8 rewrite sweep; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
fn cublaslt_fp8_scaled_candidate_executes_batched_matmul_f32_output() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
if !gpu_supports_cublaslt_fp8_launch(DType::F8E4M3) {
|
||||
return;
|
||||
}
|
||||
|
||||
let (batch, m, n, k) = (2, 16, 16, 16);
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor((batch, m, k));
|
||||
let a_scale = cx.tensor(());
|
||||
let b_scale = cx.tensor(());
|
||||
let b_input = cx.tensor((batch, n, k)).as_dtype(DType::F8E4M3);
|
||||
let b = b_input.transpose(1, 2);
|
||||
let scaled_a = (a / a_scale.expand_rhs((batch, m, k))).cast(DType::F8E4M3);
|
||||
let lhs = scaled_a.expand_dim(2, n);
|
||||
let rhs = b.permute((0, 2, 1)).expand_dim(1, m);
|
||||
let mul = unchecked_mul_same_shape(lhs, rhs, DType::F8E4M3);
|
||||
let matmul = mul.sum(3).cast(DType::F32);
|
||||
let out = (matmul * (a_scale * b_scale).expand_rhs((batch, m, n))).output();
|
||||
let expected_tuple = (
|
||||
DType::F8E4M3,
|
||||
DType::F8E4M3,
|
||||
DType::F32,
|
||||
DType::F32,
|
||||
"32F",
|
||||
DType::F32,
|
||||
);
|
||||
let llir =
|
||||
extract_forced_cublaslt_llir_where(&mut cx, "functional scaled batched fp8", |llir| {
|
||||
cublaslt_type_tuples(llir).contains(&expected_tuple)
|
||||
&& cublaslt_tensor_scale_input_tuples(llir).contains(&(true, true))
|
||||
&& cublaslt_transpose_op_tuples(llir).contains(&("T", "N"))
|
||||
&& cublaslt_matrix_order_tuples(llir).contains(&("COL", "COL", "COL", "COL"))
|
||||
});
|
||||
|
||||
let input_scale = 0.5f32;
|
||||
let weight_scale = 1.5f32;
|
||||
let (a_fp8_bytes, a_values) = fp8_exact_bytes(DType::F8E4M3, batch * m * k, 11);
|
||||
let a_data = a_values
|
||||
.iter()
|
||||
.map(|value| value * input_scale)
|
||||
.collect::<Vec<_>>();
|
||||
let (b_bytes, b_storage_values) = fp8_exact_bytes(DType::F8E4M3, batch * k * n, 13);
|
||||
let b_values = logical_b_from_batched_column_major_storage(&b_storage_values, batch, n, k);
|
||||
let mut expected = reference_matmul_batched(&a_values, &b_values, batch, m, n, k);
|
||||
for value in &mut expected {
|
||||
*value *= input_scale * weight_scale;
|
||||
}
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.load_llir(&llir);
|
||||
rt.set_data(a, a_data);
|
||||
rt.set_data(a_scale, vec![input_scale]);
|
||||
rt.set_data(b_scale, vec![weight_scale]);
|
||||
rt.set_data(b_input, b_bytes);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
assert_eq!(a_fp8_bytes.len(), batch * m * k);
|
||||
assert_close(&rt.get_f32(out.id), &expected, 1e-5, 1e-5);
|
||||
}
|
||||
|
||||
fn cublaslt_fp8_candidate_executes_2d_matmul_f32_output(a_dtype: DType, b_dtype: DType) {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
@@ -2360,85 +2168,6 @@ fn cublaslt_row_order_candidate_executes_2d_layout_pairs() {
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "large row-order CUDA functional repro for llama lm_head shape"]
|
||||
fn cublaslt_row_order_candidate_executes_large_lm_head_like_projection() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
|
||||
let (m, n, k) = (1, 128_256, 64);
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor((m, k));
|
||||
let b_input = cx.tensor((n, k));
|
||||
let b = b_input.t();
|
||||
let out = a.matmul(b).output();
|
||||
let expected_orders = ("ROW", "COL", "ROW", "ROW");
|
||||
let llir = extract_forced_cublaslt_llir_where(&mut cx, "lm_head-like row-order", |llir| {
|
||||
cublaslt_matrix_order_tuples(llir).contains(&expected_orders)
|
||||
&& cublaslt_scale_value_tuples(llir).contains(&(1.0, 0.0))
|
||||
});
|
||||
|
||||
let a_data = random_f32_vec(m * k, 0x1A11_A000, -0.5, 0.5);
|
||||
let b_data = random_f32_vec(n * k, 0x1A11_B000, -0.5, 0.5);
|
||||
let mut expected = vec![0.0f32; m * n];
|
||||
for col in 0..n {
|
||||
let mut sum = 0.0f32;
|
||||
for kk in 0..k {
|
||||
sum += a_data[kk] * b_data[col * k + kk];
|
||||
}
|
||||
expected[col] = sum;
|
||||
}
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.load_llir(&llir);
|
||||
rt.set_data(a, a_data);
|
||||
rt.set_data(b_input, b_data);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
assert_close(&rt.get_f32(out.id), &expected, 1e-5, 1e-5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "large row-order CUDA functional repro for llama MLP residual beta=1 shape"]
|
||||
fn cublaslt_row_order_beta_one_candidate_executes_llama_mlp_residual_like_projection() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
|
||||
let (m, n, k) = (1, 4096, 64);
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor((m, k));
|
||||
let b_input = cx.tensor((n, k));
|
||||
let b = b_input.t();
|
||||
let c = cx.tensor((m, n));
|
||||
let out = (a.matmul(b) + c).output();
|
||||
let expected_orders = ("ROW", "COL", "ROW", "ROW");
|
||||
let llir = extract_forced_cublaslt_llir_where(&mut cx, "mlp residual row-order", |llir| {
|
||||
cublaslt_matrix_order_tuples(llir).contains(&expected_orders)
|
||||
&& cublaslt_scale_value_tuples(llir).contains(&(1.0, 1.0))
|
||||
});
|
||||
|
||||
let a_data = random_f32_vec(m * k, 0x1A12_A000, -0.5, 0.5);
|
||||
let b_data = random_f32_vec(n * k, 0x1A12_B000, -0.5, 0.5);
|
||||
let c_data = random_f32_vec(m * n, 0x1A12_C000, -0.5, 0.5);
|
||||
let mut expected = c_data.clone();
|
||||
for col in 0..n {
|
||||
for kk in 0..k {
|
||||
expected[col] += a_data[kk] * b_data[col * k + kk];
|
||||
}
|
||||
}
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.load_llir(&llir);
|
||||
rt.set_data(a, a_data);
|
||||
rt.set_data(b_input, b_data);
|
||||
rt.set_data(c, c_data);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
assert_close(&rt.get_f32(out.id), &expected, 1e-5, 1e-5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "expensive CUDA functional candidate sweep; run with cargo test -p luminal_cuda_lite -- --ignored"]
|
||||
fn cublaslt_row_order_candidate_executes_batched_row_major_matmul() {
|
||||
@@ -3036,7 +2765,7 @@ fn cublaslt_ir_nodes(egraph: &SerializedEGraph) -> Vec<&NodeId> {
|
||||
let cublaslt_kind_classes = egraph
|
||||
.enodes
|
||||
.iter()
|
||||
.filter(|(_, (label, _))| label == "cublaslt" || label == "cublaslt_scaled")
|
||||
.filter(|(_, (label, _))| label == "cublaslt")
|
||||
.map(|(node, _)| egraph.node_to_class[node].clone())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
@@ -3053,87 +2782,6 @@ fn cublaslt_ir_nodes(egraph: &SerializedEGraph) -> Vec<&NodeId> {
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn dataflow_reachable_cublaslt_scaled_count(egraph: &SerializedEGraph) -> usize {
|
||||
dataflow_reachable_cublaslt_count(egraph, true)
|
||||
}
|
||||
|
||||
fn dataflow_reachable_cublaslt_raw_fp8_count(egraph: &SerializedEGraph) -> usize {
|
||||
dataflow_reachable_cublaslt_count(egraph, false)
|
||||
}
|
||||
|
||||
fn dataflow_reachable_cublaslt_count(egraph: &SerializedEGraph, scaled: bool) -> usize {
|
||||
let reachable = dataflow_reachable_ir_classes(egraph);
|
||||
egraph
|
||||
.enodes
|
||||
.iter()
|
||||
.filter(|(node, (label, children))| {
|
||||
label == "Op"
|
||||
&& reachable.contains(&egraph.node_to_class[*node])
|
||||
&& children.first().is_some_and(|kind_class| {
|
||||
egraph
|
||||
.eclasses
|
||||
.get(kind_class)
|
||||
.is_some_and(|(_, kind_nodes)| {
|
||||
kind_nodes.iter().any(|kind_node| {
|
||||
egraph.enodes.get(kind_node).is_some_and(|(kind_label, _)| {
|
||||
if scaled {
|
||||
kind_label == "cublaslt_scaled"
|
||||
} else {
|
||||
kind_label == "cublaslt"
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
.count()
|
||||
}
|
||||
|
||||
fn dataflow_reachable_ir_classes(egraph: &SerializedEGraph) -> FxHashSet<ClassId> {
|
||||
let mut reachable = FxHashSet::default();
|
||||
let mut stack = egraph.roots.clone();
|
||||
while let Some(class) = stack.pop() {
|
||||
if !reachable.insert(class.clone()) {
|
||||
continue;
|
||||
}
|
||||
let Some((sort, nodes)) = egraph.eclasses.get(&class) else {
|
||||
continue;
|
||||
};
|
||||
for node in nodes {
|
||||
let Some((label, children)) = egraph.enodes.get(node) else {
|
||||
continue;
|
||||
};
|
||||
match (sort.as_str(), label.as_str()) {
|
||||
("IR", "Output") => {
|
||||
if let Some(child) = children.first() {
|
||||
stack.push(child.clone());
|
||||
}
|
||||
}
|
||||
("IR", "OutputJoin") => stack.extend(children.iter().cloned()),
|
||||
("IR", "Op") => {
|
||||
if let Some(inputs) = children.get(1) {
|
||||
stack.push(inputs.clone());
|
||||
}
|
||||
}
|
||||
("IR", _) => {
|
||||
for child in children {
|
||||
if egraph
|
||||
.eclasses
|
||||
.get(child)
|
||||
.is_some_and(|(child_sort, _)| child_sort == "IR")
|
||||
{
|
||||
stack.push(child.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
("IList", "ICons") => stack.extend(children.iter().cloned()),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
reachable
|
||||
}
|
||||
|
||||
fn llir_has_cublaslt(llir: &LLIRGraph) -> bool {
|
||||
!cublaslt_type_tuples(llir).is_empty()
|
||||
}
|
||||
@@ -3152,13 +2800,6 @@ fn cublaslt_scale_value_tuples(llir: &LLIRGraph) -> Vec<CublasLtScaleValues> {
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn cublaslt_tensor_scale_input_tuples(llir: &LLIRGraph) -> Vec<(bool, bool)> {
|
||||
llir.node_weights()
|
||||
.filter_map(|op| op.to_dialect::<dyn HostOp>())
|
||||
.filter_map(|host_op| cublaslt_tensor_scale_inputs(host_op.as_ref().as_ref()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn cublaslt_epilogues(llir: &LLIRGraph) -> Vec<&'static str> {
|
||||
llir.node_weights()
|
||||
.filter_map(|op| op.to_dialect::<dyn HostOp>())
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
//! 1. Pure egglog metadata (no GPU): trait wiring, sort + rewrite parse cleanly.
|
||||
//! 2. Egglog rule firing (no GPU): the rule unifies on a real paged-attention
|
||||
//! HLIR and does NOT fire on bare attention or unrelated matmul/Gather mixes.
|
||||
//! 3. Mask helper correctness (GPU): the primitive-op `test_compute_attn_mask` builder produces the right (s, c) mask.
|
||||
//! 3. Mask op correctness (GPU): `ComputeAttnMask` produces the right (s, c) mask.
|
||||
//! 4. Full kernel correctness (GPU + JIT): direct `FlashInferAttention::execute`
|
||||
//! compared against a luminal-compiled reference attention graph.
|
||||
//!
|
||||
@@ -18,7 +18,7 @@ use luminal::op::{EgglogOp, IntoEgglogOp};
|
||||
use luminal::prelude::*;
|
||||
|
||||
use crate::host::flashinfer::FlashInferAttention;
|
||||
use crate::host::{DeviceBuffer, HostOp};
|
||||
use crate::host::{ComputeAttnMask, DeviceBuffer, HostOp};
|
||||
use crate::runtime::CudaRuntime;
|
||||
use crate::tests::utilities::get_cuda_stream;
|
||||
|
||||
@@ -285,6 +285,106 @@ fn flashinfer_op_sort_shape() {
|
||||
assert!(dbg.contains("FlashInferAttention"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compute_attn_mask_registers() {
|
||||
assert!(
|
||||
ops_contains_sort("ComputeAttnMask"),
|
||||
"ComputeAttnMask is not in CudaRuntime::Ops"
|
||||
);
|
||||
}
|
||||
|
||||
// ─── Layer 2: ComputeAttnMask correctness ────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn compute_attn_mask_matches_cpu_reference() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
|
||||
// 2 sequences, seq0 length=3, seq1 length=2 → s=2 queries (one per seq, decode),
|
||||
// c=5 total context tokens (3+2).
|
||||
let s_dim = 2usize;
|
||||
let c_dim = 5usize;
|
||||
let q_pos: Vec<i32> = vec![2, 1]; // last position in each seq
|
||||
let qo_indptr: Vec<i32> = vec![0, 1, 2];
|
||||
let kv_indptr: Vec<i32> = vec![0, 3, 5];
|
||||
let r = kv_indptr.len();
|
||||
|
||||
let q_pos_buf = stream
|
||||
.clone_htod(unsafe {
|
||||
std::slice::from_raw_parts(q_pos.as_ptr() as *const u8, q_pos.len() * 4)
|
||||
})
|
||||
.unwrap();
|
||||
let qo_buf = stream
|
||||
.clone_htod(unsafe {
|
||||
std::slice::from_raw_parts(qo_indptr.as_ptr() as *const u8, qo_indptr.len() * 4)
|
||||
})
|
||||
.unwrap();
|
||||
let kv_buf = stream
|
||||
.clone_htod(unsafe {
|
||||
std::slice::from_raw_parts(kv_indptr.as_ptr() as *const u8, kv_indptr.len() * 4)
|
||||
})
|
||||
.unwrap();
|
||||
let out_bytes = s_dim * c_dim * 4;
|
||||
let out_buf = unsafe { stream.alloc::<u8>(out_bytes).unwrap() };
|
||||
|
||||
let op = ComputeAttnMask {
|
||||
s_dim: Expression::from(s_dim),
|
||||
c_dim: Expression::from(c_dim),
|
||||
};
|
||||
|
||||
let q_pos_n = NodeIndex::new(0);
|
||||
let qo_n = NodeIndex::new(1);
|
||||
let kv_n = NodeIndex::new(2);
|
||||
let out_n = NodeIndex::new(3);
|
||||
|
||||
let mut buffers = FxHashMap::default();
|
||||
buffers.insert(
|
||||
q_pos_n,
|
||||
DeviceBuffer::new(q_pos_buf.device_ptr(&stream).0, q_pos.len() * 4),
|
||||
);
|
||||
buffers.insert(
|
||||
qo_n,
|
||||
DeviceBuffer::new(qo_buf.device_ptr(&stream).0, qo_indptr.len() * 4),
|
||||
);
|
||||
buffers.insert(
|
||||
kv_n,
|
||||
DeviceBuffer::new(kv_buf.device_ptr(&stream).0, kv_indptr.len() * 4),
|
||||
);
|
||||
buffers.insert(
|
||||
out_n,
|
||||
DeviceBuffer::new(out_buf.device_ptr(&stream).0, out_bytes),
|
||||
);
|
||||
|
||||
let inputs = [q_pos_n, qo_n, kv_n];
|
||||
let mut dyn_map = FxHashMap::default();
|
||||
dyn_map.insert('r', r);
|
||||
|
||||
op.execute(&stream, out_n, &inputs, &buffers, &dyn_map)
|
||||
.unwrap();
|
||||
stream.synchronize().unwrap();
|
||||
|
||||
let host_bytes = stream.clone_dtoh(&out_buf).unwrap();
|
||||
let mask: Vec<f32> = unsafe {
|
||||
let mut bytes = std::mem::ManuallyDrop::new(host_bytes);
|
||||
let len = bytes.len() / 4;
|
||||
Vec::from_raw_parts(bytes.as_mut_ptr() as *mut f32, len, len)
|
||||
};
|
||||
|
||||
// Expected: query 0 (q_pos=2, seq 0) attends to ctx [0, 3) i.e. mask[0, 0..3]=0;
|
||||
// query 1 (q_pos=1, seq 1) attends to ctx [3, 5) i.e. mask[1, 3..5]=0.
|
||||
// Everywhere else is -1e10.
|
||||
let mut expected = vec![-1e10f32; s_dim * c_dim];
|
||||
for j in 0..3 {
|
||||
expected[0 * c_dim + j] = 0.0;
|
||||
}
|
||||
for j in 3..5 {
|
||||
expected[1 * c_dim + j] = 0.0;
|
||||
}
|
||||
|
||||
assert_eq!(mask, expected);
|
||||
}
|
||||
|
||||
// ─── Layer 3: FlashInfer kernel correctness ──────────────────────────────
|
||||
|
||||
#[test]
|
||||
@@ -427,7 +527,7 @@ fn test_indptr_to_request_idx(
|
||||
n: Expression,
|
||||
) -> GraphTensor {
|
||||
let r = indptr.dims1();
|
||||
let indices = graph.arange(n).expand_dim(1, r);
|
||||
let indices = graph.arange(n.clone()).expand_dim(1, r.clone());
|
||||
let indptr_2d = indptr.expand_dim(0, n);
|
||||
let ge = indptr_2d.le(indices).cast(luminal::dtype::DType::Int);
|
||||
ge.sum(1).cast(luminal::dtype::DType::Int) - 1
|
||||
@@ -441,13 +541,13 @@ fn test_compute_attn_mask(
|
||||
c: Expression,
|
||||
) -> GraphTensor {
|
||||
let s = q_pos.dims1();
|
||||
let q_request = test_indptr_to_request_idx(graph, qo_indptr, s);
|
||||
let c_request = test_indptr_to_request_idx(graph, kv_indptr, c);
|
||||
let c_arange = graph.arange(c);
|
||||
let q_request = test_indptr_to_request_idx(graph, qo_indptr, s.clone());
|
||||
let c_request = test_indptr_to_request_idx(graph, kv_indptr, c.clone());
|
||||
let c_arange = graph.arange(c.clone());
|
||||
let c_kv_start = kv_indptr.gather(c_request);
|
||||
let c_local_pos = c_arange - c_kv_start;
|
||||
let q_req_2d = q_request.expand_dim(1, c);
|
||||
let c_req_2d = c_request.expand_dim(0, s);
|
||||
let q_req_2d = q_request.expand_dim(1, c.clone());
|
||||
let c_req_2d = c_request.expand_dim(0, s.clone());
|
||||
let same = q_req_2d.eq(c_req_2d);
|
||||
let c_pos_2d = c_local_pos.expand_dim(0, s);
|
||||
let qp_2d = q_pos.expand_dim(1, c);
|
||||
@@ -477,7 +577,6 @@ fn scatter_rows(
|
||||
|
||||
/// Handles to every named input of the paged-attention test graph, returned
|
||||
/// alongside the graph so the GA-selection test can `set_data` on each one.
|
||||
#[allow(dead_code)]
|
||||
struct PagedAttnHandles {
|
||||
q_rope: GraphTensor,
|
||||
k_rope: GraphTensor,
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
use as_any::Downcast;
|
||||
use luminal::egglog_utils::{egglog_to_llir, random_initial_choice};
|
||||
use luminal::prelude::*;
|
||||
|
||||
use crate::kernel::KernelOp;
|
||||
use crate::kernel::fusion::{CudaBinaryElementwise, CudaUnaryElementwise};
|
||||
use crate::runtime::CudaRuntime;
|
||||
use crate::tests::utilities::{
|
||||
TOLERANCE_SAFETY_FACTOR, dtype_epsilon, random_f32_vec, test_binary_cuda, test_unary_cuda,
|
||||
@@ -88,7 +86,7 @@ fn test_unary_fusion_preserves_output() {
|
||||
#[test]
|
||||
fn test_three_unary_ops_fuse() {
|
||||
// A chain of 3 pure-elementwise unaries with matching strides should be
|
||||
// reachable as a single marker region containing all three elementwise ops.
|
||||
// reachable as a single marker region containing all three FusedX ops.
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor(16);
|
||||
let _b = a.sin().sqrt().exp2().output();
|
||||
@@ -106,7 +104,7 @@ fn test_three_unary_ops_fuse() {
|
||||
#[test]
|
||||
fn test_four_unary_ops_fuse() {
|
||||
// 4-op chain should collapse into a single marker region containing all
|
||||
// four elementwise ops (one pair-fuse + repeated grow-FE→U firings).
|
||||
// four FusedX ops (one pair-fuse + repeated grow-FE→U firings).
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor(16);
|
||||
let _b = a.sin().sqrt().exp2().log2().output();
|
||||
@@ -319,15 +317,8 @@ fn extract_all_fused_regions(cx: &mut Graph) -> Vec<FusedRegion> {
|
||||
|
||||
let name_of = |idx: NodeIndex| -> Option<String> {
|
||||
llir.node_weight(idx).and_then(|op| {
|
||||
op.to_dialect::<dyn KernelOp>().map(|k| {
|
||||
if let Some(elem) = (***k).downcast_ref::<CudaUnaryElementwise>() {
|
||||
format!("Fused{}", elem.op)
|
||||
} else if let Some(elem) = (***k).downcast_ref::<CudaBinaryElementwise>() {
|
||||
format!("Fused{}", elem.op)
|
||||
} else {
|
||||
k.kernel_name().to_string()
|
||||
}
|
||||
})
|
||||
op.to_dialect::<dyn KernelOp>()
|
||||
.map(|k| k.kernel_name().to_string())
|
||||
})
|
||||
};
|
||||
|
||||
@@ -352,13 +343,12 @@ fn extract_all_fused_regions(cx: &mut Graph) -> Vec<FusedRegion> {
|
||||
|
||||
// Resolve chains of nested FusionStart wrappers (cascade artifact)
|
||||
// to the real external source. A FusionStart whose incoming neighbor
|
||||
// is itself a FusionStart is a cascade layer, not a new external
|
||||
// tensor. A FusionEnd predecessor is a real external region output
|
||||
// in the generic singleton-region model, so do not walk through it.
|
||||
// is itself a FusionStart — or a FusionEnd whose region is fully
|
||||
// inside ours — is a cascade layer, not a new external tensor.
|
||||
let resolve_source = |mut n: NodeIndex| -> NodeIndex {
|
||||
loop {
|
||||
match name_of(n).as_deref() {
|
||||
Some("FusionStart") => {
|
||||
Some("FusionStart") | Some("FusionEnd") => {
|
||||
let mut inc = llir.neighbors_directed(n, petgraph::Direction::Incoming);
|
||||
match inc.next() {
|
||||
Some(p) => n = p,
|
||||
@@ -389,6 +379,15 @@ fn extract_all_fused_regions(cx: &mut Graph) -> Vec<FusedRegion> {
|
||||
let mut inc =
|
||||
llir.neighbors_directed(pred, petgraph::Direction::Incoming);
|
||||
match inc.next() {
|
||||
Some(src_node)
|
||||
if name_of(src_node).as_deref() == Some("FusionEnd") =>
|
||||
{
|
||||
// Merge adjacent regions — treat the FS/FE
|
||||
// pair as internal; walk past the upstream
|
||||
// FE into its region.
|
||||
visited.insert(src_node);
|
||||
stack.push(src_node);
|
||||
}
|
||||
Some(src_node) => {
|
||||
start_sources.insert(resolve_source(src_node));
|
||||
}
|
||||
@@ -468,15 +467,6 @@ fn test_single_binary_does_not_fuse_alone() {
|
||||
fn test_chain_of_binaries_fuses() {
|
||||
// `(a + b) * c`: three external inputs collapse into one region with
|
||||
// internal [Add, Mul] and 3 FusionStarts.
|
||||
//
|
||||
// Requires BB family, which is opt-in at runtime via
|
||||
// LUMINAL_FUSION_FAMILIES. Set it before the graph build so the rules
|
||||
// emitted from FusionEnd::rewrites include the B-B pair-fuse rules.
|
||||
// SAFETY: tests run in parallel; we set this before constructing the
|
||||
// Graph, and never unset, so concurrent tests just see BB on.
|
||||
unsafe {
|
||||
std::env::set_var("LUMINAL_FUSION_FAMILIES", "uu,bu,ub,bb");
|
||||
}
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor(8);
|
||||
let b = cx.tensor(8);
|
||||
@@ -530,13 +520,6 @@ fn test_unary_then_binary_fuses() {
|
||||
}
|
||||
|
||||
#[test]
|
||||
// Subsume in grow rules (introduced to bound the BB partial-FE explosion)
|
||||
// means a multi-consumer producer can no longer be fused into the same
|
||||
// region as all its consumers — only one branch wins. The diamond's `t`
|
||||
// has two consumers, so the structural "one 5-op region" outcome is no
|
||||
// longer guaranteed. Numerical correctness still holds (see
|
||||
// test_diamond_dag_preserves_output).
|
||||
#[ignore = "asserts pre-subsume ideal multi-consumer fusion shape"]
|
||||
fn test_diamond_dag_fuses() {
|
||||
// The canonical diamond-DAG example agreed with the user:
|
||||
// t = a + b; u = exp2(t); v = sin(t); w = u * a; out = w + v
|
||||
@@ -667,7 +650,6 @@ fn test_diamond_dag_preserves_output() {
|
||||
// ---- Marker invariant tests ----
|
||||
|
||||
#[test]
|
||||
#[ignore = "asserts pre-subsume ideal multi-consumer fusion shape"]
|
||||
fn test_fused_region_has_exactly_one_end() {
|
||||
// Design invariant: a fused region always has exactly one FusionEnd.
|
||||
// Uses the diamond DAG so there's real fan-in/out inside the region.
|
||||
@@ -695,7 +677,6 @@ fn test_fused_region_has_exactly_one_end() {
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "asserts pre-subsume ideal multi-consumer fusion shape"]
|
||||
fn test_fused_region_starts_match_distinct_external_tensors() {
|
||||
// Design invariant: FusionStart count == number of distinct external input
|
||||
// tensors, NOT number of edges crossing the boundary. In the diamond DAG
|
||||
@@ -787,10 +768,6 @@ fn test_pair_fuse_binary_to_binary_rhs() {
|
||||
// Pair-fuse B→B (RHS variant): `c * (a + b)`. The inner binary feeds the
|
||||
// outer binary's B input, exercising the mirror direction of the rule
|
||||
// covered by test_chain_of_binaries_fuses.
|
||||
// See test_chain_of_binaries_fuses for the LUMINAL_FUSION_FAMILIES note.
|
||||
unsafe {
|
||||
std::env::set_var("LUMINAL_FUSION_FAMILIES", "uu,bu,ub,bb");
|
||||
}
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor(8);
|
||||
let b = cx.tensor(8);
|
||||
@@ -832,7 +809,6 @@ fn test_grow_fe_to_binary_rhs() {
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "asserts pre-subsume two-FE merge shape; numerical correctness preserved"]
|
||||
fn test_merge_two_regions_at_outer_binary() {
|
||||
// Merge: `(sin(a) + b) + (sqrt(c) + d)`. Each side independently pair-fuses
|
||||
// U→B on its own (the unary gives the inner Add a fusion partner that
|
||||
|
||||
@@ -19,8 +19,4 @@ mod performance_tests;
|
||||
#[cfg(test)]
|
||||
mod qwen3_moe_rewrite;
|
||||
#[cfg(test)]
|
||||
mod rope_test;
|
||||
#[cfg(test)]
|
||||
mod search_equivalence_fuzz;
|
||||
#[cfg(test)]
|
||||
mod transformer;
|
||||
|
||||
@@ -305,7 +305,7 @@ fn fuzz_layer_no_attn(
|
||||
}
|
||||
|
||||
/// Test a SwiGLU MLP with HLIR-only to specifically verify
|
||||
/// the HLIR matmul decomposition (elementwise Mul + KernelSumReduce).
|
||||
/// the HLIR matmul decomposition (KernelMul + KernelSumReduce).
|
||||
fn fuzz_mlp_hlir_only(seq: usize, hidden: usize, intermediate: usize, seed: u64) {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
|
||||
@@ -6,7 +6,7 @@ use crate::cuda_bandwidth_gbps;
|
||||
use crate::runtime::CudaRuntime;
|
||||
|
||||
/// Test that measures bandwidth utilization for a large element-wise add kernel.
|
||||
/// This demonstrates that generic fused Add can achieve reasonable bandwidth with large tensors.
|
||||
/// This demonstrates that KernelAdd can achieve reasonable bandwidth with large tensors.
|
||||
#[test]
|
||||
pub fn kernel_add_bandwidth_test() {
|
||||
// 64M elements = 256MB per tensor, 768MB total memory traffic (2 reads + 1 write)
|
||||
@@ -40,7 +40,7 @@ pub fn kernel_add_bandwidth_test() {
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
// Print stats
|
||||
println!("\n=== Large Fused Add Bandwidth Test ===");
|
||||
println!("\n=== Large KernelAdd Bandwidth Test ===");
|
||||
println!(
|
||||
"Tensor size: {} elements ({} MB per tensor)",
|
||||
size,
|
||||
|
||||
@@ -8,10 +8,10 @@ use crate::{
|
||||
};
|
||||
|
||||
const SEQ: usize = 2;
|
||||
const HIDDEN: usize = 32;
|
||||
const HIDDEN: usize = 16;
|
||||
const NUM_EXPERTS: usize = 8;
|
||||
const TOP_K: usize = 2;
|
||||
const MOE_INTERMEDIATE: usize = 12;
|
||||
const MOE_INTERMEDIATE: usize = 6;
|
||||
const RMS_NORM_EPS: f32 = 1e-6;
|
||||
|
||||
struct QwenMoeGraph {
|
||||
@@ -58,7 +58,6 @@ fn build_qwen_moe_graph() -> QwenMoeGraph {
|
||||
.iota(Expression::from('z') / k_expr * e_dim, top_k_indices.dims());
|
||||
let routing_flat_idx = row_offsets + top_k_indices;
|
||||
let top_k_values = routing_weights.gather(routing_flat_idx);
|
||||
let top_k_values = top_k_values / top_k_values.sum(n - 1).expand_dim(n - 1, TOP_K);
|
||||
|
||||
let gate_up_gathered = gather_experts(x, top_k_indices, gate_up_weights).cast(DType::F32);
|
||||
let x_exp = x.expand_dim(n - 1, TOP_K).unsqueeze(n);
|
||||
@@ -271,7 +270,7 @@ fn test_glumoe_matches_qwen_swiglu_pattern() {
|
||||
return;
|
||||
}
|
||||
|
||||
assert_eq!(modes, vec![GLUMoEMode::SwiGLUNormalized]);
|
||||
assert_eq!(modes, vec![GLUMoEMode::SwiGLU]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -293,7 +292,7 @@ fn test_glumoe_swiglu_matches_unfused_output() {
|
||||
assert!(baseline_modes.is_empty());
|
||||
|
||||
let (actual, fused_modes) = run_qwen_moe(true);
|
||||
assert_eq!(fused_modes, vec![GLUMoEMode::SwiGLUNormalized]);
|
||||
assert_eq!(fused_modes, vec![GLUMoEMode::SwiGLU]);
|
||||
assert_close(&actual, &expected, 3e-2, 3e-2);
|
||||
}
|
||||
|
||||
|
||||
@@ -1,112 +0,0 @@
|
||||
use cudarc::driver::CudaContext;
|
||||
use luminal::{graph::Graph, op::Runtime};
|
||||
|
||||
use crate::{kernel::apply_rope, runtime::CudaRuntime};
|
||||
|
||||
fn cpu_rope(x: &[f32], cos: &[f32], sin: &[f32], s: usize, h: usize, d: usize) -> Vec<f32> {
|
||||
assert!(d.is_multiple_of(2));
|
||||
let mut out = vec![0.0f32; s * h * d];
|
||||
for si in 0..s {
|
||||
for hi in 0..h {
|
||||
for i in 0..d {
|
||||
let xi = x[si * h * d + hi * d + i];
|
||||
let xpair = if i % 2 == 0 {
|
||||
-x[si * h * d + hi * d + i + 1]
|
||||
} else {
|
||||
x[si * h * d + hi * d + i - 1]
|
||||
};
|
||||
let c = cos[si * d + i];
|
||||
let sn = sin[si * d + i];
|
||||
out[si * h * d + hi * d + i] = xi * c + xpair * sn;
|
||||
}
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rope_matches_cpu_reference() {
|
||||
let s = 8;
|
||||
let h = 4;
|
||||
let d = 32;
|
||||
let mut cx = Graph::default();
|
||||
let x = cx.tensor((s, h, d));
|
||||
let cos = cx.tensor((s, d));
|
||||
let sin = cx.tensor((s, d));
|
||||
let y = apply_rope(x, cos, sin).output();
|
||||
|
||||
let x_data: Vec<f32> = (0..s * h * d).map(|i| ((i as f32) * 0.013).sin()).collect();
|
||||
let cos_data: Vec<f32> = (0..s * d).map(|i| ((i as f32) * 0.017).cos()).collect();
|
||||
let sin_data: Vec<f32> = (0..s * d).map(|i| ((i as f32) * 0.017).sin()).collect();
|
||||
|
||||
let ctx = CudaContext::new(0).unwrap();
|
||||
ctx.bind_to_thread().unwrap();
|
||||
let stream = ctx.default_stream();
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.set_data(x, x_data.clone());
|
||||
rt.set_data(cos, cos_data.clone());
|
||||
rt.set_data(sin, sin_data.clone());
|
||||
rt = cx.search(rt, 1);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let got = rt.get_f32(y.id);
|
||||
|
||||
let expected = cpu_rope(&x_data, &cos_data, &sin_data, s, h, d);
|
||||
let mut max_err = 0.0f32;
|
||||
for (g, e) in got.iter().zip(expected.iter()) {
|
||||
let err = (g - e).abs();
|
||||
if err > max_err {
|
||||
max_err = err;
|
||||
}
|
||||
}
|
||||
eprintln!("rope: max abs err: {max_err}");
|
||||
assert!(max_err < 1e-5, "max abs error {max_err} too high");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rope_flux2_shape() {
|
||||
// Flux 2 transformer attention: S=1536 (img+txt), H=48, D=128.
|
||||
let s = 1536;
|
||||
let h = 48;
|
||||
let d = 128;
|
||||
let mut cx = Graph::default();
|
||||
let x = cx.tensor((s, h, d));
|
||||
let cos = cx.tensor((s, d));
|
||||
let sin = cx.tensor((s, d));
|
||||
let y = apply_rope(x, cos, sin).output();
|
||||
|
||||
use rand::{Rng, SeedableRng};
|
||||
let mut rng = rand::rngs::SmallRng::seed_from_u64(11);
|
||||
let x_data: Vec<f32> = (0..s * h * d)
|
||||
.map(|_| rng.random_range(-2.0..2.0_f32))
|
||||
.collect();
|
||||
let cos_data: Vec<f32> = (0..s * d)
|
||||
.map(|_| rng.random_range(-1.0..1.0_f32))
|
||||
.collect();
|
||||
let sin_data: Vec<f32> = (0..s * d)
|
||||
.map(|_| rng.random_range(-1.0..1.0_f32))
|
||||
.collect();
|
||||
|
||||
let ctx = CudaContext::new(0).unwrap();
|
||||
ctx.bind_to_thread().unwrap();
|
||||
let stream = ctx.default_stream();
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.set_data(x, x_data.clone());
|
||||
rt.set_data(cos, cos_data.clone());
|
||||
rt.set_data(sin, sin_data.clone());
|
||||
rt = cx.search(rt, 1);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let got = rt.get_f32(y.id);
|
||||
|
||||
let expected = cpu_rope(&x_data, &cos_data, &sin_data, s, h, d);
|
||||
let mut max_err = 0.0f32;
|
||||
for (g, e) in got.iter().zip(expected.iter()) {
|
||||
let err = (g - e).abs();
|
||||
if err > max_err {
|
||||
max_err = err;
|
||||
}
|
||||
}
|
||||
eprintln!("rope flux2: max abs err: {max_err}");
|
||||
assert!(max_err < 1e-4, "max abs error {max_err} too high");
|
||||
}
|
||||
@@ -1,374 +0,0 @@
|
||||
//! End-to-end e-graph search-space equivalence fuzz tests.
|
||||
//!
|
||||
//! These tests do not compare against a hand-written reference. They assert the
|
||||
//! stronger search invariant: every selectable LLIR graph from the same e-graph
|
||||
//! must produce the same outputs for the same runtime inputs.
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[path = "../../../../examples/llama/src/model.rs"]
|
||||
mod llama_model;
|
||||
|
||||
use half::bf16;
|
||||
use luminal::{dtype::DType, prelude::*, shape::Expression};
|
||||
use rand::{Rng, SeedableRng, rngs::StdRng};
|
||||
|
||||
use super::utilities::{CudaSearchEquivalenceFuzzer, get_cuda_stream, random_f32_vec};
|
||||
|
||||
const SEARCH_EQUIV_SAMPLES: usize = 32;
|
||||
|
||||
fn random_bf16_vec(n: usize, seed: u64, low: f32, high: f32) -> Vec<bf16> {
|
||||
random_f32_vec(n, seed, low, high)
|
||||
.into_iter()
|
||||
.map(bf16::from_f32)
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn rms_norm(x: GraphTensor, weight: GraphTensor, eps: f32) -> GraphTensor {
|
||||
let normed = x.std_norm(x.shape.last_axis(), eps);
|
||||
normed * weight.expand_lhs(&x.dims()[..x.dims().len() - 1])
|
||||
}
|
||||
|
||||
#[allow(clippy::excessive_precision)]
|
||||
fn gemma_gelu(x: GraphTensor) -> GraphTensor {
|
||||
let scaled = 1.5957691216 * x * (1. + 0.044715 * x * x);
|
||||
x * scaled.sigmoid()
|
||||
}
|
||||
|
||||
fn gather_experts(
|
||||
graph_source: GraphTensor,
|
||||
top_k_indices: GraphTensor,
|
||||
weights: GraphTensor,
|
||||
) -> GraphTensor {
|
||||
let (_, d1, d2) = weights.dims3();
|
||||
let io = d1 * d2;
|
||||
let base = top_k_indices * io;
|
||||
let within = graph_source.graph().iota(Expression::from('z'), (d1, d2));
|
||||
let n_base = base.dims().len();
|
||||
let exp_base = base.expand_dim(n_base, d1).expand_dim(n_base + 1, d2);
|
||||
let mut exp_within = within;
|
||||
for (axis, dim) in base.dims().iter().enumerate() {
|
||||
exp_within = exp_within.expand_dim(axis, *dim);
|
||||
}
|
||||
let expert_flat_idx = exp_base + exp_within;
|
||||
weights.gather(expert_flat_idx)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn llama_architecture_search_space_equivalence_fuzz() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
|
||||
const SEQ: usize = 2;
|
||||
const CTX: usize = 3;
|
||||
const SLOTS: usize = 4;
|
||||
|
||||
let config = llama_model::LlamaConfig {
|
||||
layers: 2,
|
||||
hidden: 32,
|
||||
intermediate: 64,
|
||||
head_dim: 8,
|
||||
kv_groups: 2,
|
||||
vocab_size: 64,
|
||||
};
|
||||
|
||||
let mut cx = Graph::default();
|
||||
cx.set_dim('s', SEQ);
|
||||
cx.set_dim('c', CTX);
|
||||
|
||||
let input = cx.named_tensor("input", 's').as_dtype(DType::Int);
|
||||
let q_pos = cx.named_tensor("q_pos", 's').as_dtype(DType::Int);
|
||||
let scatter_idx = cx.named_tensor("scatter_idx", 's').as_dtype(DType::Int);
|
||||
let gather_idx = cx.named_tensor("gather_idx", 'c').as_dtype(DType::Int);
|
||||
let attn_mask = cx.named_tensor("attn_mask", ('s', 'c'));
|
||||
let kv_cache = llama_model::KVCache::new_with_config(&mut cx, SLOTS, config);
|
||||
let llama = llama_model::Llama::init_with_config(&mut cx, config);
|
||||
|
||||
let (logits, cache_outputs) =
|
||||
llama.forward(input, q_pos, scatter_idx, gather_idx, attn_mask, &kv_cache);
|
||||
let logits = logits.output();
|
||||
let mut fuzzer = CudaSearchEquivalenceFuzzer::new(&mut cx, &stream)
|
||||
.seed(0x5EED_1234)
|
||||
.samples(SEARCH_EQUIV_SAMPLES)
|
||||
.generation_size(8)
|
||||
.mutations(3)
|
||||
.build_options(BuildSearchSpaceOptions::new().max_memory_mib(512))
|
||||
.output_f32(logits.id, "logits", 3e-3, 3e-3);
|
||||
for (layer, (k_out, v_out)) in cache_outputs.into_iter().enumerate() {
|
||||
let k_out = k_out.output();
|
||||
let v_out = v_out.output();
|
||||
fuzzer = fuzzer.output_f32(k_out.id, format!("layer{layer}.k_cache"), 3e-3, 3e-3);
|
||||
fuzzer = fuzzer.output_f32(v_out.id, format!("layer{layer}.v_cache"), 3e-3, 3e-3);
|
||||
}
|
||||
|
||||
let mut rng = StdRng::seed_from_u64(0x11A_AA55);
|
||||
fuzzer = fuzzer
|
||||
.input_i32(input.id, vec![3, 17])
|
||||
.input_i32(q_pos.id, vec![1, 2])
|
||||
.input_i32(scatter_idx.id, vec![1, 2])
|
||||
.input_i32(gather_idx.id, vec![0, 1, 2])
|
||||
.input_f32(attn_mask.id, vec![0.0, 0.0, -1e4, 0.0, 0.0, 0.0]);
|
||||
|
||||
let kv_dim = config.kv_dim();
|
||||
for tensor in kv_cache.tensors() {
|
||||
fuzzer = fuzzer.input_f32(tensor.id, vec![0.0; SLOTS * kv_dim]);
|
||||
}
|
||||
for tensor in llama.parameter_tensors() {
|
||||
let elements = tensor
|
||||
.dims()
|
||||
.iter()
|
||||
.map(|dim| dim.to_usize().expect("tiny llama test uses static params"))
|
||||
.product::<usize>();
|
||||
let data = (0..elements)
|
||||
.map(|_| rng.random_range(-0.08f32..0.08f32))
|
||||
.collect::<Vec<_>>();
|
||||
fuzzer = fuzzer.input_f32(tensor.id, data);
|
||||
}
|
||||
|
||||
let report = fuzzer.run();
|
||||
eprintln!("llama search equivalence fuzz report: {report:?}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn gemma_architecture_search_space_equivalence_fuzz() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
|
||||
const SEQ: usize = 2;
|
||||
const HIDDEN: usize = 32;
|
||||
const Q_DIM: usize = 24;
|
||||
const INTERMEDIATE: usize = 64;
|
||||
const EPS: f32 = 1e-6;
|
||||
|
||||
let mut cx = Graph::default();
|
||||
let input = cx.tensor((SEQ, HIDDEN));
|
||||
let attn_norm_w = cx.tensor(HIDDEN);
|
||||
let post_attn_norm_w = cx.tensor(HIDDEN);
|
||||
let pre_ff_norm_w = cx.tensor(HIDDEN);
|
||||
let post_ff_norm_w = cx.tensor(HIDDEN);
|
||||
let proj_w = cx.tensor((Q_DIM, HIDDEN));
|
||||
let o_proj_w = cx.tensor((HIDDEN, Q_DIM));
|
||||
let w_gate = cx.tensor((INTERMEDIATE, HIDDEN));
|
||||
let w_up = cx.tensor((INTERMEDIATE, HIDDEN));
|
||||
let w_down = cx.tensor((HIDDEN, INTERMEDIATE));
|
||||
|
||||
let normed = rms_norm(input, attn_norm_w, EPS);
|
||||
let proj_out = normed.matmul(proj_w.t()).matmul(o_proj_w.t());
|
||||
let attn_normed = rms_norm(proj_out, post_attn_norm_w, EPS);
|
||||
let x = input + attn_normed;
|
||||
let ff_normed = rms_norm(x, pre_ff_norm_w, EPS);
|
||||
let mlp_out =
|
||||
(gemma_gelu(ff_normed.matmul(w_gate.t())) * ff_normed.matmul(w_up.t())).matmul(w_down.t());
|
||||
let mlp_normed = rms_norm(mlp_out, post_ff_norm_w, EPS);
|
||||
let out = (x + mlp_normed).output();
|
||||
|
||||
let report = CudaSearchEquivalenceFuzzer::new(&mut cx, &stream)
|
||||
.seed(0x6E4D_4DAA)
|
||||
.samples(SEARCH_EQUIV_SAMPLES)
|
||||
.generation_size(8)
|
||||
.mutations(3)
|
||||
.build_options(BuildSearchSpaceOptions::new().max_memory_mib(512))
|
||||
.input_f32(input.id, random_f32_vec(SEQ * HIDDEN, 101, -0.15, 0.15))
|
||||
.input_f32(attn_norm_w.id, random_f32_vec(HIDDEN, 102, 0.7, 1.3))
|
||||
.input_f32(post_attn_norm_w.id, random_f32_vec(HIDDEN, 103, 0.7, 1.3))
|
||||
.input_f32(pre_ff_norm_w.id, random_f32_vec(HIDDEN, 104, 0.7, 1.3))
|
||||
.input_f32(post_ff_norm_w.id, random_f32_vec(HIDDEN, 105, 0.7, 1.3))
|
||||
.input_f32(proj_w.id, random_f32_vec(Q_DIM * HIDDEN, 106, -0.08, 0.08))
|
||||
.input_f32(
|
||||
o_proj_w.id,
|
||||
random_f32_vec(HIDDEN * Q_DIM, 107, -0.08, 0.08),
|
||||
)
|
||||
.input_f32(
|
||||
w_gate.id,
|
||||
random_f32_vec(INTERMEDIATE * HIDDEN, 108, -0.08, 0.08),
|
||||
)
|
||||
.input_f32(
|
||||
w_up.id,
|
||||
random_f32_vec(INTERMEDIATE * HIDDEN, 109, -0.08, 0.08),
|
||||
)
|
||||
.input_f32(
|
||||
w_down.id,
|
||||
random_f32_vec(HIDDEN * INTERMEDIATE, 110, -0.08, 0.08),
|
||||
)
|
||||
.output_f32(out.id, "gemma_block", 5e-3, 5e-3)
|
||||
.run();
|
||||
eprintln!("gemma search equivalence fuzz report: {report:?}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn moe_architecture_search_space_equivalence_fuzz() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
|
||||
const SEQ: usize = 2;
|
||||
const HIDDEN: usize = 16;
|
||||
const NUM_EXPERTS: usize = 8;
|
||||
const TOP_K: usize = 2;
|
||||
const MOE_INTERMEDIATE: usize = 6;
|
||||
const EPS: f32 = 1e-6;
|
||||
|
||||
let mut cx = Graph::default();
|
||||
let router_input = cx.tensor(('s', HIDDEN));
|
||||
let expert_input = cx.tensor(('s', HIDDEN));
|
||||
let router_scale = cx.tensor(HIDDEN);
|
||||
let router_proj = cx.tensor((NUM_EXPERTS, HIDDEN));
|
||||
let per_expert_scale = cx.tensor(NUM_EXPERTS);
|
||||
let gate_up_weights = cx
|
||||
.tensor((NUM_EXPERTS, MOE_INTERMEDIATE * 2, HIDDEN))
|
||||
.as_dtype(DType::Bf16);
|
||||
let down_weights = cx
|
||||
.tensor((NUM_EXPERTS, HIDDEN, MOE_INTERMEDIATE))
|
||||
.as_dtype(DType::Bf16);
|
||||
|
||||
let n = router_input.dims().len();
|
||||
let e_dim = *router_proj.dims().first().unwrap();
|
||||
let k_expr = Expression::from(TOP_K);
|
||||
|
||||
let router_hidden = router_input.std_norm(n - 1, EPS)
|
||||
* router_scale.expand_lhs(&router_input.dims()[..n - 1])
|
||||
* (HIDDEN as f32).sqrt().recip();
|
||||
let routing_weights = router_hidden.matmul(router_proj.t()).softmax(n - 1);
|
||||
|
||||
let top_k_indices = routing_weights.topk_indexes(TOP_K, n - 1);
|
||||
let row_offsets = router_input
|
||||
.graph()
|
||||
.iota(Expression::from('z') / k_expr * e_dim, top_k_indices.dims());
|
||||
let routing_flat_idx = row_offsets + top_k_indices;
|
||||
let top_k_values = routing_weights.gather(routing_flat_idx);
|
||||
let top_k_norm = top_k_values.sum(n - 1).expand_dim(n - 1, TOP_K);
|
||||
let top_k_weights = (top_k_values / top_k_norm) * per_expert_scale.gather(top_k_indices);
|
||||
|
||||
let gate_up_gathered =
|
||||
gather_experts(expert_input, top_k_indices, gate_up_weights).cast(DType::F32);
|
||||
let x_exp = expert_input.expand_dim(n - 1, TOP_K).unsqueeze(n);
|
||||
let gate_up_out = x_exp.matmul(gate_up_gathered.transpose(2, 3)).squeeze(n);
|
||||
let gate = gate_up_out.slice((.., .., ..MOE_INTERMEDIATE));
|
||||
let up = gate_up_out.slice((.., .., MOE_INTERMEDIATE..));
|
||||
let hidden = gemma_gelu(gate) * up;
|
||||
|
||||
let down_gathered = gather_experts(expert_input, top_k_indices, down_weights).cast(DType::F32);
|
||||
let down_out = hidden
|
||||
.unsqueeze(2)
|
||||
.matmul(down_gathered.transpose(2, 3))
|
||||
.squeeze(2);
|
||||
let mut weights_exp = top_k_weights.unsqueeze(top_k_weights.dims().len());
|
||||
weights_exp.shape.expand(down_out.dims());
|
||||
let out = (down_out * weights_exp).sum(n - 1).output();
|
||||
cx.set_dim('s', SEQ);
|
||||
|
||||
let report = CudaSearchEquivalenceFuzzer::new(&mut cx, &stream)
|
||||
.seed(0x0DEE_55EE)
|
||||
.samples(SEARCH_EQUIV_SAMPLES)
|
||||
.generation_size(8)
|
||||
.mutations(3)
|
||||
.build_options(BuildSearchSpaceOptions::new().max_memory_mib(512))
|
||||
.input_f32(
|
||||
router_input.id,
|
||||
random_f32_vec(SEQ * HIDDEN, 201, -0.15, 0.15),
|
||||
)
|
||||
.input_f32(
|
||||
expert_input.id,
|
||||
random_f32_vec(SEQ * HIDDEN, 202, -0.15, 0.15),
|
||||
)
|
||||
.input_f32(router_scale.id, random_f32_vec(HIDDEN, 203, 0.7, 1.3))
|
||||
.input_f32(
|
||||
router_proj.id,
|
||||
random_f32_vec(NUM_EXPERTS * HIDDEN, 204, -0.2, 0.2),
|
||||
)
|
||||
.input_f32(
|
||||
per_expert_scale.id,
|
||||
random_f32_vec(NUM_EXPERTS, 205, 0.5, 1.5),
|
||||
)
|
||||
.input_bf16(
|
||||
gate_up_weights.id,
|
||||
random_bf16_vec(NUM_EXPERTS * MOE_INTERMEDIATE * 2 * HIDDEN, 206, -0.1, 0.1),
|
||||
)
|
||||
.input_bf16(
|
||||
down_weights.id,
|
||||
random_bf16_vec(NUM_EXPERTS * HIDDEN * MOE_INTERMEDIATE, 207, -0.1, 0.1),
|
||||
)
|
||||
.output_f32(out.id, "gemma_moe_block", 5e-2, 5e-2)
|
||||
.run();
|
||||
eprintln!("moe search equivalence fuzz report: {report:?}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn moe_architecture_native_reference_fuzz() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
|
||||
const SEQ: usize = 2;
|
||||
const HIDDEN: usize = 16;
|
||||
const NUM_EXPERTS: usize = 8;
|
||||
const TOP_K: usize = 2;
|
||||
const MOE_INTERMEDIATE: usize = 6;
|
||||
|
||||
let mut cx = Graph::default();
|
||||
let input = cx.tensor(('s', HIDDEN));
|
||||
let router = cx.tensor((NUM_EXPERTS, HIDDEN));
|
||||
let gate_up_weights = cx
|
||||
.tensor((NUM_EXPERTS, MOE_INTERMEDIATE * 2, HIDDEN))
|
||||
.as_dtype(DType::Bf16);
|
||||
let down_weights = cx
|
||||
.tensor((NUM_EXPERTS, HIDDEN, MOE_INTERMEDIATE))
|
||||
.as_dtype(DType::Bf16);
|
||||
|
||||
let n = input.dims().len();
|
||||
let e_dim = *router.dims().first().unwrap();
|
||||
let k_expr = Expression::from(TOP_K);
|
||||
|
||||
let routing_weights = input.matmul(router.t()).softmax(n - 1);
|
||||
let top_k_indices = routing_weights.topk_indexes(TOP_K, n - 1);
|
||||
let row_offsets = input
|
||||
.graph()
|
||||
.iota(Expression::from('z') / k_expr * e_dim, top_k_indices.dims());
|
||||
let routing_flat_idx = row_offsets + top_k_indices;
|
||||
let top_k_values = routing_weights.gather(routing_flat_idx);
|
||||
let top_k_weights = top_k_values / top_k_values.sum(n - 1).expand_dim(n - 1, TOP_K);
|
||||
|
||||
let gate_up_gathered = gather_experts(input, top_k_indices, gate_up_weights).cast(DType::F32);
|
||||
let input_exp = input.expand_dim(n - 1, TOP_K).unsqueeze(n);
|
||||
let gate_up_out = input_exp
|
||||
.matmul(gate_up_gathered.transpose(2, 3))
|
||||
.squeeze(n);
|
||||
let gate = gate_up_out.slice((.., .., ..MOE_INTERMEDIATE));
|
||||
let up = gate_up_out.slice((.., .., MOE_INTERMEDIATE..));
|
||||
let hidden = gate.silu() * up;
|
||||
|
||||
let down_gathered = gather_experts(input, top_k_indices, down_weights).cast(DType::F32);
|
||||
let down_out = hidden
|
||||
.unsqueeze(2)
|
||||
.matmul(down_gathered.transpose(2, 3))
|
||||
.squeeze(2);
|
||||
let mut weights_exp = top_k_weights.unsqueeze(top_k_weights.dims().len());
|
||||
weights_exp.shape.expand(down_out.dims());
|
||||
let out = (down_out * weights_exp).sum(n - 1).output();
|
||||
cx.set_dim('s', SEQ);
|
||||
|
||||
let report = CudaSearchEquivalenceFuzzer::new(&mut cx, &stream)
|
||||
.seed(0x51A7_E5ED)
|
||||
.samples(SEARCH_EQUIV_SAMPLES)
|
||||
.generation_size(8)
|
||||
.mutations(3)
|
||||
.build_options(BuildSearchSpaceOptions::new().max_memory_mib(512))
|
||||
.native_reference()
|
||||
.input_f32(input.id, random_f32_vec(SEQ * HIDDEN, 301, -0.15, 0.15))
|
||||
.input_f32(
|
||||
router.id,
|
||||
random_f32_vec(NUM_EXPERTS * HIDDEN, 302, -0.2, 0.2),
|
||||
)
|
||||
.input_bf16(
|
||||
gate_up_weights.id,
|
||||
random_bf16_vec(NUM_EXPERTS * MOE_INTERMEDIATE * 2 * HIDDEN, 303, -0.1, 0.1),
|
||||
)
|
||||
.input_bf16(
|
||||
down_weights.id,
|
||||
random_bf16_vec(NUM_EXPERTS * HIDDEN * MOE_INTERMEDIATE, 304, -0.1, 0.1),
|
||||
)
|
||||
.output_f32(out.id, "qwen_swiglu_moe_native_reference", 6e-2, 6e-2)
|
||||
.run();
|
||||
eprintln!("moe native-reference fuzz report: {report:?}");
|
||||
}
|
||||
@@ -2,8 +2,7 @@ use candle_core::{Device, Tensor, WithDType};
|
||||
use cudarc::driver::CudaContext;
|
||||
use half::{bf16, f16};
|
||||
use luminal::egglog_utils::{
|
||||
EGraphChoiceSet, egglog_to_llir, extract_generation, hash_choice_set, random_initial_choice,
|
||||
validate_choice_set,
|
||||
egglog_to_llir, extract_generation, hash_choice_set, random_initial_choice, validate_choice_set,
|
||||
};
|
||||
use luminal::prelude::*;
|
||||
use num_traits::{Num, Signed};
|
||||
@@ -129,399 +128,6 @@ pub fn get_cuda_stream() -> Option<Arc<cudarc::driver::CudaStream>> {
|
||||
Some(ctx.default_stream())
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum CudaFuzzInput {
|
||||
F32(NodeIndex, Vec<f32>),
|
||||
Bf16(NodeIndex, Vec<bf16>),
|
||||
I32(NodeIndex, Vec<i32>),
|
||||
}
|
||||
|
||||
impl CudaFuzzInput {
|
||||
fn apply(&self, rt: &mut CudaRuntime) {
|
||||
match self {
|
||||
Self::F32(id, data) => rt.set_data(*id, data.clone()),
|
||||
Self::Bf16(id, data) => rt.set_data(*id, data.clone()),
|
||||
Self::I32(id, data) => rt.set_data(*id, data.clone()),
|
||||
}
|
||||
}
|
||||
|
||||
fn apply_native(&self, rt: &mut NativeRuntime) {
|
||||
match self {
|
||||
Self::F32(id, data) => rt.set_data(*id, data.clone()),
|
||||
Self::Bf16(id, data) => rt.set_data(*id, data.clone()),
|
||||
Self::I32(id, data) => rt.set_data(*id, data.clone()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct F32OutputCheck {
|
||||
pub id: NodeIndex,
|
||||
pub name: String,
|
||||
pub rtol: f32,
|
||||
pub atol: f32,
|
||||
}
|
||||
|
||||
impl F32OutputCheck {
|
||||
pub fn new(id: NodeIndex, name: impl Into<String>, rtol: f32, atol: f32) -> Self {
|
||||
Self {
|
||||
id,
|
||||
name: name.into(),
|
||||
rtol,
|
||||
atol,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SearchEquivalenceFuzzConfig {
|
||||
pub seed: u64,
|
||||
pub samples: usize,
|
||||
pub generation_size: usize,
|
||||
pub mutations: usize,
|
||||
pub max_attempts: usize,
|
||||
pub build_options: BuildSearchSpaceOptions,
|
||||
pub reference: SearchEquivalenceReference,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum SearchEquivalenceReference {
|
||||
FirstCudaExtraction,
|
||||
NativeRuntime,
|
||||
}
|
||||
|
||||
impl Default for SearchEquivalenceFuzzConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
seed: 0,
|
||||
samples: 32,
|
||||
generation_size: 16,
|
||||
mutations: 2,
|
||||
max_attempts: 1_000,
|
||||
build_options: BuildSearchSpaceOptions::default(),
|
||||
reference: SearchEquivalenceReference::FirstCudaExtraction,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub struct SearchEquivalenceFuzzReport {
|
||||
pub tested: usize,
|
||||
pub skipped_invalid: usize,
|
||||
}
|
||||
|
||||
pub struct CudaSearchEquivalenceFuzzer<'a> {
|
||||
cx: &'a mut Graph,
|
||||
stream: &'a Arc<cudarc::driver::CudaStream>,
|
||||
inputs: Vec<CudaFuzzInput>,
|
||||
outputs: Vec<F32OutputCheck>,
|
||||
config: SearchEquivalenceFuzzConfig,
|
||||
}
|
||||
|
||||
impl<'a> CudaSearchEquivalenceFuzzer<'a> {
|
||||
pub fn new(cx: &'a mut Graph, stream: &'a Arc<cudarc::driver::CudaStream>) -> Self {
|
||||
Self {
|
||||
cx,
|
||||
stream,
|
||||
inputs: Vec::new(),
|
||||
outputs: Vec::new(),
|
||||
config: SearchEquivalenceFuzzConfig::default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn seed(mut self, seed: u64) -> Self {
|
||||
self.config.seed = seed;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn samples(mut self, samples: usize) -> Self {
|
||||
self.config.samples = samples;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn generation_size(mut self, generation_size: usize) -> Self {
|
||||
self.config.generation_size = generation_size;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn mutations(mut self, mutations: usize) -> Self {
|
||||
self.config.mutations = mutations;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build_options(mut self, build_options: BuildSearchSpaceOptions) -> Self {
|
||||
self.config.build_options = build_options;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn native_reference(mut self) -> Self {
|
||||
self.config.reference = SearchEquivalenceReference::NativeRuntime;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn input_f32(mut self, id: NodeIndex, data: Vec<f32>) -> Self {
|
||||
self.inputs.push(CudaFuzzInput::F32(id, data));
|
||||
self
|
||||
}
|
||||
|
||||
pub fn input_bf16(mut self, id: NodeIndex, data: Vec<bf16>) -> Self {
|
||||
self.inputs.push(CudaFuzzInput::Bf16(id, data));
|
||||
self
|
||||
}
|
||||
|
||||
pub fn input_i32(mut self, id: NodeIndex, data: Vec<i32>) -> Self {
|
||||
self.inputs.push(CudaFuzzInput::I32(id, data));
|
||||
self
|
||||
}
|
||||
|
||||
pub fn output_f32(
|
||||
mut self,
|
||||
id: NodeIndex,
|
||||
name: impl Into<String>,
|
||||
rtol: f32,
|
||||
atol: f32,
|
||||
) -> Self {
|
||||
self.outputs.push(F32OutputCheck::new(id, name, rtol, atol));
|
||||
self
|
||||
}
|
||||
|
||||
pub fn run(self) -> SearchEquivalenceFuzzReport {
|
||||
fuzz_cuda_search_space_equivalence(
|
||||
self.cx,
|
||||
self.stream,
|
||||
&self.inputs,
|
||||
&self.outputs,
|
||||
self.config,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// End-to-end search-space equivalence fuzzing for CUDA.
|
||||
///
|
||||
/// This builds the normal CUDA e-graph search space, extracts random selectable
|
||||
/// LLIR graphs, runs each with identical inputs, and verifies every requested
|
||||
/// f32 output matches the first valid extraction. The reference is intentionally
|
||||
/// another selected LLIR graph, not a hand-written CPU implementation: this
|
||||
/// catches cases where supposedly equivalent e-graph choices diverge.
|
||||
pub fn fuzz_cuda_search_space_equivalence(
|
||||
cx: &mut Graph,
|
||||
stream: &Arc<cudarc::driver::CudaStream>,
|
||||
inputs: &[CudaFuzzInput],
|
||||
outputs: &[F32OutputCheck],
|
||||
config: SearchEquivalenceFuzzConfig,
|
||||
) -> SearchEquivalenceFuzzReport {
|
||||
assert!(
|
||||
!outputs.is_empty(),
|
||||
"fuzz harness needs at least one output"
|
||||
);
|
||||
|
||||
let native_reference_outputs = if config.reference == SearchEquivalenceReference::NativeRuntime
|
||||
{
|
||||
cx.build_search_space::<NativeRuntime>();
|
||||
let mut native_rng = StdRng::seed_from_u64(config.seed);
|
||||
let mut native_rt = cx.search_options(
|
||||
NativeRuntime::default(),
|
||||
SearchOptions::new(1),
|
||||
&mut native_rng,
|
||||
);
|
||||
for input in inputs {
|
||||
input.apply_native(&mut native_rt);
|
||||
}
|
||||
native_rt.execute(&cx.dyn_map);
|
||||
Some(
|
||||
outputs
|
||||
.iter()
|
||||
.map(|out| native_rt.get_f32(out.id).clone())
|
||||
.collect::<Vec<_>>(),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
cx.build_search_space_with_options::<CudaRuntime>(config.build_options);
|
||||
|
||||
let egraph = cx.egraph().expect("search space should be built");
|
||||
let ops = cx.egglog_ops().expect("search ops should be built");
|
||||
let seed = if native_reference_outputs.is_some() {
|
||||
config.seed.wrapping_add(0xC0DA_C0DA)
|
||||
} else {
|
||||
config.seed
|
||||
};
|
||||
let mut rng = StdRng::seed_from_u64(seed);
|
||||
let mut prev_selected = FxHashSet::default();
|
||||
let mut base = random_initial_choice(egraph, &mut rng);
|
||||
prev_selected.insert(hash_choice_set(&base));
|
||||
|
||||
let mut skipped_invalid = 0usize;
|
||||
let reference_is_cuda = native_reference_outputs.is_none();
|
||||
let (reference_hash, reference_outputs, mut tested) =
|
||||
if let Some(reference_outputs) = native_reference_outputs {
|
||||
(0, reference_outputs, 0usize)
|
||||
} else {
|
||||
let mut attempts = 0usize;
|
||||
let (reference_hash, reference_outputs) = loop {
|
||||
attempts += 1;
|
||||
if attempts > config.max_attempts {
|
||||
panic!(
|
||||
"failed to extract a valid reference LLIR after {} attempts",
|
||||
config.max_attempts
|
||||
);
|
||||
}
|
||||
if validate_choice_set(egraph, &base, ops).is_err() {
|
||||
skipped_invalid += 1;
|
||||
} else {
|
||||
let hash = hash_choice_set(&base);
|
||||
match run_choice_outputs(cx, stream, inputs, outputs, &base) {
|
||||
Ok(values) => break (hash, values),
|
||||
Err(err) => {
|
||||
skipped_invalid += 1;
|
||||
eprintln!("skipping invalid reference candidate hash={hash}: {err}");
|
||||
}
|
||||
}
|
||||
}
|
||||
base = random_initial_choice(egraph, &mut rng);
|
||||
prev_selected.insert(hash_choice_set(&base));
|
||||
};
|
||||
(reference_hash, reference_outputs, 1usize)
|
||||
};
|
||||
|
||||
let mut attempts = 0usize;
|
||||
while tested < config.samples && attempts < config.max_attempts {
|
||||
attempts += 1;
|
||||
let mut candidates = extract_generation(
|
||||
egraph,
|
||||
&base,
|
||||
config.generation_size,
|
||||
config.mutations,
|
||||
&mut prev_selected,
|
||||
&mut rng,
|
||||
);
|
||||
if candidates.is_empty() {
|
||||
let next = random_initial_choice(egraph, &mut rng);
|
||||
prev_selected.insert(hash_choice_set(&next));
|
||||
candidates.push(next);
|
||||
}
|
||||
|
||||
for candidate in candidates {
|
||||
if tested >= config.samples {
|
||||
break;
|
||||
}
|
||||
let candidate_hash = hash_choice_set(&candidate);
|
||||
if reference_is_cuda && candidate_hash == reference_hash {
|
||||
continue;
|
||||
}
|
||||
if validate_choice_set(egraph, &candidate, ops).is_err() {
|
||||
skipped_invalid += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
let candidate_outputs = run_choice_outputs(cx, stream, inputs, outputs, &candidate)
|
||||
.unwrap_or_else(|err| panic!("candidate hash={candidate_hash} failed: {err}"));
|
||||
assert_fuzz_outputs_close(
|
||||
outputs,
|
||||
&reference_outputs,
|
||||
&candidate_outputs,
|
||||
reference_hash,
|
||||
candidate_hash,
|
||||
);
|
||||
base = candidate;
|
||||
tested += 1;
|
||||
}
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
tested, config.samples,
|
||||
"only tested {tested}/{} LLIR samples before exhausting attempts",
|
||||
config.samples
|
||||
);
|
||||
SearchEquivalenceFuzzReport {
|
||||
tested,
|
||||
skipped_invalid,
|
||||
}
|
||||
}
|
||||
|
||||
fn run_choice_outputs<'a>(
|
||||
cx: &'a Graph,
|
||||
stream: &Arc<cudarc::driver::CudaStream>,
|
||||
inputs: &[CudaFuzzInput],
|
||||
outputs: &[F32OutputCheck],
|
||||
choices: &EGraphChoiceSet<'a>,
|
||||
) -> Result<Vec<Vec<f32>>, String> {
|
||||
let egraph = cx.egraph().ok_or("search space was not built")?;
|
||||
let ops = cx.egglog_ops().ok_or("search ops were not built")?;
|
||||
let mut list_cache = FxHashMap::default();
|
||||
let mut expr_cache = FxHashMap::default();
|
||||
let mut llir_graph = egglog_to_llir(
|
||||
egraph,
|
||||
choices.clone(),
|
||||
ops,
|
||||
&cx.custom_ops,
|
||||
&mut list_cache,
|
||||
&mut expr_cache,
|
||||
None,
|
||||
);
|
||||
unroll_loops_in_llir(&mut llir_graph);
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream.clone());
|
||||
rt.load_llir(&llir_graph);
|
||||
for input in inputs {
|
||||
input.apply(&mut rt);
|
||||
}
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
Ok(outputs.iter().map(|out| rt.get_f32(out.id)).collect())
|
||||
}
|
||||
|
||||
fn assert_fuzz_outputs_close(
|
||||
outputs: &[F32OutputCheck],
|
||||
expected: &[Vec<f32>],
|
||||
actual: &[Vec<f32>],
|
||||
reference_hash: u64,
|
||||
candidate_hash: u64,
|
||||
) {
|
||||
for ((spec, expected), actual) in outputs.iter().zip(expected.iter()).zip(actual.iter()) {
|
||||
assert_eq!(
|
||||
expected.len(),
|
||||
actual.len(),
|
||||
"output {} length mismatch for candidate hash={candidate_hash} reference hash={reference_hash}",
|
||||
spec.name
|
||||
);
|
||||
let mut max_abs = 0.0f32;
|
||||
let mut max_rel = 0.0f32;
|
||||
let mut worst = 0usize;
|
||||
for (i, (&a, &b)) in actual.iter().zip(expected.iter()).enumerate() {
|
||||
assert!(
|
||||
a.is_finite(),
|
||||
"output {} candidate hash={candidate_hash} produced non-finite value {a} at index {i}",
|
||||
spec.name
|
||||
);
|
||||
assert!(
|
||||
b.is_finite(),
|
||||
"output {} reference hash={reference_hash} produced non-finite value {b} at index {i}",
|
||||
spec.name
|
||||
);
|
||||
let abs = (a - b).abs();
|
||||
let rel = abs / b.abs().max(1e-12);
|
||||
if abs > max_abs {
|
||||
max_abs = abs;
|
||||
max_rel = rel;
|
||||
worst = i;
|
||||
}
|
||||
if abs > spec.atol + spec.rtol * b.abs() {
|
||||
panic!(
|
||||
"output {} mismatch candidate hash={candidate_hash} reference hash={reference_hash} index={i} actual={a} expected={b} abs={abs} rel={rel} tolerance={}",
|
||||
spec.name,
|
||||
spec.atol + spec.rtol * b.abs()
|
||||
);
|
||||
}
|
||||
}
|
||||
eprintln!(
|
||||
"fuzz output {} ok: candidate hash={candidate_hash} max_abs={max_abs} max_rel={max_rel} worst={worst}",
|
||||
spec.name
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the GPU compute capability as (major, minor).
|
||||
pub fn gpu_compute_cap() -> Option<(i32, i32)> {
|
||||
let ctx = CudaContext::new(0).ok()?;
|
||||
|
||||
@@ -1,21 +1,18 @@
|
||||
[package]
|
||||
name = "luminal_metal"
|
||||
version = "0.2.0"
|
||||
edition = "2024"
|
||||
edition = "2021"
|
||||
description = "Metal backend for luminal"
|
||||
license = "MIT OR Apache-2.0"
|
||||
|
||||
[dependencies]
|
||||
luminal = { path = "../.." }
|
||||
metal = { version = "0.31", features = ["mps"] }
|
||||
metal = "0.31"
|
||||
objc = "0.2"
|
||||
as-any = "0.3.2"
|
||||
itertools = "0.12.1"
|
||||
half = { version = "2.7.1", features = ["bytemuck"] }
|
||||
half = "2.7.1"
|
||||
tracing = "0.1.43"
|
||||
safetensors = "0.7.0"
|
||||
memmap2 = "0.9.9"
|
||||
bytemuck = "1.24.0"
|
||||
|
||||
[dev-dependencies]
|
||||
candle-core = "0.9.2-alpha.1"
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
//! [`DynBackend`] implementation for the Metal runtime.
|
||||
|
||||
use luminal::dtype::DType;
|
||||
use luminal::dyn_backend::{BackendCompileArgs, DynBackend, bytes_to_native_data, compile_backend};
|
||||
use luminal::dyn_backend::{bytes_to_native_data, compile_backend, BackendCompileArgs, DynBackend};
|
||||
use luminal::prelude::*;
|
||||
|
||||
use crate::runtime::MetalRuntime;
|
||||
|
||||
@@ -1,5 +1,227 @@
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum MPSMatrixLayout {
|
||||
RowMajor,
|
||||
TransposedRowMajor,
|
||||
use super::{MetalMulInfo, MetalSumReduceInfo};
|
||||
use luminal::prelude::*;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
|
||||
pub enum MetalMatmulFamily {
|
||||
#[default]
|
||||
Naive,
|
||||
RegularTiled,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MatmulDescriptor {
|
||||
pub m: Expression,
|
||||
pub n: Expression,
|
||||
pub k: Expression,
|
||||
pub batch_shape: Vec<Expression>,
|
||||
pub lhs_strides: Vec<Expression>,
|
||||
pub rhs_strides: Vec<Expression>,
|
||||
pub out_strides: Vec<Expression>,
|
||||
pub transpose_lhs: bool,
|
||||
pub transpose_rhs: bool,
|
||||
}
|
||||
|
||||
impl MatmulDescriptor {
|
||||
pub fn from_mul_and_sum(
|
||||
mul_info: &MetalMulInfo,
|
||||
sum_info: &MetalSumReduceInfo,
|
||||
) -> Option<Self> {
|
||||
let zero = Expression::from(0);
|
||||
let z = Expression::from('z');
|
||||
|
||||
let is_simple_2d_matmul = mul_info.shape.len() == 3
|
||||
&& sum_info.shape.len() == 2
|
||||
&& mul_info.a_strides.len() == 3
|
||||
&& mul_info.b_strides.len() == 3
|
||||
&& sum_info.strides.len() == 2
|
||||
&& mul_info.shape[0] == sum_info.shape[0]
|
||||
&& mul_info.shape[1] == sum_info.shape[1]
|
||||
&& mul_info.shape[2] == sum_info.iters
|
||||
&& mul_info.a_strides[1] == zero
|
||||
&& mul_info.a_strides[2] == z
|
||||
&& mul_info.b_strides[0] == zero
|
||||
&& mul_info.b_strides[1] == z
|
||||
&& sum_info.strides[1] == z
|
||||
&& sum_info.iter_stride == z;
|
||||
|
||||
if !is_simple_2d_matmul {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(Self {
|
||||
m: sum_info.shape[0],
|
||||
n: sum_info.shape[1],
|
||||
k: sum_info.iters,
|
||||
batch_shape: Vec::new(),
|
||||
lhs_strides: mul_info.a_strides.clone(),
|
||||
rhs_strides: mul_info.b_strides.clone(),
|
||||
out_strides: sum_info.strides.clone(),
|
||||
transpose_lhs: false,
|
||||
transpose_rhs: false,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MatmulPlan {
|
||||
pub family: MetalMatmulFamily,
|
||||
pub m: Expression,
|
||||
pub n: Expression,
|
||||
pub k: Expression,
|
||||
pub lda: Expression,
|
||||
pub ldb: Expression,
|
||||
pub ldd: Expression,
|
||||
pub batch_size: u32,
|
||||
pub batch_stride_a: u32,
|
||||
pub batch_stride_b: u32,
|
||||
pub batch_stride_d: u32,
|
||||
pub bm: u16,
|
||||
pub bn: u16,
|
||||
pub bk: u16,
|
||||
pub wm: u16,
|
||||
pub wn: u16,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Clone, Copy)]
|
||||
pub struct MetalMatmulPlanner;
|
||||
|
||||
impl MetalMatmulPlanner {
|
||||
pub fn plan(&self, desc: &MatmulDescriptor) -> MatmulPlan {
|
||||
let family = if desc.batch_shape.is_empty()
|
||||
&& desc.m.as_num().is_some_and(|m| m >= 32)
|
||||
&& desc.n.as_num().is_some_and(|n| n >= 32)
|
||||
&& desc.k.as_num().is_some_and(|k| k >= 32)
|
||||
{
|
||||
MetalMatmulFamily::RegularTiled
|
||||
} else {
|
||||
MetalMatmulFamily::Naive
|
||||
};
|
||||
MatmulPlan {
|
||||
family,
|
||||
m: desc.m,
|
||||
n: desc.n,
|
||||
k: desc.k,
|
||||
lda: desc.lhs_strides[0],
|
||||
ldb: desc.rhs_strides[2],
|
||||
ldd: desc.out_strides[0],
|
||||
batch_size: 1,
|
||||
batch_stride_a: 0,
|
||||
batch_stride_b: 0,
|
||||
batch_stride_d: 0,
|
||||
bm: 16,
|
||||
bn: 16,
|
||||
bk: 8,
|
||||
wm: 2,
|
||||
wn: 2,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn descriptor_recovers_simple_2d_matmul() {
|
||||
let mul = MetalMulInfo {
|
||||
shape: vec![
|
||||
Expression::from(4),
|
||||
Expression::from(8),
|
||||
Expression::from(16),
|
||||
],
|
||||
a_strides: vec![
|
||||
Expression::from('z') * 16,
|
||||
Expression::from(0),
|
||||
Expression::from('z'),
|
||||
],
|
||||
b_strides: vec![
|
||||
Expression::from(0),
|
||||
Expression::from('z'),
|
||||
Expression::from('z') * 8,
|
||||
],
|
||||
output_strides: vec![
|
||||
Expression::from('z') * 16,
|
||||
Expression::from('z') * 8,
|
||||
Expression::from('z'),
|
||||
],
|
||||
};
|
||||
let sum = MetalSumReduceInfo {
|
||||
shape: vec![Expression::from(4), Expression::from(8)],
|
||||
strides: vec![Expression::from('z') * 8, Expression::from('z')],
|
||||
iters: Expression::from(16),
|
||||
iter_stride: Expression::from('z'),
|
||||
};
|
||||
|
||||
let desc = MatmulDescriptor::from_mul_and_sum(&mul, &sum).unwrap();
|
||||
assert_eq!(desc.m, Expression::from(4));
|
||||
assert_eq!(desc.n, Expression::from(8));
|
||||
assert_eq!(desc.k, Expression::from(16));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn planner_keeps_small_problems_on_naive_path() {
|
||||
let desc = MatmulDescriptor {
|
||||
m: Expression::from(4),
|
||||
n: Expression::from(8),
|
||||
k: Expression::from(16),
|
||||
batch_shape: Vec::new(),
|
||||
lhs_strides: vec![
|
||||
Expression::from('z') * 16,
|
||||
Expression::from(0),
|
||||
Expression::from('z'),
|
||||
],
|
||||
rhs_strides: vec![
|
||||
Expression::from(0),
|
||||
Expression::from('z'),
|
||||
Expression::from('z') * 8,
|
||||
],
|
||||
out_strides: vec![Expression::from('z') * 8, Expression::from('z')],
|
||||
transpose_lhs: false,
|
||||
transpose_rhs: false,
|
||||
};
|
||||
|
||||
let planner = MetalMatmulPlanner;
|
||||
let plan = planner.plan(&desc);
|
||||
assert_eq!(plan.family, MetalMatmulFamily::Naive);
|
||||
assert_eq!(plan.bm, 16);
|
||||
assert_eq!(plan.bn, 16);
|
||||
assert_eq!(plan.bk, 8);
|
||||
assert_eq!(plan.wm, 2);
|
||||
assert_eq!(plan.wn, 2);
|
||||
assert_eq!(plan.lda, Expression::from('z') * 16);
|
||||
assert_eq!(plan.ldb, Expression::from('z') * 8);
|
||||
assert_eq!(plan.ldd, Expression::from('z') * 8);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn planner_promotes_large_problems_to_regular_tiled() {
|
||||
let desc = MatmulDescriptor {
|
||||
m: Expression::from(64),
|
||||
n: Expression::from(64),
|
||||
k: Expression::from(64),
|
||||
batch_shape: Vec::new(),
|
||||
lhs_strides: vec![
|
||||
Expression::from('z') * 64,
|
||||
Expression::from(0),
|
||||
Expression::from('z'),
|
||||
],
|
||||
rhs_strides: vec![
|
||||
Expression::from(0),
|
||||
Expression::from('z'),
|
||||
Expression::from('z') * 64,
|
||||
],
|
||||
out_strides: vec![Expression::from('z') * 64, Expression::from('z')],
|
||||
transpose_lhs: false,
|
||||
transpose_rhs: false,
|
||||
};
|
||||
|
||||
let planner = MetalMatmulPlanner;
|
||||
let plan = planner.plan(&desc);
|
||||
assert_eq!(plan.family, MetalMatmulFamily::RegularTiled);
|
||||
assert_eq!(plan.bm, 16);
|
||||
assert_eq!(plan.bn, 16);
|
||||
assert_eq!(plan.bk, 8);
|
||||
assert_eq!(plan.wm, 2);
|
||||
assert_eq!(plan.wn, 2);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ pub use ops::*;
|
||||
use luminal::dtype::DType;
|
||||
use luminal::op::EgglogOp;
|
||||
use luminal::prelude::*;
|
||||
use metal::{Buffer, CommandBufferRef, ComputeCommandEncoderRef, ComputePipelineState, Device};
|
||||
use metal::{Buffer, ComputeCommandEncoderRef, ComputePipelineState, Device};
|
||||
|
||||
pub const DYN_SLOT_COUNT: usize = 26;
|
||||
|
||||
@@ -32,7 +32,7 @@ pub trait MetalKernelOp: EgglogOp {
|
||||
device: &Device,
|
||||
input_dtypes: &[DType],
|
||||
output_dtype: DType,
|
||||
) -> Option<ComputePipelineState>;
|
||||
) -> ComputePipelineState;
|
||||
|
||||
fn infer_output_dtype(&self, input_dtypes: &[DType]) -> DType {
|
||||
input_dtypes.first().copied().unwrap_or(DType::F32)
|
||||
@@ -40,7 +40,7 @@ pub trait MetalKernelOp: EgglogOp {
|
||||
|
||||
fn output_size(&self) -> Expression;
|
||||
|
||||
fn encode_compute(
|
||||
fn encode(
|
||||
&self,
|
||||
encoder: &ComputeCommandEncoderRef,
|
||||
pipeline: &ComputePipelineState,
|
||||
@@ -49,26 +49,6 @@ pub trait MetalKernelOp: EgglogOp {
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
);
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn encode(
|
||||
&self,
|
||||
command_buffer: &CommandBufferRef,
|
||||
pipeline: Option<&ComputePipelineState>,
|
||||
inputs: &[&Buffer],
|
||||
output: &Buffer,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
dyn_buffer: &Buffer,
|
||||
_input_dtypes: &[DType],
|
||||
_output_dtype: DType,
|
||||
) {
|
||||
let pipeline = pipeline.expect("compute pipeline not compiled");
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
let dyn_idx = inputs.len() as u64 + 1;
|
||||
encoder.set_buffer(dyn_idx, Some(dyn_buffer), 0);
|
||||
self.encode_compute(encoder, pipeline, inputs, output, dyn_map);
|
||||
encoder.end_encoding();
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// Performance Metrics for MBU/MFU Calculation
|
||||
// ========================================================================
|
||||
@@ -93,10 +73,6 @@ pub trait MetalKernelOp: EgglogOp {
|
||||
None
|
||||
}
|
||||
|
||||
fn output_aliases_input(&self) -> Option<usize> {
|
||||
None
|
||||
}
|
||||
|
||||
fn is_matmul(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,31 +1,21 @@
|
||||
use crate::kernel::{DYN_SLOT_COUNT, MetalKernelOp};
|
||||
use half::{bf16, f16};
|
||||
use crate::kernel::{
|
||||
MatmulDescriptor, MetalKernelOp, MetalMatmul, MetalMatmulPlanner, DYN_SLOT_COUNT,
|
||||
};
|
||||
use half::f16;
|
||||
use itertools::Itertools;
|
||||
use luminal::{
|
||||
dtype::DType,
|
||||
graph::{BucketLLIR, DimBucket, Graph, LLIRGraph},
|
||||
graph::LLIRGraph,
|
||||
hlir::{Input, NativeData, Output},
|
||||
op::{ExecutionStats, Runtime, RuntimeStats, TimingMethod},
|
||||
prelude::{
|
||||
petgraph::{algo::toposort, prelude::StableGraph, visit::EdgeRef, Direction},
|
||||
FxHashMap, NodeIndex, ToId,
|
||||
petgraph::{Direction, algo::toposort, prelude::StableGraph, visit::EdgeRef},
|
||||
},
|
||||
};
|
||||
use memmap2::MmapOptions;
|
||||
use metal::{Buffer, CommandQueue, ComputePipelineState, Device, MTLResourceOptions};
|
||||
use objc::rc::autoreleasepool;
|
||||
use objc::runtime::Object;
|
||||
use safetensors::{Dtype, SafeTensors};
|
||||
use std::{fs::File, time::Duration};
|
||||
|
||||
#[derive(Clone)]
|
||||
struct MetalCompiledBucket {
|
||||
bucket_indices: FxHashMap<char, usize>,
|
||||
llir_graph: LLIRGraph,
|
||||
node_dtypes: FxHashMap<NodeIndex, DType>,
|
||||
pipelines: FxHashMap<NodeIndex, ComputePipelineState>,
|
||||
output_alias_map: FxHashMap<NodeIndex, NodeIndex>,
|
||||
}
|
||||
use std::time::Duration;
|
||||
|
||||
pub struct MetalRuntime {
|
||||
device: Device,
|
||||
@@ -44,124 +34,83 @@ pub struct MetalRuntime {
|
||||
node_dtypes: FxHashMap<NodeIndex, DType>,
|
||||
/// Compiled pipeline states for each kernel node
|
||||
pipelines: FxHashMap<NodeIndex, ComputePipelineState>,
|
||||
/// LLIR output node -> input node whose buffer contains the output.
|
||||
output_alias_map: FxHashMap<NodeIndex, NodeIndex>,
|
||||
/// Bucket definitions for dynamic dimensions.
|
||||
dim_buckets: FxHashMap<char, Vec<DimBucket>>,
|
||||
/// Compiled LLIR variants, one per bucket combination.
|
||||
compiled_buckets: Vec<MetalCompiledBucket>,
|
||||
/// Currently active compiled bucket.
|
||||
active_bucket: usize,
|
||||
}
|
||||
|
||||
impl MetalRuntime {
|
||||
fn input_dtype(&self, id: NodeIndex) -> Option<DType> {
|
||||
self.llir_graph.node_indices().find_map(|node| {
|
||||
self.llir_graph[node]
|
||||
.to_op::<Input>()
|
||||
.and_then(|input| (input.node == id.index()).then_some(input.dtype))
|
||||
})
|
||||
}
|
||||
fn fuse_matmuls(llir_graph: &LLIRGraph) -> LLIRGraph {
|
||||
let mut graph = llir_graph.clone();
|
||||
let planner = MetalMatmulPlanner;
|
||||
let mut rewrites = Vec::new();
|
||||
|
||||
fn output_data_node(&self, id: NodeIndex) -> NodeIndex {
|
||||
let output_id = self
|
||||
.llir_graph
|
||||
.node_indices()
|
||||
.find(|n| {
|
||||
if let Some(Output { node }) = self.llir_graph[*n].to_op::<Output>() {
|
||||
*node == id.index()
|
||||
} else {
|
||||
false
|
||||
}
|
||||
})
|
||||
.expect("Cannot find output tensor!");
|
||||
for sum_node in graph.node_indices().collect::<Vec<_>>() {
|
||||
let Some(sum_info) = graph[sum_node]
|
||||
.to_dialect::<dyn MetalKernelOp>()
|
||||
.and_then(|op| op.sum_reduce_info())
|
||||
else {
|
||||
continue;
|
||||
};
|
||||
|
||||
self.llir_graph
|
||||
.neighbors_directed(output_id, Direction::Incoming)
|
||||
.next()
|
||||
.unwrap()
|
||||
}
|
||||
let input_edges: Vec<_> = graph
|
||||
.edges_directed(sum_node, Direction::Incoming)
|
||||
.sorted_by_key(|e| e.id())
|
||||
.map(|e| e.source())
|
||||
.collect();
|
||||
if input_edges.len() != 1 {
|
||||
continue;
|
||||
}
|
||||
|
||||
fn follow_aliases(&self, mut node: NodeIndex) -> NodeIndex {
|
||||
while let Some(target) = self.output_alias_map.get(&node) {
|
||||
node = *target;
|
||||
let mul_node = input_edges[0];
|
||||
let Some(mul_info) = graph[mul_node]
|
||||
.to_dialect::<dyn MetalKernelOp>()
|
||||
.and_then(|op| op.mul_info())
|
||||
else {
|
||||
continue;
|
||||
};
|
||||
|
||||
let Some(desc) = MatmulDescriptor::from_mul_and_sum(&mul_info, &sum_info) else {
|
||||
continue;
|
||||
};
|
||||
|
||||
let mul_inputs: Vec<_> = graph
|
||||
.edges_directed(mul_node, Direction::Incoming)
|
||||
.sorted_by_key(|e| e.id())
|
||||
.map(|e| e.source())
|
||||
.collect();
|
||||
if mul_inputs.len() != 2 {
|
||||
continue;
|
||||
}
|
||||
|
||||
rewrites.push((sum_node, mul_node, mul_inputs, planner.plan(&desc)));
|
||||
}
|
||||
node
|
||||
}
|
||||
|
||||
fn buffer_for_llir_node<'a>(
|
||||
&'a self,
|
||||
node: NodeIndex,
|
||||
llir_to_hlir: &FxHashMap<NodeIndex, NodeIndex>,
|
||||
) -> &'a Buffer {
|
||||
let data_node = self.follow_aliases(node);
|
||||
if let Some(hlir_node) = llir_to_hlir.get(&data_node) {
|
||||
self.hlir_buffers
|
||||
.get(hlir_node)
|
||||
.expect("Input buffer not set!")
|
||||
} else {
|
||||
self.buffers
|
||||
.get(&data_node)
|
||||
.expect("Intermediate buffer not found!")
|
||||
for (sum_node, mul_node, mul_inputs, plan) in rewrites {
|
||||
graph[sum_node] =
|
||||
luminal::op::LLIROp::new::<dyn MetalKernelOp>(Box::new(MetalMatmul {
|
||||
m: plan.m,
|
||||
n: plan.n,
|
||||
k: plan.k,
|
||||
lda: plan.lda,
|
||||
ldb: plan.ldb,
|
||||
ldd: plan.ldd,
|
||||
family: plan.family,
|
||||
bm: plan.bm,
|
||||
bn: plan.bn,
|
||||
bk: plan.bk,
|
||||
wm: plan.wm,
|
||||
wn: plan.wn,
|
||||
batch_size: plan.batch_size,
|
||||
batch_stride_a: plan.batch_stride_a,
|
||||
batch_stride_b: plan.batch_stride_b,
|
||||
batch_stride_d: plan.batch_stride_d,
|
||||
}));
|
||||
|
||||
graph.remove_node(mul_node);
|
||||
graph.add_edge(mul_inputs[0], sum_node, ());
|
||||
graph.add_edge(mul_inputs[1], sum_node, ());
|
||||
}
|
||||
}
|
||||
|
||||
fn buffer_from_slice<T>(&self, values: &[T]) -> Buffer {
|
||||
self.device.new_buffer_with_data(
|
||||
values.as_ptr() as *const _,
|
||||
std::mem::size_of_val(values) as u64,
|
||||
MTLResourceOptions::StorageModeShared,
|
||||
)
|
||||
graph
|
||||
}
|
||||
|
||||
fn buffer_from_safetensor(
|
||||
&self,
|
||||
tensor: &safetensors::tensor::TensorView<'_>,
|
||||
dtype: DType,
|
||||
) -> Buffer {
|
||||
match (tensor.dtype(), dtype) {
|
||||
(Dtype::F32, DType::F32) | (Dtype::F16, DType::F16) => {
|
||||
let data = tensor.data();
|
||||
self.device.new_buffer_with_data(
|
||||
data.as_ptr() as *const _,
|
||||
data.len() as u64,
|
||||
MTLResourceOptions::StorageModeShared,
|
||||
)
|
||||
}
|
||||
(Dtype::F16, DType::F32) => {
|
||||
let values: Vec<f32> = bytemuck::cast_slice::<u8, f16>(tensor.data())
|
||||
.iter()
|
||||
.map(|v| v.to_f32())
|
||||
.collect();
|
||||
self.buffer_from_slice(&values)
|
||||
}
|
||||
(Dtype::BF16, DType::F32) => {
|
||||
let values: Vec<f32> = bytemuck::cast_slice::<u8, bf16>(tensor.data())
|
||||
.iter()
|
||||
.map(|v| v.to_f32())
|
||||
.collect();
|
||||
self.buffer_from_slice(&values)
|
||||
}
|
||||
(Dtype::F32, DType::F16) => {
|
||||
let values: Vec<f16> = bytemuck::cast_slice::<u8, f32>(tensor.data())
|
||||
.iter()
|
||||
.map(|v| f16::from_f32(*v))
|
||||
.collect();
|
||||
self.buffer_from_slice(&values)
|
||||
}
|
||||
(Dtype::BF16, DType::F16) => {
|
||||
let values: Vec<f16> = bytemuck::cast_slice::<u8, bf16>(tensor.data())
|
||||
.iter()
|
||||
.map(|v| f16::from_f32(v.to_f32()))
|
||||
.collect();
|
||||
self.buffer_from_slice(&values)
|
||||
}
|
||||
(tensor_dtype, dtype) => {
|
||||
panic!("Cannot load safetensor dtype {tensor_dtype:?} into Metal dtype {dtype:?}")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn contains_matmul(&self) -> bool {
|
||||
self.llir_graph.node_indices().any(|node| {
|
||||
@@ -183,69 +132,29 @@ impl MetalRuntime {
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn load_safetensors(&mut self, cx: &Graph, file_path: &str) {
|
||||
let f = File::open(file_path).unwrap();
|
||||
let mmap = unsafe { MmapOptions::new().map(&f).unwrap() };
|
||||
let st = SafeTensors::deserialize(&mmap).unwrap();
|
||||
|
||||
for node in cx.graph.node_indices() {
|
||||
if let Some(input) = (*cx.graph[node]).as_any().downcast_ref::<Input>()
|
||||
&& let Ok(tensor) = st.tensor(&input.label)
|
||||
{
|
||||
let buffer = self.buffer_from_safetensor(&tensor, input.dtype);
|
||||
self.input_data.remove(&node);
|
||||
self.hlir_buffers.insert(node, buffer);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_data(&mut self, id: impl ToId, data: impl Into<NativeData>) {
|
||||
let id = id.to_id();
|
||||
let data = data.into();
|
||||
if let Some(dtype) = self.input_dtype(id) {
|
||||
let buffer = self.create_input_buffer(&data, dtype);
|
||||
self.hlir_buffers.insert(id, buffer);
|
||||
}
|
||||
self.input_data.insert(id, data);
|
||||
}
|
||||
|
||||
pub fn set_zeros(&mut self, id: impl ToId, num_bytes: usize) {
|
||||
let id = id.to_id();
|
||||
let buffer = self
|
||||
.device
|
||||
.new_buffer(num_bytes as u64, MTLResourceOptions::StorageModeShared);
|
||||
unsafe {
|
||||
std::ptr::write_bytes(buffer.contents(), 0, num_bytes);
|
||||
}
|
||||
self.input_data.remove(&id);
|
||||
self.hlir_buffers.insert(id, buffer);
|
||||
}
|
||||
|
||||
pub fn remove_buffer(&mut self, id: impl ToId) -> Buffer {
|
||||
let data_id = self.follow_aliases(self.output_data_node(id.to_id()));
|
||||
|
||||
if let Some(buffer) = self.buffers.remove(&data_id) {
|
||||
return buffer;
|
||||
}
|
||||
|
||||
if let Some(Input { node, .. }) = self.llir_graph[data_id].to_op::<Input>() {
|
||||
return self
|
||||
.hlir_buffers
|
||||
.remove(&NodeIndex::new(*node))
|
||||
.expect("Cannot find input tensor in runtime!");
|
||||
}
|
||||
|
||||
panic!("Cannot find tensor in runtime!");
|
||||
}
|
||||
|
||||
pub fn set_buffer(&mut self, id: impl ToId, buffer: Buffer) {
|
||||
let id = id.to_id();
|
||||
self.input_data.remove(&id);
|
||||
self.hlir_buffers.insert(id, buffer);
|
||||
self.input_data.insert(id.to_id(), data.into());
|
||||
}
|
||||
|
||||
pub fn get_f32(&self, id: impl ToId) -> Vec<f32> {
|
||||
let data_id = self.follow_aliases(self.output_data_node(id.to_id()));
|
||||
let id = id.to_id();
|
||||
let output_id = self
|
||||
.llir_graph
|
||||
.node_indices()
|
||||
.find(|n| {
|
||||
if let Some(Output { node }) = self.llir_graph[*n].to_op::<Output>() {
|
||||
*node == id.index()
|
||||
} else {
|
||||
false
|
||||
}
|
||||
})
|
||||
.expect("Cannot find output tensor!");
|
||||
|
||||
let data_id = self
|
||||
.llir_graph
|
||||
.neighbors_directed(output_id, Direction::Incoming)
|
||||
.next()
|
||||
.unwrap();
|
||||
|
||||
let buffer = self
|
||||
.buffers
|
||||
@@ -322,10 +231,6 @@ impl Runtime for MetalRuntime {
|
||||
llir_graph: StableGraph::default(),
|
||||
node_dtypes: FxHashMap::default(),
|
||||
pipelines: FxHashMap::default(),
|
||||
output_alias_map: FxHashMap::default(),
|
||||
dim_buckets: FxHashMap::default(),
|
||||
compiled_buckets: vec![],
|
||||
active_bucket: 0,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -335,10 +240,52 @@ impl Runtime for MetalRuntime {
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
fn load_llir(&mut self, llir_graph: &LLIRGraph) {
|
||||
self.pipelines.clear();
|
||||
self.buffers.clear();
|
||||
self.dim_buckets.clear();
|
||||
self.compiled_buckets = vec![self.compile_bucket(FxHashMap::default(), llir_graph)];
|
||||
self.activate_bucket(0);
|
||||
self.hlir_buffers.clear();
|
||||
self.node_dtypes.clear();
|
||||
self.llir_graph = Self::fuse_matmuls(llir_graph);
|
||||
|
||||
let topo_order = toposort(&self.llir_graph, None).expect("Graph has cycles!");
|
||||
for node in topo_order {
|
||||
if let Some(input) = self.llir_graph[node].to_op::<Input>() {
|
||||
self.node_dtypes.insert(node, input.dtype);
|
||||
let hlir_id = NodeIndex::new(input.node);
|
||||
if let Some(data) = self.input_data.get(&hlir_id) {
|
||||
let buffer = self.create_input_buffer(data, input.dtype);
|
||||
self.hlir_buffers.insert(hlir_id, buffer);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if self.llir_graph[node].to_op::<Output>().is_some() {
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(kernel_op) = self.llir_graph[node].to_dialect::<dyn MetalKernelOp>() {
|
||||
let input_nodes: Vec<NodeIndex> = self
|
||||
.llir_graph
|
||||
.edges_directed(node, Direction::Incoming)
|
||||
.sorted_by_key(|e| e.id())
|
||||
.map(|e| e.source())
|
||||
.collect();
|
||||
let input_dtypes: Vec<DType> = input_nodes
|
||||
.iter()
|
||||
.map(|n| {
|
||||
self.node_dtypes
|
||||
.get(n)
|
||||
.copied()
|
||||
.unwrap_or_else(|| panic!("Missing inferred dtype for node {n:?}"))
|
||||
})
|
||||
.collect();
|
||||
let output_dtype = kernel_op.infer_output_dtype(&input_dtypes);
|
||||
let pipeline = kernel_op.compile(&self.device, &input_dtypes, output_dtype);
|
||||
self.node_dtypes.insert(node, output_dtype);
|
||||
self.pipelines.insert(node, pipeline);
|
||||
} else {
|
||||
panic!("Metal runtime cannot execute unlowered LLIR node {node:?}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
@@ -366,105 +313,73 @@ impl Runtime for MetalRuntime {
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
fn execute(&mut self, dyn_map: &FxHashMap<char, usize>) -> Self::ExecReturn {
|
||||
autoreleasepool(|| {
|
||||
self.select_bucket(dyn_map);
|
||||
self.allocate_active_intermediate_buffers(dyn_map);
|
||||
|
||||
let llir_to_hlir: FxHashMap<NodeIndex, NodeIndex> = self
|
||||
.llir_graph
|
||||
.node_indices()
|
||||
.filter_map(|n| {
|
||||
if let Some(Input { node, .. }) = self.llir_graph[n].to_op::<Input>() {
|
||||
Some((n, NodeIndex::new(*node)))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
let topo_order = toposort(&self.llir_graph, None).expect("Graph has cycles!");
|
||||
|
||||
self.update_dyn_buffer(dyn_map);
|
||||
let command_buffer = self.command_queue.new_command_buffer();
|
||||
|
||||
for node in topo_order {
|
||||
if self.llir_graph[node].to_op::<Input>().is_some()
|
||||
|| self.llir_graph[node].to_op::<Output>().is_some()
|
||||
{
|
||||
continue;
|
||||
let llir_to_hlir: FxHashMap<NodeIndex, NodeIndex> = self
|
||||
.llir_graph
|
||||
.node_indices()
|
||||
.filter_map(|n| {
|
||||
if let Some(Input { node, .. }) = self.llir_graph[n].to_op::<Input>() {
|
||||
Some((n, NodeIndex::new(*node)))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
if let Some(kernel_op) = self.llir_graph[node].to_dialect::<dyn MetalKernelOp>() {
|
||||
let pipeline = self.pipelines.get(&node);
|
||||
let topo_order = toposort(&self.llir_graph, None).expect("Graph has cycles!");
|
||||
|
||||
let input_nodes: Vec<NodeIndex> = self
|
||||
.llir_graph
|
||||
.edges_directed(node, Direction::Incoming)
|
||||
.sorted_by_key(|e| e.id())
|
||||
.map(|e| e.source())
|
||||
.collect();
|
||||
self.update_dyn_buffer(dyn_map);
|
||||
let command_buffer = self.command_queue.new_command_buffer();
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
|
||||
let input_buffers: Vec<&Buffer> = input_nodes
|
||||
.iter()
|
||||
.map(|&n| self.buffer_for_llir_node(n, &llir_to_hlir))
|
||||
.collect();
|
||||
let input_dtypes: Vec<DType> = input_nodes
|
||||
.iter()
|
||||
.map(|n| {
|
||||
self.node_dtypes
|
||||
.get(n)
|
||||
.copied()
|
||||
.unwrap_or_else(|| panic!("Missing inferred dtype for node {n:?}"))
|
||||
})
|
||||
.collect();
|
||||
|
||||
let output_buffer = if let Some(alias_idx) = kernel_op.output_aliases_input() {
|
||||
input_buffers[alias_idx]
|
||||
} else {
|
||||
self.buffers
|
||||
.get(&node)
|
||||
.expect("Output buffer not allocated!")
|
||||
};
|
||||
let output_dtype = self.node_dtypes.get(&node).copied().unwrap_or(DType::F32);
|
||||
|
||||
kernel_op.encode(
|
||||
command_buffer,
|
||||
pipeline,
|
||||
&input_buffers,
|
||||
output_buffer,
|
||||
dyn_map,
|
||||
&self.dyn_buffer,
|
||||
&input_dtypes,
|
||||
output_dtype,
|
||||
);
|
||||
}
|
||||
for node in topo_order {
|
||||
if self.llir_graph[node].to_op::<Input>().is_some()
|
||||
|| self.llir_graph[node].to_op::<Output>().is_some()
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
});
|
||||
}
|
||||
if let Some(kernel_op) = self.llir_graph[node].to_dialect::<dyn MetalKernelOp>() {
|
||||
let pipeline = self.pipelines.get(&node).expect("Pipeline not compiled!");
|
||||
|
||||
fn clear_intermediate_buffers(&mut self) {
|
||||
self.buffers.clear();
|
||||
}
|
||||
let input_nodes: Vec<NodeIndex> = self
|
||||
.llir_graph
|
||||
.edges_directed(node, Direction::Incoming)
|
||||
.sorted_by_key(|e| e.id())
|
||||
.map(|e| e.source())
|
||||
.collect();
|
||||
|
||||
fn load_llir_buckets(
|
||||
&mut self,
|
||||
dim_buckets: &FxHashMap<char, Vec<DimBucket>>,
|
||||
bucket_llirs: &[BucketLLIR],
|
||||
) {
|
||||
self.buffers.clear();
|
||||
self.dim_buckets = dim_buckets.clone();
|
||||
self.compiled_buckets = bucket_llirs
|
||||
.iter()
|
||||
.map(|(bucket_indices, _, llir)| self.compile_bucket(bucket_indices.clone(), llir))
|
||||
.collect();
|
||||
assert!(
|
||||
!self.compiled_buckets.is_empty(),
|
||||
"Metal runtime received no bucketed LLIRs"
|
||||
);
|
||||
self.activate_bucket(0);
|
||||
let input_buffers: Vec<&Buffer> = input_nodes
|
||||
.iter()
|
||||
.map(|&n| {
|
||||
if let Some(hlir_node) = llir_to_hlir.get(&n) {
|
||||
self.hlir_buffers
|
||||
.get(hlir_node)
|
||||
.expect("Input buffer not set!")
|
||||
} else {
|
||||
self.buffers
|
||||
.get(&n)
|
||||
.expect("Intermediate buffer not found!")
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
let output_buffer = self
|
||||
.buffers
|
||||
.get(&node)
|
||||
.expect("Output buffer not allocated!");
|
||||
|
||||
// Bind dyn dims right after the output slot:
|
||||
// [inputs..., output, dyn, bytes...]
|
||||
let dyn_idx = input_buffers.len() as u64 + 1;
|
||||
encoder.set_buffer(dyn_idx, Some(&self.dyn_buffer), 0);
|
||||
|
||||
kernel_op.encode(encoder, pipeline, &input_buffers, output_buffer, dyn_map);
|
||||
}
|
||||
}
|
||||
|
||||
encoder.end_encoding();
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -525,164 +440,23 @@ impl MetalRuntime {
|
||||
}
|
||||
|
||||
pub fn allocate_intermediate_buffers(&mut self, dyn_map: &FxHashMap<char, usize>) {
|
||||
self.select_bucket(dyn_map);
|
||||
self.allocate_active_intermediate_buffers(dyn_map);
|
||||
}
|
||||
|
||||
fn allocate_active_intermediate_buffers(&mut self, dyn_map: &FxHashMap<char, usize>) {
|
||||
let mut planned = Vec::new();
|
||||
|
||||
for node in self.llir_graph.node_indices() {
|
||||
if self.llir_graph[node].to_op::<Input>().is_some() {
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(kernel_op) = self.llir_graph[node].to_dialect::<dyn MetalKernelOp>() {
|
||||
if kernel_op.output_aliases_input().is_some() {
|
||||
continue;
|
||||
}
|
||||
let size = kernel_op.output_size().exec(dyn_map).unwrap();
|
||||
let dtype = self.node_dtypes.get(&node).copied().unwrap_or(DType::F32);
|
||||
let bytes = (size * dtype.bits().div_ceil(8)) as u64;
|
||||
let needs_buffer = self
|
||||
.buffers
|
||||
.get(&node)
|
||||
.is_none_or(|buffer| buffer.length() != bytes);
|
||||
|
||||
planned.push((node, bytes, needs_buffer));
|
||||
}
|
||||
}
|
||||
|
||||
for (node, bytes, needs_buffer) in planned {
|
||||
if needs_buffer {
|
||||
let buffer = self
|
||||
.device
|
||||
.new_buffer(bytes, MTLResourceOptions::StorageModeShared);
|
||||
let buffer = self.device.new_buffer(
|
||||
(size * dtype.bits().div_ceil(8)) as u64,
|
||||
MTLResourceOptions::StorageModeShared,
|
||||
);
|
||||
self.buffers.insert(node, buffer);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn compile_bucket(
|
||||
&self,
|
||||
bucket_indices: FxHashMap<char, usize>,
|
||||
llir_graph: &LLIRGraph,
|
||||
) -> MetalCompiledBucket {
|
||||
let mut node_dtypes = FxHashMap::default();
|
||||
let mut pipelines = FxHashMap::default();
|
||||
let mut output_alias_map = FxHashMap::default();
|
||||
let llir_graph = llir_graph.clone();
|
||||
|
||||
let topo_order = toposort(&llir_graph, None).expect("Graph has cycles!");
|
||||
for node in topo_order {
|
||||
if let Some(input) = llir_graph[node].to_op::<Input>() {
|
||||
node_dtypes.insert(node, input.dtype);
|
||||
continue;
|
||||
}
|
||||
|
||||
if llir_graph[node].to_op::<Output>().is_some() {
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(kernel_op) = llir_graph[node].to_dialect::<dyn MetalKernelOp>() {
|
||||
let input_nodes: Vec<NodeIndex> = llir_graph
|
||||
.edges_directed(node, Direction::Incoming)
|
||||
.sorted_by_key(|e| e.id())
|
||||
.map(|e| e.source())
|
||||
.collect();
|
||||
let input_dtypes: Vec<DType> = input_nodes
|
||||
.iter()
|
||||
.map(|n| {
|
||||
node_dtypes
|
||||
.get(n)
|
||||
.copied()
|
||||
.unwrap_or_else(|| panic!("Missing inferred dtype for node {n:?}"))
|
||||
})
|
||||
.collect();
|
||||
let output_dtype = kernel_op.infer_output_dtype(&input_dtypes);
|
||||
let pipeline = kernel_op.compile(&self.device, &input_dtypes, output_dtype);
|
||||
node_dtypes.insert(node, output_dtype);
|
||||
if let Some(pipeline) = pipeline {
|
||||
pipelines.insert(node, pipeline);
|
||||
}
|
||||
if let Some(input_idx) = kernel_op.output_aliases_input()
|
||||
&& let Some(target) = input_nodes.get(input_idx).copied()
|
||||
{
|
||||
output_alias_map.insert(node, target);
|
||||
}
|
||||
} else {
|
||||
panic!("Metal runtime cannot execute unlowered LLIR node {node:?}");
|
||||
}
|
||||
}
|
||||
|
||||
MetalCompiledBucket {
|
||||
bucket_indices,
|
||||
llir_graph,
|
||||
node_dtypes,
|
||||
pipelines,
|
||||
output_alias_map,
|
||||
}
|
||||
}
|
||||
|
||||
fn activate_bucket(&mut self, index: usize) {
|
||||
let bucket = self
|
||||
.compiled_buckets
|
||||
.get(index)
|
||||
.unwrap_or_else(|| panic!("Metal bucket index {index} is not compiled"))
|
||||
.clone();
|
||||
self.active_bucket = index;
|
||||
self.llir_graph = bucket.llir_graph;
|
||||
self.node_dtypes = bucket.node_dtypes;
|
||||
self.pipelines = bucket.pipelines;
|
||||
self.output_alias_map = bucket.output_alias_map;
|
||||
self.refresh_input_data_buffers();
|
||||
self.buffers.clear();
|
||||
}
|
||||
|
||||
fn refresh_input_data_buffers(&mut self) {
|
||||
for node in self.llir_graph.node_indices() {
|
||||
if let Some(input) = self.llir_graph[node].to_op::<Input>() {
|
||||
let hlir_id = NodeIndex::new(input.node);
|
||||
if let Some(data) = self.input_data.get(&hlir_id) {
|
||||
let buffer = self.create_input_buffer(data, input.dtype);
|
||||
self.hlir_buffers.insert(hlir_id, buffer);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn select_bucket(&mut self, dyn_map: &FxHashMap<char, usize>) {
|
||||
if self.compiled_buckets.len() <= 1 {
|
||||
return;
|
||||
}
|
||||
|
||||
let index = self.resolve_bucket(dyn_map);
|
||||
if index != self.active_bucket {
|
||||
self.activate_bucket(index);
|
||||
}
|
||||
}
|
||||
|
||||
fn resolve_bucket(&self, dyn_map: &FxHashMap<char, usize>) -> usize {
|
||||
self.compiled_buckets
|
||||
.iter()
|
||||
.position(|bucket| {
|
||||
self.dim_buckets.iter().all(|(dim, buckets)| {
|
||||
let value = dyn_map.get(dim).copied().unwrap_or(0);
|
||||
let bucket_index = bucket.bucket_indices.get(dim).copied().unwrap_or(0);
|
||||
buckets
|
||||
.get(bucket_index)
|
||||
.map(|bucket| bucket.contains(value))
|
||||
.unwrap_or(true)
|
||||
})
|
||||
})
|
||||
.unwrap_or_else(|| {
|
||||
panic!(
|
||||
"No Metal bucket matches dyn_map {:?}. Defined buckets: {:?}",
|
||||
dyn_map, self.dim_buckets
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
fn update_dyn_buffer(&mut self, dyn_map: &FxHashMap<char, usize>) {
|
||||
let ptr = self.dyn_buffer.contents() as *mut i32;
|
||||
unsafe {
|
||||
@@ -702,99 +476,87 @@ impl MetalRuntime {
|
||||
|
||||
/// Execute and return GPU-side execution time in microseconds.
|
||||
fn execute_timed(&mut self, dyn_map: &FxHashMap<char, usize>) -> (f64, TimingMethod) {
|
||||
autoreleasepool(|| {
|
||||
self.select_bucket(dyn_map);
|
||||
self.allocate_active_intermediate_buffers(dyn_map);
|
||||
|
||||
let llir_to_hlir: FxHashMap<NodeIndex, NodeIndex> = self
|
||||
.llir_graph
|
||||
.node_indices()
|
||||
.filter_map(|n| {
|
||||
if let Some(Input { node, .. }) = self.llir_graph[n].to_op::<Input>() {
|
||||
Some((n, NodeIndex::new(*node)))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
let topo_order = toposort(&self.llir_graph, None).expect("Graph has cycles!");
|
||||
|
||||
self.update_dyn_buffer(dyn_map);
|
||||
let command_buffer = self.command_queue.new_command_buffer();
|
||||
|
||||
for node in topo_order {
|
||||
if self.llir_graph[node].to_op::<Input>().is_some()
|
||||
|| self.llir_graph[node].to_op::<Output>().is_some()
|
||||
{
|
||||
continue;
|
||||
let llir_to_hlir: FxHashMap<NodeIndex, NodeIndex> = self
|
||||
.llir_graph
|
||||
.node_indices()
|
||||
.filter_map(|n| {
|
||||
if let Some(Input { node, .. }) = self.llir_graph[n].to_op::<Input>() {
|
||||
Some((n, NodeIndex::new(*node)))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
if let Some(kernel_op) = self.llir_graph[node].to_dialect::<dyn MetalKernelOp>() {
|
||||
let pipeline = self.pipelines.get(&node);
|
||||
let topo_order = toposort(&self.llir_graph, None).expect("Graph has cycles!");
|
||||
|
||||
let input_nodes: Vec<NodeIndex> = self
|
||||
.llir_graph
|
||||
.edges_directed(node, Direction::Incoming)
|
||||
.sorted_by_key(|e| e.id())
|
||||
.map(|e| e.source())
|
||||
.collect();
|
||||
self.update_dyn_buffer(dyn_map);
|
||||
let command_buffer = self.command_queue.new_command_buffer();
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
|
||||
let input_buffers: Vec<&Buffer> = input_nodes
|
||||
.iter()
|
||||
.map(|&n| self.buffer_for_llir_node(n, &llir_to_hlir))
|
||||
.collect();
|
||||
let input_dtypes: Vec<DType> = input_nodes
|
||||
.iter()
|
||||
.map(|n| {
|
||||
self.node_dtypes
|
||||
.get(n)
|
||||
.copied()
|
||||
.unwrap_or_else(|| panic!("Missing inferred dtype for node {n:?}"))
|
||||
})
|
||||
.collect();
|
||||
|
||||
let output_buffer = if let Some(alias_idx) = kernel_op.output_aliases_input() {
|
||||
input_buffers[alias_idx]
|
||||
} else {
|
||||
self.buffers
|
||||
.get(&node)
|
||||
.expect("Output buffer not allocated!")
|
||||
};
|
||||
let output_dtype = self.node_dtypes.get(&node).copied().unwrap_or(DType::F32);
|
||||
|
||||
kernel_op.encode(
|
||||
command_buffer,
|
||||
pipeline,
|
||||
&input_buffers,
|
||||
output_buffer,
|
||||
dyn_map,
|
||||
&self.dyn_buffer,
|
||||
&input_dtypes,
|
||||
output_dtype,
|
||||
);
|
||||
}
|
||||
for node in topo_order {
|
||||
if self.llir_graph[node].to_op::<Input>().is_some()
|
||||
|| self.llir_graph[node].to_op::<Output>().is_some()
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
if let Some(kernel_op) = self.llir_graph[node].to_dialect::<dyn MetalKernelOp>() {
|
||||
let pipeline = self.pipelines.get(&node).expect("Pipeline not compiled!");
|
||||
|
||||
// gpuStartTime and gpuEndTime are available on macOS 10.15+
|
||||
let gpu_start: f64 = unsafe {
|
||||
use objc::{msg_send, sel, sel_impl};
|
||||
let ptr = command_buffer as *const _ as *mut Object;
|
||||
msg_send![ptr, GPUStartTime]
|
||||
};
|
||||
let gpu_end: f64 = unsafe {
|
||||
use objc::{msg_send, sel, sel_impl};
|
||||
let ptr = command_buffer as *const _ as *mut Object;
|
||||
msg_send![ptr, GPUEndTime]
|
||||
};
|
||||
let input_nodes: Vec<NodeIndex> = self
|
||||
.llir_graph
|
||||
.edges_directed(node, Direction::Incoming)
|
||||
.sorted_by_key(|e| e.id())
|
||||
.map(|e| e.source())
|
||||
.collect();
|
||||
|
||||
let gpu_time_seconds = gpu_end - gpu_start;
|
||||
let gpu_time_us = gpu_time_seconds * 1_000_000.0;
|
||||
let input_buffers: Vec<&Buffer> = input_nodes
|
||||
.iter()
|
||||
.map(|&n| {
|
||||
if let Some(hlir_node) = llir_to_hlir.get(&n) {
|
||||
self.hlir_buffers
|
||||
.get(hlir_node)
|
||||
.expect("Input buffer not set!")
|
||||
} else {
|
||||
self.buffers
|
||||
.get(&n)
|
||||
.expect("Intermediate buffer not found!")
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
(gpu_time_us, TimingMethod::DeviceTimestamp)
|
||||
})
|
||||
let output_buffer = self
|
||||
.buffers
|
||||
.get(&node)
|
||||
.expect("Output buffer not allocated!");
|
||||
|
||||
let dyn_idx = input_buffers.len() as u64 + 1;
|
||||
encoder.set_buffer(dyn_idx, Some(&self.dyn_buffer), 0);
|
||||
|
||||
kernel_op.encode(encoder, pipeline, &input_buffers, output_buffer, dyn_map);
|
||||
}
|
||||
}
|
||||
|
||||
encoder.end_encoding();
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
// gpuStartTime and gpuEndTime are available on macOS 10.15+
|
||||
let gpu_start: f64 = unsafe {
|
||||
use objc::{msg_send, sel, sel_impl};
|
||||
let ptr = command_buffer as *const _ as *mut Object;
|
||||
msg_send![ptr, GPUStartTime]
|
||||
};
|
||||
let gpu_end: f64 = unsafe {
|
||||
use objc::{msg_send, sel, sel_impl};
|
||||
let ptr = command_buffer as *const _ as *mut Object;
|
||||
msg_send![ptr, GPUEndTime]
|
||||
};
|
||||
|
||||
let gpu_time_seconds = gpu_end - gpu_start;
|
||||
let gpu_time_us = gpu_time_seconds * 1_000_000.0;
|
||||
|
||||
(gpu_time_us, TimingMethod::DeviceTimestamp)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,16 +1,8 @@
|
||||
use crate::{kernel::lower_expression_for_metal, runtime::MetalRuntime};
|
||||
use candle_core::{Device as CandleDevice, Tensor as CandleTensor};
|
||||
use half::{bf16, f16};
|
||||
use half::f16;
|
||||
use luminal::prelude::*;
|
||||
use proptest::prelude::*;
|
||||
use safetensors::{Dtype, tensor::TensorView};
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
path::PathBuf,
|
||||
sync::atomic::{AtomicUsize, Ordering},
|
||||
};
|
||||
|
||||
static SAFETENSORS_TEST_FILE_ID: AtomicUsize = AtomicUsize::new(0);
|
||||
|
||||
fn assert_close(actual: &[f32], expected: &[f32], tolerance: f32) {
|
||||
assert_eq!(
|
||||
@@ -34,32 +26,6 @@ fn assert_close(actual: &[f32], expected: &[f32], tolerance: f32) {
|
||||
}
|
||||
}
|
||||
|
||||
fn bytes_of<T: bytemuck::NoUninit>(values: &[T]) -> Vec<u8> {
|
||||
bytemuck::cast_slice(values).to_vec()
|
||||
}
|
||||
|
||||
fn write_test_safetensors(tensors: &[(&str, Dtype, Vec<usize>, Vec<u8>)]) -> PathBuf {
|
||||
let tensor_views: HashMap<String, TensorView<'_>> = tensors
|
||||
.iter()
|
||||
.map(|(name, dtype, shape, data)| {
|
||||
(
|
||||
(*name).to_string(),
|
||||
TensorView::new(*dtype, shape.clone(), data).unwrap(),
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
let serialized = safetensors::serialize(&tensor_views, None).unwrap();
|
||||
let id = SAFETENSORS_TEST_FILE_ID.fetch_add(1, Ordering::Relaxed);
|
||||
let mut path = std::env::temp_dir();
|
||||
path.push(format!(
|
||||
"luminal_metal_runtime_{}_{}.safetensors",
|
||||
std::process::id(),
|
||||
id
|
||||
));
|
||||
std::fs::write(&path, serialized).unwrap();
|
||||
path
|
||||
}
|
||||
|
||||
const TRANSFORMER_SEQ: usize = 4;
|
||||
const TRANSFORMER_HIDDEN: usize = 16;
|
||||
const TRANSFORMER_INTERMEDIATE: usize = 32;
|
||||
@@ -284,36 +250,6 @@ fn dynamic_dim_sum_reduce_runs() {
|
||||
assert_close(&out, &[9.0, 12.0], 0.001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn metal_bucketed_dynamic_dim_dispatches_correct_graph() {
|
||||
let mut cx = Graph::default();
|
||||
let input = cx.tensor(('s', 4));
|
||||
let output = (input + input).output();
|
||||
|
||||
cx.set_dim_buckets('s', &[DimBucket::new(1, 1), DimBucket::new(2, 4)]);
|
||||
cx.set_dim('s', 1);
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.set_data(input, vec![1.0f32; 4]);
|
||||
rt = cx.search(rt, 5);
|
||||
|
||||
cx.set_dim('s', 1);
|
||||
let s1_input = vec![1.0, 2.0, 3.0, 4.0];
|
||||
rt.set_data(input, s1_input.clone());
|
||||
rt.execute(&cx.dyn_map);
|
||||
let s1_out = rt.get_f32(output);
|
||||
assert_close(&s1_out[..4], &[2.0, 4.0, 6.0, 8.0], 0.001);
|
||||
|
||||
cx.set_dim('s', 3);
|
||||
let s3_input: Vec<f32> = (0..12).map(|i| i as f32).collect();
|
||||
let s3_expected: Vec<f32> = s3_input.iter().map(|v| v * 2.0).collect();
|
||||
rt.set_data(input, s3_input);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let s3_out = rt.get_f32(output);
|
||||
assert_close(&s3_out[..12], &s3_expected, 0.001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn metal_int_arithmetic_preserves_large_values() {
|
||||
let mut cx = Graph::default();
|
||||
@@ -709,13 +645,8 @@ fn metal_regular_tiled_matmul_path() {
|
||||
|
||||
let kernels = rt.debug_kernel_ops();
|
||||
assert!(
|
||||
kernels.iter().any(|k| k.contains("MPSMatmul")),
|
||||
"expected MPS matmul path, kernels: {:?}",
|
||||
kernels
|
||||
);
|
||||
assert!(
|
||||
!kernels.iter().any(|k| k.contains("GenericMatmul")),
|
||||
"MPS-compatible matmul should not extract the generic fallback, kernels: {:?}",
|
||||
kernels.iter().any(|k| k.contains("family: RegularTiled")),
|
||||
"expected regular tiled matmul path, kernels: {:?}",
|
||||
kernels
|
||||
);
|
||||
|
||||
@@ -733,287 +664,6 @@ fn metal_regular_tiled_matmul_path() {
|
||||
assert_close(&result, &expected, 2e-3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn metal_mps_matmul_transposed_rhs_weight_layout() {
|
||||
let mut cx = Graph::default();
|
||||
let m = 7;
|
||||
let k = 11;
|
||||
let n = 13;
|
||||
let a = cx.tensor((m, k));
|
||||
let weight = cx.tensor((n, k));
|
||||
let output = a.matmul(weight.t()).output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
|
||||
let a_data = seeded_data(m * k, 0.35, -0.17);
|
||||
let weight_data = seeded_data(n * k, 0.21, -0.09);
|
||||
|
||||
rt.set_data(a, &a_data);
|
||||
rt.set_data(weight, &weight_data);
|
||||
rt = cx.search(rt, 1);
|
||||
|
||||
let kernels = rt.debug_kernel_ops();
|
||||
assert!(
|
||||
kernels.iter().any(|k| k.contains("transpose_rhs: true")),
|
||||
"expected MPS matmul to cover transposed row-major RHS, kernels: {:?}",
|
||||
kernels
|
||||
);
|
||||
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
let result = rt.get_f32(output);
|
||||
|
||||
let device = CandleDevice::Cpu;
|
||||
let ref_a = CandleTensor::from_vec(a_data, (m, k), &device).unwrap();
|
||||
let ref_weight = CandleTensor::from_vec(weight_data, (n, k), &device).unwrap();
|
||||
let expected = ref_a.matmul(&ref_weight.t().unwrap()).unwrap();
|
||||
let expected: Vec<f32> = expected.flatten_all().unwrap().to_vec1().unwrap();
|
||||
|
||||
assert_close(&result, &expected, 1e-3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn metal_mps_matmul_transposed_lhs_layout() {
|
||||
let mut cx = Graph::default();
|
||||
let m = 5;
|
||||
let k = 9;
|
||||
let n = 6;
|
||||
let lhs_storage = cx.tensor((k, m));
|
||||
let rhs = cx.tensor((k, n));
|
||||
let output = lhs_storage.t().matmul(rhs).output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
|
||||
let lhs_data = seeded_data(k * m, 0.31, -0.12);
|
||||
let rhs_data = seeded_data(k * n, 0.27, -0.08);
|
||||
|
||||
rt.set_data(lhs_storage, &lhs_data);
|
||||
rt.set_data(rhs, &rhs_data);
|
||||
rt = cx.search(rt, 1);
|
||||
|
||||
let kernels = rt.debug_kernel_ops();
|
||||
assert!(
|
||||
kernels.iter().any(|k| k.contains("transpose_lhs: true")),
|
||||
"expected MPS matmul to cover transposed row-major LHS, kernels: {:?}",
|
||||
kernels
|
||||
);
|
||||
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
let result = rt.get_f32(output);
|
||||
|
||||
let device = CandleDevice::Cpu;
|
||||
let ref_lhs = CandleTensor::from_vec(lhs_data, (k, m), &device)
|
||||
.unwrap()
|
||||
.t()
|
||||
.unwrap();
|
||||
let ref_rhs = CandleTensor::from_vec(rhs_data, (k, n), &device).unwrap();
|
||||
let expected = ref_lhs.matmul(&ref_rhs).unwrap();
|
||||
let expected: Vec<f32> = expected.flatten_all().unwrap().to_vec1().unwrap();
|
||||
|
||||
assert_close(&result, &expected, 1e-3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn metal_mps_batched_matmul_row_row_layout() {
|
||||
let mut cx = Graph::default();
|
||||
let batch = 3;
|
||||
let m = 4;
|
||||
let k = 5;
|
||||
let n = 6;
|
||||
let a = cx.tensor((batch, m, k));
|
||||
let b = cx.tensor((batch, k, n));
|
||||
let output = a.matmul(b).output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
|
||||
let a_data = seeded_data(batch * m * k, 0.17, -0.08);
|
||||
let b_data = seeded_data(batch * k * n, 0.11, -0.05);
|
||||
rt.set_data(a, &a_data);
|
||||
rt.set_data(b, &b_data);
|
||||
rt = cx.search(rt, 1);
|
||||
|
||||
let kernels = rt.debug_kernel_ops();
|
||||
assert!(
|
||||
kernels.iter().any(|k| k.contains("MPSBatchedMatmul")),
|
||||
"expected MPS batched matmul path, kernels: {:?}",
|
||||
kernels
|
||||
);
|
||||
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(output);
|
||||
|
||||
let mut expected = vec![0.0; batch * m * n];
|
||||
for batch_idx in 0..batch {
|
||||
for row in 0..m {
|
||||
for col in 0..n {
|
||||
let mut sum = 0.0;
|
||||
for inner in 0..k {
|
||||
sum += a_data[batch_idx * m * k + row * k + inner]
|
||||
* b_data[batch_idx * k * n + inner * n + col];
|
||||
}
|
||||
expected[batch_idx * m * n + row * n + col] = sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
assert_close(&result, &expected, 1e-3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn metal_generic_matmul_covers_noncontiguous_merged_head_projection() {
|
||||
let mut cx = Graph::default();
|
||||
let heads = 3;
|
||||
let seq = 4;
|
||||
let head_dim = 5;
|
||||
let hidden = heads * head_dim;
|
||||
let out_dim = 7;
|
||||
let attn = cx.tensor((heads, seq, head_dim));
|
||||
let weight = cx.tensor((out_dim, hidden));
|
||||
let merged = attn.transpose(0, 1).merge_dims(1, 2);
|
||||
let output = merged.matmul(weight.t()).output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
|
||||
let attn_data = seeded_data(heads * seq * head_dim, 0.19, -0.09);
|
||||
let weight_data = seeded_data(out_dim * hidden, 0.14, -0.06);
|
||||
rt.set_data(attn, &attn_data);
|
||||
rt.set_data(weight, &weight_data);
|
||||
rt = cx.search(rt, 1);
|
||||
|
||||
let kernels = rt.debug_kernel_ops();
|
||||
assert!(
|
||||
kernels.iter().any(|k| k.contains("GenericMatmul")),
|
||||
"expected generic matmul fallback for non-contiguous merged-head projection, kernels: {:?}",
|
||||
kernels
|
||||
);
|
||||
assert!(
|
||||
!kernels.iter().any(|k| {
|
||||
k.contains("MetalMul") && k.contains(&format!("shape: [{seq}, {out_dim}, {hidden}]"))
|
||||
}),
|
||||
"generic fallback should remove the broadcast multiply intermediate, kernels: {:?}",
|
||||
kernels
|
||||
);
|
||||
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(output);
|
||||
|
||||
let mut expected = vec![0.0; seq * out_dim];
|
||||
for token in 0..seq {
|
||||
for out_col in 0..out_dim {
|
||||
let mut sum = 0.0;
|
||||
for inner in 0..hidden {
|
||||
let head = inner / head_dim;
|
||||
let dim = inner % head_dim;
|
||||
let attn_idx = head * seq * head_dim + token * head_dim + dim;
|
||||
sum += attn_data[attn_idx] * weight_data[out_col * hidden + inner];
|
||||
}
|
||||
expected[token * out_dim + out_col] = sum;
|
||||
}
|
||||
}
|
||||
|
||||
assert_close(&result, &expected, 1e-3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn metal_mps_batched_matmul_transposed_rhs_layout() {
|
||||
let mut cx = Graph::default();
|
||||
let batch = 4;
|
||||
let m = 3;
|
||||
let k = 7;
|
||||
let n = 5;
|
||||
let a = cx.tensor((batch, m, k));
|
||||
let weight = cx.tensor((batch, n, k));
|
||||
let output = a.matmul(weight.permute((0, 2, 1))).output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
|
||||
let a_data = seeded_data(batch * m * k, 0.13, -0.06);
|
||||
let weight_data = seeded_data(batch * n * k, 0.09, -0.04);
|
||||
rt.set_data(a, &a_data);
|
||||
rt.set_data(weight, &weight_data);
|
||||
rt = cx.search(rt, 1);
|
||||
|
||||
let kernels = rt.debug_kernel_ops();
|
||||
assert!(
|
||||
kernels
|
||||
.iter()
|
||||
.any(|k| k.contains("MPSBatchedMatmul") && k.contains("transpose_rhs: true")),
|
||||
"expected MPS batched matmul transposed RHS path, kernels: {:?}",
|
||||
kernels
|
||||
);
|
||||
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(output);
|
||||
|
||||
let mut expected = vec![0.0; batch * m * n];
|
||||
for batch_idx in 0..batch {
|
||||
for row in 0..m {
|
||||
for col in 0..n {
|
||||
let mut sum = 0.0;
|
||||
for inner in 0..k {
|
||||
sum += a_data[batch_idx * m * k + row * k + inner]
|
||||
* weight_data[batch_idx * n * k + col * k + inner];
|
||||
}
|
||||
expected[batch_idx * m * n + row * n + col] = sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
assert_close(&result, &expected, 1e-3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn metal_mps_matmul_f16_transposed_rhs_weight_layout() {
|
||||
let mut cx = Graph::default();
|
||||
let m = 6;
|
||||
let k = 10;
|
||||
let n = 7;
|
||||
let a = cx.tensor((m, k)).as_dtype(DType::F16);
|
||||
let weight = cx.tensor((n, k)).as_dtype(DType::F16);
|
||||
let output = a.matmul(weight.t()).cast(DType::F32).output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
|
||||
let a_data = seeded_data(m * k, 0.22, -0.07);
|
||||
let weight_data = seeded_data(n * k, 0.18, -0.05);
|
||||
|
||||
rt.set_data(a, to_f16_vec(&a_data));
|
||||
rt.set_data(weight, to_f16_vec(&weight_data));
|
||||
rt = cx.search(rt, 1);
|
||||
|
||||
let kernels = rt.debug_kernel_ops();
|
||||
assert!(
|
||||
kernels.iter().any(|k| k.contains("transpose_rhs: true")),
|
||||
"expected MPS F16 matmul to cover transposed row-major RHS, kernels: {:?}",
|
||||
kernels
|
||||
);
|
||||
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
let result = rt.get_f32(output);
|
||||
|
||||
let device = CandleDevice::Cpu;
|
||||
let ref_a = CandleTensor::from_vec(a_data, (m, k), &device).unwrap();
|
||||
let ref_weight = CandleTensor::from_vec(weight_data, (n, k), &device).unwrap();
|
||||
let expected = ref_a.matmul(&ref_weight.t().unwrap()).unwrap();
|
||||
let expected: Vec<f32> = expected.flatten_all().unwrap().to_vec1().unwrap();
|
||||
|
||||
assert_close(&result, &expected, 5e-3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn metal_rms_norm() {
|
||||
let mut cx = Graph::default();
|
||||
@@ -1338,131 +988,6 @@ fn test_scatter_basic() {
|
||||
assert_close(&out, &[0.0, 10.0, 0.0, 20.0, 30.0], 0.001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scatter_buffer_roundtrip() {
|
||||
let mut cx = Graph::default();
|
||||
let src = cx.tensor(1);
|
||||
let indexes = cx.tensor(1).as_dtype(DType::Int);
|
||||
let cache = cx.tensor(4).persist();
|
||||
let cache_out = src.scatter(indexes, cache);
|
||||
let read = cache_out.output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.set_data(src, &[0.0]);
|
||||
rt.set_data(indexes, &[0.0]);
|
||||
rt.set_zeros(cache, 4 * std::mem::size_of::<f32>());
|
||||
rt = cx.search(rt, 1);
|
||||
|
||||
for (pos, value, expected) in [
|
||||
(0, 10.0, [10.0, 0.0, 0.0, 0.0]),
|
||||
(1, 20.0, [10.0, 20.0, 0.0, 0.0]),
|
||||
(2, 30.0, [10.0, 20.0, 30.0, 0.0]),
|
||||
] {
|
||||
rt.set_data(src, &[value]);
|
||||
rt.set_data(indexes, &[pos as f32]);
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
assert_close(&rt.get_f32(read), &expected, 0.001);
|
||||
|
||||
let updated_cache = rt.remove_buffer(cache_out);
|
||||
rt.set_buffer(cache, updated_cache);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_load_safetensors_f32_survives_search_and_overrides_input_data() {
|
||||
let mut cx = Graph::default();
|
||||
let weights = cx.named_tensor("weights", 3);
|
||||
let bias = cx.named_tensor("bias", 3);
|
||||
let out = (weights + bias).output();
|
||||
|
||||
let weight_values = [1.25f32, -2.5, 4.0];
|
||||
let tensors = [("weights", Dtype::F32, vec![3], bytes_of(&weight_values))];
|
||||
let path = write_test_safetensors(&tensors);
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.set_data(weights, &[99.0, 99.0, 99.0]);
|
||||
rt.set_data(bias, &[0.5, 1.0, -1.5]);
|
||||
rt.load_safetensors(&cx, path.to_str().unwrap());
|
||||
rt = cx.search(rt, 1);
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
assert_close(&rt.get_f32(out), &[1.75, -1.5, 2.5], 0.001);
|
||||
std::fs::remove_file(path).ok();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_load_safetensors_converts_supported_float_dtypes() {
|
||||
let mut cx = Graph::default();
|
||||
let f16_to_f32 = cx.named_tensor("f16_to_f32", 2);
|
||||
let bf16_to_f32 = cx.named_tensor("bf16_to_f32", 2);
|
||||
let f16_to_f16 = cx.named_tensor("f16_to_f16", 2).as_dtype(DType::F16);
|
||||
let f32_to_f16 = cx.named_tensor("f32_to_f16", 2).as_dtype(DType::F16);
|
||||
let bf16_to_f16 = cx.named_tensor("bf16_to_f16", 2).as_dtype(DType::F16);
|
||||
|
||||
let f16_to_f32_out = (f16_to_f32 + 0.0).output();
|
||||
let bf16_to_f32_out = (bf16_to_f32 + 0.0).output();
|
||||
let f16_to_f16_out = f16_to_f16.cast(DType::F32).output();
|
||||
let f32_to_f16_out = f32_to_f16.cast(DType::F32).output();
|
||||
let bf16_to_f16_out = bf16_to_f16.cast(DType::F32).output();
|
||||
|
||||
let f16_to_f32_values = [f16::from_f32(1.5), f16::from_f32(-2.25)];
|
||||
let bf16_to_f32_values = [bf16::from_f32(3.5), bf16::from_f32(-4.25)];
|
||||
let f16_to_f16_values = [f16::from_f32(5.5), f16::from_f32(-6.25)];
|
||||
let f32_to_f16_values = [7.5f32, -8.25];
|
||||
let bf16_to_f16_values = [bf16::from_f32(9.5), bf16::from_f32(-10.25)];
|
||||
let tensors = [
|
||||
(
|
||||
"f16_to_f32",
|
||||
Dtype::F16,
|
||||
vec![2],
|
||||
bytes_of(&f16_to_f32_values),
|
||||
),
|
||||
(
|
||||
"bf16_to_f32",
|
||||
Dtype::BF16,
|
||||
vec![2],
|
||||
bytes_of(&bf16_to_f32_values),
|
||||
),
|
||||
(
|
||||
"f16_to_f16",
|
||||
Dtype::F16,
|
||||
vec![2],
|
||||
bytes_of(&f16_to_f16_values),
|
||||
),
|
||||
(
|
||||
"f32_to_f16",
|
||||
Dtype::F32,
|
||||
vec![2],
|
||||
bytes_of(&f32_to_f16_values),
|
||||
),
|
||||
(
|
||||
"bf16_to_f16",
|
||||
Dtype::BF16,
|
||||
vec![2],
|
||||
bytes_of(&bf16_to_f16_values),
|
||||
),
|
||||
];
|
||||
let path = write_test_safetensors(&tensors);
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.load_safetensors(&cx, path.to_str().unwrap());
|
||||
rt = cx.search(rt, 1);
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
assert_close(&rt.get_f32(f16_to_f32_out), &[1.5, -2.25], 0.001);
|
||||
assert_close(&rt.get_f32(bf16_to_f32_out), &[3.5, -4.25], 0.001);
|
||||
assert_close(&rt.get_f32(f16_to_f16_out), &[5.5, -6.25], 0.001);
|
||||
assert_close(&rt.get_f32(f32_to_f16_out), &[7.5, -8.25], 0.001);
|
||||
assert_close(&rt.get_f32(bf16_to_f16_out), &[9.5, -10.25], 0.001);
|
||||
std::fs::remove_file(path).ok();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gather_noncontiguous_data_uses_data_shape() {
|
||||
let mut cx = Graph::default();
|
||||
@@ -1499,12 +1024,6 @@ fn test_scatter_into_nonzero_dest() {
|
||||
rt.set_data(indexes, &[2f32]);
|
||||
rt.set_data(dest, &[1.0, 2.0, 3.0, 4.0, 5.0]);
|
||||
rt = cx.search(rt, 1);
|
||||
let kernels = rt.debug_kernel_ops();
|
||||
assert!(
|
||||
kernels.iter().any(|k| k.contains("MetalScatterNoCopy")),
|
||||
"expected no-copy scatter for consumed destination, kernels: {:?}",
|
||||
kernels
|
||||
);
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -1512,89 +1031,6 @@ fn test_scatter_into_nonzero_dest() {
|
||||
assert_close(&out, &[1.0, 2.0, 99.0, 4.0, 5.0], 0.001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scatter_no_copy_remove_buffer_aliases_dest() {
|
||||
let mut cx = Graph::default();
|
||||
let src = cx.tensor(2);
|
||||
let indexes = cx.tensor(2).as_dtype(DType::Int);
|
||||
let dest = cx.tensor(5);
|
||||
let result = src.scatter(indexes, dest).output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.set_data(src, &[7.0, 8.0]);
|
||||
rt.set_data(indexes, &[1.0, 3.0]);
|
||||
rt.set_data(dest, &[10.0, 20.0, 30.0, 40.0, 50.0]);
|
||||
rt = cx.search(rt, 1);
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
let moved = rt.remove_buffer(result);
|
||||
let moved_values = unsafe {
|
||||
std::slice::from_raw_parts(
|
||||
moved.contents() as *const f32,
|
||||
moved.length() as usize / std::mem::size_of::<f32>(),
|
||||
)
|
||||
.to_vec()
|
||||
};
|
||||
assert_close(&moved_values, &[10.0, 7.0, 30.0, 8.0, 50.0], 0.001);
|
||||
rt.set_buffer(dest.id, moved);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scatter_no_copy_handles_2d_destination() {
|
||||
let mut cx = Graph::default();
|
||||
let src = cx.tensor(2);
|
||||
let indexes = cx.tensor(2).as_dtype(DType::Int);
|
||||
let dest = cx.tensor((2, 3));
|
||||
let result = src.scatter(indexes, dest).output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.set_data(src, &[9.0, 8.0]);
|
||||
rt.set_data(indexes, &[2.0, 4.0]);
|
||||
rt.set_data(dest, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
|
||||
rt = cx.search(rt, 1);
|
||||
let kernels = rt.debug_kernel_ops();
|
||||
assert!(
|
||||
kernels.iter().any(|k| k.contains("MetalScatterNoCopy")),
|
||||
"expected no-copy scatter for 2D destination, kernels: {:?}",
|
||||
kernels
|
||||
);
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
assert_close(&rt.get_f32(result), &[1.0, 2.0, 9.0, 4.0, 8.0, 6.0], 0.001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scatter_no_copy_not_selected_when_dest_has_another_consumer() {
|
||||
let mut cx = Graph::default();
|
||||
let src = cx.tensor(1);
|
||||
let indexes = cx.tensor(1).as_dtype(DType::Int);
|
||||
let dest = cx.tensor(4);
|
||||
let scatter = src.scatter(indexes, dest).output();
|
||||
let dest_plus_one = (dest + 1.0).output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.set_data(src, &[99.0]);
|
||||
rt.set_data(indexes, &[1.0]);
|
||||
rt.set_data(dest, &[10.0, 20.0, 30.0, 40.0]);
|
||||
rt = cx.search(rt, 1);
|
||||
let kernels = rt.debug_kernel_ops();
|
||||
assert!(
|
||||
!kernels.iter().any(|k| k.contains("MetalScatterNoCopy")),
|
||||
"no-copy scatter should not be selected when dest is also consumed, kernels: {:?}",
|
||||
kernels
|
||||
);
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
assert_close(&rt.get_f32(scatter), &[10.0, 99.0, 30.0, 40.0], 0.001);
|
||||
assert_close(&rt.get_f32(dest_plus_one), &[11.0, 21.0, 31.0, 41.0], 0.001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scatter_all_positions() {
|
||||
let mut cx = Graph::default();
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
[package]
|
||||
name = "luminal_nn"
|
||||
version = "0.1.0"
|
||||
edition = "2024"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
|
||||
@@ -71,7 +71,7 @@ impl MoE {
|
||||
mod tests {
|
||||
use super::MoE;
|
||||
use luminal::prelude::*;
|
||||
use rand::{Rng, rng};
|
||||
use rand::{rng, Rng};
|
||||
|
||||
fn random_vec(n: usize) -> Vec<f32> {
|
||||
let mut r = rng();
|
||||
|
||||
@@ -431,7 +431,7 @@ def main() -> None:
|
||||
tokenizer = WhisperTokenizer.from_pretrained(REPO_ID)
|
||||
|
||||
use_compiled = os.environ.get("LUMINAL_DISABLE", "0") != "1"
|
||||
max_new_tokens = 100
|
||||
max_new_tokens = int(os.environ.get("GEN_TOKENS", "100"))
|
||||
search_iters = int(os.environ.get("SEARCH_ITERATIONS", "10"))
|
||||
|
||||
if use_compiled:
|
||||
|
||||
@@ -8,7 +8,7 @@ echo "=========================================="
|
||||
echo " Luminal Python: Full Test Suite"
|
||||
echo "=========================================="
|
||||
|
||||
NATIVE_TESTS="tests/test_hlir_ops.py tests/test_unary.py"
|
||||
NATIVE_TESTS="tests/test_hlir_ops.py tests/test_unary.py tests/test_input_layout.py tests/test_dtype_boundary.py tests/test_mutation_alias_contract.py"
|
||||
CUDA_TESTS="tests/"
|
||||
|
||||
# ── Phase 1: Native Backend ─────────────────────────────────
|
||||
|
||||
@@ -16,7 +16,7 @@ uv run maturin develop --manifest-path rust/Cargo.toml
|
||||
echo "Step 3: Running pytest..."
|
||||
# it is best not to add the full model tests, they end up running billion parameter models
|
||||
# on the CPU and it takes far to long
|
||||
uv run pytest tests/test_hlir_ops.py tests/test_unary.py -v
|
||||
uv run pytest tests/test_hlir_ops.py tests/test_unary.py tests/test_input_layout.py tests/test_dtype_boundary.py tests/test_mutation_alias_contract.py -v
|
||||
|
||||
echo ""
|
||||
echo "=== Tests Complete ==="
|
||||
|
||||
@@ -158,21 +158,17 @@ impl CompiledGraph {
|
||||
input_shape_exprs,
|
||||
dim_param_map,
|
||||
} = translation;
|
||||
let WeightData {
|
||||
weights,
|
||||
tensor_sizes,
|
||||
device_ptrs,
|
||||
} = weight_data;
|
||||
|
||||
// Build compile args from WeightData.
|
||||
// Build compile args from WeightData (convert TypedData -> raw bytes + dtype)
|
||||
let compile_args = BackendCompileArgs {
|
||||
search_iters,
|
||||
weights: weights
|
||||
weights: weight_data
|
||||
.weights
|
||||
.iter()
|
||||
.map(|(label, td)| (label.clone(), td.bytes.clone(), td.dtype))
|
||||
.collect(),
|
||||
tensor_sizes,
|
||||
device_ptrs,
|
||||
tensor_sizes: weight_data.tensor_sizes,
|
||||
device_ptrs: weight_data.device_ptrs,
|
||||
};
|
||||
|
||||
// Create backend via the factory directly
|
||||
@@ -391,7 +387,7 @@ impl CompiledGraph {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Register a weight from a device pointer (e.g. "fc1.weight"). Zero-copy on device.
|
||||
/// Set a weight from a device pointer (e.g. "fc1.weight"). Zero-copy on device.
|
||||
/// Requires a GPU backend.
|
||||
fn set_weight_device_ptr(
|
||||
&mut self,
|
||||
@@ -452,7 +448,7 @@ impl CompiledGraph {
|
||||
Ok(self.runtime.output_is_zero_copy(*node_id))
|
||||
}
|
||||
|
||||
/// Register a weight tensor from a CPU host pointer, matching by Input node label (dtype-aware).
|
||||
/// Set a weight tensor from a CPU host pointer, matching by Input node label (dtype-aware).
|
||||
/// `n_bytes` is the total byte count. `dtype_code` uses PT2 numbering (7=f32, 6=f16, 13=bf16, etc.).
|
||||
fn set_weight_from_ptr(
|
||||
&mut self,
|
||||
|
||||
@@ -3,7 +3,6 @@ pub mod typed_data;
|
||||
|
||||
// PT2 modules
|
||||
mod pt2_compiled_model;
|
||||
mod pt2_expr;
|
||||
mod pt2_parser;
|
||||
mod pt2_schema;
|
||||
mod pt2_util;
|
||||
|
||||
@@ -6,7 +6,6 @@ use pyo3::types::{PyCapsule, PyCapsuleMethods};
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::compiled_graph::{CompiledGraph, DimParamMap, GraphTranslation, WeightData};
|
||||
use crate::pt2_expr::parse_sympy_expr;
|
||||
use crate::pt2_schema;
|
||||
use crate::translator;
|
||||
use crate::typed_data::TypedData;
|
||||
@@ -22,7 +21,7 @@ fn resolve_dim_sizes(
|
||||
sizes
|
||||
.iter()
|
||||
.map(|s| match s {
|
||||
pt2_schema::DimSize::Int(i) => Expression::from(i.as_int),
|
||||
pt2_schema::DimSize::Int(i) => Expression::from(i.as_int as usize),
|
||||
pt2_schema::DimSize::Expr(e) => {
|
||||
let s = e.as_expr.expr_str.trim();
|
||||
// Try the full sympy-style parse first so compound forms like
|
||||
@@ -46,7 +45,7 @@ fn resolve_dim_sizes(
|
||||
.hint
|
||||
.as_ref()
|
||||
.and_then(|h| h.as_int())
|
||||
.map(Expression::from)
|
||||
.map(|h| Expression::from(h as usize))
|
||||
})
|
||||
.unwrap_or_else(|| Expression::from(1usize))
|
||||
}
|
||||
@@ -54,6 +53,139 @@ fn resolve_dim_sizes(
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Parse a sympy `srepr`-style expression string into a luminal Expression.
|
||||
///
|
||||
/// Handles the subset of sympy heads PT2 actually emits for shape metadata:
|
||||
///
|
||||
/// * `Symbol('name', ...)` — bound to the corresponding luminal char if
|
||||
/// present in `sym_to_char`, or treated as a fresh constant 1 otherwise.
|
||||
/// * `Integer(N)` / `Number(N)` — concrete int.
|
||||
/// * `Mul(a, b, ...)` / `Add(a, b, ...)` — n-ary, folded into pairwise ops.
|
||||
///
|
||||
/// Returns `None` for anything else so the caller can fall back to a less
|
||||
/// precise representation rather than committing a wrong expression.
|
||||
fn parse_sympy_expr(s: &str, sym_to_char: &HashMap<String, char>) -> Option<Expression> {
|
||||
let s = s.trim();
|
||||
if s.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Bare integer literal — `srepr` doesn't usually emit this at the top
|
||||
// level (it wraps in `Integer(...)`), but accept it for robustness.
|
||||
if let Ok(n) = s.parse::<i64>() {
|
||||
return Some(Expression::from(n as usize));
|
||||
}
|
||||
|
||||
let (head, body) = split_head(s)?;
|
||||
match head {
|
||||
"Symbol" => {
|
||||
// Body is `'name', positive=True, integer=True` etc. Pull the
|
||||
// first quoted token as the name.
|
||||
let name = extract_first_quoted(body)?;
|
||||
sym_to_char.get(&name).map(|c| Expression::from(*c))
|
||||
}
|
||||
"Integer" | "Number" => {
|
||||
let n: i64 = body.trim().parse().ok()?;
|
||||
Some(Expression::from(n as usize))
|
||||
}
|
||||
"Mul" | "Add" => {
|
||||
let parts = split_top_level_args(body);
|
||||
if parts.is_empty() {
|
||||
return None;
|
||||
}
|
||||
let mut iter = parts.into_iter();
|
||||
let mut acc = parse_sympy_expr(iter.next()?, sym_to_char)?;
|
||||
for p in iter {
|
||||
let rhs = parse_sympy_expr(p, sym_to_char)?;
|
||||
acc = if head == "Mul" { acc * rhs } else { acc + rhs };
|
||||
}
|
||||
Some(acc)
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Split `Head(body)` into (head, body); returns None if not in that form.
|
||||
fn split_head(s: &str) -> Option<(&str, &str)> {
|
||||
let open = s.find('(')?;
|
||||
if !s.ends_with(')') {
|
||||
return None;
|
||||
}
|
||||
Some((&s[..open], &s[open + 1..s.len() - 1]))
|
||||
}
|
||||
|
||||
/// Pull out the first single- or double-quoted token from a sympy arg list,
|
||||
/// e.g. `'s77', positive=True` → `s77`.
|
||||
fn extract_first_quoted(s: &str) -> Option<String> {
|
||||
let bytes = s.as_bytes();
|
||||
let mut i = 0;
|
||||
while i < bytes.len() {
|
||||
let c = bytes[i] as char;
|
||||
if c == '\'' || c == '"' {
|
||||
let quote = c;
|
||||
let start = i + 1;
|
||||
i += 1;
|
||||
while i < bytes.len() && bytes[i] as char != quote {
|
||||
i += 1;
|
||||
}
|
||||
return Some(s[start..i].to_string());
|
||||
}
|
||||
i += 1;
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Split sympy-style argument list at top-level commas, respecting nested
|
||||
/// parens and quoted strings. Discards `key=value` kwargs (they don't carry
|
||||
/// dimensional information).
|
||||
fn split_top_level_args(s: &str) -> Vec<&str> {
|
||||
let mut out = Vec::new();
|
||||
let bytes = s.as_bytes();
|
||||
let mut depth = 0;
|
||||
let mut in_quote: Option<char> = None;
|
||||
let mut start = 0;
|
||||
for (i, &b) in bytes.iter().enumerate() {
|
||||
let c = b as char;
|
||||
match in_quote {
|
||||
Some(q) => {
|
||||
if c == q {
|
||||
in_quote = None;
|
||||
}
|
||||
}
|
||||
None => match c {
|
||||
'\'' | '"' => in_quote = Some(c),
|
||||
'(' | '[' => depth += 1,
|
||||
')' | ']' => depth -= 1,
|
||||
',' if depth == 0 => {
|
||||
let part = s[start..i].trim();
|
||||
// Drop `key=value` kwargs — they're metadata sympy uses
|
||||
// for pretty-printing, not arguments to the operator.
|
||||
if !part.is_empty() && !looks_like_kwarg(part) {
|
||||
out.push(part);
|
||||
}
|
||||
start = i + 1;
|
||||
}
|
||||
_ => {}
|
||||
},
|
||||
}
|
||||
}
|
||||
let part = s[start..].trim();
|
||||
if !part.is_empty() && !looks_like_kwarg(part) {
|
||||
out.push(part);
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
fn looks_like_kwarg(part: &str) -> bool {
|
||||
if let Some(eq) = part.find('=') {
|
||||
let key = part[..eq].trim();
|
||||
// sympy kwargs are bare identifiers like `positive`, `integer`.
|
||||
!key.is_empty() && key.chars().all(|c| c.is_ascii_alphanumeric() || c == '_')
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
#[pyo3(signature = (pt2_path, weights_path, search_iters, factory_capsule, weight_device_ptrs=None))]
|
||||
pub fn process_pt2(
|
||||
|
||||
@@ -1,713 +0,0 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use luminal::prelude::*;
|
||||
use rustc_hash::FxHashMap;
|
||||
|
||||
use crate::pt2_parser::SymDimMap;
|
||||
use crate::pt2_schema::RangeConstraint;
|
||||
|
||||
#[derive(Clone, Copy, Debug, Default)]
|
||||
pub(crate) struct ExprBounds {
|
||||
pub(crate) min: Option<i64>,
|
||||
pub(crate) max: Option<i64>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
struct ParsedExpr {
|
||||
expr: Expression,
|
||||
bounds: ExprBounds,
|
||||
}
|
||||
|
||||
impl ParsedExpr {
|
||||
fn exact(expr: Expression, value: i64) -> Self {
|
||||
Self {
|
||||
expr,
|
||||
bounds: ExprBounds {
|
||||
min: Some(value),
|
||||
max: Some(value),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
struct BoundedExpr {
|
||||
expr: Expression,
|
||||
bounds: ExprBounds,
|
||||
}
|
||||
|
||||
/// Parse a sympy `srepr`-style expression string into a luminal `Expression`.
|
||||
///
|
||||
/// Supports the subset of sympy heads PT2 emits for symbolic shape metadata.
|
||||
pub(crate) fn parse_sympy_expr(
|
||||
expr: &str,
|
||||
sym_to_char: &HashMap<String, char>,
|
||||
) -> Option<Expression> {
|
||||
parse_sympy_expr_with_ranges(expr, sym_to_char, &HashMap::new())
|
||||
}
|
||||
|
||||
pub(crate) fn parse_sympy_expr_with_ranges(
|
||||
expr: &str,
|
||||
sym_to_char: &HashMap<String, char>,
|
||||
ranges: &HashMap<String, RangeConstraint>,
|
||||
) -> Option<Expression> {
|
||||
parse_sympy_expr_inner(expr, sym_to_char, ranges).map(|parsed| parsed.expr)
|
||||
}
|
||||
|
||||
pub(crate) fn sym_char_ranges(sym_map: &SymDimMap) -> FxHashMap<char, ExprBounds> {
|
||||
sym_map
|
||||
.sym_to_char
|
||||
.iter()
|
||||
.map(|(sym_name, sym_char)| {
|
||||
let range = sym_map.ranges.get(sym_name);
|
||||
let min = range
|
||||
.and_then(|range| range.min_val)
|
||||
.map(|min| min.max(0))
|
||||
.or(Some(0));
|
||||
let max = range
|
||||
.and_then(|range| range.max_val)
|
||||
.filter(|max| *max >= 0);
|
||||
(*sym_char, ExprBounds { min, max })
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub(crate) fn simplify_expr_with_ranges(
|
||||
expr: Expression,
|
||||
sym_ranges: &FxHashMap<char, ExprBounds>,
|
||||
) -> Expression {
|
||||
simplify_bound_expr(expr, sym_ranges).expr
|
||||
}
|
||||
|
||||
pub(crate) fn same_expr_with_ranges(
|
||||
lhs: Expression,
|
||||
rhs: Expression,
|
||||
sym_ranges: &FxHashMap<char, ExprBounds>,
|
||||
) -> bool {
|
||||
let lhs = simplify_bound_expr(lhs, sym_ranges);
|
||||
let rhs = simplify_bound_expr(rhs, sym_ranges);
|
||||
lhs.expr == rhs.expr
|
||||
|| lhs.expr.egglog_equal(rhs.expr)
|
||||
|| (exact_value(lhs) == exact_value(rhs) && exact_value(lhs).is_some())
|
||||
}
|
||||
|
||||
pub(crate) fn canonical_equal_expr(
|
||||
lhs: Expression,
|
||||
rhs: Expression,
|
||||
sym_ranges: &FxHashMap<char, ExprBounds>,
|
||||
) -> Option<Expression> {
|
||||
if !same_expr_with_ranges(lhs, rhs, sym_ranges) {
|
||||
return None;
|
||||
}
|
||||
let lhs_simplified = simplify_expr_with_ranges(lhs, sym_ranges);
|
||||
let rhs_simplified = simplify_expr_with_ranges(rhs, sym_ranges);
|
||||
Some(if lhs_simplified.len() <= rhs_simplified.len() {
|
||||
lhs_simplified
|
||||
} else {
|
||||
rhs_simplified
|
||||
})
|
||||
}
|
||||
|
||||
fn parse_sympy_expr_inner(
|
||||
expr: &str,
|
||||
sym_to_char: &HashMap<String, char>,
|
||||
ranges: &HashMap<String, RangeConstraint>,
|
||||
) -> Option<ParsedExpr> {
|
||||
let expr = expr.trim();
|
||||
if expr.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
if let Ok(value) = expr.parse::<i64>() {
|
||||
return Some(ParsedExpr::exact(Expression::from(value), value));
|
||||
}
|
||||
|
||||
let (head, body) = split_head(expr)?;
|
||||
match head {
|
||||
"Symbol" => {
|
||||
let name = extract_first_quoted(body)?;
|
||||
let bounds = infer_symbol_bounds(body, ranges.get(&name));
|
||||
sym_to_char.get(&name).map(|c| ParsedExpr {
|
||||
expr: Expression::from(*c),
|
||||
bounds,
|
||||
})
|
||||
}
|
||||
"Integer" | "Number" => {
|
||||
let value = body.trim().parse::<i64>().ok()?;
|
||||
Some(ParsedExpr::exact(Expression::from(value), value))
|
||||
}
|
||||
"NegativeOne" => Some(ParsedExpr::exact(Expression::from(-1i64), -1)),
|
||||
"Zero" => Some(ParsedExpr::exact(Expression::from(0i64), 0)),
|
||||
"One" => Some(ParsedExpr::exact(Expression::from(1i64), 1)),
|
||||
"Mul" | "Add" | "Min" | "Max" => {
|
||||
let parts = split_top_level_args(body);
|
||||
if parts.is_empty() {
|
||||
return None;
|
||||
}
|
||||
let mut iter = parts.into_iter();
|
||||
let mut acc = parse_sympy_expr_inner(iter.next()?, sym_to_char, ranges)?;
|
||||
for part in iter {
|
||||
let rhs = parse_sympy_expr_inner(part, sym_to_char, ranges)?;
|
||||
acc = match head {
|
||||
"Mul" => ParsedExpr {
|
||||
expr: normalize_mul_expr(acc.expr, rhs.expr),
|
||||
bounds: mul_bounds(acc.bounds, rhs.bounds),
|
||||
},
|
||||
"Add" => ParsedExpr {
|
||||
expr: normalize_add_expr(acc.expr, rhs.expr),
|
||||
bounds: add_bounds(acc.bounds, rhs.bounds),
|
||||
},
|
||||
"Min" => reduce_min(acc, rhs),
|
||||
"Max" => reduce_max(acc, rhs),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
}
|
||||
Some(acc)
|
||||
}
|
||||
"FloorDiv" => {
|
||||
let mut parts = split_top_level_args(body).into_iter();
|
||||
let lhs = parse_sympy_expr_inner(parts.next()?, sym_to_char, ranges)?;
|
||||
let rhs = parse_sympy_expr_inner(parts.next()?, sym_to_char, ranges)?;
|
||||
if parts.next().is_some() {
|
||||
return None;
|
||||
}
|
||||
Some(ParsedExpr {
|
||||
expr: lhs.expr / rhs.expr,
|
||||
bounds: div_bounds(lhs.bounds, rhs.bounds),
|
||||
})
|
||||
}
|
||||
"Mod" => {
|
||||
let mut parts = split_top_level_args(body).into_iter();
|
||||
let lhs = parse_sympy_expr_inner(parts.next()?, sym_to_char, ranges)?;
|
||||
let rhs = parse_sympy_expr_inner(parts.next()?, sym_to_char, ranges)?;
|
||||
if parts.next().is_some() {
|
||||
return None;
|
||||
}
|
||||
Some(ParsedExpr {
|
||||
expr: lhs.expr % rhs.expr,
|
||||
bounds: mod_bounds(lhs.bounds, rhs.bounds),
|
||||
})
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn infer_symbol_bounds(body: &str, range: Option<&RangeConstraint>) -> ExprBounds {
|
||||
let mut bounds = ExprBounds::default();
|
||||
if body.contains("positive=True") {
|
||||
bounds.min = Some(1);
|
||||
} else if body.contains("nonnegative=True") {
|
||||
bounds.min = Some(0);
|
||||
}
|
||||
if let Some(range) = range {
|
||||
bounds.min = match (bounds.min, range.min_val) {
|
||||
(Some(lhs), Some(rhs)) => Some(lhs.max(rhs)),
|
||||
(None, Some(rhs)) => Some(rhs),
|
||||
(lhs, None) => lhs,
|
||||
};
|
||||
bounds.max = range.max_val;
|
||||
}
|
||||
bounds
|
||||
}
|
||||
|
||||
fn exact_expr(value: i64) -> BoundedExpr {
|
||||
BoundedExpr {
|
||||
expr: Expression::from(value),
|
||||
bounds: ExprBounds {
|
||||
min: Some(value),
|
||||
max: Some(value),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn exact_value(expr: BoundedExpr) -> Option<i64> {
|
||||
expr.expr.as_num().or({
|
||||
(expr.bounds.min == expr.bounds.max)
|
||||
.then_some(expr.bounds.min)
|
||||
.flatten()
|
||||
})
|
||||
}
|
||||
|
||||
fn exact_bound_value(bounds: ExprBounds) -> Option<i64> {
|
||||
(bounds.min == bounds.max).then_some(bounds.min).flatten()
|
||||
}
|
||||
|
||||
fn with_bounds(expr: Expression, bounds: ExprBounds) -> BoundedExpr {
|
||||
BoundedExpr { expr, bounds }
|
||||
}
|
||||
|
||||
fn bool_bounds() -> ExprBounds {
|
||||
ExprBounds {
|
||||
min: Some(0),
|
||||
max: Some(1),
|
||||
}
|
||||
}
|
||||
|
||||
fn normalize_expr(expr: Expression) -> Expression {
|
||||
if expr.len() <= 16 {
|
||||
expr.simplify()
|
||||
} else {
|
||||
expr
|
||||
}
|
||||
}
|
||||
|
||||
fn commutative_key(expr: Expression) -> (usize, String) {
|
||||
(expr.len(), format!("{expr:?}"))
|
||||
}
|
||||
|
||||
fn sort_commutative(lhs: Expression, rhs: Expression) -> (Expression, Expression) {
|
||||
if commutative_key(lhs) <= commutative_key(rhs) {
|
||||
(lhs, rhs)
|
||||
} else {
|
||||
(rhs, lhs)
|
||||
}
|
||||
}
|
||||
|
||||
fn normalize_add_expr(lhs: Expression, rhs: Expression) -> Expression {
|
||||
let (lhs, rhs) = sort_commutative(lhs, rhs);
|
||||
normalize_expr(lhs + rhs)
|
||||
}
|
||||
|
||||
fn normalize_mul_expr(lhs: Expression, rhs: Expression) -> Expression {
|
||||
let (lhs, rhs) = sort_commutative(lhs, rhs);
|
||||
normalize_expr(lhs * rhs)
|
||||
}
|
||||
|
||||
fn checked_add_opt(lhs: Option<i64>, rhs: Option<i64>) -> Option<i64> {
|
||||
lhs.zip(rhs).and_then(|(lhs, rhs)| lhs.checked_add(rhs))
|
||||
}
|
||||
|
||||
fn checked_sub_opt(lhs: Option<i64>, rhs: Option<i64>) -> Option<i64> {
|
||||
lhs.zip(rhs).and_then(|(lhs, rhs)| lhs.checked_sub(rhs))
|
||||
}
|
||||
|
||||
fn checked_mul_opt(lhs: Option<i64>, rhs: Option<i64>) -> Option<i64> {
|
||||
lhs.zip(rhs).and_then(|(lhs, rhs)| lhs.checked_mul(rhs))
|
||||
}
|
||||
|
||||
fn add_bounds(lhs: ExprBounds, rhs: ExprBounds) -> ExprBounds {
|
||||
ExprBounds {
|
||||
min: checked_add_opt(lhs.min, rhs.min),
|
||||
max: checked_add_opt(lhs.max, rhs.max),
|
||||
}
|
||||
}
|
||||
|
||||
fn mul_bounds(lhs: ExprBounds, rhs: ExprBounds) -> ExprBounds {
|
||||
if lhs.min.unwrap_or(i64::MIN) >= 0 && rhs.min.unwrap_or(i64::MIN) >= 0 {
|
||||
return ExprBounds {
|
||||
min: checked_mul_opt(lhs.min, rhs.min),
|
||||
max: checked_mul_opt(lhs.max, rhs.max),
|
||||
};
|
||||
}
|
||||
ExprBounds::default()
|
||||
}
|
||||
|
||||
fn sub_bounds(lhs: ExprBounds, rhs: ExprBounds) -> ExprBounds {
|
||||
ExprBounds {
|
||||
min: checked_sub_opt(lhs.min, rhs.max),
|
||||
max: checked_sub_opt(lhs.max, rhs.min),
|
||||
}
|
||||
}
|
||||
|
||||
fn div_bounds(lhs: ExprBounds, rhs: ExprBounds) -> ExprBounds {
|
||||
let (Some(rhs_min), Some(rhs_max)) = (rhs.min, rhs.max) else {
|
||||
return ExprBounds::default();
|
||||
};
|
||||
if rhs_min <= 0 || rhs_max <= 0 {
|
||||
return ExprBounds::default();
|
||||
}
|
||||
ExprBounds {
|
||||
min: lhs.min.and_then(|lhs_min| lhs_min.checked_div(rhs_max)),
|
||||
max: lhs.max.and_then(|lhs_max| lhs_max.checked_div(rhs_min)),
|
||||
}
|
||||
}
|
||||
|
||||
fn mod_bounds(lhs: ExprBounds, rhs: ExprBounds) -> ExprBounds {
|
||||
if lhs.min.unwrap_or(i64::MIN) < 0 {
|
||||
return ExprBounds::default();
|
||||
}
|
||||
match exact_bound_value(rhs) {
|
||||
Some(rhs_exact) if rhs_exact > 0 => ExprBounds {
|
||||
min: Some(0),
|
||||
max: rhs_exact.checked_sub(1),
|
||||
},
|
||||
_ => ExprBounds::default(),
|
||||
}
|
||||
}
|
||||
|
||||
fn reduce_min(lhs: ParsedExpr, rhs: ParsedExpr) -> ParsedExpr {
|
||||
if lhs.expr == rhs.expr || lhs.expr.egglog_equal(rhs.expr) {
|
||||
return ParsedExpr {
|
||||
expr: lhs.expr,
|
||||
bounds: min_bounds(lhs.bounds, rhs.bounds),
|
||||
};
|
||||
}
|
||||
if let (Some(lhs_max), Some(rhs_min)) = (lhs.bounds.max, rhs.bounds.min)
|
||||
&& lhs_max <= rhs_min
|
||||
{
|
||||
return lhs;
|
||||
}
|
||||
if let (Some(rhs_max), Some(lhs_min)) = (rhs.bounds.max, lhs.bounds.min)
|
||||
&& rhs_max <= lhs_min
|
||||
{
|
||||
return rhs;
|
||||
}
|
||||
if expr_is_offset_by_small_const(lhs.expr, rhs.expr) {
|
||||
return rhs;
|
||||
}
|
||||
if expr_is_offset_by_small_const(rhs.expr, lhs.expr) {
|
||||
return lhs;
|
||||
}
|
||||
ParsedExpr {
|
||||
expr: lhs.expr.min(rhs.expr),
|
||||
bounds: min_bounds(lhs.bounds, rhs.bounds),
|
||||
}
|
||||
}
|
||||
|
||||
fn reduce_max(lhs: ParsedExpr, rhs: ParsedExpr) -> ParsedExpr {
|
||||
if lhs.expr == rhs.expr || lhs.expr.egglog_equal(rhs.expr) {
|
||||
return ParsedExpr {
|
||||
expr: lhs.expr,
|
||||
bounds: max_bounds(lhs.bounds, rhs.bounds),
|
||||
};
|
||||
}
|
||||
if let (Some(lhs_max), Some(rhs_min)) = (lhs.bounds.max, rhs.bounds.min)
|
||||
&& lhs_max <= rhs_min
|
||||
{
|
||||
return rhs;
|
||||
}
|
||||
if let (Some(rhs_max), Some(lhs_min)) = (rhs.bounds.max, lhs.bounds.min)
|
||||
&& rhs_max <= lhs_min
|
||||
{
|
||||
return lhs;
|
||||
}
|
||||
if expr_is_offset_by_small_const(lhs.expr, rhs.expr) {
|
||||
return lhs;
|
||||
}
|
||||
if expr_is_offset_by_small_const(rhs.expr, lhs.expr) {
|
||||
return rhs;
|
||||
}
|
||||
ParsedExpr {
|
||||
expr: lhs.expr.max(rhs.expr),
|
||||
bounds: max_bounds(lhs.bounds, rhs.bounds),
|
||||
}
|
||||
}
|
||||
|
||||
fn min_bounds(lhs: ExprBounds, rhs: ExprBounds) -> ExprBounds {
|
||||
ExprBounds {
|
||||
min: match (lhs.min, rhs.min) {
|
||||
(Some(lhs), Some(rhs)) => Some(lhs.min(rhs)),
|
||||
_ => None,
|
||||
},
|
||||
max: match (lhs.max, rhs.max) {
|
||||
(Some(lhs), Some(rhs)) => Some(lhs.min(rhs)),
|
||||
_ => None,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn max_bounds(lhs: ExprBounds, rhs: ExprBounds) -> ExprBounds {
|
||||
ExprBounds {
|
||||
min: match (lhs.min, rhs.min) {
|
||||
(Some(lhs), Some(rhs)) => Some(lhs.max(rhs)),
|
||||
_ => None,
|
||||
},
|
||||
max: match (lhs.max, rhs.max) {
|
||||
(Some(lhs), Some(rhs)) => Some(lhs.max(rhs)),
|
||||
_ => None,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn expr_is_offset_by_small_const(lhs: Expression, rhs: Expression) -> bool {
|
||||
(1..=8).any(|delta| lhs.egglog_equal(rhs + delta))
|
||||
}
|
||||
|
||||
fn split_add_const(expr: Expression) -> Option<(i64, Expression)> {
|
||||
let terms = expr.terms.read();
|
||||
if terms.len() >= 3 && terms.last() == Some(&Term::Add) {
|
||||
if let Some(Term::Num(n)) = terms.first() {
|
||||
return Some((*n, Expression::new(terms[1..terms.len() - 1].to_vec())));
|
||||
}
|
||||
if let Some(Term::Num(n)) = terms.get(terms.len() - 2) {
|
||||
return Some((*n, Expression::new(terms[..terms.len() - 2].to_vec())));
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
fn simplify_add(lhs: BoundedExpr, rhs: BoundedExpr) -> BoundedExpr {
|
||||
let expr = match (exact_value(lhs), exact_value(rhs)) {
|
||||
(Some(0), _) => rhs.expr,
|
||||
(_, Some(0)) => lhs.expr,
|
||||
(Some(lhs), Some(rhs)) => Expression::from(lhs + rhs),
|
||||
(_, Some(rhs)) => normalize_add_expr(lhs.expr, Expression::from(rhs)),
|
||||
(Some(lhs), _) => normalize_add_expr(Expression::from(lhs), rhs.expr),
|
||||
_ => normalize_add_expr(lhs.expr, rhs.expr),
|
||||
};
|
||||
with_bounds(expr, add_bounds(lhs.bounds, rhs.bounds))
|
||||
}
|
||||
|
||||
fn simplify_sub(
|
||||
lhs: BoundedExpr,
|
||||
rhs: BoundedExpr,
|
||||
sym_ranges: &FxHashMap<char, ExprBounds>,
|
||||
) -> BoundedExpr {
|
||||
if same_expr_with_ranges(lhs.expr, rhs.expr, sym_ranges) {
|
||||
return exact_expr(0);
|
||||
}
|
||||
let expr = match exact_value(rhs) {
|
||||
Some(0) => lhs.expr,
|
||||
Some(rhs_const) => {
|
||||
if let Some((lhs_const, lhs_base)) = split_add_const(lhs.expr) {
|
||||
normalize_expr(lhs_base + (lhs_const - rhs_const))
|
||||
} else {
|
||||
normalize_expr(lhs.expr - rhs_const)
|
||||
}
|
||||
}
|
||||
None => normalize_expr(lhs.expr - rhs.expr),
|
||||
};
|
||||
with_bounds(expr, sub_bounds(lhs.bounds, rhs.bounds))
|
||||
}
|
||||
|
||||
fn simplify_min(
|
||||
lhs: BoundedExpr,
|
||||
rhs: BoundedExpr,
|
||||
sym_ranges: &FxHashMap<char, ExprBounds>,
|
||||
) -> BoundedExpr {
|
||||
let bounds = min_bounds(lhs.bounds, rhs.bounds);
|
||||
if same_expr_with_ranges(lhs.expr, rhs.expr, sym_ranges) {
|
||||
return with_bounds(lhs.expr, bounds);
|
||||
}
|
||||
if let (Some(lhs_max), Some(rhs_min)) = (lhs.bounds.max, rhs.bounds.min)
|
||||
&& lhs_max <= rhs_min
|
||||
{
|
||||
return with_bounds(lhs.expr, bounds);
|
||||
}
|
||||
if let (Some(rhs_max), Some(lhs_min)) = (rhs.bounds.max, lhs.bounds.min)
|
||||
&& rhs_max <= lhs_min
|
||||
{
|
||||
return with_bounds(rhs.expr, bounds);
|
||||
}
|
||||
if let Some((lhs_const, lhs_base)) = split_add_const(lhs.expr)
|
||||
&& lhs_const >= 0
|
||||
&& same_expr_with_ranges(lhs_base, rhs.expr, sym_ranges)
|
||||
{
|
||||
return with_bounds(rhs.expr, bounds);
|
||||
}
|
||||
if let Some((rhs_const, rhs_base)) = split_add_const(rhs.expr)
|
||||
&& rhs_const >= 0
|
||||
&& same_expr_with_ranges(rhs_base, lhs.expr, sym_ranges)
|
||||
{
|
||||
return with_bounds(lhs.expr, bounds);
|
||||
}
|
||||
with_bounds(normalize_expr(lhs.expr.min(rhs.expr)), bounds)
|
||||
}
|
||||
|
||||
fn simplify_max(
|
||||
lhs: BoundedExpr,
|
||||
rhs: BoundedExpr,
|
||||
sym_ranges: &FxHashMap<char, ExprBounds>,
|
||||
) -> BoundedExpr {
|
||||
let bounds = max_bounds(lhs.bounds, rhs.bounds);
|
||||
if same_expr_with_ranges(lhs.expr, rhs.expr, sym_ranges) {
|
||||
return with_bounds(lhs.expr, bounds);
|
||||
}
|
||||
if let (Some(lhs_max), Some(rhs_min)) = (lhs.bounds.max, rhs.bounds.min)
|
||||
&& lhs_max <= rhs_min
|
||||
{
|
||||
return with_bounds(rhs.expr, bounds);
|
||||
}
|
||||
if let (Some(rhs_max), Some(lhs_min)) = (rhs.bounds.max, lhs.bounds.min)
|
||||
&& rhs_max <= lhs_min
|
||||
{
|
||||
return with_bounds(lhs.expr, bounds);
|
||||
}
|
||||
if let Some((lhs_const, lhs_base)) = split_add_const(lhs.expr)
|
||||
&& lhs_const >= 0
|
||||
&& same_expr_with_ranges(lhs_base, rhs.expr, sym_ranges)
|
||||
{
|
||||
return with_bounds(lhs.expr, bounds);
|
||||
}
|
||||
if let Some((rhs_const, rhs_base)) = split_add_const(rhs.expr)
|
||||
&& rhs_const >= 0
|
||||
&& same_expr_with_ranges(rhs_base, lhs.expr, sym_ranges)
|
||||
{
|
||||
return with_bounds(rhs.expr, bounds);
|
||||
}
|
||||
with_bounds(normalize_expr(lhs.expr.max(rhs.expr)), bounds)
|
||||
}
|
||||
|
||||
fn simplify_bound_expr(expr: Expression, sym_ranges: &FxHashMap<char, ExprBounds>) -> BoundedExpr {
|
||||
let mut stack: Vec<BoundedExpr> = Vec::new();
|
||||
let terms = expr.terms.read().clone();
|
||||
for term in terms {
|
||||
match term {
|
||||
Term::Num(n) => stack.push(exact_expr(n)),
|
||||
Term::Var(c) => stack.push(with_bounds(
|
||||
Expression::from(c),
|
||||
sym_ranges.get(&c).copied().unwrap_or_default(),
|
||||
)),
|
||||
Term::Add => {
|
||||
let lhs = stack.pop().unwrap();
|
||||
let rhs = stack.pop().unwrap();
|
||||
stack.push(simplify_add(lhs, rhs));
|
||||
}
|
||||
Term::Sub => {
|
||||
let lhs = stack.pop().unwrap();
|
||||
let rhs = stack.pop().unwrap();
|
||||
stack.push(simplify_sub(lhs, rhs, sym_ranges));
|
||||
}
|
||||
Term::Mul => {
|
||||
let lhs = stack.pop().unwrap();
|
||||
let rhs = stack.pop().unwrap();
|
||||
let expr = match (exact_value(lhs), exact_value(rhs)) {
|
||||
(Some(0), _) | (_, Some(0)) => Expression::from(0),
|
||||
(Some(1), _) => rhs.expr,
|
||||
(_, Some(1)) => lhs.expr,
|
||||
(Some(lhs), Some(rhs)) => Expression::from(lhs * rhs),
|
||||
_ => normalize_mul_expr(lhs.expr, rhs.expr),
|
||||
};
|
||||
stack.push(with_bounds(expr, mul_bounds(lhs.bounds, rhs.bounds)));
|
||||
}
|
||||
Term::Div | Term::CeilDiv => {
|
||||
let lhs = stack.pop().unwrap();
|
||||
let rhs = stack.pop().unwrap();
|
||||
let expr = match (term, exact_value(lhs), exact_value(rhs)) {
|
||||
(_, Some(0), _) => Expression::from(0),
|
||||
(_, _, Some(1)) => lhs.expr,
|
||||
(Term::Div, Some(lhs), Some(rhs)) if rhs != 0 => Expression::from(lhs / rhs),
|
||||
(Term::CeilDiv, Some(lhs), Some(rhs)) if rhs > 0 => {
|
||||
Expression::from(if lhs % rhs != 0 {
|
||||
lhs / rhs + 1
|
||||
} else {
|
||||
lhs / rhs
|
||||
})
|
||||
}
|
||||
(Term::Div, _, _) => normalize_expr(lhs.expr / rhs.expr),
|
||||
(Term::CeilDiv, _, _) => normalize_expr(lhs.expr.ceil_div(rhs.expr)),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
stack.push(with_bounds(expr, div_bounds(lhs.bounds, rhs.bounds)));
|
||||
}
|
||||
Term::Mod => {
|
||||
let lhs = stack.pop().unwrap();
|
||||
let rhs = stack.pop().unwrap();
|
||||
let expr = match (exact_value(lhs), exact_value(rhs)) {
|
||||
(Some(0), _) | (_, Some(1)) => Expression::from(0),
|
||||
(Some(lhs), Some(rhs)) if rhs != 0 => Expression::from(lhs % rhs),
|
||||
_ => normalize_expr(lhs.expr % rhs.expr),
|
||||
};
|
||||
stack.push(with_bounds(expr, mod_bounds(lhs.bounds, rhs.bounds)));
|
||||
}
|
||||
Term::Min => {
|
||||
let lhs = stack.pop().unwrap();
|
||||
let rhs = stack.pop().unwrap();
|
||||
stack.push(simplify_min(lhs, rhs, sym_ranges));
|
||||
}
|
||||
Term::Max => {
|
||||
let lhs = stack.pop().unwrap();
|
||||
let rhs = stack.pop().unwrap();
|
||||
stack.push(simplify_max(lhs, rhs, sym_ranges));
|
||||
}
|
||||
term @ (Term::And | Term::Or | Term::Gte | Term::Lt) => {
|
||||
let lhs = stack.pop().unwrap();
|
||||
let rhs = stack.pop().unwrap();
|
||||
let expr = match (term, exact_value(lhs), exact_value(rhs)) {
|
||||
(Term::And, Some(lhs), Some(rhs)) => {
|
||||
Expression::from((lhs != 0 && rhs != 0) as i64)
|
||||
}
|
||||
(Term::And, _, _) => normalize_expr(lhs.expr & rhs.expr),
|
||||
(Term::Or, Some(lhs), Some(rhs)) => {
|
||||
Expression::from((lhs != 0 || rhs != 0) as i64)
|
||||
}
|
||||
(Term::Or, _, _) => normalize_expr(lhs.expr | rhs.expr),
|
||||
(Term::Gte, Some(lhs), Some(rhs)) => Expression::from((lhs >= rhs) as i64),
|
||||
(Term::Gte, _, _) => normalize_expr(lhs.expr.gte(rhs.expr)),
|
||||
(Term::Lt, Some(lhs), Some(rhs)) => Expression::from((lhs < rhs) as i64),
|
||||
(Term::Lt, _, _) => normalize_expr(lhs.expr.lt(rhs.expr)),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
stack.push(with_bounds(expr, bool_bounds()));
|
||||
}
|
||||
}
|
||||
}
|
||||
stack
|
||||
.pop()
|
||||
.unwrap_or(with_bounds(expr, ExprBounds::default()))
|
||||
}
|
||||
|
||||
/// Split `Head(body)` into `(head, body)`.
|
||||
fn split_head(expr: &str) -> Option<(&str, &str)> {
|
||||
let open = expr.find('(')?;
|
||||
if !expr.ends_with(')') {
|
||||
return None;
|
||||
}
|
||||
Some((&expr[..open], &expr[open + 1..expr.len() - 1]))
|
||||
}
|
||||
|
||||
/// Pull out the first single- or double-quoted token from a sympy arg list.
|
||||
fn extract_first_quoted(expr: &str) -> Option<String> {
|
||||
let bytes = expr.as_bytes();
|
||||
let mut i = 0;
|
||||
while i < bytes.len() {
|
||||
let c = bytes[i] as char;
|
||||
if c == '\'' || c == '"' {
|
||||
let quote = c;
|
||||
let start = i + 1;
|
||||
i += 1;
|
||||
while i < bytes.len() && bytes[i] as char != quote {
|
||||
i += 1;
|
||||
}
|
||||
return Some(expr[start..i].to_string());
|
||||
}
|
||||
i += 1;
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Split a sympy-style argument list at top-level commas, respecting nested
|
||||
/// parens and quoted strings. Drops `key=value` kwargs.
|
||||
fn split_top_level_args(expr: &str) -> Vec<&str> {
|
||||
let mut out = Vec::new();
|
||||
let bytes = expr.as_bytes();
|
||||
let mut depth = 0;
|
||||
let mut in_quote: Option<char> = None;
|
||||
let mut start = 0;
|
||||
for (i, &b) in bytes.iter().enumerate() {
|
||||
let c = b as char;
|
||||
match in_quote {
|
||||
Some(q) => {
|
||||
if c == q {
|
||||
in_quote = None;
|
||||
}
|
||||
}
|
||||
None => match c {
|
||||
'\'' | '"' => in_quote = Some(c),
|
||||
'(' | '[' => depth += 1,
|
||||
')' | ']' => depth -= 1,
|
||||
',' if depth == 0 => {
|
||||
let part = expr[start..i].trim();
|
||||
if !part.is_empty() && !looks_like_kwarg(part) {
|
||||
out.push(part);
|
||||
}
|
||||
start = i + 1;
|
||||
}
|
||||
_ => {}
|
||||
},
|
||||
}
|
||||
}
|
||||
let part = expr[start..].trim();
|
||||
if !part.is_empty() && !looks_like_kwarg(part) {
|
||||
out.push(part);
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
fn looks_like_kwarg(part: &str) -> bool {
|
||||
if let Some(eq) = part.find('=') {
|
||||
let key = part[..eq].trim();
|
||||
return !key.is_empty() && key.chars().all(|c| c == '_' || c.is_ascii_alphanumeric());
|
||||
}
|
||||
false
|
||||
}
|
||||
@@ -1,9 +1,5 @@
|
||||
use luminal::prelude::*;
|
||||
|
||||
fn same_dim(lhs: Expression, rhs: Expression) -> bool {
|
||||
lhs == rhs || lhs.simplify() == rhs.simplify() || lhs.egglog_equal(rhs)
|
||||
}
|
||||
|
||||
/// Binary operation type.
|
||||
#[derive(Clone, Copy)]
|
||||
pub enum BinaryOp {
|
||||
@@ -55,7 +51,7 @@ pub fn broadcast_binary(mut a: GraphTensor, mut b: GraphTensor) -> (GraphTensor,
|
||||
let a_dim = a.shape.dims[i];
|
||||
let b_dim = b.shape.dims[i];
|
||||
|
||||
if same_dim(a_dim, b_dim) {
|
||||
if a_dim == b_dim {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,40 +1,11 @@
|
||||
use anyhow::Result;
|
||||
use luminal::prelude::*;
|
||||
use rustc_hash::FxHashMap;
|
||||
|
||||
use crate::pt2_expr::{ExprBounds, canonical_equal_expr, same_expr_with_ranges, sym_char_ranges};
|
||||
use crate::pt2_schema::*;
|
||||
use crate::pt2_util::*;
|
||||
|
||||
use super::Translator;
|
||||
|
||||
fn normalize_equal_dims(
|
||||
a: &mut GraphTensor,
|
||||
b: &mut GraphTensor,
|
||||
sym_ranges: &FxHashMap<char, ExprBounds>,
|
||||
) {
|
||||
for i in 0..a.shape.len() {
|
||||
let lhs = a.shape.dims[i];
|
||||
let rhs = b.shape.dims[i];
|
||||
if let Some(canonical) = canonical_equal_expr(lhs, rhs, sym_ranges) {
|
||||
a.shape.dims[i] = canonical;
|
||||
b.shape.dims[i] = canonical;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn same_dims(
|
||||
lhs: &[Expression],
|
||||
rhs: &[Expression],
|
||||
sym_ranges: &FxHashMap<char, ExprBounds>,
|
||||
) -> bool {
|
||||
lhs.len() == rhs.len()
|
||||
&& lhs
|
||||
.iter()
|
||||
.zip(rhs.iter())
|
||||
.all(|(lhs, rhs)| same_expr_with_ranges(*lhs, *rhs, sym_ranges))
|
||||
}
|
||||
|
||||
impl<'a> Translator<'a> {
|
||||
pub(crate) fn translate_binary_op(&mut self, node: &Node, op: BinaryOp) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
@@ -42,18 +13,7 @@ impl<'a> Translator<'a> {
|
||||
if let Some(name) = arg1.as_tensor_name() {
|
||||
let b = self.get_tensor(name)?;
|
||||
let (a, b) = ensure_same_dtype(a, b);
|
||||
let (mut a, mut b) = broadcast_binary(a, b);
|
||||
let sym_ranges = sym_char_ranges(&self.sym_map);
|
||||
normalize_equal_dims(&mut a, &mut b, &sym_ranges);
|
||||
let lhs_dims = a.dims();
|
||||
let rhs_dims = b.dims();
|
||||
if !same_dims(&lhs_dims, &rhs_dims, &sym_ranges) {
|
||||
anyhow::bail!(
|
||||
"binary op {} still has mismatched dims after broadcast: lhs={lhs_dims:?} rhs={rhs_dims:?} inputs={:?}",
|
||||
node.target,
|
||||
node.inputs
|
||||
);
|
||||
}
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
Ok(match op {
|
||||
BinaryOp::Add => a + b,
|
||||
BinaryOp::Mul => a * b,
|
||||
@@ -61,12 +21,6 @@ impl<'a> Translator<'a> {
|
||||
BinaryOp::Div => a / b,
|
||||
})
|
||||
} else {
|
||||
if let Some(f) = arg1.as_float() {
|
||||
return Ok(self.apply_scalar_op(a, f as f32, op));
|
||||
}
|
||||
if let Some(expr) = self.resolve_arg_as_expression(arg1) {
|
||||
return Ok(self.apply_symbolic_scalar_op(a, expr, op));
|
||||
}
|
||||
let val = self.get_float_arg(node, 1)? as f32;
|
||||
Ok(self.apply_scalar_op(a, val, op))
|
||||
}
|
||||
@@ -78,13 +32,6 @@ impl<'a> Translator<'a> {
|
||||
op: BinaryOp,
|
||||
) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let arg1 = &node.inputs[1].arg;
|
||||
if let Some(f) = arg1.as_float() {
|
||||
return Ok(self.apply_scalar_op(a, f as f32, op));
|
||||
}
|
||||
if let Some(expr) = self.resolve_arg_as_expression(arg1) {
|
||||
return Ok(self.apply_symbolic_scalar_op(a, expr, op));
|
||||
}
|
||||
let val = self.get_float_arg(node, 1)? as f32;
|
||||
Ok(self.apply_scalar_op(a, val, op))
|
||||
}
|
||||
@@ -107,47 +54,4 @@ impl<'a> Translator<'a> {
|
||||
BinaryOp::Div => a / scalar,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn apply_symbolic_scalar_op(
|
||||
&mut self,
|
||||
a: GraphTensor,
|
||||
val: Expression,
|
||||
op: BinaryOp,
|
||||
) -> GraphTensor {
|
||||
match op {
|
||||
BinaryOp::Add => a + val,
|
||||
BinaryOp::Mul => a * val,
|
||||
BinaryOp::Sub => a - val,
|
||||
BinaryOp::Div => a / val,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::pt2_expr::simplify_expr_with_ranges;
|
||||
|
||||
#[test]
|
||||
fn simplifies_mark_dynamic_slice_shapes_using_lower_bound() {
|
||||
let a = Expression::from('a');
|
||||
let lhs = (a.min(1) + a).min(a + 1) - 1;
|
||||
let rhs = (a.min(1) + a).min(a);
|
||||
let sym_ranges = [(
|
||||
'a',
|
||||
ExprBounds {
|
||||
min: Some(2),
|
||||
max: None,
|
||||
},
|
||||
)]
|
||||
.into_iter()
|
||||
.collect::<FxHashMap<_, _>>();
|
||||
|
||||
let lhs_simplified = simplify_expr_with_ranges(lhs, &sym_ranges);
|
||||
let rhs_simplified = simplify_expr_with_ranges(rhs, &sym_ranges);
|
||||
|
||||
assert_eq!(lhs_simplified, Expression::from('a'));
|
||||
assert_eq!(rhs_simplified, Expression::from('a'));
|
||||
assert!(same_expr_with_ranges(lhs, rhs, &sym_ranges));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,7 +17,6 @@ use anyhow::{Context, Result};
|
||||
use luminal::graph::Graph;
|
||||
use luminal::prelude::*;
|
||||
|
||||
use crate::pt2_expr::parse_sympy_expr_with_ranges;
|
||||
use crate::pt2_parser::{InputKind, ParsedPT2, SymDimMap};
|
||||
use crate::pt2_schema::*;
|
||||
use crate::pt2_util;
|
||||
@@ -280,13 +279,13 @@ impl<'a> Translator<'a> {
|
||||
.with_context(|| format!("Node {} missing input {idx}", node.target))?
|
||||
.arg;
|
||||
if let Some(ints) = arg.as_ints() {
|
||||
return Ok(ints.iter().map(|&v| Expression::from(v)).collect());
|
||||
return Ok(ints.iter().map(|&v| Expression::from(v as usize)).collect());
|
||||
}
|
||||
if let Some(entries) = arg.as_sym_ints() {
|
||||
return entries
|
||||
.iter()
|
||||
.map(|entry| match entry {
|
||||
SymIntEntry::Int(i) => Ok(Expression::from(i.as_int)),
|
||||
SymIntEntry::Int(i) => Ok(Expression::from(i.as_int as usize)),
|
||||
SymIntEntry::Name(s) => self
|
||||
.resolve_sym_int(&s.as_name)
|
||||
.with_context(|| format!("Cannot resolve sym_int: {}", s.as_name)),
|
||||
@@ -319,13 +318,17 @@ impl<'a> Translator<'a> {
|
||||
|
||||
pub(crate) fn dim_size_to_expr(&self, dim: &DimSize) -> Result<Expression> {
|
||||
match dim {
|
||||
DimSize::Int(i) => Ok(Expression::from(i.as_int)),
|
||||
DimSize::Expr(e) => self.resolve_expr_value(&e.as_expr).with_context(|| {
|
||||
format!(
|
||||
"Cannot resolve symbolic dimension expression: {}",
|
||||
e.as_expr.expr_str
|
||||
)
|
||||
}),
|
||||
DimSize::Int(i) => Ok(Expression::from(i.as_int as usize)),
|
||||
DimSize::Expr(e) => {
|
||||
let sym_name = crate::pt2_parser::extract_symbol_name_pub(&e.as_expr.expr_str)
|
||||
.with_context(|| format!("Cannot parse symbol: {}", e.as_expr.expr_str))?;
|
||||
let c = self
|
||||
.sym_map
|
||||
.sym_to_char
|
||||
.get(&sym_name)
|
||||
.with_context(|| format!("Unknown symbol: {sym_name}"))?;
|
||||
Ok(Expression::from(*c))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -336,9 +339,10 @@ impl<'a> Translator<'a> {
|
||||
.get("as_expr")
|
||||
.and_then(|e| e.get("expr_str"))
|
||||
.and_then(|s| s.as_str())
|
||||
&& let Some(expr) = self.resolve_expr_str(expr_str)
|
||||
&& let Some(sym) = crate::pt2_parser::extract_symbol_name_pub(expr_str)
|
||||
&& let Some(&c) = self.sym_map.sym_to_char.get(&sym)
|
||||
{
|
||||
return Some(expr);
|
||||
return Some(Expression::from(c));
|
||||
}
|
||||
if let Some(hint) = val
|
||||
.get("as_expr")
|
||||
@@ -346,7 +350,7 @@ impl<'a> Translator<'a> {
|
||||
.and_then(|h| h.get("as_int"))
|
||||
.and_then(|v| v.as_i64())
|
||||
{
|
||||
return Some(Expression::from(hint));
|
||||
return Some(Expression::from(hint as usize));
|
||||
}
|
||||
}
|
||||
None
|
||||
@@ -354,32 +358,21 @@ impl<'a> Translator<'a> {
|
||||
|
||||
pub(crate) fn resolve_arg_as_expression(&self, arg: &Argument) -> Option<Expression> {
|
||||
if let Some(v) = arg.as_int() {
|
||||
return Some(Expression::from(v));
|
||||
return Some(Expression::from(v as usize));
|
||||
}
|
||||
if let Some(name) = arg.as_sym_int_name() {
|
||||
return self.resolve_sym_int(name);
|
||||
}
|
||||
if let Argument::Expr(e) = arg {
|
||||
return self.resolve_expr_value(&e.as_expr);
|
||||
if let Some(sym) = crate::pt2_parser::extract_symbol_name_pub(&e.as_expr.expr_str)
|
||||
&& let Some(&c) = self.sym_map.sym_to_char.get(&sym)
|
||||
{
|
||||
return Some(Expression::from(c));
|
||||
}
|
||||
if let Some(hint) = e.as_expr.hint.as_ref().and_then(|h| h.as_int()) {
|
||||
return Some(Expression::from(hint as usize));
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
pub(crate) fn resolve_expr_str(&self, expr_str: &str) -> Option<Expression> {
|
||||
parse_sympy_expr_with_ranges(expr_str, &self.sym_map.sym_to_char, &self.sym_map.ranges)
|
||||
.or_else(|| {
|
||||
crate::pt2_parser::extract_symbol_name_pub(expr_str)
|
||||
.and_then(|sym| self.sym_map.sym_to_char.get(&sym).copied())
|
||||
.map(Expression::from)
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn resolve_expr_value(&self, expr: &ExprValue) -> Option<Expression> {
|
||||
self.resolve_expr_str(&expr.expr_str).or_else(|| {
|
||||
expr.hint
|
||||
.as_ref()
|
||||
.and_then(|h| h.as_int())
|
||||
.map(Expression::from)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
use anyhow::{Context, Result, bail};
|
||||
use luminal::prelude::*;
|
||||
use rustc_hash::FxHashMap;
|
||||
|
||||
use crate::pt2_expr::{ExprBounds, canonical_equal_expr, sym_char_ranges};
|
||||
use crate::pt2_schema::*;
|
||||
use crate::pt2_util::*;
|
||||
|
||||
@@ -13,25 +11,6 @@ const SCATTER_DIM_ARG: usize = 1;
|
||||
const SCATTER_INDEX_ARG: usize = 2;
|
||||
const SCATTER_VALUE_ARG: usize = 3;
|
||||
|
||||
fn normalize_concat_dims(
|
||||
lhs: &mut GraphTensor,
|
||||
rhs: &mut GraphTensor,
|
||||
skip_dim: Option<usize>,
|
||||
sym_ranges: &FxHashMap<char, ExprBounds>,
|
||||
) {
|
||||
for i in 0..lhs.shape.len() {
|
||||
if Some(i) == skip_dim {
|
||||
continue;
|
||||
}
|
||||
let lhs_dim = lhs.shape.dims[i];
|
||||
let rhs_dim = rhs.shape.dims[i];
|
||||
if let Some(canonical) = canonical_equal_expr(lhs_dim, rhs_dim, sym_ranges) {
|
||||
lhs.shape.dims[i] = canonical;
|
||||
rhs.shape.dims[i] = canonical;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Translator<'a> {
|
||||
pub(crate) fn translate_reshape(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
@@ -222,17 +201,8 @@ impl<'a> Translator<'a> {
|
||||
|
||||
let dim = normalize_dim(dim, tensors[0].shape.len());
|
||||
let mut result = tensors[0];
|
||||
let sym_ranges = sym_char_ranges(&self.sym_map);
|
||||
for t in &tensors[1..] {
|
||||
let mut next = *t;
|
||||
normalize_concat_dims(&mut result, &mut next, Some(dim), &sym_ranges);
|
||||
|
||||
let lhs_axis = result.dims()[dim];
|
||||
let rhs_axis = next.dims()[dim];
|
||||
let mut lhs_padded = result.pad_along(0, rhs_axis, dim, 0.);
|
||||
let mut rhs_padded = next.pad_along(lhs_axis, 0, dim, 0.);
|
||||
normalize_concat_dims(&mut lhs_padded, &mut rhs_padded, None, &sym_ranges);
|
||||
result = lhs_padded + rhs_padded;
|
||||
result = result.concat_along(*t, dim);
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
@@ -220,7 +220,6 @@ impl<'a> Translator<'a> {
|
||||
let input = self.get_input_tensor(node, 0)?;
|
||||
let weight = self.get_input_tensor(node, 1)?;
|
||||
let offs = self.get_input_tensor(node, 2)?;
|
||||
let out_dtype = self.output_meta_dtype(node)?;
|
||||
|
||||
anyhow::ensure!(
|
||||
input.shape.len() == 2,
|
||||
@@ -275,15 +274,8 @@ impl<'a> Translator<'a> {
|
||||
let exp_within = within.expand_dim(0, s);
|
||||
let flat_idx = exp_base + exp_within;
|
||||
|
||||
// Gather → [S, K, N], then normalize both operands to the op's declared
|
||||
// output dtype before matmul. On real Qwen3-MoE bf16 checkpoints the FX
|
||||
// graph inserts casts on the activation path, and relying on the input
|
||||
// tensor's translated dtype can leave us with mixed F32/Bf16 operands
|
||||
// by the time matmul expands into elementwise Mul. Using the PT2 output
|
||||
// metadata keeps the matmul dtype aligned with the exported contract
|
||||
// without upcasting the full expert weight bank.
|
||||
let weight_gathered = weight.gather(flat_idx).cast(out_dtype);
|
||||
let input = input.cast(out_dtype);
|
||||
// Gather → [S, K, N], preserves weight's native dtype (bf16 stays bf16).
|
||||
let weight_gathered = weight.gather(flat_idx);
|
||||
|
||||
// Per-token matmul: [S, 1, K] @ [S, K, N] → [S, 1, N] → [S, N].
|
||||
// Operands stay in their native dtype — no F32 cast on the gathered
|
||||
@@ -295,7 +287,7 @@ impl<'a> Translator<'a> {
|
||||
// (cuBLASLt etc.) handle bf16 input with F32 accumulator internally.
|
||||
let result = input.unsqueeze(1).matmul(weight_gathered).squeeze(1);
|
||||
|
||||
Ok(result.cast(out_dtype))
|
||||
Ok(result.cast(input.dtype))
|
||||
}
|
||||
|
||||
/// Build the where-formula graph: `cond * x + (1 - cond) * y`, computed
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""CompiledModel wrapper for the Rust CompiledGraph."""
|
||||
|
||||
import warnings
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
@@ -8,6 +9,10 @@ from .dtype_util import code_to_torch_dtype
|
||||
from .dtype_util import torch_dtype_code as _torch_dtype_code
|
||||
|
||||
|
||||
class DTypeBoundaryWarning(UserWarning):
|
||||
"""Warns when the PyTorch boundary must cast input data before execution."""
|
||||
|
||||
|
||||
class CompiledModel:
|
||||
"""Wrapper around CompiledGraph that handles PyTorch tensor conversion."""
|
||||
|
||||
@@ -95,6 +100,15 @@ class CompiledModel:
|
||||
for name, tensor, expected_dtype in zip(
|
||||
self._input_names, user_inputs, self._input_dtypes
|
||||
):
|
||||
if tensor.dtype != expected_dtype:
|
||||
warnings.warn(
|
||||
"Luminal compiled input "
|
||||
f"'{name}' has dtype {tensor.dtype}, but the compiled graph "
|
||||
f"expects {expected_dtype}; converting at every call will "
|
||||
"allocate/copy input data.",
|
||||
DTypeBoundaryWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
if self._supports_device_ptrs and tensor.is_cuda:
|
||||
t = tensor.detach().contiguous().to(expected_dtype)
|
||||
n_bytes = t.numel() * t.element_size()
|
||||
|
||||
215
crates/luminal_python/tests/test_dtype_boundary.py
Normal file
215
crates/luminal_python/tests/test_dtype_boundary.py
Normal file
@@ -0,0 +1,215 @@
|
||||
from dataclasses import dataclass
|
||||
import warnings
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from luminal import luminal_backend
|
||||
from luminal.compiled_model import DTypeBoundaryWarning
|
||||
|
||||
|
||||
class BoundaryNoopModel(torch.nn.Module):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if x.dtype is torch.bool:
|
||||
return x | torch.zeros((), dtype=torch.bool, device=x.device)
|
||||
return x + torch.zeros((), dtype=x.dtype, device=x.device)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DTypeCase:
|
||||
name: str
|
||||
dtype: torch.dtype
|
||||
values: Callable[[], torch.Tensor]
|
||||
xfail_reason: str | None = None
|
||||
|
||||
|
||||
DTYPE_CASES = [
|
||||
DTypeCase(
|
||||
"bool",
|
||||
torch.bool,
|
||||
lambda: torch.tensor([True, False, True], dtype=torch.bool),
|
||||
),
|
||||
DTypeCase(
|
||||
"uint8",
|
||||
torch.uint8,
|
||||
lambda: torch.tensor([0, 127, 255], dtype=torch.uint8),
|
||||
),
|
||||
DTypeCase(
|
||||
"int8",
|
||||
torch.int8,
|
||||
lambda: torch.tensor([-128, -1, 127], dtype=torch.int8),
|
||||
),
|
||||
DTypeCase(
|
||||
"int16",
|
||||
torch.int16,
|
||||
lambda: torch.tensor([-32768, -1, 32767], dtype=torch.int16),
|
||||
),
|
||||
DTypeCase(
|
||||
"int32",
|
||||
torch.int32,
|
||||
lambda: torch.tensor(
|
||||
[-2147483648, -1, 2147483647],
|
||||
dtype=torch.int32,
|
||||
),
|
||||
),
|
||||
DTypeCase(
|
||||
"int64_i32_range",
|
||||
torch.int64,
|
||||
lambda: torch.tensor(
|
||||
[-2147483648, -1, 2147483647],
|
||||
dtype=torch.int64,
|
||||
),
|
||||
),
|
||||
DTypeCase(
|
||||
"float16",
|
||||
torch.float16,
|
||||
lambda: torch.tensor([1.0, 1.5, -2.0], dtype=torch.float16),
|
||||
),
|
||||
DTypeCase(
|
||||
"bfloat16",
|
||||
torch.bfloat16,
|
||||
lambda: torch.tensor([1.0, 1.5, -2.0], dtype=torch.bfloat16),
|
||||
),
|
||||
DTypeCase(
|
||||
"float32",
|
||||
torch.float32,
|
||||
lambda: torch.tensor([1.0, 1.5, -2.0], dtype=torch.float32),
|
||||
),
|
||||
DTypeCase(
|
||||
"float64_f32_exact",
|
||||
torch.float64,
|
||||
lambda: torch.tensor([1.0, 1.5, float(2**40)], dtype=torch.float64),
|
||||
),
|
||||
DTypeCase(
|
||||
"int64_outside_i32_range",
|
||||
torch.int64,
|
||||
lambda: torch.tensor([-(2**40), -1, 2**40], dtype=torch.int64),
|
||||
xfail_reason=(
|
||||
"Luminal currently collapses integer inputs through i32 at the "
|
||||
"compiled boundary, so out-of-range int64 values lose information."
|
||||
),
|
||||
),
|
||||
DTypeCase(
|
||||
"float64_precision_sensitive",
|
||||
torch.float64,
|
||||
lambda: torch.tensor(
|
||||
[1.0, 1.0000000000000002, float(2**40) + 0.25],
|
||||
dtype=torch.float64,
|
||||
),
|
||||
xfail_reason=(
|
||||
"Luminal currently routes float64 no-op computation through f32 "
|
||||
"storage/outputs before restoring the PyTorch-visible dtype."
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def _cuda_skip_reason() -> str | None:
|
||||
if not torch.cuda.is_available():
|
||||
return "CUDA is not available"
|
||||
|
||||
try:
|
||||
from luminal.luminal import _cuda_lite_factory_capsule
|
||||
|
||||
_cuda_lite_factory_capsule()
|
||||
except (ImportError, AttributeError, RuntimeError) as exc:
|
||||
return f"luminal_python was not built with CUDA support: {exc}"
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@pytest.fixture(params=["cpu", "cuda"], ids=["cpu", "cuda"])
|
||||
def boundary_device(request) -> torch.device:
|
||||
device_name = request.param
|
||||
if device_name == "cuda":
|
||||
skip_reason = _cuda_skip_reason()
|
||||
if skip_reason is not None:
|
||||
pytest.skip(skip_reason)
|
||||
return torch.device(device_name)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
[
|
||||
pytest.param(
|
||||
case,
|
||||
marks=pytest.mark.xfail(reason=case.xfail_reason, strict=True)
|
||||
if case.xfail_reason is not None
|
||||
else (),
|
||||
id=case.name,
|
||||
)
|
||||
for case in DTYPE_CASES
|
||||
],
|
||||
)
|
||||
def test_boundary_noop_preserves_dtype_and_values(
|
||||
boundary_device: torch.device,
|
||||
case: DTypeCase,
|
||||
) -> None:
|
||||
model = BoundaryNoopModel().to(boundary_device)
|
||||
compiled = torch.compile(model, backend=luminal_backend)
|
||||
|
||||
x = case.values().to(boundary_device)
|
||||
expected = model(x)
|
||||
actual = compiled(x)
|
||||
|
||||
assert isinstance(actual, torch.Tensor)
|
||||
assert actual.dtype == expected.dtype
|
||||
assert torch.equal(actual.cpu(), expected.cpu())
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
[
|
||||
pytest.param(case, id=case.name)
|
||||
for case in DTYPE_CASES
|
||||
if case.name
|
||||
in {
|
||||
"uint8",
|
||||
"int8",
|
||||
"int16",
|
||||
"int64_i32_range",
|
||||
"int64_outside_i32_range",
|
||||
"float64_f32_exact",
|
||||
"float64_precision_sensitive",
|
||||
}
|
||||
],
|
||||
)
|
||||
def test_boundary_warns_when_input_dtype_requires_conversion(
|
||||
boundary_device: torch.device,
|
||||
case: DTypeCase,
|
||||
) -> None:
|
||||
model = BoundaryNoopModel().to(boundary_device)
|
||||
compiled = torch.compile(model, backend=luminal_backend)
|
||||
x = case.values().to(boundary_device)
|
||||
|
||||
with pytest.warns(DTypeBoundaryWarning, match="allocate/copy input data"):
|
||||
compiled(x)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
[
|
||||
pytest.param(case, id=case.name)
|
||||
for case in DTYPE_CASES
|
||||
if case.name in {"bool", "int32", "float16", "bfloat16", "float32"}
|
||||
],
|
||||
)
|
||||
def test_boundary_does_not_warn_when_input_dtype_matches_graph(
|
||||
boundary_device: torch.device,
|
||||
case: DTypeCase,
|
||||
) -> None:
|
||||
model = BoundaryNoopModel().to(boundary_device)
|
||||
compiled = torch.compile(model, backend=luminal_backend)
|
||||
x = case.values().to(boundary_device)
|
||||
|
||||
with warnings.catch_warnings(record=True) as records:
|
||||
warnings.simplefilter("always")
|
||||
compiled(x)
|
||||
|
||||
dtype_boundary_warnings = [
|
||||
record
|
||||
for record in records
|
||||
if issubclass(record.category, DTypeBoundaryWarning)
|
||||
]
|
||||
assert dtype_boundary_warnings == []
|
||||
@@ -9,9 +9,6 @@ PT2 export, and reuse a single compiled graph across shape changes.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from contextlib import contextmanager
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch._dynamo
|
||||
@@ -37,64 +34,6 @@ def _compile_with_dynamic_true(model, count_holder):
|
||||
return torch.compile(model, backend=wrapper, dynamic=True)
|
||||
|
||||
|
||||
def _compile_with_capture(model, count_holder, capture_holder):
|
||||
def wrapper(gm, example_inputs):
|
||||
out = luminal_backend(gm, example_inputs)
|
||||
count_holder.append(1)
|
||||
if "gm" not in capture_holder:
|
||||
capture_holder["gm"] = copy.deepcopy(gm).eval()
|
||||
capture_holder["example_inputs"] = example_inputs
|
||||
capture_holder["compiled_impl"] = out
|
||||
return out
|
||||
|
||||
return torch.compile(model, backend=wrapper)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _explicit_mark_dynamic_mode():
|
||||
prev_auto = torch._dynamo.config.automatic_dynamic_shapes
|
||||
prev_cache_limit = torch._dynamo.config.cache_size_limit
|
||||
torch._dynamo.reset()
|
||||
torch._dynamo.config.automatic_dynamic_shapes = False
|
||||
torch._dynamo.config.cache_size_limit = 8
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
torch._dynamo.config.automatic_dynamic_shapes = prev_auto
|
||||
torch._dynamo.config.cache_size_limit = prev_cache_limit
|
||||
torch._dynamo.reset()
|
||||
|
||||
|
||||
def _first_trace_dynamic_shapes(capture_holder):
|
||||
from luminal.pt2 import (
|
||||
_build_dynamic_shapes_from_gm,
|
||||
_reinternalize_lifted_params,
|
||||
_strip_symint_placeholders,
|
||||
)
|
||||
|
||||
gm = copy.deepcopy(capture_holder["gm"]).eval()
|
||||
example_inputs = capture_holder["example_inputs"]
|
||||
gm, user_inputs, _, _ = _reinternalize_lifted_params(gm, example_inputs)
|
||||
user_inputs, _, strip_ok = _strip_symint_placeholders(gm, user_inputs)
|
||||
dynamic_shapes = _build_dynamic_shapes_from_gm(gm) if strip_ok else None
|
||||
return strip_ok, dynamic_shapes
|
||||
|
||||
|
||||
def _assert_input_dynamic_dims(dynamic_shapes, input_index, expected_dims):
|
||||
args_spec = dynamic_shapes.get("args")
|
||||
assert args_spec is not None and len(args_spec) > input_index, (
|
||||
f"expected dynamic spec for input {input_index}, got {dynamic_shapes}"
|
||||
)
|
||||
spec = args_spec[input_index]
|
||||
assert spec is not None, (
|
||||
f"expected a per-dim dynamic spec for input {input_index}, got {dynamic_shapes}"
|
||||
)
|
||||
assert set(spec.keys()) == set(expected_dims), (
|
||||
f"expected dynamic dims {set(expected_dims)} for input {input_index}, "
|
||||
f"got {dynamic_shapes}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _enable_automatic_dynamic():
|
||||
"""Make sure the tests run with Dynamo's automatic-dynamic detection on.
|
||||
@@ -267,202 +206,6 @@ def test_torch_compile_dynamic_true_single_compile(device: torch.device):
|
||||
)
|
||||
|
||||
|
||||
def test_mark_dynamic_seq_via_torch_compile_starts_dynamic(device: torch.device):
|
||||
"""Explicit `mark_dynamic` should skip the static-then-promote compile dance."""
|
||||
|
||||
class Mdl(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return (x.sin() + x.square()).sum(-1)
|
||||
|
||||
with _explicit_mark_dynamic_mode():
|
||||
model = Mdl().eval().to(device)
|
||||
counts: list[int] = []
|
||||
capture: dict[str, object] = {}
|
||||
compiled = _compile_with_capture(model, counts, capture)
|
||||
|
||||
first = torch.randn(2, 4, device=device)
|
||||
torch._dynamo.mark_dynamic(first, 1, min=2, max=16)
|
||||
|
||||
inputs = {
|
||||
4: first,
|
||||
6: torch.randn(2, 6, device=device),
|
||||
9: torch.randn(2, 9, device=device),
|
||||
}
|
||||
|
||||
for seq_len, x in inputs.items():
|
||||
ref = model(x)
|
||||
out = compiled(x)
|
||||
assert out.shape == ref.shape == (2,), (
|
||||
f"seq_len={seq_len}: got {out.shape}, expected {ref.shape}"
|
||||
)
|
||||
assert torch.allclose(out, ref, atol=1e-5), (
|
||||
f"seq_len={seq_len}: max_diff={torch.max(torch.abs(out - ref)).item():.2e}"
|
||||
)
|
||||
|
||||
compiled_impl = capture["compiled_impl"]
|
||||
assert compiled_impl.has_dynamic_dims
|
||||
assert len(compiled_impl.dim_params) == 1
|
||||
|
||||
strip_ok, dynamic_shapes = _first_trace_dynamic_shapes(capture)
|
||||
assert strip_ok, "Expected explicit mark_dynamic SymInts to be rewritten"
|
||||
assert dynamic_shapes is not None
|
||||
_assert_input_dynamic_dims(dynamic_shapes, 0, {1})
|
||||
|
||||
assert len(counts) == 1, (
|
||||
"Explicit mark_dynamic should produce one dynamic backend trace from the start, "
|
||||
f"got {len(counts)} backend invocations"
|
||||
)
|
||||
|
||||
|
||||
def test_mark_dynamic_seq_with_lifted_weights_single_compile(device: torch.device):
|
||||
"""Lifted parameters should compose with an explicitly dynamic token axis."""
|
||||
|
||||
class Mdl(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.embed = torch.nn.Embedding(128, 16)
|
||||
self.proj = torch.nn.Linear(16, 8)
|
||||
|
||||
def forward(self, input_ids):
|
||||
return self.proj(self.embed(input_ids)).sum(-1)
|
||||
|
||||
with _explicit_mark_dynamic_mode():
|
||||
model = Mdl().eval().to(device)
|
||||
counts: list[int] = []
|
||||
capture: dict[str, object] = {}
|
||||
compiled = _compile_with_capture(model, counts, capture)
|
||||
|
||||
first = torch.tensor([[1, 2, 3, 4]], device=device)
|
||||
torch._dynamo.mark_dynamic(first, 1, min=2, max=32)
|
||||
|
||||
inputs = {
|
||||
4: first,
|
||||
6: torch.arange(1, 7, device=device).unsqueeze(0),
|
||||
9: torch.arange(1, 10, device=device).unsqueeze(0),
|
||||
}
|
||||
|
||||
with torch.no_grad():
|
||||
for seq_len, input_ids in inputs.items():
|
||||
ref = model(input_ids)
|
||||
out = compiled(input_ids)
|
||||
assert out.shape == ref.shape == (1, seq_len), (
|
||||
f"seq_len={seq_len}: got {out.shape}, expected {ref.shape}"
|
||||
)
|
||||
assert torch.allclose(out, ref, atol=1e-5), (
|
||||
"seq_len="
|
||||
f"{seq_len}: max_diff={torch.max(torch.abs(out - ref)).item():.2e}"
|
||||
)
|
||||
|
||||
compiled_impl = capture["compiled_impl"]
|
||||
assert compiled_impl.has_dynamic_dims
|
||||
assert len(compiled_impl.dim_params) == 1
|
||||
|
||||
strip_ok, dynamic_shapes = _first_trace_dynamic_shapes(capture)
|
||||
assert strip_ok
|
||||
assert dynamic_shapes is not None
|
||||
_assert_input_dynamic_dims(dynamic_shapes, 0, {1})
|
||||
|
||||
assert len(counts) == 1, (
|
||||
"Explicit mark_dynamic should avoid a second compile for lifted-weight models, "
|
||||
f"got {len(counts)} backend invocations"
|
||||
)
|
||||
|
||||
|
||||
def test_mark_dynamic_seq_preserves_affine_output_shape(device: torch.device):
|
||||
"""Output-shape expressions like `2 * seq` should stay dynamic from call 1."""
|
||||
|
||||
class Mdl(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return torch.cat([x, x], dim=1)
|
||||
|
||||
with _explicit_mark_dynamic_mode():
|
||||
model = Mdl().eval().to(device)
|
||||
counts: list[int] = []
|
||||
capture: dict[str, object] = {}
|
||||
compiled = _compile_with_capture(model, counts, capture)
|
||||
|
||||
first = torch.randn(2, 4, 3, device=device)
|
||||
torch._dynamo.mark_dynamic(first, 1, min=2, max=16)
|
||||
|
||||
inputs = {
|
||||
4: first,
|
||||
5: torch.randn(2, 5, 3, device=device),
|
||||
7: torch.randn(2, 7, 3, device=device),
|
||||
}
|
||||
|
||||
for seq_len, x in inputs.items():
|
||||
ref = model(x)
|
||||
out = compiled(x)
|
||||
assert out.shape == ref.shape == (2, 2 * seq_len, 3), (
|
||||
f"seq_len={seq_len}: got {out.shape}, expected {ref.shape}"
|
||||
)
|
||||
assert torch.allclose(out, ref, atol=1e-5), (
|
||||
f"seq_len={seq_len}: max_diff={torch.max(torch.abs(out - ref)).item():.2e}"
|
||||
)
|
||||
|
||||
compiled_impl = capture["compiled_impl"]
|
||||
assert compiled_impl.has_dynamic_dims
|
||||
assert len(compiled_impl.dim_params) == 1
|
||||
|
||||
strip_ok, dynamic_shapes = _first_trace_dynamic_shapes(capture)
|
||||
assert strip_ok
|
||||
assert dynamic_shapes is not None
|
||||
_assert_input_dynamic_dims(dynamic_shapes, 0, {1})
|
||||
|
||||
assert len(counts) == 1, (
|
||||
"Explicit mark_dynamic should keep affine output-shape models on one compile, "
|
||||
f"got {len(counts)} backend invocations"
|
||||
)
|
||||
|
||||
|
||||
def test_mark_dynamic_two_dim_via_torch_compile_starts_dynamic(device: torch.device):
|
||||
"""Marking both batch and seq dynamic should still compile only once."""
|
||||
|
||||
class Mdl(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return x.mean(-1)
|
||||
|
||||
with _explicit_mark_dynamic_mode():
|
||||
model = Mdl().eval().to(device)
|
||||
counts: list[int] = []
|
||||
capture: dict[str, object] = {}
|
||||
compiled = _compile_with_capture(model, counts, capture)
|
||||
|
||||
first = torch.randn(2, 8, 4, device=device)
|
||||
torch._dynamo.mark_dynamic(first, 0, min=1, max=8)
|
||||
torch._dynamo.mark_dynamic(first, 1, min=2, max=16)
|
||||
|
||||
inputs = {
|
||||
(2, 8): first,
|
||||
(3, 9): torch.randn(3, 9, 4, device=device),
|
||||
(5, 11): torch.randn(5, 11, 4, device=device),
|
||||
}
|
||||
|
||||
for shape, x in inputs.items():
|
||||
ref = model(x)
|
||||
out = compiled(x)
|
||||
assert out.shape == ref.shape == shape, (
|
||||
f"shape={shape}: got {out.shape}, expected {ref.shape}"
|
||||
)
|
||||
assert torch.allclose(out, ref, atol=1e-5), (
|
||||
f"shape={shape}: max_diff={torch.max(torch.abs(out - ref)).item():.2e}"
|
||||
)
|
||||
|
||||
compiled_impl = capture["compiled_impl"]
|
||||
assert compiled_impl.has_dynamic_dims
|
||||
assert len(compiled_impl.dim_params) == 2
|
||||
|
||||
strip_ok, dynamic_shapes = _first_trace_dynamic_shapes(capture)
|
||||
assert strip_ok
|
||||
assert dynamic_shapes is not None
|
||||
_assert_input_dynamic_dims(dynamic_shapes, 0, {0, 1})
|
||||
|
||||
assert len(counts) == 1, (
|
||||
"Explicitly marked batch+seq dims should compile once from the first call, "
|
||||
f"got {len(counts)} backend invocations"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(),
|
||||
reason="CUDA-only — exercises the cuda_lite dynamic-dim runtime",
|
||||
|
||||
142
crates/luminal_python/tests/test_input_layout.py
Normal file
142
crates/luminal_python/tests/test_input_layout.py
Normal file
@@ -0,0 +1,142 @@
|
||||
import torch
|
||||
import pytest
|
||||
|
||||
from luminal import luminal_backend
|
||||
|
||||
|
||||
class StrideSensitiveInputModel(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.register_buffer(
|
||||
"coeff",
|
||||
torch.tensor([1.0, 10.0, 100.0], dtype=torch.float32),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x @ self.coeff
|
||||
|
||||
|
||||
class TwoInputReadModel(torch.nn.Module):
|
||||
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
return x * 2.0 + y * 3.0
|
||||
|
||||
|
||||
class ReturnInputModel(torch.nn.Module):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x
|
||||
|
||||
|
||||
class ReturnInputAndComputedModel(torch.nn.Module):
|
||||
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return x, x + 1.0
|
||||
|
||||
|
||||
class CloneThenMutateModel(torch.nn.Module):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
y = x.clone()
|
||||
y.add_(1.0)
|
||||
return y, x * 2.0
|
||||
|
||||
|
||||
def _base_view(device: torch.device) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
base = torch.arange(12, dtype=torch.float32, device=device).reshape(3, 4)
|
||||
return base, base.t()
|
||||
|
||||
|
||||
def _assert_non_contiguous_storage_alias(base: torch.Tensor, view: torch.Tensor) -> None:
|
||||
assert not view.is_contiguous()
|
||||
assert view.untyped_storage().data_ptr() == base.untyped_storage().data_ptr()
|
||||
|
||||
|
||||
def _assert_same(actual, expected) -> None:
|
||||
if isinstance(expected, tuple):
|
||||
assert isinstance(actual, tuple)
|
||||
assert len(actual) == len(expected)
|
||||
for actual_item, expected_item in zip(actual, expected):
|
||||
_assert_same(actual_item, expected_item)
|
||||
return
|
||||
|
||||
assert torch.allclose(actual, expected)
|
||||
|
||||
|
||||
def _single_non_contiguous_view(device: torch.device):
|
||||
base, view = _base_view(device)
|
||||
_assert_non_contiguous_storage_alias(base, view)
|
||||
return StrideSensitiveInputModel().to(device), (view,), base
|
||||
|
||||
|
||||
def _same_view_twice(device: torch.device):
|
||||
base, view = _base_view(device)
|
||||
_assert_non_contiguous_storage_alias(base, view)
|
||||
return TwoInputReadModel().to(device), (view, view), base
|
||||
|
||||
|
||||
def _overlapping_views(device: torch.device):
|
||||
base = torch.arange(20, dtype=torch.float32, device=device).reshape(4, 5)
|
||||
x = base[:3, :4]
|
||||
y = base[1:, 1:]
|
||||
assert not x.is_contiguous()
|
||||
assert not y.is_contiguous()
|
||||
assert x.untyped_storage().data_ptr() == base.untyped_storage().data_ptr()
|
||||
assert y.untyped_storage().data_ptr() == base.untyped_storage().data_ptr()
|
||||
return TwoInputReadModel().to(device), (x, y), base
|
||||
|
||||
|
||||
def _return_input(device: torch.device):
|
||||
base, view = _base_view(device)
|
||||
_assert_non_contiguous_storage_alias(base, view)
|
||||
return ReturnInputModel().to(device), (view,), base
|
||||
|
||||
|
||||
def _return_input_and_computed(device: torch.device):
|
||||
base, view = _base_view(device)
|
||||
_assert_non_contiguous_storage_alias(base, view)
|
||||
return ReturnInputAndComputedModel().to(device), (view,), base
|
||||
|
||||
|
||||
def _internal_clone_inplace(device: torch.device):
|
||||
base, view = _base_view(device)
|
||||
_assert_non_contiguous_storage_alias(base, view)
|
||||
return CloneThenMutateModel().to(device), (view,), base
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"make_case",
|
||||
[
|
||||
pytest.param(
|
||||
_single_non_contiguous_view,
|
||||
id="single_non_contiguous_view_stride_sensitive_read",
|
||||
),
|
||||
pytest.param(_same_view_twice, id="same_view_passed_as_two_read_inputs"),
|
||||
pytest.param(_overlapping_views, id="overlapping_views_as_two_read_inputs"),
|
||||
pytest.param(_return_input, id="return_input_boundary_value"),
|
||||
pytest.param(
|
||||
_return_input_and_computed,
|
||||
id="return_input_boundary_value_and_computed_value",
|
||||
),
|
||||
pytest.param(_internal_clone_inplace, id="inplace_mutation_on_internal_clone"),
|
||||
],
|
||||
)
|
||||
def test_input_boundary_contiguous_materialization_cases(
|
||||
device: torch.device, make_case
|
||||
) -> None:
|
||||
model, inputs, base = make_case(device)
|
||||
compiled = torch.compile(model, backend=luminal_backend)
|
||||
|
||||
base_before = base.clone()
|
||||
expected = model(*inputs)
|
||||
actual = compiled(*inputs)
|
||||
|
||||
_assert_same(actual, expected)
|
||||
assert torch.allclose(base, base_before)
|
||||
|
||||
|
||||
def test_non_contiguous_view_input_fails_if_raw_storage_order_is_used(
|
||||
device: torch.device,
|
||||
) -> None:
|
||||
model, (view,), base = _single_non_contiguous_view(device)
|
||||
|
||||
wrong_if_storage_order_used = model(base.reshape(view.shape))
|
||||
expected = model(view)
|
||||
|
||||
assert not torch.allclose(wrong_if_storage_order_used, expected)
|
||||
@@ -97,87 +97,6 @@ def test_kv_cache_growing():
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(),
|
||||
reason="dynamic-cache torch.compile reuse requires CUDA coverage",
|
||||
)
|
||||
@pytest.mark.slow
|
||||
def test_dynamic_kv_cache_torch_compile_matches_reference_and_reuses_decode_graph():
|
||||
"""End-to-end server-style path: torch.compile + DynamicCache on CUDA."""
|
||||
from transformers import DynamicCache, LlamaConfig, LlamaForCausalLM
|
||||
|
||||
backend_invocations = []
|
||||
|
||||
def counting_backend(gm, example_inputs, options=None):
|
||||
backend_invocations.append((gm, example_inputs))
|
||||
return luminal_backend(gm, example_inputs, options)
|
||||
|
||||
prev_auto = torch._dynamo.config.automatic_dynamic_shapes
|
||||
prev_cache_limit = torch._dynamo.config.cache_size_limit
|
||||
prev_recompile_limit = torch._dynamo.config.recompile_limit
|
||||
torch._dynamo.config.automatic_dynamic_shapes = True
|
||||
torch._dynamo.config.cache_size_limit = 16
|
||||
torch._dynamo.config.recompile_limit = 16
|
||||
|
||||
try:
|
||||
model = (
|
||||
LlamaForCausalLM(
|
||||
LlamaConfig(
|
||||
hidden_size=64,
|
||||
num_attention_heads=4,
|
||||
num_key_value_heads=2,
|
||||
num_hidden_layers=1,
|
||||
intermediate_size=128,
|
||||
vocab_size=256,
|
||||
max_position_embeddings=128,
|
||||
use_cache=True,
|
||||
attn_implementation="eager",
|
||||
)
|
||||
)
|
||||
.eval()
|
||||
.cuda()
|
||||
)
|
||||
compiled = torch.compile(model, backend=counting_backend, fullgraph=True)
|
||||
|
||||
ref_cache = DynamicCache(config=model.config)
|
||||
out_cache = DynamicCache(config=model.config)
|
||||
input_ids = torch.tensor([[1, 2, 3, 4]], device="cuda")
|
||||
|
||||
with torch.no_grad():
|
||||
ref = model(input_ids=input_ids, past_key_values=ref_cache, use_cache=True)
|
||||
out = compiled(
|
||||
input_ids=input_ids,
|
||||
past_key_values=out_cache,
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
for _ in range(4):
|
||||
ref_next = int(ref.logits[0, -1].argmax().item())
|
||||
out_next = int(out.logits[0, -1].argmax().item())
|
||||
assert out_next == ref_next
|
||||
with torch.no_grad():
|
||||
ref = model(
|
||||
input_ids=torch.tensor([[ref_next]], device="cuda"),
|
||||
past_key_values=ref.past_key_values,
|
||||
use_cache=True,
|
||||
)
|
||||
out = compiled(
|
||||
input_ids=torch.tensor([[out_next]], device="cuda"),
|
||||
past_key_values=out.past_key_values,
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
assert len(backend_invocations) == 3, (
|
||||
"Expected prefill/static decode/dynamic decode traces only once each, "
|
||||
f"got {len(backend_invocations)} backend invocations"
|
||||
)
|
||||
finally:
|
||||
torch._dynamo.config.automatic_dynamic_shapes = prev_auto
|
||||
torch._dynamo.config.cache_size_limit = prev_cache_limit
|
||||
torch._dynamo.config.recompile_limit = prev_recompile_limit
|
||||
torch._dynamo.reset()
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(),
|
||||
reason="R1 full-width 1-layer is too memory-heavy for CPU native backend",
|
||||
|
||||
@@ -450,138 +450,3 @@ def test_hf_llama38b_full(device: torch.device):
|
||||
assert torch.allclose(out.logits, ref.logits, atol=1e-5), (
|
||||
f"max_diff={torch.max(torch.abs(out.logits - ref.logits)).item():.2e}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(),
|
||||
reason="Full Llama-3.1-8B dynamic-shape regression requires CUDA",
|
||||
)
|
||||
def test_hf_llama38b_mark_dynamic_seq_dim_before_compile(device: torch.device):
|
||||
"""Explicitly marking the token sequence dim dynamic should be honored end to end.
|
||||
|
||||
This exercises the real user path:
|
||||
1. wrap the pretrained 8B model with ``torch.compile(..., backend=luminal_backend)``
|
||||
2. mark ``input_ids.shape[1]`` dynamic before the first invocation
|
||||
3. verify the first backend trace is already dynamic on that axis
|
||||
4. reuse the same compiled graph for multiple sequence lengths
|
||||
"""
|
||||
import copy
|
||||
|
||||
from transformers import AutoConfig, LlamaForCausalLM
|
||||
|
||||
from luminal.pt2 import (
|
||||
_build_dynamic_shapes_from_gm,
|
||||
_reinternalize_lifted_params,
|
||||
_strip_symint_placeholders,
|
||||
)
|
||||
|
||||
backend_invocations = []
|
||||
capture = {}
|
||||
|
||||
def inspector_backend(gm, example_inputs, **kwargs):
|
||||
backend_invocations.append((gm, example_inputs, kwargs))
|
||||
if len(backend_invocations) == 1:
|
||||
capture["gm"] = copy.deepcopy(gm).eval()
|
||||
capture["example_inputs"] = example_inputs
|
||||
compiled_impl = luminal_backend(gm, example_inputs, **kwargs)
|
||||
if len(backend_invocations) == 1:
|
||||
capture["compiled_impl"] = compiled_impl
|
||||
return compiled_impl
|
||||
|
||||
prev_auto = torch._dynamo.config.automatic_dynamic_shapes
|
||||
prev_cache_limit = torch._dynamo.config.cache_size_limit
|
||||
torch._dynamo.reset()
|
||||
torch._dynamo.config.automatic_dynamic_shapes = False
|
||||
torch._dynamo.config.cache_size_limit = 8
|
||||
|
||||
try:
|
||||
config = AutoConfig.from_pretrained("NousResearch/Meta-Llama-3.1-8B-Instruct")
|
||||
config.use_cache = False
|
||||
config._attn_implementation = "eager"
|
||||
|
||||
model = (
|
||||
LlamaForCausalLM.from_pretrained(
|
||||
"NousResearch/Meta-Llama-3.1-8B-Instruct",
|
||||
config=config,
|
||||
torch_dtype=torch.float32,
|
||||
)
|
||||
.eval()
|
||||
.to(device)
|
||||
)
|
||||
compiled = torch.compile(model, backend=inspector_backend)
|
||||
|
||||
first_input_ids = torch.tensor([[1, 2, 3, 4]], device=device)
|
||||
torch._dynamo.mark_dynamic(first_input_ids, 1, min=2, max=16)
|
||||
|
||||
seq_inputs = {
|
||||
4: first_input_ids,
|
||||
6: torch.arange(1, 7, device=device).unsqueeze(0),
|
||||
9: torch.arange(1, 10, device=device).unsqueeze(0),
|
||||
}
|
||||
|
||||
with torch.no_grad():
|
||||
first_ref = model(first_input_ids)
|
||||
first_out = compiled(first_input_ids)
|
||||
|
||||
compiled_impl = capture["compiled_impl"]
|
||||
assert compiled_impl.has_dynamic_dims, (
|
||||
"explicit mark_dynamic on input_ids[:, 1] should produce a dynamic Luminal graph"
|
||||
)
|
||||
assert len(compiled_impl.dim_params) == 1, (
|
||||
f"expected exactly one dynamic dim param, got {compiled_impl.dim_params}"
|
||||
)
|
||||
|
||||
gm = capture["gm"]
|
||||
example_inputs = capture["example_inputs"]
|
||||
gm, user_inputs, _, _ = _reinternalize_lifted_params(gm, example_inputs)
|
||||
user_inputs, _, strip_ok = _strip_symint_placeholders(gm, user_inputs)
|
||||
dynamic_shapes = _build_dynamic_shapes_from_gm(gm) if strip_ok else None
|
||||
|
||||
assert strip_ok, "Expected explicit mark_dynamic SymInts to be rewritten"
|
||||
assert dynamic_shapes is not None, (
|
||||
"Expected the first backend trace to preserve a dynamic shape spec"
|
||||
)
|
||||
args_spec = dynamic_shapes.get("args")
|
||||
assert args_spec is not None and len(args_spec) == 1, (
|
||||
f"expected one user-input dynamic spec, got {dynamic_shapes}"
|
||||
)
|
||||
assert args_spec[0] is not None, (
|
||||
f"expected a per-dim dynamic spec for input_ids, got {dynamic_shapes}"
|
||||
)
|
||||
assert set(args_spec[0].keys()) == {1}, (
|
||||
"Expected only the token sequence axis (dim=1) to be dynamic, "
|
||||
f"got {dynamic_shapes}"
|
||||
)
|
||||
|
||||
first_diff = torch.max(torch.abs(first_out.logits - first_ref.logits)).item()
|
||||
assert torch.allclose(first_out.logits, first_ref.logits, atol=1e-3, rtol=0), (
|
||||
f"seq_len=4: max_diff={first_diff:.2e}"
|
||||
)
|
||||
|
||||
for seq_len, input_ids in seq_inputs.items():
|
||||
with torch.no_grad():
|
||||
ref = model(input_ids)
|
||||
out = first_out if seq_len == 4 else compiled(input_ids)
|
||||
assert (
|
||||
out.logits.shape
|
||||
== ref.logits.shape
|
||||
== (
|
||||
1,
|
||||
seq_len,
|
||||
config.vocab_size,
|
||||
)
|
||||
), f"seq_len={seq_len}: got {out.logits.shape}, expected {ref.logits.shape}"
|
||||
assert torch.allclose(out.logits, ref.logits, atol=1e-3, rtol=0), (
|
||||
f"seq_len={seq_len}: "
|
||||
f"max_diff={torch.max(torch.abs(out.logits - ref.logits)).item():.2e}"
|
||||
)
|
||||
|
||||
assert len(backend_invocations) == 1, (
|
||||
"Explicit mark_dynamic should produce one dynamic backend trace from the start, "
|
||||
f"got {len(backend_invocations)} backend invocations"
|
||||
)
|
||||
finally:
|
||||
torch._dynamo.config.automatic_dynamic_shapes = prev_auto
|
||||
torch._dynamo.config.cache_size_limit = prev_cache_limit
|
||||
torch._dynamo.reset()
|
||||
|
||||
138
crates/luminal_python/tests/test_mutation_alias_contract.py
Normal file
138
crates/luminal_python/tests/test_mutation_alias_contract.py
Normal file
@@ -0,0 +1,138 @@
|
||||
"""Regression coverage for torch.compile mutation and alias contracts.
|
||||
|
||||
PyTorch backends are expected to preserve the semantics of the traced graph.
|
||||
After torch.export functionalization, input mutations are represented as
|
||||
leading mutation outputs before user outputs. Luminal currently treats every
|
||||
compiled graph output as a user output and also materializes inputs at the
|
||||
boundary, so caller-visible mutation and aliasing semantics are not preserved.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from luminal import luminal_backend
|
||||
|
||||
|
||||
class MutateInputThenCompute(torch.nn.Module):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x.add_(1.0)
|
||||
return x * 2.0
|
||||
|
||||
|
||||
class MutateInputReturnAlias(torch.nn.Module):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x.add_(1.0)
|
||||
return x
|
||||
|
||||
|
||||
class MutateOverlappingInputAlias(torch.nn.Module):
|
||||
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
x.add_(10.0)
|
||||
return y * 2.0
|
||||
|
||||
|
||||
class ReturnInputView(torch.nn.Module):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x.t()
|
||||
|
||||
|
||||
def _assert_same_storage(a: torch.Tensor, b: torch.Tensor) -> None:
|
||||
assert a.untyped_storage().data_ptr() == b.untyped_storage().data_ptr()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("backend", ["eager", "aot_eager", "inductor"])
|
||||
def test_stock_torch_compile_preserves_input_mutation_writeback(backend: str) -> None:
|
||||
model = MutateInputThenCompute()
|
||||
expected_input = torch.arange(6, dtype=torch.float32).reshape(2, 3)
|
||||
actual_input = expected_input.clone()
|
||||
|
||||
expected = model(expected_input)
|
||||
compiled = torch.compile(model, backend=backend)
|
||||
actual = compiled(actual_input)
|
||||
|
||||
assert torch.equal(actual, expected)
|
||||
assert torch.equal(actual_input, expected_input)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("backend", ["eager", "aot_eager", "inductor"])
|
||||
def test_stock_torch_compile_preserves_mutated_return_alias(backend: str) -> None:
|
||||
model = MutateInputReturnAlias()
|
||||
x = torch.arange(6, dtype=torch.float32).reshape(2, 3)
|
||||
|
||||
compiled = torch.compile(model, backend=backend)
|
||||
out = compiled(x)
|
||||
|
||||
assert torch.equal(x, torch.arange(1, 7, dtype=torch.float32).reshape(2, 3))
|
||||
_assert_same_storage(out, x)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("backend", ["eager", "aot_eager", "inductor"])
|
||||
def test_stock_torch_compile_preserves_returned_view_alias(backend: str) -> None:
|
||||
model = ReturnInputView()
|
||||
x = torch.arange(6, dtype=torch.float32).reshape(2, 3)
|
||||
|
||||
compiled = torch.compile(model, backend=backend)
|
||||
out = compiled(x)
|
||||
|
||||
assert torch.equal(out, x.t())
|
||||
assert out.stride() == (1, 3)
|
||||
_assert_same_storage(out, x)
|
||||
|
||||
|
||||
@pytest.mark.xfail(
|
||||
strict=True,
|
||||
reason=(
|
||||
"Luminal currently treats functionalized input-mutation outputs as user "
|
||||
"outputs and does not copy mutation outputs back to caller inputs."
|
||||
),
|
||||
)
|
||||
def test_luminal_input_mutation_writeback_contract(device: torch.device) -> None:
|
||||
model = MutateInputThenCompute().to(device)
|
||||
x = torch.arange(6, dtype=torch.float32, device=device).reshape(2, 3)
|
||||
|
||||
compiled = torch.compile(model, backend=luminal_backend)
|
||||
out = compiled(x)
|
||||
|
||||
expected_x = torch.arange(1, 7, dtype=torch.float32, device=device).reshape(2, 3)
|
||||
expected_out = expected_x * 2.0
|
||||
assert torch.equal(out, expected_out)
|
||||
assert torch.equal(x, expected_x)
|
||||
|
||||
|
||||
@pytest.mark.xfail(
|
||||
strict=True,
|
||||
reason=(
|
||||
"Luminal does not preserve caller-visible overlapping input aliasing "
|
||||
"when one aliased input is mutated."
|
||||
),
|
||||
)
|
||||
def test_luminal_overlapping_input_alias_mutation_contract(
|
||||
device: torch.device,
|
||||
) -> None:
|
||||
model = MutateOverlappingInputAlias().to(device)
|
||||
|
||||
eager_base = torch.arange(6, dtype=torch.float32, device=device)
|
||||
expected = model(eager_base[:4], eager_base[1:5])
|
||||
|
||||
base = torch.arange(6, dtype=torch.float32, device=device)
|
||||
compiled = torch.compile(model, backend=luminal_backend)
|
||||
actual = compiled(base[:4], base[1:5])
|
||||
|
||||
assert torch.equal(actual, expected)
|
||||
assert torch.equal(base, eager_base)
|
||||
|
||||
|
||||
@pytest.mark.xfail(
|
||||
strict=True,
|
||||
reason="Luminal materializes returned input views instead of preserving aliasing.",
|
||||
)
|
||||
def test_luminal_returned_view_alias_contract(device: torch.device) -> None:
|
||||
model = ReturnInputView().to(device)
|
||||
x = torch.arange(6, dtype=torch.float32, device=device).reshape(2, 3)
|
||||
|
||||
compiled = torch.compile(model, backend=luminal_backend)
|
||||
out = compiled(x)
|
||||
|
||||
assert torch.equal(out, x.t())
|
||||
assert out.stride() == (1, 3)
|
||||
_assert_same_storage(out, x)
|
||||
@@ -152,57 +152,6 @@ def test_hf_qwen3_moe_medium(device: torch.device):
|
||||
_run_hf_qwen3_moe_test(config, device, atol=1e-4)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(),
|
||||
reason="bf16 grouped_mm coverage requires CUDA",
|
||||
)
|
||||
def test_hf_qwen3_moe_tiny_bf16(device: torch.device):
|
||||
"""HuggingFace Qwen3MoeForCausalLM — tiny bf16 path on CUDA.
|
||||
|
||||
Exercises the grouped-mm MoE lowering with bf16 weights/activations so we
|
||||
catch mixed-dtype compile regressions without paying the full 30B checkpoint
|
||||
cost. Like the full pretrained bf16 test below, this only asserts that the
|
||||
compiled path runs and stays numerically sane; tight bf16 equivalence is
|
||||
tracked separately.
|
||||
"""
|
||||
from transformers import Qwen3MoeForCausalLM
|
||||
|
||||
config = _make_qwen3_moe_config(
|
||||
hidden_size=32,
|
||||
num_attention_heads=2,
|
||||
num_key_value_heads=1,
|
||||
num_hidden_layers=1,
|
||||
intermediate_size=64,
|
||||
moe_intermediate_size=64,
|
||||
num_experts=2,
|
||||
num_experts_per_tok=1,
|
||||
vocab_size=128,
|
||||
)
|
||||
|
||||
model = Qwen3MoeForCausalLM(config).eval().to(dtype=torch.bfloat16, device=device)
|
||||
compiled = torch.compile(model, backend=luminal_backend)
|
||||
input_ids = torch.tensor([[1, 2, 3, 4]], device=device)
|
||||
with torch.no_grad():
|
||||
ref = model(input_ids)
|
||||
out = compiled(input_ids)
|
||||
|
||||
ref_logits = ref.logits.float()
|
||||
out_logits = out.logits.float()
|
||||
ref_max = ref_logits.abs().max().item()
|
||||
out_max = out_logits.abs().max().item()
|
||||
n_nan = int(out_logits.isnan().sum().item())
|
||||
n_inf = int(out_logits.isinf().sum().item())
|
||||
|
||||
assert n_nan == 0 and n_inf == 0, (
|
||||
f"compiled forward produced non-finite logits: {n_nan} NaNs, "
|
||||
f"{n_inf} Infs (eager max abs={ref_max:.2f}, compiled max abs={out_max:.2f})"
|
||||
)
|
||||
assert 0.1 * ref_max <= out_max <= 10.0 * ref_max, (
|
||||
f"compiled max abs={out_max:.2f} is out of band vs eager max abs={ref_max:.2f} "
|
||||
f"(>10x off in either direction); likely a numerical/scale bug"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_hf_qwen3_moe_real_config_1layer(device: torch.device):
|
||||
"""HuggingFace Qwen3MoeForCausalLM — real Qwen3-30B-A3B architecture, 1 layer.
|
||||
|
||||
@@ -1,7 +0,0 @@
|
||||
[package]
|
||||
name = "example_common"
|
||||
version = "0.1.0"
|
||||
edition = "2024"
|
||||
|
||||
[dependencies]
|
||||
rustc-hash = "2"
|
||||
@@ -1,167 +0,0 @@
|
||||
//! Shared helpers for the rust example binaries.
|
||||
//!
|
||||
//! - `--stdio` arg detection and READY/TOK/EOQ protocol used by the
|
||||
//! luminal-benchmarks harness to drive a long-lived subprocess.
|
||||
//! - Env-var parsing for benchmark knobs (`GEN_TOKENS`, `SEARCH_GRAPHS`).
|
||||
//! - `info!` routing — stderr in stdio mode, stdout otherwise.
|
||||
//! - Greedy sampling with a repetition penalty.
|
||||
|
||||
use rustc_hash::FxHashSet;
|
||||
|
||||
pub fn env_usize(name: &str, default: usize) -> usize {
|
||||
std::env::var(name)
|
||||
.ok()
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(default)
|
||||
}
|
||||
|
||||
pub fn env_bool(name: &str) -> bool {
|
||||
std::env::var(name)
|
||||
.ok()
|
||||
.is_some_and(|s| matches!(s.as_str(), "1" | "true" | "TRUE" | "yes" | "YES"))
|
||||
}
|
||||
|
||||
pub fn has_arg(name: &str) -> bool {
|
||||
std::env::args().any(|a| a == name)
|
||||
}
|
||||
|
||||
/// Route an info message to stderr in stdio mode (so the protocol channel
|
||||
/// stays clean) or stdout otherwise.
|
||||
pub fn info(stdio_mode: bool, msg: impl AsRef<str>) {
|
||||
let msg = msg.as_ref();
|
||||
if stdio_mode {
|
||||
eprintln!("{msg}");
|
||||
} else {
|
||||
println!("{msg}");
|
||||
}
|
||||
}
|
||||
|
||||
/// Greedy argmax with a multiplicative repetition penalty applied to
|
||||
/// previously-seen tokens.
|
||||
pub fn sample_greedy_with_penalty(
|
||||
logits_row: &[f32],
|
||||
seen: &FxHashSet<u32>,
|
||||
repetition_penalty: f32,
|
||||
) -> u32 {
|
||||
let mut row = logits_row.to_vec();
|
||||
for &tok in seen {
|
||||
let logit = &mut row[tok as usize];
|
||||
if *logit > 0.0 {
|
||||
*logit /= repetition_penalty;
|
||||
} else {
|
||||
*logit *= repetition_penalty;
|
||||
}
|
||||
}
|
||||
row.iter()
|
||||
.enumerate()
|
||||
.max_by(|(_, a), (_, b)| a.total_cmp(b))
|
||||
.unwrap()
|
||||
.0 as u32
|
||||
}
|
||||
|
||||
/// Escape a token's UTF-8 for one-line TOK transport: \ → \\, \t → \\t,
|
||||
/// \n → \\n, \r → \\r. Inverted on the python side.
|
||||
pub fn escape_tok(s: &str) -> String {
|
||||
let mut out = String::with_capacity(s.len());
|
||||
for c in s.chars() {
|
||||
match c {
|
||||
'\\' => out.push_str("\\\\"),
|
||||
'\t' => out.push_str("\\t"),
|
||||
'\n' => out.push_str("\\n"),
|
||||
'\r' => out.push_str("\\r"),
|
||||
_ => out.push(c),
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
pub mod stdio {
|
||||
//! READY / TOK\t<text> / EOQ\t<n>\t<elapsed_ms> protocol shared with
|
||||
//! `luminal-benchmarks/sut/rust.py`.
|
||||
|
||||
use super::escape_tok;
|
||||
use std::io::{BufRead, Write};
|
||||
use std::time::Duration;
|
||||
|
||||
/// Print the one-shot READY line that tells the harness init is done.
|
||||
pub fn ready() {
|
||||
let mut out = std::io::stdout().lock();
|
||||
let _ = writeln!(out, "READY");
|
||||
let _ = out.flush();
|
||||
}
|
||||
|
||||
/// One generated token, emitted as it's produced.
|
||||
pub fn emit_tok(text: &str) {
|
||||
let mut out = std::io::stdout().lock();
|
||||
let _ = writeln!(out, "TOK\t{}", escape_tok(text));
|
||||
let _ = out.flush();
|
||||
}
|
||||
|
||||
/// End-of-query marker with the total tokens produced for this prompt
|
||||
/// and the elapsed time. (The harness uses LoadGen's own timestamps,
|
||||
/// but the line is required to mark the boundary.)
|
||||
pub fn emit_eoq(n_tokens: usize, elapsed: Duration) {
|
||||
let mut out = std::io::stdout().lock();
|
||||
let _ = writeln!(
|
||||
out,
|
||||
"EOQ\t{}\t{:.3}",
|
||||
n_tokens,
|
||||
elapsed.as_secs_f64() * 1e3
|
||||
);
|
||||
let _ = out.flush();
|
||||
}
|
||||
|
||||
/// Read one prompt line. Blank lines are skipped (the harness writes
|
||||
/// one prompt per non-empty line). Returns `None` on EOF / read error.
|
||||
pub fn next_prompt<R: BufRead>(reader: &mut R, buf: &mut String) -> Option<String> {
|
||||
loop {
|
||||
buf.clear();
|
||||
match reader.read_line(buf) {
|
||||
Ok(0) => return None,
|
||||
Ok(_) => {}
|
||||
Err(e) => {
|
||||
eprintln!("stdio read error: {e}");
|
||||
return None;
|
||||
}
|
||||
}
|
||||
let prompt = buf
|
||||
.trim_end_matches('\n')
|
||||
.trim_end_matches('\r')
|
||||
.to_string();
|
||||
if !prompt.is_empty() {
|
||||
return Some(prompt);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Drive the per-prompt protocol: print READY, then for each stdin
|
||||
/// line call `run` with the prompt and post-call emit EOQ. `run`
|
||||
/// returns `(n_tokens_generated, prompt_elapsed)` and is expected to
|
||||
/// have called `emit_tok` once per generated token.
|
||||
pub fn serve(mut run: impl FnMut(&str) -> (usize, Duration)) {
|
||||
ready();
|
||||
let stdin = std::io::stdin();
|
||||
let mut handle = stdin.lock();
|
||||
let mut buf = String::new();
|
||||
while let Some(prompt) = next_prompt(&mut handle, &mut buf) {
|
||||
let (n, elapsed) = run(&prompt);
|
||||
emit_eoq(n, elapsed);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Pull the standard benchmark knobs from env vars in one call.
|
||||
pub struct BenchEnv {
|
||||
pub gen_tokens: usize,
|
||||
pub search_graphs: usize,
|
||||
}
|
||||
|
||||
impl BenchEnv {
|
||||
pub fn from_env(default_gen_tokens: usize, default_search_graphs: usize) -> Self {
|
||||
Self {
|
||||
gen_tokens: env_usize("GEN_TOKENS", default_gen_tokens),
|
||||
search_graphs: env_usize("SEARCH_GRAPHS", default_search_graphs),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
2
examples/flux2/.gitignore
vendored
2
examples/flux2/.gitignore
vendored
@@ -1,2 +0,0 @@
|
||||
*.png
|
||||
reference
|
||||
@@ -1,31 +0,0 @@
|
||||
[package]
|
||||
name = "flux2"
|
||||
version = "0.1.0"
|
||||
edition = "2024"
|
||||
|
||||
[[bin]]
|
||||
name = "flux2"
|
||||
path = "src/main.rs"
|
||||
|
||||
[dependencies]
|
||||
luminal = { path = "../.." }
|
||||
luminal_nn = { path = "../../crates/luminal_nn" }
|
||||
luminal_cuda_lite = { path = "../../crates/luminal_cuda_lite" }
|
||||
luminal_tracing = { path = "../../crates/luminal_tracing" }
|
||||
tokenizers = "0.21"
|
||||
tracing = "0.1.43"
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||
|
||||
# HuggingFace model download
|
||||
hf-hub = { version = "0.4", default-features = false, features = ["rustls-tls", "ureq"] }
|
||||
safetensors = "0.7.0"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
half = { version = "2.7.1", features = ["bytemuck"] }
|
||||
bytemuck = "1.24.0"
|
||||
memmap2 = "0.9.9"
|
||||
|
||||
# PNG output + RNG
|
||||
png = "0.18"
|
||||
rand = "0.9"
|
||||
rand_distr = "0.5"
|
||||
@@ -1,62 +0,0 @@
|
||||
//! HuggingFace download helpers for the Flux 2 multi-folder repo layout.
|
||||
//!
|
||||
//! Unlike the LLM examples we deliberately do **not** combine shards into a
|
||||
//! single FP32 file: Flux 2's transformer weights are 70 GB on disk in BF16
|
||||
//! (~140 GB if upcast to F32) and the text encoder is ~48 GB. We download the
|
||||
//! original BF16 / F32 shards as-is and load them directly via
|
||||
//! `runtime.load_safetensors`, which already supports BF16 in luminal_cuda_lite.
|
||||
|
||||
use hf_hub::api::sync::Api;
|
||||
use serde::Deserialize;
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
|
||||
const REPO_ID: &str = "black-forest-labs/FLUX.2-dev";
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct SafetensorsIndex {
|
||||
weight_map: HashMap<String, String>,
|
||||
}
|
||||
|
||||
fn api() -> Result<hf_hub::api::sync::ApiRepo, Box<dyn std::error::Error>> {
|
||||
let api = Api::new()?;
|
||||
Ok(api.model(REPO_ID.to_string()))
|
||||
}
|
||||
|
||||
/// Download a single file from a sub-folder under the Flux 2 repo, returning
|
||||
/// the local cache path.
|
||||
pub fn fetch(path_in_repo: &str) -> Result<PathBuf, Box<dyn std::error::Error>> {
|
||||
Ok(api()?.get(path_in_repo)?)
|
||||
}
|
||||
|
||||
/// Resolve every shard listed in a sub-folder's safetensors index, downloading
|
||||
/// what's missing. Returns the absolute local paths in shard order.
|
||||
pub fn fetch_sharded(folder: &str) -> Result<Vec<PathBuf>, Box<dyn std::error::Error>> {
|
||||
let index_path = fetch(&format!(
|
||||
"{folder}/diffusion_pytorch_model.safetensors.index.json"
|
||||
))
|
||||
.or_else(|_| fetch(&format!("{folder}/model.safetensors.index.json")))?;
|
||||
|
||||
let raw = std::fs::read_to_string(&index_path)?;
|
||||
let idx: SafetensorsIndex = serde_json::from_str(&raw)?;
|
||||
|
||||
let mut files: Vec<String> = idx.weight_map.values().cloned().collect();
|
||||
files.sort();
|
||||
files.dedup();
|
||||
|
||||
let mut paths = Vec::with_capacity(files.len());
|
||||
for f in files {
|
||||
paths.push(fetch(&format!("{folder}/{f}"))?);
|
||||
}
|
||||
Ok(paths)
|
||||
}
|
||||
|
||||
/// Convenience wrapper for the small VAE (single safetensors file).
|
||||
pub fn fetch_vae() -> Result<PathBuf, Box<dyn std::error::Error>> {
|
||||
fetch("vae/diffusion_pytorch_model.safetensors")
|
||||
}
|
||||
|
||||
/// Tokenizer JSON (Pixtral / Mistral tokenizer used by Flux 2's text encoder).
|
||||
pub fn fetch_tokenizer() -> Result<PathBuf, Box<dyn std::error::Error>> {
|
||||
fetch("tokenizer/tokenizer.json")
|
||||
}
|
||||
@@ -1,556 +0,0 @@
|
||||
//! Flux 2 (`black-forest-labs/FLUX.2-dev`) text-to-image example.
|
||||
//!
|
||||
//! End-to-end pipeline:
|
||||
//! ```text
|
||||
//! prompt → tokenize → Mistral3 text encoder ─► text features (S_txt, 15360)
|
||||
//! noise latent (S_img, 128) ─► transformer (28× denoising) ─► clean latent
|
||||
//! latent ─► VAE decoder ─► (3, H, W) image ─► PNG
|
||||
//! ```
|
||||
//!
|
||||
//! ## Optional env
|
||||
//! * `FLUX2_NUM_LAYERS` / `FLUX2_NUM_SINGLE_LAYERS` (optional) — override
|
||||
//! the default 8 + 48 transformer block counts. The default count
|
||||
//! overflows the 96 GB GPU because there's no live-range buffer
|
||||
//! reuse in `CudaRuntime::allocate_intermediate_buffers` — every
|
||||
//! intermediate is alive for the whole forward pass. `1 + 1` runs
|
||||
//! the full pipeline end-to-end at 1024² in well under a minute and
|
||||
//! is the right setting for plumbing-validation. Higher counts
|
||||
//! (e.g. `8 + 16`) work but use proportionally more memory.
|
||||
//!
|
||||
//! ## Memory plan
|
||||
//! GPU is 96 GB; transformer (60 GB BF16) + text encoder (33 GB BF16) +
|
||||
//! VAE (336 MB) won't all fit. The full pipeline keeps **at most one** large
|
||||
//! model resident at a time:
|
||||
//! 1. Load text encoder, encode prompt, **drop the runtime** to free 33 GB.
|
||||
//! 2. Load transformer, run the diffusion loop, **drop the runtime**.
|
||||
//! 3. Load VAE, decode, dump PNG.
|
||||
|
||||
mod hf;
|
||||
#[allow(dead_code)]
|
||||
mod quant;
|
||||
mod scheduler;
|
||||
mod text_encoder;
|
||||
mod transformer;
|
||||
mod vae;
|
||||
|
||||
use std::fs::File;
|
||||
use std::io::BufWriter;
|
||||
use std::time::Instant;
|
||||
|
||||
use luminal::graph::BuildSearchSpaceOptions;
|
||||
use luminal::prelude::*;
|
||||
use luminal_cuda_lite::{cudarc::driver::CudaContext, runtime::CudaRuntime};
|
||||
use rand::{Rng, SeedableRng, rngs::StdRng};
|
||||
use rand_distr::StandardNormal;
|
||||
use scheduler::{SchedulerConfig, compute_mu, euler_step, make_schedule};
|
||||
use tokenizers::Tokenizer;
|
||||
use vae::{LATENT_CHANNELS, VAE_DOWNSAMPLE, VaeDecoder};
|
||||
|
||||
fn env_usize(name: &str, default: usize) -> usize {
|
||||
std::env::var(name)
|
||||
.ok()
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(default)
|
||||
}
|
||||
|
||||
fn env_f32(name: &str, default: f32) -> f32 {
|
||||
std::env::var(name)
|
||||
.ok()
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(default)
|
||||
}
|
||||
|
||||
/// Override-able via `TEXT_LEN=N` for testing. Diffusers' Flux 2 pipeline
|
||||
/// pads to 512; smaller works for the text encoder's transformer compile to
|
||||
/// fit in fewer GPU temp buffers during search.
|
||||
const DEFAULT_TEXT_LEN: usize = 512;
|
||||
|
||||
fn text_len() -> usize {
|
||||
env_usize("TEXT_LEN", DEFAULT_TEXT_LEN)
|
||||
}
|
||||
const DEFAULT_GUIDANCE: f32 = 2.5;
|
||||
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let prompt = std::env::args()
|
||||
.nth(1)
|
||||
.unwrap_or_else(|| "a cat in a hat".to_string());
|
||||
let width = env_usize("WIDTH", 1024);
|
||||
let height = env_usize("HEIGHT", 1024);
|
||||
let steps = env_usize("STEPS", 28);
|
||||
let guidance = env_f32("GUIDANCE", DEFAULT_GUIDANCE);
|
||||
|
||||
println!("Prompt: {prompt}");
|
||||
println!("Resolution: {width}x{height}, steps={steps}, guidance={guidance}");
|
||||
|
||||
run_full_pipeline(&prompt, width, height, steps, guidance)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Text encoder path (compute the (S_txt, 15360) prompt features)
|
||||
// =============================================================================
|
||||
|
||||
fn tokenize_prompt(
|
||||
tokenizer: &Tokenizer,
|
||||
prompt: &str,
|
||||
text_len: usize,
|
||||
) -> Result<(Vec<i32>, usize), Box<dyn std::error::Error>> {
|
||||
// Format the chat template (system + user) the way Flux 2's pipeline does,
|
||||
// then tokenize. The Mistral 3 tokenizer treats `[SYSTEM_PROMPT]`,
|
||||
// `[/SYSTEM_PROMPT]`, `[INST]`, `[/INST]` as added tokens, so they
|
||||
// round-trip as single ids; the `<s>` BOS is added by the tokenizer
|
||||
// (`add_bos_token = true`).
|
||||
let formatted = text_encoder::format_chat(text_encoder::SYSTEM_MESSAGE, prompt);
|
||||
let encoded = tokenizer
|
||||
.encode(formatted, true)
|
||||
.map_err(|e| format!("tokenize failed: {e}"))?;
|
||||
let mut ids: Vec<i32> = encoded.get_ids().iter().map(|&i| i as i32).collect();
|
||||
let real_len = ids.len();
|
||||
if real_len > text_len {
|
||||
ids.truncate(text_len);
|
||||
} else {
|
||||
// Right-pad to `text_len` with Mistral's `<pad>` token (id 11).
|
||||
// The previous padding value of 0 (= `<unk>`) silently gave
|
||||
// every padding position a different embedding than diffusers
|
||||
// — divergence appeared at exactly position `real_len` and
|
||||
// compounded through 30 layers, leaving prompt_embeds with
|
||||
// cos_sim ≈ 0.65 against the reference. See
|
||||
// `tokenizer.json` added_tokens_decoder: id=11 is `<pad>`.
|
||||
ids.resize(text_len, 11);
|
||||
}
|
||||
Ok((ids, real_len.min(text_len)))
|
||||
}
|
||||
|
||||
fn run_text_encoder(prompt: &str) -> Result<Vec<f32>, Box<dyn std::error::Error>> {
|
||||
println!("\n[1/3] Resolving text encoder weights...");
|
||||
let tok_path = hf::fetch_tokenizer()?;
|
||||
let te_paths = hf::fetch_sharded("text_encoder")?;
|
||||
|
||||
println!("Loading tokenizer...");
|
||||
let tokenizer = Tokenizer::from_file(&tok_path).map_err(|e| format!("tokenizer: {e}"))?;
|
||||
let text_len = text_len();
|
||||
let (ids, real_len) = tokenize_prompt(&tokenizer, prompt, text_len)?;
|
||||
println!(
|
||||
" prompt → {} ids ({} real, padded to {})",
|
||||
ids.len(),
|
||||
real_len,
|
||||
text_len,
|
||||
);
|
||||
|
||||
println!("Building text encoder graph...");
|
||||
let mut cx = Graph::default();
|
||||
let input_ids = cx
|
||||
.named_tensor("__input_ids", text_len)
|
||||
.as_dtype(DType::Int);
|
||||
let pos_ids = cx.named_tensor("__pos_ids", text_len).as_dtype(DType::Int);
|
||||
// Attention mask: 1 for real tokens (positions 0..real_len), 0 for
|
||||
// padding. Mistral 3 self-attention masks padding keys so padding
|
||||
// queries only attend to the real prefix; without it our padding
|
||||
// hidden states drift wildly from diffusers (cos_sim ~0.65 on the
|
||||
// 15360-dim text features even when token IDs match exactly).
|
||||
let attention_mask = cx
|
||||
.named_tensor("__attention_mask", text_len)
|
||||
.as_dtype(DType::F32);
|
||||
let encoder = text_encoder::Mistral3TextEncoder::init(&mut cx);
|
||||
let features = encoder.forward(input_ids, pos_ids, attention_mask).output();
|
||||
// Memory-budget enforcement is opt-in (the estimator over-counts; see
|
||||
// the matching comment in `run_vae_only`). Set `TEXT_MEM_GIB` to opt in.
|
||||
if let Ok(g) = std::env::var("TEXT_MEM_GIB").and_then(|s| {
|
||||
s.parse::<usize>()
|
||||
.map_err(|_| std::env::VarError::NotPresent)
|
||||
}) {
|
||||
cx.build_search_space_with_options::<CudaRuntime>(
|
||||
BuildSearchSpaceOptions::new().max_memory_gib(g),
|
||||
);
|
||||
} else {
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
}
|
||||
|
||||
let ctx = CudaContext::new(0).unwrap();
|
||||
let stream = ctx.default_stream();
|
||||
let mut runtime = CudaRuntime::initialize(stream);
|
||||
|
||||
println!(
|
||||
"Loading {} text encoder shards (~48 GB BF16)...",
|
||||
te_paths.len()
|
||||
);
|
||||
let t0 = Instant::now();
|
||||
for p in &te_paths {
|
||||
runtime.load_safetensors(&cx, p.to_str().unwrap());
|
||||
}
|
||||
println!(" loaded in {:.1}s", t0.elapsed().as_secs_f64());
|
||||
|
||||
runtime.set_data(input_ids, ids);
|
||||
runtime.set_data(pos_ids, (0..text_len as i32).collect::<Vec<_>>());
|
||||
let mask: Vec<f32> = (0..text_len)
|
||||
.map(|i| if i < real_len { 1.0_f32 } else { 0.0_f32 })
|
||||
.collect();
|
||||
runtime.set_data(attention_mask, mask);
|
||||
|
||||
println!("Compiling text encoder...");
|
||||
let t0 = Instant::now();
|
||||
runtime = cx.search(runtime, env_usize("SEARCH_ITERS", 5));
|
||||
println!(" compile done in {:.1}s", t0.elapsed().as_secs_f64());
|
||||
|
||||
println!("Encoding prompt...");
|
||||
let t0 = Instant::now();
|
||||
runtime.execute(&cx.dyn_map);
|
||||
let out = runtime.get_f32(features);
|
||||
println!(" encode done in {:.1}s", t0.elapsed().as_secs_f64());
|
||||
println!(
|
||||
" features: len={} (= {} × {})",
|
||||
out.len(),
|
||||
text_len,
|
||||
text_encoder::OUTPUT_DIM,
|
||||
);
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Full pipeline: text → diffusion → VAE → PNG
|
||||
// =============================================================================
|
||||
|
||||
fn run_full_pipeline(
|
||||
prompt: &str,
|
||||
width: usize,
|
||||
height: usize,
|
||||
steps: usize,
|
||||
guidance: f32,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
// VAE latent grid: (LATENT_CHANNELS=32, h_lat, w_lat)
|
||||
let h_lat = height / VAE_DOWNSAMPLE;
|
||||
let w_lat = width / VAE_DOWNSAMPLE;
|
||||
// Transformer "pack" grid: (IN_CHANNELS=128, h_pack, w_pack) = (32*4, h_lat/2, w_lat/2).
|
||||
// The diffusers pipeline folds 2×2 spatial pixels into the channel axis
|
||||
// before the transformer (`_patchify_latents`) and undoes it after
|
||||
// (`_unpatchify_latents`), so the transformer sees `(S_img, 128)` tokens
|
||||
// where `S_img = (H/16) * (W/16)`.
|
||||
assert!(
|
||||
h_lat.is_multiple_of(2) && w_lat.is_multiple_of(2),
|
||||
"WIDTH and HEIGHT must be multiples of 16 (got {width}x{height})",
|
||||
);
|
||||
let h_pack = h_lat / 2;
|
||||
let w_pack = w_lat / 2;
|
||||
let s_img = h_pack * w_pack;
|
||||
let s_txt = text_len();
|
||||
|
||||
// ── 1. TEXT ENCODE ─────────────────────────────────────────────────────────
|
||||
let text_features = run_text_encoder(prompt)?;
|
||||
assert_eq!(text_features.len(), s_txt * text_encoder::OUTPUT_DIM);
|
||||
|
||||
// ── 2. DIFFUSION LOOP ──────────────────────────────────────────────────────
|
||||
println!("\n[2/3] Resolving transformer weights...");
|
||||
let tx_paths = hf::fetch_sharded("transformer")?;
|
||||
println!(
|
||||
" {} transformer shards downloaded ({:.1} GB total)",
|
||||
tx_paths.len(),
|
||||
tx_paths
|
||||
.iter()
|
||||
.map(|p| std::fs::metadata(p).map(|m| m.len()).unwrap_or(0))
|
||||
.sum::<u64>() as f64
|
||||
/ 1e9,
|
||||
);
|
||||
|
||||
let cfg = SchedulerConfig::default();
|
||||
let image_seq_len = s_img;
|
||||
let mu = compute_mu(&cfg, image_seq_len);
|
||||
let (sigmas, timesteps) = make_schedule(&cfg, steps, mu);
|
||||
println!(" scheduler: mu={mu:.4}, {} steps", timesteps.len());
|
||||
|
||||
// Pre-compute RoPE tables (host-side; these are constant per resolution).
|
||||
// Grid is the post-pack `(h_pack, w_pack)`, matching what the transformer
|
||||
// and diffusers' `_prepare_latent_ids` see.
|
||||
let (rope_cos, rope_sin) = transformer::build_rope_tables(s_txt, h_pack, w_pack);
|
||||
let s_total = s_txt + s_img;
|
||||
assert_eq!(rope_cos.len(), s_total * transformer::HEAD_DIM);
|
||||
|
||||
// Initial noise latent in (S_img, IN_CHANNELS) layout.
|
||||
let mut rng = StdRng::seed_from_u64(env_usize("SEED", 0) as u64);
|
||||
let mut latent: Vec<f32> = (0..s_img * transformer::IN_CHANNELS)
|
||||
.map(|_| rng.sample::<f32, _>(StandardNormal))
|
||||
.collect();
|
||||
|
||||
println!("Building transformer graph...");
|
||||
let mut cx = Graph::default();
|
||||
// Inputs that change per diffusion step.
|
||||
let latent_in = cx.named_tensor("__latent", (s_img, transformer::IN_CHANNELS));
|
||||
let timestep_in = cx.named_tensor("__timestep", 1);
|
||||
// Inputs that are constant across the whole diffusion loop. `.persist()`
|
||||
// marks them as outputs so their buffers survive between successive
|
||||
// `runtime.execute()` calls; without this the runtime treats them as
|
||||
// transient intermediates and a second `execute()` reads freed memory
|
||||
// (manifests as `CUDA_ERROR_ILLEGAL_ADDRESS` on the post-kernel sync).
|
||||
let text_in = cx
|
||||
.named_tensor("__text", (s_txt, text_encoder::OUTPUT_DIM))
|
||||
.persist();
|
||||
let cos_in = cx
|
||||
.named_tensor("__rope_cos", (s_total, transformer::HEAD_DIM))
|
||||
.persist();
|
||||
let sin_in = cx
|
||||
.named_tensor("__rope_sin", (s_total, transformer::HEAD_DIM))
|
||||
.persist();
|
||||
let guidance_in = cx.named_tensor("__guidance", 1).persist();
|
||||
|
||||
let model = transformer::Flux2Transformer::init(&mut cx);
|
||||
let velocity = model
|
||||
.forward(latent_in, text_in, cos_in, sin_in, timestep_in, guidance_in)
|
||||
.output();
|
||||
|
||||
println!("Building search space (this is the long step — many minutes for the full DiT)...");
|
||||
if let Ok(g) = std::env::var("TX_MEM_GIB").and_then(|s| {
|
||||
s.parse::<usize>()
|
||||
.map_err(|_| std::env::VarError::NotPresent)
|
||||
}) {
|
||||
cx.build_search_space_with_options::<CudaRuntime>(
|
||||
BuildSearchSpaceOptions::new().max_memory_gib(g),
|
||||
);
|
||||
} else {
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
}
|
||||
|
||||
let ctx = CudaContext::new(0).unwrap();
|
||||
let stream = ctx.default_stream();
|
||||
let mut runtime = CudaRuntime::initialize(stream);
|
||||
|
||||
println!(
|
||||
"Loading {} transformer shards (~{:.1} GB BF16)...",
|
||||
tx_paths.len(),
|
||||
tx_paths
|
||||
.iter()
|
||||
.map(|p| std::fs::metadata(p).map(|m| m.len()).unwrap_or(0))
|
||||
.sum::<u64>() as f64
|
||||
/ 1e9,
|
||||
);
|
||||
let t0 = Instant::now();
|
||||
for p in &tx_paths {
|
||||
runtime.load_safetensors(&cx, p.to_str().unwrap());
|
||||
}
|
||||
println!(" loaded in {:.1}s", t0.elapsed().as_secs_f64());
|
||||
|
||||
// Set the inputs that don't change across steps.
|
||||
runtime.set_data(text_in, text_features);
|
||||
runtime.set_data(cos_in, rope_cos);
|
||||
runtime.set_data(sin_in, rope_sin);
|
||||
// Match diffusers' transformer call signature:
|
||||
// `timestep=timestep / 1000` (0..1 range, sigma-like)
|
||||
// `guidance=guidance` (raw guidance_scale, e.g. 2.5)
|
||||
// The previous code multiplied both by 1000, making the
|
||||
// `timesteps_proj` argument saturate.
|
||||
runtime.set_data(guidance_in, vec![guidance]);
|
||||
|
||||
// First-step dummy values so search() has shapes/data to profile against.
|
||||
runtime.set_data(latent_in, latent.clone());
|
||||
runtime.set_data(timestep_in, vec![timesteps[0] / 1000.0]);
|
||||
|
||||
println!("Compiling transformer (search)...");
|
||||
let t0 = Instant::now();
|
||||
if let Ok(seed) = std::env::var("TX_SEARCH_SEED")
|
||||
.and_then(|s| s.parse::<u64>().map_err(|_| std::env::VarError::NotPresent))
|
||||
{
|
||||
use rand::SeedableRng;
|
||||
let mut rng = rand::rngs::SmallRng::seed_from_u64(seed);
|
||||
let opts = luminal::graph::SearchOptions::new(env_usize("SEARCH_ITERS", 5));
|
||||
runtime = cx.search_options(runtime, opts, &mut rng);
|
||||
} else {
|
||||
runtime = cx.search(runtime, env_usize("SEARCH_ITERS", 5));
|
||||
}
|
||||
println!(" compile done in {:.1}s", t0.elapsed().as_secs_f64());
|
||||
|
||||
println!("Running diffusion loop ({} steps)...", timesteps.len());
|
||||
for (i, &t) in timesteps.iter().enumerate() {
|
||||
let step_start = Instant::now();
|
||||
runtime.set_data(latent_in, latent.clone());
|
||||
runtime.set_data(timestep_in, vec![t / 1000.0]);
|
||||
runtime.execute(&cx.dyn_map);
|
||||
let v = runtime.get_f32(velocity);
|
||||
// Euler integrate: latent += (sigma_next - sigma) * v
|
||||
euler_step(&mut latent, &v, sigmas[i], sigmas[i + 1]);
|
||||
println!(
|
||||
" step {:>2}/{}: t={:>8.2}, σ {:.4} → {:.4} ({:.1}s)",
|
||||
i + 1,
|
||||
timesteps.len(),
|
||||
t,
|
||||
sigmas[i],
|
||||
sigmas[i + 1],
|
||||
step_start.elapsed().as_secs_f64(),
|
||||
);
|
||||
}
|
||||
|
||||
// Drop the transformer runtime to free its weights before loading the VAE.
|
||||
drop(runtime);
|
||||
drop(cx);
|
||||
|
||||
// ── 3. VAE DECODE ──────────────────────────────────────────────────────────
|
||||
println!("\n[3/3] Decoding latent through VAE...");
|
||||
let vae_path = hf::fetch_vae()?;
|
||||
|
||||
// Convert the diffusion output `(S_img, 128)` to the VAE's input shape
|
||||
// `(32, h_lat, w_lat)` on the host. Mirrors the diffusers pipeline:
|
||||
// 1. _unpack_latents_with_ids: (S_img, 128) -> (128, h_pack, w_pack)
|
||||
// 2. BN inverse: x = x * bn_std + bn_mean (per-channel)
|
||||
// 3. _unpatchify_latents: (128, h_pack, w_pack) -> (32, h_lat, w_lat)
|
||||
let bn_mean = read_safetensors_f32(&vae_path, "bn.running_mean")?;
|
||||
let bn_var = read_safetensors_f32(&vae_path, "bn.running_var")?;
|
||||
assert_eq!(bn_mean.len(), transformer::IN_CHANNELS);
|
||||
assert_eq!(bn_var.len(), transformer::IN_CHANNELS);
|
||||
const BN_EPS: f32 = 1e-4; // matches vae/config.json batch_norm_eps=0.0001
|
||||
let bn_std: Vec<f32> = bn_var.iter().map(|v| (v + BN_EPS).sqrt()).collect();
|
||||
|
||||
let unpacked = unpack_packed_host(&latent, transformer::IN_CHANNELS, h_pack, w_pack);
|
||||
let denormed = bn_inverse_host(&unpacked, &bn_mean, &bn_std, transformer::IN_CHANNELS);
|
||||
let vae_input = unpatchify_host(&denormed, LATENT_CHANNELS, h_pack, w_pack);
|
||||
assert_eq!(vae_input.len(), LATENT_CHANNELS * h_lat * w_lat);
|
||||
|
||||
let mut cx = Graph::default();
|
||||
let latent_in = cx.named_tensor("latent", (LATENT_CHANNELS, h_lat, w_lat));
|
||||
let decoder = VaeDecoder::new(&mut cx);
|
||||
let out = decoder.forward(latent_in).output();
|
||||
if let Ok(g) = std::env::var("VAE_MEM_GIB").and_then(|s| {
|
||||
s.parse::<usize>()
|
||||
.map_err(|_| std::env::VarError::NotPresent)
|
||||
}) {
|
||||
cx.build_search_space_with_options::<CudaRuntime>(
|
||||
BuildSearchSpaceOptions::new().max_memory_gib(g),
|
||||
);
|
||||
} else {
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
}
|
||||
|
||||
let ctx = CudaContext::new(0).unwrap();
|
||||
let stream = ctx.default_stream();
|
||||
let mut runtime = CudaRuntime::initialize(stream);
|
||||
runtime.load_safetensors(&cx, vae_path.to_str().unwrap());
|
||||
runtime.set_data(latent_in, vae_input);
|
||||
runtime = cx.search(runtime, env_usize("SEARCH_ITERS", 5));
|
||||
runtime.execute(&cx.dyn_map);
|
||||
let img = runtime.get_f32(out);
|
||||
// VaeDecoder output is in roughly [-1, 1] range. Diffusers'
|
||||
// ImageProcessor.postprocess does `((x + 1) / 2).clamp(0, 1)` for
|
||||
// output_type="pt".
|
||||
save_png("out.png", &img, width, height)?;
|
||||
println!("Wrote out.png");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Host-side pipeline glue: pack/unpack/BN/unpatchify between the transformer
|
||||
// and the VAE. These mirror diffusers' Flux2Pipeline static methods exactly.
|
||||
// =============================================================================
|
||||
|
||||
/// Inverse of `_pack_latents`: `(S_img, C) -> (C, h_pack, w_pack)` row-major.
|
||||
fn unpack_packed_host(packed: &[f32], c: usize, h_pack: usize, w_pack: usize) -> Vec<f32> {
|
||||
let s_img = h_pack * w_pack;
|
||||
assert_eq!(packed.len(), s_img * c);
|
||||
let mut out = vec![0.0_f32; c * s_img];
|
||||
for hi in 0..h_pack {
|
||||
for wi in 0..w_pack {
|
||||
let token = hi * w_pack + wi;
|
||||
for ci in 0..c {
|
||||
out[ci * s_img + token] = packed[token * c + ci];
|
||||
}
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
/// `latent[c, *] = latent[c, *] * std[c] + mean[c]`. In-place by-copy.
|
||||
fn bn_inverse_host(latent: &[f32], mean: &[f32], std: &[f32], c: usize) -> Vec<f32> {
|
||||
let hw = latent.len() / c;
|
||||
assert_eq!(mean.len(), c);
|
||||
assert_eq!(std.len(), c);
|
||||
let mut out = vec![0.0_f32; latent.len()];
|
||||
for ci in 0..c {
|
||||
let m = mean[ci];
|
||||
let s = std[ci];
|
||||
for i in 0..hw {
|
||||
out[ci * hw + i] = latent[ci * hw + i] * s + m;
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
/// `_unpatchify_latents`: `(C*4, h_pack, w_pack) -> (C, 2*h_pack, 2*w_pack)`.
|
||||
///
|
||||
/// Diffusers does:
|
||||
/// ```python
|
||||
/// latents.reshape(B, C, 2, 2, H, W).permute(0, 1, 4, 2, 5, 3).reshape(B, C, 2H, 2W)
|
||||
/// ```
|
||||
/// So input channel `c*4 + ph*2 + pw` (with ph, pw in {0,1}) maps to output
|
||||
/// position `(c, hi*2 + ph, wi*2 + pw)`.
|
||||
fn unpatchify_host(packed: &[f32], c_out: usize, h_pack: usize, w_pack: usize) -> Vec<f32> {
|
||||
assert_eq!(packed.len(), c_out * 4 * h_pack * w_pack);
|
||||
let h_lat = h_pack * 2;
|
||||
let w_lat = w_pack * 2;
|
||||
let mut out = vec![0.0_f32; c_out * h_lat * w_lat];
|
||||
for c in 0..c_out {
|
||||
for ph in 0..2 {
|
||||
for pw in 0..2 {
|
||||
let in_c = c * 4 + ph * 2 + pw;
|
||||
for hi in 0..h_pack {
|
||||
for wi in 0..w_pack {
|
||||
let in_idx = in_c * h_pack * w_pack + hi * w_pack + wi;
|
||||
let out_h = hi * 2 + ph;
|
||||
let out_w = wi * 2 + pw;
|
||||
let out_idx = c * h_lat * w_lat + out_h * w_lat + out_w;
|
||||
out[out_idx] = packed[in_idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
/// Read one F32 tensor by name from a single-file safetensors archive.
|
||||
fn read_safetensors_f32(
|
||||
path: &std::path::Path,
|
||||
name: &str,
|
||||
) -> Result<Vec<f32>, Box<dyn std::error::Error>> {
|
||||
use std::io::{Read, Seek, SeekFrom};
|
||||
let mut file = std::fs::File::open(path)?;
|
||||
let mut header_len_bytes = [0u8; 8];
|
||||
file.read_exact(&mut header_len_bytes)?;
|
||||
let header_len = u64::from_le_bytes(header_len_bytes) as usize;
|
||||
let mut header_bytes = vec![0u8; header_len];
|
||||
file.read_exact(&mut header_bytes)?;
|
||||
let header: serde_json::Value = serde_json::from_slice(&header_bytes)?;
|
||||
let info = header
|
||||
.get(name)
|
||||
.ok_or_else(|| format!("safetensors: tensor '{name}' not found in {path:?}"))?;
|
||||
let dtype = info["dtype"].as_str().unwrap_or("");
|
||||
if dtype != "F32" {
|
||||
return Err(format!("safetensors: tensor '{name}' has dtype {dtype}, want F32").into());
|
||||
}
|
||||
let offsets = &info["data_offsets"];
|
||||
let start = offsets[0].as_u64().unwrap();
|
||||
let end = offsets[1].as_u64().unwrap();
|
||||
let n_bytes = (end - start) as usize;
|
||||
file.seek(SeekFrom::Start(8 + header_len as u64 + start))?;
|
||||
let mut buf = vec![0u8; n_bytes];
|
||||
file.read_exact(&mut buf)?;
|
||||
Ok(buf
|
||||
.chunks_exact(4)
|
||||
.map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
|
||||
.collect())
|
||||
}
|
||||
|
||||
fn save_png(path: &str, chw: &[f32], w: usize, h: usize) -> Result<(), Box<dyn std::error::Error>> {
|
||||
assert_eq!(chw.len(), 3 * h * w, "save_png: shape mismatch");
|
||||
let mut bytes = vec![0u8; 3 * h * w];
|
||||
for y in 0..h {
|
||||
for x in 0..w {
|
||||
for c in 0..3 {
|
||||
let v = chw[c * h * w + y * w + x];
|
||||
let v = ((v + 1.0) * 0.5 * 255.0).clamp(0.0, 255.0);
|
||||
bytes[(y * w + x) * 3 + c] = v as u8;
|
||||
}
|
||||
}
|
||||
}
|
||||
let file = File::create(path)?;
|
||||
let bw = BufWriter::new(file);
|
||||
let mut encoder = png::Encoder::new(bw, w as u32, h as u32);
|
||||
encoder.set_color(png::ColorType::Rgb);
|
||||
encoder.set_depth(png::BitDepth::Eight);
|
||||
encoder.write_header()?.write_image_data(&bytes)?;
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,214 +0,0 @@
|
||||
//! NVFP4 dequant + linear, modelled entirely in HLIR.
|
||||
//!
|
||||
//! Mirrors the pattern used by `luminal_tron::dequant_matmul` for GPTQ: declare
|
||||
//! the packed weights as named tensors with their **logical** dtype, then
|
||||
//! express the dequantization as ordinary HLIR ops (cast, expand, repeat,
|
||||
//! multiply). No custom ops, no opaque kernels — the optimizer sees the entire
|
||||
//! `cast → broadcast → multiply → matmul` chain and can fuse it on its own.
|
||||
//!
|
||||
//! ## File layout (NVIDIA Model Optimizer NVFP4)
|
||||
//!
|
||||
//! For each quantized linear layer, four tensors are stored in the safetensors
|
||||
//! file:
|
||||
//!
|
||||
//! | name suffix | dtype | logical shape | meaning |
|
||||
//! |------------------|-----------|----------------|--------------------------------------|
|
||||
//! | `weight` | F4E2M1 | (out, in) | 4-bit signed weight, 2 packed/byte |
|
||||
//! | `weight_scale` | F8E4M3 | (out, in/16) | per-block FP8 scale (block size 16) |
|
||||
//! | `weight_scale_2` | F32 | (1,) | per-tensor outer scale |
|
||||
//! | `input_scale` | F32 | (1,) | for activation requant (unused here) |
|
||||
//!
|
||||
//! On disk the FP4 weight is recorded as `dtype=U8 shape=[out, in/2]` (two
|
||||
//! values per byte). luminal's bit-aware sizing computes byte counts from the
|
||||
//! **logical** dtype, so declaring it as `F4E2M1 shape=(out, in)` matches the
|
||||
//! same byte count and the safetensors loader uploads the raw bytes verbatim.
|
||||
//!
|
||||
//! ## Reconstruction (the math we model)
|
||||
//!
|
||||
//! real_W[o, i] = fp4_to_f(weight[o, i])
|
||||
//! * fp8_to_f(weight_scale[o, i / 16])
|
||||
//! * weight_scale_2
|
||||
//!
|
||||
//! Step by step in HLIR:
|
||||
//! 1. `cast(weight, target)` unpacks 4-bit packed bytes → target dtype via
|
||||
//! the existing sub-byte path in `KernelCast`.
|
||||
//! 2. `cast(weight_scale, target)` converts FP8_E4M3 → target dtype via the
|
||||
//! standard cast kernel (CUDA's `__nv_fp8_e4m3` has built-in conversion).
|
||||
//! 3. The (out, in/16) scale is broadcast to (out, in) by inserting a length-
|
||||
//! 16 axis (`expand_dim`) and merging it back, so each FP8 scale applies
|
||||
//! to a contiguous run of 16 input columns.
|
||||
//! 4. `weight_scale_2` (a 1-element tensor) is broadcast to (out, in) via
|
||||
//! `expand_lhs` + `repeat`.
|
||||
//! 5. The three are multiplied, and the result is fed to a standard matmul.
|
||||
|
||||
use luminal::dtype::DType;
|
||||
use luminal::graph::Graph;
|
||||
use luminal::prelude::GraphTensor;
|
||||
use luminal::shape::Expression;
|
||||
|
||||
/// NVFP4 block size along the input dimension (matches NVIDIA modelopt /
|
||||
/// Blackwell hardware).
|
||||
pub const NVFP4_BLOCK: usize = 16;
|
||||
|
||||
/// Persistent tensor handles for one NVFP4-quantized linear layer.
|
||||
pub struct Nvfp4Linear {
|
||||
/// Packed FP4 weight, declared shape `(out, in)`, dtype `F4E2M1`.
|
||||
pub weight: GraphTensor,
|
||||
/// Per-block FP8 scale, declared shape `(out, in / 16)`, dtype `F8E4M3`.
|
||||
pub weight_scale: GraphTensor,
|
||||
/// Per-tensor F32 outer scale, declared shape `(1,)`.
|
||||
pub weight_scale_2: GraphTensor,
|
||||
}
|
||||
|
||||
impl Nvfp4Linear {
|
||||
/// Declare the persistent inputs for one NVFP4 linear layer.
|
||||
///
|
||||
/// `out_dim` and `in_dim` are the **unpacked** weight dimensions (matching
|
||||
/// PyTorch `Linear.weight: (out, in)` semantics). `in_dim` must be a
|
||||
/// multiple of [`NVFP4_BLOCK`].
|
||||
pub fn new(prefix: &str, out_dim: usize, in_dim: usize, cx: &mut Graph) -> Self {
|
||||
assert!(
|
||||
in_dim.is_multiple_of(NVFP4_BLOCK),
|
||||
"in_dim ({in_dim}) must be a multiple of NVFP4 block size ({NVFP4_BLOCK})",
|
||||
);
|
||||
let in_blocks = in_dim / NVFP4_BLOCK;
|
||||
Self {
|
||||
weight: cx
|
||||
.named_tensor(format!("{prefix}.weight"), (out_dim, in_dim))
|
||||
.as_dtype(DType::F4E2M1)
|
||||
.persist(),
|
||||
weight_scale: cx
|
||||
.named_tensor(format!("{prefix}.weight_scale"), (out_dim, in_blocks))
|
||||
.as_dtype(DType::F8E4M3)
|
||||
.persist(),
|
||||
weight_scale_2: cx
|
||||
.named_tensor(format!("{prefix}.weight_scale_2"), 1)
|
||||
.as_dtype(DType::F32)
|
||||
.persist(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Reconstruct the dense weight `(out, in)` in the requested dtype using
|
||||
/// only HLIR ops.
|
||||
pub fn dequant(&self, target_dtype: DType) -> GraphTensor {
|
||||
let w_dims = self.weight.dims();
|
||||
let out_dim = w_dims[0];
|
||||
let in_dim = w_dims[1];
|
||||
|
||||
// 1. Cast the packed FP4 weights. KernelCast's bits<8 path extracts
|
||||
// each 4-bit field and runs CUDA's __nv_fp4_e2m1 → target_dtype
|
||||
// conversion, so the result already holds the correct numerical
|
||||
// FP4 values (in {±0, ±0.5, ±1, ±1.5, ±2, ±3, ±4, ±6}) cast up to
|
||||
// the target dtype.
|
||||
let w = self.weight.cast(target_dtype);
|
||||
|
||||
// 2. Cast the per-block FP8 scales.
|
||||
let s = self.weight_scale.cast(target_dtype);
|
||||
|
||||
// 3. Broadcast each block scale to NVFP4_BLOCK consecutive columns:
|
||||
// (out, in/16) -> (out, in/16, 16) via expand_dim broadcast,
|
||||
// -> (out, in) via merge_dims. Element (o, i) ends up reading
|
||||
// weight_scale[o, i / 16].
|
||||
let s_blocked = s.expand_dim(2, NVFP4_BLOCK).merge_dims(1, 2);
|
||||
|
||||
// 4. Broadcast the scalar outer scale to (out, in). expand_lhs adds a
|
||||
// new outer axis of size out_dim (broadcast), then repeat extends
|
||||
// the original size-1 axis to in_dim (also broadcast since the
|
||||
// original dim was 1).
|
||||
let s2 = self
|
||||
.weight_scale_2
|
||||
.cast(target_dtype)
|
||||
.expand_lhs([out_dim])
|
||||
.repeat([Expression::from(1_usize), in_dim]);
|
||||
|
||||
// 5. Combined elementwise dequant.
|
||||
w * s_blocked * s2
|
||||
}
|
||||
|
||||
/// Standard linear forward: `y = x @ dequant(W)^T`. Dequant is performed
|
||||
/// in `x.dtype` so the matmul stays in a single dtype.
|
||||
pub fn forward(&self, x: GraphTensor) -> GraphTensor {
|
||||
let dequant = self.dequant(x.dtype);
|
||||
x.matmul(dequant.t())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
//! Reference dequant in pure Rust we can compare HLIR output against.
|
||||
//! See `src/quant_tests.rs` for an end-to-end CUDA round-trip; these
|
||||
//! tests cover only the Rust scalar reference math.
|
||||
|
||||
/// 16-entry FP4 E2M1 table (1 sign + 2 exponent + 1 mantissa, no NaN/Inf).
|
||||
/// Confirmed against the OCP MX spec / NVIDIA modelopt fp4 docs.
|
||||
pub const FP4_E2M1_LUT: [f32; 16] = [
|
||||
0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, // sign=0
|
||||
-0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0, // sign=1
|
||||
];
|
||||
|
||||
/// FP8 E4M3 (1 sign + 4 exponent + 3 mantissa, bias=7) — finite-only path
|
||||
/// matching CUDA's `__nv_fp8_e4m3` conversion. NaN at 0xFF / 0x7F is left
|
||||
/// unhandled here; we only care about the finite values that appear in
|
||||
/// modelopt's NVFP4 scales.
|
||||
pub fn fp8_e4m3_to_f32(byte: u8) -> f32 {
|
||||
let sign = ((byte >> 7) & 0x1) as u32;
|
||||
let exp = ((byte >> 3) & 0xF) as i32;
|
||||
let mant = (byte & 0x7) as u32;
|
||||
let f_bits = if exp == 0 {
|
||||
// subnormal: value = (-1)^sign * 2^-6 * mant/8
|
||||
if mant == 0 {
|
||||
sign << 31
|
||||
} else {
|
||||
let mant_f = mant as f32 / 8.0;
|
||||
let v = mant_f * (1.0_f32 / 64.0); // 2^-6
|
||||
let v = if sign == 1 { -v } else { v };
|
||||
v.to_bits()
|
||||
}
|
||||
} else {
|
||||
// normal: value = (-1)^sign * 2^(exp - 7) * (1 + mant/8)
|
||||
let unbiased = exp - 7;
|
||||
let v = (1.0 + mant as f32 / 8.0) * 2f32.powi(unbiased);
|
||||
let v = if sign == 1 { -v } else { v };
|
||||
v.to_bits()
|
||||
};
|
||||
f32::from_bits(f_bits)
|
||||
}
|
||||
|
||||
pub fn dequant_byte(packed: u8, lo_block_scale: f32, scale_2: f32) -> (f32, f32) {
|
||||
let lo = (packed & 0xF) as usize;
|
||||
let hi = ((packed >> 4) & 0xF) as usize;
|
||||
let lo_v = FP4_E2M1_LUT[lo] * lo_block_scale * scale_2;
|
||||
let hi_v = FP4_E2M1_LUT[hi] * lo_block_scale * scale_2;
|
||||
(lo_v, hi_v)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fp4_lut_signed_zero_and_six() {
|
||||
assert_eq!(FP4_E2M1_LUT[0], 0.0);
|
||||
assert_eq!(FP4_E2M1_LUT[8], -0.0);
|
||||
assert_eq!(FP4_E2M1_LUT[7], 6.0);
|
||||
assert_eq!(FP4_E2M1_LUT[15], -6.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fp8_e4m3_basic_values() {
|
||||
// 0x00 -> 0.0; 0x80 -> -0.0
|
||||
assert_eq!(fp8_e4m3_to_f32(0x00), 0.0);
|
||||
assert_eq!(fp8_e4m3_to_f32(0x80), -0.0);
|
||||
// 0x38 -> 1.0 (sign=0, exp=7, mant=0 -> 1.0 * 2^0)
|
||||
assert_eq!(fp8_e4m3_to_f32(0x38), 1.0);
|
||||
// 0x40 -> 2.0 (sign=0, exp=8, mant=0 -> 1.0 * 2^1)
|
||||
assert_eq!(fp8_e4m3_to_f32(0x40), 2.0);
|
||||
// 0x3C -> 1.5 (sign=0, exp=7, mant=4 -> (1 + 0.5) * 2^0)
|
||||
assert_eq!(fp8_e4m3_to_f32(0x3C), 1.5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dequant_byte_combines_block_and_outer_scale() {
|
||||
// Byte 0x12 -> low nibble = 2 -> +1.0; high nibble = 1 -> +0.5.
|
||||
// With block scale 2.0 and outer scale 0.5, expect +1.0 and +0.5.
|
||||
let (lo, hi) = dequant_byte(0x12, 2.0, 0.5);
|
||||
assert!((lo - 1.0).abs() < 1e-6);
|
||||
assert!((hi - 0.5).abs() < 1e-6);
|
||||
}
|
||||
}
|
||||
@@ -1,148 +0,0 @@
|
||||
//! Pure-Rust port of `diffusers.FlowMatchEulerDiscreteScheduler`, the scheduler
|
||||
//! configured by Flux 2's `scheduler/scheduler_config.json`.
|
||||
//!
|
||||
//! Behaves exactly like the diffusers implementation when:
|
||||
//! - `use_dynamic_shifting = true`
|
||||
//! - `time_shift_type = "exponential"`
|
||||
//! - `invert_sigmas = false`
|
||||
//! - `shift_terminal = None`
|
||||
//! - karras / exponential / beta sigmas are all disabled
|
||||
//!
|
||||
//! Validated against `diffusers==0.36.x` for `image_seq_len = 4096`,
|
||||
//! `num_inference_steps = 8`, `mu = 1.15` (max difference < 1e-6).
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SchedulerConfig {
|
||||
pub num_train_timesteps: f32,
|
||||
pub sigma_max: f32,
|
||||
pub sigma_min: f32,
|
||||
pub base_image_seq_len: f32,
|
||||
pub max_image_seq_len: f32,
|
||||
pub base_shift: f32,
|
||||
pub max_shift: f32,
|
||||
}
|
||||
|
||||
impl Default for SchedulerConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
num_train_timesteps: 1000.0,
|
||||
sigma_max: 1.0,
|
||||
sigma_min: 1e-3,
|
||||
base_image_seq_len: 256.0,
|
||||
max_image_seq_len: 4096.0,
|
||||
base_shift: 0.5,
|
||||
max_shift: 1.15,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Linear interpolation of the shift parameter over the configured image-sequence range.
|
||||
/// Matches the `mu` computation in `diffusers.FluxPipeline.calculate_shift`.
|
||||
pub fn compute_mu(cfg: &SchedulerConfig, image_seq_len: usize) -> f32 {
|
||||
let m = (cfg.max_shift - cfg.base_shift) / (cfg.max_image_seq_len - cfg.base_image_seq_len);
|
||||
m * (image_seq_len as f32 - cfg.base_image_seq_len) + cfg.base_shift
|
||||
}
|
||||
|
||||
/// Build the schedule of sigmas (length `num_inference_steps + 1`, ending in 0.0)
|
||||
/// and timesteps (length `num_inference_steps`) for one inference run.
|
||||
pub fn make_schedule(
|
||||
cfg: &SchedulerConfig,
|
||||
num_inference_steps: usize,
|
||||
mu: f32,
|
||||
) -> (Vec<f32>, Vec<f32>) {
|
||||
assert!(num_inference_steps >= 1);
|
||||
|
||||
// 1. Linearly spaced timesteps -> sigmas in [sigma_max, sigma_min].
|
||||
let n = num_inference_steps;
|
||||
let mut sigmas: Vec<f32> = (0..n)
|
||||
.map(|i| {
|
||||
let t_max = cfg.sigma_max * cfg.num_train_timesteps;
|
||||
let t_min = cfg.sigma_min * cfg.num_train_timesteps;
|
||||
let alpha = if n == 1 {
|
||||
0.0
|
||||
} else {
|
||||
i as f32 / (n - 1) as f32
|
||||
};
|
||||
let t = t_max + (t_min - t_max) * alpha;
|
||||
t / cfg.num_train_timesteps
|
||||
})
|
||||
.collect();
|
||||
|
||||
// 2. Resolution-dependent exponential time shift.
|
||||
let exp_mu = mu.exp();
|
||||
for s in sigmas.iter_mut() {
|
||||
// s' = exp(mu) / (exp(mu) + (1/s - 1))
|
||||
let rhs = exp_mu + (1.0 / *s - 1.0);
|
||||
*s = exp_mu / rhs;
|
||||
}
|
||||
|
||||
// 3. Timesteps = sigmas * num_train_timesteps before terminal append.
|
||||
let timesteps: Vec<f32> = sigmas.iter().map(|s| s * cfg.num_train_timesteps).collect();
|
||||
|
||||
// 4. Append terminal 0 sigma.
|
||||
sigmas.push(0.0);
|
||||
(sigmas, timesteps)
|
||||
}
|
||||
|
||||
/// One Euler integration step of the rectified-flow ODE.
|
||||
/// `sample_next = sample + (sigma_next - sigma) * model_output`.
|
||||
/// `sigmas[i]` is the current step's sigma, `sigmas[i + 1]` the next.
|
||||
pub fn euler_step(sample: &mut [f32], model_output: &[f32], sigma: f32, sigma_next: f32) {
|
||||
debug_assert_eq!(sample.len(), model_output.len());
|
||||
let dt = sigma_next - sigma;
|
||||
for (s, &m) in sample.iter_mut().zip(model_output) {
|
||||
*s += dt * m;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn close(a: f32, b: f32, tol: f32) -> bool {
|
||||
(a - b).abs() < tol
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn matches_diffusers_4096_steps_8() {
|
||||
// Reference output captured from diffusers 0.36 with the FLUX.2-dev config
|
||||
// (use_dynamic_shifting=True, time_shift_type=exponential, mu=1.15).
|
||||
let cfg = SchedulerConfig::default();
|
||||
let mu = compute_mu(&cfg, 4096);
|
||||
assert!(close(mu, 1.15, 1e-6), "mu={mu}");
|
||||
|
||||
let (sigmas, timesteps) = make_schedule(&cfg, 8, mu);
|
||||
let expected_sigmas = [
|
||||
1.0, 0.9499281, 0.887_723, 0.8083667, 0.7036315, 0.5590252, 0.3464282, 0.0031514, 0.0,
|
||||
];
|
||||
let expected_timesteps = [
|
||||
1000.0, 949.9281, 887.723, 808.3667, 703.6315, 559.0252, 346.4282, 3.1514,
|
||||
];
|
||||
assert_eq!(sigmas.len(), expected_sigmas.len());
|
||||
for (got, want) in sigmas.iter().zip(expected_sigmas.iter()) {
|
||||
assert!(
|
||||
close(*got, *want, 1e-4),
|
||||
"sigma mismatch: got {got} want {want}"
|
||||
);
|
||||
}
|
||||
assert_eq!(timesteps.len(), expected_timesteps.len());
|
||||
for (got, want) in timesteps.iter().zip(expected_timesteps.iter()) {
|
||||
assert!(
|
||||
close(*got, *want, 1e-1),
|
||||
"timestep mismatch: got {got} want {want}",
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn euler_step_matches_formula() {
|
||||
// Trivial: prev = sample + (sigma_next - sigma) * out.
|
||||
let mut x = vec![1.0_f32, 2.0, 3.0];
|
||||
let out = vec![10.0_f32, -10.0, 0.0];
|
||||
euler_step(&mut x, &out, 0.5, 0.2);
|
||||
let dt = 0.2 - 0.5;
|
||||
assert!(close(x[0], 1.0 + dt * 10.0, 1e-6));
|
||||
assert!(close(x[1], 2.0 + dt * -10.0, 1e-6));
|
||||
assert!(close(x[2], 3.0 + dt * 0.0, 1e-6));
|
||||
}
|
||||
}
|
||||
@@ -1,398 +0,0 @@
|
||||
//! Mistral3 text encoder (text branch only) for the Flux 2 pipeline.
|
||||
//!
|
||||
//! ## What we need to produce
|
||||
//!
|
||||
//! Flux 2's text-conditioning is `joint_attention_dim = 15360 = 3 × 5120`,
|
||||
//! constructed by stacking the **post-residual hidden states** at layer
|
||||
//! indices 10, 20, 30 of the Mistral 3 Small text branch:
|
||||
//!
|
||||
//! ```text
|
||||
//! out = stack([hidden_states[10], hidden_states[20], hidden_states[30]], dim=1)
|
||||
//! # shape (B, 3, S, 5120)
|
||||
//! out = out.permute(0, 2, 1, 3).reshape(B, S, 15360)
|
||||
//! ```
|
||||
//!
|
||||
//! `hidden_states[k]` follows the HuggingFace convention: index 0 is the
|
||||
//! token embeddings (pre-layer-0), index k is the post-residual output of
|
||||
//! layer k-1. So `[10, 20, 30]` taps after layers 9, 19, 29 — meaning we
|
||||
//! only need to run **layers 0..30** (= 30 layers), not the full 40.
|
||||
//!
|
||||
//! ## Architecture (`text_encoder/config.json`)
|
||||
//!
|
||||
//! Standard Mistral / Llama-shape decoder-only LM:
|
||||
//! * 5120 hidden, 32 heads, 8 kv heads (GQA, kv_groups=4), head_dim 128
|
||||
//! * 32768 intermediate (SwiGLU MLP), RMSNorm eps 1e-5
|
||||
//! * RoPE theta 1e9, no sliding window
|
||||
//! * vocab 131072
|
||||
//! * BF16 storage, ten safetensors shards
|
||||
//! * On-disk weight prefix is `language_model.model.{embed_tokens, layers.i, norm}.*`
|
||||
//! because the parent class is multimodal — there's also a `vision_tower`
|
||||
//! and a `multi_modal_projector` we ignore.
|
||||
//!
|
||||
//! ## Memory
|
||||
//!
|
||||
//! 30 layers × ≈530 M params each + 671 M embedding ≈ 16.6 B params at BF16
|
||||
//! ≈ **33 GB** GPU memory for weights. Activations for 1 token-sequence of
|
||||
//! 512 tokens are negligible (a few hundred MB). So this fits comfortably on
|
||||
//! the 96 GB GH200, leaving 60 GB free — enough headroom to keep the VAE
|
||||
//! resident alongside (336 MB) and load the transformer separately afterwards.
|
||||
|
||||
use luminal::{dtype::DType, graph::Graph, prelude::*};
|
||||
|
||||
// ── Mistral 3 Small architecture constants for FLUX.2-dev ────────────────────
|
||||
pub const HIDDEN: usize = 5120;
|
||||
pub const NUM_HEADS: usize = 32;
|
||||
pub const NUM_KV_HEADS: usize = 8;
|
||||
pub const KV_GROUPS: usize = NUM_HEADS / NUM_KV_HEADS; // 4
|
||||
pub const HEAD_DIM: usize = 128;
|
||||
pub const Q_DIM: usize = NUM_HEADS * HEAD_DIM; // 4096 — wait, 32*128 = 4096 != 5120
|
||||
pub const KV_DIM: usize = NUM_KV_HEADS * HEAD_DIM; // 1024
|
||||
pub const INTERMEDIATE: usize = 32768;
|
||||
pub const RMS_EPS: f32 = 1e-5;
|
||||
pub const ROPE_THETA: f32 = 1.0e9;
|
||||
pub const VOCAB_SIZE: usize = 131072;
|
||||
/// We only need layers 0..29 to capture hidden_states[30] at the post-29
|
||||
/// residual. Layers 30..39 of the full Mistral 3 model are not loaded.
|
||||
pub const NUM_LAYERS_USED: usize = 30;
|
||||
/// Indices into `hidden_states` (HF convention, first item is the embedding)
|
||||
/// we tap and concatenate to get the 15360-dim Flux 2 text features.
|
||||
pub const TAP_LAYERS: [usize; 3] = [10, 20, 30];
|
||||
/// Concatenated channel dimension after stacking the 3 taps.
|
||||
pub const OUTPUT_DIM: usize = 3 * HIDDEN; // 15360 = joint_attention_dim
|
||||
|
||||
/// Storage dtype — Mistral 3 ships in BF16.
|
||||
pub const WEIGHT_DTYPE: DType = DType::Bf16;
|
||||
|
||||
// =============================================================================
|
||||
// Helpers (mirror the patterns in the existing `examples/qwen` & `gemma4_moe`)
|
||||
// =============================================================================
|
||||
|
||||
fn linear_no_bias(x: GraphTensor, w: GraphTensor) -> GraphTensor {
|
||||
// Direct mixed-precision kernel: F32 A × BF16 B^T → F32 (M, N), with the
|
||||
// BF16 → F32 conversion happening on each load inside the kernel rather
|
||||
// than as a separate cast op. This keeps the BF16 weight in memory as-is
|
||||
// (a 24 GB → 48 GB cast for the full encoder would not fit on the GPU)
|
||||
// and bypasses the egglog matmul lowering, where the cublaslt 2D rule
|
||||
// doesn't reliably fire for these shapes — see kernel::matmul2d's docs.
|
||||
//
|
||||
// Falls back to the standard `x.matmul(w.cast(x.dtype).t())` lowering
|
||||
// for ranks > 2 (e.g. attention's batched (heads, seq, head_dim) form),
|
||||
// since the custom kernel is only 2D.
|
||||
if x.shape.len() == 2 && w.shape.len() == 2 {
|
||||
luminal_cuda_lite::kernel::linear_no_bias_bf16_w(x, w)
|
||||
} else {
|
||||
x.matmul(w.cast(x.dtype).t())
|
||||
}
|
||||
}
|
||||
|
||||
fn rmsnorm(x: GraphTensor, weight: GraphTensor, eps: f32) -> GraphTensor {
|
||||
let w = if weight.dtype == DType::F32 {
|
||||
weight
|
||||
} else {
|
||||
weight.cast(DType::F32)
|
||||
};
|
||||
let x_rank = x.dims().len();
|
||||
let w_rank = w.dims().len();
|
||||
x.std_norm(x_rank - 1, eps) * w.expand_lhs(&x.dims()[..x_rank - w_rank])
|
||||
}
|
||||
|
||||
/// Rotary position embedding — half-rotation convention (`[x0, x1] →
|
||||
/// [x0*cos - x1*sin, x1*cos + x0*sin]` where `x0`, `x1` are the first and
|
||||
/// second halves of the head dim). Matches Llama / Mistral.
|
||||
///
|
||||
/// Inputs:
|
||||
/// * `x`: `(seq, n_heads, head_dim)`
|
||||
/// * `pos_ids`: `(seq,)` Int
|
||||
/// * `theta`: RoPE base
|
||||
fn apply_rope(x: GraphTensor, pos_ids: GraphTensor, n_heads: usize, theta: f32) -> GraphTensor {
|
||||
let cx = x.graph();
|
||||
let _seq = x.dims()[0];
|
||||
let half = HEAD_DIM / 2;
|
||||
|
||||
// Frequencies: theta^(-2i/D) for i in 0..D/2 — represented as 1 / theta^(2i/D)
|
||||
let exponents = cx.arange_options(0, HEAD_DIM, 2).cast(DType::F32) / HEAD_DIM as f32;
|
||||
use luminal::prelude::F32Pow;
|
||||
let inv_freqs = theta.pow(exponents).reciprocal();
|
||||
let emb = pos_ids
|
||||
.cast(DType::F32)
|
||||
.expand_dim(1, 1)
|
||||
.matmul(inv_freqs.expand_dim(0, 1)); // (seq, half)
|
||||
|
||||
let cos = emb.cos().expand_dim(1, n_heads); // (seq, n_heads, half)
|
||||
let sin = emb.sin().expand_dim(1, n_heads);
|
||||
|
||||
let x0 = x.slice((.., .., ..half));
|
||||
let x1 = x.slice((.., .., half..));
|
||||
let r0 = x0.cast(DType::F32) * cos - x1.cast(DType::F32) * sin;
|
||||
let r1 = x1.cast(DType::F32) * cos + x0.cast(DType::F32) * sin;
|
||||
r0.concat_along(r1, 2)
|
||||
}
|
||||
|
||||
/// Standard scaled dot-product attention over `(n_heads, seq_q, head_dim)`,
|
||||
/// `(n_heads, seq_k, head_dim)`, `(n_heads, seq_k, head_dim)` with a causal
|
||||
/// mask. Returns `(seq_q, n_heads * head_dim)`.
|
||||
///
|
||||
/// Routes the two batched matmuls through `kernel::matmul_3d_t` /
|
||||
/// `matmul_3d` rather than the egglog matmul lowering. The standard path
|
||||
/// has the same problem the VAE attention had (cublaslt batched rules
|
||||
/// fail to fire reliably; the broadcast Mul + SumReduce fallback creates
|
||||
/// a `(n_heads, M, N, K)` intermediate that scales O(seq²) and OOMs at
|
||||
/// seq_len ≥ ~256 even with BF16 weights elsewhere).
|
||||
fn causal_sdpa(
|
||||
q: GraphTensor,
|
||||
k: GraphTensor,
|
||||
v: GraphTensor,
|
||||
attention_mask: GraphTensor,
|
||||
) -> GraphTensor {
|
||||
let cx = q.graph();
|
||||
let n_heads = q.dims()[0];
|
||||
let seq = q.dims()[1];
|
||||
let scale = (HEAD_DIM as f32).sqrt().recip();
|
||||
// The kernel needs contiguous batches; a `* 1.0` after the upstream
|
||||
// transpose / GQA-expand chain materialises the strided view.
|
||||
let q = q * 1.0_f32;
|
||||
let k = k * 1.0_f32;
|
||||
let v = v * 1.0_f32;
|
||||
// Q @ K^T: (heads, seq, head_dim) @ (heads, seq, head_dim)^T = (heads, seq, seq).
|
||||
let scores = luminal_cuda_lite::kernel::matmul_3d_t(q, k) * scale;
|
||||
// Causal mask: positions where k_pos > q_pos are masked.
|
||||
let q_pos = cx.arange(seq).cast(DType::F32);
|
||||
let k_pos = cx.arange(seq).cast(DType::F32);
|
||||
let causal = k_pos.expand_dim(0, seq).gt(q_pos.expand_dim(1, seq));
|
||||
let causal = causal.cast(DType::F32);
|
||||
// Padding mask: keys at positions where attention_mask == 0 (padding
|
||||
// tokens) are masked regardless of the causal relation. Without this,
|
||||
// padding queries attend to prior padding keys via causal alone, and
|
||||
// every padding hidden state diverges from diffusers — surfaces as
|
||||
// cos_sim ≈ 0.65 on `prompt_embeds` even though tokens 0..real_len-1
|
||||
// match exactly. attention_mask has shape (seq,) with 1 for real and
|
||||
// 0 for padding tokens; broadcast as a per-key column to all queries.
|
||||
// (1 - mask[k]) is 1 for padding keys, 0 for real keys → adds -1e10
|
||||
// to every (q, padding_k) score.
|
||||
let pad_key = (attention_mask.cast(DType::F32) * (-1.0_f32) + 1.0_f32) // (seq,)
|
||||
.expand_dim(0, seq); // (seq_q=seq, seq_k=seq) — broadcast over q.
|
||||
// Combine: anywhere either causal or padding masks → -1e10.
|
||||
let mask = causal + pad_key;
|
||||
let mask = mask.expand_dim(0, n_heads);
|
||||
let masked = scores + mask * (-1e10_f32);
|
||||
let weights = masked.softmax(2);
|
||||
// attn = weights @ v: (heads, seq, seq) @ (heads, seq, head_dim) = (heads, seq, head_dim).
|
||||
let attn = luminal_cuda_lite::kernel::matmul_3d(weights, v);
|
||||
// `transpose(0, 1).merge_dims(1, 2)` produces the merge_dims
|
||||
// non-contiguous K stride `(((z/HEAD_DIM)*HEAD_DIM)*SEQ)+(z%HEAD_DIM)`.
|
||||
// The cublaslt 2D rule requires `K stride = MIter` (contiguous), so
|
||||
// without forcing materialization here the downstream o_proj matmul
|
||||
// falls through to a broadcast Mul whose `(SEQ, HIDDEN, KV_DIM)`
|
||||
// intermediate is ~20 GB BF16 and OOMs the GPU during search.
|
||||
attn.transpose(0, 1).merge_dims(1, 2) * 1.0_f32 // (seq_q, n_heads*head_dim)
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// One Mistral 3 layer (RMSNorm → GQA self-attn + residual → RMSNorm → SwiGLU
|
||||
// MLP + residual). Identical in shape to the existing `examples/qwen`'s
|
||||
// `QwenLayer`.
|
||||
// =============================================================================
|
||||
|
||||
struct MistralLayer {
|
||||
attn_rms: GraphTensor, // (HIDDEN,)
|
||||
q_proj: GraphTensor, // (Q_DIM, HIDDEN) — Q dim = 32*128 = 4096
|
||||
k_proj: GraphTensor, // (KV_DIM, HIDDEN)
|
||||
v_proj: GraphTensor, // (KV_DIM, HIDDEN)
|
||||
o_proj: GraphTensor, // (HIDDEN, Q_DIM)
|
||||
mlp_rms: GraphTensor, // (HIDDEN,)
|
||||
gate_proj: GraphTensor, // (INTERMEDIATE, HIDDEN)
|
||||
up_proj: GraphTensor, // (INTERMEDIATE, HIDDEN)
|
||||
down_proj: GraphTensor, // (HIDDEN, INTERMEDIATE)
|
||||
}
|
||||
|
||||
impl MistralLayer {
|
||||
fn new(idx: usize, cx: &mut Graph) -> Self {
|
||||
let prefix = format!("language_model.model.layers.{idx}");
|
||||
let mk = |name: &str, shape: (usize, usize), cx: &mut Graph| -> GraphTensor {
|
||||
cx.named_tensor(format!("{prefix}.{name}"), shape)
|
||||
.as_dtype(WEIGHT_DTYPE)
|
||||
.persist()
|
||||
};
|
||||
let mk1 = |name: &str, n: usize, cx: &mut Graph| -> GraphTensor {
|
||||
cx.named_tensor(format!("{prefix}.{name}"), n)
|
||||
.as_dtype(WEIGHT_DTYPE)
|
||||
.persist()
|
||||
};
|
||||
Self {
|
||||
attn_rms: mk1("input_layernorm.weight", HIDDEN, cx),
|
||||
q_proj: mk("self_attn.q_proj.weight", (Q_DIM, HIDDEN), cx),
|
||||
k_proj: mk("self_attn.k_proj.weight", (KV_DIM, HIDDEN), cx),
|
||||
v_proj: mk("self_attn.v_proj.weight", (KV_DIM, HIDDEN), cx),
|
||||
o_proj: mk("self_attn.o_proj.weight", (HIDDEN, Q_DIM), cx),
|
||||
mlp_rms: mk1("post_attention_layernorm.weight", HIDDEN, cx),
|
||||
gate_proj: mk("mlp.gate_proj.weight", (INTERMEDIATE, HIDDEN), cx),
|
||||
up_proj: mk("mlp.up_proj.weight", (INTERMEDIATE, HIDDEN), cx),
|
||||
down_proj: mk("mlp.down_proj.weight", (HIDDEN, INTERMEDIATE), cx),
|
||||
}
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&self,
|
||||
x: GraphTensor,
|
||||
pos_ids: GraphTensor,
|
||||
attention_mask: GraphTensor,
|
||||
) -> GraphTensor {
|
||||
let h = rmsnorm(x, self.attn_rms, RMS_EPS);
|
||||
let q = linear_no_bias(h, self.q_proj);
|
||||
let k = linear_no_bias(h, self.k_proj);
|
||||
let v = linear_no_bias(h, self.v_proj);
|
||||
|
||||
// (seq, dim) → (seq, n_heads, head_dim) → ... → (n_heads, seq, head_dim)
|
||||
let q = q.split_dims(1, HEAD_DIM); // (seq, NUM_HEADS, HEAD_DIM)
|
||||
let k = k.split_dims(1, HEAD_DIM); // (seq, NUM_KV_HEADS, HEAD_DIM)
|
||||
let v = v.split_dims(1, HEAD_DIM);
|
||||
|
||||
let q = apply_rope(q, pos_ids, NUM_HEADS, ROPE_THETA);
|
||||
let k = apply_rope(k, pos_ids, NUM_KV_HEADS, ROPE_THETA);
|
||||
|
||||
// GQA expand: tile k, v along the kv_groups axis to match num_heads.
|
||||
let k = k
|
||||
.transpose(0, 1) // (NUM_KV_HEADS, seq, HEAD_DIM)
|
||||
.expand_dim(1, KV_GROUPS) // (NUM_KV_HEADS, KV_GROUPS, seq, HEAD_DIM)
|
||||
.merge_dims(0, 1); // (NUM_HEADS, seq, HEAD_DIM)
|
||||
let v = v.transpose(0, 1).expand_dim(1, KV_GROUPS).merge_dims(0, 1);
|
||||
let q = q.transpose(0, 1); // (NUM_HEADS, seq, HEAD_DIM)
|
||||
|
||||
let attn = causal_sdpa(q, k, v, attention_mask); // (seq, Q_DIM)
|
||||
let attn_out = linear_no_bias(attn, self.o_proj); // (seq, HIDDEN)
|
||||
let x = x + attn_out;
|
||||
|
||||
let h = rmsnorm(x, self.mlp_rms, RMS_EPS);
|
||||
let gate = linear_no_bias(h, self.gate_proj).silu();
|
||||
let up = linear_no_bias(h, self.up_proj);
|
||||
let mlp = linear_no_bias(gate * up, self.down_proj);
|
||||
x + mlp
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Top-level text encoder
|
||||
// =============================================================================
|
||||
|
||||
pub struct Mistral3TextEncoder {
|
||||
pub embed_tokens: GraphTensor, // (VOCAB_SIZE, HIDDEN) — used as a gather table
|
||||
layers: Vec<MistralLayer>,
|
||||
}
|
||||
|
||||
impl Mistral3TextEncoder {
|
||||
pub fn init(cx: &mut Graph) -> Self {
|
||||
let embed_tokens = cx
|
||||
.named_tensor(
|
||||
"language_model.model.embed_tokens.weight",
|
||||
(VOCAB_SIZE, HIDDEN),
|
||||
)
|
||||
.as_dtype(WEIGHT_DTYPE)
|
||||
.persist();
|
||||
let layers = (0..NUM_LAYERS_USED)
|
||||
.map(|i| MistralLayer::new(i, cx))
|
||||
.collect();
|
||||
Self {
|
||||
embed_tokens,
|
||||
layers,
|
||||
}
|
||||
}
|
||||
|
||||
/// Run the prompt through the (truncated) text encoder and return the
|
||||
/// **stacked-and-flattened** `(seq, OUTPUT_DIM=15360)` text features the
|
||||
/// Flux 2 transformer's `context_embedder` consumes.
|
||||
///
|
||||
/// Steps mirror diffusers' `_get_mistral_3_small_prompt_embeds`:
|
||||
/// 1. Gather `embed_tokens[input_ids]` → `(seq, HIDDEN)`.
|
||||
/// 2. Run layers; capture `hidden_states[10/20/30]` (in HF convention,
|
||||
/// = post-residual at layers 9, 19, 29).
|
||||
/// 3. Stack along a new "tap" axis: `(seq, 3, HIDDEN)`.
|
||||
/// 4. Flatten the tap axis into the channel axis: `(seq, 3*HIDDEN)`.
|
||||
pub fn forward(
|
||||
&self,
|
||||
input_ids: GraphTensor,
|
||||
pos_ids: GraphTensor,
|
||||
attention_mask: GraphTensor,
|
||||
) -> GraphTensor {
|
||||
let seq = input_ids.dims1();
|
||||
// Token embedding lookup via gather. Mirror the qwen / llama pattern:
|
||||
// build a flat index table (id * HIDDEN + col) that picks the right
|
||||
// row from the embed_tokens (VOCAB_SIZE × HIDDEN) buffer. The source
|
||||
// is BF16 so the gathered slice is BF16 too — cast to F32 immediately
|
||||
// so the rest of the network runs in F32 with BF16 weights upcast at
|
||||
// each matmul (see `linear_no_bias`).
|
||||
let mut x = self.embed_tokens.gather(
|
||||
(input_ids * HIDDEN).expand_dim(1, HIDDEN)
|
||||
+ input_ids.graph().arange(HIDDEN).expand_dim(0, seq),
|
||||
);
|
||||
x = x.cast(DType::F32);
|
||||
|
||||
// Run layers, taking snapshots at the right HF-convention layer indices.
|
||||
// hidden_states[10] = post-residual after layer 9, so we capture AFTER
|
||||
// running layer 9. Same for 19 and 29.
|
||||
let mut taps: Vec<GraphTensor> = Vec::with_capacity(TAP_LAYERS.len());
|
||||
for (idx, layer) in self.layers.iter().enumerate() {
|
||||
x = layer.forward(x, pos_ids, attention_mask);
|
||||
// Map: TAP_LAYERS = [10, 20, 30] meaning "post-layer 9/19/29".
|
||||
if TAP_LAYERS.iter().any(|&k| idx + 1 == k) {
|
||||
taps.push(x);
|
||||
}
|
||||
}
|
||||
|
||||
// Stack as (seq, n_taps, HIDDEN) then flatten last two dims.
|
||||
let mut stacked = taps[0].expand_dim(1, 1_usize); // (seq, 1, HIDDEN)
|
||||
for t in &taps[1..] {
|
||||
stacked = stacked.concat_along(t.expand_dim(1, 1_usize), 1);
|
||||
}
|
||||
// (seq, 3, HIDDEN) → (seq, 3*HIDDEN)
|
||||
stacked.merge_dims(1, 2)
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Chat-template formatting (text-only path) — produces the byte string that
|
||||
// then gets fed to a tokenizer. Matches the Mistral 3 chat template applied
|
||||
// by diffusers' `_get_mistral_3_small_prompt_embeds`.
|
||||
// =============================================================================
|
||||
|
||||
/// The system message Flux 2's pipeline uses by default for txt2img.
|
||||
/// Verbatim from `diffusers.pipelines.flux2.system_messages.SYSTEM_MESSAGE`.
|
||||
pub const SYSTEM_MESSAGE: &str = "You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object\nattribution and actions without speculation.";
|
||||
|
||||
/// Format `(system, user)` into the wire-format string the tokenizer expects
|
||||
/// after `apply_chat_template(..., add_generation_prompt=False)`. The
|
||||
/// template inserts `<s>` (BOS) on its own, so we don't add it here — the
|
||||
/// tokenizer will emit it via `add_bos_token = true`.
|
||||
pub fn format_chat(system_message: &str, user_prompt: &str) -> String {
|
||||
// The Mistral 3 jinja template renders to:
|
||||
// <bos>[SYSTEM_PROMPT]{sys}[/SYSTEM_PROMPT][INST]{user}[/INST]
|
||||
// The bracketed tags are individual added-tokens in tokenizer.json, so
|
||||
// they'll round-trip through the tokenizer as single ids.
|
||||
format!("[SYSTEM_PROMPT]{system_message}[/SYSTEM_PROMPT][INST]{user_prompt}[/INST]")
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn chat_template_matches_jinja_output() {
|
||||
// Sanity check: the result is the deterministic concatenation we
|
||||
// expect for a text-only prompt.
|
||||
let s = format_chat("hello world", "make a cat");
|
||||
assert_eq!(
|
||||
s,
|
||||
"[SYSTEM_PROMPT]hello world[/SYSTEM_PROMPT][INST]make a cat[/INST]"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn architecture_constants_consistent() {
|
||||
assert_eq!(NUM_HEADS * HEAD_DIM, Q_DIM);
|
||||
assert_eq!(NUM_KV_HEADS * HEAD_DIM, KV_DIM);
|
||||
assert!(NUM_HEADS.is_multiple_of(NUM_KV_HEADS));
|
||||
assert_eq!(KV_GROUPS, NUM_HEADS / NUM_KV_HEADS);
|
||||
assert_eq!(OUTPUT_DIM, TAP_LAYERS.len() * HIDDEN);
|
||||
// hidden_states[30] requires running 30 layers (0..29 inclusive).
|
||||
assert_eq!(NUM_LAYERS_USED, *TAP_LAYERS.iter().max().unwrap());
|
||||
}
|
||||
}
|
||||
@@ -1,917 +0,0 @@
|
||||
//! Flux2Transformer2DModel — the diffusion transformer / DiT — in pure HLIR.
|
||||
//!
|
||||
//! Mirrors the diffusers reference (`diffusers.models.transformers.transformer_flux2`)
|
||||
//! op-for-op. Architecture summary:
|
||||
//!
|
||||
//! ## Top-level forward (per denoising step)
|
||||
//!
|
||||
//! ```text
|
||||
//! latent (S_img, 128) ─┐
|
||||
//! ├─ x_embedder ──────► img (S_img, 6144)
|
||||
//! text (S_txt, 15360) ─┴─ context_embedder ► txt (S_txt, 6144)
|
||||
//!
|
||||
//! timestep, guidance ─► time_guidance_embed ► temb (6144)
|
||||
//! ├─ double_mod_img(temb) ─► (4096*9 = 36864) modulation
|
||||
//! ├─ double_mod_txt(temb) ─► (36864) modulation
|
||||
//! └─ single_mod(temb) ─► (18432) modulation
|
||||
//!
|
||||
//! img_ids (S_img, 4), txt_ids (S_txt, 4) ─► pos_embed ─► (cos, sin) of shape (S, 128)
|
||||
//! (concatenated txt then img)
|
||||
//!
|
||||
//! 8x DoubleStream: (img, txt) -> (img, txt) ◄── temb_mod_{img,txt}, rope
|
||||
//! concat: hidden = [txt, img] (length S_txt + S_img)
|
||||
//! 48x SingleStream: hidden -> hidden ◄── temb_mod, rope
|
||||
//! drop txt prefix: hidden = hidden[S_txt:]
|
||||
//!
|
||||
//! norm_out(hidden, temb) ─► proj_out ─► (S_img, 128)
|
||||
//! ```
|
||||
//!
|
||||
//! ## Per-block (DoubleStream)
|
||||
//!
|
||||
//! ```text
|
||||
//! mod_img split → (shift_msa, scale_msa, gate_msa), (shift_mlp, scale_mlp, gate_mlp)
|
||||
//! mod_txt split → (c_shift_msa, c_scale_msa, c_gate_msa), (c_shift_mlp, c_scale_mlp, c_gate_mlp)
|
||||
//!
|
||||
//! img' = LN(img) * (1+scale_msa) + shift_msa
|
||||
//! txt' = LN(txt) * (1+c_scale_msa) + c_shift_msa
|
||||
//! q_img, k_img, v_img = to_q(img'), to_k(img'), to_v(img')
|
||||
//! q_txt, k_txt, v_txt = add_q_proj(txt'), add_k_proj(txt'), add_v_proj(txt')
|
||||
//! q,k = RMSNorm_{qk}(reshape to (heads, head_dim))
|
||||
//! q = [norm_added_q(q_txt) ; norm_q(q_img)] along sequence axis
|
||||
//! k = [norm_added_k(k_txt) ; norm_k(k_img)]
|
||||
//! v = [v_txt ; v_img]
|
||||
//! q,k = apply_rotary(q,k, rope)
|
||||
//! attn = scaled_dot_product(q, k, v) // standard
|
||||
//! attn = flatten(heads, head_dim)
|
||||
//! attn_txt, attn_img = split(attn, [S_txt, S_img])
|
||||
//! img += gate_msa * to_out.0(attn_img)
|
||||
//! img += gate_mlp * FF(LN(img) * (1+scale_mlp) + shift_mlp)
|
||||
//! txt += c_gate_msa * to_add_out(attn_txt)
|
||||
//! txt += c_gate_mlp * FF_context(LN(txt) * (1+c_scale_mlp) + c_shift_mlp)
|
||||
//! ```
|
||||
//!
|
||||
//! ## Per-block (SingleStream — parallel attention + MLP)
|
||||
//!
|
||||
//! ```text
|
||||
//! mod split → (shift, scale, gate)
|
||||
//! h = LN(hidden) * (1+scale) + shift
|
||||
//! qkv_mlp = to_qkv_mlp_proj(h) // → 3*6144 + 2*mlp_hidden=2*18432
|
||||
//! qkv, mlp_in = split([3*6144, 2*18432])
|
||||
//! q,k,v = chunk(qkv, 3)
|
||||
//! q,k = RMSNorm + RoPE
|
||||
//! attn = sdpa(q,k,v); attn = flatten heads
|
||||
//! mlp = SwiGLU(mlp_in) // mlp_in has 2*mlp_hidden, halved
|
||||
//! out = to_out([attn; mlp])
|
||||
//! hidden += gate * out
|
||||
//! ```
|
||||
//!
|
||||
//! ## Status
|
||||
//!
|
||||
//! - **Architecture: complete.** Every weight in `flux2-dev`'s 7 BF16 shards
|
||||
//! has a place in this graph (see [`Flux2Transformer::init`]).
|
||||
//! - **Numerical validation: not yet done.** The transformer hasn't been run
|
||||
//! end-to-end against the diffusers reference — that requires downloading
|
||||
//! 60+ GB of weights and is the next step.
|
||||
//! - **Test coverage:** the FFN, modulation split, and 4D RoPE construction
|
||||
//! are unit-tested against a Rust scalar reference in the test module at
|
||||
//! the bottom of this file.
|
||||
|
||||
use luminal::{dtype::DType, graph::Graph, prelude::*};
|
||||
|
||||
// ── architecture constants for `black-forest-labs/FLUX.2-dev` ───────────────
|
||||
//
|
||||
// `FLUX2_NUM_LAYERS` / `FLUX2_NUM_SINGLE_LAYERS` env vars override the
|
||||
// counts at runtime. Reducing them is useful for end-to-end pipeline
|
||||
// validation with a much smaller compile-time cost — at the full
|
||||
// 8 + 48 layer count the egglog egraph for the transformer can blow
|
||||
// past 200 GB of CPU RAM.
|
||||
pub fn num_layers() -> usize {
|
||||
std::env::var("FLUX2_NUM_LAYERS")
|
||||
.ok()
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(8)
|
||||
}
|
||||
pub fn num_single_layers() -> usize {
|
||||
std::env::var("FLUX2_NUM_SINGLE_LAYERS")
|
||||
.ok()
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(48)
|
||||
}
|
||||
pub const NUM_HEADS: usize = 48;
|
||||
pub const HEAD_DIM: usize = 128;
|
||||
pub const HIDDEN: usize = NUM_HEADS * HEAD_DIM; // 6144
|
||||
pub const MLP_HIDDEN: usize = 18432;
|
||||
pub const JOINT_ATTENTION_DIM: usize = 15360;
|
||||
pub const TIMESTEP_GUIDANCE_CHANNELS: usize = 256;
|
||||
pub const IN_CHANNELS: usize = 128;
|
||||
pub const PATCH_SIZE: usize = 1;
|
||||
pub const RMS_EPS: f32 = 1e-6;
|
||||
pub const RMS_NORM_HEAD_EPS: f32 = 1e-6;
|
||||
pub const ROPE_THETA: f32 = 2000.0;
|
||||
pub const ROPE_AXES: [usize; 4] = [32, 32, 32, 32];
|
||||
|
||||
/// Storage dtype for transformer weights. The Flux 2 checkpoint ships in
|
||||
/// BF16; we keep it that way and cast to F32 only at the points where the
|
||||
/// numerics matter (matmul accumulation, normalization).
|
||||
pub const WEIGHT_DTYPE: DType = DType::Bf16;
|
||||
|
||||
// =============================================================================
|
||||
// Small helpers
|
||||
// =============================================================================
|
||||
|
||||
fn linear_no_bias(x: GraphTensor, w: GraphTensor) -> GraphTensor {
|
||||
// For 2D inputs we go through `kernel::linear_no_bias_bf16_w`, which
|
||||
// is a direct mixed-precision SGEMM (F32 A × BF16 B^T → F32) that
|
||||
// converts BF16 → F32 on each load instead of materializing a
|
||||
// separate F32 cast tensor. Two reasons we don't use the egglog
|
||||
// matmul lowering for these:
|
||||
// 1. The cublaslt 2D rule fails to fire reliably for some matmul
|
||||
// shapes (see kernel::matmul2d's docs); even one bad genome
|
||||
// pick on the broadcast Mul + SumReduce fallback creates an
|
||||
// `(M, N, K)` intermediate that OOMs the GPU.
|
||||
// 2. Explicitly casting all BF16 weights to F32 first would more
|
||||
// than double the transformer's working set (~120 GB) and
|
||||
// wouldn't fit. The kernel keeps weights as BF16 in memory.
|
||||
//
|
||||
// Higher-rank cases (3D batched matmul inside attention) fall
|
||||
// through to the standard matmul lowering — those go through the
|
||||
// separate `matmul_3d` / `matmul_3d_t` helpers in `sdpa` below.
|
||||
if x.shape.len() == 2 && w.shape.len() == 2 {
|
||||
luminal_cuda_lite::kernel::linear_no_bias_bf16_w(x, w)
|
||||
} else {
|
||||
x.matmul(w.cast(x.dtype).t())
|
||||
}
|
||||
}
|
||||
|
||||
/// Pre-norm RMSNorm over the trailing axis with weight (`scale`); no shift.
|
||||
fn rmsnorm(x: GraphTensor, weight: GraphTensor, eps: f32) -> GraphTensor {
|
||||
let w = if weight.dtype == DType::F32 {
|
||||
weight
|
||||
} else {
|
||||
weight.cast(DType::F32)
|
||||
};
|
||||
let x_rank = x.dims().len();
|
||||
let w_rank = w.dims().len();
|
||||
x.std_norm(x_rank - 1, eps) * w.expand_lhs(&x.dims()[..x_rank - w_rank])
|
||||
}
|
||||
|
||||
/// LayerNorm with no affine parameters (mean-norm + std-norm only).
|
||||
/// Matches `nn.LayerNorm(dim, elementwise_affine=False)` in PyTorch.
|
||||
fn layernorm_noaffine(x: GraphTensor, eps: f32) -> GraphTensor {
|
||||
let last = x.shape.last_axis();
|
||||
x.layer_norm(last, eps)
|
||||
}
|
||||
|
||||
/// Apply rotary embedding. `x` is `(S, H, D)` and `(cos, sin)` are `(S, D)`.
|
||||
///
|
||||
/// Diffusers Flux 2 uses `repeat_interleave_real=True`. The rotation pairs
|
||||
/// adjacent dims: `[x0, x1, x2, x3, ...]` → rotated `[-x1, x0, -x3, x2, ...]`.
|
||||
/// This matches `apply_rotary_emb(use_real_unbind_dim=-1)` with
|
||||
/// `freqs_repeat_interleave_real=True`.
|
||||
fn apply_rope(x: GraphTensor, cos: GraphTensor, sin: GraphTensor) -> GraphTensor {
|
||||
// x: (S, H, D); cos/sin: (S, D) -> explicitly broadcast to (S, H, D).
|
||||
let (_s, h, d_expr) = x.dims3();
|
||||
let d = d_expr.to_usize().expect("head_dim must be static");
|
||||
assert!(d % 2 == 0, "RoPE head_dim must be even");
|
||||
|
||||
let pairs = x.split_dims(2, 2_usize);
|
||||
let x_a = pairs.slice((.., .., .., ..1)).squeeze(3);
|
||||
let x_b = pairs.slice((.., .., .., 1..)).squeeze(3);
|
||||
|
||||
let neg_b = x_b * (-1.0_f32);
|
||||
let rotated_pairs = neg_b
|
||||
.expand_dim(3, 1_usize)
|
||||
.concat_along(x_a.expand_dim(3, 1_usize), 3);
|
||||
let x_rot = rotated_pairs.merge_dims(2, 3);
|
||||
|
||||
let cos_b = cos.expand_dim(1, h);
|
||||
let sin_b = sin.expand_dim(1, h);
|
||||
x.cast(DType::F32) * cos_b.cast(DType::F32) + x_rot.cast(DType::F32) * sin_b.cast(DType::F32)
|
||||
}
|
||||
|
||||
/// Scaled dot-product attention with NO mask, no causal: standard SDPA.
|
||||
/// q, k, v: `(H, S, D)`. Returns `(S, H, D)`.
|
||||
///
|
||||
/// Routes through the direct batched matmul kernels for the same reason
|
||||
/// the text encoder does — see `text_encoder::causal_sdpa` for context.
|
||||
fn sdpa(q: GraphTensor, k: GraphTensor, v: GraphTensor) -> GraphTensor {
|
||||
let head_dim = q.dims()[2].to_usize().expect("head_dim must be static");
|
||||
let scale = (head_dim as f32).sqrt().recip();
|
||||
// The kernel needs contiguous batches; materialize the strided views
|
||||
// produced upstream (transpose / split_dims chains).
|
||||
let q = q * 1.0_f32;
|
||||
let k = k * 1.0_f32;
|
||||
let v = v * 1.0_f32;
|
||||
let scores = luminal_cuda_lite::kernel::matmul_3d_t(q, k) * scale; // (H, S, S)
|
||||
let attn_w = scores.softmax(2);
|
||||
let attn = luminal_cuda_lite::kernel::matmul_3d(attn_w, v); // (H, S, D)
|
||||
attn.transpose(0, 1) // (S, H, D)
|
||||
}
|
||||
|
||||
/// SwiGLU: split `x` along last axis into `(x1, x2)`, return `silu(x1) * x2`.
|
||||
/// `x` shape `(..., 2 * mlp_hidden)`. Handles 2D and 3D inputs.
|
||||
fn swiglu(x: GraphTensor) -> GraphTensor {
|
||||
let dims = x.dims();
|
||||
let last = dims[dims.len() - 1].to_usize().expect("static");
|
||||
assert!(last.is_multiple_of(2));
|
||||
let half = last / 2;
|
||||
match dims.len() {
|
||||
2 => {
|
||||
let x1 = x.slice((.., ..half));
|
||||
let x2 = x.slice((.., half..));
|
||||
x1.silu() * x2
|
||||
}
|
||||
3 => {
|
||||
let x1 = x.slice((.., .., ..half));
|
||||
let x2 = x.slice((.., .., half..));
|
||||
x1.silu() * x2
|
||||
}
|
||||
n => panic!("swiglu: unsupported rank {n}"),
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Sinusoidal timestep embedding
|
||||
// =============================================================================
|
||||
|
||||
/// Build the sinusoidal positional embedding of `timestep` (a `(1,)` F32
|
||||
/// tensor — caller sets the value at runtime via `runtime.set_data`).
|
||||
/// Matches `diffusers.models.embeddings.Timesteps` with
|
||||
/// `flip_sin_to_cos=True`, `downscale_freq_shift=0`, `max_period=10000`,
|
||||
/// `scale=1`. Returns shape `(num_channels,)`.
|
||||
///
|
||||
/// Taking `timestep` as a graph tensor (not a Rust f32 constant) is what
|
||||
/// lets the whole transformer forward be compiled **once** and re-executed
|
||||
/// each diffusion step with a different timestep, instead of paying the
|
||||
/// minutes-long search cost per step.
|
||||
fn timesteps_proj(timestep: GraphTensor, num_channels: usize) -> GraphTensor {
|
||||
let cx = timestep.graph();
|
||||
let half = num_channels / 2;
|
||||
let exponents = cx.arange(half).cast(DType::F32) / half as f32;
|
||||
let log10000 = (10000.0_f32).ln();
|
||||
let freqs = (exponents * (-log10000)).exp(); // (half,)
|
||||
// Broadcast scalar timestep (shape (1,)) to (half,) by repeating along
|
||||
// the size-1 axis (stride substitution makes it a zero-stride broadcast).
|
||||
let t_broadcast = timestep.cast(DType::F32).repeat([half]);
|
||||
let arg = freqs * t_broadcast;
|
||||
// flip_sin_to_cos=True: cos first, then sin
|
||||
arg.cos().concat_along(arg.sin(), 0)
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Modulation
|
||||
// =============================================================================
|
||||
|
||||
/// Modulation linear: `out = linear(silu(temb))`. Output dim = `dim * 3 * sets`.
|
||||
fn modulation(temb: GraphTensor, weight: GraphTensor) -> GraphTensor {
|
||||
let act = temb.silu();
|
||||
linear_no_bias(act, weight)
|
||||
}
|
||||
|
||||
/// Split modulation tensor (shape `(dim * 3 * sets,)`) into `sets` triples
|
||||
/// of `(shift, scale, gate)`, each `(dim,)`.
|
||||
fn split_modulation(
|
||||
mod_t: GraphTensor,
|
||||
sets: usize,
|
||||
) -> Vec<(GraphTensor, GraphTensor, GraphTensor)> {
|
||||
let total = mod_t.dims()[0]
|
||||
.to_usize()
|
||||
.expect("mod tensor dim must be static");
|
||||
let dim = total / (3 * sets);
|
||||
let mut out = Vec::with_capacity(sets);
|
||||
for i in 0..sets {
|
||||
let base = 3 * i * dim;
|
||||
let shift = mod_t.slice((base..base + dim,));
|
||||
let scale = mod_t.slice((base + dim..base + 2 * dim,));
|
||||
let gate = mod_t.slice((base + 2 * dim..base + 3 * dim,));
|
||||
out.push((shift, scale, gate));
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
/// Apply (1 + scale) * x + shift, broadcasting scale/shift over the leading
|
||||
/// sequence axis. `x: (S, D)`, `scale, shift: (D,)`.
|
||||
fn ada_modulate(x: GraphTensor, scale: GraphTensor, shift: GraphTensor) -> GraphTensor {
|
||||
let s = x.dims()[0];
|
||||
let scale_b = scale.expand_lhs([s]);
|
||||
let shift_b = shift.expand_lhs([s]);
|
||||
x * (scale_b + 1.0_f32) + shift_b
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// FeedForward (used by double-stream blocks)
|
||||
// =============================================================================
|
||||
|
||||
struct FeedForward {
|
||||
linear_in: GraphTensor, // (mlp_hidden*2, dim)
|
||||
linear_out: GraphTensor, // (dim, mlp_hidden)
|
||||
}
|
||||
|
||||
impl FeedForward {
|
||||
fn new(prefix: &str, dim: usize, mlp_hidden: usize, cx: &mut Graph) -> Self {
|
||||
Self {
|
||||
linear_in: cx
|
||||
.named_tensor(format!("{prefix}.linear_in.weight"), (mlp_hidden * 2, dim))
|
||||
.as_dtype(WEIGHT_DTYPE)
|
||||
.persist(),
|
||||
linear_out: cx
|
||||
.named_tensor(format!("{prefix}.linear_out.weight"), (dim, mlp_hidden))
|
||||
.as_dtype(WEIGHT_DTYPE)
|
||||
.persist(),
|
||||
}
|
||||
}
|
||||
|
||||
fn forward(&self, x: GraphTensor) -> GraphTensor {
|
||||
let h = linear_no_bias(x, self.linear_in);
|
||||
let h = swiglu(h);
|
||||
linear_no_bias(h, self.linear_out)
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Double-stream attention (img + txt joint attention)
|
||||
// =============================================================================
|
||||
|
||||
struct DoubleStreamAttn {
|
||||
to_q: GraphTensor,
|
||||
to_k: GraphTensor,
|
||||
to_v: GraphTensor,
|
||||
add_q_proj: GraphTensor,
|
||||
add_k_proj: GraphTensor,
|
||||
add_v_proj: GraphTensor,
|
||||
norm_q: GraphTensor, // (head_dim,)
|
||||
norm_k: GraphTensor, // (head_dim,)
|
||||
norm_added_q: GraphTensor, // (head_dim,)
|
||||
norm_added_k: GraphTensor, // (head_dim,)
|
||||
to_out: GraphTensor, // image-stream output projection
|
||||
to_add_out: GraphTensor, // text-stream output projection
|
||||
}
|
||||
|
||||
impl DoubleStreamAttn {
|
||||
fn new(prefix: &str, cx: &mut Graph) -> Self {
|
||||
let lin = |n: &str, cx: &mut Graph| -> GraphTensor {
|
||||
cx.named_tensor(format!("{prefix}.{n}"), (HIDDEN, HIDDEN))
|
||||
.as_dtype(WEIGHT_DTYPE)
|
||||
.persist()
|
||||
};
|
||||
Self {
|
||||
to_q: lin("to_q.weight", cx),
|
||||
to_k: lin("to_k.weight", cx),
|
||||
to_v: lin("to_v.weight", cx),
|
||||
add_q_proj: lin("add_q_proj.weight", cx),
|
||||
add_k_proj: lin("add_k_proj.weight", cx),
|
||||
add_v_proj: lin("add_v_proj.weight", cx),
|
||||
norm_q: cx
|
||||
.named_tensor(format!("{prefix}.norm_q.weight"), HEAD_DIM)
|
||||
.as_dtype(WEIGHT_DTYPE)
|
||||
.persist(),
|
||||
norm_k: cx
|
||||
.named_tensor(format!("{prefix}.norm_k.weight"), HEAD_DIM)
|
||||
.as_dtype(WEIGHT_DTYPE)
|
||||
.persist(),
|
||||
norm_added_q: cx
|
||||
.named_tensor(format!("{prefix}.norm_added_q.weight"), HEAD_DIM)
|
||||
.as_dtype(WEIGHT_DTYPE)
|
||||
.persist(),
|
||||
norm_added_k: cx
|
||||
.named_tensor(format!("{prefix}.norm_added_k.weight"), HEAD_DIM)
|
||||
.as_dtype(WEIGHT_DTYPE)
|
||||
.persist(),
|
||||
to_out: lin("to_out.0.weight", cx),
|
||||
to_add_out: lin("to_add_out.weight", cx),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns `(img_out, txt_out)`.
|
||||
/// img / txt: `(S_img, HIDDEN)` / `(S_txt, HIDDEN)`. RoPE: `(cos, sin)` of
|
||||
/// shape `(S_txt + S_img, HEAD_DIM)`.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn forward(
|
||||
&self,
|
||||
img: GraphTensor,
|
||||
txt: GraphTensor,
|
||||
rope_cos: GraphTensor,
|
||||
rope_sin: GraphTensor,
|
||||
) -> (GraphTensor, GraphTensor) {
|
||||
let s_img = img.dims()[0].to_usize().expect("S_img static");
|
||||
let s_txt = txt.dims()[0].to_usize().expect("S_txt static");
|
||||
|
||||
// QKV projections.
|
||||
let q_img = linear_no_bias(img, self.to_q);
|
||||
let k_img = linear_no_bias(img, self.to_k);
|
||||
let v_img = linear_no_bias(img, self.to_v);
|
||||
let q_txt = linear_no_bias(txt, self.add_q_proj);
|
||||
let k_txt = linear_no_bias(txt, self.add_k_proj);
|
||||
let v_txt = linear_no_bias(txt, self.add_v_proj);
|
||||
|
||||
// Reshape to (S, H, D).
|
||||
let q_img = q_img.split_dims(1, HEAD_DIM); // (S_img, HEADS, HEAD_DIM)
|
||||
let k_img = k_img.split_dims(1, HEAD_DIM);
|
||||
let v_img = v_img.split_dims(1, HEAD_DIM);
|
||||
let q_txt = q_txt.split_dims(1, HEAD_DIM);
|
||||
let k_txt = k_txt.split_dims(1, HEAD_DIM);
|
||||
let v_txt = v_txt.split_dims(1, HEAD_DIM);
|
||||
|
||||
// QK norm per head.
|
||||
let q_img = rmsnorm(q_img, self.norm_q, RMS_NORM_HEAD_EPS);
|
||||
let k_img = rmsnorm(k_img, self.norm_k, RMS_NORM_HEAD_EPS);
|
||||
let q_txt = rmsnorm(q_txt, self.norm_added_q, RMS_NORM_HEAD_EPS);
|
||||
let k_txt = rmsnorm(k_txt, self.norm_added_k, RMS_NORM_HEAD_EPS);
|
||||
|
||||
// Concat along sequence (txt first, then img — matches diffusers).
|
||||
let q = q_txt.concat_along(q_img, 0); // (S_txt + S_img, H, D)
|
||||
let k = k_txt.concat_along(k_img, 0);
|
||||
let v = v_txt.concat_along(v_img, 0);
|
||||
|
||||
// RoPE on Q, K (V unchanged).
|
||||
let q = apply_rope(q, rope_cos, rope_sin);
|
||||
let k = apply_rope(k, rope_cos, rope_sin);
|
||||
|
||||
// SDPA expects (H, S, D).
|
||||
let q = q.transpose(0, 1);
|
||||
let k = k.transpose(0, 1);
|
||||
let v = v.transpose(0, 1);
|
||||
|
||||
let attn = sdpa(q, k, v); // (S_total, H, D)
|
||||
// `merge_dims(1, 2)` on (S, H, D) produces non-contiguous K
|
||||
// stride for the next matmul (the o_proj path). Without
|
||||
// `* 1.0` the cublaslt 2D rule can't match and the broadcast
|
||||
// Mul intermediate is ~36 GB BF16 at flux2 dimensions.
|
||||
let attn = attn.merge_dims(1, 2) * 1.0_f32; // (S_total, HIDDEN)
|
||||
|
||||
// Split back into txt + img streams.
|
||||
let attn_txt = attn.slice((..s_txt, ..));
|
||||
let attn_img = attn.slice((s_txt..s_txt + s_img, ..));
|
||||
|
||||
let img_out = linear_no_bias(attn_img, self.to_out);
|
||||
let txt_out = linear_no_bias(attn_txt, self.to_add_out);
|
||||
(img_out, txt_out)
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Single-stream parallel attention (fused QKV + MLP-in, fused attn-out + MLP-out)
|
||||
// =============================================================================
|
||||
|
||||
struct SingleStreamAttn {
|
||||
to_qkv_mlp_proj: GraphTensor, // (3*HIDDEN + 2*MLP_HIDDEN, HIDDEN)
|
||||
norm_q: GraphTensor, // (HEAD_DIM,)
|
||||
norm_k: GraphTensor,
|
||||
to_out: GraphTensor, // (HIDDEN, HIDDEN + MLP_HIDDEN)
|
||||
}
|
||||
|
||||
impl SingleStreamAttn {
|
||||
fn new(prefix: &str, cx: &mut Graph) -> Self {
|
||||
let qkv_mlp_out = 3 * HIDDEN + 2 * MLP_HIDDEN; // 18432 + 36864 = 55296
|
||||
Self {
|
||||
to_qkv_mlp_proj: cx
|
||||
.named_tensor(
|
||||
format!("{prefix}.to_qkv_mlp_proj.weight"),
|
||||
(qkv_mlp_out, HIDDEN),
|
||||
)
|
||||
.as_dtype(WEIGHT_DTYPE)
|
||||
.persist(),
|
||||
norm_q: cx
|
||||
.named_tensor(format!("{prefix}.norm_q.weight"), HEAD_DIM)
|
||||
.as_dtype(WEIGHT_DTYPE)
|
||||
.persist(),
|
||||
norm_k: cx
|
||||
.named_tensor(format!("{prefix}.norm_k.weight"), HEAD_DIM)
|
||||
.as_dtype(WEIGHT_DTYPE)
|
||||
.persist(),
|
||||
to_out: cx
|
||||
.named_tensor(
|
||||
format!("{prefix}.to_out.weight"),
|
||||
(HIDDEN, HIDDEN + MLP_HIDDEN),
|
||||
)
|
||||
.as_dtype(WEIGHT_DTYPE)
|
||||
.persist(),
|
||||
}
|
||||
}
|
||||
|
||||
/// `hidden`: `(S, HIDDEN)`, `rope_cos/sin`: `(S, HEAD_DIM)`.
|
||||
fn forward(
|
||||
&self,
|
||||
hidden: GraphTensor,
|
||||
rope_cos: GraphTensor,
|
||||
rope_sin: GraphTensor,
|
||||
) -> GraphTensor {
|
||||
let projected = linear_no_bias(hidden, self.to_qkv_mlp_proj);
|
||||
let qkv_size = 3 * HIDDEN;
|
||||
let qkv = projected.slice((.., ..qkv_size));
|
||||
let mlp_in = projected.slice((.., qkv_size..));
|
||||
|
||||
let q = qkv.slice((.., ..HIDDEN));
|
||||
let k = qkv.slice((.., HIDDEN..2 * HIDDEN));
|
||||
let v = qkv.slice((.., 2 * HIDDEN..));
|
||||
|
||||
let q = q.split_dims(1, HEAD_DIM); // (S, H, D)
|
||||
let k = k.split_dims(1, HEAD_DIM);
|
||||
let v = v.split_dims(1, HEAD_DIM);
|
||||
|
||||
let q = rmsnorm(q, self.norm_q, RMS_NORM_HEAD_EPS);
|
||||
let k = rmsnorm(k, self.norm_k, RMS_NORM_HEAD_EPS);
|
||||
|
||||
let q = apply_rope(q, rope_cos, rope_sin);
|
||||
let k = apply_rope(k, rope_cos, rope_sin);
|
||||
|
||||
let q = q.transpose(0, 1);
|
||||
let k = k.transpose(0, 1);
|
||||
let v = v.transpose(0, 1);
|
||||
// `merge_dims(1, 2)` on (S, H, D) produces non-contiguous K
|
||||
// stride; force materialization so cublaslt can match the
|
||||
// downstream `to_out` matmul. See dual-stream block above.
|
||||
let attn = sdpa(q, k, v).merge_dims(1, 2) * 1.0_f32; // (S, HIDDEN)
|
||||
|
||||
let mlp = swiglu(mlp_in); // (S, MLP_HIDDEN)
|
||||
|
||||
let combined = attn.concat_along(mlp, 1); // (S, HIDDEN + MLP_HIDDEN)
|
||||
linear_no_bias(combined, self.to_out)
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Double-stream block
|
||||
// =============================================================================
|
||||
|
||||
struct DoubleStreamBlock {
|
||||
attn: DoubleStreamAttn,
|
||||
ff: FeedForward,
|
||||
ff_context: FeedForward,
|
||||
}
|
||||
|
||||
impl DoubleStreamBlock {
|
||||
fn new(idx: usize, cx: &mut Graph) -> Self {
|
||||
let prefix = format!("transformer_blocks.{idx}");
|
||||
Self {
|
||||
attn: DoubleStreamAttn::new(&format!("{prefix}.attn"), cx),
|
||||
ff: FeedForward::new(&format!("{prefix}.ff"), HIDDEN, MLP_HIDDEN, cx),
|
||||
ff_context: FeedForward::new(&format!("{prefix}.ff_context"), HIDDEN, MLP_HIDDEN, cx),
|
||||
}
|
||||
}
|
||||
|
||||
/// img/txt: `(S_*, HIDDEN)`. mod tensors `(36864,)`. Returns `(img_out, txt_out)`.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn forward(
|
||||
&self,
|
||||
img: GraphTensor,
|
||||
txt: GraphTensor,
|
||||
mod_img: GraphTensor,
|
||||
mod_txt: GraphTensor,
|
||||
rope_cos: GraphTensor,
|
||||
rope_sin: GraphTensor,
|
||||
) -> (GraphTensor, GraphTensor) {
|
||||
let img_mods = split_modulation(mod_img, 2);
|
||||
let txt_mods = split_modulation(mod_txt, 2);
|
||||
let (shift_msa, scale_msa, gate_msa) = img_mods[0];
|
||||
let (shift_mlp, scale_mlp, gate_mlp) = img_mods[1];
|
||||
let (c_shift_msa, c_scale_msa, c_gate_msa) = txt_mods[0];
|
||||
let (c_shift_mlp, c_scale_mlp, c_gate_mlp) = txt_mods[1];
|
||||
|
||||
// Pre-attn norms + adaLN modulation.
|
||||
let img_n = ada_modulate(layernorm_noaffine(img, RMS_EPS), scale_msa, shift_msa);
|
||||
let txt_n = ada_modulate(layernorm_noaffine(txt, RMS_EPS), c_scale_msa, c_shift_msa);
|
||||
|
||||
let (attn_img, attn_txt) = self.attn.forward(img_n, txt_n, rope_cos, rope_sin);
|
||||
|
||||
let img = img + ada_gate(attn_img, gate_msa);
|
||||
let txt = txt + ada_gate(attn_txt, c_gate_msa);
|
||||
|
||||
// FF on each stream with second-set adaLN.
|
||||
let img_ff = self.ff.forward(ada_modulate(
|
||||
layernorm_noaffine(img, RMS_EPS),
|
||||
scale_mlp,
|
||||
shift_mlp,
|
||||
));
|
||||
let img = img + ada_gate(img_ff, gate_mlp);
|
||||
|
||||
let txt_ff = self.ff_context.forward(ada_modulate(
|
||||
layernorm_noaffine(txt, RMS_EPS),
|
||||
c_scale_mlp,
|
||||
c_shift_mlp,
|
||||
));
|
||||
let txt = txt + ada_gate(txt_ff, c_gate_mlp);
|
||||
|
||||
(img, txt)
|
||||
}
|
||||
}
|
||||
|
||||
fn ada_gate(x: GraphTensor, gate: GraphTensor) -> GraphTensor {
|
||||
let s = x.dims()[0];
|
||||
x * gate.expand_lhs([s])
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Single-stream block
|
||||
// =============================================================================
|
||||
|
||||
struct SingleStreamBlock {
|
||||
attn: SingleStreamAttn,
|
||||
}
|
||||
|
||||
impl SingleStreamBlock {
|
||||
fn new(idx: usize, cx: &mut Graph) -> Self {
|
||||
let prefix = format!("single_transformer_blocks.{idx}");
|
||||
Self {
|
||||
attn: SingleStreamAttn::new(&format!("{prefix}.attn"), cx),
|
||||
}
|
||||
}
|
||||
|
||||
/// hidden: `(S, HIDDEN)`. mod: `(18432,)`.
|
||||
fn forward(
|
||||
&self,
|
||||
hidden: GraphTensor,
|
||||
mod_t: GraphTensor,
|
||||
rope_cos: GraphTensor,
|
||||
rope_sin: GraphTensor,
|
||||
) -> GraphTensor {
|
||||
let mods = split_modulation(mod_t, 1);
|
||||
let (shift, scale, gate) = mods[0];
|
||||
|
||||
let h = ada_modulate(layernorm_noaffine(hidden, RMS_EPS), scale, shift);
|
||||
let attn_out = self.attn.forward(h, rope_cos, rope_sin);
|
||||
hidden + ada_gate(attn_out, gate)
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Position-id construction + Flux2PosEmbed (RoPE freqs)
|
||||
// =============================================================================
|
||||
|
||||
/// Build the 4D position-id tensor for the concatenated (txt, img) sequence,
|
||||
/// matching `diffusers.pipelines.flux2.pipeline_flux2._prepare_text_ids` and
|
||||
/// `_prepare_latent_ids` exactly.
|
||||
///
|
||||
/// The 4 axes are interpreted as **(time, h, w, layer/sequence)**.
|
||||
///
|
||||
/// * `txt_ids`: shape `(S_txt, 4)`. Row `l` is `(0, 0, 0, l)` — text tokens
|
||||
/// vary only along the last axis (the "layer" / sequence index).
|
||||
/// * `img_ids`: shape `(S_img, 4)` where `S_img = h_pack * w_pack`. Row at
|
||||
/// `(hi, wi)` (cartesian product order) is `(0, hi, wi, 0)` — image
|
||||
/// tokens vary along the spatial axes 1 and 2.
|
||||
///
|
||||
/// `h_pack` and `w_pack` are the **post-pack** spatial dims that the
|
||||
/// transformer sees, i.e. `H/16` and `W/16` for an HxW pixel image. (The VAE
|
||||
/// 8× downsample plus the channel-pack 2× spatial fold give 16× total.)
|
||||
pub fn build_position_ids(s_txt: usize, h_pack: usize, w_pack: usize) -> (Vec<f32>, Vec<f32>) {
|
||||
let mut txt_ids = Vec::with_capacity(s_txt * 4);
|
||||
for l in 0..s_txt {
|
||||
txt_ids.extend_from_slice(&[0.0, 0.0, 0.0, l as f32]);
|
||||
}
|
||||
let mut img_ids = Vec::with_capacity(h_pack * w_pack * 4);
|
||||
for hi in 0..h_pack {
|
||||
for wi in 0..w_pack {
|
||||
img_ids.extend_from_slice(&[0.0, hi as f32, wi as f32, 0.0]);
|
||||
}
|
||||
}
|
||||
(txt_ids, img_ids)
|
||||
}
|
||||
|
||||
/// Pre-compute `(cos, sin)` flat tables for the concatenated `(txt, img)`
|
||||
/// position grid. Each is `S × HEAD_DIM` row-major. This mirrors
|
||||
/// `Flux2PosEmbed.forward` (calls `get_1d_rotary_pos_embed` per axis with
|
||||
/// `repeat_interleave_real=True`, then concatenates along the last dim).
|
||||
pub fn build_rope_tables(s_txt: usize, h_pack: usize, w_pack: usize) -> (Vec<f32>, Vec<f32>) {
|
||||
let (txt_ids, img_ids) = build_position_ids(s_txt, h_pack, w_pack);
|
||||
let s_total = s_txt + h_pack * w_pack;
|
||||
let head_dim = HEAD_DIM;
|
||||
debug_assert_eq!(ROPE_AXES.iter().sum::<usize>(), head_dim);
|
||||
|
||||
let mut cos_table = Vec::with_capacity(s_total * head_dim);
|
||||
let mut sin_table = Vec::with_capacity(s_total * head_dim);
|
||||
|
||||
let row = |row_ids: &[f32]| -> (Vec<f32>, Vec<f32>) {
|
||||
// For each axis, generate cos/sin of length axes_dim[i] (with
|
||||
// repeat_interleave_real=True meaning each freq is repeated twice).
|
||||
let mut row_cos = Vec::with_capacity(head_dim);
|
||||
let mut row_sin = Vec::with_capacity(head_dim);
|
||||
for (i, &dim) in ROPE_AXES.iter().enumerate() {
|
||||
let pos = row_ids[i];
|
||||
let half = dim / 2;
|
||||
for j in 0..half {
|
||||
let exponent = (2 * j) as f32 / dim as f32;
|
||||
let freq = 1.0_f32 / ROPE_THETA.powf(exponent);
|
||||
let arg = pos * freq;
|
||||
let c = arg.cos();
|
||||
let s = arg.sin();
|
||||
// repeat_interleave_real: cos cos sin sin pattern
|
||||
row_cos.push(c);
|
||||
row_cos.push(c);
|
||||
row_sin.push(s);
|
||||
row_sin.push(s);
|
||||
}
|
||||
}
|
||||
(row_cos, row_sin)
|
||||
};
|
||||
|
||||
for r in 0..s_txt {
|
||||
let (c, s) = row(&txt_ids[r * 4..(r + 1) * 4]);
|
||||
cos_table.extend(c);
|
||||
sin_table.extend(s);
|
||||
}
|
||||
for r in 0..h_pack * w_pack {
|
||||
let (c, s) = row(&img_ids[r * 4..(r + 1) * 4]);
|
||||
cos_table.extend(c);
|
||||
sin_table.extend(s);
|
||||
}
|
||||
(cos_table, sin_table)
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Top-level transformer
|
||||
// =============================================================================
|
||||
|
||||
pub struct Flux2Transformer {
|
||||
// Embedders
|
||||
pub x_embedder: GraphTensor, // (HIDDEN, IN_CHANNELS)
|
||||
pub context_embedder: GraphTensor, // (HIDDEN, JOINT_ATTENTION_DIM)
|
||||
|
||||
// Time + guidance embedding
|
||||
pub time_t1_w: GraphTensor, // (HIDDEN, TIMESTEP_GUIDANCE_CHANNELS)
|
||||
pub time_t2_w: GraphTensor, // (HIDDEN, HIDDEN)
|
||||
pub guidance_t1_w: GraphTensor,
|
||||
pub guidance_t2_w: GraphTensor,
|
||||
|
||||
// Modulation tables
|
||||
pub mod_img: GraphTensor, // (HIDDEN*6, HIDDEN)
|
||||
pub mod_txt: GraphTensor, // (HIDDEN*6, HIDDEN)
|
||||
pub mod_single: GraphTensor, // (HIDDEN*3, HIDDEN)
|
||||
|
||||
// Output
|
||||
pub norm_out_lin: GraphTensor, // (HIDDEN*2, HIDDEN) for AdaLayerNormContinuous
|
||||
pub proj_out: GraphTensor, // (PATCH²*OUT_CHANNELS, HIDDEN)
|
||||
|
||||
// Blocks
|
||||
transformer_blocks: Vec<DoubleStreamBlock>,
|
||||
single_transformer_blocks: Vec<SingleStreamBlock>,
|
||||
}
|
||||
|
||||
impl Flux2Transformer {
|
||||
pub fn init(cx: &mut Graph) -> Self {
|
||||
let bf16 = WEIGHT_DTYPE;
|
||||
let mk = |name: &str, shape: (usize, usize), cx: &mut Graph| -> GraphTensor {
|
||||
cx.named_tensor(name, shape).as_dtype(bf16).persist()
|
||||
};
|
||||
let mk1 = |name: &str, n: usize, cx: &mut Graph| -> GraphTensor {
|
||||
cx.named_tensor(name, n).as_dtype(bf16).persist()
|
||||
};
|
||||
|
||||
let x_embedder = mk("x_embedder.weight", (HIDDEN, IN_CHANNELS), cx);
|
||||
let context_embedder = mk("context_embedder.weight", (HIDDEN, JOINT_ATTENTION_DIM), cx);
|
||||
|
||||
let time_t1_w = mk(
|
||||
"time_guidance_embed.timestep_embedder.linear_1.weight",
|
||||
(HIDDEN, TIMESTEP_GUIDANCE_CHANNELS),
|
||||
cx,
|
||||
);
|
||||
let time_t2_w = mk(
|
||||
"time_guidance_embed.timestep_embedder.linear_2.weight",
|
||||
(HIDDEN, HIDDEN),
|
||||
cx,
|
||||
);
|
||||
let guidance_t1_w = mk(
|
||||
"time_guidance_embed.guidance_embedder.linear_1.weight",
|
||||
(HIDDEN, TIMESTEP_GUIDANCE_CHANNELS),
|
||||
cx,
|
||||
);
|
||||
let guidance_t2_w = mk(
|
||||
"time_guidance_embed.guidance_embedder.linear_2.weight",
|
||||
(HIDDEN, HIDDEN),
|
||||
cx,
|
||||
);
|
||||
|
||||
let mod_img = mk(
|
||||
"double_stream_modulation_img.linear.weight",
|
||||
(HIDDEN * 6, HIDDEN),
|
||||
cx,
|
||||
);
|
||||
let mod_txt = mk(
|
||||
"double_stream_modulation_txt.linear.weight",
|
||||
(HIDDEN * 6, HIDDEN),
|
||||
cx,
|
||||
);
|
||||
let mod_single = mk(
|
||||
"single_stream_modulation.linear.weight",
|
||||
(HIDDEN * 3, HIDDEN),
|
||||
cx,
|
||||
);
|
||||
|
||||
let norm_out_lin = mk("norm_out.linear.weight", (HIDDEN * 2, HIDDEN), cx);
|
||||
let proj_out = mk(
|
||||
"proj_out.weight",
|
||||
(PATCH_SIZE * PATCH_SIZE * IN_CHANNELS, HIDDEN),
|
||||
cx,
|
||||
);
|
||||
|
||||
let transformer_blocks = (0..num_layers())
|
||||
.map(|i| DoubleStreamBlock::new(i, cx))
|
||||
.collect();
|
||||
let single_transformer_blocks = (0..num_single_layers())
|
||||
.map(|i| SingleStreamBlock::new(i, cx))
|
||||
.collect();
|
||||
|
||||
let _ = mk1; // kept for parity if extra biases get added later
|
||||
Self {
|
||||
x_embedder,
|
||||
context_embedder,
|
||||
time_t1_w,
|
||||
time_t2_w,
|
||||
guidance_t1_w,
|
||||
guidance_t2_w,
|
||||
mod_img,
|
||||
mod_txt,
|
||||
mod_single,
|
||||
norm_out_lin,
|
||||
proj_out,
|
||||
transformer_blocks,
|
||||
single_transformer_blocks,
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute `temb = timestep_emb + guidance_emb`. Both `timestep` and
|
||||
/// `guidance` are `(1,)` F32 graph tensors set per denoising step at
|
||||
/// runtime; the caller is responsible for the `* 1000` scaling that
|
||||
/// diffusers does in its forward.
|
||||
fn embed_time(&self, timestep: GraphTensor, guidance: GraphTensor) -> GraphTensor {
|
||||
// Diffusers' Flux2Transformer2DModel.forward multiplies its
|
||||
// timestep + guidance inputs by 1000 before passing them to
|
||||
// `time_guidance_embed`. The pipeline upstream divides the raw
|
||||
// scheduler timestep by 1000 to give the transformer a 0..1
|
||||
// scalar; the transformer multiplies it back to 0..1000 here so
|
||||
// the sin/cos `time_proj` argument range matches what the
|
||||
// model was trained on.
|
||||
//
|
||||
// Our `main.rs` feeds the same 0..1 sigma scalar (matching the
|
||||
// pipeline-level interface) and we mirror the *1000 here.
|
||||
let timestep = timestep * 1000.0_f32;
|
||||
let guidance = guidance * 1000.0_f32;
|
||||
let t_proj = timesteps_proj(timestep, TIMESTEP_GUIDANCE_CHANNELS);
|
||||
let t1 = linear_no_bias(t_proj, self.time_t1_w).silu();
|
||||
let t_emb = linear_no_bias(t1, self.time_t2_w);
|
||||
let g_proj = timesteps_proj(guidance, TIMESTEP_GUIDANCE_CHANNELS);
|
||||
let g1 = linear_no_bias(g_proj, self.guidance_t1_w).silu();
|
||||
let g_emb = linear_no_bias(g1, self.guidance_t2_w);
|
||||
t_emb + g_emb
|
||||
}
|
||||
|
||||
/// Single denoising-step forward, fully graph-tensorized so the same
|
||||
/// compiled graph runs for every step of the diffusion loop.
|
||||
///
|
||||
/// - `latent`: `(S_img, IN_CHANNELS=128)` already patched, F32. Updated
|
||||
/// each step.
|
||||
/// - `text_embed`: `(S_txt, JOINT_ATTENTION_DIM=15360)`, set once before
|
||||
/// the loop.
|
||||
/// - `rope_cos`, `rope_sin`: `(S_txt + S_img, HEAD_DIM=128)`, also set once.
|
||||
/// - `timestep`, `guidance`: `(1,)` F32 scalars set per step (already
|
||||
/// scaled by 1000).
|
||||
///
|
||||
/// Returns the model's velocity prediction the scheduler integrates.
|
||||
pub fn forward(
|
||||
&self,
|
||||
latent: GraphTensor,
|
||||
text_embed: GraphTensor,
|
||||
rope_cos: GraphTensor,
|
||||
rope_sin: GraphTensor,
|
||||
timestep: GraphTensor,
|
||||
guidance: GraphTensor,
|
||||
) -> GraphTensor {
|
||||
let temb = self.embed_time(timestep, guidance);
|
||||
let mod_img = modulation(temb, self.mod_img);
|
||||
let mod_txt = modulation(temb, self.mod_txt);
|
||||
let mod_single = modulation(temb, self.mod_single);
|
||||
|
||||
let mut img = linear_no_bias(latent, self.x_embedder);
|
||||
let mut txt = linear_no_bias(text_embed, self.context_embedder);
|
||||
|
||||
for block in self.transformer_blocks.iter() {
|
||||
let (i, t) = block.forward(img, txt, mod_img, mod_txt, rope_cos, rope_sin);
|
||||
img = i;
|
||||
txt = t;
|
||||
}
|
||||
|
||||
let s_img = img.dims()[0].to_usize().expect("S_img static");
|
||||
let s_txt = txt.dims()[0].to_usize().expect("S_txt static");
|
||||
let mut hidden = txt.concat_along(img, 0); // (S_txt + S_img, HIDDEN)
|
||||
|
||||
for block in self.single_transformer_blocks.iter() {
|
||||
hidden = block.forward(hidden, mod_single, rope_cos, rope_sin);
|
||||
}
|
||||
|
||||
// Drop text prefix.
|
||||
let img = hidden.slice((s_txt..s_txt + s_img, ..));
|
||||
|
||||
// AdaLayerNormContinuous: scale, shift = chunk(linear(silu(temb)), 2).
|
||||
let emb = linear_no_bias(temb.silu(), self.norm_out_lin);
|
||||
let half = HIDDEN;
|
||||
let scale = emb.slice((..half,));
|
||||
let shift = emb.slice((half..,));
|
||||
let normed = layernorm_noaffine(img, RMS_EPS);
|
||||
let modulated = ada_modulate(normed, scale, shift);
|
||||
|
||||
linear_no_bias(modulated, self.proj_out)
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Tests
|
||||
// =============================================================================
|
||||
@@ -1,502 +0,0 @@
|
||||
//! AutoencoderKLFlux2 decoder, in pure HLIR.
|
||||
//!
|
||||
//! ## Status
|
||||
//!
|
||||
//! - All three primitives (`conv2d_bias`, `group_norm`, `nearest_upsample_2x`)
|
||||
//! are implemented and **individually validated** against numerical
|
||||
//! references — see the tests at the bottom of this file.
|
||||
//! - Stitching them into the full decoder currently hits a `luminal_cuda_lite`
|
||||
//! optimizer limit: chains of two prefix convs feeding a two-iteration
|
||||
//! resnet body with a residual back to the second conv's output cause the
|
||||
//! e-graph cleanup to discard the output's eclass ("No valid graphs present
|
||||
//! in the e-graph!"). See `deep_conv_chain_with_residual_compiles` (ignored)
|
||||
//! for the minimal reproducer. Every resnet block in the diffusers VAE has
|
||||
//! this shape, so the full decoder can't be lowered until that's resolved.
|
||||
//!
|
||||
//! ## Architecture (for reference once the optimizer is fixed)
|
||||
//!
|
||||
//! Pipeline (input image of side N pixels, latent stride 8):
|
||||
//! 1. `post_quant_conv` : 1×1 conv 32 → 32, latent at (N/8, N/8)
|
||||
//! 2. `decoder.conv_in` : 3×3 conv 32 → 512
|
||||
//! 3. `decoder.mid_block` : ResNet → SelfAttn → ResNet, all 512
|
||||
//! 4. `decoder.up_blocks[0..3]` : 3 resnets each + nearest-2× upsample
|
||||
//! (channel sequence 512 → 512 → 512 → 256 → 128; last block has no upsample)
|
||||
//! 5. `decoder.conv_norm_out` : GroupNorm(32 groups) + SiLU
|
||||
//! 6. `decoder.conv_out` : 3×3 conv 128 → 3 = (R,G,B) pixels
|
||||
//!
|
||||
//! Three building blocks that don't exist in `luminal_nn` get inlined here
|
||||
//! using only stock HLIR ops (no custom kernels):
|
||||
//!
|
||||
//! - **`conv2d_bias`** — unfold + matmul + bias, then a single explicit gather
|
||||
//! to reshape (H_out*W_out, C_out) into (C_out, H_out, W_out).
|
||||
//! - **`group_norm`** — flatten each group's volume into a single axis,
|
||||
//! `layer_norm` over that axis, reshape back, per-channel affine.
|
||||
//! - **`nearest_upsample_2x`** — `expand_dim(broadcast) + merge_dims` on each
|
||||
//! spatial axis, so each pixel is duplicated 2×2.
|
||||
|
||||
use luminal::{graph::Graph, prelude::*};
|
||||
|
||||
/// Standard AutoencoderKL constants for Flux 2.
|
||||
pub const LATENT_CHANNELS: usize = 32;
|
||||
pub const VAE_DOWNSAMPLE: usize = 8; // 3 spatial halvings on the encoder side.
|
||||
pub const NORM_NUM_GROUPS: usize = 32;
|
||||
pub const NORM_EPS: f32 = 1e-6;
|
||||
pub const BLOCK_OUT_CHANNELS: [usize; 4] = [128, 256, 512, 512];
|
||||
pub const LAYERS_PER_BLOCK: usize = 2; // diffusers config; the decoder uses 3 resnets/block (= layers_per_block + 1).
|
||||
pub const RESNETS_PER_BLOCK: usize = LAYERS_PER_BLOCK + 1;
|
||||
|
||||
// Decoder channel progression (reverse of encoder: deepest first).
|
||||
// up_blocks[i].in_channels = block_out_channels[max(reversed_idx - 1, 0)]
|
||||
// up_blocks[i].out_channels = block_out_channels[reversed_idx]
|
||||
// where reversed_idx walks block_out_channels from back to front.
|
||||
fn decoder_block_channels(block_idx: usize) -> (usize, usize) {
|
||||
let n = BLOCK_OUT_CHANNELS.len();
|
||||
let reversed = n - 1 - block_idx;
|
||||
let prev = if reversed + 1 < n {
|
||||
BLOCK_OUT_CHANNELS[reversed + 1]
|
||||
} else {
|
||||
BLOCK_OUT_CHANNELS[reversed]
|
||||
};
|
||||
let out = BLOCK_OUT_CHANNELS[reversed];
|
||||
let in_c = if block_idx == 0 {
|
||||
BLOCK_OUT_CHANNELS[n - 1] // mid block runs at the deepest channel count
|
||||
} else {
|
||||
prev
|
||||
};
|
||||
(in_c, out)
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// HLIR primitive helpers
|
||||
// =============================================================================
|
||||
|
||||
/// 2D convolution with bias on a `(C_in, H, W)` input, weights stored as
|
||||
/// `(C_out, C_in, K, K)` flat-loaded, bias as `(C_out,)`. Returns
|
||||
/// `(C_out, H_out, W_out)` where `H_out = (H + 2*padding - kernel) / stride + 1`.
|
||||
///
|
||||
/// Wraps the direct conv kernel from [`luminal_cuda_lite::kernel::conv2d_bias`]
|
||||
/// (one CUDA thread per output element), which avoids materializing the
|
||||
/// `(H_out*W_out, C_in*K*K)` unfold intermediate that earlier HLIR-only
|
||||
/// implementations needed.
|
||||
fn conv2d_bias(
|
||||
x: GraphTensor,
|
||||
weight: GraphTensor,
|
||||
bias: GraphTensor,
|
||||
kernel: usize,
|
||||
stride: usize,
|
||||
padding: usize,
|
||||
) -> GraphTensor {
|
||||
luminal_cuda_lite::kernel::conv2d_bias(x, weight, bias, kernel, stride, padding)
|
||||
}
|
||||
|
||||
/// PyTorch-style GroupNorm on a (C, H, W) tensor.
|
||||
///
|
||||
/// The channel axis is split into `(num_groups, group_size)`; the mean and
|
||||
/// variance are computed jointly over `(group_size, H, W)` per group; then
|
||||
/// the output is rescaled and shifted by per-channel `weight` and `bias`.
|
||||
///
|
||||
/// Implementation note: we flatten the per-group volume into a single axis
|
||||
/// before normalizing (rather than calling `layer_norm` over three axes at
|
||||
/// once). The single-axis form generates simpler egglog patterns and survives
|
||||
/// composition into deep conv chains, where the 3-axis form drops out of the
|
||||
/// e-graph during cleanup.
|
||||
fn group_norm(
|
||||
x: GraphTensor,
|
||||
weight: GraphTensor,
|
||||
bias: GraphTensor,
|
||||
num_groups: usize,
|
||||
eps: f32,
|
||||
) -> GraphTensor {
|
||||
let dims = x.dims();
|
||||
assert_eq!(dims.len(), 3, "group_norm expects (C, H, W)");
|
||||
let c = dims[0];
|
||||
let h = dims[1];
|
||||
let w = dims[2];
|
||||
|
||||
let c_const = c
|
||||
.to_usize()
|
||||
.expect("num_channels must be static for GroupNorm");
|
||||
let h_const = h.to_usize().expect("height must be static for GroupNorm");
|
||||
let w_const = w.to_usize().expect("width must be static for GroupNorm");
|
||||
assert!(
|
||||
c_const.is_multiple_of(num_groups),
|
||||
"num_channels ({c_const}) must be a multiple of num_groups ({num_groups})",
|
||||
);
|
||||
let group_size = c_const / num_groups;
|
||||
let group_volume = group_size * h_const * w_const;
|
||||
|
||||
// Reshape to (num_groups, group_size * H * W) — one flat axis per group.
|
||||
let flat = x.merge_dims(0, 1).merge_dims(0, 1); // (C*H*W,)
|
||||
let grouped = flat.split_dims(0, group_volume); // (num_groups, group_volume)
|
||||
|
||||
// LayerNorm over the single per-group axis.
|
||||
let normed = grouped.layer_norm(1, eps);
|
||||
|
||||
// Reshape (num_groups, group_volume) back to (C, H, W).
|
||||
let unshaped = normed
|
||||
.merge_dims(0, 1) // flat (C*H*W,)
|
||||
.split_dims(0, h_const * w_const) // (C, H*W)
|
||||
.split_dims(1, w_const); // (C, H, W)
|
||||
|
||||
// Per-channel affine: weight, bias both shape (C,) -> (C, H, W).
|
||||
let w_b = weight.expand_dim(1, h).expand_dim(2, w);
|
||||
let b_b = bias.expand_dim(1, h).expand_dim(2, w);
|
||||
unshaped * w_b + b_b
|
||||
}
|
||||
|
||||
/// Nearest-neighbour 2× spatial upsample on a (C, H, W) tensor.
|
||||
fn nearest_upsample_2x(x: GraphTensor) -> GraphTensor {
|
||||
// (C, H, W) -> (C, H, 2, W) -> (C, 2H, W) -> (C, 2H, W, 2) -> (C, 2H, 2W)
|
||||
let stage1 = x.expand_dim(2, 2_usize).merge_dims(1, 2);
|
||||
let stage2 = stage1.expand_dim(3, 2_usize).merge_dims(2, 3);
|
||||
// Materialize the broadcast view so subsequent ops see contiguous strides.
|
||||
stage2 + 0.0_f32
|
||||
}
|
||||
|
||||
/// SiLU = x * sigmoid(x).
|
||||
fn silu(x: GraphTensor) -> GraphTensor {
|
||||
x.silu()
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Decoder building blocks
|
||||
// =============================================================================
|
||||
|
||||
struct ResnetBlock {
|
||||
norm1_w: GraphTensor,
|
||||
norm1_b: GraphTensor,
|
||||
conv1_w: GraphTensor,
|
||||
conv1_b: GraphTensor,
|
||||
norm2_w: GraphTensor,
|
||||
norm2_b: GraphTensor,
|
||||
conv2_w: GraphTensor,
|
||||
conv2_b: GraphTensor,
|
||||
shortcut: Option<(GraphTensor, GraphTensor)>, // 1×1 conv when in_c != out_c
|
||||
in_channels: usize,
|
||||
out_channels: usize,
|
||||
}
|
||||
|
||||
impl ResnetBlock {
|
||||
fn new(prefix: &str, in_c: usize, out_c: usize, cx: &mut Graph) -> Self {
|
||||
let shortcut = if in_c == out_c {
|
||||
None
|
||||
} else {
|
||||
Some((
|
||||
cx.named_tensor(format!("{prefix}.conv_shortcut.weight"), (out_c, in_c))
|
||||
.persist(),
|
||||
cx.named_tensor(format!("{prefix}.conv_shortcut.bias"), out_c)
|
||||
.persist(),
|
||||
))
|
||||
};
|
||||
Self {
|
||||
norm1_w: cx
|
||||
.named_tensor(format!("{prefix}.norm1.weight"), in_c)
|
||||
.persist(),
|
||||
norm1_b: cx
|
||||
.named_tensor(format!("{prefix}.norm1.bias"), in_c)
|
||||
.persist(),
|
||||
conv1_w: cx
|
||||
.named_tensor(format!("{prefix}.conv1.weight"), (out_c, in_c * 3 * 3))
|
||||
.persist(),
|
||||
conv1_b: cx
|
||||
.named_tensor(format!("{prefix}.conv1.bias"), out_c)
|
||||
.persist(),
|
||||
norm2_w: cx
|
||||
.named_tensor(format!("{prefix}.norm2.weight"), out_c)
|
||||
.persist(),
|
||||
norm2_b: cx
|
||||
.named_tensor(format!("{prefix}.norm2.bias"), out_c)
|
||||
.persist(),
|
||||
conv2_w: cx
|
||||
.named_tensor(format!("{prefix}.conv2.weight"), (out_c, out_c * 3 * 3))
|
||||
.persist(),
|
||||
conv2_b: cx
|
||||
.named_tensor(format!("{prefix}.conv2.bias"), out_c)
|
||||
.persist(),
|
||||
shortcut,
|
||||
in_channels: in_c,
|
||||
out_channels: out_c,
|
||||
}
|
||||
}
|
||||
|
||||
fn forward(&self, x: GraphTensor) -> GraphTensor {
|
||||
let h = group_norm(x, self.norm1_w, self.norm1_b, NORM_NUM_GROUPS, NORM_EPS);
|
||||
let h = silu(h);
|
||||
let h = conv2d_bias(h, self.conv1_w, self.conv1_b, 3, 1, 1);
|
||||
let h = group_norm(h, self.norm2_w, self.norm2_b, NORM_NUM_GROUPS, NORM_EPS);
|
||||
let h = silu(h);
|
||||
let h = conv2d_bias(h, self.conv2_w, self.conv2_b, 3, 1, 1);
|
||||
|
||||
let skip = if self.in_channels == self.out_channels {
|
||||
x
|
||||
} else {
|
||||
let (sw, sb) = self.shortcut.expect("shortcut required when in_c != out_c");
|
||||
conv2d_bias(x, sw, sb, 1, 1, 0)
|
||||
};
|
||||
skip + h
|
||||
}
|
||||
}
|
||||
|
||||
struct AttnBlock {
|
||||
group_norm_w: GraphTensor,
|
||||
group_norm_b: GraphTensor,
|
||||
to_q_w: GraphTensor,
|
||||
to_q_b: GraphTensor,
|
||||
to_k_w: GraphTensor,
|
||||
to_k_b: GraphTensor,
|
||||
to_v_w: GraphTensor,
|
||||
to_v_b: GraphTensor,
|
||||
to_out_w: GraphTensor,
|
||||
to_out_b: GraphTensor,
|
||||
channels: usize,
|
||||
}
|
||||
|
||||
impl AttnBlock {
|
||||
fn new(prefix: &str, channels: usize, cx: &mut Graph) -> Self {
|
||||
let lin =
|
||||
|name: &str, out: usize, inn: usize, cx: &mut Graph| -> (GraphTensor, GraphTensor) {
|
||||
(
|
||||
cx.named_tensor(format!("{prefix}.{name}.weight"), (out, inn))
|
||||
.persist(),
|
||||
cx.named_tensor(format!("{prefix}.{name}.bias"), out)
|
||||
.persist(),
|
||||
)
|
||||
};
|
||||
let (to_q_w, to_q_b) = lin("to_q", channels, channels, cx);
|
||||
let (to_k_w, to_k_b) = lin("to_k", channels, channels, cx);
|
||||
let (to_v_w, to_v_b) = lin("to_v", channels, channels, cx);
|
||||
let (to_out_w, to_out_b) = lin("to_out.0", channels, channels, cx);
|
||||
Self {
|
||||
group_norm_w: cx
|
||||
.named_tensor(format!("{prefix}.group_norm.weight"), channels)
|
||||
.persist(),
|
||||
group_norm_b: cx
|
||||
.named_tensor(format!("{prefix}.group_norm.bias"), channels)
|
||||
.persist(),
|
||||
to_q_w,
|
||||
to_q_b,
|
||||
to_k_w,
|
||||
to_k_b,
|
||||
to_v_w,
|
||||
to_v_b,
|
||||
to_out_w,
|
||||
to_out_b,
|
||||
channels,
|
||||
}
|
||||
}
|
||||
|
||||
fn forward(&self, x: GraphTensor) -> GraphTensor {
|
||||
let dims = x.dims();
|
||||
assert_eq!(dims.len(), 3, "AttnBlock expects (C, H, W)");
|
||||
let _h = dims[1];
|
||||
let w = dims[2];
|
||||
let residual = x;
|
||||
|
||||
// GroupNorm + reshape to (HW, C) for linear projections.
|
||||
let normed = group_norm(
|
||||
x,
|
||||
self.group_norm_w,
|
||||
self.group_norm_b,
|
||||
NORM_NUM_GROUPS,
|
||||
NORM_EPS,
|
||||
);
|
||||
// (C, H, W) -> (C, H*W) -> (H*W, C). The transpose at the end leaves
|
||||
// a column-major view, which the direct matmul kernels assume away;
|
||||
// `* 1.0` forces a contiguous row-major materialization.
|
||||
let merged = normed.merge_dims(1, 2).transpose(0, 1) * 1.0_f32;
|
||||
|
||||
// Q, K, V projections — direct kernel routes around the cublaslt
|
||||
// 2D rule, which silently fails to fire for some of these matmuls
|
||||
// and lets search occasionally pick the broadcast Mul + SumReduce
|
||||
// fallback. At 1024² the bad path on `q @ kᵀ` allocates a
|
||||
// `(HW, HW, C) = (16384, 16384, 512)` ≈ 524 GiB intermediate.
|
||||
let q = luminal_cuda_lite::kernel::linear_bias(merged, self.to_q_w, self.to_q_b);
|
||||
let k = luminal_cuda_lite::kernel::linear_bias(merged, self.to_k_w, self.to_k_b);
|
||||
let v = luminal_cuda_lite::kernel::linear_bias(merged, self.to_v_w, self.to_v_b);
|
||||
|
||||
// Standard scaled dot-product attention over the spatial axis.
|
||||
// `q @ kᵀ` with k stored row-major as `(HW, C)`: matmul_2d_t handles
|
||||
// the transpose without materialising k as a separate tensor.
|
||||
let scale = (self.channels as f32).sqrt().recip();
|
||||
let scores = luminal_cuda_lite::kernel::matmul_2d_t(q, k) * scale;
|
||||
let attn_w = scores.softmax(1);
|
||||
// attn_w is (HW, HW) row-major, v is (HW, C) row-major; plain matmul.
|
||||
let attn = luminal_cuda_lite::kernel::matmul_2d(attn_w, v);
|
||||
|
||||
let out = luminal_cuda_lite::kernel::linear_bias(attn, self.to_out_w, self.to_out_b);
|
||||
// (H*W, C) -> (C, H*W) -> (C, H, W)
|
||||
let out = out.transpose(0, 1).split_dims(1, w);
|
||||
residual + out
|
||||
}
|
||||
}
|
||||
|
||||
struct UpBlock {
|
||||
resnets: Vec<ResnetBlock>,
|
||||
upsampler: Option<(GraphTensor, GraphTensor)>, // 3×3 conv after nearest-2×
|
||||
}
|
||||
|
||||
impl UpBlock {
|
||||
fn new(prefix: &str, in_c: usize, out_c: usize, with_upsampler: bool, cx: &mut Graph) -> Self {
|
||||
let mut resnets = Vec::with_capacity(RESNETS_PER_BLOCK);
|
||||
for r in 0..RESNETS_PER_BLOCK {
|
||||
let resnet_in = if r == 0 { in_c } else { out_c };
|
||||
resnets.push(ResnetBlock::new(
|
||||
&format!("{prefix}.resnets.{r}"),
|
||||
resnet_in,
|
||||
out_c,
|
||||
cx,
|
||||
));
|
||||
}
|
||||
let upsampler = if with_upsampler {
|
||||
Some((
|
||||
cx.named_tensor(
|
||||
format!("{prefix}.upsamplers.0.conv.weight"),
|
||||
(out_c, out_c * 3 * 3),
|
||||
)
|
||||
.persist(),
|
||||
cx.named_tensor(format!("{prefix}.upsamplers.0.conv.bias"), out_c)
|
||||
.persist(),
|
||||
))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Self { resnets, upsampler }
|
||||
}
|
||||
|
||||
fn forward(&self, mut x: GraphTensor) -> GraphTensor {
|
||||
for r in &self.resnets {
|
||||
x = r.forward(x);
|
||||
}
|
||||
if let Some((w, b)) = &self.upsampler {
|
||||
let up = nearest_upsample_2x(x);
|
||||
x = conv2d_bias(up, *w, *b, 3, 1, 1);
|
||||
}
|
||||
x
|
||||
}
|
||||
}
|
||||
|
||||
pub struct VaeDecoder {
|
||||
post_quant_w: GraphTensor,
|
||||
post_quant_b: GraphTensor,
|
||||
conv_in_w: GraphTensor,
|
||||
conv_in_b: GraphTensor,
|
||||
mid_resnet_0: ResnetBlock,
|
||||
mid_attn: AttnBlock,
|
||||
mid_resnet_1: ResnetBlock,
|
||||
up_blocks: Vec<UpBlock>,
|
||||
norm_out_w: GraphTensor,
|
||||
norm_out_b: GraphTensor,
|
||||
conv_out_w: GraphTensor,
|
||||
conv_out_b: GraphTensor,
|
||||
}
|
||||
|
||||
impl VaeDecoder {
|
||||
pub fn new(cx: &mut Graph) -> Self {
|
||||
let post_quant_w = cx
|
||||
.named_tensor("post_quant_conv.weight", (LATENT_CHANNELS, LATENT_CHANNELS))
|
||||
.persist();
|
||||
let post_quant_b = cx
|
||||
.named_tensor("post_quant_conv.bias", LATENT_CHANNELS)
|
||||
.persist();
|
||||
|
||||
let mid = BLOCK_OUT_CHANNELS[BLOCK_OUT_CHANNELS.len() - 1];
|
||||
let conv_in_w = cx
|
||||
.named_tensor("decoder.conv_in.weight", (mid, LATENT_CHANNELS * 3 * 3))
|
||||
.persist();
|
||||
let conv_in_b = cx.named_tensor("decoder.conv_in.bias", mid).persist();
|
||||
|
||||
let mid_resnet_0 = ResnetBlock::new("decoder.mid_block.resnets.0", mid, mid, cx);
|
||||
let mid_attn = AttnBlock::new("decoder.mid_block.attentions.0", mid, cx);
|
||||
let mid_resnet_1 = ResnetBlock::new("decoder.mid_block.resnets.1", mid, mid, cx);
|
||||
|
||||
let mut up_blocks = Vec::with_capacity(BLOCK_OUT_CHANNELS.len());
|
||||
for b in 0..BLOCK_OUT_CHANNELS.len() {
|
||||
let (in_c, out_c) = decoder_block_channels(b);
|
||||
let with_upsampler = b < BLOCK_OUT_CHANNELS.len() - 1;
|
||||
up_blocks.push(UpBlock::new(
|
||||
&format!("decoder.up_blocks.{b}"),
|
||||
in_c,
|
||||
out_c,
|
||||
with_upsampler,
|
||||
cx,
|
||||
));
|
||||
}
|
||||
|
||||
let last_c = BLOCK_OUT_CHANNELS[0];
|
||||
let norm_out_w = cx
|
||||
.named_tensor("decoder.conv_norm_out.weight", last_c)
|
||||
.persist();
|
||||
let norm_out_b = cx
|
||||
.named_tensor("decoder.conv_norm_out.bias", last_c)
|
||||
.persist();
|
||||
let conv_out_w = cx
|
||||
.named_tensor("decoder.conv_out.weight", (3, last_c * 3 * 3))
|
||||
.persist();
|
||||
let conv_out_b = cx.named_tensor("decoder.conv_out.bias", 3).persist();
|
||||
|
||||
Self {
|
||||
post_quant_w,
|
||||
post_quant_b,
|
||||
conv_in_w,
|
||||
conv_in_b,
|
||||
mid_resnet_0,
|
||||
mid_attn,
|
||||
mid_resnet_1,
|
||||
up_blocks,
|
||||
norm_out_w,
|
||||
norm_out_b,
|
||||
conv_out_w,
|
||||
conv_out_b,
|
||||
}
|
||||
}
|
||||
|
||||
/// Decode a latent of shape (LATENT_CHANNELS, h, w) into an RGB image
|
||||
/// of shape (3, h * VAE_DOWNSAMPLE, w * VAE_DOWNSAMPLE) in the [-1, 1] range.
|
||||
pub fn forward(&self, latent: GraphTensor) -> GraphTensor {
|
||||
self.forward_partial(latent, usize::MAX)
|
||||
}
|
||||
|
||||
/// Run the decoder up to stage `stop_at` (used for incremental debugging).
|
||||
/// Stages: 0=post_quant only, 1=+conv_in, 2..=4=+mid (resnet, attn, resnet),
|
||||
/// 5..=8=+up_blocks[0..3], 9=+conv_norm_out+silu, 10=+conv_out (full).
|
||||
pub fn forward_partial(&self, latent: GraphTensor, stop_at: usize) -> GraphTensor {
|
||||
let mut x = conv2d_bias(latent, self.post_quant_w, self.post_quant_b, 1, 1, 0);
|
||||
if stop_at == 0 {
|
||||
return x;
|
||||
}
|
||||
x = conv2d_bias(x, self.conv_in_w, self.conv_in_b, 3, 1, 1);
|
||||
if stop_at == 1 {
|
||||
return x;
|
||||
}
|
||||
x = self.mid_resnet_0.forward(x);
|
||||
if stop_at == 2 {
|
||||
return x;
|
||||
}
|
||||
x = self.mid_attn.forward(x);
|
||||
if stop_at == 3 {
|
||||
return x;
|
||||
}
|
||||
x = self.mid_resnet_1.forward(x);
|
||||
if stop_at == 4 {
|
||||
return x;
|
||||
}
|
||||
for (i, blk) in self.up_blocks.iter().enumerate() {
|
||||
x = blk.forward(x);
|
||||
if stop_at == 5 + i {
|
||||
return x;
|
||||
}
|
||||
}
|
||||
x = group_norm(
|
||||
x,
|
||||
self.norm_out_w,
|
||||
self.norm_out_b,
|
||||
NORM_NUM_GROUPS,
|
||||
NORM_EPS,
|
||||
);
|
||||
x = silu(x);
|
||||
if stop_at == 9 {
|
||||
return x;
|
||||
}
|
||||
conv2d_bias(x, self.conv_out_w, self.conv_out_b, 3, 1, 1)
|
||||
}
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
[package]
|
||||
name = "gemma"
|
||||
version = "0.1.0"
|
||||
edition = "2024"
|
||||
edition = "2021"
|
||||
|
||||
[features]
|
||||
|
||||
@@ -22,4 +22,4 @@ serde_json = "1.0"
|
||||
half = {version = "2.7.1", features = ["bytemuck"]}
|
||||
bytemuck = "1.24.0"
|
||||
memmap2 = "0.9.9"
|
||||
rustc-hash = "2.1"
|
||||
rustc-hash = "2.1"
|
||||
@@ -1,7 +1,7 @@
|
||||
use half::{bf16, f16};
|
||||
use hf_hub::api::sync::Api;
|
||||
use memmap2::MmapOptions;
|
||||
use safetensors::{Dtype, SafeTensors, tensor::TensorView};
|
||||
use safetensors::{tensor::TensorView, Dtype, SafeTensors};
|
||||
use serde::Deserialize;
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
|
||||
@@ -13,10 +13,6 @@ use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
||||
|
||||
const REPO_ID: &str = "unsloth/gemma-3-4b-it";
|
||||
|
||||
fn gemma3_chat_prompt(user_prompt: &str) -> String {
|
||||
format!("<bos><start_of_turn>user\n{user_prompt}<end_of_turn>\n<start_of_turn>model\n")
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let max_seq_len = 4096;
|
||||
let gen_tokens = 500;
|
||||
@@ -35,12 +31,7 @@ fn main() {
|
||||
println!("Using model directory: {}", model_dir.display());
|
||||
|
||||
let tokenizer = Tokenizer::from_file(model_dir.join("tokenizer.json")).unwrap();
|
||||
let chat_prompt = gemma3_chat_prompt(prompt);
|
||||
let prompt_tokens = tokenizer
|
||||
.encode(chat_prompt.as_str(), false)
|
||||
.unwrap()
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
let prompt_tokens = tokenizer.encode(prompt, true).unwrap().get_ids().to_vec();
|
||||
|
||||
// Build graph
|
||||
let mut cx = Graph::default();
|
||||
@@ -69,21 +60,10 @@ fn main() {
|
||||
}
|
||||
|
||||
println!("Compiling...");
|
||||
let max_prefill = (prompt_tokens.len() + 16)
|
||||
.next_power_of_two()
|
||||
.min(max_seq_len);
|
||||
let search_s = 16.min(max_prefill).max(2);
|
||||
cx.set_dim_buckets(
|
||||
's',
|
||||
&[
|
||||
DimBucket::new(1, 1),
|
||||
DimBucket::new(2, max_prefill).representative(search_s),
|
||||
],
|
||||
);
|
||||
cx.set_dim('s', search_s);
|
||||
cx.set_dim('p', 0);
|
||||
runtime.set_data(input, vec![1; search_s]);
|
||||
runtime.set_data(token_ids, (0..search_s as i32).collect::<Vec<_>>());
|
||||
cx.set_dim('s', 1);
|
||||
cx.set_dim('p', 1);
|
||||
runtime.set_data(input, vec![1]);
|
||||
runtime.set_data(token_ids, vec![1]);
|
||||
runtime = cx.search(runtime, search_graphs);
|
||||
|
||||
for i in 0..LAYERS {
|
||||
@@ -91,85 +71,26 @@ fn main() {
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
|
||||
}
|
||||
|
||||
let mut prev_seq = 1usize;
|
||||
let mut sentence = vec![prompt_tokens[0]];
|
||||
let total_steps = prompt_tokens.len() - 1 + gen_tokens;
|
||||
let prompt_len = prompt_tokens.len();
|
||||
let mut prev_seq = 0usize;
|
||||
let mut fwd_durations = vec![];
|
||||
let mut seen_tokens = FxHashSet::default();
|
||||
let repetition_penalty: f32 = 1.05;
|
||||
|
||||
const EOS_TOKEN: u32 = 1; // <eos>
|
||||
const STOP_TOKEN: u32 = 106; // <end_of_turn>
|
||||
const EOS_TOKEN: u32 = 1;
|
||||
const STOP_TOKEN: u32 = 107;
|
||||
|
||||
println!(
|
||||
"Prompt: {} tokens, generating up to {} tokens",
|
||||
prompt_len, gen_tokens
|
||||
);
|
||||
|
||||
let mut generated = 0usize;
|
||||
let mut sentence = Vec::new();
|
||||
|
||||
if gen_tokens > 0 && prompt_len > 0 {
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
cx.set_dim('s', prompt_len);
|
||||
cx.set_dim('p', 0);
|
||||
|
||||
runtime.set_data(
|
||||
input,
|
||||
prompt_tokens.iter().map(|t| *t as i32).collect::<Vec<_>>(),
|
||||
);
|
||||
runtime.set_data(token_ids, (0..prompt_len as i32).collect::<Vec<_>>());
|
||||
|
||||
runtime.execute(&cx.dyn_map);
|
||||
let logits_data = runtime.get_f32(logits);
|
||||
|
||||
// Round-trip KV cache
|
||||
for (layer_idx, (k_out, v_out)) in cache_outputs.iter().enumerate() {
|
||||
let k_buf = runtime.remove_buffer(*k_out);
|
||||
let v_buf = runtime.remove_buffer(*v_out);
|
||||
runtime.set_buffer(kv_cache.k_caches[layer_idx], k_buf);
|
||||
runtime.set_buffer(kv_cache.v_caches[layer_idx], v_buf);
|
||||
}
|
||||
|
||||
prev_seq = prompt_len;
|
||||
fwd_durations.push(start.elapsed());
|
||||
|
||||
// Greedy decode with repetition penalty
|
||||
let row_start = (prompt_len - 1) * VOCAB_SIZE;
|
||||
let mut last_row = logits_data[row_start..row_start + VOCAB_SIZE].to_vec();
|
||||
for &tok in &seen_tokens {
|
||||
let logit = &mut last_row[tok as usize];
|
||||
if *logit > 0.0 {
|
||||
*logit /= repetition_penalty;
|
||||
} else {
|
||||
*logit *= repetition_penalty;
|
||||
}
|
||||
}
|
||||
let next_token = last_row
|
||||
.iter()
|
||||
.enumerate()
|
||||
.max_by(|(_, a), (_, b)| a.total_cmp(b))
|
||||
.unwrap()
|
||||
.0 as u32;
|
||||
sentence = vec![next_token];
|
||||
seen_tokens.insert(next_token);
|
||||
generated = 1;
|
||||
|
||||
if next_token != EOS_TOKEN && next_token != STOP_TOKEN {
|
||||
let decoded = tokenizer.decode(&[next_token], true).unwrap();
|
||||
print!("{}", decoded);
|
||||
std::io::stdout().flush().unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
while generated < gen_tokens && !sentence.is_empty() {
|
||||
for i in 0..total_steps {
|
||||
let start = std::time::Instant::now();
|
||||
let is_prefill = i < prompt_len - 1;
|
||||
let seq_len = sentence.len();
|
||||
let current_token = sentence[0];
|
||||
|
||||
if current_token == EOS_TOKEN || current_token == STOP_TOKEN {
|
||||
break;
|
||||
}
|
||||
|
||||
cx.set_dim('s', seq_len);
|
||||
cx.set_dim('p', prev_seq);
|
||||
@@ -197,6 +118,11 @@ fn main() {
|
||||
prev_seq += seq_len;
|
||||
fwd_durations.push(start.elapsed());
|
||||
|
||||
if is_prefill {
|
||||
sentence = vec![prompt_tokens[i + 1]];
|
||||
continue;
|
||||
}
|
||||
|
||||
// Greedy decode with repetition penalty
|
||||
let mut last_row = logits_data[logits_data.len() - VOCAB_SIZE..].to_vec();
|
||||
for &tok in &seen_tokens {
|
||||
@@ -215,7 +141,6 @@ fn main() {
|
||||
.0 as u32;
|
||||
sentence = vec![next_token];
|
||||
seen_tokens.insert(next_token);
|
||||
generated += 1;
|
||||
|
||||
if next_token == EOS_TOKEN || next_token == STOP_TOKEN {
|
||||
break;
|
||||
@@ -228,11 +153,15 @@ fn main() {
|
||||
println!();
|
||||
|
||||
// Benchmarks
|
||||
let decode_durations: Vec<_> = fwd_durations.iter().skip(1).collect();
|
||||
let decode_durations: Vec<_> = fwd_durations.iter().skip(prompt_len).collect();
|
||||
if decode_durations.len() > 2 {
|
||||
println!(
|
||||
" TTFT: {:.2} ms",
|
||||
fwd_durations[..1].iter().sum::<Duration>().as_secs_f64() * 1e3
|
||||
fwd_durations[..prompt_len]
|
||||
.iter()
|
||||
.sum::<Duration>()
|
||||
.as_secs_f64()
|
||||
* 1e3
|
||||
);
|
||||
println!(
|
||||
" TPOT: {:.2} ms",
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
[package]
|
||||
name = "gemma4_moe"
|
||||
version = "0.1.0"
|
||||
edition = "2024"
|
||||
edition = "2021"
|
||||
|
||||
[features]
|
||||
|
||||
@@ -9,7 +9,6 @@ edition = "2024"
|
||||
luminal = { path = "../.." }
|
||||
luminal_nn = { path = "../../crates/luminal_nn" }
|
||||
luminal_cuda_lite = { path = "../../crates/luminal_cuda_lite" }
|
||||
example_common = { path = "../example_common" }
|
||||
tokenizers = "0.22.2"
|
||||
rustc-hash = "2"
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use half::{bf16, f16};
|
||||
use hf_hub::api::sync::Api;
|
||||
use memmap2::MmapOptions;
|
||||
use safetensors::{Dtype, SafeTensors, tensor::TensorView};
|
||||
use safetensors::{tensor::TensorView, Dtype, SafeTensors};
|
||||
use serde::Deserialize;
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
mod hf;
|
||||
mod model;
|
||||
|
||||
use example_common::{BenchEnv, env_bool, has_arg, info, sample_greedy_with_penalty, stdio};
|
||||
use hf::prepare_hf_model;
|
||||
use luminal::prelude::*;
|
||||
use luminal_cuda_lite::{cudarc::driver::CudaContext, runtime::CudaRuntime};
|
||||
@@ -11,140 +10,39 @@ use std::{io::Write, time::Duration};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
const REPO_ID: &str = "google/gemma-4-26B-A4B";
|
||||
const STDIO_MAX_PREFILL: usize = 512;
|
||||
const DEFAULT_GEN_TOKENS: usize = 30;
|
||||
const DEFAULT_SEARCH_GRAPHS: usize = 50;
|
||||
const EOS_TOKEN: u32 = 1;
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn run_one_prompt(
|
||||
cx: &mut Graph,
|
||||
runtime: &mut CudaRuntime,
|
||||
tokenizer: &Tokenizer,
|
||||
input: GraphTensor,
|
||||
pos_ids: GraphTensor,
|
||||
logits: GraphTensor,
|
||||
kv_cache: &KVCache,
|
||||
cache_outputs: &[(GraphTensor, GraphTensor)],
|
||||
max_seq_len: usize,
|
||||
prompt_tokens: &[u32],
|
||||
gen_tokens: usize,
|
||||
repetition_penalty: f32,
|
||||
emit_tok: &mut dyn FnMut(&str),
|
||||
) -> (usize, Duration) {
|
||||
for layer in 0..LAYERS {
|
||||
let cache_bytes = cache_bytes_for_layer(layer, max_seq_len);
|
||||
runtime.set_zeros(kv_cache.k_caches[layer], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[layer], cache_bytes);
|
||||
}
|
||||
fn env_usize(name: &str, default: usize) -> usize {
|
||||
std::env::var(name)
|
||||
.ok()
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(default)
|
||||
}
|
||||
|
||||
let prompt_len = prompt_tokens.len();
|
||||
if prompt_len == 0 || gen_tokens == 0 {
|
||||
return (0, Duration::default());
|
||||
}
|
||||
|
||||
let mut seen_tokens: FxHashSet<u32> = FxHashSet::default();
|
||||
let mut generated = 0usize;
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
cx.set_dim('s', prompt_len);
|
||||
cx.set_dim('p', 0);
|
||||
runtime.set_data(
|
||||
input,
|
||||
prompt_tokens.iter().map(|t| *t as i32).collect::<Vec<_>>(),
|
||||
);
|
||||
runtime.set_data(pos_ids, (0..prompt_len as i32).collect::<Vec<_>>());
|
||||
runtime.execute(&cx.dyn_map);
|
||||
|
||||
for (layer_idx, (k_out, v_out)) in cache_outputs.iter().enumerate() {
|
||||
let k_buf = runtime.remove_buffer(*k_out);
|
||||
let v_buf = runtime.remove_buffer(*v_out);
|
||||
runtime.set_buffer(kv_cache.k_caches[layer_idx], k_buf);
|
||||
runtime.set_buffer(kv_cache.v_caches[layer_idx], v_buf);
|
||||
}
|
||||
let mut prev_seq = prompt_len;
|
||||
|
||||
let logits_data = runtime.get_f32(logits);
|
||||
let row_start = (prompt_len - 1) * VOCAB_SIZE;
|
||||
let mut next_token = sample_greedy_with_penalty(
|
||||
&logits_data[row_start..row_start + VOCAB_SIZE],
|
||||
&seen_tokens,
|
||||
repetition_penalty,
|
||||
);
|
||||
seen_tokens.insert(next_token);
|
||||
generated += 1;
|
||||
if next_token != EOS_TOKEN {
|
||||
let decoded = tokenizer.decode(&[next_token], true).unwrap();
|
||||
emit_tok(&decoded);
|
||||
}
|
||||
|
||||
while generated < gen_tokens {
|
||||
if next_token == EOS_TOKEN {
|
||||
break;
|
||||
}
|
||||
cx.set_dim('s', 1);
|
||||
cx.set_dim('p', prev_seq);
|
||||
runtime.set_data(input, vec![next_token as i32]);
|
||||
runtime.set_data(pos_ids, vec![prev_seq as i32]);
|
||||
runtime.execute(&cx.dyn_map);
|
||||
|
||||
for (layer_idx, (k_out, v_out)) in cache_outputs.iter().enumerate() {
|
||||
let k_buf = runtime.remove_buffer(*k_out);
|
||||
let v_buf = runtime.remove_buffer(*v_out);
|
||||
runtime.set_buffer(kv_cache.k_caches[layer_idx], k_buf);
|
||||
runtime.set_buffer(kv_cache.v_caches[layer_idx], v_buf);
|
||||
}
|
||||
|
||||
prev_seq += 1;
|
||||
|
||||
let logits_data = runtime.get_f32(logits);
|
||||
next_token = sample_greedy_with_penalty(
|
||||
&logits_data[..VOCAB_SIZE],
|
||||
&seen_tokens,
|
||||
repetition_penalty,
|
||||
);
|
||||
seen_tokens.insert(next_token);
|
||||
generated += 1;
|
||||
|
||||
if next_token == EOS_TOKEN {
|
||||
break;
|
||||
}
|
||||
|
||||
let decoded = tokenizer.decode(&[next_token], true).unwrap();
|
||||
emit_tok(&decoded);
|
||||
}
|
||||
|
||||
(generated, start.elapsed())
|
||||
fn env_bool(name: &str) -> bool {
|
||||
std::env::var(name)
|
||||
.ok()
|
||||
.is_some_and(|s| matches!(s.as_str(), "1" | "true" | "TRUE" | "yes" | "YES"))
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let max_seq_len = 4096;
|
||||
let bench = BenchEnv::from_env(DEFAULT_GEN_TOKENS, DEFAULT_SEARCH_GRAPHS);
|
||||
let stdio_mode = has_arg("--stdio");
|
||||
let max_seq_len = env_usize("MAX_SEQ_LEN", 4096);
|
||||
let gen_tokens = env_usize("GEN_TOKENS", 30);
|
||||
let search_graphs = env_usize("SEARCH_GRAPHS", 50);
|
||||
let prompt = std::env::var("PROMPT").unwrap_or_else(|_| "The capital of France is".to_string());
|
||||
let print_token_ids = env_bool("PRINT_TOKEN_IDS");
|
||||
|
||||
let log = |s: &str| info(stdio_mode, s);
|
||||
|
||||
let ctx = CudaContext::new(0).unwrap();
|
||||
let stream = ctx.default_stream();
|
||||
|
||||
let model_dir = prepare_hf_model(REPO_ID).expect("Failed to prepare model");
|
||||
log(&format!("Using model directory: {}", model_dir.display()));
|
||||
println!("Using model directory: {}", model_dir.display());
|
||||
|
||||
let tokenizer = Tokenizer::from_file(model_dir.join("tokenizer.json")).unwrap();
|
||||
|
||||
let (default_prompt_tokens, prompt_len) = if stdio_mode {
|
||||
(Vec::<u32>::new(), 0usize)
|
||||
} else {
|
||||
let toks = tokenizer
|
||||
.encode(prompt.as_str(), true)
|
||||
.unwrap()
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
let len = toks.len();
|
||||
(toks, len)
|
||||
};
|
||||
let prompt_tokens = tokenizer
|
||||
.encode(prompt.as_str(), true)
|
||||
.unwrap()
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
|
||||
let mut cx = Graph::default();
|
||||
let input = cx.named_tensor("input", 's').as_dtype(DType::Int);
|
||||
@@ -157,10 +55,10 @@ fn main() {
|
||||
v_out.output();
|
||||
}
|
||||
|
||||
log("Building E-Graph...");
|
||||
println!("Building E-Graph...");
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
|
||||
log("Loading weights...");
|
||||
println!("Loading weights...");
|
||||
let mut runtime = CudaRuntime::initialize(stream);
|
||||
let weights_path = model_dir.join("model_combined.safetensors");
|
||||
runtime.load_safetensors(&cx, weights_path.to_str().unwrap());
|
||||
@@ -171,25 +69,12 @@ fn main() {
|
||||
runtime.set_zeros(kv_cache.v_caches[layer], cache_bytes);
|
||||
}
|
||||
|
||||
log("Compiling...");
|
||||
let max_prefill = if stdio_mode {
|
||||
STDIO_MAX_PREFILL.min(max_seq_len)
|
||||
} else {
|
||||
(prompt_len + 16).next_power_of_two().min(max_seq_len)
|
||||
};
|
||||
let search_s = 16.min(max_prefill).max(2);
|
||||
cx.set_dim_buckets(
|
||||
's',
|
||||
&[
|
||||
DimBucket::new(1, 1),
|
||||
DimBucket::new(2, max_prefill).representative(search_s),
|
||||
],
|
||||
);
|
||||
cx.set_dim('s', search_s);
|
||||
cx.set_dim('p', 0);
|
||||
runtime.set_data(input, vec![1; search_s]);
|
||||
runtime.set_data(pos_ids, (0..search_s as i32).collect::<Vec<_>>());
|
||||
runtime = cx.search(runtime, bench.search_graphs);
|
||||
println!("Compiling...");
|
||||
cx.set_dim('s', 1);
|
||||
cx.set_dim('p', 1);
|
||||
runtime.set_data(input, vec![1]);
|
||||
runtime.set_data(pos_ids, vec![1]);
|
||||
runtime = cx.search(runtime, search_graphs);
|
||||
|
||||
for layer in 0..LAYERS {
|
||||
let cache_bytes = cache_bytes_for_layer(layer, max_seq_len);
|
||||
@@ -197,80 +82,50 @@ fn main() {
|
||||
runtime.set_zeros(kv_cache.v_caches[layer], cache_bytes);
|
||||
}
|
||||
|
||||
let repetition_penalty: f32 = 1.05;
|
||||
|
||||
if stdio_mode {
|
||||
stdio::serve(|user_prompt| {
|
||||
let prompt_tokens = tokenizer
|
||||
.encode(user_prompt, true)
|
||||
.unwrap()
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
run_one_prompt(
|
||||
&mut cx,
|
||||
&mut runtime,
|
||||
&tokenizer,
|
||||
input,
|
||||
pos_ids,
|
||||
logits,
|
||||
&kv_cache,
|
||||
&cache_outputs,
|
||||
max_seq_len,
|
||||
&prompt_tokens,
|
||||
bench.gen_tokens,
|
||||
repetition_penalty,
|
||||
&mut |s| stdio::emit_tok(s),
|
||||
)
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
// Legacy single-prompt flow.
|
||||
print!("{prompt}");
|
||||
std::io::stdout().flush().unwrap();
|
||||
|
||||
let mut prev_seq = 0usize;
|
||||
let mut fwd_durations = vec![];
|
||||
let mut seen_tokens = FxHashSet::default();
|
||||
let mut generated_token_ids = vec![];
|
||||
let repetition_penalty: f32 = 1.05;
|
||||
|
||||
const EOS_TOKEN: u32 = 1;
|
||||
|
||||
let prefill_start = std::time::Instant::now();
|
||||
cx.set_dim('s', default_prompt_tokens.len());
|
||||
cx.set_dim('p', 0);
|
||||
runtime.set_data(
|
||||
input,
|
||||
default_prompt_tokens
|
||||
.iter()
|
||||
.map(|t| *t as i32)
|
||||
.collect::<Vec<_>>(),
|
||||
);
|
||||
runtime.set_data(
|
||||
pos_ids,
|
||||
(0..default_prompt_tokens.len() as i32).collect::<Vec<_>>(),
|
||||
);
|
||||
runtime.execute(&cx.dyn_map);
|
||||
for &token in &prompt_tokens {
|
||||
cx.set_dim('s', 1);
|
||||
cx.set_dim('p', prev_seq);
|
||||
runtime.set_data(input, vec![token as i32]);
|
||||
runtime.set_data(pos_ids, vec![prev_seq as i32]);
|
||||
runtime.execute(&cx.dyn_map);
|
||||
|
||||
for (layer_idx, (k_out, v_out)) in cache_outputs.iter().enumerate() {
|
||||
let k_buf = runtime.remove_buffer(*k_out);
|
||||
let v_buf = runtime.remove_buffer(*v_out);
|
||||
runtime.set_buffer(kv_cache.k_caches[layer_idx], k_buf);
|
||||
runtime.set_buffer(kv_cache.v_caches[layer_idx], v_buf);
|
||||
for (layer_idx, (k_out, v_out)) in cache_outputs.iter().enumerate() {
|
||||
let k_buf = runtime.remove_buffer(*k_out);
|
||||
let v_buf = runtime.remove_buffer(*v_out);
|
||||
runtime.set_buffer(kv_cache.k_caches[layer_idx], k_buf);
|
||||
runtime.set_buffer(kv_cache.v_caches[layer_idx], v_buf);
|
||||
}
|
||||
|
||||
prev_seq += 1;
|
||||
}
|
||||
let mut prev_seq = default_prompt_tokens.len();
|
||||
let prefill_duration = prefill_start.elapsed();
|
||||
|
||||
let logits_data = runtime.get_f32(logits);
|
||||
let row_start = (default_prompt_tokens.len() - 1) * VOCAB_SIZE;
|
||||
let mut next_token = sample_greedy_with_penalty(
|
||||
&logits_data[row_start..row_start + VOCAB_SIZE],
|
||||
&seen_tokens,
|
||||
repetition_penalty,
|
||||
);
|
||||
let last_row = &logits_data[..VOCAB_SIZE];
|
||||
let mut next_token = last_row
|
||||
.iter()
|
||||
.enumerate()
|
||||
.max_by(|(_, a), (_, b)| a.total_cmp(b))
|
||||
.unwrap()
|
||||
.0 as u32;
|
||||
generated_token_ids.push(next_token);
|
||||
print!("{}", tokenizer.decode(&[next_token], true).unwrap());
|
||||
std::io::stdout().flush().unwrap();
|
||||
seen_tokens.insert(next_token);
|
||||
|
||||
for _ in 1..bench.gen_tokens {
|
||||
for _ in 1..gen_tokens {
|
||||
let start = std::time::Instant::now();
|
||||
cx.set_dim('s', 1);
|
||||
cx.set_dim('p', prev_seq);
|
||||
@@ -288,11 +143,21 @@ fn main() {
|
||||
prev_seq += 1;
|
||||
|
||||
let logits_data = runtime.get_f32(logits);
|
||||
next_token = sample_greedy_with_penalty(
|
||||
&logits_data[..VOCAB_SIZE],
|
||||
&seen_tokens,
|
||||
repetition_penalty,
|
||||
);
|
||||
let mut last_row = logits_data[..VOCAB_SIZE].to_vec();
|
||||
for &tok in &seen_tokens {
|
||||
let logit = &mut last_row[tok as usize];
|
||||
if *logit > 0.0 {
|
||||
*logit /= repetition_penalty;
|
||||
} else {
|
||||
*logit *= repetition_penalty;
|
||||
}
|
||||
}
|
||||
next_token = last_row
|
||||
.iter()
|
||||
.enumerate()
|
||||
.max_by(|(_, a), (_, b)| a.total_cmp(b))
|
||||
.unwrap()
|
||||
.0 as u32;
|
||||
generated_token_ids.push(next_token);
|
||||
seen_tokens.insert(next_token);
|
||||
|
||||
@@ -312,7 +177,7 @@ fn main() {
|
||||
println!(
|
||||
" TTFT: {:.2} ms ({} prompt tokens)",
|
||||
prefill_duration.as_secs_f64() * 1e3,
|
||||
default_prompt_tokens.len()
|
||||
prompt_tokens.len()
|
||||
);
|
||||
if fwd_durations.len() > 1 {
|
||||
println!(
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
[package]
|
||||
name = "llama"
|
||||
version = "0.1.0"
|
||||
edition = "2024"
|
||||
edition = "2021"
|
||||
|
||||
[[bin]]
|
||||
name = "llama"
|
||||
@@ -14,7 +14,6 @@ luminal = { path = "../.." }
|
||||
luminal_nn = { path = "../../crates/luminal_nn" }
|
||||
luminal_cuda_lite = { path = "../../crates/luminal_cuda_lite" }
|
||||
luminal_tracing = {path="../../crates/luminal_tracing"}
|
||||
example_common = { path = "../example_common" }
|
||||
tokenizers = "0.15.2"
|
||||
tracing = "0.1.43"
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||
@@ -28,4 +27,3 @@ half = {version = "2.7.1", features = ["bytemuck"]}
|
||||
bytemuck = "1.24.0"
|
||||
memmap2 = "0.9.9"
|
||||
rustc-hash = "2.1"
|
||||
rand = "0.9.2"
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use half::{bf16, f16};
|
||||
use hf_hub::api::sync::Api;
|
||||
use memmap2::MmapOptions;
|
||||
use safetensors::{Dtype, SafeTensors, tensor::TensorView};
|
||||
use safetensors::{tensor::TensorView, Dtype, SafeTensors};
|
||||
use serde::Deserialize;
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
@@ -17,22 +17,10 @@ struct SafetensorsIndex {
|
||||
weight_map: HashMap<String, String>,
|
||||
}
|
||||
|
||||
/// Stored tensor data with shape, dtype, and serialized bytes.
|
||||
/// Stored tensor data with shape and converted FP32 bytes
|
||||
struct StoredTensor {
|
||||
shape: Vec<usize>,
|
||||
dtype: Dtype,
|
||||
data: Vec<u8>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum WeightFormat {
|
||||
Fp32,
|
||||
Fp8,
|
||||
}
|
||||
|
||||
pub struct PreparedModel {
|
||||
pub model_dir: PathBuf,
|
||||
pub weight_files: Vec<PathBuf>,
|
||||
data: Vec<f32>,
|
||||
}
|
||||
|
||||
/// Downloads model files from HuggingFace and returns the cache directory path.
|
||||
@@ -96,59 +84,6 @@ fn tensor_to_f32(tensor: &safetensors::tensor::TensorView) -> Vec<f32> {
|
||||
}
|
||||
}
|
||||
|
||||
fn tensor_to_f32_bytes(tensor: &safetensors::tensor::TensorView) -> Vec<u8> {
|
||||
let fp32 = tensor_to_f32(tensor);
|
||||
bytemuck::cast_slice(&fp32).to_vec()
|
||||
}
|
||||
|
||||
fn stored_tensor_from_view(
|
||||
tensor: &safetensors::tensor::TensorView,
|
||||
preserve_fp8: bool,
|
||||
) -> StoredTensor {
|
||||
let shape = tensor.shape().to_vec();
|
||||
let dtype = tensor.dtype();
|
||||
match dtype {
|
||||
Dtype::F32 if preserve_fp8 => StoredTensor {
|
||||
shape,
|
||||
dtype,
|
||||
data: tensor.data().to_vec(),
|
||||
},
|
||||
Dtype::F8_E4M3 | Dtype::F8_E5M2 | Dtype::F8_E8M0 if preserve_fp8 => StoredTensor {
|
||||
shape,
|
||||
dtype,
|
||||
data: tensor.data().to_vec(),
|
||||
},
|
||||
Dtype::F32 | Dtype::F16 | Dtype::BF16 => StoredTensor {
|
||||
shape,
|
||||
dtype: Dtype::F32,
|
||||
data: tensor_to_f32_bytes(tensor),
|
||||
},
|
||||
other => {
|
||||
panic!("Unsupported dtype for model preparation: {other:?}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn model_shard_files(model_dir: &Path) -> Result<Vec<PathBuf>, Box<dyn std::error::Error>> {
|
||||
let index_path = model_dir.join("model.safetensors.index.json");
|
||||
let single_shard_path = model_dir.join("model.safetensors");
|
||||
|
||||
if single_shard_path.exists() && !index_path.exists() {
|
||||
Ok(vec![single_shard_path])
|
||||
} else if index_path.exists() {
|
||||
let index_content = std::fs::read_to_string(&index_path)?;
|
||||
let index: SafetensorsIndex = serde_json::from_str(&index_content)?;
|
||||
|
||||
let mut files: Vec<String> = index.weight_map.values().cloned().collect();
|
||||
files.sort();
|
||||
files.dedup();
|
||||
|
||||
Ok(files.into_iter().map(|f| model_dir.join(f)).collect())
|
||||
} else {
|
||||
Err("No model.safetensors or model.safetensors.index.json found".into())
|
||||
}
|
||||
}
|
||||
|
||||
/// Combines sharded safetensors files into a single FP32 file.
|
||||
///
|
||||
/// This function:
|
||||
@@ -165,11 +100,29 @@ pub fn combine_safetensors_to_fp32(
|
||||
return Ok(output_path);
|
||||
}
|
||||
|
||||
let shard_files = model_shard_files(model_dir)?;
|
||||
info!(
|
||||
"Loading {} shard files (converting to FP32)...",
|
||||
shard_files.len()
|
||||
);
|
||||
let index_path = model_dir.join("model.safetensors.index.json");
|
||||
let single_shard_path = model_dir.join("model.safetensors");
|
||||
|
||||
// Determine which shard files to load
|
||||
let shard_files: Vec<PathBuf> = if single_shard_path.exists() && !index_path.exists() {
|
||||
info!("Single shard model detected, converting to FP32...");
|
||||
vec![single_shard_path]
|
||||
} else if index_path.exists() {
|
||||
let index_content = std::fs::read_to_string(&index_path)?;
|
||||
let index: SafetensorsIndex = serde_json::from_str(&index_content)?;
|
||||
|
||||
let mut files: Vec<String> = index.weight_map.values().cloned().collect();
|
||||
files.sort();
|
||||
files.dedup();
|
||||
|
||||
info!(
|
||||
"Loading {} shard files (converting to FP32)...",
|
||||
files.len()
|
||||
);
|
||||
files.into_iter().map(|f| model_dir.join(f)).collect()
|
||||
} else {
|
||||
return Err("No model.safetensors or model.safetensors.index.json found".into());
|
||||
};
|
||||
|
||||
// Load and convert all tensors
|
||||
let mut all_tensors: HashMap<String, StoredTensor> = HashMap::new();
|
||||
@@ -185,7 +138,16 @@ pub fn combine_safetensors_to_fp32(
|
||||
|
||||
for name in st.names() {
|
||||
let tensor = st.tensor(name)?;
|
||||
all_tensors.insert(name.to_string(), stored_tensor_from_view(&tensor, false));
|
||||
let shape: Vec<usize> = tensor.shape().to_vec();
|
||||
let fp32_data = tensor_to_f32(&tensor);
|
||||
|
||||
all_tensors.insert(
|
||||
name.to_string(),
|
||||
StoredTensor {
|
||||
shape,
|
||||
data: fp32_data,
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -197,7 +159,8 @@ pub fn combine_safetensors_to_fp32(
|
||||
let tensor_views: HashMap<String, TensorView<'_>> = all_tensors
|
||||
.iter()
|
||||
.map(|(name, stored)| {
|
||||
let view = TensorView::new(stored.dtype, stored.shape.clone(), &stored.data).unwrap();
|
||||
let data_bytes: &[u8] = bytemuck::cast_slice(&stored.data);
|
||||
let view = TensorView::new(Dtype::F32, stored.shape.clone(), data_bytes).unwrap();
|
||||
(name.clone(), view)
|
||||
})
|
||||
.collect();
|
||||
@@ -211,81 +174,13 @@ pub fn combine_safetensors_to_fp32(
|
||||
Ok(output_path)
|
||||
}
|
||||
|
||||
/// Combines sharded safetensors files into one file while preserving FP8 tensors.
|
||||
///
|
||||
/// Non-FP8 tensors are converted to FP32 so the existing embedding, norm, and
|
||||
/// output-head graph inputs can still load without changing their dtypes.
|
||||
pub fn combine_safetensors_preserve_fp8(
|
||||
model_dir: &Path,
|
||||
) -> Result<PathBuf, Box<dyn std::error::Error>> {
|
||||
let output_path = model_dir.join("model_combined_fp8.safetensors");
|
||||
|
||||
if output_path.exists() {
|
||||
return Ok(output_path);
|
||||
}
|
||||
|
||||
let shard_files = model_shard_files(model_dir)?;
|
||||
info!(
|
||||
"Loading {} shard files (preserving FP8 tensors)...",
|
||||
shard_files.len()
|
||||
);
|
||||
|
||||
let mut all_tensors: HashMap<String, StoredTensor> = HashMap::new();
|
||||
|
||||
for shard_path in &shard_files {
|
||||
info!(
|
||||
" Loading {}...",
|
||||
shard_path.file_name().unwrap().to_string_lossy()
|
||||
);
|
||||
let file = File::open(shard_path)?;
|
||||
let mmap = unsafe { MmapOptions::new().map(&file)? };
|
||||
let st = SafeTensors::deserialize(&mmap)?;
|
||||
|
||||
for name in st.names() {
|
||||
let tensor = st.tensor(name)?;
|
||||
all_tensors.insert(name.to_string(), stored_tensor_from_view(&tensor, true));
|
||||
}
|
||||
}
|
||||
|
||||
info!("Extracted {} language model tensors", all_tensors.len());
|
||||
info!(
|
||||
"Saving mixed FP8/FP32 model to {}...",
|
||||
output_path.display()
|
||||
);
|
||||
|
||||
let tensor_views: HashMap<String, TensorView<'_>> = all_tensors
|
||||
.iter()
|
||||
.map(|(name, stored)| {
|
||||
let view = TensorView::new(stored.dtype, stored.shape.clone(), &stored.data).unwrap();
|
||||
(name.clone(), view)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let serialized = safetensors::serialize(&tensor_views, None)?;
|
||||
|
||||
let mut file = File::create(&output_path)?;
|
||||
file.write_all(&serialized)?;
|
||||
|
||||
info!("Combined mixed FP8/FP32 model saved successfully!");
|
||||
Ok(output_path)
|
||||
}
|
||||
|
||||
/// Downloads a model from HuggingFace and prepares it for use.
|
||||
///
|
||||
/// Returns the path to the model directory containing:
|
||||
/// - tokenizer.json
|
||||
/// - a combined safetensors file for the requested weight format
|
||||
pub fn prepare_hf_model(
|
||||
repo_id: &str,
|
||||
weight_format: WeightFormat,
|
||||
) -> Result<PreparedModel, Box<dyn std::error::Error>> {
|
||||
/// - model_combined.safetensors (FP32)
|
||||
pub fn prepare_hf_model(repo_id: &str) -> Result<PathBuf, Box<dyn std::error::Error>> {
|
||||
let model_dir = download_hf_model(repo_id)?;
|
||||
let weights_path = match weight_format {
|
||||
WeightFormat::Fp32 => combine_safetensors_to_fp32(&model_dir)?,
|
||||
WeightFormat::Fp8 => combine_safetensors_preserve_fp8(&model_dir)?,
|
||||
};
|
||||
Ok(PreparedModel {
|
||||
model_dir,
|
||||
weight_files: vec![weights_path],
|
||||
})
|
||||
combine_safetensors_to_fp32(&model_dir)?;
|
||||
Ok(model_dir)
|
||||
}
|
||||
|
||||
@@ -1,409 +1,51 @@
|
||||
mod hf;
|
||||
mod model;
|
||||
|
||||
use example_common::{BenchEnv, info, sample_greedy_with_penalty, stdio};
|
||||
use hf::{WeightFormat, prepare_hf_model};
|
||||
use hf::prepare_hf_model;
|
||||
use luminal::prelude::*;
|
||||
use luminal_cuda_lite::{cudarc::driver::CudaContext, runtime::CudaRuntime};
|
||||
use luminal_tracing::*;
|
||||
use model::*;
|
||||
use rand::{SeedableRng, rngs::StdRng};
|
||||
use rustc_hash::FxHashSet;
|
||||
use std::{env, io::Write, time::Duration};
|
||||
use std::{io::Write, time::Duration};
|
||||
use tokenizers::Tokenizer;
|
||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
||||
|
||||
const FP32_REPO_ID: &str = "NousResearch/Meta-Llama-3-8B-Instruct";
|
||||
const FP8_REPO_ID: &str = "nvidia/Llama-3.1-8B-Instruct-FP8";
|
||||
const MAX_SEQ_LEN: usize = 4096;
|
||||
const DEFAULT_GEN_TOKENS: usize = 500;
|
||||
const DEFAULT_SEARCH_GRAPHS: usize = 500;
|
||||
const STDIO_MAX_PREFILL: usize = 512;
|
||||
const SEARCH_TRIALS: usize = 1;
|
||||
const SEARCH_KEEP_BEST: usize = 4;
|
||||
const SEARCH_MEMORY_MIB: usize = 2048;
|
||||
const SEARCH_SEED: u64 = 0;
|
||||
const PROMPT: &str = "Explain what a neural network is in a paragraph.";
|
||||
|
||||
const EOS_TOKEN: u32 = 128009;
|
||||
const STOP_TOKEN: u32 = 128001;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
enum LlamaWeightMode {
|
||||
Fp32,
|
||||
Fp8,
|
||||
}
|
||||
|
||||
impl LlamaWeightMode {
|
||||
fn repo_id(self) -> &'static str {
|
||||
match self {
|
||||
Self::Fp32 => FP32_REPO_ID,
|
||||
Self::Fp8 => FP8_REPO_ID,
|
||||
}
|
||||
}
|
||||
|
||||
fn weight_format(self) -> WeightFormat {
|
||||
match self {
|
||||
Self::Fp32 => WeightFormat::Fp32,
|
||||
Self::Fp8 => WeightFormat::Fp8,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct CliArgs {
|
||||
weight_mode: LlamaWeightMode,
|
||||
stdio: bool,
|
||||
}
|
||||
|
||||
fn print_usage(program: &str) {
|
||||
println!("Usage: {program} [--fp8] [--stdio]");
|
||||
println!();
|
||||
println!(" --fp8 Use {FP8_REPO_ID} with FP8 projection weights");
|
||||
println!(" --stdio Long-lived stdio benchmark protocol (READY/TOK/EOQ)");
|
||||
println!(" -h,--help Show this help");
|
||||
}
|
||||
|
||||
fn parse_args() -> CliArgs {
|
||||
let mut weight_mode = LlamaWeightMode::Fp32;
|
||||
let mut stdio = false;
|
||||
let mut args = env::args();
|
||||
let program = args.next().unwrap_or_else(|| "llama".to_string());
|
||||
for arg in args {
|
||||
match arg.as_str() {
|
||||
"--fp8" => weight_mode = LlamaWeightMode::Fp8,
|
||||
"--stdio" => stdio = true,
|
||||
"-h" | "--help" => {
|
||||
print_usage(&program);
|
||||
std::process::exit(0);
|
||||
}
|
||||
_ => {
|
||||
eprintln!("Unknown argument: {arg}");
|
||||
print_usage(&program);
|
||||
std::process::exit(2);
|
||||
}
|
||||
}
|
||||
}
|
||||
CliArgs { weight_mode, stdio }
|
||||
}
|
||||
|
||||
fn llama3_chat_prompt(user_prompt: &str) -> String {
|
||||
format!(
|
||||
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{user_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
}
|
||||
|
||||
#[derive(Default, Clone)]
|
||||
struct StepProfile {
|
||||
total: Duration,
|
||||
set_inputs: Duration,
|
||||
execute: Duration,
|
||||
get_logits: Duration,
|
||||
cache_roundtrip: Duration,
|
||||
sample: Duration,
|
||||
}
|
||||
|
||||
fn sum_profiles<'a>(profiles: impl Iterator<Item = &'a StepProfile>) -> StepProfile {
|
||||
profiles.fold(StepProfile::default(), |mut acc, p| {
|
||||
acc.total += p.total;
|
||||
acc.set_inputs += p.set_inputs;
|
||||
acc.execute += p.execute;
|
||||
acc.get_logits += p.get_logits;
|
||||
acc.cache_roundtrip += p.cache_roundtrip;
|
||||
acc.sample += p.sample;
|
||||
acc
|
||||
})
|
||||
}
|
||||
|
||||
fn avg_ms(duration: Duration, n: usize) -> f64 {
|
||||
if n == 0 {
|
||||
0.0
|
||||
} else {
|
||||
duration.as_secs_f64() * 1e3 / n as f64
|
||||
}
|
||||
}
|
||||
|
||||
fn print_profile(label: &str, profile: &StepProfile, n: usize) {
|
||||
println!(
|
||||
" {label}: n={n}, avg={:.2} ms [set={:.2}, exec={:.2}, logits_dtoh={:.2}, cache={:.2}, sample={:.2}]",
|
||||
avg_ms(profile.total, n),
|
||||
avg_ms(profile.set_inputs, n),
|
||||
avg_ms(profile.execute, n),
|
||||
avg_ms(profile.get_logits, n),
|
||||
avg_ms(profile.cache_roundtrip, n),
|
||||
avg_ms(profile.sample, n),
|
||||
);
|
||||
}
|
||||
|
||||
fn print_host_op_summary(runtime: &CudaRuntime, label: &str, quiet: bool) {
|
||||
if quiet {
|
||||
return;
|
||||
}
|
||||
let host_ops = runtime.host_ops();
|
||||
let debug_ops = host_ops
|
||||
.iter()
|
||||
.map(|op| format!("{op:?}"))
|
||||
.collect::<Vec<_>>();
|
||||
let cublaslt = debug_ops
|
||||
.iter()
|
||||
.filter(|op| op.contains("CuBlasLt"))
|
||||
.count();
|
||||
let fp8_cublaslt = debug_ops
|
||||
.iter()
|
||||
.filter(|op| {
|
||||
op.contains("CuBlasLt") && (op.contains("a_dtype: F8") || op.contains("b_dtype: F8"))
|
||||
})
|
||||
.count();
|
||||
let scaled_fp8_cublaslt = debug_ops
|
||||
.iter()
|
||||
.filter(|op| {
|
||||
op.contains("CuBlasLt")
|
||||
&& (op.contains("a_dtype: F8") || op.contains("b_dtype: F8"))
|
||||
&& op.contains("a_scale_input: true")
|
||||
&& op.contains("b_scale_input: true")
|
||||
})
|
||||
.count();
|
||||
println!(
|
||||
"Host op summary ({label}): total={}, cublasLt={}, fp8_cublasLt={}, scaled_fp8_cublasLt={}",
|
||||
debug_ops.len(),
|
||||
cublaslt,
|
||||
fp8_cublaslt,
|
||||
scaled_fp8_cublaslt
|
||||
);
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn run_model_step(
|
||||
cx: &mut Graph,
|
||||
runtime: &mut CudaRuntime,
|
||||
input: GraphTensor,
|
||||
q_pos_t: GraphTensor,
|
||||
scatter_idx_t: GraphTensor,
|
||||
gather_idx_t: GraphTensor,
|
||||
attn_mask_t: GraphTensor,
|
||||
logits: GraphTensor,
|
||||
kv_cache: &KVCache,
|
||||
cache_outputs: &[(GraphTensor, GraphTensor)],
|
||||
tokens: &[u32],
|
||||
q_pos: &[i32],
|
||||
scatter_idx: &[i32],
|
||||
gather_idx: &[i32],
|
||||
attn_mask: &[f32],
|
||||
) -> (Vec<f32>, StepProfile) {
|
||||
let start = std::time::Instant::now();
|
||||
let seq_len = tokens.len();
|
||||
let mut profile = StepProfile::default();
|
||||
|
||||
cx.set_dim('s', seq_len);
|
||||
cx.set_dim('c', gather_idx.len());
|
||||
|
||||
let set_start = std::time::Instant::now();
|
||||
runtime.set_data(input, tokens.iter().map(|t| *t as i32).collect::<Vec<_>>());
|
||||
runtime.set_data(q_pos_t, q_pos.to_vec());
|
||||
runtime.set_data(scatter_idx_t, scatter_idx.to_vec());
|
||||
runtime.set_data(gather_idx_t, gather_idx.to_vec());
|
||||
runtime.set_data(attn_mask_t, attn_mask.to_vec());
|
||||
profile.set_inputs = set_start.elapsed();
|
||||
|
||||
let execute_start = std::time::Instant::now();
|
||||
runtime.execute(&cx.dyn_map);
|
||||
profile.execute = execute_start.elapsed();
|
||||
|
||||
let logits_start = std::time::Instant::now();
|
||||
let logits_data = runtime.get_f32(logits);
|
||||
profile.get_logits = logits_start.elapsed();
|
||||
|
||||
let cache_start = std::time::Instant::now();
|
||||
for (layer_idx, (k_out, v_out)) in cache_outputs.iter().enumerate() {
|
||||
let k_buf = runtime.remove_buffer(*k_out);
|
||||
let v_buf = runtime.remove_buffer(*v_out);
|
||||
runtime.set_buffer(kv_cache.k_caches[layer_idx], k_buf);
|
||||
runtime.set_buffer(kv_cache.v_caches[layer_idx], v_buf);
|
||||
}
|
||||
profile.cache_roundtrip = cache_start.elapsed();
|
||||
profile.total = start.elapsed();
|
||||
|
||||
(logits_data, profile)
|
||||
}
|
||||
|
||||
fn causal_mask(q_pos: &[usize], context_len: usize) -> Vec<f32> {
|
||||
let mut mask = vec![-1e10f32; q_pos.len() * context_len];
|
||||
for (qi, &pos) in q_pos.iter().enumerate() {
|
||||
for ci in 0..context_len {
|
||||
if ci <= pos {
|
||||
mask[qi * context_len + ci] = 0.0;
|
||||
}
|
||||
}
|
||||
}
|
||||
mask
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn run_one_prompt(
|
||||
cx: &mut Graph,
|
||||
runtime: &mut CudaRuntime,
|
||||
tokenizer: &Tokenizer,
|
||||
input: GraphTensor,
|
||||
q_pos_t: GraphTensor,
|
||||
scatter_idx_t: GraphTensor,
|
||||
gather_idx_t: GraphTensor,
|
||||
attn_mask_t: GraphTensor,
|
||||
logits: GraphTensor,
|
||||
kv_cache: &KVCache,
|
||||
cache_outputs: &[(GraphTensor, GraphTensor)],
|
||||
cache_bytes: usize,
|
||||
prompt_tokens: &[u32],
|
||||
gen_tokens: usize,
|
||||
repetition_penalty: f32,
|
||||
emit_tok: &mut dyn FnMut(&str),
|
||||
) -> (usize, Duration) {
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
|
||||
}
|
||||
|
||||
let prompt_len = prompt_tokens.len();
|
||||
let mut seen_tokens: FxHashSet<u32> = FxHashSet::default();
|
||||
let mut context_len = 0usize;
|
||||
let mut generated = 0usize;
|
||||
let mut next_token: Option<u32> = None;
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
if gen_tokens > 0 && prompt_len > 0 {
|
||||
let positions: Vec<usize> = (0..prompt_len).collect();
|
||||
let q_pos: Vec<i32> = positions.iter().map(|&p| p as i32).collect();
|
||||
let scatter_idx = q_pos.clone();
|
||||
let gather_idx = q_pos.clone();
|
||||
let mask = causal_mask(&positions, prompt_len);
|
||||
let (logits_data, _profile) = run_model_step(
|
||||
cx,
|
||||
runtime,
|
||||
input,
|
||||
q_pos_t,
|
||||
scatter_idx_t,
|
||||
gather_idx_t,
|
||||
attn_mask_t,
|
||||
logits,
|
||||
kv_cache,
|
||||
cache_outputs,
|
||||
prompt_tokens,
|
||||
&q_pos,
|
||||
&scatter_idx,
|
||||
&gather_idx,
|
||||
&mask,
|
||||
);
|
||||
context_len = prompt_len;
|
||||
let token = sample_greedy_with_penalty(
|
||||
&logits_data[logits_data.len() - VOCAB_SIZE..],
|
||||
&seen_tokens,
|
||||
repetition_penalty,
|
||||
);
|
||||
seen_tokens.insert(token);
|
||||
next_token = Some(token);
|
||||
generated = 1;
|
||||
if token != EOS_TOKEN && token != STOP_TOKEN {
|
||||
let decoded = tokenizer.decode(&[token], true).unwrap();
|
||||
emit_tok(&decoded);
|
||||
}
|
||||
}
|
||||
|
||||
while generated < gen_tokens {
|
||||
let current_token = match next_token {
|
||||
Some(token) if token != EOS_TOKEN && token != STOP_TOKEN => token,
|
||||
_ => break,
|
||||
};
|
||||
let (logits_data, _profile) = run_model_step(
|
||||
cx,
|
||||
runtime,
|
||||
input,
|
||||
q_pos_t,
|
||||
scatter_idx_t,
|
||||
gather_idx_t,
|
||||
attn_mask_t,
|
||||
logits,
|
||||
kv_cache,
|
||||
cache_outputs,
|
||||
&[current_token],
|
||||
&[context_len as i32],
|
||||
&[context_len as i32],
|
||||
&(0..=context_len as i32).collect::<Vec<_>>(),
|
||||
&causal_mask(&[context_len], context_len + 1),
|
||||
);
|
||||
context_len += 1;
|
||||
let token = sample_greedy_with_penalty(
|
||||
&logits_data[logits_data.len() - VOCAB_SIZE..],
|
||||
&seen_tokens,
|
||||
repetition_penalty,
|
||||
);
|
||||
seen_tokens.insert(token);
|
||||
next_token = Some(token);
|
||||
generated += 1;
|
||||
if token == EOS_TOKEN || token == STOP_TOKEN {
|
||||
break;
|
||||
}
|
||||
let decoded = tokenizer.decode(&[token], true).unwrap();
|
||||
emit_tok(&decoded);
|
||||
}
|
||||
(generated, start.elapsed())
|
||||
}
|
||||
const REPO_ID: &str = "NousResearch/Meta-Llama-3-8B-Instruct";
|
||||
|
||||
fn main() {
|
||||
let max_seq_len = 4096;
|
||||
let gen_tokens = 500;
|
||||
let search_graphs = 500;
|
||||
let prompt = "Explain what a neural network is in a paragraph.";
|
||||
|
||||
tracing_subscriber::registry()
|
||||
.with(tracing_subscriber::fmt::layer())
|
||||
.with(luminal_filter())
|
||||
.init();
|
||||
|
||||
let cli = parse_args();
|
||||
let stdio_mode = cli.stdio;
|
||||
let weight_mode = cli.weight_mode;
|
||||
|
||||
let bench = BenchEnv::from_env(DEFAULT_GEN_TOKENS, DEFAULT_SEARCH_GRAPHS);
|
||||
let log = |s: &str| info(stdio_mode, s);
|
||||
|
||||
let ctx = CudaContext::new(0).unwrap();
|
||||
let stream = ctx.default_stream();
|
||||
|
||||
let prepared = prepare_hf_model(weight_mode.repo_id(), weight_mode.weight_format())
|
||||
.expect("Failed to prepare model");
|
||||
log(&format!("Using model: {}", weight_mode.repo_id()));
|
||||
log(&format!(
|
||||
"Using model directory: {}",
|
||||
prepared.model_dir.display()
|
||||
));
|
||||
let model_dir = prepare_hf_model(REPO_ID).expect("Failed to prepare model");
|
||||
println!("Using model directory: {}", model_dir.display());
|
||||
|
||||
let tokenizer = Tokenizer::from_file(prepared.model_dir.join("tokenizer.json")).unwrap();
|
||||
|
||||
let (prompt_tokens_default, prompt_len) = if stdio_mode {
|
||||
(Vec::<u32>::new(), 0usize)
|
||||
} else {
|
||||
let chat_prompt = llama3_chat_prompt(PROMPT);
|
||||
let toks = tokenizer
|
||||
.encode(chat_prompt.as_str(), false)
|
||||
.unwrap()
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
let len = toks.len();
|
||||
(toks, len)
|
||||
};
|
||||
let tokenizer = Tokenizer::from_file(model_dir.join("tokenizer.json")).unwrap();
|
||||
let chat_prompt = format!(
|
||||
"<|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
);
|
||||
let prompt_tokens = tokenizer
|
||||
.encode(chat_prompt.as_str(), true)
|
||||
.unwrap()
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
|
||||
// Build graph
|
||||
let mut cx = Graph::default();
|
||||
let input = cx.named_tensor("input", 's').as_dtype(DType::Int);
|
||||
let q_pos_t = cx.named_tensor("q_pos", 's').as_dtype(DType::Int);
|
||||
let scatter_idx_t = cx.named_tensor("scatter_idx", 's').as_dtype(DType::Int);
|
||||
let gather_idx_t = cx.named_tensor("gather_idx", 'c').as_dtype(DType::Int);
|
||||
let attn_mask_t = cx.named_tensor("attn_mask", ('s', 'c'));
|
||||
let kv_cache = KVCache::new(&mut cx, MAX_SEQ_LEN);
|
||||
let llama = match weight_mode {
|
||||
LlamaWeightMode::Fp32 => Llama::init(&mut cx),
|
||||
LlamaWeightMode::Fp8 => Llama::init_fp8(&mut cx),
|
||||
};
|
||||
let (logits, cache_outputs) = llama.forward(
|
||||
input,
|
||||
q_pos_t,
|
||||
scatter_idx_t,
|
||||
gather_idx_t,
|
||||
attn_mask_t,
|
||||
&kv_cache,
|
||||
);
|
||||
let token_ids = cx.named_tensor("token_ids", 's').as_dtype(DType::Int);
|
||||
let kv_cache = KVCache::new(&mut cx, max_seq_len);
|
||||
let (logits, cache_outputs) = Llama::init(&mut cx).forward(input, token_ids, &kv_cache);
|
||||
let logits = logits.output();
|
||||
for (k_out, v_out) in &cache_outputs {
|
||||
k_out.output();
|
||||
@@ -411,236 +53,124 @@ fn main() {
|
||||
}
|
||||
|
||||
cx.set_dim('s', 1);
|
||||
cx.set_dim('c', 1);
|
||||
cx.set_dim('p', 1);
|
||||
|
||||
log("Building E-Graph...");
|
||||
let egraph_start = std::time::Instant::now();
|
||||
println!("Building E-Graph...");
|
||||
cx.build_search_space_with_options::<CudaRuntime>(
|
||||
BuildSearchSpaceOptions::new().max_memory_mib(SEARCH_MEMORY_MIB),
|
||||
BuildSearchSpaceOptions::new().max_memory_mib(500),
|
||||
);
|
||||
log(&format!(
|
||||
" E-Graph build: {:.2} s",
|
||||
egraph_start.elapsed().as_secs_f64()
|
||||
));
|
||||
|
||||
log("Loading weights...");
|
||||
let load_start = std::time::Instant::now();
|
||||
println!("Loading weights...");
|
||||
let mut runtime = CudaRuntime::initialize(stream);
|
||||
for weights_path in &prepared.weight_files {
|
||||
log(&format!(" Loading {}", weights_path.display()));
|
||||
runtime.load_safetensors(&cx, weights_path.to_str().unwrap());
|
||||
}
|
||||
log(&format!(
|
||||
" Weight load: {:.2} s",
|
||||
load_start.elapsed().as_secs_f64()
|
||||
));
|
||||
let weights_path = model_dir.join("model_combined.safetensors");
|
||||
runtime.load_safetensors(&cx, weights_path.to_str().unwrap());
|
||||
|
||||
let cache_bytes = MAX_SEQ_LEN * KV_DIM * std::mem::size_of::<f32>();
|
||||
let cache_bytes = N_KV_HEADS * max_seq_len * HEAD_DIM * std::mem::size_of::<f32>();
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
|
||||
}
|
||||
|
||||
log("Compiling...");
|
||||
let compile_start = std::time::Instant::now();
|
||||
let max_prefill = if stdio_mode {
|
||||
STDIO_MAX_PREFILL.min(MAX_SEQ_LEN)
|
||||
} else {
|
||||
(prompt_len + 16).next_power_of_two().min(MAX_SEQ_LEN)
|
||||
};
|
||||
let search_s = 16.min(max_prefill).max(2);
|
||||
cx.set_dim_buckets(
|
||||
's',
|
||||
&[
|
||||
DimBucket::new(1, 1),
|
||||
DimBucket::new(2, max_prefill).representative(search_s),
|
||||
],
|
||||
);
|
||||
cx.set_dim('s', search_s);
|
||||
cx.set_dim('c', search_s);
|
||||
runtime.set_data(input, vec![1; search_s]);
|
||||
runtime.set_data(q_pos_t, (0..search_s as i32).collect::<Vec<_>>());
|
||||
runtime.set_data(scatter_idx_t, (0..search_s as i32).collect::<Vec<_>>());
|
||||
runtime.set_data(gather_idx_t, (0..search_s as i32).collect::<Vec<_>>());
|
||||
runtime.set_data(attn_mask_t, vec![0.0f32; search_s * search_s]);
|
||||
log(&format!(" Search seed: {SEARCH_SEED}"));
|
||||
log(&format!(" Search trials: {SEARCH_TRIALS}"));
|
||||
log(&format!(" Search keep-best: {SEARCH_KEEP_BEST}"));
|
||||
let mut rng = StdRng::seed_from_u64(SEARCH_SEED);
|
||||
runtime = cx.search_options(
|
||||
runtime,
|
||||
SearchOptions::new(bench.search_graphs)
|
||||
.trials(SEARCH_TRIALS)
|
||||
.keep_best(SEARCH_KEEP_BEST),
|
||||
&mut rng,
|
||||
);
|
||||
log(&format!(
|
||||
" Search/compile: {:.2} s",
|
||||
compile_start.elapsed().as_secs_f64()
|
||||
));
|
||||
print_host_op_summary(&runtime, "post-compile active bucket", stdio_mode);
|
||||
println!("Compiling...");
|
||||
cx.set_dim('s', 1);
|
||||
cx.set_dim('p', 1);
|
||||
runtime.set_data(input, vec![1]);
|
||||
runtime.set_data(token_ids, vec![1]);
|
||||
runtime = cx.search(runtime, search_graphs);
|
||||
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
|
||||
}
|
||||
|
||||
let mut prev_seq = 1usize;
|
||||
let mut sentence = vec![prompt_tokens[0]];
|
||||
let total_steps = prompt_tokens.len() - 1 + gen_tokens;
|
||||
let prompt_len = prompt_tokens.len();
|
||||
let mut fwd_durations = vec![];
|
||||
let mut seen_tokens = FxHashSet::default();
|
||||
let repetition_penalty: f32 = 1.05;
|
||||
|
||||
if stdio_mode {
|
||||
stdio::serve(|user_prompt| {
|
||||
let chat_prompt = llama3_chat_prompt(user_prompt);
|
||||
let prompt_tokens = tokenizer
|
||||
.encode(chat_prompt.as_str(), false)
|
||||
.unwrap()
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
run_one_prompt(
|
||||
&mut cx,
|
||||
&mut runtime,
|
||||
&tokenizer,
|
||||
input,
|
||||
q_pos_t,
|
||||
scatter_idx_t,
|
||||
gather_idx_t,
|
||||
attn_mask_t,
|
||||
logits,
|
||||
&kv_cache,
|
||||
&cache_outputs,
|
||||
cache_bytes,
|
||||
&prompt_tokens,
|
||||
bench.gen_tokens,
|
||||
repetition_penalty,
|
||||
&mut |s| stdio::emit_tok(s),
|
||||
)
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
// Non-stdio: legacy single-prompt flow with profiling output.
|
||||
let mut context_len = 0usize;
|
||||
let mut fwd_durations = vec![];
|
||||
let mut step_profiles = vec![];
|
||||
let mut seen_tokens = FxHashSet::default();
|
||||
const EOS_TOKEN: u32 = 128009;
|
||||
const STOP_TOKEN: u32 = 128001;
|
||||
|
||||
println!(
|
||||
"Prompt: {} tokens, generating up to {} tokens",
|
||||
prompt_len, bench.gen_tokens
|
||||
prompt_len, gen_tokens
|
||||
);
|
||||
|
||||
let mut generated = 0usize;
|
||||
let mut next_token = None;
|
||||
if bench.gen_tokens > 0 && prompt_len > 0 {
|
||||
let positions: Vec<usize> = (0..prompt_len).collect();
|
||||
let q_pos: Vec<i32> = positions.iter().map(|&p| p as i32).collect();
|
||||
let scatter_idx = q_pos.clone();
|
||||
let gather_idx = q_pos.clone();
|
||||
let mask = causal_mask(&positions, prompt_len);
|
||||
let (logits_data, mut profile) = run_model_step(
|
||||
&mut cx,
|
||||
&mut runtime,
|
||||
for i in 0..total_steps {
|
||||
let start = std::time::Instant::now();
|
||||
let is_prefill = i < prompt_len - 1;
|
||||
let seq_len = sentence.len();
|
||||
|
||||
cx.set_dim('s', seq_len);
|
||||
cx.set_dim('p', prev_seq);
|
||||
|
||||
runtime.set_data(
|
||||
input,
|
||||
q_pos_t,
|
||||
scatter_idx_t,
|
||||
gather_idx_t,
|
||||
attn_mask_t,
|
||||
logits,
|
||||
&kv_cache,
|
||||
&cache_outputs,
|
||||
&prompt_tokens_default,
|
||||
&q_pos,
|
||||
&scatter_idx,
|
||||
&gather_idx,
|
||||
&mask,
|
||||
sentence.iter().map(|t| *t as i32).collect::<Vec<_>>(),
|
||||
);
|
||||
print_host_op_summary(&runtime, "after prefill", false);
|
||||
context_len = prompt_len;
|
||||
|
||||
let sample_start = std::time::Instant::now();
|
||||
let token = sample_greedy_with_penalty(
|
||||
&logits_data[logits_data.len() - VOCAB_SIZE..],
|
||||
&seen_tokens,
|
||||
repetition_penalty,
|
||||
runtime.set_data(
|
||||
token_ids,
|
||||
(prev_seq as i32..(seq_len + prev_seq) as i32).collect::<Vec<_>>(),
|
||||
);
|
||||
profile.sample = sample_start.elapsed();
|
||||
profile.total += profile.sample;
|
||||
seen_tokens.insert(token);
|
||||
next_token = Some(token);
|
||||
generated = 1;
|
||||
|
||||
fwd_durations.push(profile.total);
|
||||
step_profiles.push(profile);
|
||||
runtime.execute(&cx.dyn_map);
|
||||
let logits_data = runtime.get_f32(logits);
|
||||
|
||||
if token != EOS_TOKEN && token != STOP_TOKEN {
|
||||
let decoded = tokenizer.decode(&[token], true).unwrap();
|
||||
print!("{}", decoded);
|
||||
std::io::stdout().flush().unwrap();
|
||||
// Round-trip KV cache
|
||||
for (layer_idx, (k_out, v_out)) in cache_outputs.iter().enumerate() {
|
||||
let k_buf = runtime.remove_buffer(*k_out);
|
||||
let v_buf = runtime.remove_buffer(*v_out);
|
||||
runtime.set_buffer(kv_cache.k_caches[layer_idx], k_buf);
|
||||
runtime.set_buffer(kv_cache.v_caches[layer_idx], v_buf);
|
||||
}
|
||||
}
|
||||
|
||||
while generated < bench.gen_tokens {
|
||||
let current_token = match next_token {
|
||||
Some(token) if token != EOS_TOKEN && token != STOP_TOKEN => token,
|
||||
_ => break,
|
||||
};
|
||||
prev_seq += seq_len;
|
||||
fwd_durations.push(start.elapsed());
|
||||
|
||||
let (logits_data, mut profile) = run_model_step(
|
||||
&mut cx,
|
||||
&mut runtime,
|
||||
input,
|
||||
q_pos_t,
|
||||
scatter_idx_t,
|
||||
gather_idx_t,
|
||||
attn_mask_t,
|
||||
logits,
|
||||
&kv_cache,
|
||||
&cache_outputs,
|
||||
&[current_token],
|
||||
&[context_len as i32],
|
||||
&[context_len as i32],
|
||||
&(0..=context_len as i32).collect::<Vec<_>>(),
|
||||
&causal_mask(&[context_len], context_len + 1),
|
||||
);
|
||||
if generated == 1 {
|
||||
print_host_op_summary(&runtime, "after first decode", false);
|
||||
if is_prefill {
|
||||
sentence = vec![prompt_tokens[i + 1]];
|
||||
continue;
|
||||
}
|
||||
context_len += 1;
|
||||
|
||||
let sample_start = std::time::Instant::now();
|
||||
let token = sample_greedy_with_penalty(
|
||||
&logits_data[logits_data.len() - VOCAB_SIZE..],
|
||||
&seen_tokens,
|
||||
repetition_penalty,
|
||||
);
|
||||
profile.sample = sample_start.elapsed();
|
||||
profile.total += profile.sample;
|
||||
seen_tokens.insert(token);
|
||||
next_token = Some(token);
|
||||
generated += 1;
|
||||
// Greedy decode with repetition penalty
|
||||
let mut last_row = logits_data[logits_data.len() - VOCAB_SIZE..].to_vec();
|
||||
for &tok in &seen_tokens {
|
||||
let logit = &mut last_row[tok as usize];
|
||||
if *logit > 0.0 {
|
||||
*logit /= repetition_penalty;
|
||||
} else {
|
||||
*logit *= repetition_penalty;
|
||||
}
|
||||
}
|
||||
let next_token = last_row
|
||||
.iter()
|
||||
.enumerate()
|
||||
.max_by(|(_, a), (_, b)| a.total_cmp(b))
|
||||
.unwrap()
|
||||
.0 as u32;
|
||||
sentence = vec![next_token];
|
||||
seen_tokens.insert(next_token);
|
||||
|
||||
fwd_durations.push(profile.total);
|
||||
step_profiles.push(profile);
|
||||
|
||||
if token == EOS_TOKEN || token == STOP_TOKEN {
|
||||
if next_token == EOS_TOKEN || next_token == STOP_TOKEN {
|
||||
break;
|
||||
}
|
||||
|
||||
let decoded = tokenizer.decode(&[token], true).unwrap();
|
||||
let decoded = tokenizer.decode(&[next_token], true).unwrap();
|
||||
print!("{}", decoded);
|
||||
std::io::stdout().flush().unwrap();
|
||||
}
|
||||
println!();
|
||||
|
||||
// Benchmarks
|
||||
let prefill_steps = usize::from(!step_profiles.is_empty());
|
||||
let ttft_steps = prefill_steps;
|
||||
let decode_durations: Vec<_> = fwd_durations.iter().skip(ttft_steps).collect();
|
||||
let decode_durations: Vec<_> = fwd_durations.iter().skip(prompt_len).collect();
|
||||
if decode_durations.len() > 2 {
|
||||
println!(
|
||||
" TTFT: {:.2} ms",
|
||||
fwd_durations
|
||||
fwd_durations[..prompt_len]
|
||||
.iter()
|
||||
.take(ttft_steps)
|
||||
.sum::<Duration>()
|
||||
.as_secs_f64()
|
||||
* 1e3
|
||||
@@ -653,18 +183,4 @@ fn main() {
|
||||
* 1_000.
|
||||
);
|
||||
}
|
||||
println!("\nProfile breakdown:");
|
||||
let decode_steps = step_profiles.len().saturating_sub(ttft_steps);
|
||||
let prefill = sum_profiles(step_profiles.iter().take(prefill_steps));
|
||||
let decode = sum_profiles(step_profiles.iter().skip(ttft_steps));
|
||||
print_profile("batched prefill", &prefill, prefill_steps);
|
||||
print_profile("steady decode", &decode, decode_steps);
|
||||
if ttft_steps > 0 {
|
||||
let ttft = sum_profiles(step_profiles.iter().take(ttft_steps));
|
||||
println!(
|
||||
" TTFT attribution: {:.2} ms [batched prefill {:.2}]",
|
||||
ttft.total.as_secs_f64() * 1e3,
|
||||
prefill.total.as_secs_f64() * 1e3,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,8 +2,9 @@ use luminal::{
|
||||
dtype::DType,
|
||||
graph::Graph,
|
||||
prelude::{F32Pow, GraphTensor},
|
||||
shape::Expression,
|
||||
};
|
||||
use luminal_nn::{LayerNorm, gather_rows, scatter_rows};
|
||||
use luminal_nn::LayerNorm;
|
||||
|
||||
// Llama 3 8B hyperparams
|
||||
pub const LAYERS: usize = 32;
|
||||
@@ -13,80 +14,37 @@ pub const HEAD_DIM: usize = 128;
|
||||
pub const KV_GROUPS: usize = 4;
|
||||
pub const VOCAB_SIZE: usize = 128256;
|
||||
pub const N_KV_HEADS: usize = HIDDEN / HEAD_DIM / KV_GROUPS; // 8
|
||||
#[allow(dead_code)]
|
||||
pub const N_HEADS: usize = HIDDEN / HEAD_DIM; // 32
|
||||
pub const KV_DIM: usize = N_KV_HEADS * HEAD_DIM; // 1024
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub struct LlamaConfig {
|
||||
pub layers: usize,
|
||||
pub hidden: usize,
|
||||
pub intermediate: usize,
|
||||
pub head_dim: usize,
|
||||
pub kv_groups: usize,
|
||||
pub vocab_size: usize,
|
||||
}
|
||||
|
||||
impl Default for LlamaConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
layers: LAYERS,
|
||||
hidden: HIDDEN,
|
||||
intermediate: INTERMEDIATE,
|
||||
head_dim: HEAD_DIM,
|
||||
kv_groups: KV_GROUPS,
|
||||
vocab_size: VOCAB_SIZE,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl LlamaConfig {
|
||||
pub fn n_heads(self) -> usize {
|
||||
self.hidden / self.head_dim
|
||||
}
|
||||
|
||||
pub fn n_kv_heads(self) -> usize {
|
||||
self.hidden / self.head_dim / self.kv_groups
|
||||
}
|
||||
|
||||
pub fn kv_dim(self) -> usize {
|
||||
self.n_kv_heads() * self.head_dim
|
||||
}
|
||||
}
|
||||
|
||||
pub struct KVCache {
|
||||
pub k_caches: Vec<GraphTensor>,
|
||||
pub v_caches: Vec<GraphTensor>,
|
||||
pub max_seq: usize,
|
||||
}
|
||||
|
||||
impl KVCache {
|
||||
pub fn new(cx: &mut Graph, num_slots: usize) -> Self {
|
||||
Self::new_with_config(cx, num_slots, LlamaConfig::default())
|
||||
}
|
||||
|
||||
pub fn new_with_config(cx: &mut Graph, num_slots: usize, config: LlamaConfig) -> Self {
|
||||
let kv_dim = config.kv_dim();
|
||||
let mut k_caches = Vec::with_capacity(config.layers);
|
||||
let mut v_caches = Vec::with_capacity(config.layers);
|
||||
for l in 0..config.layers {
|
||||
k_caches.push(cx.named_tensor(format!("kv_cache.{l}.k"), (num_slots, kv_dim)));
|
||||
v_caches.push(cx.named_tensor(format!("kv_cache.{l}.v"), (num_slots, kv_dim)));
|
||||
pub fn new(cx: &mut Graph, max_seq: usize) -> Self {
|
||||
let mut k_caches = Vec::with_capacity(LAYERS);
|
||||
let mut v_caches = Vec::with_capacity(LAYERS);
|
||||
for l in 0..LAYERS {
|
||||
let k = cx
|
||||
.named_tensor(format!("kv_cache.{l}.k"), (N_KV_HEADS, max_seq, HEAD_DIM))
|
||||
.persist();
|
||||
let v = cx
|
||||
.named_tensor(format!("kv_cache.{l}.v"), (N_KV_HEADS, max_seq, HEAD_DIM))
|
||||
.persist();
|
||||
k_caches.push(k);
|
||||
v_caches.push(v);
|
||||
}
|
||||
Self {
|
||||
k_caches,
|
||||
v_caches,
|
||||
max_seq,
|
||||
}
|
||||
Self { k_caches, v_caches }
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn tensors(&self) -> Vec<GraphTensor> {
|
||||
self.k_caches
|
||||
.iter()
|
||||
.chain(self.v_caches.iter())
|
||||
.copied()
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Llama {
|
||||
config: LlamaConfig,
|
||||
embedding: GraphTensor,
|
||||
layers: Vec<LlamaLayer>,
|
||||
lm_norm: LayerNorm,
|
||||
@@ -95,105 +53,60 @@ pub struct Llama {
|
||||
|
||||
impl Llama {
|
||||
pub fn init(cx: &mut Graph) -> Self {
|
||||
Self::init_with_config(cx, LlamaConfig::default())
|
||||
}
|
||||
|
||||
pub fn init_fp8(cx: &mut Graph) -> Self {
|
||||
Self::init_with_config_and_fp8(cx, LlamaConfig::default(), true)
|
||||
}
|
||||
|
||||
pub fn init_with_config(cx: &mut Graph, config: LlamaConfig) -> Self {
|
||||
Self::init_with_config_and_fp8(cx, config, false)
|
||||
}
|
||||
|
||||
pub fn init_with_config_and_fp8(
|
||||
cx: &mut Graph,
|
||||
config: LlamaConfig,
|
||||
fp8_linears: bool,
|
||||
) -> Self {
|
||||
let mut layers = Vec::with_capacity(config.layers);
|
||||
for l in 0..config.layers {
|
||||
layers.push(LlamaLayer {
|
||||
config,
|
||||
up: linear_weight(
|
||||
cx,
|
||||
format!("model.layers.{l}.mlp.up_proj"),
|
||||
(config.intermediate, config.hidden),
|
||||
fp8_linears,
|
||||
),
|
||||
up_scales: fp8_linear_scales(
|
||||
cx,
|
||||
format!("model.layers.{l}.mlp.up_proj"),
|
||||
fp8_linears,
|
||||
),
|
||||
gate: linear_weight(
|
||||
cx,
|
||||
format!("model.layers.{l}.mlp.gate_proj"),
|
||||
(config.intermediate, config.hidden),
|
||||
fp8_linears,
|
||||
),
|
||||
gate_scales: fp8_linear_scales(
|
||||
cx,
|
||||
format!("model.layers.{l}.mlp.gate_proj"),
|
||||
fp8_linears,
|
||||
),
|
||||
down: linear_weight(
|
||||
cx,
|
||||
format!("model.layers.{l}.mlp.down_proj"),
|
||||
(config.hidden, config.intermediate),
|
||||
fp8_linears,
|
||||
),
|
||||
down_scales: fp8_linear_scales(
|
||||
cx,
|
||||
format!("model.layers.{l}.mlp.down_proj"),
|
||||
fp8_linears,
|
||||
),
|
||||
q_proj: linear_weight(
|
||||
cx,
|
||||
format!("model.layers.{l}.self_attn.q_proj"),
|
||||
(config.hidden, config.hidden),
|
||||
fp8_linears,
|
||||
),
|
||||
q_proj_scales: fp8_linear_scales(
|
||||
cx,
|
||||
format!("model.layers.{l}.self_attn.q_proj"),
|
||||
fp8_linears,
|
||||
),
|
||||
k_proj: linear_weight(
|
||||
cx,
|
||||
format!("model.layers.{l}.self_attn.k_proj"),
|
||||
(config.kv_dim(), config.hidden),
|
||||
fp8_linears,
|
||||
),
|
||||
k_proj_scales: fp8_linear_scales(
|
||||
cx,
|
||||
format!("model.layers.{l}.self_attn.k_proj"),
|
||||
fp8_linears,
|
||||
),
|
||||
v_proj: linear_weight(
|
||||
cx,
|
||||
format!("model.layers.{l}.self_attn.v_proj"),
|
||||
(config.kv_dim(), config.hidden),
|
||||
fp8_linears,
|
||||
),
|
||||
v_proj_scales: fp8_linear_scales(
|
||||
cx,
|
||||
format!("model.layers.{l}.self_attn.v_proj"),
|
||||
fp8_linears,
|
||||
),
|
||||
o_proj: linear_weight(
|
||||
cx,
|
||||
format!("model.layers.{l}.self_attn.o_proj"),
|
||||
(config.hidden, config.hidden),
|
||||
fp8_linears,
|
||||
),
|
||||
o_proj_scales: fp8_linear_scales(
|
||||
cx,
|
||||
format!("model.layers.{l}.self_attn.o_proj"),
|
||||
fp8_linears,
|
||||
),
|
||||
let mut w = vec![];
|
||||
for l in 0..LAYERS {
|
||||
let up = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.mlp.up_proj.weight"),
|
||||
(INTERMEDIATE, HIDDEN),
|
||||
)
|
||||
.persist();
|
||||
let gate = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.mlp.gate_proj.weight"),
|
||||
(INTERMEDIATE, HIDDEN),
|
||||
)
|
||||
.persist();
|
||||
let down = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.mlp.down_proj.weight"),
|
||||
(HIDDEN, INTERMEDIATE),
|
||||
)
|
||||
.persist();
|
||||
let q_proj = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.self_attn.q_proj.weight"),
|
||||
(HIDDEN, HIDDEN),
|
||||
)
|
||||
.persist();
|
||||
let k_proj = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.self_attn.k_proj.weight"),
|
||||
(HIDDEN / KV_GROUPS, HIDDEN),
|
||||
)
|
||||
.persist();
|
||||
let v_proj = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.self_attn.v_proj.weight"),
|
||||
(HIDDEN / KV_GROUPS, HIDDEN),
|
||||
)
|
||||
.persist();
|
||||
let o_proj = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.self_attn.o_proj.weight"),
|
||||
(HIDDEN, HIDDEN),
|
||||
)
|
||||
.persist();
|
||||
w.push(LlamaLayer {
|
||||
up,
|
||||
gate,
|
||||
down,
|
||||
q_proj,
|
||||
k_proj,
|
||||
v_proj,
|
||||
o_proj,
|
||||
attn_rms: LayerNorm::new(
|
||||
config.hidden,
|
||||
HIDDEN,
|
||||
Some(&format!("model.layers.{l}.input_layernorm.weight")),
|
||||
None,
|
||||
false,
|
||||
@@ -201,7 +114,7 @@ impl Llama {
|
||||
cx,
|
||||
),
|
||||
mlp_rms: LayerNorm::new(
|
||||
config.hidden,
|
||||
HIDDEN,
|
||||
Some(&format!("model.layers.{l}.post_attention_layernorm.weight")),
|
||||
None,
|
||||
false,
|
||||
@@ -210,320 +123,189 @@ impl Llama {
|
||||
),
|
||||
});
|
||||
}
|
||||
let lm_norm = LayerNorm::new(HIDDEN, Some("model.norm.weight"), None, false, 1e-5, cx);
|
||||
let lm_head = cx
|
||||
.named_tensor("lm_head.weight", (VOCAB_SIZE, HIDDEN))
|
||||
.persist();
|
||||
let embedding = cx
|
||||
.named_tensor("model.embed_tokens.weight", (VOCAB_SIZE, HIDDEN))
|
||||
.persist();
|
||||
Self {
|
||||
config,
|
||||
embedding: cx
|
||||
.named_tensor(
|
||||
"model.embed_tokens.weight",
|
||||
(config.vocab_size, config.hidden),
|
||||
)
|
||||
.persist(),
|
||||
layers,
|
||||
lm_head: cx
|
||||
.named_tensor("lm_head.weight", (config.vocab_size, config.hidden))
|
||||
.persist(),
|
||||
lm_norm: LayerNorm::new(
|
||||
config.hidden,
|
||||
Some("model.norm.weight"),
|
||||
None,
|
||||
false,
|
||||
1e-5,
|
||||
cx,
|
||||
),
|
||||
embedding,
|
||||
layers: w,
|
||||
lm_head,
|
||||
lm_norm,
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub fn forward(
|
||||
&self,
|
||||
input: GraphTensor,
|
||||
q_pos: GraphTensor,
|
||||
scatter_idx: GraphTensor,
|
||||
gather_idx: GraphTensor,
|
||||
attn_mask: GraphTensor,
|
||||
token_ids: GraphTensor,
|
||||
pos_ids: GraphTensor,
|
||||
kv_cache: &KVCache,
|
||||
) -> (GraphTensor, Vec<(GraphTensor, GraphTensor)>) {
|
||||
let seq = input.dims1();
|
||||
let hidden = self.config.hidden;
|
||||
let seq = token_ids.dims1();
|
||||
let mut x = self.embedding.gather(
|
||||
(input * hidden).expand_dim(1, hidden)
|
||||
+ input.graph().arange(hidden).expand_dim(0, seq),
|
||||
(token_ids * HIDDEN).expand_dim(1, HIDDEN)
|
||||
+ token_ids.graph().arange(HIDDEN).expand_dim(0, seq),
|
||||
);
|
||||
let mut cache_outputs = Vec::with_capacity(self.config.layers);
|
||||
let mut cache_outputs = Vec::with_capacity(LAYERS);
|
||||
for (i, layer) in self.layers.iter().enumerate() {
|
||||
let (x_new, k_out, v_out) = layer.forward(
|
||||
x,
|
||||
q_pos,
|
||||
scatter_idx,
|
||||
gather_idx,
|
||||
attn_mask,
|
||||
pos_ids,
|
||||
kv_cache.k_caches[i],
|
||||
kv_cache.v_caches[i],
|
||||
kv_cache.max_seq,
|
||||
);
|
||||
x = x_new;
|
||||
//x = x_new.graph_break();
|
||||
cache_outputs.push((k_out, v_out));
|
||||
}
|
||||
let logits = self.lm_norm.forward(x).matmul(self.lm_head.t());
|
||||
(logits, cache_outputs)
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn parameter_tensors(&self) -> Vec<GraphTensor> {
|
||||
let mut tensors = Vec::new();
|
||||
tensors.push(self.embedding);
|
||||
for layer in &self.layers {
|
||||
tensors.extend(layer.parameter_tensors());
|
||||
}
|
||||
if let Some(weight) = self.lm_norm.weight {
|
||||
tensors.push(weight);
|
||||
}
|
||||
if let Some(bias) = self.lm_norm.bias {
|
||||
tensors.push(bias);
|
||||
}
|
||||
tensors.push(self.lm_head);
|
||||
tensors
|
||||
}
|
||||
}
|
||||
|
||||
struct LlamaLayer {
|
||||
config: LlamaConfig,
|
||||
up: GraphTensor,
|
||||
up_scales: Option<Fp8LinearScales>,
|
||||
gate: GraphTensor,
|
||||
gate_scales: Option<Fp8LinearScales>,
|
||||
down: GraphTensor,
|
||||
down_scales: Option<Fp8LinearScales>,
|
||||
q_proj: GraphTensor,
|
||||
q_proj_scales: Option<Fp8LinearScales>,
|
||||
k_proj: GraphTensor,
|
||||
k_proj_scales: Option<Fp8LinearScales>,
|
||||
v_proj: GraphTensor,
|
||||
v_proj_scales: Option<Fp8LinearScales>,
|
||||
o_proj: GraphTensor,
|
||||
o_proj_scales: Option<Fp8LinearScales>,
|
||||
attn_rms: LayerNorm,
|
||||
mlp_rms: LayerNorm,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
struct Fp8LinearScales {
|
||||
input: GraphTensor,
|
||||
weight: GraphTensor,
|
||||
}
|
||||
|
||||
fn linear_weight(
|
||||
cx: &mut Graph,
|
||||
prefix: impl ToString,
|
||||
shape: impl luminal::prelude::ToShape,
|
||||
fp8: bool,
|
||||
) -> GraphTensor {
|
||||
let tensor = cx.named_tensor(format!("{}.weight", prefix.to_string()), shape);
|
||||
if fp8 {
|
||||
tensor.as_dtype(DType::F8E4M3).persist()
|
||||
} else {
|
||||
tensor.persist()
|
||||
}
|
||||
}
|
||||
|
||||
fn fp8_linear_scales(cx: &mut Graph, prefix: impl ToString, fp8: bool) -> Option<Fp8LinearScales> {
|
||||
if !fp8 {
|
||||
return None;
|
||||
}
|
||||
let prefix = prefix.to_string();
|
||||
Some(Fp8LinearScales {
|
||||
input: cx
|
||||
.named_tensor(format!("{prefix}.input_scale"), ())
|
||||
.persist(),
|
||||
weight: cx
|
||||
.named_tensor(format!("{prefix}.weight_scale"), ())
|
||||
.persist(),
|
||||
})
|
||||
}
|
||||
|
||||
fn expand_scalar(scale: GraphTensor, like: GraphTensor) -> GraphTensor {
|
||||
scale.expand_rhs(like.dims())
|
||||
}
|
||||
|
||||
fn linear_matmul(
|
||||
input: GraphTensor,
|
||||
weight: GraphTensor,
|
||||
scales: Option<Fp8LinearScales>,
|
||||
) -> GraphTensor {
|
||||
if let Some(scales) = scales {
|
||||
let input_scale = expand_scalar(scales.input, input);
|
||||
let scaled_input = input / input_scale;
|
||||
let output = scaled_input
|
||||
.cast(DType::F8E4M3)
|
||||
.matmul(weight.t())
|
||||
.cast(DType::F32);
|
||||
let output_scale = expand_scalar(scales.input * scales.weight, output);
|
||||
output * output_scale
|
||||
} else {
|
||||
input.matmul(weight.t())
|
||||
}
|
||||
}
|
||||
|
||||
fn llama_rotary_embeddings(
|
||||
mut input: GraphTensor,
|
||||
pos_ids: GraphTensor,
|
||||
config: LlamaConfig,
|
||||
) -> GraphTensor {
|
||||
input = input.split_dims(1, config.head_dim).transpose(0, 1);
|
||||
fn llama_rotary_embeddings(mut input: GraphTensor, pos_ids: GraphTensor) -> GraphTensor {
|
||||
// Input: [seq, dim]
|
||||
input = input.split_dims(1, HEAD_DIM).transpose(0, 1); // n_heads, seq, head_dim
|
||||
|
||||
// Get freqs
|
||||
let freqs = input
|
||||
.graph()
|
||||
.arange_options(0, config.head_dim, 2)
|
||||
.arange_options(0, HEAD_DIM, 2)
|
||||
.cast(DType::F32)
|
||||
/ config.head_dim as f32;
|
||||
/ HEAD_DIM as f32;
|
||||
let inv_freqs = 500_000_f32.pow(freqs).reciprocal();
|
||||
let emb = pos_ids
|
||||
.cast(DType::F32)
|
||||
.expand_dim(1, 1)
|
||||
.matmul(inv_freqs.expand_dim(0, 1));
|
||||
|
||||
let x0 = input.slice((.., .., ..config.head_dim / 2));
|
||||
let x1 = input.slice((.., .., config.head_dim / 2..));
|
||||
// Split into first half and second half (Llama "half" rotation convention)
|
||||
let x0 = input.slice((.., .., ..HEAD_DIM / 2));
|
||||
let x1 = input.slice((.., .., HEAD_DIM / 2..));
|
||||
|
||||
// Apply sin and cos embeddings
|
||||
let cos = emb.cos().expand_dim(0, x0.dims()[0]);
|
||||
let sin = emb.sin().expand_dim(0, x0.dims()[0]);
|
||||
let x0_out = x0 * cos - x1 * sin;
|
||||
let x1_out = x1 * cos + x0 * sin;
|
||||
|
||||
// Combine back: [n_heads, seq, HEAD_DIM] -> [seq, n_heads, HEAD_DIM] -> [seq, dim]
|
||||
x0_out
|
||||
.concat_along(x1_out, 2)
|
||||
.transpose(0, 1)
|
||||
.merge_dims(1, 2)
|
||||
}
|
||||
|
||||
struct AttentionInputs {
|
||||
q_rope: GraphTensor,
|
||||
k_rope: GraphTensor,
|
||||
v: GraphTensor,
|
||||
k_cache: GraphTensor,
|
||||
v_cache: GraphTensor,
|
||||
scatter_idx: GraphTensor,
|
||||
gather_idx: GraphTensor,
|
||||
attn_mask: GraphTensor,
|
||||
}
|
||||
|
||||
fn attention(
|
||||
AttentionInputs {
|
||||
q_rope,
|
||||
k_rope,
|
||||
v,
|
||||
k_cache,
|
||||
v_cache,
|
||||
scatter_idx,
|
||||
gather_idx,
|
||||
attn_mask,
|
||||
}: AttentionInputs,
|
||||
config: LlamaConfig,
|
||||
/// HLIR attention with pre-allocated KV cache using scatter.
|
||||
/// Returns (attn_output, k_cache_updated, v_cache_updated).
|
||||
fn hlir_attention(
|
||||
q_rope: GraphTensor, // [seq, HIDDEN]
|
||||
k_rope: GraphTensor, // [seq, HIDDEN/KV_GROUPS]
|
||||
v: GraphTensor, // [seq, HIDDEN/KV_GROUPS]
|
||||
k_cache_in: GraphTensor, // [N_KV_HEADS, max_seq, HEAD_DIM]
|
||||
v_cache_in: GraphTensor, // [N_KV_HEADS, max_seq, HEAD_DIM]
|
||||
max_seq: usize,
|
||||
) -> (GraphTensor, GraphTensor, GraphTensor) {
|
||||
let kv_dim = config.kv_dim();
|
||||
let k_cache_out = scatter_rows(k_rope, scatter_idx, k_cache, kv_dim);
|
||||
let v_cache_out = scatter_rows(v, scatter_idx, v_cache, kv_dim);
|
||||
let cx = q_rope.graph();
|
||||
let seq = q_rope.dims()[0]; // Expression 's'
|
||||
let prev = Expression::from('p');
|
||||
let total_seq = prev + seq;
|
||||
|
||||
let k = gather_rows(k_cache_out, gather_idx, kv_dim);
|
||||
let v_ctx = gather_rows(v_cache_out, gather_idx, kv_dim);
|
||||
// Reshape new K, V: [seq, kv_dim] -> [N_KV_HEADS, seq, HEAD_DIM]
|
||||
let k_new = k_rope.split_dims(1, HEAD_DIM).transpose(0, 1);
|
||||
let v_new = v.split_dims(1, HEAD_DIM).transpose(0, 1);
|
||||
|
||||
let q = (q_rope * 1.0)
|
||||
.split_dims(1, config.head_dim)
|
||||
.transpose(0, 1);
|
||||
let k = k.split_dims(1, config.head_dim).permute((1, 2, 0));
|
||||
let v_ctx = v_ctx.split_dims(1, config.head_dim).transpose(0, 1);
|
||||
// Build flat scatter indices for cache positions [prev..prev+seq]
|
||||
// Cache layout: [N_KV_HEADS, max_seq, HEAD_DIM], flat index = h*max_seq*HEAD_DIM + (prev+s)*HEAD_DIM + d
|
||||
let h_offset = cx.arange(N_KV_HEADS) * (max_seq * HEAD_DIM);
|
||||
let p_offset = (cx.arange(seq) + prev) * HEAD_DIM;
|
||||
let d_offset = cx.arange(HEAD_DIM);
|
||||
let scatter_idx = h_offset.expand_dim(1, seq).expand_dim(2, HEAD_DIM)
|
||||
+ p_offset.expand_dim(0, N_KV_HEADS).expand_dim(2, HEAD_DIM)
|
||||
+ d_offset.expand_dim(0, N_KV_HEADS).expand_dim(1, seq);
|
||||
|
||||
let k = k.expand_dim(1, config.kv_groups).merge_dims(0, 1) * 1.0;
|
||||
let v_ctx = v_ctx.expand_dim(1, config.kv_groups).merge_dims(0, 1) * 1.0;
|
||||
// Scatter new K/V into cache
|
||||
let k_cache_out = k_new.scatter(scatter_idx, k_cache_in);
|
||||
let v_cache_out = v_new.scatter(scatter_idx, v_cache_in);
|
||||
|
||||
let scores = q.matmul(k) / (config.head_dim as f32).sqrt();
|
||||
let masked_scores = scores + attn_mask.expand_dim(0, config.n_heads());
|
||||
let weights = masked_scores.softmax(2);
|
||||
let out = weights.matmul(v_ctx);
|
||||
let attn_out = out.transpose(0, 1).merge_dims(1, 2);
|
||||
// Slice to valid range: [N_KV_HEADS, total_seq, HEAD_DIM]
|
||||
let mut k_full = k_cache_out.slice((.., ..total_seq, ..));
|
||||
let mut v_full = v_cache_out.slice((.., ..total_seq, ..));
|
||||
// LUM-545: model invariant `prev + seq <= max_seq`, but the frontend
|
||||
// cannot yet propagate expression-bound assertions, so `slice` reports
|
||||
// `min(max_seq, p+s)`. Normalize the visible cache axis to `total_seq`.
|
||||
k_full.shape.dims[1] = total_seq;
|
||||
v_full.shape.dims[1] = total_seq;
|
||||
|
||||
(attn_out, k_cache_out, v_cache_out)
|
||||
// GQA expand: [N_KV_HEADS, total_seq, HEAD_DIM] -> [N_HEADS, total_seq, HEAD_DIM]
|
||||
let k_3d = k_full.expand_dim(1, KV_GROUPS).merge_dims(0, 1);
|
||||
let v_3d = v_full.expand_dim(1, KV_GROUPS).merge_dims(0, 1);
|
||||
|
||||
// Q: [seq, HIDDEN] -> [N_HEADS, seq, HEAD_DIM]
|
||||
let q = q_rope.split_dims(1, HEAD_DIM).transpose(0, 1);
|
||||
|
||||
// Attention scores: Q @ K^T / sqrt(d)
|
||||
// 3D matmul: [N_HEADS, seq, HEAD_DIM] x [N_HEADS, HEAD_DIM, total_seq] -> [N_HEADS, seq, total_seq]
|
||||
let scores = q.matmul(k_3d.transpose(1, 2)) / (HEAD_DIM as f32).sqrt();
|
||||
|
||||
// Causal mask: mask positions where k_pos > prev + q_local_pos
|
||||
let q_abs = cx.arange(seq).cast(DType::F32) + prev;
|
||||
let k_pos = cx.arange(total_seq).cast(DType::F32);
|
||||
let mask = k_pos.expand_dim(0, seq).gt(q_abs.expand_dim(1, total_seq));
|
||||
let mask_3d = mask.cast(DType::F32).expand_dim(0, N_HEADS);
|
||||
let masked_scores = scores + mask_3d * (-1e10f32);
|
||||
|
||||
// Softmax along key dimension
|
||||
let attn_weights = masked_scores.softmax(2);
|
||||
|
||||
// Weighted sum: [N_HEADS, seq, total_seq] x [N_HEADS, total_seq, HEAD_DIM] -> [N_HEADS, seq, HEAD_DIM]
|
||||
let attn_out = attn_weights.matmul(v_3d);
|
||||
|
||||
// Reshape: [N_HEADS, seq, HEAD_DIM] -> [seq, N_HEADS, HEAD_DIM] -> [seq, HIDDEN]
|
||||
let out = attn_out.transpose(0, 1).merge_dims(1, 2);
|
||||
|
||||
(out, k_cache_out, v_cache_out)
|
||||
}
|
||||
|
||||
impl LlamaLayer {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn forward(
|
||||
&self,
|
||||
mut x: GraphTensor,
|
||||
q_pos: GraphTensor,
|
||||
scatter_idx: GraphTensor,
|
||||
gather_idx: GraphTensor,
|
||||
attn_mask: GraphTensor,
|
||||
k_cache: GraphTensor,
|
||||
v_cache: GraphTensor,
|
||||
pos_ids: GraphTensor,
|
||||
k_cache_in: GraphTensor,
|
||||
v_cache_in: GraphTensor,
|
||||
max_seq: usize,
|
||||
) -> (GraphTensor, GraphTensor, GraphTensor) {
|
||||
let x_attn = self.attn_rms.forward(x);
|
||||
let q = linear_matmul(x_attn, self.q_proj, self.q_proj_scales);
|
||||
let k = linear_matmul(x_attn, self.k_proj, self.k_proj_scales);
|
||||
let v = linear_matmul(x_attn, self.v_proj, self.v_proj_scales);
|
||||
|
||||
let q_rope = llama_rotary_embeddings(q, q_pos, self.config);
|
||||
let k_rope = llama_rotary_embeddings(k, q_pos, self.config);
|
||||
|
||||
let (attn_out, k_cache_out, v_cache_out) = attention(
|
||||
AttentionInputs {
|
||||
q_rope,
|
||||
k_rope,
|
||||
v,
|
||||
k_cache,
|
||||
v_cache,
|
||||
scatter_idx,
|
||||
gather_idx,
|
||||
attn_mask,
|
||||
},
|
||||
self.config,
|
||||
);
|
||||
x += linear_matmul(attn_out, self.o_proj, self.o_proj_scales);
|
||||
let q = x_attn.matmul(self.q_proj.t());
|
||||
let k = x_attn.matmul(self.k_proj.t());
|
||||
let v = x_attn.matmul(self.v_proj.t());
|
||||
let q_rope = llama_rotary_embeddings(q, pos_ids);
|
||||
let k_rope = llama_rotary_embeddings(k, pos_ids);
|
||||
let (attn_out, k_cache_out, v_cache_out) =
|
||||
hlir_attention(q_rope, k_rope, v, k_cache_in, v_cache_in, max_seq);
|
||||
x += attn_out.matmul(self.o_proj.t());
|
||||
|
||||
let x_mlp = self.mlp_rms.forward(x);
|
||||
let mlp_out = linear_matmul(x_mlp, self.gate, self.gate_scales).swish()
|
||||
* linear_matmul(x_mlp, self.up, self.up_scales);
|
||||
let mlp_out = linear_matmul(mlp_out, self.down, self.down_scales);
|
||||
let mlp_out =
|
||||
(x_mlp.matmul(self.gate.t()).swish() * x_mlp.matmul(self.up.t())).matmul(self.down.t());
|
||||
(x + mlp_out, k_cache_out, v_cache_out)
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn parameter_tensors(&self) -> Vec<GraphTensor> {
|
||||
let mut tensors = vec![
|
||||
self.up,
|
||||
self.gate,
|
||||
self.down,
|
||||
self.q_proj,
|
||||
self.k_proj,
|
||||
self.v_proj,
|
||||
self.o_proj,
|
||||
];
|
||||
for scales in [
|
||||
self.up_scales,
|
||||
self.gate_scales,
|
||||
self.down_scales,
|
||||
self.q_proj_scales,
|
||||
self.k_proj_scales,
|
||||
self.v_proj_scales,
|
||||
self.o_proj_scales,
|
||||
]
|
||||
.into_iter()
|
||||
.flatten()
|
||||
{
|
||||
tensors.push(scales.input);
|
||||
tensors.push(scales.weight);
|
||||
}
|
||||
if let Some(weight) = self.attn_rms.weight {
|
||||
tensors.push(weight);
|
||||
}
|
||||
if let Some(bias) = self.attn_rms.bias {
|
||||
tensors.push(bias);
|
||||
}
|
||||
if let Some(weight) = self.mlp_rms.weight {
|
||||
tensors.push(weight);
|
||||
}
|
||||
if let Some(bias) = self.mlp_rms.bias {
|
||||
tensors.push(bias);
|
||||
}
|
||||
tensors
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
[package]
|
||||
name = "paged_llama"
|
||||
version = "0.1.0"
|
||||
edition = "2024"
|
||||
edition = "2021"
|
||||
|
||||
[[bin]]
|
||||
name = "paged_llama"
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use half::{bf16, f16};
|
||||
use hf_hub::api::sync::Api;
|
||||
use memmap2::MmapOptions;
|
||||
use safetensors::{Dtype, SafeTensors, tensor::TensorView};
|
||||
use safetensors::{tensor::TensorView, Dtype, SafeTensors};
|
||||
use serde::Deserialize;
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
|
||||
@@ -13,12 +13,6 @@ use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
||||
|
||||
const REPO_ID: &str = "NousResearch/Meta-Llama-3-8B-Instruct";
|
||||
|
||||
fn llama3_chat_prompt(user_prompt: &str) -> String {
|
||||
format!(
|
||||
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{user_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
}
|
||||
|
||||
struct PageTable {
|
||||
tables: Vec<Vec<usize>>,
|
||||
next_free_slot: usize,
|
||||
@@ -141,15 +135,8 @@ fn tick(
|
||||
const EOS_TOKEN: u32 = 128009;
|
||||
const STOP_TOKEN: u32 = 128001;
|
||||
|
||||
fn env_usize(name: &str, default: usize) -> usize {
|
||||
std::env::var(name)
|
||||
.ok()
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(default)
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let num_slots = env_usize("NUM_SLOTS", 8192);
|
||||
let num_slots = 8192;
|
||||
let search_graphs = 100;
|
||||
let gen_tokens = 30;
|
||||
let prompt_a = "Explain what a neural network is in a paragraph.";
|
||||
@@ -169,9 +156,11 @@ fn main() {
|
||||
let tokenizer = Tokenizer::from_file(model_dir.join("tokenizer.json")).unwrap();
|
||||
|
||||
let encode = |prompt: &str| -> Vec<u32> {
|
||||
let chat = llama3_chat_prompt(prompt);
|
||||
let chat = format!(
|
||||
"<|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
);
|
||||
tokenizer
|
||||
.encode(chat.as_str(), false)
|
||||
.encode(chat.as_str(), true)
|
||||
.unwrap()
|
||||
.get_ids()
|
||||
.to_vec()
|
||||
|
||||
@@ -3,7 +3,7 @@ use luminal::{
|
||||
graph::Graph,
|
||||
prelude::{F32Pow, GraphTensor},
|
||||
};
|
||||
use luminal_nn::{LayerNorm, gather_rows, scatter_rows};
|
||||
use luminal_nn::{gather_rows, scatter_rows, LayerNorm};
|
||||
|
||||
// Llama 3 8B hyperparams
|
||||
pub const LAYERS: usize = 32;
|
||||
|
||||
@@ -1,34 +1,29 @@
|
||||
[package]
|
||||
name = "qwen"
|
||||
version = "0.1.0"
|
||||
edition = "2024"
|
||||
edition = "2021"
|
||||
|
||||
[[bin]]
|
||||
name = "qwen"
|
||||
path = "src/main.rs"
|
||||
|
||||
[features]
|
||||
default = []
|
||||
cuda = ["dep:luminal_cuda_lite"]
|
||||
metal = ["dep:luminal_metal"]
|
||||
|
||||
[dependencies]
|
||||
luminal = { path = "../.." }
|
||||
luminal_nn = { path = "../../crates/luminal_nn" }
|
||||
luminal_tracing = { path = "../../crates/luminal_tracing" }
|
||||
luminal_cuda_lite = { path = "../../crates/luminal_cuda_lite", optional = true }
|
||||
luminal_metal = { path = "../../crates/luminal_metal", optional = true }
|
||||
example_common = { path = "../example_common" }
|
||||
luminal_cuda_lite = { path = "../../crates/luminal_cuda_lite" }
|
||||
luminal_tracing = {path="../../crates/luminal_tracing"}
|
||||
tokenizers = "0.22.2"
|
||||
tracing = "0.1.43"
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||
|
||||
# HuggingFace model download
|
||||
hf-hub = { version = "0.4", default-features = false, features = ["rustls-tls", "ureq"] }
|
||||
hf-hub = {version = "0.4", default-features = false, features = ["rustls-tls", "ureq"]}
|
||||
safetensors = "0.7.0"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde = {version = "1.0", features = ["derive"]}
|
||||
serde_json = "1.0"
|
||||
half = { version = "2.7.1", features = ["bytemuck"] }
|
||||
half = {version = "2.7.1", features = ["bytemuck"]}
|
||||
bytemuck = "1.24.0"
|
||||
memmap2 = "0.9.9"
|
||||
rustc-hash = "2.1"
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
# Qwen3 4B
|
||||
|
||||
Run Qwen3-4B through Luminal's CUDA backend:
|
||||
|
||||
```bash
|
||||
cargo run --release -p qwen --features cuda
|
||||
```
|
||||
|
||||
Run Qwen3-4B through Luminal's Metal backend on Apple targets:
|
||||
|
||||
```bash
|
||||
cargo run --release -p qwen --features metal
|
||||
```
|
||||
|
||||
The first run downloads `Qwen/Qwen3-4B`, converts the safetensors weights to a combined FP32 file, compiles the selected backend graph, and then generates text.
|
||||
@@ -1,7 +1,7 @@
|
||||
use half::{bf16, f16};
|
||||
use hf_hub::api::sync::Api;
|
||||
use memmap2::MmapOptions;
|
||||
use safetensors::{Dtype, SafeTensors, tensor::TensorView};
|
||||
use safetensors::{tensor::TensorView, Dtype, SafeTensors};
|
||||
use serde::Deserialize;
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
@@ -16,141 +16,10 @@ struct SafetensorsIndex {
|
||||
weight_map: HashMap<String, String>,
|
||||
}
|
||||
|
||||
/// Stored tensor data with shape and converted FP32 bytes
|
||||
struct StoredTensor {
|
||||
name: String,
|
||||
shape: Vec<usize>,
|
||||
data: Vec<u8>,
|
||||
}
|
||||
|
||||
fn format_bytes(bytes: u64) -> String {
|
||||
const GIB: f64 = 1024.0 * 1024.0 * 1024.0;
|
||||
const MIB: f64 = 1024.0 * 1024.0;
|
||||
if bytes >= 1024 * 1024 * 1024 {
|
||||
format!("{:.2} GiB", bytes as f64 / GIB)
|
||||
} else if bytes >= 1024 * 1024 {
|
||||
format!("{:.2} MiB", bytes as f64 / MIB)
|
||||
} else {
|
||||
format!("{bytes} bytes")
|
||||
}
|
||||
}
|
||||
|
||||
fn tensor_to_f32(tensor: &safetensors::tensor::TensorView<'_>) -> Vec<u8> {
|
||||
match tensor.dtype() {
|
||||
Dtype::F32 => tensor.data().to_vec(),
|
||||
Dtype::F16 => bytemuck::cast_slice::<f32, u8>(
|
||||
&bytemuck::cast_slice::<u8, f16>(tensor.data())
|
||||
.iter()
|
||||
.map(|v| v.to_f32())
|
||||
.collect::<Vec<_>>(),
|
||||
)
|
||||
.to_vec(),
|
||||
Dtype::BF16 => bytemuck::cast_slice::<f32, u8>(
|
||||
&bytemuck::cast_slice::<u8, bf16>(tensor.data())
|
||||
.iter()
|
||||
.map(|v| v.to_f32())
|
||||
.collect::<Vec<_>>(),
|
||||
)
|
||||
.to_vec(),
|
||||
dtype => panic!("Unsupported safetensors dtype {dtype:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
fn combine_safetensors_to_fp32(model_dir: &Path) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let output_path = model_dir.join("model_combined.safetensors");
|
||||
if output_path.exists() {
|
||||
let existing_bytes = std::fs::metadata(&output_path)?.len();
|
||||
println!(
|
||||
"Using existing combined FP32 model at {} ({})",
|
||||
output_path.display(),
|
||||
format_bytes(existing_bytes)
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let single_path = model_dir.join("model.safetensors");
|
||||
if single_path.exists() {
|
||||
let bytes = std::fs::metadata(&single_path)?.len();
|
||||
println!(
|
||||
"Using single safetensors model at {} ({})",
|
||||
single_path.display(),
|
||||
format_bytes(bytes)
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let index_path = model_dir.join("model.safetensors.index.json");
|
||||
let index_content = std::fs::read_to_string(&index_path)?;
|
||||
let index: SafetensorsIndex = serde_json::from_str(&index_content)?;
|
||||
let mut shard_files: Vec<String> = index.weight_map.values().cloned().collect();
|
||||
shard_files.sort();
|
||||
shard_files.dedup();
|
||||
let original_bytes = shard_files.iter().try_fold(0u64, |acc, shard_file| {
|
||||
Ok::<_, std::io::Error>(acc + std::fs::metadata(model_dir.join(shard_file))?.len())
|
||||
})?;
|
||||
|
||||
println!(
|
||||
"Loading {} shard files ({} original bytes, converting to FP32)...",
|
||||
shard_files.len(),
|
||||
format_bytes(original_bytes)
|
||||
);
|
||||
|
||||
let mut tensors = Vec::new();
|
||||
for shard_file in &shard_files {
|
||||
println!(" Loading {shard_file}...");
|
||||
let shard_path = model_dir.join(shard_file);
|
||||
let file = File::open(&shard_path)?;
|
||||
let mmap = unsafe { MmapOptions::new().map(&file)? };
|
||||
let safetensors = SafeTensors::deserialize(&mmap)?;
|
||||
|
||||
for name in safetensors.names() {
|
||||
let tensor = safetensors.tensor(name)?;
|
||||
tensors.push(StoredTensor {
|
||||
name: name.to_string(),
|
||||
shape: tensor.shape().to_vec(),
|
||||
data: tensor_to_f32(&tensor),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
let total_params: usize = tensors
|
||||
.iter()
|
||||
.map(|tensor| tensor.shape.iter().product::<usize>())
|
||||
.sum();
|
||||
let raw_fp32_bytes: u64 = (total_params * std::mem::size_of::<f32>()) as u64;
|
||||
println!(
|
||||
"Extracted {} tensors: {} params, {} raw FP32 tensor payload",
|
||||
tensors.len(),
|
||||
total_params,
|
||||
format_bytes(raw_fp32_bytes)
|
||||
);
|
||||
|
||||
let tensor_map: HashMap<String, TensorView<'_>> = tensors
|
||||
.iter()
|
||||
.map(|tensor| {
|
||||
(
|
||||
tensor.name.clone(),
|
||||
TensorView::new(Dtype::F32, tensor.shape.clone(), &tensor.data).unwrap(),
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
let serialized = safetensors::serialize(&tensor_map, None)?;
|
||||
println!(
|
||||
"Serialized combined FP32 safetensors size: {}",
|
||||
format_bytes(serialized.len() as u64)
|
||||
);
|
||||
|
||||
println!("Removing original shards before saving combined FP32 model...");
|
||||
for shard_file in &shard_files {
|
||||
std::fs::remove_file(model_dir.join(shard_file))?;
|
||||
}
|
||||
std::fs::remove_file(&index_path)?;
|
||||
|
||||
println!("Saving combined FP32 model to {}...", output_path.display());
|
||||
let mut output_file = File::create(&output_path)?;
|
||||
output_file.write_all(&serialized)?;
|
||||
println!("Done!");
|
||||
|
||||
Ok(())
|
||||
data: Vec<f32>,
|
||||
}
|
||||
|
||||
/// Downloads model files from HuggingFace and returns the cache directory path.
|
||||
@@ -180,11 +49,121 @@ pub fn download_hf_model(repo_id: &str) -> Result<PathBuf, Box<dyn std::error::E
|
||||
Ok(model_dir)
|
||||
}
|
||||
|
||||
/// Convert tensor data to f32 vec
|
||||
fn tensor_to_f32(tensor: &safetensors::tensor::TensorView) -> Vec<f32> {
|
||||
let dtype = tensor.dtype();
|
||||
let data = tensor.data();
|
||||
|
||||
match dtype {
|
||||
Dtype::F32 => bytemuck::cast_slice::<u8, f32>(data).to_vec(),
|
||||
Dtype::F16 => {
|
||||
let f16_slice: &[f16] = bytemuck::cast_slice(data);
|
||||
f16_slice.iter().map(|x| x.to_f32()).collect()
|
||||
}
|
||||
Dtype::BF16 => {
|
||||
let bf16_slice: &[bf16] = bytemuck::cast_slice(data);
|
||||
bf16_slice.iter().map(|x| x.to_f32()).collect()
|
||||
}
|
||||
other => {
|
||||
panic!("Unsupported dtype for conversion: {other:?}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Combines sharded safetensors files into a single FP32 file.
|
||||
///
|
||||
/// This function:
|
||||
/// 1. Loads tensors from shard(s)
|
||||
/// 2. Converts all to FP32 and writes combined file
|
||||
pub fn combine_safetensors_to_fp32(
|
||||
model_dir: &Path,
|
||||
) -> Result<PathBuf, Box<dyn std::error::Error>> {
|
||||
let output_path = model_dir.join("model_combined.safetensors");
|
||||
|
||||
// Skip if already combined
|
||||
if output_path.exists() {
|
||||
return Ok(output_path);
|
||||
}
|
||||
|
||||
let index_path = model_dir.join("model.safetensors.index.json");
|
||||
let single_shard_path = model_dir.join("model.safetensors");
|
||||
|
||||
// Determine which shard files to load
|
||||
let shard_files: Vec<PathBuf> = if single_shard_path.exists() && !index_path.exists() {
|
||||
println!("Single shard model detected, converting to FP32...");
|
||||
vec![single_shard_path]
|
||||
} else if index_path.exists() {
|
||||
let index_content = std::fs::read_to_string(&index_path)?;
|
||||
let index: SafetensorsIndex = serde_json::from_str(&index_content)?;
|
||||
|
||||
let mut files: Vec<String> = index.weight_map.values().cloned().collect();
|
||||
files.sort();
|
||||
files.dedup();
|
||||
|
||||
println!(
|
||||
"Loading {} shard files (converting to FP32)...",
|
||||
files.len()
|
||||
);
|
||||
files.into_iter().map(|f| model_dir.join(f)).collect()
|
||||
} else {
|
||||
return Err("No model.safetensors or model.safetensors.index.json found".into());
|
||||
};
|
||||
|
||||
// Load and convert all tensors
|
||||
let mut all_tensors: HashMap<String, StoredTensor> = HashMap::new();
|
||||
|
||||
for shard_path in &shard_files {
|
||||
println!(
|
||||
" Loading {}...",
|
||||
shard_path.file_name().unwrap().to_string_lossy()
|
||||
);
|
||||
let file = File::open(shard_path)?;
|
||||
let mmap = unsafe { MmapOptions::new().map(&file)? };
|
||||
let st = SafeTensors::deserialize(&mmap)?;
|
||||
|
||||
for name in st.names() {
|
||||
let tensor = st.tensor(name)?;
|
||||
let shape: Vec<usize> = tensor.shape().to_vec();
|
||||
let fp32_data = tensor_to_f32(&tensor);
|
||||
|
||||
all_tensors.insert(
|
||||
name.to_string(),
|
||||
StoredTensor {
|
||||
shape,
|
||||
data: fp32_data,
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
println!("Extracted {} tensors", all_tensors.len());
|
||||
|
||||
// Serialize to combined file
|
||||
println!("Saving combined FP32 model to {}...", output_path.display());
|
||||
|
||||
let tensor_views: HashMap<String, TensorView<'_>> = all_tensors
|
||||
.iter()
|
||||
.map(|(name, stored)| {
|
||||
let data_bytes: &[u8] = bytemuck::cast_slice(&stored.data);
|
||||
let view = TensorView::new(Dtype::F32, stored.shape.clone(), data_bytes).unwrap();
|
||||
(name.clone(), view)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let serialized = safetensors::serialize(&tensor_views, None)?;
|
||||
|
||||
let mut file = File::create(&output_path)?;
|
||||
file.write_all(&serialized)?;
|
||||
|
||||
println!("Combined FP32 model saved successfully!");
|
||||
Ok(output_path)
|
||||
}
|
||||
|
||||
/// Downloads a model from HuggingFace and prepares it for use.
|
||||
///
|
||||
/// Returns the path to the model directory containing:
|
||||
/// - tokenizer.json
|
||||
/// - model.safetensors or model_combined.safetensors
|
||||
/// - model_combined.safetensors (FP32)
|
||||
pub fn prepare_hf_model(repo_id: &str) -> Result<PathBuf, Box<dyn std::error::Error>> {
|
||||
let model_dir = download_hf_model(repo_id)?;
|
||||
combine_safetensors_to_fp32(&model_dir)?;
|
||||
|
||||
@@ -1,491 +0,0 @@
|
||||
pub mod hf;
|
||||
pub mod model;
|
||||
|
||||
use example_common::{BenchEnv, info, sample_greedy_with_penalty, stdio};
|
||||
use hf::prepare_hf_model;
|
||||
pub use luminal::prelude::Runtime;
|
||||
use luminal::prelude::*;
|
||||
use luminal_tracing::luminal_filter;
|
||||
use model::*;
|
||||
use rustc_hash::FxHashSet;
|
||||
use std::{error::Error, io::Write, time::Duration};
|
||||
use tokenizers::Tokenizer;
|
||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
||||
|
||||
const EOS_TOKEN: u32 = 151645; // <|im_end|>
|
||||
const STOP_TOKEN: u32 = 151643; // <|endoftext|>
|
||||
const STDIO_MAX_PREFILL: usize = 512;
|
||||
|
||||
pub struct QwenRunConfig {
|
||||
pub repo_id: String,
|
||||
pub max_seq_len: usize,
|
||||
pub gen_tokens: usize,
|
||||
pub search_graphs: usize,
|
||||
pub prompt: String,
|
||||
pub repetition_penalty: f32,
|
||||
pub layers: usize,
|
||||
pub stdio: bool,
|
||||
}
|
||||
|
||||
fn qwen3_chat_prompt(user_prompt: &str) -> String {
|
||||
format!(
|
||||
"<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
|
||||
)
|
||||
}
|
||||
|
||||
impl Default for QwenRunConfig {
|
||||
fn default() -> Self {
|
||||
let bench = BenchEnv::from_env(500, 500);
|
||||
Self {
|
||||
repo_id: "Qwen/Qwen3-4B".to_string(),
|
||||
max_seq_len: 4096,
|
||||
gen_tokens: bench.gen_tokens,
|
||||
search_graphs: bench.search_graphs,
|
||||
prompt: "Explain what a neural network is in a paragraph.".to_string(),
|
||||
repetition_penalty: 1.05,
|
||||
layers: LAYERS,
|
||||
stdio: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait QwenRuntime: Runtime<ExecReturn = ()> {
|
||||
type Buffer;
|
||||
|
||||
fn load_safetensors(&mut self, cx: &Graph, file_path: &str);
|
||||
fn set_i32_data(&mut self, id: NodeIndex, data: Vec<i32>);
|
||||
fn set_zeros(&mut self, id: NodeIndex, num_bytes: usize);
|
||||
fn remove_buffer(&mut self, id: NodeIndex) -> Self::Buffer;
|
||||
fn set_buffer(&mut self, id: NodeIndex, buffer: Self::Buffer);
|
||||
fn get_f32(&self, id: NodeIndex) -> Vec<f32>;
|
||||
|
||||
fn prepare_execute(&mut self, _dyn_map: &FxHashMap<char, usize>) {}
|
||||
}
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
impl QwenRuntime for luminal_cuda_lite::runtime::CudaRuntime {
|
||||
type Buffer = luminal_cuda_lite::cudarc::driver::CudaSlice<u8>;
|
||||
|
||||
fn load_safetensors(&mut self, cx: &Graph, file_path: &str) {
|
||||
luminal_cuda_lite::runtime::CudaRuntime::load_safetensors(self, cx, file_path);
|
||||
}
|
||||
|
||||
fn set_i32_data(&mut self, id: NodeIndex, data: Vec<i32>) {
|
||||
luminal_cuda_lite::runtime::CudaRuntime::set_data(self, id, data);
|
||||
}
|
||||
|
||||
fn set_zeros(&mut self, id: NodeIndex, num_bytes: usize) {
|
||||
luminal_cuda_lite::runtime::CudaRuntime::set_zeros(self, id, num_bytes);
|
||||
}
|
||||
|
||||
fn remove_buffer(&mut self, id: NodeIndex) -> Self::Buffer {
|
||||
luminal_cuda_lite::runtime::CudaRuntime::remove_buffer(self, id)
|
||||
}
|
||||
|
||||
fn set_buffer(&mut self, id: NodeIndex, buffer: Self::Buffer) {
|
||||
luminal_cuda_lite::runtime::CudaRuntime::set_buffer(self, id, buffer);
|
||||
}
|
||||
|
||||
fn get_f32(&self, id: NodeIndex) -> Vec<f32> {
|
||||
luminal_cuda_lite::runtime::CudaRuntime::get_f32(self, id)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
impl QwenRuntime for luminal_metal::MetalRuntime {
|
||||
type Buffer = luminal_metal::Buffer;
|
||||
|
||||
fn load_safetensors(&mut self, cx: &Graph, file_path: &str) {
|
||||
luminal_metal::MetalRuntime::load_safetensors(self, cx, file_path);
|
||||
}
|
||||
|
||||
fn set_i32_data(&mut self, id: NodeIndex, data: Vec<i32>) {
|
||||
luminal_metal::MetalRuntime::set_data(self, id, data);
|
||||
}
|
||||
|
||||
fn set_zeros(&mut self, id: NodeIndex, num_bytes: usize) {
|
||||
luminal_metal::MetalRuntime::set_zeros(self, id, num_bytes);
|
||||
}
|
||||
|
||||
fn remove_buffer(&mut self, id: NodeIndex) -> Self::Buffer {
|
||||
luminal_metal::MetalRuntime::remove_buffer(self, id)
|
||||
}
|
||||
|
||||
fn set_buffer(&mut self, id: NodeIndex, buffer: Self::Buffer) {
|
||||
luminal_metal::MetalRuntime::set_buffer(self, id, buffer);
|
||||
}
|
||||
|
||||
fn get_f32(&self, id: NodeIndex) -> Vec<f32> {
|
||||
luminal_metal::MetalRuntime::get_f32(self, id)
|
||||
}
|
||||
|
||||
fn prepare_execute(&mut self, dyn_map: &FxHashMap<char, usize>) {
|
||||
luminal_metal::MetalRuntime::allocate_intermediate_buffers(self, dyn_map);
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn run_one_prompt<R: QwenRuntime>(
|
||||
cx: &mut Graph,
|
||||
runtime: &mut R,
|
||||
tokenizer: &Tokenizer,
|
||||
input: GraphTensor,
|
||||
token_ids: GraphTensor,
|
||||
logits: GraphTensor,
|
||||
kv_cache: &KVCache,
|
||||
cache_outputs: &[(GraphTensor, GraphTensor)],
|
||||
cache_bytes: usize,
|
||||
layers: usize,
|
||||
prompt_tokens: &[u32],
|
||||
gen_tokens: usize,
|
||||
repetition_penalty: f32,
|
||||
emit_tok: &mut dyn FnMut(&str),
|
||||
) -> Result<(usize, Duration), Box<dyn Error>> {
|
||||
for i in 0..layers {
|
||||
runtime.set_zeros(kv_cache.k_caches[i].id, cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[i].id, cache_bytes);
|
||||
}
|
||||
let prompt_len = prompt_tokens.len();
|
||||
let mut prev_seq = 0usize;
|
||||
let mut seen_tokens: FxHashSet<u32> = FxHashSet::default();
|
||||
let mut generated = 0usize;
|
||||
let mut sentence: Vec<u32> = Vec::new();
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
if gen_tokens > 0 && prompt_len > 0 {
|
||||
cx.set_dim('s', prompt_len);
|
||||
cx.set_dim('p', 0);
|
||||
|
||||
runtime.set_i32_data(
|
||||
input.id,
|
||||
prompt_tokens.iter().map(|t| *t as i32).collect::<Vec<_>>(),
|
||||
);
|
||||
runtime.set_i32_data(token_ids.id, (0..prompt_len as i32).collect::<Vec<_>>());
|
||||
runtime.prepare_execute(&cx.dyn_map);
|
||||
runtime.execute(&cx.dyn_map);
|
||||
let logits_data = runtime.get_f32(logits.id);
|
||||
|
||||
for (layer_idx, (k_out, v_out)) in cache_outputs.iter().enumerate() {
|
||||
let k_buf = runtime.remove_buffer(k_out.id);
|
||||
let v_buf = runtime.remove_buffer(v_out.id);
|
||||
runtime.set_buffer(kv_cache.k_caches[layer_idx].id, k_buf);
|
||||
runtime.set_buffer(kv_cache.v_caches[layer_idx].id, v_buf);
|
||||
}
|
||||
prev_seq = prompt_len;
|
||||
|
||||
let row_start = (prompt_len - 1) * VOCAB_SIZE;
|
||||
let next_token = sample_greedy_with_penalty(
|
||||
&logits_data[row_start..row_start + VOCAB_SIZE],
|
||||
&seen_tokens,
|
||||
repetition_penalty,
|
||||
);
|
||||
sentence = vec![next_token];
|
||||
seen_tokens.insert(next_token);
|
||||
generated = 1;
|
||||
|
||||
if next_token != EOS_TOKEN && next_token != STOP_TOKEN {
|
||||
let decoded = tokenizer
|
||||
.decode(&[next_token], true)
|
||||
.map_err(|err| err as Box<dyn Error>)?;
|
||||
emit_tok(&decoded);
|
||||
}
|
||||
}
|
||||
|
||||
while generated < gen_tokens && !sentence.is_empty() {
|
||||
let seq_len = sentence.len();
|
||||
let current_token = sentence[0];
|
||||
|
||||
if current_token == EOS_TOKEN || current_token == STOP_TOKEN {
|
||||
break;
|
||||
}
|
||||
|
||||
cx.set_dim('s', seq_len);
|
||||
cx.set_dim('p', prev_seq);
|
||||
|
||||
runtime.set_i32_data(
|
||||
input.id,
|
||||
sentence.iter().map(|t| *t as i32).collect::<Vec<_>>(),
|
||||
);
|
||||
runtime.set_i32_data(
|
||||
token_ids.id,
|
||||
(prev_seq as i32..(seq_len + prev_seq) as i32).collect::<Vec<_>>(),
|
||||
);
|
||||
runtime.prepare_execute(&cx.dyn_map);
|
||||
runtime.execute(&cx.dyn_map);
|
||||
let logits_data = runtime.get_f32(logits.id);
|
||||
|
||||
for (layer_idx, (k_out, v_out)) in cache_outputs.iter().enumerate() {
|
||||
let k_buf = runtime.remove_buffer(k_out.id);
|
||||
let v_buf = runtime.remove_buffer(v_out.id);
|
||||
runtime.set_buffer(kv_cache.k_caches[layer_idx].id, k_buf);
|
||||
runtime.set_buffer(kv_cache.v_caches[layer_idx].id, v_buf);
|
||||
}
|
||||
|
||||
prev_seq += seq_len;
|
||||
|
||||
let next_token = sample_greedy_with_penalty(
|
||||
&logits_data[logits_data.len() - VOCAB_SIZE..],
|
||||
&seen_tokens,
|
||||
repetition_penalty,
|
||||
);
|
||||
sentence = vec![next_token];
|
||||
seen_tokens.insert(next_token);
|
||||
generated += 1;
|
||||
|
||||
if next_token == EOS_TOKEN || next_token == STOP_TOKEN {
|
||||
break;
|
||||
}
|
||||
|
||||
let decoded = tokenizer
|
||||
.decode(&[next_token], true)
|
||||
.map_err(|err| err as Box<dyn Error>)?;
|
||||
emit_tok(&decoded);
|
||||
}
|
||||
|
||||
Ok((generated, start.elapsed()))
|
||||
}
|
||||
|
||||
pub fn run_qwen<R>(mut runtime: R, config: QwenRunConfig) -> Result<(), Box<dyn Error>>
|
||||
where
|
||||
R: QwenRuntime + 'static,
|
||||
{
|
||||
let _ = tracing_subscriber::registry()
|
||||
.with(tracing_subscriber::fmt::layer())
|
||||
.with(luminal_filter())
|
||||
.try_init();
|
||||
|
||||
let stdio_mode = config.stdio;
|
||||
let log = |s: &str| info(stdio_mode, s);
|
||||
|
||||
let model_dir = prepare_hf_model(&config.repo_id)?;
|
||||
log(&format!("Using model directory: {}", model_dir.display()));
|
||||
|
||||
let tokenizer = Tokenizer::from_file(model_dir.join("tokenizer.json"))
|
||||
.map_err(|err| err as Box<dyn Error>)?;
|
||||
|
||||
let (default_prompt_tokens, prompt_len) = if stdio_mode {
|
||||
(Vec::<u32>::new(), 0usize)
|
||||
} else {
|
||||
let prompt = qwen3_chat_prompt(&config.prompt);
|
||||
let toks = tokenizer
|
||||
.encode(prompt.as_str(), false)
|
||||
.map_err(|err| err as Box<dyn Error>)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
let len = toks.len();
|
||||
(toks, len)
|
||||
};
|
||||
|
||||
let mut cx = Graph::default();
|
||||
let input = cx.named_tensor("input", 's').as_dtype(DType::Int);
|
||||
let token_ids = cx.named_tensor("token_ids", 's').as_dtype(DType::Int);
|
||||
let kv_cache = KVCache::new(&mut cx, config.max_seq_len, config.layers);
|
||||
let (logits, cache_outputs) =
|
||||
Qwen::init(&mut cx, config.layers).forward(input, token_ids, &kv_cache);
|
||||
let logits = logits.output();
|
||||
for (k_out, v_out) in &cache_outputs {
|
||||
k_out.output();
|
||||
v_out.output();
|
||||
}
|
||||
|
||||
log("Building E-Graph...");
|
||||
cx.build_search_space::<R>();
|
||||
|
||||
log("Loading weights...");
|
||||
let weights_path = model_dir.join("model_combined.safetensors");
|
||||
runtime.load_safetensors(&cx, weights_path.to_str().unwrap());
|
||||
|
||||
let cache_bytes = N_KV_HEADS * config.max_seq_len * HEAD_DIM * std::mem::size_of::<f32>();
|
||||
for i in 0..config.layers {
|
||||
runtime.set_zeros(kv_cache.k_caches[i].id, cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[i].id, cache_bytes);
|
||||
}
|
||||
|
||||
log("Compiling...");
|
||||
let max_prefill = if stdio_mode {
|
||||
STDIO_MAX_PREFILL.min(config.max_seq_len)
|
||||
} else {
|
||||
(prompt_len + 16)
|
||||
.next_power_of_two()
|
||||
.min(config.max_seq_len)
|
||||
};
|
||||
let search_s = 16.min(max_prefill).max(2);
|
||||
cx.set_dim_buckets(
|
||||
's',
|
||||
&[
|
||||
DimBucket::new(1, 1),
|
||||
DimBucket::new(2, max_prefill).representative(search_s),
|
||||
],
|
||||
);
|
||||
cx.set_dim('s', search_s);
|
||||
cx.set_dim('p', 0);
|
||||
runtime.set_i32_data(input.id, vec![1; search_s]);
|
||||
runtime.set_i32_data(token_ids.id, (0..search_s as i32).collect::<Vec<_>>());
|
||||
runtime = cx.search(runtime, config.search_graphs);
|
||||
|
||||
for i in 0..config.layers {
|
||||
runtime.set_zeros(kv_cache.k_caches[i].id, cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[i].id, cache_bytes);
|
||||
}
|
||||
|
||||
if stdio_mode {
|
||||
stdio::serve(|user_prompt| {
|
||||
let chat_prompt = qwen3_chat_prompt(user_prompt);
|
||||
let prompt_tokens = tokenizer
|
||||
.encode(chat_prompt.as_str(), false)
|
||||
.expect("tokenize failed")
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
run_one_prompt(
|
||||
&mut cx,
|
||||
&mut runtime,
|
||||
&tokenizer,
|
||||
input,
|
||||
token_ids,
|
||||
logits,
|
||||
&kv_cache,
|
||||
&cache_outputs,
|
||||
cache_bytes,
|
||||
config.layers,
|
||||
&prompt_tokens,
|
||||
config.gen_tokens,
|
||||
config.repetition_penalty,
|
||||
&mut |s| stdio::emit_tok(s),
|
||||
)
|
||||
.expect("decode failed")
|
||||
});
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let prompt_tokens = default_prompt_tokens;
|
||||
let mut prev_seq = 0usize;
|
||||
let mut fwd_durations = vec![];
|
||||
let mut seen_tokens = FxHashSet::default();
|
||||
|
||||
println!(
|
||||
"Prompt: {} tokens, generating up to {} tokens",
|
||||
prompt_len, config.gen_tokens
|
||||
);
|
||||
|
||||
let mut generated = 0usize;
|
||||
let mut sentence = Vec::new();
|
||||
|
||||
if config.gen_tokens > 0 && prompt_len > 0 {
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
cx.set_dim('s', prompt_len);
|
||||
cx.set_dim('p', 0);
|
||||
|
||||
runtime.set_i32_data(
|
||||
input.id,
|
||||
prompt_tokens.iter().map(|t| *t as i32).collect::<Vec<_>>(),
|
||||
);
|
||||
runtime.set_i32_data(token_ids.id, (0..prompt_len as i32).collect::<Vec<_>>());
|
||||
runtime.prepare_execute(&cx.dyn_map);
|
||||
|
||||
runtime.execute(&cx.dyn_map);
|
||||
let logits_data = runtime.get_f32(logits.id);
|
||||
|
||||
for (layer_idx, (k_out, v_out)) in cache_outputs.iter().enumerate() {
|
||||
let k_buf = runtime.remove_buffer(k_out.id);
|
||||
let v_buf = runtime.remove_buffer(v_out.id);
|
||||
runtime.set_buffer(kv_cache.k_caches[layer_idx].id, k_buf);
|
||||
runtime.set_buffer(kv_cache.v_caches[layer_idx].id, v_buf);
|
||||
}
|
||||
|
||||
prev_seq = prompt_len;
|
||||
fwd_durations.push(start.elapsed());
|
||||
|
||||
let row_start = (prompt_len - 1) * VOCAB_SIZE;
|
||||
let next_token = sample_greedy_with_penalty(
|
||||
&logits_data[row_start..row_start + VOCAB_SIZE],
|
||||
&seen_tokens,
|
||||
config.repetition_penalty,
|
||||
);
|
||||
sentence = vec![next_token];
|
||||
seen_tokens.insert(next_token);
|
||||
generated = 1;
|
||||
|
||||
if next_token != EOS_TOKEN && next_token != STOP_TOKEN {
|
||||
let decoded = tokenizer
|
||||
.decode(&[next_token], true)
|
||||
.map_err(|err| err as Box<dyn Error>)?;
|
||||
print!("{}", decoded);
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
}
|
||||
|
||||
while generated < config.gen_tokens && !sentence.is_empty() {
|
||||
let start = std::time::Instant::now();
|
||||
let seq_len = sentence.len();
|
||||
let current_token = sentence[0];
|
||||
|
||||
if current_token == EOS_TOKEN || current_token == STOP_TOKEN {
|
||||
break;
|
||||
}
|
||||
|
||||
cx.set_dim('s', seq_len);
|
||||
cx.set_dim('p', prev_seq);
|
||||
|
||||
runtime.set_i32_data(
|
||||
input.id,
|
||||
sentence.iter().map(|t| *t as i32).collect::<Vec<_>>(),
|
||||
);
|
||||
runtime.set_i32_data(
|
||||
token_ids.id,
|
||||
(prev_seq as i32..(seq_len + prev_seq) as i32).collect::<Vec<_>>(),
|
||||
);
|
||||
runtime.prepare_execute(&cx.dyn_map);
|
||||
|
||||
runtime.execute(&cx.dyn_map);
|
||||
let logits_data = runtime.get_f32(logits.id);
|
||||
|
||||
for (layer_idx, (k_out, v_out)) in cache_outputs.iter().enumerate() {
|
||||
let k_buf = runtime.remove_buffer(k_out.id);
|
||||
let v_buf = runtime.remove_buffer(v_out.id);
|
||||
runtime.set_buffer(kv_cache.k_caches[layer_idx].id, k_buf);
|
||||
runtime.set_buffer(kv_cache.v_caches[layer_idx].id, v_buf);
|
||||
}
|
||||
|
||||
prev_seq += seq_len;
|
||||
fwd_durations.push(start.elapsed());
|
||||
|
||||
let next_token = sample_greedy_with_penalty(
|
||||
&logits_data[logits_data.len() - VOCAB_SIZE..],
|
||||
&seen_tokens,
|
||||
config.repetition_penalty,
|
||||
);
|
||||
sentence = vec![next_token];
|
||||
seen_tokens.insert(next_token);
|
||||
generated += 1;
|
||||
|
||||
if next_token == EOS_TOKEN || next_token == STOP_TOKEN {
|
||||
break;
|
||||
}
|
||||
|
||||
let decoded = tokenizer
|
||||
.decode(&[next_token], true)
|
||||
.map_err(|err| err as Box<dyn Error>)?;
|
||||
print!("{}", decoded);
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
println!();
|
||||
|
||||
let decode_durations: Vec<_> = fwd_durations.iter().skip(1).collect();
|
||||
if decode_durations.len() > 2 {
|
||||
println!(
|
||||
" TTFT: {:.2} ms",
|
||||
fwd_durations[..1].iter().sum::<Duration>().as_secs_f64() * 1e3
|
||||
);
|
||||
println!(
|
||||
" TPOT: {:.2} ms",
|
||||
(decode_durations.iter().skip(1).copied().sum::<Duration>()
|
||||
/ (decode_durations.len() - 1) as u32)
|
||||
.as_secs_f64()
|
||||
* 1_000.
|
||||
);
|
||||
}
|
||||
|
||||
let _ = generated;
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,60 +1,174 @@
|
||||
#[cfg(all(feature = "cuda", feature = "metal"))]
|
||||
compile_error!("features `cuda` and `metal` are mutually exclusive");
|
||||
mod hf;
|
||||
mod model;
|
||||
|
||||
#[cfg(all(feature = "cuda", feature = "metal"))]
|
||||
fn main() {}
|
||||
|
||||
#[cfg(all(feature = "cuda", not(feature = "metal")))]
|
||||
use hf::prepare_hf_model;
|
||||
use luminal::prelude::*;
|
||||
use luminal_cuda_lite::{cudarc::driver::CudaContext, runtime::CudaRuntime};
|
||||
#[cfg(all(feature = "metal", not(feature = "cuda"), target_vendor = "apple"))]
|
||||
use luminal_metal::MetalRuntime;
|
||||
#[cfg(any(
|
||||
all(feature = "cuda", not(feature = "metal")),
|
||||
all(feature = "metal", not(feature = "cuda"), target_vendor = "apple")
|
||||
))]
|
||||
use qwen::{QwenRunConfig, Runtime, run_qwen};
|
||||
use luminal_tracing::*;
|
||||
use model::*;
|
||||
use rustc_hash::FxHashSet;
|
||||
use std::{io::Write, time::Duration};
|
||||
use tokenizers::Tokenizer;
|
||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
||||
|
||||
#[cfg(any(
|
||||
all(feature = "cuda", not(feature = "metal")),
|
||||
all(feature = "metal", not(feature = "cuda"), target_vendor = "apple")
|
||||
))]
|
||||
fn parse_cli() -> QwenRunConfig {
|
||||
let mut cfg = QwenRunConfig::default();
|
||||
for arg in std::env::args().skip(1) {
|
||||
match arg.as_str() {
|
||||
"--stdio" => cfg.stdio = true,
|
||||
"-h" | "--help" => {
|
||||
println!("Usage: qwen [--stdio]");
|
||||
std::process::exit(0);
|
||||
}
|
||||
other => {
|
||||
eprintln!("Unknown argument: {other}");
|
||||
std::process::exit(2);
|
||||
}
|
||||
}
|
||||
}
|
||||
cfg
|
||||
}
|
||||
const REPO_ID: &str = "Qwen/Qwen3-4B";
|
||||
|
||||
#[cfg(all(feature = "cuda", not(feature = "metal")))]
|
||||
fn main() {
|
||||
let max_seq_len = 4096;
|
||||
let gen_tokens = 500;
|
||||
let search_graphs = 500;
|
||||
let prompt = "Explain what a neural network is in a paragraph.";
|
||||
|
||||
tracing_subscriber::registry()
|
||||
.with(tracing_subscriber::fmt::layer())
|
||||
.with(luminal_filter())
|
||||
.init();
|
||||
|
||||
let ctx = CudaContext::new(0).unwrap();
|
||||
let stream = ctx.default_stream();
|
||||
run_qwen(CudaRuntime::initialize(stream), parse_cli()).unwrap();
|
||||
}
|
||||
|
||||
#[cfg(all(feature = "metal", not(feature = "cuda"), target_vendor = "apple"))]
|
||||
fn main() {
|
||||
run_qwen(MetalRuntime::initialize(()), parse_cli()).unwrap();
|
||||
}
|
||||
let model_dir = prepare_hf_model(REPO_ID).expect("Failed to prepare model");
|
||||
println!("Using model directory: {}", model_dir.display());
|
||||
|
||||
#[cfg(all(feature = "metal", not(feature = "cuda"), not(target_vendor = "apple")))]
|
||||
fn main() {
|
||||
eprintln!("qwen --features metal requires an Apple target with Metal support.");
|
||||
}
|
||||
let tokenizer = Tokenizer::from_file(model_dir.join("tokenizer.json")).unwrap();
|
||||
let prompt_tokens = tokenizer.encode(prompt, true).unwrap().get_ids().to_vec();
|
||||
|
||||
#[cfg(not(any(feature = "cuda", feature = "metal")))]
|
||||
fn main() {
|
||||
eprintln!("select exactly one backend with `--features cuda` or `--features metal`.");
|
||||
std::process::exit(2);
|
||||
// Build graph
|
||||
let mut cx = Graph::default();
|
||||
let input = cx.named_tensor("input", 's').as_dtype(DType::Int);
|
||||
let token_ids = cx.named_tensor("token_ids", 's').as_dtype(DType::Int);
|
||||
let kv_cache = KVCache::new(&mut cx, max_seq_len);
|
||||
let (logits, cache_outputs) = Qwen::init(&mut cx).forward(input, token_ids, &kv_cache);
|
||||
let logits = logits.output();
|
||||
for (k_out, v_out) in &cache_outputs {
|
||||
k_out.output();
|
||||
v_out.output();
|
||||
}
|
||||
|
||||
println!("Building E-Graph...");
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
|
||||
println!("Loading weights...");
|
||||
let mut runtime = CudaRuntime::initialize(stream);
|
||||
let weights_path = model_dir.join("model_combined.safetensors");
|
||||
runtime.load_safetensors(&cx, weights_path.to_str().unwrap());
|
||||
|
||||
let cache_bytes = N_KV_HEADS * max_seq_len * HEAD_DIM * std::mem::size_of::<f32>();
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
|
||||
}
|
||||
|
||||
println!("Compiling...");
|
||||
cx.set_dim('s', 1);
|
||||
cx.set_dim('p', 1);
|
||||
runtime.set_data(input, vec![1]);
|
||||
runtime.set_data(token_ids, vec![1]);
|
||||
runtime = cx.search(runtime, search_graphs);
|
||||
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
|
||||
}
|
||||
|
||||
let mut prev_seq = 1usize;
|
||||
let mut sentence = vec![prompt_tokens[0]];
|
||||
let total_steps = prompt_tokens.len() - 1 + gen_tokens;
|
||||
let prompt_len = prompt_tokens.len();
|
||||
let mut fwd_durations = vec![];
|
||||
let mut seen_tokens = FxHashSet::default();
|
||||
let repetition_penalty: f32 = 1.05;
|
||||
|
||||
const EOS_TOKEN: u32 = 151645; // <|endoftext|>
|
||||
const STOP_TOKEN: u32 = 151643; // <|end|>
|
||||
|
||||
println!(
|
||||
"Prompt: {} tokens, generating up to {} tokens",
|
||||
prompt_len, gen_tokens
|
||||
);
|
||||
|
||||
for i in 0..total_steps {
|
||||
let start = std::time::Instant::now();
|
||||
let is_prefill = i < prompt_len - 1;
|
||||
let seq_len = sentence.len();
|
||||
|
||||
cx.set_dim('s', seq_len);
|
||||
cx.set_dim('p', prev_seq);
|
||||
|
||||
runtime.set_data(
|
||||
input,
|
||||
sentence.iter().map(|t| *t as i32).collect::<Vec<_>>(),
|
||||
);
|
||||
runtime.set_data(
|
||||
token_ids,
|
||||
(prev_seq as i32..(seq_len + prev_seq) as i32).collect::<Vec<_>>(),
|
||||
);
|
||||
|
||||
runtime.execute(&cx.dyn_map);
|
||||
let logits_data = runtime.get_f32(logits);
|
||||
|
||||
// Round-trip KV cache
|
||||
for (layer_idx, (k_out, v_out)) in cache_outputs.iter().enumerate() {
|
||||
let k_buf = runtime.remove_buffer(*k_out);
|
||||
let v_buf = runtime.remove_buffer(*v_out);
|
||||
runtime.set_buffer(kv_cache.k_caches[layer_idx], k_buf);
|
||||
runtime.set_buffer(kv_cache.v_caches[layer_idx], v_buf);
|
||||
}
|
||||
|
||||
prev_seq += seq_len;
|
||||
fwd_durations.push(start.elapsed());
|
||||
|
||||
if is_prefill {
|
||||
sentence = vec![prompt_tokens[i + 1]];
|
||||
continue;
|
||||
}
|
||||
|
||||
// Greedy decode with repetition penalty
|
||||
let mut last_row = logits_data[logits_data.len() - VOCAB_SIZE..].to_vec();
|
||||
for &tok in &seen_tokens {
|
||||
let logit = &mut last_row[tok as usize];
|
||||
if *logit > 0.0 {
|
||||
*logit /= repetition_penalty;
|
||||
} else {
|
||||
*logit *= repetition_penalty;
|
||||
}
|
||||
}
|
||||
let next_token = last_row
|
||||
.iter()
|
||||
.enumerate()
|
||||
.max_by(|(_, a), (_, b)| a.total_cmp(b))
|
||||
.unwrap()
|
||||
.0 as u32;
|
||||
sentence = vec![next_token];
|
||||
seen_tokens.insert(next_token);
|
||||
|
||||
if next_token == EOS_TOKEN || next_token == STOP_TOKEN {
|
||||
break;
|
||||
}
|
||||
|
||||
let decoded = tokenizer.decode(&[next_token], true).unwrap();
|
||||
print!("{}", decoded);
|
||||
std::io::stdout().flush().unwrap();
|
||||
}
|
||||
println!();
|
||||
|
||||
// Benchmarks
|
||||
let decode_durations: Vec<_> = fwd_durations.iter().skip(prompt_len).collect();
|
||||
if decode_durations.len() > 2 {
|
||||
println!(
|
||||
" TTFT: {:.2} ms",
|
||||
fwd_durations[..prompt_len]
|
||||
.iter()
|
||||
.sum::<Duration>()
|
||||
.as_secs_f64()
|
||||
* 1e3
|
||||
);
|
||||
println!(
|
||||
" TPOT: {:.2} ms",
|
||||
(decode_durations.iter().skip(1).copied().sum::<Duration>()
|
||||
/ (decode_durations.len() - 1) as u32)
|
||||
.as_secs_f64()
|
||||
* 1_000.
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -26,14 +26,10 @@ pub struct KVCache {
|
||||
}
|
||||
|
||||
impl KVCache {
|
||||
pub fn new(cx: &mut Graph, max_seq: usize, layers: usize) -> Self {
|
||||
assert!(
|
||||
layers <= LAYERS,
|
||||
"requested {layers} layers, but model has {LAYERS}"
|
||||
);
|
||||
let mut k_caches = Vec::with_capacity(layers);
|
||||
let mut v_caches = Vec::with_capacity(layers);
|
||||
for l in 0..layers {
|
||||
pub fn new(cx: &mut Graph, max_seq: usize) -> Self {
|
||||
let mut k_caches = Vec::with_capacity(LAYERS);
|
||||
let mut v_caches = Vec::with_capacity(LAYERS);
|
||||
for l in 0..LAYERS {
|
||||
let k = cx
|
||||
.named_tensor(format!("kv_cache.{l}.k"), (N_KV_HEADS, max_seq, HEAD_DIM))
|
||||
.persist();
|
||||
@@ -58,13 +54,9 @@ pub struct Qwen {
|
||||
}
|
||||
|
||||
impl Qwen {
|
||||
pub fn init(cx: &mut Graph, layers: usize) -> Self {
|
||||
assert!(
|
||||
layers <= LAYERS,
|
||||
"requested {layers} layers, but model has {LAYERS}"
|
||||
);
|
||||
pub fn init(cx: &mut Graph) -> Self {
|
||||
let mut w = vec![];
|
||||
for l in 0..layers {
|
||||
for l in 0..LAYERS {
|
||||
let up = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.mlp.up_proj.weight"),
|
||||
@@ -177,7 +169,7 @@ impl Qwen {
|
||||
(token_ids * HIDDEN).expand_dim(1, HIDDEN)
|
||||
+ token_ids.graph().arange(HIDDEN).expand_dim(0, seq),
|
||||
);
|
||||
let mut cache_outputs = Vec::with_capacity(self.layers.len());
|
||||
let mut cache_outputs = Vec::with_capacity(LAYERS);
|
||||
for (i, layer) in self.layers.iter().enumerate() {
|
||||
let (x_new, k_out, v_out) = layer.forward(
|
||||
x,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
[package]
|
||||
name = "qwen3_moe"
|
||||
version = "0.1.0"
|
||||
edition = "2024"
|
||||
edition = "2021"
|
||||
|
||||
[features]
|
||||
|
||||
@@ -9,7 +9,6 @@ edition = "2024"
|
||||
luminal = { path = "../.." }
|
||||
luminal_nn = { path = "../../crates/luminal_nn" }
|
||||
luminal_cuda_lite = { path = "../../crates/luminal_cuda_lite" }
|
||||
example_common = { path = "../example_common" }
|
||||
tokenizers = "0.22.2"
|
||||
rustc-hash = "2"
|
||||
|
||||
@@ -21,4 +20,3 @@ serde_json = "1.0"
|
||||
half = {version = "2.7.1", features = ["bytemuck"]}
|
||||
bytemuck = "1.24.0"
|
||||
memmap2 = "0.9.9"
|
||||
rand = "0.9.2"
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user