mirror of
https://git.teahaven.kr/Rust-related/luminal.git
synced 2026-06-04 08:39:48 +09:00
cargo examples (#325)
* cargo examples * Fix commit message generation for diff context * Generalize GLUMoE search-space checks and harden NaN tests
This commit is contained in:
@@ -1,3 +1,6 @@
|
||||
[alias]
|
||||
examples = "run --release --bin examples-perf --"
|
||||
|
||||
[target.aarch64-unknown-linux-gnu]
|
||||
rustflags = [
|
||||
"-Ctarget-feature=+fp16,+fhm"
|
||||
|
||||
185
ci/examples_perf.py
Normal file
185
ci/examples_perf.py
Normal file
@@ -0,0 +1,185 @@
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from example_output import validate_output
|
||||
|
||||
|
||||
DEFAULT_EXAMPLES = ["llama", "gemma", "qwen", "qwen3_moe", "gemma4_moe", "whisper"]
|
||||
|
||||
EXAMPLE_CARGO_ARGS = {
|
||||
"llama": ["run", "--release", "-p", "llama"],
|
||||
"gemma": ["run", "--release", "-p", "gemma"],
|
||||
"qwen": ["run", "--release", "-p", "qwen", "--features", "cuda"],
|
||||
"qwen3_moe": ["run", "--release", "-p", "qwen3_moe"],
|
||||
"gemma4_moe": ["run", "--release", "-p", "gemma4_moe"],
|
||||
"whisper": ["run", "--release", "-p", "whisper"],
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Metrics:
|
||||
ttft_ms: float | None = None
|
||||
tpot_ms: float | None = None
|
||||
tps: float | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExampleResult:
|
||||
name: str
|
||||
ok: bool
|
||||
metrics: Metrics = field(default_factory=Metrics)
|
||||
wall_s: float = 0.0
|
||||
error: str | None = None
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = [arg for arg in sys.argv[1:] if arg != "--"]
|
||||
if any(arg in {"-h", "--help"} for arg in args):
|
||||
print_help()
|
||||
return
|
||||
if "--list" in args:
|
||||
print("\n".join(DEFAULT_EXAMPLES))
|
||||
return
|
||||
|
||||
examples = args or DEFAULT_EXAMPLES
|
||||
results = [run_example(example) for example in examples]
|
||||
print_table(results)
|
||||
if any(not result.ok for result in results):
|
||||
raise SystemExit(1)
|
||||
|
||||
|
||||
def print_help() -> None:
|
||||
print(
|
||||
"Run validated Luminal examples, validate textual output, and summarize perf.\n"
|
||||
"\n"
|
||||
"Usage:\n"
|
||||
" cargo examples\n"
|
||||
" cargo examples llama qwen whisper\n"
|
||||
"\n"
|
||||
"Options:\n"
|
||||
" --list Print the default validated examples\n"
|
||||
" -h, --help\n"
|
||||
"\n"
|
||||
f"The default set matches the Modal examples CI: {', '.join(DEFAULT_EXAMPLES)}."
|
||||
)
|
||||
|
||||
|
||||
def run_example(example: str) -> ExampleResult:
|
||||
cargo_args = EXAMPLE_CARGO_ARGS.get(example)
|
||||
if cargo_args is None:
|
||||
known = ", ".join(DEFAULT_EXAMPLES)
|
||||
return ExampleResult(example, False, error=f"unknown example; known examples: {known}")
|
||||
|
||||
print(f"\n=== Running {example} ===")
|
||||
print(f"$ cargo {' '.join(cargo_args)}")
|
||||
started = time.monotonic()
|
||||
env = os.environ.copy()
|
||||
env.setdefault("CUDARC_CUDA_VERSION", "12080")
|
||||
process = subprocess.Popen(
|
||||
["cargo", *cargo_args],
|
||||
cwd=repo_root(),
|
||||
env=env,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
)
|
||||
assert process.stdout is not None
|
||||
|
||||
chunks: list[bytes] = []
|
||||
while True:
|
||||
chunk = process.stdout.read1(4096)
|
||||
if not chunk:
|
||||
break
|
||||
sys.stdout.buffer.write(chunk)
|
||||
sys.stdout.buffer.flush()
|
||||
chunks.append(chunk)
|
||||
|
||||
return_code = process.wait()
|
||||
output = b"".join(chunks).decode("utf-8", errors="replace")
|
||||
wall_s = time.monotonic() - started
|
||||
metrics = parse_metrics(output)
|
||||
|
||||
if return_code:
|
||||
return ExampleResult(
|
||||
example,
|
||||
False,
|
||||
metrics=metrics,
|
||||
wall_s=wall_s,
|
||||
error=f"process exited with code {return_code}",
|
||||
)
|
||||
|
||||
try:
|
||||
validate_output(example, output)
|
||||
except Exception as exc:
|
||||
return ExampleResult(example, False, metrics=metrics, wall_s=wall_s, error=str(exc))
|
||||
|
||||
return ExampleResult(example, True, metrics=metrics, wall_s=wall_s)
|
||||
|
||||
|
||||
def repo_root() -> str:
|
||||
return os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
|
||||
def parse_metrics(output: str) -> Metrics:
|
||||
metrics = Metrics()
|
||||
for line in output.splitlines():
|
||||
if "TTFT:" in line:
|
||||
metrics.ttft_ms = parse_number_after(line, "TTFT:")
|
||||
if "TPOT:" in line:
|
||||
metrics.tpot_ms = parse_number_after(line, "TPOT:")
|
||||
if "tok/s" in line:
|
||||
metrics.tps = parse_tok_per_second(line)
|
||||
if metrics.tps is None and metrics.tpot_ms:
|
||||
metrics.tps = 1000.0 / metrics.tpot_ms
|
||||
return metrics
|
||||
|
||||
|
||||
def parse_number_after(line: str, marker: str) -> float | None:
|
||||
tail = line.split(marker, 1)[1].lstrip()
|
||||
chars = []
|
||||
for char in tail:
|
||||
if char.isdigit() or char == ".":
|
||||
chars.append(char)
|
||||
else:
|
||||
break
|
||||
if not chars:
|
||||
return None
|
||||
return float("".join(chars))
|
||||
|
||||
|
||||
def parse_tok_per_second(line: str) -> float | None:
|
||||
head = line.split("tok/s", 1)[0].rstrip(" (")
|
||||
parts = head.split()
|
||||
if not parts:
|
||||
return None
|
||||
try:
|
||||
return float(parts[-1])
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
def print_table(results: list[ExampleResult]) -> None:
|
||||
print("\nSummary")
|
||||
print(f"{'example':<14} {'status':<8} {'TTFT ms':>10} {'TPOT ms':>10} {'tok/s':>10} {'wall s':>10}")
|
||||
print("-" * 68)
|
||||
for result in results:
|
||||
status = "ok" if result.ok else "failed"
|
||||
print(
|
||||
f"{result.name:<14} {status:<8} "
|
||||
f"{format_metric(result.metrics.ttft_ms):>10} "
|
||||
f"{format_metric(result.metrics.tpot_ms):>10} "
|
||||
f"{format_metric(result.metrics.tps):>10} "
|
||||
f"{result.wall_s:>10.1f}"
|
||||
)
|
||||
if result.error:
|
||||
print(f" error: {result.error}")
|
||||
|
||||
|
||||
def format_metric(value: float | None) -> str:
|
||||
return "-" if value is None else f"{value:.2f}"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,258 +0,0 @@
|
||||
use std::sync::{Arc, OnceLock};
|
||||
|
||||
use luminal::{
|
||||
egglog_utils::{
|
||||
api::{Rule, SortDef, sort},
|
||||
base::{EXPRESSION, OP_KIND, STRING},
|
||||
extract_expr,
|
||||
},
|
||||
op::{EgglogOp, LLIROp},
|
||||
prelude::{
|
||||
tracing::{Level, span, trace},
|
||||
*,
|
||||
},
|
||||
};
|
||||
|
||||
use crate::{
|
||||
cudarc::{
|
||||
cublas::{
|
||||
CudaBlas,
|
||||
sys::{cublasOperation_t, cublasSetStream_v2, cublasSgemm_v2, cublasStatus_t},
|
||||
},
|
||||
driver::CudaStream,
|
||||
},
|
||||
host::{DeviceBuffer, HostOp},
|
||||
};
|
||||
|
||||
/// Global shared cuBLAS handle to avoid per-operation workspace allocation
|
||||
static SHARED_CUBLAS: OnceLock<Arc<CudaBlas>> = OnceLock::new();
|
||||
|
||||
/// Parse cuBLAS operation from egglog string (e.g., "\"T\"" -> CUBLAS_OP_T)
|
||||
pub fn parse_cublas_op(s: &str) -> cublasOperation_t {
|
||||
// Strip quotes if present (egglog strings are stored with quotes)
|
||||
let stripped = s.trim_matches('"');
|
||||
match stripped {
|
||||
"T" => cublasOperation_t::CUBLAS_OP_T,
|
||||
"N" => cublasOperation_t::CUBLAS_OP_N,
|
||||
"C" => cublasOperation_t::CUBLAS_OP_C,
|
||||
other => panic!("Unknown cuBLAS operation: '{other}' (original: '{s}')"),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[allow(dead_code)]
|
||||
pub struct CuBlasSgemmV2 {
|
||||
m: Expression,
|
||||
n: Expression,
|
||||
k: Expression,
|
||||
a_layout: cublasOperation_t,
|
||||
b_layout: cublasOperation_t,
|
||||
lda: Expression,
|
||||
ldb: Expression,
|
||||
ldc: Expression,
|
||||
/// Lazily initialized cuBLAS handle - created on first execute
|
||||
cublas: OnceLock<Arc<CudaBlas>>,
|
||||
}
|
||||
|
||||
// Useless default for IntoEgglogOp
|
||||
impl Default for CuBlasSgemmV2 {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
m: Expression::default(),
|
||||
n: Expression::default(),
|
||||
k: Expression::default(),
|
||||
a_layout: cublasOperation_t::CUBLAS_OP_N, // IGNORE NOT REAL
|
||||
b_layout: cublasOperation_t::CUBLAS_OP_T, // IGNORE NOT REAL
|
||||
lda: Expression::default(),
|
||||
ldb: Expression::default(),
|
||||
ldc: Expression::default(),
|
||||
cublas: OnceLock::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl EgglogOp for CuBlasSgemmV2 {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
"cublasSgemmV2",
|
||||
&[
|
||||
("m", EXPRESSION),
|
||||
("n", EXPRESSION),
|
||||
("k", EXPRESSION),
|
||||
("a_layout", STRING),
|
||||
("b_layout", STRING),
|
||||
("lda", EXPRESSION),
|
||||
("ldb", EXPRESSION),
|
||||
("ldc", EXPRESSION),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
fn n_inputs(&self) -> usize {
|
||||
2
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![
|
||||
Rule::raw(include_str!["sgemm_v2_RmRm_rewrite.egg"]), // row row
|
||||
Rule::raw(include_str!["sgemm_v2_RmCm_rewrite.egg"]), // row col
|
||||
Rule::raw(include_str!["sgemm_v2_CmRm_rewrite.egg"]), // col row
|
||||
Rule::raw(include_str!["sgemm_v2_CmCm_rewrite.egg"]), // col col
|
||||
]
|
||||
}
|
||||
|
||||
#[allow(unused_variables)]
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a luminal::egglog_utils::SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
// Extract dimensions from egglog
|
||||
let m = extract_expr(egraph, kind_children[0], expr_cache).unwrap();
|
||||
let n = extract_expr(egraph, kind_children[1], expr_cache).unwrap();
|
||||
let k = extract_expr(egraph, kind_children[2], expr_cache).unwrap();
|
||||
|
||||
// Extract layout strings from egglog
|
||||
let a_layout_str = &egraph.enodes[kind_children[3]].0;
|
||||
let b_layout_str = &egraph.enodes[kind_children[4]].0;
|
||||
let a_layout = parse_cublas_op(a_layout_str);
|
||||
let b_layout = parse_cublas_op(b_layout_str);
|
||||
|
||||
// Extract leading dimensions from egglog
|
||||
let lda = extract_expr(egraph, kind_children[5], expr_cache).unwrap();
|
||||
let ldb = extract_expr(egraph, kind_children[6], expr_cache).unwrap();
|
||||
let ldc = extract_expr(egraph, kind_children[7], expr_cache).unwrap();
|
||||
|
||||
let extracted_state = Self {
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
a_layout,
|
||||
b_layout,
|
||||
lda,
|
||||
ldb,
|
||||
ldc,
|
||||
cublas: OnceLock::new(),
|
||||
};
|
||||
trace!(?extracted_state);
|
||||
|
||||
let extracted = LLIROp::new::<dyn HostOp>(Box::new(extracted_state) as Box<dyn HostOp>);
|
||||
|
||||
(extracted, input_enodes)
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
impl HostOp for CuBlasSgemmV2 {
|
||||
fn execute(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
self_node: NodeIndex,
|
||||
inputs: &[NodeIndex],
|
||||
buffers: &FxHashMap<NodeIndex, DeviceBuffer>,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> anyhow::Result<()> {
|
||||
// GEMM parameters
|
||||
let m = self.m.exec(dyn_map).unwrap() as i32;
|
||||
let n = self.n.exec(dyn_map).unwrap() as i32;
|
||||
let k = self.k.exec(dyn_map).unwrap() as i32;
|
||||
let a_layout = self.a_layout;
|
||||
let b_layout = self.b_layout;
|
||||
let lda = self.lda.exec(dyn_map).unwrap() as i32;
|
||||
let ldb = self.ldb.exec(dyn_map).unwrap() as i32;
|
||||
let ldc = self.ldc.exec(dyn_map).unwrap() as i32;
|
||||
|
||||
let alpha = 1.0f32;
|
||||
let beta = 0.0f32;
|
||||
|
||||
// Get buffers: output is self_node, inputs are from graph edges
|
||||
let c_buf = buffers[&self_node];
|
||||
let a_buf = buffers[&inputs[0]];
|
||||
let b_buf = buffers[&inputs[1]];
|
||||
|
||||
// Get device pointers
|
||||
let a_ptr = a_buf.ptr();
|
||||
let b_ptr = b_buf.ptr();
|
||||
let c_ptr = c_buf.ptr();
|
||||
|
||||
// Debug: Check buffer sizes
|
||||
trace!(
|
||||
"buffer_validation {}=={},{}=={},{}=={}",
|
||||
a_buf.len(),
|
||||
m * k * 4,
|
||||
b_buf.len(),
|
||||
k * n * 4,
|
||||
c_buf.len(),
|
||||
m * n * 4
|
||||
);
|
||||
let _sgemm_span = span!(
|
||||
Level::TRACE,
|
||||
"cuBLAS_SGEMM_V2",
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
alpha,
|
||||
beta,
|
||||
lda,
|
||||
ldb,
|
||||
ldc,
|
||||
?a_layout,
|
||||
?b_layout,
|
||||
)
|
||||
.entered();
|
||||
|
||||
// Use shared cuBLAS handle to avoid per-operation workspace allocation
|
||||
let cublas = SHARED_CUBLAS.get_or_init(|| Arc::new(CudaBlas::new(stream.clone()).unwrap()));
|
||||
|
||||
// Set the stream for this operation (cuBLAS handle can work with any stream)
|
||||
// The CUstream types from cublas::sys and driver::sys are compatible, just cast
|
||||
unsafe {
|
||||
cublasSetStream_v2(*cublas.handle(), stream.cu_stream() as _);
|
||||
}
|
||||
|
||||
let status = unsafe {
|
||||
cublasSgemm_v2(
|
||||
*cublas.handle(),
|
||||
a_layout,
|
||||
b_layout,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
&alpha as *const f32,
|
||||
a_ptr as *const f32,
|
||||
lda,
|
||||
b_ptr as *const f32,
|
||||
ldb,
|
||||
&beta as *const f32,
|
||||
c_ptr as *mut f32,
|
||||
ldc,
|
||||
)
|
||||
};
|
||||
stream.synchronize().unwrap();
|
||||
|
||||
if status != cublasStatus_t::CUBLAS_STATUS_SUCCESS {
|
||||
return Err(anyhow::anyhow!(
|
||||
"cuBLAS SGEMM TN failed with status: {:?}",
|
||||
status
|
||||
));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
self.m * self.n
|
||||
}
|
||||
|
||||
fn output_bytes(&self) -> Expression {
|
||||
// CuBlasSgemmV2 is F32 only (Sgemm = Single precision)
|
||||
self.output_size() * 4
|
||||
}
|
||||
}
|
||||
@@ -1,73 +0,0 @@
|
||||
; Column-major × Column-major matmul: C[m,n] = A[m,k] × B[k,n]
|
||||
; A[m,k] column-major → expand to [m, n, k] with strides [1, 0, m]
|
||||
; B[k,n] column-major → expand to [m, n, k] with strides [0, k, 1]
|
||||
;
|
||||
; Row-major viewed as column-major (swap trick):
|
||||
; Column-major A[m,k] is already column-major with lda=m
|
||||
; Column-major B[k,n] is already column-major with ldb=k
|
||||
; Row-major C[m,n] ≡ column-major C^T[n,m] with ldc=n
|
||||
;
|
||||
; C^T[n,m] = (A × B)^T = B^T[n,k] × A^T[k,m]
|
||||
; cuBLAS: cublasSgemm(OP_T, OP_T, n, m, k, α, B, k, A, m, β, C, n)
|
||||
(rule
|
||||
(
|
||||
; Match Mul node
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
; Match Sum that reduces the Mul (k dimension)
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
|
||||
; Must be exactly 2D (no batch dims) — batched matmul uses CuBlasLt
|
||||
(= (len ?out_shape) 2)
|
||||
|
||||
; Get dimensions from output shape
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
|
||||
; Get A strides in [m, n, k] space
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
; Get B strides in [m, n, k] space
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
; Assert contiguous k stride on output (required for reduction)
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
; Assert A has strides [1, 0, m] (column-major A[m,k] broadcast to [m,n,k])
|
||||
(= ?a_m_stride (MIter))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MMul (MIter) ?m))
|
||||
|
||||
; Assert B has strides [0, k, 1] (column-major B[k,n] broadcast to [m,n,k])
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MMul (MIter) ?k))
|
||||
(= ?b_k_stride (MIter))
|
||||
|
||||
(= (F32) (dtype ?a))
|
||||
(= (F32) (dtype ?b))
|
||||
)
|
||||
(
|
||||
; For column-major A × column-major B with cuBLAS:
|
||||
; C^T = B^T × A^T → cublasSgemm(OP_T, OP_T, n, m, k, α, B, k, A, m, β, C, n)
|
||||
(let ?sgemm (Op (cublasSgemmV2
|
||||
?n ; cuBLAS m = our n (swapped)
|
||||
?m ; cuBLAS n = our m (swapped)
|
||||
?k ; k unchanged
|
||||
"T" ; transa = Transpose (B is column-major [k,n], need B^T[n,k])
|
||||
"T" ; transb = Transpose (A is column-major [m,k], need A^T[k,m])
|
||||
?k ; lda = k (column-major B[k,n])
|
||||
?m ; ldb = m (column-major A[m,k])
|
||||
?n) ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublas sgemm column-major × column-major"
|
||||
)
|
||||
@@ -1,73 +0,0 @@
|
||||
; Column-major × Row-major matmul: C[m,n] = A[m,k] × B[k,n]
|
||||
; A[m,k] column-major → expand to [m, n, k] with strides [1, 0, m]
|
||||
; B[k,n] row-major → permute to [n,k] then expand to [m, n, k] with strides [0, 1, n]
|
||||
;
|
||||
; Row-major viewed as column-major (swap trick):
|
||||
; Column-major A[m,k] is already column-major with lda=m
|
||||
; Row-major B[k,n] ≡ column-major B^T[n,k] with ldb=n
|
||||
; Row-major C[m,n] ≡ column-major C^T[n,m] with ldc=n
|
||||
;
|
||||
; C^T[n,m] = (A × B)^T = B^T[n,k] × A^T[k,m]
|
||||
; cuBLAS: cublasSgemm(OP_N, OP_T, n, m, k, α, B, n, A, m, β, C, n)
|
||||
(rule
|
||||
(
|
||||
; Match Mul node
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
; Match Sum that reduces the Mul (k dimension)
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
|
||||
; Must be exactly 2D (no batch dims) — batched matmul uses CuBlasLt
|
||||
(= (len ?out_shape) 2)
|
||||
|
||||
; Get dimensions from output shape
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
|
||||
; Get A strides in [m, n, k] space
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
; Get B strides in [m, n, k] space
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
; Assert contiguous k stride on output (required for reduction)
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
; Assert A has strides [1, 0, m] (column-major A[m,k] broadcast to [m,n,k])
|
||||
(= ?a_m_stride (MIter))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MMul (MIter) ?m))
|
||||
|
||||
; Assert B has strides [0, 1, n] (row-major B[k,n] permuted to [n,k] then broadcast to [m,n,k])
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MIter))
|
||||
(= ?b_k_stride (MMul (MIter) ?n))
|
||||
|
||||
(= (F32) (dtype ?a))
|
||||
(= (F32) (dtype ?b))
|
||||
)
|
||||
(
|
||||
; For column-major A × row-major B with cuBLAS:
|
||||
; C^T = B^T × A^T → cublasSgemm(OP_N, OP_T, n, m, k, α, B, n, A, m, β, C, n)
|
||||
(let ?sgemm (Op (cublasSgemmV2
|
||||
?n ; cuBLAS m = our n (swapped)
|
||||
?m ; cuBLAS n = our m (swapped)
|
||||
?k ; k unchanged
|
||||
"N" ; transa = No transpose (B is row-major, viewed as col-major [n,k])
|
||||
"T" ; transb = Transpose (A is column-major [m,k], need A^T[k,m])
|
||||
?n ; lda = n (row-major B[k,n] viewed as col-major [n,k])
|
||||
?m ; ldb = m (column-major A[m,k])
|
||||
?n) ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublas sgemm column-major × row-major"
|
||||
)
|
||||
@@ -1,73 +0,0 @@
|
||||
; Row-major × Column-major matmul: C[m,n] = A[m,k] × B[k,n]
|
||||
; A[m,k] row-major → expand to [m, n, k] with strides [k, 0, 1]
|
||||
; B[k,n] column-major → expand to [m, n, k] with strides [0, k, 1]
|
||||
;
|
||||
; Row-major viewed as column-major (swap trick):
|
||||
; Row-major A[m,k] ≡ column-major A^T[k,m] with lda=k
|
||||
; Column-major B[k,n] is already column-major with ldb=k
|
||||
; Row-major C[m,n] ≡ column-major C^T[n,m] with ldc=n
|
||||
;
|
||||
; C^T[n,m] = (A × B)^T = B^T[n,k] × A^T[k,m]
|
||||
; cuBLAS: cublasSgemm(OP_T, OP_N, n, m, k, α, B, k, A, k, β, C, n)
|
||||
(rule
|
||||
(
|
||||
; Match Mul node
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
; Match Sum that reduces the Mul (k dimension)
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
|
||||
; Must be exactly 2D (no batch dims) — batched matmul uses CuBlasLt
|
||||
(= (len ?out_shape) 2)
|
||||
|
||||
; Get dimensions from output shape
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
|
||||
; Get A strides in [m, n, k] space
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
; Get B strides in [m, n, k] space
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
; Assert contiguous k stride on output (required for reduction)
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
; Assert A has strides [k, 0, 1] (row-major A[m,k] broadcast to [m,n,k])
|
||||
(= ?a_m_stride (MMul (MIter) ?k))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MIter))
|
||||
|
||||
; Assert B has strides [0, k, 1] (column-major B[k,n] broadcast to [m,n,k])
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MMul (MIter) ?k))
|
||||
(= ?b_k_stride (MIter))
|
||||
|
||||
(= (F32) (dtype ?a))
|
||||
(= (F32) (dtype ?b))
|
||||
)
|
||||
(
|
||||
; For row-major A × column-major B with cuBLAS:
|
||||
; C^T = B^T × A^T → cublasSgemm(OP_T, OP_N, n, m, k, α, B, k, A, k, β, C, n)
|
||||
(let ?sgemm (Op (cublasSgemmV2
|
||||
?n ; cuBLAS m = our n (swapped)
|
||||
?m ; cuBLAS n = our m (swapped)
|
||||
?k ; k unchanged
|
||||
"T" ; transa = Transpose (B is column-major, need B^T)
|
||||
"N" ; transb = No transpose
|
||||
?k ; lda = k (column-major B[k,n])
|
||||
?k ; ldb = k (row-major A[m,k] viewed as col-major [k,m])
|
||||
?n) ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublas sgemm row-major × column-major"
|
||||
)
|
||||
@@ -1,73 +0,0 @@
|
||||
; Row-major matmul: C[m,n] = A[m,k] × B[k,n]
|
||||
; A[m,k] row-major → expand to [m, n, k] with strides [k, 0, 1]
|
||||
; B[k,n] row-major → permute to [n,k] then expand to [m, n, k] with strides [0, 1, n]
|
||||
;
|
||||
; Row-major viewed as column-major (swap trick):
|
||||
; Row-major A[m,k] ≡ column-major [k,m] with lda=k
|
||||
; Row-major B[k,n] ≡ column-major [n,k] with ldb=n
|
||||
; Row-major C[m,n] ≡ column-major [n,m] with ldc=n
|
||||
;
|
||||
; cuBLAS computes: C_col[n,m] = B_col[n,k] × A_col[k,m]
|
||||
; cublasSgemm(OP_N, OP_N, n, m, k, α, B, n, A, k, β, C, n)
|
||||
(rule
|
||||
(
|
||||
; Match Mul node
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
; Match Sum that reduces the Mul (k dimension)
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
|
||||
; Must be exactly 2D (no batch dims) — batched matmul uses CuBlasLt
|
||||
(= (len ?out_shape) 2)
|
||||
|
||||
; Get dimensions from output shape
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
|
||||
; Get A strides in [m, n, k] space
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
; Get B strides in [m, n, k] space
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
; Assert contiguous k stride on output (required for reduction)
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
; Assert A has strides [k, 0, 1] (row-major A[m,k] broadcast to [m,n,k])
|
||||
(= ?a_m_stride (MMul (MIter) ?k))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MIter))
|
||||
|
||||
; Assert B has strides [0, 1, n] (row-major B[k,n] permuted to [n,k] then broadcast to [m,n,k])
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MIter))
|
||||
(= ?b_k_stride (MMul (MIter) ?n))
|
||||
|
||||
(= (F32) (dtype ?a))
|
||||
(= (F32) (dtype ?b))
|
||||
)
|
||||
(
|
||||
; For row-major C = A × B with cuBLAS (column-major):
|
||||
; cublasSgemm(OP_N, OP_N, n, m, k, α, B, n, A, k, β, C, n)
|
||||
(let ?sgemm (Op (cublasSgemmV2
|
||||
?n ; cuBLAS m = our n (swapped)
|
||||
?m ; cuBLAS n = our m (swapped)
|
||||
?k ; k unchanged
|
||||
"N" ; transa = No transpose
|
||||
"N" ; transb = No transpose
|
||||
?n ; lda = n (row-major B[k,n] viewed as col-major [n,k])
|
||||
?k ; ldb = k (row-major A[m,k] viewed as col-major [k,m])
|
||||
?n) ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name "cublas sgemm row-major"
|
||||
)
|
||||
@@ -11,11 +11,13 @@
|
||||
; cuBLAS: cublasSgemm(OP_T, OP_T, n, m, k, α, B, k, A, m, β, C, n)
|
||||
(rule
|
||||
(
|
||||
; Match Mul node
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
; Match Sum that reduces the Mul (k dimension)
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
; Match the generic matmul produced from Mul -> Sum.
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
; Match exactly 2D output shape
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
@@ -77,8 +79,12 @@
|
||||
; B column-major per batch: b_k_stride=MIter, b_m_stride=0
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
|
||||
@@ -11,11 +11,13 @@
|
||||
; cuBLAS: cublasSgemm(OP_N, OP_T, n, m, k, α, B, n, A, m, β, C, n)
|
||||
(rule
|
||||
(
|
||||
; Match Mul node
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
; Match Sum that reduces the Mul (k dimension)
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
; Match the generic matmul produced from Mul -> Sum.
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
; Match exactly 2D output shape
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
@@ -77,8 +79,12 @@
|
||||
; B row-major per batch: b_n_stride=MIter, b_m_stride=0
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
|
||||
@@ -11,11 +11,13 @@
|
||||
; cuBLAS: cublasSgemm(OP_T, OP_N, n, m, k, α, B, k, A, k, β, C, n)
|
||||
(rule
|
||||
(
|
||||
; Match Mul node
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
; Match Sum that reduces the Mul (k dimension)
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
; Match the generic matmul produced from Mul -> Sum.
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
; Match exactly 2D output shape
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
@@ -77,8 +79,12 @@
|
||||
; B column-major per batch: b_k_stride=MIter, b_m_stride=0
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
|
||||
@@ -11,11 +11,13 @@
|
||||
; cublasSgemm(OP_N, OP_N, n, m, k, α, B, n, A, k, β, C, n)
|
||||
(rule
|
||||
(
|
||||
; Match Mul node
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
; Match Sum that reduces the Mul (k dimension)
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
; Match the generic matmul produced from Mul -> Sum.
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
; Match exactly 2D output shape
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
@@ -79,8 +81,12 @@
|
||||
; Leading dimensions may differ from k/n when batch slices are non-contiguous.
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
; Output shape: [batch, m, n]
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
|
||||
@@ -25,8 +25,12 @@
|
||||
(ICons ?input_scale (INil))))
|
||||
(= ?a (Op (Cast ?a_size ?a_dtype) (ICons ?scaled_activation (INil))))
|
||||
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(= ?cast (Op (Cast ?size (F32)) (ICons ?sum (INil))))
|
||||
(= ?scale_product (Op (Mul (ENil) (ENil) (ENil) (ENil))
|
||||
(ICons ?input_scale (ICons ?weight_scale (INil)))))
|
||||
@@ -96,8 +100,12 @@
|
||||
(ICons ?input_scale (INil))))
|
||||
(= ?a (Op (Cast ?a_size ?a_dtype) (ICons ?scaled_activation (INil))))
|
||||
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(= ?cast (Op (Cast ?size (F32)) (ICons ?sum (INil))))
|
||||
(= ?scale_product (Op (Mul (ENil) (ENil) (ENil) (ENil))
|
||||
(ICons ?input_scale (ICons ?weight_scale (INil)))))
|
||||
@@ -368,8 +376,12 @@
|
||||
(ICons ?input_scale (INil))))
|
||||
(= ?a (Op (Cast ?a_size ?a_dtype) (ICons ?scaled_activation (INil))))
|
||||
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(= ?cast (Op (Cast ?size (F32)) (ICons ?sum (INil))))
|
||||
(= ?scale_product (Op (Mul (ENil) (ENil) (ENil) (ENil))
|
||||
(ICons ?input_scale (ICons ?weight_scale (INil)))))
|
||||
@@ -440,8 +452,12 @@
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(= ?cast (Op (Cast ?size (F32)) (ICons ?sum (INil))))
|
||||
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
@@ -489,8 +505,12 @@
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(= ?cast (Op (Cast ?size (F32)) (ICons ?sum (INil))))
|
||||
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
@@ -538,8 +558,12 @@
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(= ?cast (Op (Cast ?size (F32)) (ICons ?sum (INil))))
|
||||
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
@@ -587,8 +611,12 @@
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(= ?cast (Op (Cast ?size (F32)) (ICons ?sum (INil))))
|
||||
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
@@ -650,8 +678,12 @@
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(= ?cast (Op (Cast ?size (F32)) (ICons ?sum (INil))))
|
||||
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
@@ -713,8 +745,12 @@
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
(= ?cast (Op (Cast ?size (F32)) (ICons ?sum (INil))))
|
||||
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
|
||||
@@ -5,8 +5,12 @@
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
(!= ?m (MNum 0))
|
||||
@@ -54,8 +58,12 @@
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
(!= ?m (MNum 0))
|
||||
@@ -103,8 +111,12 @@
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
(!= ?m (MNum 0))
|
||||
@@ -152,8 +164,12 @@
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
(!= ?m (MNum 0))
|
||||
@@ -201,8 +217,12 @@
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
@@ -264,8 +284,12 @@
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
@@ -327,8 +351,12 @@
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
@@ -390,8 +418,12 @@
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?out_shape ?mul_shape ?k
|
||||
?a_stride ?b_stride
|
||||
?sum_in_stride ?k_stride ?sum_out_stride
|
||||
?matmul_dtype)
|
||||
(ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
|
||||
@@ -35,10 +35,20 @@ use crate::{
|
||||
},
|
||||
driver::{CudaStream, DevicePtr},
|
||||
},
|
||||
host::{DeviceBuffer, HostOp, cublas::parse_cublas_op},
|
||||
host::{DeviceBuffer, HostOp},
|
||||
try_create_cublaslt,
|
||||
};
|
||||
|
||||
fn parse_cublas_op(s: &str) -> cublasOperation_t {
|
||||
let stripped = s.trim_matches('"');
|
||||
match stripped {
|
||||
"T" => cublasOperation_t::CUBLAS_OP_T,
|
||||
"N" => cublasOperation_t::CUBLAS_OP_N,
|
||||
"C" => cublasOperation_t::CUBLAS_OP_C,
|
||||
other => panic!("Unknown cuBLAS operation: '{other}' (original: '{s}')"),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[allow(dead_code)]
|
||||
pub struct CuBlasLt {
|
||||
@@ -189,50 +199,50 @@ impl EgglogOp for CuBlasLt {
|
||||
Rule::raw(include_str!["cublaslt_beta_rewrite.egg"]),
|
||||
Rule::raw(include_str!["cublaslt_epilogue_rewrite.egg"]),
|
||||
Rule::raw(include_str!["cublaslt_row_order_rewrite.egg"]),
|
||||
// Delete the matmul-broadcast Mul eclass when the consuming Sum
|
||||
// eclass has a `cublaslt` or `KernelBatchMatMul` alternative. The
|
||||
// cuBLASLt / batched-matmul rewrite rules only union those enodes
|
||||
// into the Sum eclass after the broadcast pattern check passes,
|
||||
// so their presence is the matmul-broadcast signal — no further
|
||||
// stride-form check needed.
|
||||
//
|
||||
// Delete the HLIR `Mul` fallback from the Mul eclass. Emptying that
|
||||
// eclass lets the empty-eclass cascade prune the downstream Sum /
|
||||
// KernelSum fallback. cuBLAS, TileMatmulFullSplit, KernelBatchMatVec,
|
||||
// and KernelBatchMatMul all take original (a, b) inputs rather than
|
||||
// the Mul eclass, so they survive the cascade and remain as the
|
||||
// matmul output alternative.
|
||||
// cuBLASLt now specializes GenericMatmul, so cleanup should prune
|
||||
// the matmul output alternatives directly. Do not delete the
|
||||
// broadcast Mul here; it may still have non-matmul consumers.
|
||||
Rule::raw("(rule
|
||||
((= ?mul (Op (Mul ?shape ?as ?bs ?os) ?inputs))
|
||||
(= ?sum (Op (Sum ?sshape ?sk ?ssi ?sks ?sso) (ICons ?mul (INil))))
|
||||
((= ?sum (Op (Sum ?shape ?k ?sis ?sks ?sos) ?inputs))
|
||||
(= ?sum (Op (cublaslt ?cm ?cn ?ck ?cta ?ctb ?cao ?cbo ?cco ?cdo ?clda ?cldb ?cldc ?cldd ?cbc ?csa ?csb ?csc ?csd ?cadt ?cbdt ?ccdt ?cddt ?ccompute ?cscale ?calpha ?cbeta ?cepilogue) ?ci)))
|
||||
((delete (Op (Mul ?shape ?as ?bs ?os) ?inputs)))
|
||||
((delete (Op (Sum ?shape ?k ?sis ?sks ?sos) ?inputs)))
|
||||
:ruleset cleanup
|
||||
:name \"delete-sum-when-cublaslt-exists\"
|
||||
)"),
|
||||
Rule::raw("(rule
|
||||
((= ?mul (Op (Mul ?shape ?as ?bs ?os) ?inputs))
|
||||
(= ?sum (Op (Sum ?sshape ?sk ?ssi ?sks ?sso) (ICons ?mul (INil))))
|
||||
(= ?sum (Op (KernelBatchMatMul ?bos ?bk ?bas ?baks ?bbs ?bbks ?bouts ?bdt) ?bi)))
|
||||
((delete (Op (Mul ?shape ?as ?bs ?os) ?inputs)))
|
||||
:ruleset cleanup
|
||||
)"),
|
||||
// Also remove any generic fusion wrapper that was unioned with the
|
||||
// broadcast Mul. This is deliberately a separate rule: requiring a
|
||||
// FusionEnd in the same eclass made cleanup miss valid cuBLASLt
|
||||
// matmuls when fusion wrapping was absent.
|
||||
Rule::raw("(rule
|
||||
((= ?mul (Op (FusionEnd ?fshape ?fos ?fdt) ?finputs))
|
||||
(= ?sum (Op (Sum ?sshape ?sk ?ssi ?sks ?sso) (ICons ?mul (INil))))
|
||||
((= ?sum (Op (KernelSum ?shape ?k ?sis ?sks ?sos ?dt) ?inputs))
|
||||
(= ?sum (Op (cublaslt ?cm ?cn ?ck ?cta ?ctb ?cao ?cbo ?cco ?cdo ?clda ?cldb ?cldc ?cldd ?cbc ?csa ?csb ?csc ?csd ?cadt ?cbdt ?ccdt ?cddt ?ccompute ?cscale ?calpha ?cbeta ?cepilogue) ?ci)))
|
||||
((delete (Op (FusionEnd ?fshape ?fos ?fdt) ?finputs)))
|
||||
((delete (Op (KernelSum ?shape ?k ?sis ?sks ?sos ?dt) ?inputs)))
|
||||
:ruleset cleanup
|
||||
:name \"delete-kernel-sum-when-cublaslt-exists\"
|
||||
)"),
|
||||
Rule::raw("(rule
|
||||
((= ?mul (Op (FusionEnd ?fshape ?fos ?fdt) ?finputs))
|
||||
(= ?sum (Op (Sum ?sshape ?sk ?ssi ?sks ?sso) (ICons ?mul (INil))))
|
||||
(= ?sum (Op (KernelBatchMatMul ?bos ?bk ?bas ?baks ?bbs ?bbks ?bouts ?bdt) ?bi)))
|
||||
((delete (Op (FusionEnd ?fshape ?fos ?fdt) ?finputs)))
|
||||
((= ?sum (Op (Sum ?shape ?k ?sis ?sks ?sos) ?inputs))
|
||||
(= ?sum (Op (cublaslt_scaled ?cm ?cn ?ck ?cta ?ctb ?cao ?cbo ?cco ?cdo ?clda ?cldb ?cldc ?cldd ?cbc ?csa ?csb ?csc ?csd ?cadt ?cbdt ?ccdt ?cddt ?ccompute ?cscale ?calpha ?cbeta ?cepilogue) ?ci)))
|
||||
((delete (Op (Sum ?shape ?k ?sis ?sks ?sos) ?inputs)))
|
||||
:ruleset cleanup
|
||||
:name \"delete-sum-when-scaled-cublaslt-exists\"
|
||||
)"),
|
||||
Rule::raw("(rule
|
||||
((= ?sum (Op (KernelSum ?shape ?k ?sis ?sks ?sos ?dt) ?inputs))
|
||||
(= ?sum (Op (cublaslt_scaled ?cm ?cn ?ck ?cta ?ctb ?cao ?cbo ?cco ?cdo ?clda ?cldb ?cldc ?cldd ?cbc ?csa ?csb ?csc ?csd ?cadt ?cbdt ?ccdt ?cddt ?ccompute ?cscale ?calpha ?cbeta ?cepilogue) ?ci)))
|
||||
((delete (Op (KernelSum ?shape ?k ?sis ?sks ?sos ?dt) ?inputs)))
|
||||
:ruleset cleanup
|
||||
:name \"delete-kernel-sum-when-scaled-cublaslt-exists\"
|
||||
)"),
|
||||
Rule::raw("(rule
|
||||
((= ?sum (Op (GenericMatmul ?go ?gm ?gk ?gls ?grs ?gsis ?gsit ?gos ?gdt) ?generic_inputs))
|
||||
(= ?sum (Op (cublaslt ?cm ?cn ?ck ?cta ?ctb ?cao ?cbo ?cco ?cdo ?clda ?cldb ?cldc ?cldd ?cbc ?csa ?csb ?csc ?csd ?cadt ?cbdt ?ccdt ?cddt ?ccompute ?cscale ?calpha ?cbeta ?cepilogue) ?cublas_inputs)))
|
||||
((delete (Op (GenericMatmul ?go ?gm ?gk ?gls ?grs ?gsis ?gsit ?gos ?gdt) ?generic_inputs)))
|
||||
:ruleset cleanup
|
||||
:name \"prefer-cublaslt-over-generic-matmul\"
|
||||
)"),
|
||||
Rule::raw("(rule
|
||||
((= ?sum (Op (GenericMatmul ?go ?gm ?gk ?gls ?grs ?gsis ?gsit ?gos ?gdt) ?generic_inputs))
|
||||
(= ?sum (Op (cublaslt_scaled ?cm ?cn ?ck ?cta ?ctb ?cao ?cbo ?cco ?cdo ?clda ?cldb ?cldc ?cldd ?cbc ?csa ?csb ?csc ?csd ?cadt ?cbdt ?ccdt ?cddt ?ccompute ?cscale ?calpha ?cbeta ?cepilogue) ?cublas_inputs)))
|
||||
((delete (Op (GenericMatmul ?go ?gm ?gk ?gls ?grs ?gsis ?gsit ?gos ?gdt) ?generic_inputs)))
|
||||
:ruleset cleanup
|
||||
:name \"prefer-scaled-cublaslt-over-generic-matmul\"
|
||||
)"),
|
||||
]
|
||||
}
|
||||
|
||||
@@ -2,13 +2,11 @@ use std::{fmt::Debug, sync::Arc};
|
||||
|
||||
use crate::cudarc::driver::{CudaStream, DriverError, result};
|
||||
use luminal::{op::EgglogOp, prelude::*};
|
||||
mod cublas;
|
||||
mod cublaslt;
|
||||
pub mod flashinfer;
|
||||
pub mod moe;
|
||||
|
||||
pub type Ops = (
|
||||
// cublas::CuBlasSgemmV2,
|
||||
cublaslt::CuBlasLt,
|
||||
cublaslt::CuBlasLtScaled,
|
||||
moe::GLUMoE,
|
||||
|
||||
@@ -309,6 +309,61 @@ impl EgglogOp for FusionEnd {
|
||||
// `Cycle(NodeIndex(_))`. Grow rules already compose adjacent regions
|
||||
// correctly without dissolve.
|
||||
|
||||
rules.push(Rule::raw(
|
||||
"(rule (
|
||||
(= ?fe (Op (FusionEnd ?fe_shape ?fe_s ?dt) (ICons ?inner (INil))))
|
||||
(= ?inner (Op (CudaUnaryElementwise ?op ?inner_shape ?inner_in_s ?inner_s ?dt) ?inner_inputs))
|
||||
(!= ?fe_shape ?inner_shape)
|
||||
) (
|
||||
(delete (Op (FusionEnd ?fe_shape ?fe_s ?dt) (ICons ?inner (INil))))
|
||||
) :ruleset cleanup :name \"delete-malformed-FE-unary-shape\")",
|
||||
));
|
||||
rules.push(Rule::raw(
|
||||
"(rule (
|
||||
(= ?fe (Op (FusionEnd ?fe_shape ?fe_s ?dt) (ICons ?inner (INil))))
|
||||
(= ?inner (Op (CudaUnaryElementwise ?op ?inner_shape ?inner_in_s ?inner_s ?dt) ?inner_inputs))
|
||||
(!= ?fe_s ?inner_s)
|
||||
) (
|
||||
(delete (Op (FusionEnd ?fe_shape ?fe_s ?dt) (ICons ?inner (INil))))
|
||||
) :ruleset cleanup :name \"delete-malformed-FE-unary-strides\")",
|
||||
));
|
||||
rules.push(Rule::raw(
|
||||
"(rule (
|
||||
(= ?fe (Op (FusionEnd ?fe_shape ?fe_s ?dt) (ICons ?inner (INil))))
|
||||
(= ?inner (Op (CudaBinaryElementwise ?op ?inner_shape ?a_s ?b_s ?inner_s ?dt) ?inner_inputs))
|
||||
(!= ?fe_shape ?inner_shape)
|
||||
) (
|
||||
(delete (Op (FusionEnd ?fe_shape ?fe_s ?dt) (ICons ?inner (INil))))
|
||||
) :ruleset cleanup :name \"delete-malformed-FE-binary-shape\")",
|
||||
));
|
||||
rules.push(Rule::raw(
|
||||
"(rule (
|
||||
(= ?fe (Op (FusionEnd ?fe_shape ?fe_s ?dt) (ICons ?inner (INil))))
|
||||
(= ?inner (Op (CudaBinaryElementwise ?op ?inner_shape ?a_s ?b_s ?inner_s ?dt) ?inner_inputs))
|
||||
(!= ?fe_s ?inner_s)
|
||||
) (
|
||||
(delete (Op (FusionEnd ?fe_shape ?fe_s ?dt) (ICons ?inner (INil))))
|
||||
) :ruleset cleanup :name \"delete-malformed-FE-binary-strides\")",
|
||||
));
|
||||
rules.push(Rule::raw(
|
||||
"(rule (
|
||||
(= ?fe (Op (FusionEnd ?fe_shape ?fe_s ?dt) (ICons ?inner (INil))))
|
||||
(= ?inner (Op (FusionEnd ?inner_shape ?inner_s ?dt) ?inner_inputs))
|
||||
(!= ?fe_shape ?inner_shape)
|
||||
) (
|
||||
(delete (Op (FusionEnd ?fe_shape ?fe_s ?dt) (ICons ?inner (INil))))
|
||||
) :ruleset cleanup :name \"delete-malformed-FE-nested-shape\")",
|
||||
));
|
||||
rules.push(Rule::raw(
|
||||
"(rule (
|
||||
(= ?fe (Op (FusionEnd ?fe_shape ?fe_s ?dt) (ICons ?inner (INil))))
|
||||
(= ?inner (Op (FusionEnd ?inner_shape ?inner_s ?dt) ?inner_inputs))
|
||||
(!= ?fe_s ?inner_s)
|
||||
) (
|
||||
(delete (Op (FusionEnd ?fe_shape ?fe_s ?dt) (ICons ?inner (INil))))
|
||||
) :ruleset cleanup :name \"delete-malformed-FE-nested-strides\")",
|
||||
));
|
||||
|
||||
rules
|
||||
}
|
||||
|
||||
|
||||
319
crates/luminal_cuda_lite/src/kernel/generic_matmul.rs
Normal file
319
crates/luminal_cuda_lite/src/kernel/generic_matmul.rs
Normal file
@@ -0,0 +1,319 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{
|
||||
compile_module_image_for_current_device, cuda_dtype,
|
||||
kernel::{
|
||||
KernelOp,
|
||||
hlir::{dtype_includes, generate_dyn_dims_defines},
|
||||
},
|
||||
};
|
||||
use cudarc::driver::{CudaFunction, CudaModule, CudaSlice, CudaStream};
|
||||
use luminal::{
|
||||
egglog_utils::{
|
||||
api::{Rule, SortDef, sort},
|
||||
base::{DTYPE, ELIST, EXPRESSION, OP_KIND},
|
||||
extract_dtype, extract_expr, extract_expr_list,
|
||||
},
|
||||
op::*,
|
||||
prelude::*,
|
||||
shape::flatten_strides,
|
||||
};
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct GenericMatmul {
|
||||
out_shape: Vec<Expression>,
|
||||
mul_shape: Vec<Expression>,
|
||||
k: Expression,
|
||||
lhs_strides: Vec<Expression>,
|
||||
rhs_strides: Vec<Expression>,
|
||||
sum_input_strides: Vec<Expression>,
|
||||
sum_iter_stride: Expression,
|
||||
out_strides: Vec<Expression>,
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
impl EgglogOp for GenericMatmul {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
"GenericMatmul",
|
||||
&[
|
||||
("out_shape", ELIST),
|
||||
("mul_shape", ELIST),
|
||||
("k", EXPRESSION),
|
||||
("lhs_strides", ELIST),
|
||||
("rhs_strides", ELIST),
|
||||
("sum_input_strides", ELIST),
|
||||
("sum_iter_stride", EXPRESSION),
|
||||
("out_strides", ELIST),
|
||||
("dtype", DTYPE),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
fn n_inputs(&self) -> usize {
|
||||
2
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![
|
||||
Rule::raw(
|
||||
"(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?lhs_strides ?rhs_strides ?mul_out_strides)
|
||||
(ICons ?lhs (ICons ?rhs (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_input_strides ?sum_iter_stride ?out_strides)
|
||||
(ICons ?mul (INil))))
|
||||
(= ?dt (dtype ?sum))
|
||||
)
|
||||
(
|
||||
(let ?generic (Op (GenericMatmul
|
||||
?out_shape
|
||||
?mul_shape
|
||||
?k
|
||||
?lhs_strides
|
||||
?rhs_strides
|
||||
?sum_input_strides
|
||||
?sum_iter_stride
|
||||
?out_strides
|
||||
?dt)
|
||||
(ICons ?lhs (ICons ?rhs (INil)))))
|
||||
(union ?sum ?generic)
|
||||
(set (dtype ?generic) ?dt)
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name \"generic-matmul-cuda-mul-sum\"
|
||||
)",
|
||||
),
|
||||
Rule::raw(
|
||||
"(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?lhs_strides ?rhs_strides ?mul_out_strides)
|
||||
(ICons ?lhs (ICons ?rhs (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_input_strides ?sum_iter_stride ?out_strides)
|
||||
(ICons ?mul (INil))))
|
||||
(= ?sum (Op (GenericMatmul
|
||||
?go ?gm ?gk ?gls ?grs ?gsis ?gsit ?gos ?gdt)
|
||||
?generic_inputs))
|
||||
)
|
||||
(
|
||||
(delete (Op (Sum ?out_shape ?k ?sum_input_strides ?sum_iter_stride ?out_strides)
|
||||
(ICons ?mul (INil))))
|
||||
)
|
||||
:ruleset cleanup
|
||||
:name \"delete-sum-when-generic-matmul-exists\"
|
||||
)",
|
||||
),
|
||||
Rule::raw(
|
||||
"(rule
|
||||
(
|
||||
(= ?kernel_sum (Op (KernelSum ?out_shape ?k ?sum_input_strides ?sum_iter_stride ?out_strides ?dt)
|
||||
?sum_inputs))
|
||||
(= ?kernel_sum (Op (GenericMatmul
|
||||
?go ?gm ?gk ?gls ?grs ?gsis ?gsit ?gos ?gdt)
|
||||
?generic_inputs))
|
||||
)
|
||||
((delete (Op (KernelSum ?out_shape ?k ?sum_input_strides ?sum_iter_stride ?out_strides ?dt)
|
||||
?sum_inputs)))
|
||||
:ruleset cleanup
|
||||
:name \"delete-kernel-sum-when-generic-matmul-exists\"
|
||||
)",
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
(
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(Self {
|
||||
out_shape: extract_expr_list(egraph, kind_children[0], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
mul_shape: extract_expr_list(egraph, kind_children[1], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
k: extract_expr(egraph, kind_children[2], expr_cache).unwrap(),
|
||||
lhs_strides: extract_expr_list(egraph, kind_children[3], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
rhs_strides: extract_expr_list(egraph, kind_children[4], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
sum_input_strides: extract_expr_list(
|
||||
egraph,
|
||||
kind_children[5],
|
||||
list_cache,
|
||||
expr_cache,
|
||||
)
|
||||
.unwrap(),
|
||||
sum_iter_stride: extract_expr(egraph, kind_children[6], expr_cache).unwrap(),
|
||||
out_strides: extract_expr_list(egraph, kind_children[7], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
dtype: extract_dtype(egraph, kind_children[8]),
|
||||
})),
|
||||
input_enodes,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl KernelOp for GenericMatmul {
|
||||
fn compile(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> (
|
||||
CudaFunction,
|
||||
Arc<CudaModule>,
|
||||
String,
|
||||
(Expression, Expression, Expression),
|
||||
(Expression, Expression, Expression),
|
||||
Expression,
|
||||
FxHashMap<char, CudaSlice<u8>>,
|
||||
) {
|
||||
let vars = self.all_dyn_vars();
|
||||
let dtype = cuda_dtype(self.dtype);
|
||||
let includes = dtype_includes(&[self.dtype]);
|
||||
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
|
||||
let dyn_dims_param = if vars.is_empty() {
|
||||
""
|
||||
} else {
|
||||
", const int* dyn_dims"
|
||||
};
|
||||
|
||||
let n_outputs = self.output_size();
|
||||
let sum_base_idx = flatten_strides(&self.out_shape, &self.sum_input_strides).to_kernel();
|
||||
let iter_offset = self.sum_iter_stride.to_kernel().replace("const_z", "i");
|
||||
let lhs_idx = flatten_strides(&self.mul_shape, &self.lhs_strides)
|
||||
.to_kernel()
|
||||
.replace("const_z", "mul_idx");
|
||||
let rhs_idx = flatten_strides(&self.mul_shape, &self.rhs_strides)
|
||||
.to_kernel()
|
||||
.replace("const_z", "mul_idx");
|
||||
let out_idx = flatten_strides(&self.out_shape, &self.out_strides).to_kernel();
|
||||
let k = self.k.to_kernel();
|
||||
|
||||
let kernel = format!(
|
||||
"{includes}
|
||||
#define WARP_SIZE 32
|
||||
#define THREADS_PER_BLOCK 256
|
||||
#define FULL_MASK 0xffffffff
|
||||
{dyn_defines}
|
||||
extern \"C\" {{
|
||||
__global__ void generic_matmul({dtype} *out, const {dtype} *lhs, const {dtype} *rhs{dyn_dims_param}) {{
|
||||
__shared__ float warp_sums[THREADS_PER_BLOCK / WARP_SIZE];
|
||||
long long const_z = blockIdx.x;
|
||||
if (const_z >= {n_outputs}) return;
|
||||
|
||||
int tid = threadIdx.x;
|
||||
int lane_id = tid % WARP_SIZE;
|
||||
int warp_id = tid / WARP_SIZE;
|
||||
|
||||
long long base_idx = {sum_base_idx};
|
||||
long long iters = {k};
|
||||
|
||||
float partial = 0.0f;
|
||||
for (long long i = tid; i < iters; i += THREADS_PER_BLOCK) {{
|
||||
long long mul_idx = base_idx + {iter_offset};
|
||||
partial += static_cast<float>(lhs[{lhs_idx}]) * static_cast<float>(rhs[{rhs_idx}]);
|
||||
}}
|
||||
|
||||
#pragma unroll
|
||||
for (int s = WARP_SIZE / 2; s > 0; s >>= 1) {{
|
||||
partial += __shfl_down_sync(FULL_MASK, partial, s);
|
||||
}}
|
||||
|
||||
if (lane_id == 0) {{
|
||||
warp_sums[warp_id] = partial;
|
||||
}}
|
||||
__syncthreads();
|
||||
|
||||
if (warp_id == 0) {{
|
||||
float block_sum = tid < (THREADS_PER_BLOCK / WARP_SIZE) ? warp_sums[tid] : 0.0f;
|
||||
|
||||
#pragma unroll
|
||||
for (int s = (THREADS_PER_BLOCK / WARP_SIZE) / 2; s > 0; s >>= 1) {{
|
||||
block_sum += __shfl_down_sync(FULL_MASK, block_sum, s);
|
||||
}}
|
||||
|
||||
if (tid == 0) {{
|
||||
out[{out_idx}] = ({dtype})block_sum;
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
}}",
|
||||
n_outputs = n_outputs.to_kernel(),
|
||||
);
|
||||
|
||||
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
|
||||
(module.clone(), func.clone())
|
||||
} else {
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let func = module.load_function("generic_matmul").unwrap();
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
(module, func)
|
||||
};
|
||||
|
||||
(
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(n_outputs, 1.into(), 1.into()),
|
||||
(256.into(), 1.into(), 1.into()),
|
||||
32.into(),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
self.out_shape
|
||||
.iter()
|
||||
.copied()
|
||||
.product::<Expression>()
|
||||
.max(Expression::from(1))
|
||||
}
|
||||
|
||||
fn all_dyn_vars(&self) -> FxHashSet<char> {
|
||||
self.out_shape
|
||||
.iter()
|
||||
.flat_map(|e| e.dyn_vars())
|
||||
.chain(self.mul_shape.iter().flat_map(|e| e.dyn_vars()))
|
||||
.chain(self.k.dyn_vars())
|
||||
.chain(self.lhs_strides.iter().flat_map(|e| e.dyn_vars()))
|
||||
.chain(self.rhs_strides.iter().flat_map(|e| e.dyn_vars()))
|
||||
.chain(self.sum_input_strides.iter().flat_map(|e| e.dyn_vars()))
|
||||
.chain(self.sum_iter_stride.dyn_vars())
|
||||
.chain(self.out_strides.iter().flat_map(|e| e.dyn_vars()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn output_bytes(&self) -> Expression {
|
||||
(self.output_size() * self.dtype.bits()).ceil_div(8)
|
||||
}
|
||||
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
(self.output_size() * self.k * self.dtype.bits() * 2).ceil_div(8)
|
||||
}
|
||||
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
self.output_bytes()
|
||||
}
|
||||
|
||||
fn flops(&self) -> Expression {
|
||||
self.output_size() * self.k * 2
|
||||
}
|
||||
|
||||
fn output_dtype(&self) -> DType {
|
||||
self.dtype
|
||||
}
|
||||
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
"GenericMatmul"
|
||||
}
|
||||
}
|
||||
@@ -12,6 +12,7 @@ use uuid::Uuid;
|
||||
pub mod conv2d;
|
||||
pub mod cuda_graph;
|
||||
pub mod fusion;
|
||||
pub mod generic_matmul;
|
||||
pub mod hlir;
|
||||
pub mod matmul2d;
|
||||
pub mod other_ops;
|
||||
@@ -19,13 +20,20 @@ pub mod rope;
|
||||
|
||||
pub use conv2d::KernelConv2D;
|
||||
pub use cuda_graph::*;
|
||||
pub use generic_matmul::GenericMatmul;
|
||||
pub use matmul2d::{
|
||||
Matmul2DCustom, Matmul2DKernel, linear_bias, linear_no_bias_bf16_w, matmul_2d, matmul_2d_t,
|
||||
matmul_3d, matmul_3d_t,
|
||||
};
|
||||
pub use rope::{RoPECustom, RoPEKernel, apply_rope};
|
||||
|
||||
pub type Ops = (hlir::Ops, other_ops::Ops, conv2d::KernelConv2D, fusion::Ops);
|
||||
pub type Ops = (
|
||||
hlir::Ops,
|
||||
other_ops::Ops,
|
||||
conv2d::KernelConv2D,
|
||||
GenericMatmul,
|
||||
fusion::Ops,
|
||||
);
|
||||
|
||||
/// Build a mapping from interned string IDs to their string values for a given sequence.
|
||||
fn build_interned_strings(trace: &schema::Trace) -> std::collections::HashMap<(u32, u64), String> {
|
||||
|
||||
@@ -17,13 +17,7 @@ use luminal::{
|
||||
prelude::*,
|
||||
};
|
||||
|
||||
pub type Ops = (
|
||||
KernelMeanReduce,
|
||||
KernelBatchMatVec,
|
||||
KernelBatchMatMul,
|
||||
KernelScatterNoCopy,
|
||||
KernelSoftmax,
|
||||
);
|
||||
pub type Ops = (KernelMeanReduce, KernelScatterNoCopy, KernelSoftmax);
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
|
||||
@@ -619,569 +613,6 @@ extern \"C\" {{
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// KernelBatchMatVec: Fused batched matrix-vector product for attention
|
||||
// Matches: Mul(broadcast) + Sum pattern for [B, 1, K] x [B, K, N] -> [B, 1, N]
|
||||
// or [B, M, K] x [B, K, N] -> [B, M, N] with small M
|
||||
// Replaces the broadcast elementwise Mul + single-threaded KernelSumReduce pipeline
|
||||
// =============================================================================
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct KernelBatchMatVec {
|
||||
// Output shape: the final reduced shape [B..., M, N]
|
||||
out_shape: Vec<Expression>,
|
||||
// K: the reduction dimension (was the Sum iters)
|
||||
k_dim: Expression,
|
||||
// Strides for input A (with K dim removed)
|
||||
a_stride: Vec<Expression>,
|
||||
a_k_stride: Expression,
|
||||
// Strides for input B (with K dim removed)
|
||||
b_stride: Vec<Expression>,
|
||||
b_k_stride: Expression,
|
||||
// Output strides
|
||||
out_stride: Vec<Expression>,
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
impl EgglogOp for KernelBatchMatVec {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
"KernelBatchMatVec",
|
||||
&[
|
||||
("out_shape", ELIST),
|
||||
("k_dim", EXPRESSION),
|
||||
("a_stride", ELIST),
|
||||
("a_k_stride", EXPRESSION),
|
||||
("b_stride", ELIST),
|
||||
("b_k_stride", EXPRESSION),
|
||||
("out_strides", ELIST),
|
||||
("dtype", DTYPE),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
fn n_inputs(&self) -> usize {
|
||||
2
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![Rule::raw(
|
||||
"(rule
|
||||
(
|
||||
; Match Mul node (broadcast multiply)
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
; Match Sum that reduces the Mul (k dimension)
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
|
||||
; Output shape must have 3+ dimensions (batched)
|
||||
(= ?out_shape (ECons ?batch_or_d0 (ECons ?d1 (ECons ?d2 ?rest))))
|
||||
|
||||
; k_stride must be contiguous
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
; Get A's k-dimension stride (second from end in Mul's a_stride)
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 1))
|
||||
|
||||
; Get B's k-dimension stride (second from end in Mul's b_stride)
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 1))
|
||||
|
||||
; A's k stride must be contiguous (row-major A)
|
||||
(= ?a_k_stride (MIter))
|
||||
|
||||
; B's k stride must be contiguous (col-major B)
|
||||
(= ?b_k_stride (MIter))
|
||||
|
||||
; Must be F32
|
||||
(= (F32) (dtype ?a))
|
||||
(= (F32) (dtype ?b))
|
||||
)
|
||||
(
|
||||
; Remove the k-dimension from A strides for the kernel
|
||||
(let ?a_kern_stride (RemoveNthFromEnd ?a_stride 1))
|
||||
; Remove the k-dimension from B strides
|
||||
(let ?b_kern_stride (RemoveNthFromEnd ?b_stride 1))
|
||||
|
||||
(let ?bmv (Op (KernelBatchMatVec
|
||||
?out_shape ?k
|
||||
?a_kern_stride ?a_k_stride
|
||||
?b_kern_stride ?b_k_stride
|
||||
?sum_out_stride (F32)) (ICons ?a (ICons ?b (INil)))))
|
||||
(union ?sum ?bmv)
|
||||
(set (dtype ?bmv) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name \"batch mat-vec\"
|
||||
)"
|
||||
)]
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
(
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(Self {
|
||||
out_shape: extract_expr_list(egraph, kind_children[0], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
k_dim: extract_expr(egraph, kind_children[1], expr_cache).unwrap(),
|
||||
a_stride: extract_expr_list(egraph, kind_children[2], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
a_k_stride: extract_expr(egraph, kind_children[3], expr_cache).unwrap(),
|
||||
b_stride: extract_expr_list(egraph, kind_children[4], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
b_k_stride: extract_expr(egraph, kind_children[5], expr_cache).unwrap(),
|
||||
out_stride: extract_expr_list(egraph, kind_children[6], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
dtype: extract_dtype(egraph, kind_children[7]),
|
||||
})),
|
||||
input_enodes, // A, B
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl KernelOp for KernelBatchMatVec {
|
||||
fn compile(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> (
|
||||
CudaFunction,
|
||||
Arc<CudaModule>,
|
||||
String,
|
||||
(Expression, Expression, Expression),
|
||||
(Expression, Expression, Expression),
|
||||
Expression,
|
||||
FxHashMap<char, CudaSlice<u8>>,
|
||||
) {
|
||||
let vars: FxHashSet<char> = self
|
||||
.out_shape
|
||||
.iter()
|
||||
.flat_map(|e| e.dyn_vars())
|
||||
.chain(self.a_stride.iter().flat_map(|e| e.dyn_vars()))
|
||||
.chain(self.b_stride.iter().flat_map(|e| e.dyn_vars()))
|
||||
.chain(self.out_stride.iter().flat_map(|e| e.dyn_vars()))
|
||||
.chain(self.k_dim.dyn_vars())
|
||||
.chain(self.a_k_stride.dyn_vars())
|
||||
.chain(self.b_k_stride.dyn_vars())
|
||||
.collect();
|
||||
|
||||
let n_outputs: Expression = self.out_shape.iter().copied().product();
|
||||
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
|
||||
let dyn_dims_param = if vars.is_empty() {
|
||||
""
|
||||
} else {
|
||||
", const int* dyn_dims"
|
||||
};
|
||||
|
||||
// Each output element is a dot product of length K.
|
||||
// We launch one block of 256 threads per output element.
|
||||
// Threads cooperatively reduce K using warp shuffles.
|
||||
let a_idx = flatten_strides(&self.out_shape, &self.a_stride).to_kernel();
|
||||
let b_idx = flatten_strides(&self.out_shape, &self.b_stride).to_kernel();
|
||||
let out_idx = flatten_strides(&self.out_shape, &self.out_stride).to_kernel();
|
||||
let k_expr = self.k_dim.to_kernel();
|
||||
let a_k_stride_expr = self
|
||||
.a_k_stride
|
||||
.substitute('z', Expression::from(1))
|
||||
.simplify()
|
||||
.to_kernel();
|
||||
let b_k_stride_expr = self
|
||||
.b_k_stride
|
||||
.substitute('z', Expression::from(1))
|
||||
.simplify()
|
||||
.to_kernel();
|
||||
|
||||
let kernel = format!(
|
||||
"
|
||||
#define WARP_SIZE 32
|
||||
#define THREADS_PER_BLOCK 256
|
||||
#define FULL_MASK 0xffffffff
|
||||
{dyn_defines}
|
||||
extern \"C\" {{
|
||||
__global__ void batch_matvec(float *out, const float *A, const float *B{dyn_dims_param}) {{
|
||||
__shared__ float warp_sums[THREADS_PER_BLOCK / WARP_SIZE];
|
||||
long long const_z = blockIdx.x;
|
||||
int tid = threadIdx.x;
|
||||
int lane_id = tid % WARP_SIZE;
|
||||
int warp_id = tid / WARP_SIZE;
|
||||
|
||||
long long a_base = {a_idx};
|
||||
long long b_base = {b_idx};
|
||||
long long K = {k_expr};
|
||||
long long a_k_stride = {a_k_stride_expr};
|
||||
long long b_k_stride = {b_k_stride_expr};
|
||||
|
||||
float partial = 0.0f;
|
||||
for (long long k = tid; k < K; k += THREADS_PER_BLOCK) {{
|
||||
partial += A[a_base + k * a_k_stride] * B[b_base + k * b_k_stride];
|
||||
}}
|
||||
|
||||
#pragma unroll
|
||||
for (int s = WARP_SIZE / 2; s > 0; s /= 2) {{
|
||||
partial += __shfl_down_sync(FULL_MASK, partial, s);
|
||||
}}
|
||||
|
||||
if (lane_id == 0) {{
|
||||
warp_sums[warp_id] = partial;
|
||||
}}
|
||||
__syncthreads();
|
||||
|
||||
if (warp_id == 0) {{
|
||||
int cnt = THREADS_PER_BLOCK / WARP_SIZE;
|
||||
float block_sum = tid < cnt ? warp_sums[tid] : 0.0f;
|
||||
|
||||
#pragma unroll
|
||||
for (int s = cnt / 2; s > 0; s /= 2) {{
|
||||
block_sum += __shfl_down_sync(FULL_MASK, block_sum, s);
|
||||
}}
|
||||
|
||||
if (tid == 0) {{
|
||||
out[{out_idx}] = block_sum;
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
}}"
|
||||
);
|
||||
|
||||
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
|
||||
(module.clone(), func.clone())
|
||||
} else {
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let func = module.load_function("batch_matvec").unwrap();
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
(module, func)
|
||||
};
|
||||
|
||||
(
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(n_outputs, 1.into(), 1.into()), // grid: one block per output
|
||||
(256.into(), 1.into(), 1.into()), // block: 256 threads
|
||||
32.into(), // shared mem for warp_sums
|
||||
FxHashMap::default(),
|
||||
)
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
self.out_shape.iter().copied().product()
|
||||
}
|
||||
|
||||
fn output_bytes(&self) -> Expression {
|
||||
self.output_size() * 4
|
||||
}
|
||||
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
let n = self.output_size();
|
||||
// Each output loads K elements from A and K elements from B
|
||||
n * self.k_dim * 2 * 4
|
||||
}
|
||||
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
self.output_size() * 4
|
||||
}
|
||||
|
||||
fn flops(&self) -> Expression {
|
||||
// Each output: K multiply-adds = 2*K FLOPs
|
||||
self.output_size() * self.k_dim * 2
|
||||
}
|
||||
|
||||
fn output_dtype(&self) -> DType {
|
||||
self.dtype
|
||||
}
|
||||
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
"BatchMatVec"
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// KernelBatchMatMul: General batched matmul with arbitrary strides
|
||||
// Like KernelBatchMatVec but handles non-contiguous K strides (e.g., transposed
|
||||
// inputs) and non-uniform batch strides (e.g., GQA expansion). One block of 256
|
||||
// threads per output element; threads cooperatively reduce along K.
|
||||
// =============================================================================
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct KernelBatchMatMul {
|
||||
out_shape: Vec<Expression>,
|
||||
k_dim: Expression,
|
||||
a_stride: Vec<Expression>,
|
||||
a_k_stride: Expression,
|
||||
b_stride: Vec<Expression>,
|
||||
b_k_stride: Expression,
|
||||
out_stride: Vec<Expression>,
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
impl EgglogOp for KernelBatchMatMul {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
"KernelBatchMatMul",
|
||||
&[
|
||||
("out_shape", ELIST),
|
||||
("k_dim", EXPRESSION),
|
||||
("a_stride", ELIST),
|
||||
("a_k_stride", EXPRESSION),
|
||||
("b_stride", ELIST),
|
||||
("b_k_stride", EXPRESSION),
|
||||
("out_strides", ELIST),
|
||||
("dtype", DTYPE),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
fn n_inputs(&self) -> usize {
|
||||
2
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![Rule::raw(
|
||||
"(rule
|
||||
(
|
||||
; Match Mul node (broadcast multiply)
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
; Match Sum that reduces the Mul (k dimension)
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
|
||||
; Output shape must have 3+ dimensions (batched)
|
||||
(= ?out_shape (ECons ?batch_or_d0 (ECons ?d1 (ECons ?d2 ?rest))))
|
||||
|
||||
; k_stride must be contiguous in the Sum output
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
; K must be > 1 (K=1 is a degenerate outer product, not a real matmul)
|
||||
(!= ?k (MNum 1))
|
||||
|
||||
; Get A's and B's k-dimension strides (no contiguity requirement)
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 1))
|
||||
|
||||
; One of A's non-k strides must be 0 (broadcast along n)
|
||||
(= (MNum 0) (nth_from_end ?a_stride 0))
|
||||
|
||||
; One of B's non-k strides must be 0 (broadcast along m)
|
||||
(= (MNum 0) (nth_from_end ?b_stride 2))
|
||||
|
||||
; Must be F32
|
||||
(= (F32) (dtype ?a))
|
||||
(= (F32) (dtype ?b))
|
||||
)
|
||||
(
|
||||
(let ?a_kern_stride (RemoveNthFromEnd ?a_stride 1))
|
||||
(let ?b_kern_stride (RemoveNthFromEnd ?b_stride 1))
|
||||
|
||||
(let ?bmm (Op (KernelBatchMatMul
|
||||
?out_shape ?k
|
||||
?a_kern_stride ?a_k_stride
|
||||
?b_kern_stride ?b_k_stride
|
||||
?sum_out_stride (F32)) (ICons ?a (ICons ?b (INil)))))
|
||||
(union ?sum ?bmm)
|
||||
(set (dtype ?bmm) (F32))
|
||||
)
|
||||
:ruleset matmul_backend
|
||||
:name \"batch matmul\"
|
||||
)"
|
||||
)]
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
(
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(Self {
|
||||
out_shape: extract_expr_list(egraph, kind_children[0], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
k_dim: extract_expr(egraph, kind_children[1], expr_cache).unwrap(),
|
||||
a_stride: extract_expr_list(egraph, kind_children[2], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
a_k_stride: extract_expr(egraph, kind_children[3], expr_cache).unwrap(),
|
||||
b_stride: extract_expr_list(egraph, kind_children[4], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
b_k_stride: extract_expr(egraph, kind_children[5], expr_cache).unwrap(),
|
||||
out_stride: extract_expr_list(egraph, kind_children[6], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
dtype: extract_dtype(egraph, kind_children[7]),
|
||||
})),
|
||||
input_enodes,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl KernelOp for KernelBatchMatMul {
|
||||
fn compile(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> (
|
||||
CudaFunction,
|
||||
Arc<CudaModule>,
|
||||
String,
|
||||
(Expression, Expression, Expression),
|
||||
(Expression, Expression, Expression),
|
||||
Expression,
|
||||
FxHashMap<char, CudaSlice<u8>>,
|
||||
) {
|
||||
let vars: FxHashSet<char> = self
|
||||
.out_shape
|
||||
.iter()
|
||||
.flat_map(|e| e.dyn_vars())
|
||||
.chain(self.a_stride.iter().flat_map(|e| e.dyn_vars()))
|
||||
.chain(self.b_stride.iter().flat_map(|e| e.dyn_vars()))
|
||||
.chain(self.out_stride.iter().flat_map(|e| e.dyn_vars()))
|
||||
.chain(self.k_dim.dyn_vars())
|
||||
.chain(self.a_k_stride.dyn_vars())
|
||||
.chain(self.b_k_stride.dyn_vars())
|
||||
.collect();
|
||||
|
||||
let n_outputs: Expression = self.out_shape.iter().copied().product();
|
||||
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
|
||||
let dyn_dims_param = if vars.is_empty() {
|
||||
""
|
||||
} else {
|
||||
", const int* dyn_dims"
|
||||
};
|
||||
|
||||
let a_idx = flatten_strides(&self.out_shape, &self.a_stride).to_kernel();
|
||||
let b_idx = flatten_strides(&self.out_shape, &self.b_stride).to_kernel();
|
||||
let out_idx = flatten_strides(&self.out_shape, &self.out_stride).to_kernel();
|
||||
let k_expr = self.k_dim.to_kernel();
|
||||
let a_k_stride_expr = self
|
||||
.a_k_stride
|
||||
.substitute('z', Expression::from(1))
|
||||
.simplify()
|
||||
.to_kernel();
|
||||
let b_k_stride_expr = self
|
||||
.b_k_stride
|
||||
.substitute('z', Expression::from(1))
|
||||
.simplify()
|
||||
.to_kernel();
|
||||
|
||||
let kernel = format!(
|
||||
"
|
||||
#define WARP_SIZE 32
|
||||
#define THREADS_PER_BLOCK 256
|
||||
#define FULL_MASK 0xffffffff
|
||||
{dyn_defines}
|
||||
extern \"C\" {{
|
||||
__global__ void batch_matmul(float *out, const float *A, const float *B{dyn_dims_param}) {{
|
||||
__shared__ float warp_sums[THREADS_PER_BLOCK / WARP_SIZE];
|
||||
long long const_z = blockIdx.x;
|
||||
int tid = threadIdx.x;
|
||||
int lane_id = tid % WARP_SIZE;
|
||||
int warp_id = tid / WARP_SIZE;
|
||||
|
||||
long long a_base = {a_idx};
|
||||
long long b_base = {b_idx};
|
||||
long long K = {k_expr};
|
||||
long long a_k_stride = {a_k_stride_expr};
|
||||
long long b_k_stride = {b_k_stride_expr};
|
||||
|
||||
float partial = 0.0f;
|
||||
for (long long k = tid; k < K; k += THREADS_PER_BLOCK) {{
|
||||
partial += A[a_base + k * a_k_stride] * B[b_base + k * b_k_stride];
|
||||
}}
|
||||
|
||||
#pragma unroll
|
||||
for (int s = WARP_SIZE / 2; s > 0; s /= 2) {{
|
||||
partial += __shfl_down_sync(FULL_MASK, partial, s);
|
||||
}}
|
||||
|
||||
if (lane_id == 0) {{
|
||||
warp_sums[warp_id] = partial;
|
||||
}}
|
||||
__syncthreads();
|
||||
|
||||
if (warp_id == 0) {{
|
||||
int cnt = THREADS_PER_BLOCK / WARP_SIZE;
|
||||
float block_sum = tid < cnt ? warp_sums[tid] : 0.0f;
|
||||
|
||||
#pragma unroll
|
||||
for (int s = cnt / 2; s > 0; s /= 2) {{
|
||||
block_sum += __shfl_down_sync(FULL_MASK, block_sum, s);
|
||||
}}
|
||||
|
||||
if (tid == 0) {{
|
||||
out[{out_idx}] = block_sum;
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
}}"
|
||||
);
|
||||
|
||||
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
|
||||
(module.clone(), func.clone())
|
||||
} else {
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let func = module.load_function("batch_matmul").unwrap();
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
(module, func)
|
||||
};
|
||||
|
||||
(
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(n_outputs, 1.into(), 1.into()),
|
||||
(256.into(), 1.into(), 1.into()),
|
||||
32.into(),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
self.out_shape.iter().copied().product()
|
||||
}
|
||||
|
||||
fn output_bytes(&self) -> Expression {
|
||||
self.output_size() * 4
|
||||
}
|
||||
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
let n = self.output_size();
|
||||
n * self.k_dim * 2 * 4
|
||||
}
|
||||
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
self.output_size() * 4
|
||||
}
|
||||
|
||||
fn flops(&self) -> Expression {
|
||||
self.output_size() * self.k_dim * 2
|
||||
}
|
||||
|
||||
fn output_dtype(&self) -> DType {
|
||||
self.dtype
|
||||
}
|
||||
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
"BatchMatMul"
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// KernelSoftmax: Fused softmax over last dimension
|
||||
// Matches: Mul(Recip(Sum(Exp2(Sub(x, Max(x))))), Exp2(Sub(x, Max(x))))
|
||||
|
||||
@@ -342,8 +342,7 @@ impl CudaGraphOp {
|
||||
"Constant" | "Iota" => Some(0),
|
||||
"MaxReduce" | "MeanReduce" | "SumReduce" | "Cast" | "Exp" | "Exp2" | "Log2" | "Sin"
|
||||
| "Recip" | "Sigmoid" | "Softmax" | "Sqrt" => Some(1),
|
||||
"Add" | "BatchMatMul" | "BatchMatVec" | "Embed" | "Gather" | "LessThan" | "Mod"
|
||||
| "Mul" => Some(2),
|
||||
"Add" | "Embed" | "Gather" | "GenericMatmul" | "LessThan" | "Mod" | "Mul" => Some(2),
|
||||
"Scatter" | "ScatterNoCopy" => Some(3),
|
||||
_ => None,
|
||||
}
|
||||
|
||||
@@ -80,6 +80,14 @@ struct PlannedBuffer {
|
||||
end: usize,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct NonFiniteBufferReport {
|
||||
pub(crate) node: NodeIndex,
|
||||
pub(crate) index: usize,
|
||||
pub(crate) value: f32,
|
||||
}
|
||||
|
||||
/// Per-bucket compiled state. Each bucket holds its own executable graph,
|
||||
/// explicit runtime metadata, intermediate buffers, and node mappings.
|
||||
/// Weights (hlir_buffers) are shared.
|
||||
@@ -106,6 +114,9 @@ pub(crate) struct CompiledBucket {
|
||||
pub(crate) bucket_indices: FxHashMap<char, usize>,
|
||||
/// Whether HLIR pointers have been synced into this bucket's cached_buffer_ptrs
|
||||
pub(crate) hlir_synced: bool,
|
||||
/// Test/debug mode: give every intermediate a distinct arena range so
|
||||
/// post-execution diagnostics can inspect expired nodes without reuse noise.
|
||||
pub(crate) preserve_intermediate_buffers_for_debug: bool,
|
||||
}
|
||||
|
||||
impl CompiledBucket {
|
||||
@@ -130,6 +141,7 @@ impl CompiledBucket {
|
||||
intermediate_buffer_dims: FxHashSet::default(),
|
||||
bucket_indices: FxHashMap::default(),
|
||||
hlir_synced: false,
|
||||
preserve_intermediate_buffers_for_debug: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -228,6 +240,93 @@ impl CudaRuntime {
|
||||
dst
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn first_nonfinite_f32_buffer_in_nodes(
|
||||
&self,
|
||||
nodes: impl IntoIterator<Item = NodeIndex>,
|
||||
) -> Option<NonFiniteBufferReport> {
|
||||
let _ = self.cuda_stream.synchronize();
|
||||
let bucket = self.active();
|
||||
let mut checked = FxHashSet::default();
|
||||
|
||||
for node in nodes {
|
||||
let spec_node = resolve_logical_buffer_node(
|
||||
node,
|
||||
&bucket.logical_buffer_bytes,
|
||||
&bucket.output_alias_map,
|
||||
)
|
||||
.unwrap_or(node);
|
||||
if !checked.insert(spec_node) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let Some(spec) = bucket.buffer_specs.get(&spec_node) else {
|
||||
continue;
|
||||
};
|
||||
if !matches!(spec.dtype, DType::F32) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let Some(buf) = Self::resolve_runtime_buffer(
|
||||
bucket,
|
||||
&self.cuda_stream,
|
||||
&self.hlir_buffers,
|
||||
&self.external_buffers,
|
||||
&self.external_output_buffers,
|
||||
spec_node,
|
||||
) else {
|
||||
continue;
|
||||
};
|
||||
if buf.is_empty() || buf.len() % std::mem::size_of::<f32>() != 0 {
|
||||
continue;
|
||||
}
|
||||
|
||||
let host_bytes = match buf.clone_dtoh(&self.cuda_stream) {
|
||||
Ok(bytes) => bytes,
|
||||
Err(_) => continue,
|
||||
};
|
||||
let values: &[f32] = bytemuck::cast_slice(&host_bytes);
|
||||
if let Some((index, value)) = values
|
||||
.iter()
|
||||
.copied()
|
||||
.enumerate()
|
||||
.find(|(_, value)| !value.is_finite())
|
||||
{
|
||||
return Some(NonFiniteBufferReport {
|
||||
node: spec_node,
|
||||
index,
|
||||
value,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn first_nonfinite_f32_buffer(&self) -> Option<NonFiniteBufferReport> {
|
||||
let bucket = self.active();
|
||||
self.first_nonfinite_f32_buffer_in_nodes(
|
||||
bucket
|
||||
.buffer_specs
|
||||
.keys()
|
||||
.copied()
|
||||
.sorted_by_key(|node| node.index()),
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn preserve_intermediate_buffers_for_debug(&mut self) {
|
||||
for bucket in &mut self.compiled_buckets {
|
||||
bucket.preserve_intermediate_buffers_for_debug = true;
|
||||
bucket.logical_buffer_offsets.clear();
|
||||
bucket.logical_buffer_bytes.clear();
|
||||
bucket.cached_buffer_ptrs.clear();
|
||||
bucket.arena = None;
|
||||
bucket.arena_bytes = 0;
|
||||
}
|
||||
}
|
||||
|
||||
fn resolve_runtime_buffer(
|
||||
bucket: &CompiledBucket,
|
||||
stream: &Arc<CudaStream>,
|
||||
@@ -898,6 +997,32 @@ impl CudaRuntime {
|
||||
let planned_logical_bytes = planned.iter().map(|buf| buf.bytes).sum::<usize>();
|
||||
let logical_peak = logical_interval_peak(&planned);
|
||||
|
||||
if bucket.preserve_intermediate_buffers_for_debug {
|
||||
planned.sort_by_key(|buf| buf.node.index());
|
||||
let mut arena_end = 0usize;
|
||||
for buf in &planned {
|
||||
let offset = align_up(arena_end, ARENA_ALIGNMENT);
|
||||
bucket.logical_buffer_offsets.insert(buf.node, offset);
|
||||
bucket.logical_buffer_bytes.insert(buf.node, buf.bytes);
|
||||
arena_end = offset + align_up(buf.bytes, ARENA_ALIGNMENT);
|
||||
}
|
||||
bucket.arena_bytes = arena_end;
|
||||
|
||||
if std::env::var_os("LUMINAL_CUDA_MEMORY_DEBUG").is_some() {
|
||||
eprintln!(
|
||||
" CUDA memory plan specs={total_spec_count} used={planned_logical_count} skipped={} spec_bytes={} used_bytes={} skipped_bytes={} logical_peak={} preserved_arena={} allocations={}",
|
||||
total_spec_count.saturating_sub(planned_logical_count),
|
||||
total_spec_bytes,
|
||||
planned_logical_bytes,
|
||||
total_spec_bytes.saturating_sub(planned_logical_bytes),
|
||||
logical_peak,
|
||||
bucket.arena_bytes,
|
||||
bucket.logical_buffer_offsets.len(),
|
||||
);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
let mut arena_end = 0usize;
|
||||
let mut placed: Vec<(usize, usize, usize, usize)> = Vec::with_capacity(planned.len());
|
||||
let mut placement_order = planned.iter().collect_vec();
|
||||
@@ -1409,6 +1534,35 @@ impl Runtime for CudaRuntime {
|
||||
.filter(|n| n.to_dialect::<dyn HostOp>().is_some())
|
||||
.count()
|
||||
);
|
||||
let display = if std::env::var_os("LUMINAL_SEARCH_OP_NAMES").is_some() {
|
||||
let mut kernel_counts = std::collections::BTreeMap::<&'static str, usize>::new();
|
||||
let mut host_counts = std::collections::BTreeMap::<String, usize>::new();
|
||||
for node in llir_graph.node_weights() {
|
||||
if let Some(kernel) = node.to_dialect::<dyn KernelOp>() {
|
||||
*kernel_counts.entry(kernel.kernel_name()).or_default() += 1;
|
||||
}
|
||||
if let Some(host) = node.to_dialect::<dyn HostOp>() {
|
||||
let debug = format!("{:?}", host.as_ref().as_ref());
|
||||
let name = debug
|
||||
.split([' ', '{', '('])
|
||||
.next()
|
||||
.unwrap_or("HostOp")
|
||||
.to_string();
|
||||
*host_counts.entry(name).or_default() += 1;
|
||||
}
|
||||
}
|
||||
let kernel_summary = kernel_counts
|
||||
.iter()
|
||||
.map(|(name, count)| format!("{name}:{count}"))
|
||||
.join(",");
|
||||
let host_summary = host_counts
|
||||
.iter()
|
||||
.map(|(name, count)| format!("{name}:{count}"))
|
||||
.join(",");
|
||||
format!("{display} [Kernels: {kernel_summary}] [Hosts: {host_summary}]")
|
||||
} else {
|
||||
display
|
||||
};
|
||||
|
||||
(duration, display)
|
||||
}
|
||||
@@ -1534,6 +1688,21 @@ impl Runtime for CudaRuntime {
|
||||
exec_op.internal.stats_name().unwrap_or("unknown")
|
||||
);
|
||||
});
|
||||
|
||||
#[cfg(test)]
|
||||
if std::env::var_os("LUMINAL_CUDA_CHECK_NONFINITE_INTERNAL").is_some() {
|
||||
let mut produced_nodes = exec_op.internal.extra_buffer_nodes();
|
||||
produced_nodes.push(exec_op.output);
|
||||
if let Some(report) = self.first_nonfinite_f32_buffer_in_nodes(produced_nodes) {
|
||||
panic!(
|
||||
"CUDA execute produced non-finite buffer after {:?}: node={} index={} value={}",
|
||||
exec_op.internal.stats_name().unwrap_or("unknown"),
|
||||
report.node.index(),
|
||||
report.index,
|
||||
report.value
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Single sync at end - CUDA stream ordering guarantees sequential execution
|
||||
self.cuda_stream.synchronize().unwrap();
|
||||
|
||||
169
crates/luminal_cuda_lite/src/tests/generic_matmul_rewrite.rs
Normal file
169
crates/luminal_cuda_lite/src/tests/generic_matmul_rewrite.rs
Normal file
@@ -0,0 +1,169 @@
|
||||
use luminal::{
|
||||
egglog_utils::{
|
||||
NodeId, SerializedEGraph, egglog_to_llir, random_initial_choice, validate_choice_set,
|
||||
},
|
||||
prelude::*,
|
||||
};
|
||||
use rand::{SeedableRng, rngs::StdRng};
|
||||
|
||||
use crate::{kernel::KernelOp, runtime::CudaRuntime};
|
||||
|
||||
use super::utilities::{assert_close, get_cuda_stream};
|
||||
|
||||
#[test]
|
||||
fn generic_matmul_covers_noncontiguous_merged_head_projection() {
|
||||
let mut cx = Graph::default();
|
||||
let heads = 3;
|
||||
let seq = 4;
|
||||
let head_dim = 5;
|
||||
let hidden = heads * head_dim;
|
||||
let out_dim = 7;
|
||||
|
||||
let attn = cx.tensor((heads, seq, head_dim));
|
||||
let weight = cx.tensor((out_dim, hidden));
|
||||
let merged = attn.transpose(0, 1).merge_dims(1, 2);
|
||||
merged.matmul(weight.t()).output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
let llir = extract_forced_kernel_llir(&mut cx, "GenericMatmul");
|
||||
let names = llir_kernel_names(&llir);
|
||||
|
||||
assert!(
|
||||
names.contains(&"GenericMatmul"),
|
||||
"expected generic matmul fallback, kernels: {names:?}"
|
||||
);
|
||||
assert!(
|
||||
!names.contains(&"Mul") && !names.contains(&"SumReduce"),
|
||||
"generic matmul should prune the broadcast multiply/sum fallback, kernels: {names:?}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn generic_matmul_executes_noncontiguous_merged_head_projection() {
|
||||
let mut cx = Graph::default();
|
||||
let heads = 3;
|
||||
let seq = 4;
|
||||
let head_dim = 5;
|
||||
let hidden = heads * head_dim;
|
||||
let out_dim = 7;
|
||||
|
||||
let attn = cx.tensor((heads, seq, head_dim));
|
||||
let weight = cx.tensor((out_dim, hidden));
|
||||
let merged = attn.transpose(0, 1).merge_dims(1, 2);
|
||||
let output = merged.matmul(weight.t()).output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
let stream = get_cuda_stream().expect("CUDA device required for GenericMatmul execution test");
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
|
||||
let attn_data = seeded_data(heads * seq * head_dim, 0.19, -0.09);
|
||||
let weight_data = seeded_data(out_dim * hidden, 0.14, -0.06);
|
||||
rt.set_data(attn, attn_data.as_slice());
|
||||
rt.set_data(weight, weight_data.as_slice());
|
||||
|
||||
rt = cx.search(rt, 1);
|
||||
assert!(
|
||||
rt.kernel_names().contains(&"GenericMatmul"),
|
||||
"expected GenericMatmul to be selected, kernels: {:?}",
|
||||
rt.kernel_names()
|
||||
);
|
||||
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(output.id);
|
||||
|
||||
let mut expected = vec![0.0; seq * out_dim];
|
||||
for token in 0..seq {
|
||||
for out_col in 0..out_dim {
|
||||
let mut sum = 0.0;
|
||||
for inner in 0..hidden {
|
||||
let head = inner / head_dim;
|
||||
let dim = inner % head_dim;
|
||||
let attn_idx = head * seq * head_dim + token * head_dim + dim;
|
||||
sum += attn_data[attn_idx] * weight_data[out_col * hidden + inner];
|
||||
}
|
||||
expected[token * out_dim + out_col] = sum;
|
||||
}
|
||||
}
|
||||
|
||||
assert_close(&result, &expected, 1e-5, 1e-5);
|
||||
}
|
||||
|
||||
fn seeded_data(len: usize, scale: f32, bias: f32) -> Vec<f32> {
|
||||
(0..len)
|
||||
.map(|i| {
|
||||
let x = ((i * 37 + 11) % 97) as f32 / 97.0;
|
||||
x * scale + bias
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn extract_forced_kernel_llir(cx: &mut Graph, kernel_name: &str) -> LLIRGraph {
|
||||
let egraph = cx.egraph().expect("search space should have an e-graph");
|
||||
let ops = cx
|
||||
.egglog_ops()
|
||||
.expect("search space should have registered egglog ops");
|
||||
let kernel_nodes = op_ir_nodes(egraph, kernel_name);
|
||||
assert!(
|
||||
!kernel_nodes.is_empty(),
|
||||
"expected at least one {kernel_name} candidate"
|
||||
);
|
||||
|
||||
for (idx, kernel_node) in kernel_nodes.iter().enumerate() {
|
||||
let mut rng = StdRng::seed_from_u64(0x9E_EE_0000 + idx as u64);
|
||||
let mut choices = random_initial_choice(egraph, &mut rng);
|
||||
let kernel_class = &egraph.node_to_class[*kernel_node];
|
||||
choices.insert(kernel_class, kernel_node);
|
||||
|
||||
if validate_choice_set(egraph, &choices, ops).is_err() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut list_cache = FxHashMap::default();
|
||||
let mut expr_cache = FxHashMap::default();
|
||||
let llir = egglog_to_llir(
|
||||
egraph,
|
||||
choices,
|
||||
ops,
|
||||
&cx.custom_ops,
|
||||
&mut list_cache,
|
||||
&mut expr_cache,
|
||||
None,
|
||||
);
|
||||
if llir_kernel_names(&llir).contains(&kernel_name) {
|
||||
return llir;
|
||||
}
|
||||
}
|
||||
|
||||
panic!("could not extract a valid {kernel_name} candidate");
|
||||
}
|
||||
|
||||
fn llir_kernel_names(llir: &LLIRGraph) -> Vec<&'static str> {
|
||||
llir.node_indices()
|
||||
.filter_map(|node| {
|
||||
llir[node]
|
||||
.to_dialect::<dyn KernelOp>()
|
||||
.map(|kernel| kernel.kernel_name())
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn op_ir_nodes<'a>(egraph: &'a SerializedEGraph, kind_label: &str) -> Vec<&'a NodeId> {
|
||||
let op_kind_classes = egraph
|
||||
.enodes
|
||||
.iter()
|
||||
.filter(|(_, (label, _))| label == kind_label)
|
||||
.map(|(node, _)| egraph.node_to_class[node].clone())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
egraph
|
||||
.enodes
|
||||
.iter()
|
||||
.filter_map(|(node, (label, children))| {
|
||||
(label == "Op"
|
||||
&& children
|
||||
.first()
|
||||
.is_some_and(|kind| op_kind_classes.contains(kind)))
|
||||
.then_some(node)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
@@ -13,6 +13,8 @@ mod flashinfer;
|
||||
#[cfg(test)]
|
||||
mod fusion;
|
||||
#[cfg(test)]
|
||||
mod generic_matmul_rewrite;
|
||||
#[cfg(test)]
|
||||
mod model_fuzz;
|
||||
#[cfg(test)]
|
||||
mod op_functional_tests;
|
||||
|
||||
@@ -2,10 +2,7 @@ use half::bf16;
|
||||
use luminal::{dtype::DType, prelude::*, shape::Expression};
|
||||
|
||||
use super::utilities::{assert_close, get_cuda_stream, random_f32_vec};
|
||||
use crate::{
|
||||
host::moe::{GLUMoE, GLUMoEMode},
|
||||
runtime::CudaRuntime,
|
||||
};
|
||||
use crate::{host::moe::GLUMoE, runtime::CudaRuntime};
|
||||
|
||||
const SEQ: usize = 2;
|
||||
const HIDDEN: usize = 32;
|
||||
@@ -173,25 +170,44 @@ fn gemma_gelu(x: GraphTensor) -> GraphTensor {
|
||||
x * scaled.sigmoid()
|
||||
}
|
||||
|
||||
fn glumoe_modes(rt: &CudaRuntime) -> Vec<GLUMoEMode> {
|
||||
rt.host_ops()
|
||||
.into_iter()
|
||||
.filter_map(|op| {
|
||||
op.as_any()
|
||||
.downcast_ref::<GLUMoE>()
|
||||
.map(|glumoe| glumoe.mode)
|
||||
})
|
||||
.collect()
|
||||
fn search_space_contains(cx: &Graph, op_name: &str) -> bool {
|
||||
let egraph = cx.egraph().expect("test should build an e-graph");
|
||||
|
||||
for (label, children) in egraph.enodes.values() {
|
||||
if label != "Op" {
|
||||
continue;
|
||||
}
|
||||
let Some(kind_eclass) = children.first() else {
|
||||
continue;
|
||||
};
|
||||
let Some((_, kind_enodes)) = egraph.eclasses.get(kind_eclass) else {
|
||||
continue;
|
||||
};
|
||||
if kind_enodes
|
||||
.iter()
|
||||
.any(|kind_node| egraph.enodes[kind_node].0 == op_name)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
fn run_qwen_moe(use_glumoe: bool) -> (Vec<f32>, Vec<GLUMoEMode>) {
|
||||
fn assert_glumoe_in_search_space(cx: &Graph) {
|
||||
assert!(
|
||||
search_space_contains(cx, "GLUMoE"),
|
||||
"GLUMoE was not in the e-graph search space"
|
||||
);
|
||||
}
|
||||
|
||||
fn run_qwen_moe(include_glumoe: bool) -> Vec<f32> {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return (vec![], vec![]);
|
||||
return vec![];
|
||||
};
|
||||
|
||||
let mut model = build_qwen_moe_graph();
|
||||
model.graph.set_dim('s', SEQ);
|
||||
if use_glumoe {
|
||||
if include_glumoe {
|
||||
model.graph.build_search_space::<CudaRuntime>();
|
||||
} else {
|
||||
model
|
||||
@@ -218,17 +234,17 @@ fn run_qwen_moe(use_glumoe: bool) -> (Vec<f32>, Vec<GLUMoEMode>) {
|
||||
rt = model.graph.search(rt, 10);
|
||||
rt.execute(&model.graph.dyn_map);
|
||||
|
||||
(rt.get_f32(model.output.id), glumoe_modes(&rt))
|
||||
rt.get_f32(model.output.id)
|
||||
}
|
||||
|
||||
fn run_gemma_moe(use_glumoe: bool) -> (Vec<f32>, Vec<GLUMoEMode>) {
|
||||
fn run_gemma_moe(include_glumoe: bool) -> Vec<f32> {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return (vec![], vec![]);
|
||||
return vec![];
|
||||
};
|
||||
|
||||
let mut model = build_gemma_moe_graph();
|
||||
model.graph.set_dim('s', SEQ);
|
||||
if use_glumoe {
|
||||
if include_glumoe {
|
||||
model.graph.build_search_space::<CudaRuntime>();
|
||||
} else {
|
||||
model
|
||||
@@ -261,51 +277,51 @@ fn run_gemma_moe(use_glumoe: bool) -> (Vec<f32>, Vec<GLUMoEMode>) {
|
||||
rt = model.graph.search(rt, 10);
|
||||
rt.execute(&model.graph.dyn_map);
|
||||
|
||||
(rt.get_f32(model.output.id), glumoe_modes(&rt))
|
||||
rt.get_f32(model.output.id)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_glumoe_matches_qwen_swiglu_pattern() {
|
||||
let (_result, modes) = run_qwen_moe(true);
|
||||
if modes.is_empty() {
|
||||
if get_cuda_stream().is_none() {
|
||||
return;
|
||||
}
|
||||
|
||||
assert_eq!(modes, vec![GLUMoEMode::SwiGLUNormalized]);
|
||||
let mut model = build_qwen_moe_graph();
|
||||
model.graph.set_dim('s', SEQ);
|
||||
model.graph.build_search_space::<CudaRuntime>();
|
||||
assert_glumoe_in_search_space(&model.graph);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_glumoe_matches_gemma_gelu_pattern() {
|
||||
let (_result, modes) = run_gemma_moe(true);
|
||||
if modes.is_empty() {
|
||||
if get_cuda_stream().is_none() {
|
||||
return;
|
||||
}
|
||||
|
||||
assert_eq!(modes, vec![GLUMoEMode::GemmaGELU]);
|
||||
let mut model = build_gemma_moe_graph();
|
||||
model.graph.set_dim('s', SEQ);
|
||||
model.graph.build_search_space::<CudaRuntime>();
|
||||
assert_glumoe_in_search_space(&model.graph);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_glumoe_swiglu_matches_unfused_output() {
|
||||
let (expected, baseline_modes) = run_qwen_moe(false);
|
||||
let expected = run_qwen_moe(false);
|
||||
if expected.is_empty() {
|
||||
return;
|
||||
}
|
||||
assert!(baseline_modes.is_empty());
|
||||
|
||||
let (actual, fused_modes) = run_qwen_moe(true);
|
||||
assert_eq!(fused_modes, vec![GLUMoEMode::SwiGLUNormalized]);
|
||||
let actual = run_qwen_moe(true);
|
||||
assert_close(&actual, &expected, 3e-2, 3e-2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_glumoe_gemma_gelu_matches_unfused_output() {
|
||||
let (expected, baseline_modes) = run_gemma_moe(false);
|
||||
let expected = run_gemma_moe(false);
|
||||
if expected.is_empty() {
|
||||
return;
|
||||
}
|
||||
assert!(baseline_modes.is_empty());
|
||||
|
||||
let (actual, fused_modes) = run_gemma_moe(true);
|
||||
assert_eq!(fused_modes, vec![GLUMoEMode::GemmaGELU]);
|
||||
let actual = run_gemma_moe(true);
|
||||
assert_close(&actual, &expected, 3e-2, 3e-2);
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
//!
|
||||
//! These tests do not compare against a hand-written reference. They assert the
|
||||
//! stronger search invariant: every selectable LLIR graph from the same e-graph
|
||||
//! must produce the same outputs for the same runtime inputs.
|
||||
//! must produce finite, numerically close outputs for the same runtime inputs.
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[path = "../../../../examples/llama/src/model.rs"]
|
||||
@@ -93,7 +93,7 @@ fn llama_architecture_search_space_equivalence_fuzz() {
|
||||
.generation_size(8)
|
||||
.mutations(3)
|
||||
.build_options(BuildSearchSpaceOptions::new().max_memory_mib(512))
|
||||
.output_f32(logits.id, "logits", 3e-3, 3e-3);
|
||||
.output_f32(logits.id, "logits", 5e-2, 5e-2);
|
||||
for (layer, (k_out, v_out)) in cache_outputs.into_iter().enumerate() {
|
||||
let k_out = k_out.output();
|
||||
let v_out = v_out.output();
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
use candle_core::{Device, Tensor, WithDType};
|
||||
use cudarc::driver::CudaContext;
|
||||
use half::{bf16, f16};
|
||||
use itertools::Itertools;
|
||||
use luminal::egglog_utils::{
|
||||
EGraphChoiceSet, egglog_to_llir, extract_generation, hash_choice_set, random_initial_choice,
|
||||
validate_choice_set,
|
||||
};
|
||||
use luminal::prelude::*;
|
||||
use luminal::prelude::{
|
||||
petgraph::{Direction, algo::toposort, visit::EdgeRef},
|
||||
*,
|
||||
};
|
||||
use num_traits::{Num, Signed};
|
||||
use rand::{Rng, SeedableRng, rngs::StdRng};
|
||||
use std::sync::Arc;
|
||||
@@ -210,6 +214,11 @@ pub struct SearchEquivalenceFuzzReport {
|
||||
pub skipped_invalid: usize,
|
||||
}
|
||||
|
||||
struct ChoiceRun {
|
||||
outputs: Vec<Vec<f32>>,
|
||||
llir_summary: String,
|
||||
}
|
||||
|
||||
pub struct CudaSearchEquivalenceFuzzer<'a> {
|
||||
cx: &'a mut Graph,
|
||||
stream: &'a Arc<cudarc::driver::CudaStream>,
|
||||
@@ -302,7 +311,8 @@ impl<'a> CudaSearchEquivalenceFuzzer<'a> {
|
||||
/// LLIR graphs, runs each with identical inputs, and verifies every requested
|
||||
/// f32 output matches the first valid extraction. The reference is intentionally
|
||||
/// another selected LLIR graph, not a hand-written CPU implementation: this
|
||||
/// catches cases where supposedly equivalent e-graph choices diverge.
|
||||
/// catches cases where supposedly equivalent e-graph choices diverge, including
|
||||
/// candidates that produce non-finite outputs.
|
||||
pub fn fuzz_cuda_search_space_equivalence(
|
||||
cx: &mut Graph,
|
||||
stream: &Arc<cudarc::driver::CudaStream>,
|
||||
@@ -354,12 +364,12 @@ pub fn fuzz_cuda_search_space_equivalence(
|
||||
|
||||
let mut skipped_invalid = 0usize;
|
||||
let reference_is_cuda = native_reference_outputs.is_none();
|
||||
let (reference_hash, reference_outputs, mut tested) =
|
||||
let (reference_hash, reference_outputs, reference_llir_summary, mut tested) =
|
||||
if let Some(reference_outputs) = native_reference_outputs {
|
||||
(0, reference_outputs, 0usize)
|
||||
(0, reference_outputs, None, 0usize)
|
||||
} else {
|
||||
let mut attempts = 0usize;
|
||||
let (reference_hash, reference_outputs) = loop {
|
||||
let (reference_hash, reference_run) = loop {
|
||||
attempts += 1;
|
||||
if attempts > config.max_attempts {
|
||||
panic!(
|
||||
@@ -372,17 +382,19 @@ pub fn fuzz_cuda_search_space_equivalence(
|
||||
} else {
|
||||
let hash = hash_choice_set(&base);
|
||||
match run_choice_outputs(cx, stream, inputs, outputs, &base) {
|
||||
Ok(values) => break (hash, values),
|
||||
Err(err) => {
|
||||
skipped_invalid += 1;
|
||||
eprintln!("skipping invalid reference candidate hash={hash}: {err}");
|
||||
}
|
||||
Ok(run) => break (hash, run),
|
||||
Err(err) => panic!("reference candidate hash={hash} failed: {err}"),
|
||||
}
|
||||
}
|
||||
base = random_initial_choice(egraph, &mut rng);
|
||||
prev_selected.insert(hash_choice_set(&base));
|
||||
};
|
||||
(reference_hash, reference_outputs, 1usize)
|
||||
(
|
||||
reference_hash,
|
||||
reference_run.outputs,
|
||||
Some(reference_run.llir_summary),
|
||||
1usize,
|
||||
)
|
||||
};
|
||||
|
||||
let mut attempts = 0usize;
|
||||
@@ -415,12 +427,14 @@ pub fn fuzz_cuda_search_space_equivalence(
|
||||
continue;
|
||||
}
|
||||
|
||||
let candidate_outputs = run_choice_outputs(cx, stream, inputs, outputs, &candidate)
|
||||
let candidate_run = run_choice_outputs(cx, stream, inputs, outputs, &candidate)
|
||||
.unwrap_or_else(|err| panic!("candidate hash={candidate_hash} failed: {err}"));
|
||||
assert_fuzz_outputs_close(
|
||||
outputs,
|
||||
&reference_outputs,
|
||||
&candidate_outputs,
|
||||
&candidate_run.outputs,
|
||||
&candidate_run.llir_summary,
|
||||
reference_llir_summary.as_deref(),
|
||||
reference_hash,
|
||||
candidate_hash,
|
||||
);
|
||||
@@ -446,7 +460,7 @@ fn run_choice_outputs<'a>(
|
||||
inputs: &[CudaFuzzInput],
|
||||
outputs: &[F32OutputCheck],
|
||||
choices: &EGraphChoiceSet<'a>,
|
||||
) -> Result<Vec<Vec<f32>>, String> {
|
||||
) -> Result<ChoiceRun, String> {
|
||||
let egraph = cx.egraph().ok_or("search space was not built")?;
|
||||
let ops = cx.egglog_ops().ok_or("search ops were not built")?;
|
||||
let mut list_cache = FxHashMap::default();
|
||||
@@ -461,21 +475,86 @@ fn run_choice_outputs<'a>(
|
||||
None,
|
||||
);
|
||||
unroll_loops_in_llir(&mut llir_graph);
|
||||
let llir_summary = summarize_llir(&llir_graph);
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream.clone());
|
||||
rt.load_llir(&llir_graph);
|
||||
rt.preserve_intermediate_buffers_for_debug();
|
||||
for input in inputs {
|
||||
input.apply(&mut rt);
|
||||
}
|
||||
if std::env::var_os("LUMINAL_FUZZ_DUMP_LAST_LLIR").is_some() {
|
||||
let _ = std::fs::write("/tmp/luminal_fuzz_last_candidate_llir.txt", &llir_summary);
|
||||
}
|
||||
rt.execute(&cx.dyn_map);
|
||||
let topo_order = toposort(&llir_graph, None).map_err(|cycle| {
|
||||
format!(
|
||||
"extracted LLIR contains cycle at node {:?}",
|
||||
cycle.node_id()
|
||||
)
|
||||
})?;
|
||||
if let Some(report) = rt.first_nonfinite_f32_buffer_in_nodes(topo_order) {
|
||||
let dump_path = "/tmp/luminal_fuzz_bad_candidate_llir.txt";
|
||||
let _ = std::fs::write(dump_path, &llir_summary);
|
||||
let op = llir_graph
|
||||
.node_weight(report.node)
|
||||
.map(|op| format!("{op:?}"))
|
||||
.unwrap_or_else(|| "unknown op".to_string());
|
||||
return Err(format!(
|
||||
"LLIR produced non-finite F32 buffer node={} index={} value={} op={}; llir={dump_path}",
|
||||
report.node.index(),
|
||||
report.index,
|
||||
report.value,
|
||||
op
|
||||
));
|
||||
}
|
||||
|
||||
Ok(outputs.iter().map(|out| rt.get_f32(out.id)).collect())
|
||||
let values = outputs
|
||||
.iter()
|
||||
.map(|out| rt.get_f32(out.id))
|
||||
.collect::<Vec<_>>();
|
||||
for (spec, values) in outputs.iter().zip(&values) {
|
||||
if let Some((idx, value)) = values
|
||||
.iter()
|
||||
.enumerate()
|
||||
.find(|(_, value)| !value.is_finite())
|
||||
{
|
||||
let dump_path = "/tmp/luminal_fuzz_bad_candidate_llir.txt";
|
||||
let _ = std::fs::write(dump_path, &llir_summary);
|
||||
let internal = rt
|
||||
.first_nonfinite_f32_buffer()
|
||||
.map(|report| {
|
||||
let op = llir_graph
|
||||
.node_weight(report.node)
|
||||
.map(|op| format!("{op:?}"))
|
||||
.unwrap_or_else(|| "unknown op".to_string());
|
||||
format!(
|
||||
"; first observed non-finite buffer node={} index={} value={} op={}",
|
||||
report.node.index(),
|
||||
report.index,
|
||||
report.value,
|
||||
op
|
||||
)
|
||||
})
|
||||
.unwrap_or_default();
|
||||
return Err(format!(
|
||||
"output {} produced non-finite value {value} at index {idx}{internal}; llir={dump_path}",
|
||||
spec.name
|
||||
));
|
||||
}
|
||||
}
|
||||
Ok(ChoiceRun {
|
||||
outputs: values,
|
||||
llir_summary,
|
||||
})
|
||||
}
|
||||
|
||||
fn assert_fuzz_outputs_close(
|
||||
outputs: &[F32OutputCheck],
|
||||
expected: &[Vec<f32>],
|
||||
actual: &[Vec<f32>],
|
||||
candidate_llir_summary: &str,
|
||||
reference_llir_summary: Option<&str>,
|
||||
reference_hash: u64,
|
||||
candidate_hash: u64,
|
||||
) {
|
||||
@@ -508,8 +587,16 @@ fn assert_fuzz_outputs_close(
|
||||
worst = i;
|
||||
}
|
||||
if abs > spec.atol + spec.rtol * b.abs() {
|
||||
let dump_path = "/tmp/luminal_fuzz_bad_candidate_llir.txt";
|
||||
let _ = std::fs::write(dump_path, candidate_llir_summary);
|
||||
if let Some(reference_llir_summary) = reference_llir_summary {
|
||||
let _ = std::fs::write(
|
||||
"/tmp/luminal_fuzz_bad_reference_llir.txt",
|
||||
reference_llir_summary,
|
||||
);
|
||||
}
|
||||
panic!(
|
||||
"output {} mismatch candidate hash={candidate_hash} reference hash={reference_hash} index={i} actual={a} expected={b} abs={abs} rel={rel} tolerance={}",
|
||||
"output {} mismatch candidate hash={candidate_hash} reference hash={reference_hash} index={i} actual={a} expected={b} abs={abs} rel={rel} tolerance={} candidate_llir={dump_path}",
|
||||
spec.name,
|
||||
spec.atol + spec.rtol * b.abs()
|
||||
);
|
||||
@@ -522,6 +609,22 @@ fn assert_fuzz_outputs_close(
|
||||
}
|
||||
}
|
||||
|
||||
fn summarize_llir(llir_graph: &LLIRGraph) -> String {
|
||||
llir_graph
|
||||
.node_indices()
|
||||
.map(|idx| {
|
||||
let inputs = llir_graph
|
||||
.edges_directed(idx, Direction::Incoming)
|
||||
.sorted_by_key(|edge| edge.id())
|
||||
.map(|edge| edge.source().index().to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ");
|
||||
format!("{} <- [{}]: {:?}", idx.index(), inputs, &llir_graph[idx])
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n")
|
||||
}
|
||||
|
||||
/// Get the GPU compute capability as (major, minor).
|
||||
pub fn gpu_compute_cap() -> Option<(i32, i32)> {
|
||||
let ctx = CudaContext::new(0).ok()?;
|
||||
|
||||
@@ -17,7 +17,7 @@ const FP8_REPO_ID: &str = "nvidia/Llama-3.1-8B-Instruct-FP8";
|
||||
const MAX_SEQ_LEN: usize = 4096;
|
||||
const GEN_TOKENS: usize = 500;
|
||||
const SEARCH_GRAPHS: usize = 500;
|
||||
const SEARCH_TRIALS: usize = 1;
|
||||
const SEARCH_TRIALS: usize = 10;
|
||||
const SEARCH_KEEP_BEST: usize = 4;
|
||||
const SEARCH_MEMORY_MIB: usize = 2048;
|
||||
const SEARCH_SEED: u64 = 0;
|
||||
|
||||
18
src/bin/examples-perf.rs
Normal file
18
src/bin/examples-perf.rs
Normal file
@@ -0,0 +1,18 @@
|
||||
use std::{
|
||||
env,
|
||||
path::Path,
|
||||
process::{Command, ExitCode},
|
||||
};
|
||||
|
||||
fn main() -> ExitCode {
|
||||
let repo_root = env!("CARGO_MANIFEST_DIR");
|
||||
let script = Path::new(repo_root).join("ci/examples_perf.py");
|
||||
let status = Command::new("python3")
|
||||
.arg(script)
|
||||
.args(env::args_os().skip(1))
|
||||
.current_dir(repo_root)
|
||||
.status()
|
||||
.expect("failed to run python3 ci/examples_perf.py");
|
||||
|
||||
ExitCode::from(status.code().unwrap_or(1) as u8)
|
||||
}
|
||||
Reference in New Issue
Block a user