mirror of
https://github.com/jafioti/luminal.git
synced 2026-06-01 21:49:47 +09:00
Compare commits
118 Commits
pytest-cla
...
codex-lumi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
79d00a4827 | ||
|
|
acad3a625a | ||
|
|
4d1ff217be | ||
|
|
07ad11d101 | ||
|
|
98f4f2102b | ||
|
|
44b293bee0 | ||
|
|
f9b9657c1c | ||
|
|
6db0f716d5 | ||
|
|
d03ab816d8 | ||
|
|
61904fbc76 | ||
|
|
f461fca3da | ||
|
|
5f199e94c6 | ||
|
|
93fb02c495 | ||
|
|
16de9638fc | ||
|
|
f08d24e73f | ||
|
|
aba9627563 | ||
|
|
7d68b62aa8 | ||
|
|
13c870de86 | ||
|
|
f8b742d718 | ||
|
|
3555d169bd | ||
|
|
be74153c12 | ||
|
|
75535c93f0 | ||
|
|
84f13cae00 | ||
|
|
703c2d9ea4 | ||
|
|
2e3158c48e | ||
|
|
8af22776aa | ||
|
|
cd8c01f620 | ||
|
|
461b746937 | ||
|
|
38e467aa6c | ||
|
|
7429ac163b | ||
|
|
07c151dd70 | ||
|
|
c0f7f1f054 | ||
|
|
df96fe5110 | ||
|
|
18a550dd15 | ||
|
|
254680001d | ||
|
|
2920011897 | ||
|
|
d879376697 | ||
|
|
2be30c18cd | ||
|
|
48f921d2a1 | ||
|
|
f55e7e0589 | ||
|
|
db2027d345 | ||
|
|
9a5032bfc9 | ||
|
|
c665b01c4e | ||
|
|
883508e682 | ||
|
|
080b99b69e | ||
|
|
0bd19289ea | ||
|
|
a3b7f6ecc1 | ||
|
|
438ae460bf | ||
|
|
da440fdef0 | ||
|
|
586365be4d | ||
|
|
3c962a9df8 | ||
|
|
1a460bac96 | ||
|
|
ce06a901cc | ||
|
|
c97288cdae | ||
|
|
d66b3f2643 | ||
|
|
66b0807462 | ||
|
|
c24ea4a7a5 | ||
|
|
c309d9b4ed | ||
|
|
745c071ee5 | ||
|
|
896c4b7c7e | ||
|
|
56ffe8bbb3 | ||
|
|
13dbdcb53b | ||
|
|
0134aa425a | ||
|
|
c8ad5f8b75 | ||
|
|
51c6596f6a | ||
|
|
aef4c68537 | ||
|
|
1ac423c36c | ||
|
|
59c38b3c88 | ||
|
|
9b3b2f5244 | ||
|
|
aed7b86aad | ||
|
|
e3c6d98f36 | ||
|
|
10971d7d05 | ||
|
|
4b0bfa5669 | ||
|
|
2c0c3bb988 | ||
|
|
ca6fac8f78 | ||
|
|
900fee4d67 | ||
|
|
59901c8b12 | ||
|
|
a860a2cb6b | ||
|
|
52b2a45c62 | ||
|
|
0af1c186fd | ||
|
|
e6d13a3979 | ||
|
|
86b2784b51 | ||
|
|
773935b91b | ||
|
|
afb8d7ae4d | ||
|
|
fb23b80a01 | ||
|
|
d6a3171b7b | ||
|
|
59edd0b179 | ||
|
|
8a2fd832b6 | ||
|
|
76c0d43aa0 | ||
|
|
f99f1e10cb | ||
|
|
a5b26100ba | ||
|
|
a40f5dd386 | ||
|
|
efe746ba39 | ||
|
|
d91dce41d4 | ||
|
|
11d59a351c | ||
|
|
6d66f80340 | ||
|
|
2da5cdaa30 | ||
|
|
44520a8100 | ||
|
|
53c58576fc | ||
|
|
64e4eedcc6 | ||
|
|
cc1b448c90 | ||
|
|
63afb602b0 | ||
|
|
985e7752aa | ||
|
|
3fd7831e6d | ||
|
|
4c8bed686f | ||
|
|
cbf1ef5fc4 | ||
|
|
7a53d39852 | ||
|
|
3786977f01 | ||
|
|
1a4662ec3b | ||
|
|
2963278637 | ||
|
|
97f11a78bf | ||
|
|
27faf0819c | ||
|
|
c225d3affb | ||
|
|
ac10f82308 | ||
|
|
f2f5944f47 | ||
|
|
f9865ae2a3 | ||
|
|
46ebc58334 | ||
|
|
412147ea78 |
6
.github/workflows/modal-examples.yml
vendored
6
.github/workflows/modal-examples.yml
vendored
@@ -3,7 +3,7 @@ name: Modal Examples
|
||||
on:
|
||||
push:
|
||||
branches: ["main"]
|
||||
pull_request:
|
||||
pull_request_target:
|
||||
branches: ["main"]
|
||||
types: [labeled, synchronize]
|
||||
workflow_dispatch:
|
||||
@@ -13,7 +13,7 @@ jobs:
|
||||
if: >-
|
||||
github.event_name == 'push'
|
||||
|| github.event_name == 'workflow_dispatch'
|
||||
|| (github.event_name == 'pull_request'
|
||||
|| (github.event_name == 'pull_request_target'
|
||||
&& contains(github.event.pull_request.labels.*.name, 'modal-ready'))
|
||||
name: "${{ matrix.example }} (Modal ${{ matrix.gpu.type }})"
|
||||
runs-on: ubuntu-latest
|
||||
@@ -30,6 +30,8 @@ jobs:
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
with:
|
||||
ref: ${{ github.event.pull_request.head.sha || github.sha }}
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
|
||||
6
.github/workflows/test-cuda.yml
vendored
6
.github/workflows/test-cuda.yml
vendored
@@ -3,7 +3,7 @@ name: Test CUDA
|
||||
on:
|
||||
push:
|
||||
branches: ["main"]
|
||||
pull_request:
|
||||
pull_request_target:
|
||||
branches: ["main"]
|
||||
types: [labeled, synchronize]
|
||||
workflow_dispatch:
|
||||
@@ -13,7 +13,7 @@ jobs:
|
||||
if: >-
|
||||
github.event_name == 'push'
|
||||
|| github.event_name == 'workflow_dispatch'
|
||||
|| (github.event_name == 'pull_request'
|
||||
|| (github.event_name == 'pull_request_target'
|
||||
&& contains(github.event.pull_request.labels.*.name, 'modal-ready'))
|
||||
name: Cuda Unit Tests
|
||||
runs-on: ubuntu-latest
|
||||
@@ -22,6 +22,8 @@ jobs:
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
with:
|
||||
ref: ${{ github.event.pull_request.head.sha || github.sha }}
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
|
||||
6
.github/workflows/test-python-cuda.yml
vendored
6
.github/workflows/test-python-cuda.yml
vendored
@@ -3,7 +3,7 @@ name: Test Python CUDA
|
||||
on:
|
||||
push:
|
||||
branches: ["main"]
|
||||
pull_request:
|
||||
pull_request_target:
|
||||
branches: ["main"]
|
||||
types: [labeled, synchronize]
|
||||
workflow_dispatch:
|
||||
@@ -13,7 +13,7 @@ jobs:
|
||||
if: >-
|
||||
github.event_name == 'push'
|
||||
|| github.event_name == 'workflow_dispatch'
|
||||
|| (github.event_name == 'pull_request'
|
||||
|| (github.event_name == 'pull_request_target'
|
||||
&& contains(github.event.pull_request.labels.*.name, 'modal-ready'))
|
||||
name: Python CUDA Tests
|
||||
runs-on: ubuntu-latest
|
||||
@@ -25,6 +25,8 @@ jobs:
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
with:
|
||||
ref: ${{ github.event.pull_request.head.sha || github.sha }}
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
|
||||
@@ -32,6 +32,7 @@ pretty-duration = "0.1.1"
|
||||
anyhow = "1.0"
|
||||
graphviz-rust = { version = "0.9", default-features = false}
|
||||
lru = "0.16.2"
|
||||
rayon = "1.10"
|
||||
|
||||
[workspace.package]
|
||||
edition = "2024"
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
<img href="luminal.com" alt="Screenshot 2025-08-14 at 9 18 54 PM" src="https://github.com/user-attachments/assets/c5832634-55d5-45b7-ba65-6efe36afce4a" />
|
||||
<img href="luminal.com" alt="Screenshot 2025-08-14 at 9 18 54 PM" src="https://github.com/luminal-ai/luminal/blob/main/docs/logo/inference_at_the_speed_of_light.png" />
|
||||
|
||||
<h3 align="center">
|
||||
Luminal is a high-performance general-purpose inference compiler.
|
||||
</h3>
|
||||
|
||||
[](https://github.com/jafioti/luminal/actions)
|
||||
[](https://github.com/luminal-ai/luminal/actions)
|
||||
[](https://docs.luminalai.com)
|
||||
[](https://crates.io/crates/luminal)
|
||||
[](https://discord.gg/APjuwHAbGy)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import modal
|
||||
import subprocess
|
||||
import os
|
||||
import sys
|
||||
|
||||
gpu_type = os.environ.get("GPU_TYPE", "T4")
|
||||
CUDARC_CUDA_VERSION = "12080"
|
||||
@@ -46,8 +45,10 @@ def run_cargo_test():
|
||||
|
||||
subprocess.run(
|
||||
[
|
||||
"cargo", "test",
|
||||
"-p", "luminal_cuda_lite",
|
||||
"cargo",
|
||||
"test",
|
||||
"-p",
|
||||
"luminal_cuda_lite",
|
||||
"--verbose",
|
||||
"--",
|
||||
"--test-threads=1",
|
||||
|
||||
@@ -106,13 +106,13 @@ impl Case {
|
||||
let out = match self {
|
||||
Case::Mul => {
|
||||
let x = cx.tensor(size);
|
||||
x.clone() * x
|
||||
x * x
|
||||
}
|
||||
Case::Sigmoid => cx.tensor(size).sigmoid(),
|
||||
Case::Tanh => cx.tensor(size).tanh(),
|
||||
Case::GeluInner => {
|
||||
let x = cx.tensor(size);
|
||||
(0.797_884_560_8_f32 * x.clone() * (1. + 0.044_715_f32 * x.clone() * x)).tanh()
|
||||
(0.797_884_6_f32 * x * (1. + 0.044_715_f32 * x * x)).tanh()
|
||||
}
|
||||
Case::Gelu => cx.tensor(size).gelu(),
|
||||
Case::LayerNorm => {
|
||||
@@ -447,10 +447,10 @@ where
|
||||
if let Some(ref backend) = backend_analysis {
|
||||
print_lowering_analysis(backend);
|
||||
}
|
||||
} else if !args.inspect_ops.is_empty() {
|
||||
if let Some(ref backend) = backend_analysis {
|
||||
print_lowering_analysis(backend);
|
||||
}
|
||||
} else if !args.inspect_ops.is_empty()
|
||||
&& let Some(ref backend) = backend_analysis
|
||||
{
|
||||
print_lowering_analysis(backend);
|
||||
}
|
||||
|
||||
// Trace facts for explicit variables.
|
||||
|
||||
75
crates/luminal_cuda_lite/src/dyn_backend.rs
Normal file
75
crates/luminal_cuda_lite/src/dyn_backend.rs
Normal file
@@ -0,0 +1,75 @@
|
||||
//! [`DynBackend`] implementation for the CUDA lite runtime.
|
||||
|
||||
use luminal::dtype::DType;
|
||||
use luminal::dyn_backend::{BackendCompileArgs, DynBackend, compile_backend};
|
||||
use luminal::prelude::*;
|
||||
|
||||
use crate::cudarc::driver::CudaContext;
|
||||
use crate::runtime::CudaRuntime;
|
||||
|
||||
/// [`DynBackend`] wrapper for [`CudaRuntime`].
|
||||
pub struct CudaLiteDynBackend {
|
||||
pub runtime: CudaRuntime,
|
||||
}
|
||||
|
||||
impl DynBackend for CudaLiteDynBackend {
|
||||
fn name(&self) -> &str {
|
||||
"cuda_lite"
|
||||
}
|
||||
fn device_type(&self) -> &str {
|
||||
"cuda"
|
||||
}
|
||||
|
||||
fn set_data_bytes(&mut self, node: NodeIndex, bytes: Vec<u8>, _dtype: DType) {
|
||||
self.runtime.set_data(node, bytes);
|
||||
}
|
||||
fn set_data_f32(&mut self, node: NodeIndex, data: Vec<f32>) {
|
||||
self.runtime.set_data(node, data);
|
||||
}
|
||||
fn get_output_f32(&self, node: NodeIndex) -> Vec<f32> {
|
||||
self.runtime.get_f32(node)
|
||||
}
|
||||
fn get_output_i32(&self, node: NodeIndex) -> Vec<i32> {
|
||||
self.runtime.get_i32(node)
|
||||
}
|
||||
fn get_output_bool(&self, node: NodeIndex) -> Vec<bool> {
|
||||
self.runtime.get_bool(node)
|
||||
}
|
||||
fn execute(&mut self, dyn_map: &FxHashMap<char, usize>) {
|
||||
self.runtime.execute(dyn_map);
|
||||
}
|
||||
|
||||
fn supports_device_ptrs(&self) -> bool {
|
||||
true
|
||||
}
|
||||
unsafe fn set_device_ptr(&mut self, node: NodeIndex, ptr: u64, n: usize) {
|
||||
unsafe { self.runtime.set_device_ptr(node, ptr, n) }
|
||||
}
|
||||
unsafe fn set_output_device_ptr(&mut self, node: NodeIndex, ptr: u64, n: usize) {
|
||||
unsafe { self.runtime.set_output_device_ptr(node, ptr, n) }
|
||||
}
|
||||
fn output_is_zero_copy(&self, node: NodeIndex) -> bool {
|
||||
self.runtime.output_is_zero_copy(node)
|
||||
}
|
||||
unsafe fn copy_output_to_device_ptr(&self, node: NodeIndex, ptr: u64, n: usize) {
|
||||
unsafe { self.runtime.copy_output_to_device_ptr(node, ptr, n) }
|
||||
}
|
||||
}
|
||||
|
||||
pub fn cuda_lite_factory(
|
||||
graph: &mut Graph,
|
||||
args: BackendCompileArgs,
|
||||
) -> Result<Box<dyn DynBackend>, String> {
|
||||
let cuda_ctx = CudaContext::new(0).map_err(|e| format!("CUDA init failed: {e}"))?;
|
||||
let stream = cuda_ctx.default_stream();
|
||||
compile_backend::<CudaRuntime>(
|
||||
graph,
|
||||
args,
|
||||
|| Ok(CudaRuntime::initialize(stream)),
|
||||
|rt, node, bytes, _dtype| {
|
||||
rt.set_data(node, bytes);
|
||||
},
|
||||
Some(&|rt, node, ptr, n| unsafe { rt.set_device_ptr(node, ptr, n) }),
|
||||
|rt| Box::new(CudaLiteDynBackend { runtime: rt }),
|
||||
)
|
||||
}
|
||||
@@ -32,6 +32,7 @@ use crate::{
|
||||
driver::{CudaSlice, CudaStream, DevicePtr},
|
||||
},
|
||||
host::{HostOp, cublas::parse_cublas_op},
|
||||
try_create_cublaslt,
|
||||
};
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -248,6 +249,19 @@ fn dtype_to_cuda_types(dtype: DType) -> (cudaDataType, cublasComputeType_t, cuda
|
||||
}
|
||||
}
|
||||
|
||||
impl CuBlasLt {
|
||||
fn get_cublaslt(&self, stream: &Arc<CudaStream>) -> anyhow::Result<Arc<CudaBlasLT>> {
|
||||
if let Some(cublaslt) = self.cublaslt.get() {
|
||||
return Ok(cublaslt.clone());
|
||||
}
|
||||
let created = try_create_cublaslt(stream.clone()).map_err(|message| {
|
||||
anyhow::anyhow!("cuBLASLt unavailable on this machine: {message}")
|
||||
})?;
|
||||
let _ = self.cublaslt.set(created.clone());
|
||||
Ok(created)
|
||||
}
|
||||
}
|
||||
|
||||
impl HostOp for CuBlasLt {
|
||||
fn execute(
|
||||
&self,
|
||||
@@ -324,9 +338,7 @@ impl HostOp for CuBlasLt {
|
||||
)
|
||||
.entered();
|
||||
|
||||
let cublaslt = self
|
||||
.cublaslt
|
||||
.get_or_init(|| Arc::new(CudaBlasLT::new(stream.clone()).unwrap()));
|
||||
let cublaslt = self.get_cublaslt(stream)?;
|
||||
|
||||
let mut matmul_desc: cublasLtMatmulDesc_t = std::ptr::null_mut();
|
||||
let mut a_desc: cublasLtMatrixLayout_t = std::ptr::null_mut();
|
||||
@@ -461,7 +473,8 @@ impl HostOp for CuBlasLt {
|
||||
cublasLtMatmulDescDestroy(matmul_desc);
|
||||
}
|
||||
|
||||
stream.synchronize()?;
|
||||
// No stream.synchronize() here — CUDA stream ordering guarantees
|
||||
// sequential execution. The runtime syncs once at the end of execute().
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
@@ -1,128 +1,213 @@
|
||||
; GLUMoE: Match the expert computation subgraph of a Gated MoE (SwiGLU variant).
|
||||
; GLUMoE: Match the expert computation subgraph of a gated MoE.
|
||||
;
|
||||
; This matches the pattern produced by QwenMoE::forward() starting from the
|
||||
; expert gathers through to the final weighted sum, and replaces it with a
|
||||
; fused GLUMoE HostOp.
|
||||
; One fused op supports two activation modes:
|
||||
; mode=0: Qwen-style SwiGLU (silu(gate) * up)
|
||||
; mode=1: Gemma-style GELU (gate * sigmoid(1.595769 * gate * (1 + 0.044715 * gate^2)))
|
||||
;
|
||||
; Inputs extracted:
|
||||
; ?x - input activations [s, H] F32
|
||||
; ?topk_idx - top-k expert indices [s, k] Int (from argsort+slice)
|
||||
; ?topk_vals - top-k routing values [s, k] F32 (from gather on softmax)
|
||||
; ?gate_up_w - stacked gate+up expert weights [E, intermediate*2, H] BF16
|
||||
; ?down_w - stacked down expert weights [E, H, intermediate] BF16
|
||||
;
|
||||
; The pattern captures:
|
||||
; 1. Gate-up expert gather (Iota, Mul, Cast, Iota, Cast, Add, Cast, Gather)
|
||||
; 2. Cast BF16→F32 of gathered gate-up weights
|
||||
; 3. Gate-up batched matmul (Mul + SumReduce)
|
||||
; 4. Gate/Up split via Iota+Gather (slice semantics)
|
||||
; 5. SwiGLU: silu(gate) * up
|
||||
; 6. Down expert gather (same pattern as gate-up)
|
||||
; 7. Cast BF16→F32 of gathered down weights
|
||||
; 8. Down batched matmul (Mul + SumReduce)
|
||||
; 9. Weighted sum: (down_out * topk_values) summed over k
|
||||
;
|
||||
; Variables with ? prefix are egglog pattern variables.
|
||||
; We use wildcards (?_xxx) for shapes/strides we don't extract.
|
||||
; To keep matching fast, we stage through marker states:
|
||||
; 1) Shared gate-up matmul marker
|
||||
; 2) Activation marker (separate swiglu / gemma_gelu paths)
|
||||
; 3) Down matmul marker (separate swiglu / gemma_gelu paths)
|
||||
; 4) Final GLUMoE fusion (separate swiglu / gemma_gelu rules)
|
||||
|
||||
(datatype*
|
||||
(GLUMoEGateUpState
|
||||
(MkGLUMoEGateUpState Expression Expression Expression IR IR IR)
|
||||
)
|
||||
(GLUMoESwiGLUState
|
||||
(MkGLUMoESwiGLUState GLUMoEGateUpState)
|
||||
)
|
||||
(GLUMoEGemmaGELUState
|
||||
(MkGLUMoEGemmaGELUState GLUMoEGateUpState)
|
||||
)
|
||||
(GLUMoESwiGLUDownState
|
||||
(MkGLUMoESwiGLUDownState Expression Expression Expression GLUMoESwiGLUState IR IR)
|
||||
)
|
||||
(GLUMoEGemmaDownState
|
||||
(MkGLUMoEGemmaDownState Expression Expression Expression GLUMoEGemmaGELUState IR IR)
|
||||
)
|
||||
)
|
||||
|
||||
(function glumoe_gate_up (IR) GLUMoEGateUpState :merge new)
|
||||
(function glumoe_swiglu (IR) GLUMoESwiGLUState :merge new)
|
||||
(function glumoe_gemma_gelu (IR) GLUMoEGemmaGELUState :merge new)
|
||||
(function glumoe_swiglu_down (IR) GLUMoESwiGLUDownState :merge new)
|
||||
(function glumoe_gemma_down (IR) GLUMoEGemmaDownState :merge new)
|
||||
|
||||
(rule
|
||||
(
|
||||
; ===== Gate-up expert gather =====
|
||||
; t51: Iota for base index (expert_idx * io_gu)
|
||||
(= ?gu_iota_base (Op (Iota ?gu_io ?gu_iota_base_range) (INil)))
|
||||
; t52: Mul topk_indices * io → base offsets [s, k]
|
||||
(= ?gu_mul_base (Op (Mul ?gu_mul_base_shape ?gu_mul_base_a_stride ?gu_mul_base_b_stride ?gu_mul_base_out_stride) (ICons ?topk_idx (ICons ?gu_iota_base (INil)))))
|
||||
; t53: Cast to F32
|
||||
(= ?gu_cast_base (Op (Cast ?gu_cast_base_size (F32)) (ICons ?gu_mul_base (INil))))
|
||||
; t54: Iota for within-expert index
|
||||
(= ?gu_iota_within (Op (Iota (MIter) ?gu_iota_within_range) (INil)))
|
||||
; t55: Cast within to F32
|
||||
(= ?gu_cast_within (Op (Cast ?gu_cast_within_size (F32)) (ICons ?gu_iota_within (INil))))
|
||||
; t56: Add base + within → flat gather indices
|
||||
(= ?gu_add_idx (Op (Add ?gu_add_shape ?gu_add_a_stride ?gu_add_b_stride ?gu_add_out_stride) (ICons ?gu_cast_base (ICons ?gu_cast_within (INil)))))
|
||||
; t57: Cast to Int
|
||||
(= ?gu_cast_idx (Op (Cast ?gu_cast_idx_size (Int)) (ICons ?gu_add_idx (INil))))
|
||||
; t58: Gather gate_up weights
|
||||
(= ?gu_gathered (Op (Gather ?gu_gather_idx_shape ?gu_gather_idx_stride ?gu_gather_data_shape ?gu_gather_data_stride) (ICons ?gu_cast_idx (ICons ?gate_up_w (INil)))))
|
||||
(= ?gu_add_idx (Op (Add ?gu_add_shape ?gu_add_a_stride ?gu_add_b_stride ?gu_add_out_stride) (ICons ?gu_mul_base (ICons ?gu_iota_within (INil)))))
|
||||
(= ?gu_gathered (Op (Gather ?gu_gather_idx_shape ?gu_gather_idx_stride ?gu_gather_data_shape ?gu_gather_data_stride) (ICons ?gu_add_idx (ICons ?gate_up_w (INil)))))
|
||||
|
||||
; ===== Cast BF16→F32 =====
|
||||
; t59: Cast gathered gate_up to F32
|
||||
(= ?gu_f32 (Op (Cast ?gu_f32_size (F32)) (ICons ?gu_gathered (INil))))
|
||||
|
||||
; ===== Gate-up batched matmul =====
|
||||
; t60: Mul x * gathered_gu (broadcast multiply)
|
||||
(= ?gu_matmul_mul (Op (Mul ?gu_matmul_mul_shape ?gu_matmul_a_stride ?gu_matmul_b_stride ?gu_matmul_mul_out_stride) (ICons ?x (ICons ?gu_f32 (INil)))))
|
||||
; t61: SumReduce over K dimension
|
||||
(= ?gu_matmul (Op (Sum ?gu_matmul_out_shape ?gu_matmul_k ?gu_matmul_in_stride ?gu_matmul_k_stride ?gu_matmul_out_stride) (ICons ?gu_matmul_mul (INil))))
|
||||
)
|
||||
(
|
||||
(set (glumoe_gate_up ?gu_matmul)
|
||||
(MkGLUMoEGateUpState ?gu_io ?gu_matmul_k ?gu_iota_within_range ?x ?topk_idx ?gate_up_w))
|
||||
)
|
||||
:name "GLUMoE gate-up matmul marker"
|
||||
)
|
||||
|
||||
; ===== SwiGLU activation marker =====
|
||||
(rule
|
||||
(
|
||||
(= ?gate_up_state (glumoe_gate_up ?gu_matmul))
|
||||
(= ?gate_up_state (MkGLUMoEGateUpState ?gu_io ?gu_matmul_k ?gu_within_range ?x ?topk_idx ?gate_up_w))
|
||||
|
||||
; ===== Up slice via Iota+Gather =====
|
||||
; t62: Iota with complex expression (slicing the "up" half)
|
||||
(= ?up_iota (Op (Iota ?up_iota_expr ?up_iota_range) (INil)))
|
||||
; t63: Gather to select up portion from matmul result
|
||||
(= ?up_slice (Op (Gather ?up_gather_idx_shape ?up_gather_idx_stride ?up_gather_data_shape ?up_gather_data_stride) (ICons ?up_iota (ICons ?gu_matmul (INil)))))
|
||||
|
||||
; ===== SwiGLU: silu(gate) * up =====
|
||||
; t64: Constant(-1)
|
||||
(= ?neg1 (Op (Constant -1.000000) (INil)))
|
||||
; t65: gate * -1
|
||||
(= ?neg_gate (Op (Mul ?silu_shape1 ?silu_a_stride1 ?silu_b_stride1 ?silu_out_stride1) (ICons ?gu_matmul (ICons ?neg1 (INil)))))
|
||||
; t66: Constant(log2e)
|
||||
(= ?log2e (Op (Constant 1.442695) (INil)))
|
||||
; t67: neg_gate * log2e
|
||||
(= ?scaled (Op (Mul ?silu_shape2 ?silu_a_stride2 ?silu_b_stride2 ?silu_out_stride2) (ICons ?neg_gate (ICons ?log2e (INil)))))
|
||||
; t68: exp2
|
||||
(= ?exp2_val (Op (Exp2 ?silu_shape3 ?silu_in_stride3 ?silu_out_stride3) (ICons ?scaled (INil))))
|
||||
; t69: Constant(1)
|
||||
(= ?one (Op (Constant 1.000000) (INil)))
|
||||
; t70: exp2 + 1
|
||||
(= ?plus1 (Op (Add ?silu_shape4 ?silu_a_stride4 ?silu_b_stride4 ?silu_out_stride4) (ICons ?exp2_val (ICons ?one (INil)))))
|
||||
; t71: recip
|
||||
(= ?sigmoid (Op (Recip ?silu_shape5 ?silu_in_stride5 ?silu_out_stride5) (ICons ?plus1 (INil))))
|
||||
; t72: gate * sigmoid(gate) = silu(gate)
|
||||
(= ?silu_out (Op (Mul ?silu_shape6 ?silu_a_stride6 ?silu_b_stride6 ?silu_out_stride6) (ICons ?gu_matmul (ICons ?sigmoid (INil)))))
|
||||
; t73: silu(gate) * up
|
||||
(= ?swiglu_out (Op (Mul ?swiglu_shape ?swiglu_a_stride ?swiglu_b_stride ?swiglu_out_stride) (ICons ?silu_out (ICons ?up_slice (INil)))))
|
||||
)
|
||||
(
|
||||
(set (glumoe_swiglu ?swiglu_out) (MkGLUMoESwiGLUState ?gate_up_state))
|
||||
)
|
||||
:name "GLUMoE swiglu marker"
|
||||
)
|
||||
|
||||
; ===== Gemma GELU activation marker =====
|
||||
(rule
|
||||
(
|
||||
(= ?gate_up_state (glumoe_gate_up ?gu_matmul))
|
||||
(= ?gate_up_state (MkGLUMoEGateUpState ?gu_io ?gu_matmul_k ?gu_within_range ?x ?topk_idx ?gate_up_w))
|
||||
|
||||
(= ?up_iota (Op (Iota ?up_iota_expr ?up_iota_range) (INil)))
|
||||
(= ?up_slice (Op (Gather ?up_gather_idx_shape ?up_gather_idx_stride ?up_gather_data_shape ?up_gather_data_stride) (ICons ?up_iota (ICons ?gu_matmul (INil)))))
|
||||
|
||||
(= ?gelu_coeff_inner (Op (Constant 0.044715) (INil)))
|
||||
(= ?gelu_inner_scaled (Op (Mul ?gelu_inner_scaled_shape ?gelu_inner_scaled_a_stride ?gelu_inner_scaled_b_stride ?gelu_inner_scaled_out_stride) (ICons ?gu_matmul (ICons ?gelu_coeff_inner (INil)))))
|
||||
(= ?gelu_inner_quad (Op (Mul ?gelu_inner_quad_shape ?gelu_inner_quad_a_stride ?gelu_inner_quad_b_stride ?gelu_inner_quad_out_stride) (ICons ?gelu_inner_scaled (ICons ?gu_matmul (INil)))))
|
||||
(= ?gelu_one (Op (Constant 1.000000) (INil)))
|
||||
(= ?gelu_poly (Op (Add ?gelu_poly_shape ?gelu_poly_a_stride ?gelu_poly_b_stride ?gelu_poly_out_stride) (ICons ?gelu_inner_quad (ICons ?gelu_one (INil)))))
|
||||
(= ?gelu_coeff_outer (Op (Constant 1.595769) (INil)))
|
||||
(= ?gelu_outer_scaled (Op (Mul ?gelu_outer_scaled_shape ?gelu_outer_scaled_a_stride ?gelu_outer_scaled_b_stride ?gelu_outer_scaled_out_stride) (ICons ?gu_matmul (ICons ?gelu_coeff_outer (INil)))))
|
||||
(= ?gelu_scaled (Op (Mul ?gelu_scaled_shape ?gelu_scaled_a_stride ?gelu_scaled_b_stride ?gelu_scaled_out_stride) (ICons ?gelu_outer_scaled (ICons ?gelu_poly (INil)))))
|
||||
(= ?neg1 (Op (Constant -1.000000) (INil)))
|
||||
(= ?gelu_neg (Op (Mul ?gelu_neg_shape ?gelu_neg_a_stride ?gelu_neg_b_stride ?gelu_neg_out_stride) (ICons ?gelu_scaled (ICons ?neg1 (INil)))))
|
||||
(= ?log2e (Op (Constant 1.442695) (INil)))
|
||||
(= ?gelu_exp_scaled (Op (Mul ?gelu_exp_scaled_shape ?gelu_exp_scaled_a_stride ?gelu_exp_scaled_b_stride ?gelu_exp_scaled_out_stride) (ICons ?gelu_neg (ICons ?log2e (INil)))))
|
||||
(= ?gelu_exp2_val (Op (Exp2 ?gelu_exp_shape ?gelu_exp_in_stride ?gelu_exp_out_stride) (ICons ?gelu_exp_scaled (INil))))
|
||||
(= ?gelu_plus1 (Op (Add ?gelu_plus1_shape ?gelu_plus1_a_stride ?gelu_plus1_b_stride ?gelu_plus1_out_stride) (ICons ?gelu_exp2_val (ICons ?gelu_one (INil)))))
|
||||
(= ?gelu_sigmoid (Op (Recip ?gelu_sigmoid_shape ?gelu_sigmoid_in_stride ?gelu_sigmoid_out_stride) (ICons ?gelu_plus1 (INil))))
|
||||
(= ?gelu_out (Op (Mul ?gelu_out_shape ?gelu_out_a_stride ?gelu_out_b_stride ?gelu_out_out_stride) (ICons ?gu_matmul (ICons ?gelu_sigmoid (INil)))))
|
||||
(= ?gemma_out (Op (Mul ?geglu_shape ?geglu_a_stride ?geglu_b_stride ?geglu_out_stride) (ICons ?gelu_out (ICons ?up_slice (INil)))))
|
||||
)
|
||||
(
|
||||
(set (glumoe_gemma_gelu ?gemma_out) (MkGLUMoEGemmaGELUState ?gate_up_state))
|
||||
)
|
||||
:name "GLUMoE gemma gelu marker"
|
||||
)
|
||||
|
||||
; ===== SwiGLU down marker =====
|
||||
(rule
|
||||
(
|
||||
(= ?swiglu_state (glumoe_swiglu ?swiglu_out))
|
||||
(= ?swiglu_state (MkGLUMoESwiGLUState ?gate_up_state))
|
||||
|
||||
; ===== Down expert gather =====
|
||||
; t74: Iota for base index (expert_idx * io_down)
|
||||
(= ?dn_iota_base (Op (Iota ?dn_io ?dn_iota_base_range) (INil)))
|
||||
; t75: Mul topk_indices * io_down
|
||||
(= ?dn_mul_base (Op (Mul ?dn_mul_base_shape ?dn_mul_base_a_stride ?dn_mul_base_b_stride ?dn_mul_base_out_stride) (ICons ?topk_idx (ICons ?dn_iota_base (INil)))))
|
||||
; t76: Cast to F32
|
||||
(= ?dn_cast_base (Op (Cast ?dn_cast_base_size (F32)) (ICons ?dn_mul_base (INil))))
|
||||
; t77: Iota for within-expert index
|
||||
(= ?dn_iota_within (Op (Iota (MIter) ?dn_iota_within_range) (INil)))
|
||||
; t78: Cast within to F32
|
||||
(= ?dn_cast_within (Op (Cast ?dn_cast_within_size (F32)) (ICons ?dn_iota_within (INil))))
|
||||
; t79: Add base + within
|
||||
(= ?dn_add_idx (Op (Add ?dn_add_shape ?dn_add_a_stride ?dn_add_b_stride ?dn_add_out_stride) (ICons ?dn_cast_base (ICons ?dn_cast_within (INil)))))
|
||||
; t80: Cast to Int
|
||||
(= ?dn_cast_idx (Op (Cast ?dn_cast_idx_size (Int)) (ICons ?dn_add_idx (INil))))
|
||||
; t81: Gather down weights
|
||||
(= ?dn_gathered (Op (Gather ?dn_gather_idx_shape ?dn_gather_idx_stride ?dn_gather_data_shape ?dn_gather_data_stride) (ICons ?dn_cast_idx (ICons ?down_w (INil)))))
|
||||
|
||||
; ===== Cast BF16→F32 =====
|
||||
; t82: Cast gathered down to F32
|
||||
(= ?dn_add_idx (Op (Add ?dn_add_shape ?dn_add_a_stride ?dn_add_b_stride ?dn_add_out_stride) (ICons ?dn_mul_base (ICons ?dn_iota_within (INil)))))
|
||||
(= ?dn_gathered (Op (Gather ?dn_gather_idx_shape ?dn_gather_idx_stride ?dn_gather_data_shape ?dn_gather_data_stride) (ICons ?dn_add_idx (ICons ?down_w (INil)))))
|
||||
(= ?dn_f32 (Op (Cast ?dn_f32_size (F32)) (ICons ?dn_gathered (INil))))
|
||||
|
||||
; ===== Down batched matmul =====
|
||||
; t83: Mul swiglu_out * gathered_down (broadcast multiply)
|
||||
(= ?dn_matmul_mul (Op (Mul ?dn_matmul_mul_shape ?dn_matmul_a_stride ?dn_matmul_b_stride ?dn_matmul_mul_out_stride) (ICons ?swiglu_out (ICons ?dn_f32 (INil)))))
|
||||
; t84: SumReduce
|
||||
(= ?dn_matmul (Op (Sum ?dn_matmul_out_shape ?dn_matmul_k ?dn_matmul_in_stride ?dn_matmul_k_stride ?dn_matmul_out_stride) (ICons ?dn_matmul_mul (INil))))
|
||||
)
|
||||
(
|
||||
(set (glumoe_swiglu_down ?dn_matmul)
|
||||
(MkGLUMoESwiGLUDownState ?dn_io ?dn_matmul_k ?dn_iota_within_range ?swiglu_state ?topk_idx ?down_w))
|
||||
)
|
||||
:name "GLUMoE swiglu down marker"
|
||||
)
|
||||
|
||||
; ===== Gemma GELU down marker =====
|
||||
(rule
|
||||
(
|
||||
(= ?gemma_state (glumoe_gemma_gelu ?gemma_out))
|
||||
(= ?gemma_state (MkGLUMoEGemmaGELUState ?gate_up_state))
|
||||
|
||||
(= ?dn_iota_base (Op (Iota ?dn_io ?dn_iota_base_range) (INil)))
|
||||
(= ?dn_mul_base (Op (Mul ?dn_mul_base_shape ?dn_mul_base_a_stride ?dn_mul_base_b_stride ?dn_mul_base_out_stride) (ICons ?topk_idx (ICons ?dn_iota_base (INil)))))
|
||||
(= ?dn_iota_within (Op (Iota (MIter) ?dn_iota_within_range) (INil)))
|
||||
(= ?dn_add_idx (Op (Add ?dn_add_shape ?dn_add_a_stride ?dn_add_b_stride ?dn_add_out_stride) (ICons ?dn_mul_base (ICons ?dn_iota_within (INil)))))
|
||||
(= ?dn_gathered (Op (Gather ?dn_gather_idx_shape ?dn_gather_idx_stride ?dn_gather_data_shape ?dn_gather_data_stride) (ICons ?dn_add_idx (ICons ?down_w (INil)))))
|
||||
(= ?dn_f32 (Op (Cast ?dn_f32_size (F32)) (ICons ?dn_gathered (INil))))
|
||||
(= ?dn_matmul_mul (Op (Mul ?dn_matmul_mul_shape ?dn_matmul_a_stride ?dn_matmul_b_stride ?dn_matmul_mul_out_stride) (ICons ?gemma_out (ICons ?dn_f32 (INil)))))
|
||||
(= ?dn_matmul (Op (Sum ?dn_matmul_out_shape ?dn_matmul_k ?dn_matmul_in_stride ?dn_matmul_k_stride ?dn_matmul_out_stride) (ICons ?dn_matmul_mul (INil))))
|
||||
)
|
||||
(
|
||||
(set (glumoe_gemma_down ?dn_matmul)
|
||||
(MkGLUMoEGemmaDownState ?dn_io ?dn_matmul_k ?dn_iota_within_range ?gemma_state ?topk_idx ?down_w))
|
||||
)
|
||||
:name "GLUMoE gemma down marker"
|
||||
)
|
||||
|
||||
; ===== Final fusion: mode 0 (SwiGLU) =====
|
||||
(rule
|
||||
(
|
||||
(= ?down_state (glumoe_swiglu_down ?dn_matmul))
|
||||
(= ?down_state (MkGLUMoESwiGLUDownState ?dn_io ?dn_matmul_k ?dn_within_range ?swiglu_state ?topk_idx ?down_w))
|
||||
(= ?swiglu_state (MkGLUMoESwiGLUState ?gate_up_state))
|
||||
(= ?gate_up_state (MkGLUMoEGateUpState ?gu_io ?gu_matmul_k ?gu_within_range ?x ?topk_idx ?gate_up_w))
|
||||
|
||||
; ===== Weighted sum over k experts =====
|
||||
; t85: Mul down_out * topk_values
|
||||
(= ?weighted (Op (Mul ?weighted_shape ?weighted_a_stride ?weighted_b_stride ?weighted_out_stride) (ICons ?dn_matmul (ICons ?topk_vals (INil)))))
|
||||
; t86: SumReduce over k dimension → [s, H]
|
||||
(= ?output (Op (Sum ?output_shape ?output_k ?output_in_stride ?output_k_stride ?output_out_stride) (ICons ?weighted (INil))))
|
||||
)
|
||||
(
|
||||
(let ?glumoe (Op (GLUMoE
|
||||
?gu_io ?dn_io ?gu_matmul_k ?dn_matmul_k ?output_k
|
||||
?gu_iota_within_range ?dn_iota_within_range)
|
||||
(ICons ?x (ICons ?topk_idx (ICons ?topk_vals (ICons ?gate_up_w (ICons ?down_w (INil))))))))
|
||||
?gu_within_range ?dn_within_range (MNum 0))
|
||||
(ICons ?x (ICons ?topk_idx (ICons ?topk_vals (ICons ?gate_up_w (ICons ?down_w (ICons ?topk_vals (INil)))))))))
|
||||
(union ?output ?glumoe)
|
||||
)
|
||||
:name "GLUMoE fused expert computation"
|
||||
:name "GLUMoE fused expert computation (swiglu)"
|
||||
)
|
||||
|
||||
; ===== Final fusion: mode 1 (Gemma GELU) =====
|
||||
(rule
|
||||
(
|
||||
(= ?down_state (glumoe_gemma_down ?dn_matmul))
|
||||
(= ?down_state (MkGLUMoEGemmaDownState ?dn_io ?dn_matmul_k ?dn_within_range ?gemma_state ?topk_idx ?down_w))
|
||||
(= ?gemma_state (MkGLUMoEGemmaGELUState ?gate_up_state))
|
||||
(= ?gate_up_state (MkGLUMoEGateUpState ?gu_io ?gu_matmul_k ?gu_within_range ?x ?topk_idx ?gate_up_w))
|
||||
|
||||
; Gemma expert weights: topk_weights = normed_topk * per_expert_scale.gather(topk_idx)
|
||||
(= ?per_expert_vals (Op (Gather ?scale_gather_idx_shape ?scale_gather_idx_stride ?scale_gather_data_shape ?scale_gather_data_stride) (ICons ?topk_idx (ICons ?per_expert_scale (INil)))))
|
||||
(= ?topk_row_offsets (Op (Iota ?topk_row_offsets_expr ?topk_row_offsets_range) (INil)))
|
||||
(= ?topk_flat_idx (Op (Add ?topk_flat_idx_shape ?topk_flat_idx_a_stride ?topk_flat_idx_b_stride ?topk_flat_idx_out_stride) (ICons ?topk_row_offsets (ICons ?topk_idx (INil)))))
|
||||
(= ?topk_vals (Op (Gather ?topk_vals_gather_idx_shape ?topk_vals_gather_idx_stride ?topk_vals_gather_data_shape ?topk_vals_gather_data_stride) (ICons ?topk_flat_idx (ICons ?routing_weights (INil)))))
|
||||
(= ?topk_norm (Op (Sum ?topk_norm_shape ?output_k ?topk_norm_in_stride ?topk_norm_k_stride ?topk_norm_out_stride) (ICons ?topk_vals (INil))))
|
||||
(= ?topk_norm_factor (Op (Recip ?topk_norm_recip_shape ?topk_norm_recip_in_stride ?topk_norm_recip_out_stride) (ICons ?topk_norm (INil))))
|
||||
(= ?normed_topk (Op (Mul ?normed_topk_shape ?normed_topk_a_stride ?normed_topk_b_stride ?normed_topk_out_stride) (ICons ?topk_vals (ICons ?topk_norm_factor (INil)))))
|
||||
(= ?expert_weights (Op (Mul ?expert_weights_shape ?expert_weights_a_stride ?expert_weights_b_stride ?expert_weights_out_stride) (ICons ?normed_topk (ICons ?per_expert_vals (INil)))))
|
||||
|
||||
(= ?weighted (Op (Mul ?weighted_shape ?weighted_a_stride ?weighted_b_stride ?weighted_out_stride) (ICons ?dn_matmul (ICons ?expert_weights (INil)))))
|
||||
(= ?output (Op (Sum ?output_shape ?output_k ?output_in_stride ?output_k_stride ?output_out_stride) (ICons ?weighted (INil))))
|
||||
)
|
||||
(
|
||||
(let ?glumoe (Op (GLUMoE
|
||||
?gu_io ?dn_io ?gu_matmul_k ?dn_matmul_k ?output_k
|
||||
?gu_within_range ?dn_within_range (MNum 1))
|
||||
(ICons ?x (ICons ?topk_idx (ICons ?topk_vals (ICons ?gate_up_w (ICons ?down_w (ICons ?per_expert_scale (INil)))))))))
|
||||
(union ?output ?glumoe)
|
||||
)
|
||||
:name "GLUMoE fused expert computation (gemma_gelu)"
|
||||
)
|
||||
|
||||
@@ -33,14 +33,15 @@ use crate::{
|
||||
},
|
||||
},
|
||||
host::HostOp,
|
||||
try_create_cublaslt,
|
||||
};
|
||||
|
||||
const WORKSPACE_SIZE: usize = 32 * 1024 * 1024; // 32 MiB
|
||||
|
||||
/// Fused GLU-MoE HostOp matched via egglog pattern.
|
||||
///
|
||||
/// Replaces the expert computation subgraph (expert gathers + matmuls + SwiGLU
|
||||
/// + weighted sum) with an efficient cuBLASLt implementation.
|
||||
/// Replaces the expert computation subgraph (expert gathers + matmuls + gated
|
||||
/// activation + weighted sum) with an efficient cuBLASLt implementation.
|
||||
///
|
||||
/// Inputs (graph edges, in order):
|
||||
/// 0: x [seq, hidden] F32
|
||||
@@ -48,9 +49,13 @@ const WORKSPACE_SIZE: usize = 32 * 1024 * 1024; // 32 MiB
|
||||
/// 2: topk_values [seq, k] F32
|
||||
/// 3: gate_up_w [E, gate_up_dim, hidden] BF16
|
||||
/// 4: down_w [E, hidden, intermediate] BF16
|
||||
/// 5: mode_aux
|
||||
/// - SwiGLU: ignored (rewriter wires `topk_values` again)
|
||||
/// - GemmaGELU: per_expert_scale [E] F32
|
||||
///
|
||||
/// Output: [seq, hidden] F32
|
||||
pub struct GLUMoE {
|
||||
pub(crate) mode: GLUMoEMode,
|
||||
/// Product of gate_up weight dimensions per expert (gate_up_dim * hidden) used for gather stride
|
||||
gu_io: Expression,
|
||||
/// Product of down weight dimensions per expert (hidden * intermediate) used for gather stride
|
||||
@@ -69,9 +74,35 @@ pub struct GLUMoE {
|
||||
module: OnceLock<(Arc<CudaModule>, CudaFunction, CudaFunction)>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub(crate) enum GLUMoEMode {
|
||||
SwiGLU,
|
||||
GemmaGELU,
|
||||
}
|
||||
|
||||
impl GLUMoEMode {
|
||||
fn from_mode_id(mode_id: usize) -> Self {
|
||||
match mode_id {
|
||||
0 => Self::SwiGLU,
|
||||
1 => Self::GemmaGELU,
|
||||
other => {
|
||||
panic!("Unknown GLUMoE mode id: {other}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn activation_kernel_mode(self) -> i32 {
|
||||
match self {
|
||||
Self::SwiGLU => 0,
|
||||
Self::GemmaGELU => 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for GLUMoE {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
mode: GLUMoEMode::SwiGLU,
|
||||
gu_io: Expression::default(),
|
||||
dn_io: Expression::default(),
|
||||
gu_matmul_k: Expression::default(),
|
||||
@@ -88,6 +119,7 @@ impl Default for GLUMoE {
|
||||
impl std::fmt::Debug for GLUMoE {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("GLUMoE")
|
||||
.field("mode", &self.mode)
|
||||
.field("gu_io", &self.gu_io)
|
||||
.field("dn_io", &self.dn_io)
|
||||
.field("gu_matmul_k", &self.gu_matmul_k)
|
||||
@@ -100,6 +132,7 @@ impl std::fmt::Debug for GLUMoE {
|
||||
impl Clone for GLUMoE {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
mode: self.mode,
|
||||
gu_io: self.gu_io,
|
||||
dn_io: self.dn_io,
|
||||
gu_matmul_k: self.gu_matmul_k,
|
||||
@@ -114,9 +147,15 @@ impl Clone for GLUMoE {
|
||||
}
|
||||
|
||||
impl GLUMoE {
|
||||
fn get_cublaslt(&self, stream: &Arc<CudaStream>) -> &Arc<CudaBlasLT> {
|
||||
self.cublaslt
|
||||
.get_or_init(|| Arc::new(CudaBlasLT::new(stream.clone()).unwrap()))
|
||||
fn get_cublaslt(&self, stream: &Arc<CudaStream>) -> anyhow::Result<Arc<CudaBlasLT>> {
|
||||
if let Some(cublaslt) = self.cublaslt.get() {
|
||||
return Ok(cublaslt.clone());
|
||||
}
|
||||
let created = try_create_cublaslt(stream.clone()).map_err(|message| {
|
||||
anyhow::anyhow!("cuBLASLt unavailable on this machine: {message}")
|
||||
})?;
|
||||
let _ = self.cublaslt.set(created.clone());
|
||||
Ok(created)
|
||||
}
|
||||
|
||||
fn get_kernels(
|
||||
@@ -134,23 +173,34 @@ extern "C" __global__ void f32_to_bf16(unsigned long long in_ptr, unsigned long
|
||||
if (i < n) out[i] = __float2bfloat16(in_[i]);
|
||||
}
|
||||
|
||||
extern "C" __global__ void swiglu_bf16(unsigned long long gate_up_ptr, unsigned long long out_ptr, int intermediate) {
|
||||
extern "C" __global__ void glu_activation_bf16(
|
||||
unsigned long long gate_up_ptr,
|
||||
unsigned long long out_ptr,
|
||||
int intermediate,
|
||||
int mode
|
||||
) {
|
||||
const __nv_bfloat16* gate_up = (const __nv_bfloat16*)gate_up_ptr;
|
||||
__nv_bfloat16* out = (__nv_bfloat16*)out_ptr;
|
||||
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (i < intermediate) {
|
||||
float gate = __bfloat162float(gate_up[i]);
|
||||
float up = __bfloat162float(gate_up[i + intermediate]);
|
||||
float silu = gate / (1.0f + expf(-gate));
|
||||
out[i] = __float2bfloat16(silu * up);
|
||||
float activated;
|
||||
if (mode == 0) {
|
||||
activated = gate / (1.0f + expf(-gate));
|
||||
} else {
|
||||
float scaled = 1.5957691216f * gate * (1.0f + 0.044715f * gate * gate);
|
||||
activated = gate / (1.0f + expf(-scaled));
|
||||
}
|
||||
out[i] = __float2bfloat16(activated * up);
|
||||
}
|
||||
}
|
||||
"#;
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), src).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let f32_to_bf16 = module.load_function("f32_to_bf16").unwrap();
|
||||
let swiglu = module.load_function("swiglu_bf16").unwrap();
|
||||
(module, f32_to_bf16, swiglu)
|
||||
let activation = module.load_function("glu_activation_bf16").unwrap();
|
||||
(module, f32_to_bf16, activation)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -168,12 +218,27 @@ impl EgglogOp for GLUMoE {
|
||||
("output_k", EXPRESSION),
|
||||
("gu_within_range", EXPRESSION),
|
||||
("dn_within_range", EXPRESSION),
|
||||
("mode", EXPRESSION),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![Rule::raw(
|
||||
"(rule
|
||||
(
|
||||
(= ?e (Op (GLUMoE ?gu_io ?dn_io ?gu_matmul_k ?dn_matmul_k ?output_k ?gu_within_range ?dn_within_range ?mode) ?inputs))
|
||||
)
|
||||
(
|
||||
(set (dtype ?e) (F32))
|
||||
)
|
||||
:ruleset dtype_prop
|
||||
)",
|
||||
)]
|
||||
}
|
||||
|
||||
fn n_inputs(&self) -> usize {
|
||||
5
|
||||
6
|
||||
}
|
||||
|
||||
fn early_rewrites(&self) -> Vec<Rule> {
|
||||
@@ -195,8 +260,14 @@ impl EgglogOp for GLUMoE {
|
||||
let output_k = extract_expr(egraph, kind_children[4], expr_cache).unwrap();
|
||||
let gu_within_range = extract_expr(egraph, kind_children[5], expr_cache).unwrap();
|
||||
let dn_within_range = extract_expr(egraph, kind_children[6], expr_cache).unwrap();
|
||||
let mode_expr = extract_expr(egraph, kind_children[7], expr_cache).unwrap();
|
||||
let mode_id = mode_expr
|
||||
.to_usize()
|
||||
.unwrap_or_else(|| panic!("GLUMoE mode must be static, got expression: {mode_expr}"));
|
||||
let mode = GLUMoEMode::from_mode_id(mode_id);
|
||||
|
||||
let extracted = GLUMoE {
|
||||
mode,
|
||||
gu_io,
|
||||
dn_io,
|
||||
gu_matmul_k,
|
||||
@@ -209,7 +280,7 @@ impl EgglogOp for GLUMoE {
|
||||
};
|
||||
|
||||
let op = LLIROp::new::<dyn HostOp>(Box::new(extracted) as Box<dyn HostOp>);
|
||||
// Return the 5 IR inputs: x, topk_idx, topk_vals, gate_up_w, down_w
|
||||
// Return the 6 IR inputs: x, topk_idx, topk_values, gate_up_w, down_w, mode_aux
|
||||
(op, input_enodes)
|
||||
}
|
||||
|
||||
@@ -230,9 +301,9 @@ impl HostOp for GLUMoE {
|
||||
// Resolve dimensions
|
||||
let hidden = self.gu_matmul_k.exec(dyn_map).unwrap();
|
||||
let intermediate = self.dn_matmul_k.exec(dyn_map).unwrap();
|
||||
let top_k = self.output_k.exec(dyn_map).unwrap();
|
||||
let top_k_expected = self.output_k.exec(dyn_map).unwrap();
|
||||
let gate_up_dim = self.gu_io.exec(dyn_map).unwrap() / hidden; // gate_up_dim = gu_io / hidden
|
||||
let _num_experts = self.gu_within_range.exec(dyn_map).unwrap() / (gate_up_dim * hidden);
|
||||
let num_experts = self.gu_within_range.exec(dyn_map).unwrap() / (gate_up_dim * hidden);
|
||||
|
||||
// Derive seq from x buffer size: x is [seq, hidden] F32 → seq = len / (hidden * 4)
|
||||
let x_buf = buffers[&inputs[0]];
|
||||
@@ -243,6 +314,7 @@ impl HostOp for GLUMoE {
|
||||
let topk_vals_buf = buffers[&inputs[2]]; // [seq, k] F32
|
||||
let gate_up_buf = buffers[&inputs[3]]; // [E, gate_up_dim, hidden] BF16
|
||||
let down_buf = buffers[&inputs[4]]; // [E, hidden, intermediate] BF16
|
||||
let mode_aux_buf = buffers[&inputs[5]];
|
||||
let output_buf = buffers[&self_node]; // [seq, hidden] F32
|
||||
|
||||
// Get raw device pointer addresses
|
||||
@@ -251,14 +323,59 @@ impl HostOp for GLUMoE {
|
||||
let down_ptr = buf_ptr(down_buf, stream);
|
||||
let output_ptr = buf_ptr(output_buf, stream);
|
||||
|
||||
let cublaslt = self.get_cublaslt(stream);
|
||||
let (_, f32_to_bf16_fn, swiglu_fn) = self.get_kernels(stream);
|
||||
let cublaslt = self.get_cublaslt(stream)?;
|
||||
let (_, f32_to_bf16_fn, activation_fn) = self.get_kernels(stream);
|
||||
|
||||
// Read topk indices and values from GPU
|
||||
// Read top-k routing values from GPU
|
||||
let topk_idx_host: Vec<u8> = stream.clone_dtoh(topk_idx_buf)?;
|
||||
let topk_idx_i32: &[i32] = bytemuck::cast_slice(&topk_idx_host);
|
||||
let topk_vals_host: Vec<u8> = stream.clone_dtoh(topk_vals_buf)?;
|
||||
let topk_vals_f32: &[f32] = bytemuck::cast_slice(&topk_vals_host);
|
||||
let idx_k = topk_idx_i32
|
||||
.len()
|
||||
.checked_div(seq)
|
||||
.unwrap_or(top_k_expected);
|
||||
let val_k = topk_vals_f32
|
||||
.len()
|
||||
.checked_div(seq)
|
||||
.unwrap_or(top_k_expected);
|
||||
let top_k = idx_k.min(val_k);
|
||||
if seq > 0 && top_k == 0 {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Mode-dependent expert weights used for the final reduction:
|
||||
// - SwiGLU: direct topk values
|
||||
// - GemmaGELU: normalize topk values and scale by per-expert factors
|
||||
let mut expert_weights_storage: Vec<f32> = Vec::new();
|
||||
let expert_weights_f32: &[f32] = match self.mode {
|
||||
GLUMoEMode::SwiGLU => topk_vals_f32,
|
||||
GLUMoEMode::GemmaGELU => {
|
||||
let per_expert_scale_host: Vec<u8> = stream.clone_dtoh(mode_aux_buf)?;
|
||||
let per_expert_scale_f32: &[f32] = bytemuck::cast_slice(&per_expert_scale_host);
|
||||
debug_assert!(per_expert_scale_f32.len() >= num_experts);
|
||||
expert_weights_storage.resize(seq * top_k, 0.0);
|
||||
for t in 0..seq {
|
||||
let base = t * top_k;
|
||||
let vals = &topk_vals_f32[base..base + top_k];
|
||||
let norm = vals.iter().copied().sum::<f32>();
|
||||
let inv_norm = if norm != 0.0 { norm.recip() } else { 0.0 };
|
||||
for i in 0..top_k {
|
||||
let expert_idx = topk_idx_i32[base + i] as usize;
|
||||
if expert_idx >= per_expert_scale_f32.len() {
|
||||
anyhow::bail!(
|
||||
"GLUMoE Gemma mode expert index {} out of bounds {}",
|
||||
expert_idx,
|
||||
per_expert_scale_f32.len()
|
||||
);
|
||||
}
|
||||
let scale = per_expert_scale_f32[expert_idx];
|
||||
expert_weights_storage[base + i] = vals[i] * inv_norm * scale;
|
||||
}
|
||||
}
|
||||
&expert_weights_storage
|
||||
}
|
||||
};
|
||||
|
||||
// Allocate temp buffers
|
||||
let x_bf16_buf = unsafe { stream.alloc::<u8>(seq * hidden * 2)? }; // BF16
|
||||
@@ -291,22 +408,10 @@ impl HostOp for GLUMoE {
|
||||
let gu_stride = (gate_up_dim * hidden * 2) as u64; // bytes per expert gate_up (BF16)
|
||||
let down_stride = (hidden * intermediate * 2) as u64; // bytes per expert down (BF16)
|
||||
|
||||
// Normalize top-k values per token (norm_topk_prob=true)
|
||||
let mut normalized_vals = topk_vals_f32.to_vec();
|
||||
for t in 0..seq {
|
||||
let row = &mut normalized_vals[t * top_k..(t + 1) * top_k];
|
||||
let sum: f32 = row.iter().sum();
|
||||
if sum > 0.0 {
|
||||
for v in row.iter_mut() {
|
||||
*v /= sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for t in 0..seq {
|
||||
let x_t_ptr = xbf16_ptr + (t * hidden * 2) as u64; // BF16
|
||||
let expert_indices = &topk_idx_i32[t * top_k..(t + 1) * top_k];
|
||||
let weights = &normalized_vals[t * top_k..(t + 1) * top_k];
|
||||
let weights = &expert_weights_f32[t * top_k..(t + 1) * top_k];
|
||||
|
||||
for (i, (&expert_idx, &weight)) in expert_indices.iter().zip(weights.iter()).enumerate()
|
||||
{
|
||||
@@ -316,7 +421,7 @@ impl HostOp for GLUMoE {
|
||||
let expert_gu_ptr = gate_up_ptr + expert_idx as u64 * gu_stride;
|
||||
cublas_matmul(
|
||||
stream,
|
||||
cublaslt,
|
||||
&cublaslt,
|
||||
ws_ptr,
|
||||
gate_up_dim as u64,
|
||||
1,
|
||||
@@ -335,17 +440,19 @@ impl HostOp for GLUMoE {
|
||||
0.0f32,
|
||||
)?;
|
||||
|
||||
// b. SwiGLU kernel (BF16 → BF16)
|
||||
// b. Mode-specific gated activation (BF16 → BF16)
|
||||
let moe_int = intermediate as i32;
|
||||
let swiglu_blocks = (moe_int as u32).div_ceil(256);
|
||||
let activation_mode = self.mode.activation_kernel_mode();
|
||||
let activation_blocks = (moe_int as u32).div_ceil(256);
|
||||
unsafe {
|
||||
stream
|
||||
.launch_builder(swiglu_fn)
|
||||
.launch_builder(activation_fn)
|
||||
.arg(&gu_out_ptr)
|
||||
.arg(&hid_ptr)
|
||||
.arg(&moe_int)
|
||||
.arg(&activation_mode)
|
||||
.launch(LaunchConfig {
|
||||
grid_dim: (swiglu_blocks, 1, 1),
|
||||
grid_dim: (activation_blocks, 1, 1),
|
||||
block_dim: (256, 1, 1),
|
||||
shared_mem_bytes: 0,
|
||||
})?;
|
||||
@@ -358,7 +465,7 @@ impl HostOp for GLUMoE {
|
||||
let beta = if i == 0 { 0.0f32 } else { 1.0f32 };
|
||||
cublas_matmul_mixed(
|
||||
stream,
|
||||
cublaslt,
|
||||
&cublaslt,
|
||||
ws_ptr,
|
||||
hidden as u64,
|
||||
1,
|
||||
|
||||
@@ -653,4 +653,53 @@ mod tests {
|
||||
}
|
||||
assert_close(&rt.get_f32(output), &expected, 1e-2, 1e-2);
|
||||
}
|
||||
|
||||
/// Test that CUDA graphs produce correct results when dynamic dimensions
|
||||
/// change incrementally across many executions (simulating a decode loop
|
||||
/// where position offset increments each step).
|
||||
#[test]
|
||||
fn test_cuda_graph_incremental_dim_changes() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
let mut cx = Graph::default();
|
||||
let a = cx.tensor('s');
|
||||
let b = cx.tensor('s');
|
||||
let c = ((a + b) * a).output();
|
||||
|
||||
let initial_size = 128;
|
||||
cx.set_dim('s', initial_size);
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
let data_a = random_f32_vec(initial_size, 42, -0.5, 0.5);
|
||||
let data_b = random_f32_vec(initial_size, 43, -0.5, 0.5);
|
||||
rt.set_data(a, data_a.clone());
|
||||
rt.set_data(b, data_b.clone());
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
rt = cx.search(rt, 5);
|
||||
|
||||
// Initial execution
|
||||
rt.execute(&cx.dyn_map);
|
||||
let eps = dtype_epsilon(luminal::dtype::DType::F32);
|
||||
let tol = eps * TOLERANCE_SAFETY_FACTOR;
|
||||
let expected: Vec<f32> = data_a
|
||||
.iter()
|
||||
.zip(&data_b)
|
||||
.map(|(a, b)| (a + b) * a)
|
||||
.collect();
|
||||
assert_close(&rt.get_f32(c), &expected, tol, tol);
|
||||
|
||||
// Incrementally change the dynamic dimension 10 times,
|
||||
// simulating decode steps where position offset grows.
|
||||
for step in 1..=10usize {
|
||||
let size = initial_size + step;
|
||||
cx.set_dim('s', size);
|
||||
let da = random_f32_vec(size, 100 + step as u64, -0.5, 0.5);
|
||||
let db = random_f32_vec(size, 200 + step as u64, -0.5, 0.5);
|
||||
rt.set_data(a, da.clone());
|
||||
rt.set_data(b, db.clone());
|
||||
rt.execute(&cx.dyn_map);
|
||||
let expected: Vec<f32> = da.iter().zip(&db).map(|(a, b)| (a + b) * a).collect();
|
||||
assert_close(&rt.get_f32(c), &expected, tol, tol);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -634,8 +634,8 @@ extern \"C\" {{
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(out_size.ceil_div(128), 1.into(), 1.into()),
|
||||
(out_size.min(128), 1.into(), 1.into()),
|
||||
(out_size.ceil_div(256), 1.into(), 1.into()),
|
||||
(out_size.min(256), 1.into(), 1.into()),
|
||||
0.into(),
|
||||
FxHashMap::default(), // No per-module constants needed
|
||||
)
|
||||
@@ -797,8 +797,8 @@ extern \"C\" {{
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(out_size.ceil_div(128), 1.into(), 1.into()),
|
||||
(out_size.min(128), 1.into(), 1.into()),
|
||||
(out_size.ceil_div(256), 1.into(), 1.into()),
|
||||
(out_size.min(256), 1.into(), 1.into()),
|
||||
0.into(),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
@@ -990,12 +990,13 @@ extern \"C\" {{
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
(module, func)
|
||||
};
|
||||
let out_size = self.out_shape.iter().copied().product::<Expression>();
|
||||
(
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(self.out_shape.iter().copied().product(), 1.into(), 1.into()),
|
||||
(1.into(), 1.into(), 1.into()),
|
||||
(out_size.ceil_div(256), 1.into(), 1.into()),
|
||||
(out_size.min(256), 1.into(), 1.into()),
|
||||
0.into(),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
@@ -1615,8 +1616,8 @@ extern \"C\" {{
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(out_size.ceil_div(128), 1.into(), 1.into()),
|
||||
(out_size.min(128), 1.into(), 1.into()),
|
||||
(out_size.ceil_div(256), 1.into(), 1.into()),
|
||||
(out_size.min(256), 1.into(), 1.into()),
|
||||
0.into(),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
@@ -1769,8 +1770,8 @@ extern \"C\" {{
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(out_size.ceil_div(128), 1.into(), 1.into()),
|
||||
(out_size.min(128), 1.into(), 1.into()),
|
||||
(out_size.ceil_div(256), 1.into(), 1.into()),
|
||||
(out_size.min(256), 1.into(), 1.into()),
|
||||
0.into(),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
@@ -1923,8 +1924,8 @@ extern \"C\" {{
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(out_size.ceil_div(128), 1.into(), 1.into()),
|
||||
(out_size.min(128), 1.into(), 1.into()),
|
||||
(out_size.ceil_div(256), 1.into(), 1.into()),
|
||||
(out_size.min(256), 1.into(), 1.into()),
|
||||
0.into(),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
@@ -2077,8 +2078,8 @@ extern \"C\" {{
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(out_size.ceil_div(128), 1.into(), 1.into()),
|
||||
(out_size.min(128), 1.into(), 1.into()),
|
||||
(out_size.ceil_div(256), 1.into(), 1.into()),
|
||||
(out_size.min(256), 1.into(), 1.into()),
|
||||
0.into(),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
@@ -2231,8 +2232,8 @@ extern \"C\" {{
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(out_size.ceil_div(128), 1.into(), 1.into()),
|
||||
(out_size.min(128), 1.into(), 1.into()),
|
||||
(out_size.ceil_div(256), 1.into(), 1.into()),
|
||||
(out_size.min(256), 1.into(), 1.into()),
|
||||
0.into(),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
@@ -2392,8 +2393,8 @@ extern \"C\" {{
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(out_size.ceil_div(128), 1.into(), 1.into()),
|
||||
(out_size.min(128), 1.into(), 1.into()),
|
||||
(out_size.ceil_div(256), 1.into(), 1.into()),
|
||||
(out_size.min(256), 1.into(), 1.into()),
|
||||
0.into(),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
@@ -2567,8 +2568,8 @@ extern \"C\" {{
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(out_size.ceil_div(128), 1.into(), 1.into()),
|
||||
(out_size.min(128), 1.into(), 1.into()),
|
||||
(out_size.ceil_div(256), 1.into(), 1.into()),
|
||||
(out_size.min(256), 1.into(), 1.into()),
|
||||
0.into(),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
|
||||
@@ -10,7 +10,7 @@ use itertools::Itertools;
|
||||
use luminal::{
|
||||
egglog_utils::{
|
||||
api::{Rule, SortDef, sort},
|
||||
base::{DTYPE, ELIST, EXPRESSION, OP_KIND},
|
||||
base::{DTYPE, ELIST, EXPRESSION, OP_KIND, STRING},
|
||||
extract_dtype, extract_expr, extract_expr_list,
|
||||
},
|
||||
op::*,
|
||||
@@ -25,6 +25,7 @@ pub type Ops = (
|
||||
KernelSoftmax,
|
||||
KernelExp,
|
||||
KernelSigmoid,
|
||||
KernelFusedElementwise,
|
||||
);
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
@@ -1544,8 +1545,8 @@ extern \"C\" {{
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(out_size.ceil_div(128), 1.into(), 1.into()),
|
||||
(out_size.min(128), 1.into(), 1.into()),
|
||||
(out_size.ceil_div(256), 1.into(), 1.into()),
|
||||
(out_size.min(256), 1.into(), 1.into()),
|
||||
0.into(),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
@@ -1730,8 +1731,8 @@ extern \"C\" {{
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(out_size.ceil_div(128), 1.into(), 1.into()),
|
||||
(out_size.min(128), 1.into(), 1.into()),
|
||||
(out_size.ceil_div(256), 1.into(), 1.into()),
|
||||
(out_size.min(256), 1.into(), 1.into()),
|
||||
0.into(),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
@@ -1766,3 +1767,283 @@ extern \"C\" {{
|
||||
"Sigmoid"
|
||||
}
|
||||
}
|
||||
|
||||
/// A unary math function that can appear inside a fused elementwise kernel.
|
||||
/// Each variant has a stable string name (used both as the egglog token in
|
||||
/// the rule-generated ops string and as the `kernel_name()` of the source
|
||||
/// unary kernel op).
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum UnaryFn {
|
||||
Sin,
|
||||
Sqrt,
|
||||
Exp2,
|
||||
Log2,
|
||||
Recip,
|
||||
}
|
||||
|
||||
impl UnaryFn {
|
||||
pub fn name(self) -> &'static str {
|
||||
match self {
|
||||
UnaryFn::Sin => "Sin",
|
||||
UnaryFn::Sqrt => "Sqrt",
|
||||
UnaryFn::Exp2 => "Exp2",
|
||||
UnaryFn::Log2 => "Log2",
|
||||
UnaryFn::Recip => "Recip",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_name(name: &str) -> Self {
|
||||
match name {
|
||||
"Sin" => UnaryFn::Sin,
|
||||
"Sqrt" => UnaryFn::Sqrt,
|
||||
"Exp2" => UnaryFn::Exp2,
|
||||
"Log2" => UnaryFn::Log2,
|
||||
"Recip" => UnaryFn::Recip,
|
||||
_ => panic!("invalid UnaryFn name: {name}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// An LLIR-only op created by fusing a chain of unary elementwise kernels.
|
||||
/// Only fires when every op in the chain shares the same stride pattern,
|
||||
/// so reads and writes use a single `strides` field.
|
||||
///
|
||||
/// The `ops` sequence is carried as a comma-separated egglog `String`
|
||||
/// (e.g. `"Sin,Sqrt,Exp2"`) — it's pure codegen metadata that egglog never
|
||||
/// reasons about, and `String` is a primitive sort, so this avoids
|
||||
/// introducing a new datatype/sort just to carry the list.
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct KernelFusedElementwise {
|
||||
shape: Vec<Expression>,
|
||||
strides: Vec<Expression>,
|
||||
ops: Vec<UnaryFn>,
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
impl KernelFusedElementwise {
|
||||
pub fn ops(&self) -> &[UnaryFn] {
|
||||
&self.ops
|
||||
}
|
||||
}
|
||||
|
||||
impl EgglogOp for KernelFusedElementwise {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
"KernelFusedElementwise",
|
||||
&[
|
||||
("shape", ELIST),
|
||||
("strides", ELIST),
|
||||
("ops", STRING),
|
||||
("dtype", DTYPE),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
fn n_inputs(&self) -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
let unaries = [
|
||||
("KernelSin", UnaryFn::Sin),
|
||||
("KernelSqrt", UnaryFn::Sqrt),
|
||||
("KernelExp2", UnaryFn::Exp2),
|
||||
("KernelLog2", UnaryFn::Log2),
|
||||
("KernelRecip", UnaryFn::Recip),
|
||||
];
|
||||
let mut rules = Vec::with_capacity(unaries.len() * unaries.len() + unaries.len());
|
||||
|
||||
// Pair fusion: two adjacent pure-elementwise unaries -> Fused[a, b].
|
||||
for (a_name, a_fn) in unaries {
|
||||
for (b_name, b_fn) in unaries {
|
||||
let (a_str, b_str) = (a_fn.name(), b_fn.name());
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule
|
||||
(
|
||||
(= ?a (Op ({a_name} ?shape ?strides ?strides ?dt) (ICons ?inp (INil))))
|
||||
(= ?b (Op ({b_name} ?shape ?strides ?strides ?dt) (ICons ?a (INil))))
|
||||
)
|
||||
(
|
||||
(let ?fused (Op (KernelFusedElementwise ?shape ?strides
|
||||
\"{a_str},{b_str}\" ?dt)
|
||||
(ICons ?inp (INil))))
|
||||
(union ?b ?fused)
|
||||
)
|
||||
:name \"fuse-{a_name}-{b_name}\"
|
||||
)"
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
// Chain extend: Fused[ops] -> unary -> Fused[ops + \",<new>\"]. One
|
||||
// rule per outer unary. `+` is the builtin variadic string concat,
|
||||
// so this is O(1) per firing and handles chains of any length
|
||||
// without recursion.
|
||||
for (b_name, b_fn) in unaries {
|
||||
let b_str = b_fn.name();
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule
|
||||
(
|
||||
(= ?fused (Op (KernelFusedElementwise ?shape ?strides ?ops ?dt)
|
||||
(ICons ?inp (INil))))
|
||||
(= ?next (Op ({b_name} ?shape ?strides ?strides ?dt)
|
||||
(ICons ?fused (INil))))
|
||||
)
|
||||
(
|
||||
(let ?new_ops (+ ?ops \",{b_str}\"))
|
||||
(let ?new_fused (Op (KernelFusedElementwise ?shape ?strides ?new_ops ?dt)
|
||||
(ICons ?inp (INil))))
|
||||
(union ?next ?new_fused)
|
||||
)
|
||||
:name \"extend-Fused-{b_name}\"
|
||||
)"
|
||||
)));
|
||||
}
|
||||
|
||||
rules
|
||||
}
|
||||
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
// The `ops` field is a String enode; its label is the quoted
|
||||
// literal (e.g. `"Sin,Sqrt"`), so strip the quotes and split.
|
||||
let ops_str = egraph.enodes[kind_children[2]].0.replace('"', "");
|
||||
let ops = if ops_str.is_empty() {
|
||||
Vec::new()
|
||||
} else {
|
||||
ops_str.split(',').map(UnaryFn::from_name).collect()
|
||||
};
|
||||
(
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(Self {
|
||||
shape: extract_expr_list(egraph, kind_children[0], list_cache, expr_cache).unwrap(),
|
||||
strides: extract_expr_list(egraph, kind_children[1], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
ops,
|
||||
dtype: extract_dtype(egraph, kind_children[3]),
|
||||
})),
|
||||
input_enodes,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl KernelOp for KernelFusedElementwise {
|
||||
fn compile(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> (
|
||||
CudaFunction,
|
||||
Arc<CudaModule>,
|
||||
String,
|
||||
(Expression, Expression, Expression),
|
||||
(Expression, Expression, Expression),
|
||||
Expression,
|
||||
FxHashMap<char, CudaSlice<u8>>,
|
||||
) {
|
||||
let vars = self
|
||||
.shape
|
||||
.iter()
|
||||
.flat_map(|e| e.dyn_vars())
|
||||
.chain(self.strides.iter().flat_map(|e| e.dyn_vars()))
|
||||
.collect::<FxHashSet<_>>();
|
||||
let dtype = cuda_dtype(self.dtype);
|
||||
let includes = dtype_includes(&[self.dtype]);
|
||||
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
|
||||
let dyn_dims_param = if vars.is_empty() {
|
||||
""
|
||||
} else {
|
||||
", const int* dyn_dims"
|
||||
};
|
||||
let n_elements = self
|
||||
.shape
|
||||
.iter()
|
||||
.copied()
|
||||
.product::<Expression>()
|
||||
.to_kernel();
|
||||
let idx = flatten_strides(&self.shape, &self.strides).to_kernel();
|
||||
let ops_body = self
|
||||
.ops
|
||||
.iter()
|
||||
.map(|op| match op {
|
||||
UnaryFn::Sin => "val = sinf(val);",
|
||||
UnaryFn::Sqrt => "val = sqrtf(val);",
|
||||
UnaryFn::Exp2 => "val = exp2f(val);",
|
||||
UnaryFn::Log2 => "val = log2f(val);",
|
||||
UnaryFn::Recip => "val = 1.0f / val;",
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n ");
|
||||
let kernel = format!(
|
||||
"{includes}
|
||||
{dyn_defines}
|
||||
extern \"C\" {{
|
||||
__global__ void fused_elementwise_k({dtype} *out, const {dtype} *in{dyn_dims_param}) {{
|
||||
long long const_z = (long long)blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (const_z >= {n_elements}) return;
|
||||
long long idx = {idx};
|
||||
{dtype} val = in[idx];
|
||||
{ops_body}
|
||||
out[idx] = val;
|
||||
}}
|
||||
}}"
|
||||
);
|
||||
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
|
||||
(module.clone(), func.clone())
|
||||
} else {
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let func = module.load_function("fused_elementwise_k").unwrap();
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
(module, func)
|
||||
};
|
||||
let out_size = self.shape.iter().copied().product::<Expression>();
|
||||
(
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(out_size.ceil_div(256), 1.into(), 1.into()),
|
||||
(out_size.min(256), 1.into(), 1.into()),
|
||||
0.into(),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
self.shape.iter().copied().product()
|
||||
}
|
||||
|
||||
fn output_bytes(&self) -> Expression {
|
||||
(self.output_size() * self.dtype.bits()).ceil_div(8)
|
||||
}
|
||||
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
self.output_bytes()
|
||||
}
|
||||
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
self.output_bytes()
|
||||
}
|
||||
|
||||
fn flops(&self) -> Expression {
|
||||
self.output_size() * (self.ops.len() as i32)
|
||||
}
|
||||
|
||||
fn output_dtype(&self) -> DType {
|
||||
self.dtype
|
||||
}
|
||||
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
"FusedElementwise"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -302,8 +302,10 @@ impl CudaGraphOp {
|
||||
kernel.internal_bufs = kernel.kernel_op.allocate_internal_buffers(stream, dyn_map);
|
||||
}
|
||||
}
|
||||
// Force full rebuild when dims change (debug: testing if update_kernel_node is the issue)
|
||||
if dyn_map_changed || needs_internal_realloc {
|
||||
// Only force full rebuild when internal buffer sizes change.
|
||||
// Dim-only changes (e.g. position offset `p` incrementing each decode step) are
|
||||
// handled by updating the dyn_dims device buffer + kernel node params in-place.
|
||||
if needs_internal_realloc {
|
||||
state.cuda_graph = None;
|
||||
state.cuda_graph_exec = None;
|
||||
state.node_to_graph_node.clear();
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
pub mod dyn_backend;
|
||||
pub mod host;
|
||||
pub mod kernel;
|
||||
pub mod runtime;
|
||||
@@ -9,6 +10,8 @@ use std::{
|
||||
|
||||
pub use cudarc;
|
||||
|
||||
use cudarc::{cublaslt::CudaBlasLT, driver::CudaStream};
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
@@ -137,6 +140,25 @@ fn cuda_driver_diagnostics() -> (Option<i32>, Option<i32>) {
|
||||
(driver_version, None)
|
||||
}
|
||||
|
||||
pub(crate) fn try_create_cublaslt(
|
||||
stream: Arc<CudaStream>,
|
||||
) -> std::result::Result<Arc<CudaBlasLT>, String> {
|
||||
match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| CudaBlasLT::new(stream))) {
|
||||
Ok(Ok(handle)) => Ok(Arc::new(handle)),
|
||||
Ok(Err(err)) => Err(err.to_string()),
|
||||
Err(payload) => {
|
||||
let message = if let Some(message) = payload.downcast_ref::<String>() {
|
||||
message.clone()
|
||||
} else if let Some(message) = payload.downcast_ref::<&str>() {
|
||||
message.to_string()
|
||||
} else {
|
||||
"cuBLASLt initialization panicked".to_string()
|
||||
};
|
||||
Err(message)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn cuda_nvrtc_compile_options(target_arch: &str) -> Vec<String> {
|
||||
let mut options = cuda_nvrtc_include_paths()
|
||||
.into_iter()
|
||||
@@ -186,9 +208,9 @@ fn get_cubin(program: nvrtc_sys::nvrtcProgram) -> Result<Vec<u8>, NvrtcError> {
|
||||
}
|
||||
|
||||
let mut cubin = Vec::with_capacity(cubin_size);
|
||||
cubin.resize(cubin_size, 0);
|
||||
unsafe { nvrtc_sys::nvrtcGetCUBIN(program, cubin.as_mut_ptr()) }.result()?;
|
||||
Ok(cubin.into_iter().map(|byte| byte as u8).collect())
|
||||
cubin.resize(cubin_size, 0u8);
|
||||
unsafe { nvrtc_sys::nvrtcGetCUBIN(program, cubin.as_mut_ptr() as *mut _) }.result()?;
|
||||
Ok(cubin)
|
||||
}
|
||||
|
||||
pub(crate) fn compile_module_image_for_current_device<S: AsRef<str>>(
|
||||
|
||||
@@ -120,13 +120,17 @@ pub struct CudaRuntime {
|
||||
/// Bucket definitions per dimension (empty = single-bucket mode)
|
||||
dim_buckets: FxHashMap<char, Vec<DimBucket>>,
|
||||
|
||||
/// HLIR nodes that should never be consumed after execute().
|
||||
/// Used for weight tensors shared via external device pointers.
|
||||
persistent_hlir_nodes: FxHashSet<NodeIndex>,
|
||||
|
||||
/// Non-owning CudaSlice wrappers for external device pointers.
|
||||
/// ManuallyDrop prevents cuMemFree — the external allocator (e.g. PyTorch) owns the memory.
|
||||
external_buffers: FxHashMap<NodeIndex, std::mem::ManuallyDrop<CudaSlice<u8>>>,
|
||||
|
||||
/// Pending output pointer registrations: HLIR output id -> (device_ptr, n_bytes)
|
||||
/// Set by python before execute(), consumed at start of execute()
|
||||
output_ptr_registrations: FxHashMap<NodeIndex, (u64, usize)>,
|
||||
|
||||
/// Non-owning CudaSlice views of external output pointers, keyed by LLIR data node
|
||||
/// ManuallyDrop prevents cuMemFree -- Pytorch owns the memory
|
||||
external_output_buffers: FxHashMap<NodeIndex, std::mem::ManuallyDrop<CudaSlice<u8>>>,
|
||||
}
|
||||
|
||||
impl CudaRuntime {
|
||||
@@ -228,9 +232,25 @@ impl CudaRuntime {
|
||||
self.changed_hlir.insert(id);
|
||||
}
|
||||
|
||||
/// Mark an HLIR node as persistent — its buffer won't be consumed after execute().
|
||||
pub fn persist_hlir_node(&mut self, id: impl ToId) {
|
||||
self.persistent_hlir_nodes.insert(id.to_id());
|
||||
/// Register an external device pointer for an output tensor (zero-copy output).
|
||||
/// The pointer is stored lazily — resolution to LLIR nodes happens in execute().
|
||||
///
|
||||
/// # Safety
|
||||
/// The device pointer must point to a valid CUDA allocation with at least `n_bytes` bytes,
|
||||
/// and must remain valid through the next execute() call.
|
||||
pub unsafe fn set_output_device_ptr(&mut self, id: impl ToId, device_ptr: u64, n_bytes: usize) {
|
||||
debug_assert!(
|
||||
device_ptr != 0,
|
||||
"set_output_device_ptr called with null pointer"
|
||||
);
|
||||
self.output_ptr_registrations
|
||||
.insert(id.to_id(), (device_ptr, n_bytes));
|
||||
}
|
||||
|
||||
pub fn output_is_zero_copy(&self, id: impl ToId) -> bool {
|
||||
let producer = self.find_producer_node(id);
|
||||
let data_node = self.follow_aliases(producer);
|
||||
self.external_output_buffers.contains_key(&data_node)
|
||||
}
|
||||
|
||||
/// Find the LLIR producing node for an output tensor.
|
||||
@@ -390,6 +410,50 @@ impl CudaRuntime {
|
||||
self.cuda_stream.synchronize().unwrap();
|
||||
}
|
||||
|
||||
/// Resolve pending output pointer registrations into external_output_buffers.
|
||||
/// Called at the start of execute(), after buffer allocation and HLIR sync.
|
||||
fn apply_output_ptr_registrations(&mut self) {
|
||||
// clear stale external output buffers from previous execution
|
||||
self.external_output_buffers.clear();
|
||||
|
||||
if self.output_ptr_registrations.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
// Collect registrations to avoid borrow conflict (drain borrows self mutably,
|
||||
// but find_producer_node/follow_aliases need &self).
|
||||
|
||||
let registrations: Vec<_> = self.output_ptr_registrations.drain().collect();
|
||||
|
||||
for (hlir_id, (device_ptr, n_bytes)) in registrations {
|
||||
// Resolve HLIR output id -> LLIR producer -> follow aliases -> data node
|
||||
let producer = self.find_producer_node(hlir_id);
|
||||
let data_node = self.follow_aliases(producer);
|
||||
|
||||
// If data_node is an HLIR input (aliased output), skip — can't substitute
|
||||
if self.compiled_buckets[self.active_bucket]
|
||||
.llir_to_hlir
|
||||
.contains_key(&data_node)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
// Create non-owning CudaSlice view of PyTorch's buffer
|
||||
let slice = unsafe {
|
||||
self.cuda_stream
|
||||
.upgrade_device_ptr::<u8>(device_ptr, n_bytes)
|
||||
};
|
||||
|
||||
self.external_output_buffers
|
||||
.insert(data_node, std::mem::ManuallyDrop::new(slice));
|
||||
|
||||
// Update cached_buffer_ptrs so CudaGraphOp picks up the new pointer
|
||||
self.compiled_buckets[self.active_bucket]
|
||||
.cached_buffer_ptrs
|
||||
.insert(data_node, device_ptr);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_f32(&self, id: impl ToId) -> Vec<f32> {
|
||||
let bytes = self.get_output_data(id);
|
||||
let bytes = bytes.leak();
|
||||
@@ -790,11 +854,16 @@ impl Runtime for CudaRuntime {
|
||||
compiled_buckets: vec![CompiledBucket::new()],
|
||||
active_bucket: 0,
|
||||
dim_buckets: FxHashMap::default(),
|
||||
persistent_hlir_nodes: FxHashSet::default(),
|
||||
output_ptr_registrations: FxHashMap::default(),
|
||||
external_output_buffers: FxHashMap::default(),
|
||||
external_buffers: FxHashMap::default(),
|
||||
}
|
||||
}
|
||||
|
||||
fn aggregate_profile_metrics(metrics: &[Self::ProfileMetric]) -> Self::ProfileMetric {
|
||||
metrics.iter().copied().sum()
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
fn load_llir(&mut self, llir_graph: &LLIRGraph) {
|
||||
// Sync before clearing old data to ensure all operations complete
|
||||
@@ -827,15 +896,13 @@ impl Runtime for CudaRuntime {
|
||||
}
|
||||
}
|
||||
|
||||
fn allocate_dummy_input(&mut self, node_index: usize, num_elements: usize) {
|
||||
// Use small non-zero values (ones) instead of zeros so that NaN-producing
|
||||
// graph variants are detected during profiling. Zero inputs often hide
|
||||
// numerical issues that appear with real data.
|
||||
let host_data = vec![1.0f32; num_elements];
|
||||
let buf = self
|
||||
.cuda_stream
|
||||
.clone_htod(bytemuck::cast_slice::<f32, u8>(&host_data))
|
||||
.unwrap();
|
||||
fn allocate_dummy_input(&mut self, node_index: usize, num_bytes: usize) {
|
||||
// Boundary scratch buffers are sized in raw bytes and may represent
|
||||
// non-float tensors such as gather/scatter indices. Initialize with zero
|
||||
// bytes so integer boundaries stay in-range and the raw allocation size
|
||||
// matches the requested tensor storage.
|
||||
let host_data = vec![0u8; num_bytes];
|
||||
let buf = self.cuda_stream.clone_htod(&host_data).unwrap();
|
||||
let id = NodeIndex::new(node_index);
|
||||
self.hlir_buffers.insert(id, CudaInput::Buffer(buf));
|
||||
self.changed_hlir.insert(id);
|
||||
@@ -1013,6 +1080,9 @@ impl Runtime for CudaRuntime {
|
||||
// Ensure all CUDA graphs are built (handles first execute and any missing graphs)
|
||||
self.prebuild_graphs(dyn_map);
|
||||
|
||||
// Resolve external output pointer registrations (zero-copy output path)
|
||||
self.apply_output_ptr_registrations();
|
||||
|
||||
let total_start = std::time::Instant::now();
|
||||
let bucket = &self.compiled_buckets[self.active_bucket];
|
||||
|
||||
@@ -1022,8 +1092,11 @@ impl Runtime for CudaRuntime {
|
||||
|
||||
// Build buffer map for the HostOp interface
|
||||
let mut buffer_map: FxHashMap<NodeIndex, &CudaSlice<u8>> = FxHashMap::default();
|
||||
// Add output buffer
|
||||
if let Some(buf) = bucket.buffers.get(&exec_op.output) {
|
||||
|
||||
// Add output buffer -- prefer external output pointer if registered (zero copy)
|
||||
if let Some(ext) = self.external_output_buffers.get(&exec_op.output) {
|
||||
buffer_map.insert(exec_op.output, &**ext);
|
||||
} else if let Some(buf) = bucket.buffers.get(&exec_op.output) {
|
||||
buffer_map.insert(exec_op.output, buf);
|
||||
}
|
||||
// Add input buffers (prefer HLIR weight buffers over intermediate placeholders)
|
||||
@@ -1053,7 +1126,9 @@ impl Runtime for CudaRuntime {
|
||||
let extra_nodes = exec_op.internal.extra_buffer_nodes();
|
||||
for extra_node in extra_nodes {
|
||||
if let Entry::Vacant(e) = buffer_map.entry(extra_node) {
|
||||
if let Some(buf) = bucket.buffers.get(&extra_node) {
|
||||
if let Some(ext) = self.external_output_buffers.get(&extra_node) {
|
||||
e.insert(&**ext);
|
||||
} else if let Some(buf) = bucket.buffers.get(&extra_node) {
|
||||
e.insert(buf);
|
||||
} else if let Some(hlir_node) = bucket.llir_to_hlir.get(&extra_node) {
|
||||
match self.hlir_buffers.get(hlir_node) {
|
||||
@@ -1138,11 +1213,6 @@ impl Runtime for CudaRuntime {
|
||||
}
|
||||
}
|
||||
|
||||
// Final sync to ensure all operations completed successfully
|
||||
self.cuda_stream
|
||||
.synchronize()
|
||||
.expect("Final sync failed in execute");
|
||||
|
||||
// Consume input buffers
|
||||
if self.profiling {
|
||||
return;
|
||||
@@ -1190,7 +1260,6 @@ impl Runtime for CudaRuntime {
|
||||
.hlir_buffers
|
||||
.keys()
|
||||
.filter(|hlir_node| !inputs_with_outputs.contains(hlir_node))
|
||||
.filter(|hlir_node| !self.persistent_hlir_nodes.contains(hlir_node))
|
||||
.copied()
|
||||
.collect();
|
||||
|
||||
|
||||
@@ -41,7 +41,7 @@ fn test_bucket_dispatch_simple() {
|
||||
rt.set_data(a, vec![1.0f32; 4]);
|
||||
|
||||
let mut rng = SmallRng::seed_from_u64(42);
|
||||
rt = cx.search_rng(rt, 5, &mut rng);
|
||||
rt = cx.search_options(rt, SearchOptions::new(5), &mut rng);
|
||||
|
||||
// Test bucket 1: s=1
|
||||
cx.set_dim('s', 1);
|
||||
@@ -85,7 +85,7 @@ fn test_bucket_matmul_dynamic() {
|
||||
rt.set_data(b_tensor, b_data.clone());
|
||||
|
||||
let mut rng = SmallRng::seed_from_u64(42);
|
||||
rt = cx.search_rng(rt, 5, &mut rng);
|
||||
rt = cx.search_options(rt, SearchOptions::new(5), &mut rng);
|
||||
|
||||
// Execute at s=1
|
||||
cx.set_dim('s', 1);
|
||||
@@ -140,7 +140,7 @@ fn test_bucket_results_match_unbucketed() {
|
||||
let input_data = random_f32_vec(12, seed, -1.0, 1.0);
|
||||
rt1.set_data(a1, input_data.clone());
|
||||
let mut rng1 = SmallRng::seed_from_u64(seed);
|
||||
rt1 = cx1.search_rng(rt1, 5, &mut rng1);
|
||||
rt1 = cx1.search_options(rt1, SearchOptions::new(5), &mut rng1);
|
||||
rt1.set_data(a1, input_data.clone());
|
||||
rt1.execute(&cx1.dyn_map);
|
||||
let result_unbucketed = rt1.get_f32(b1);
|
||||
@@ -153,7 +153,7 @@ fn test_bucket_results_match_unbucketed() {
|
||||
let mut rt2 = CudaRuntime::initialize(stream.clone());
|
||||
rt2.set_data(a2, input_data.clone());
|
||||
let mut rng2 = SmallRng::seed_from_u64(seed);
|
||||
rt2 = cx2.search_rng(rt2, 5, &mut rng2);
|
||||
rt2 = cx2.search_options(rt2, SearchOptions::new(5), &mut rng2);
|
||||
rt2.set_data(a2, input_data.clone());
|
||||
rt2.execute(&cx2.dyn_map);
|
||||
let result_bucketed = rt2.get_f32(b2);
|
||||
@@ -179,7 +179,7 @@ fn test_bucket_out_of_range_panics() {
|
||||
cx.set_dim('s', 1);
|
||||
rt.set_data(a, vec![1.0f32; 4]);
|
||||
let mut rng = SmallRng::seed_from_u64(42);
|
||||
rt = cx.search_rng(rt, 3, &mut rng);
|
||||
rt = cx.search_options(rt, SearchOptions::new(3), &mut rng);
|
||||
|
||||
// s=10 is outside all buckets — should panic
|
||||
cx.set_dim('s', 10);
|
||||
@@ -204,7 +204,7 @@ fn test_bucket_no_buckets_backward_compat() {
|
||||
let input_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
|
||||
rt.set_data(a, input_data.clone());
|
||||
let mut rng = SmallRng::seed_from_u64(42);
|
||||
rt = cx.search_rng(rt, 3, &mut rng);
|
||||
rt = cx.search_options(rt, SearchOptions::new(3), &mut rng);
|
||||
|
||||
rt.set_data(a, input_data.clone());
|
||||
rt.execute(&cx.dyn_map);
|
||||
@@ -249,7 +249,7 @@ fn test_bucket_switch_preserves_weights() {
|
||||
rt.set_data(b_tensor, b_data.clone());
|
||||
|
||||
let mut rng = SmallRng::seed_from_u64(42);
|
||||
rt = cx.search_rng(rt, 5, &mut rng);
|
||||
rt = cx.search_options(rt, SearchOptions::new(5), &mut rng);
|
||||
|
||||
// Execute with bucket 1 (s=1)
|
||||
cx.set_dim('s', 1);
|
||||
@@ -305,7 +305,7 @@ fn test_bucket_multiple_executions_same_bucket() {
|
||||
cx.set_dim('s', 1);
|
||||
rt.set_data(a, vec![1.0f32; 4]);
|
||||
let mut rng = SmallRng::seed_from_u64(42);
|
||||
rt = cx.search_rng(rt, 3, &mut rng);
|
||||
rt = cx.search_options(rt, SearchOptions::new(3), &mut rng);
|
||||
|
||||
// Execute at different sizes within the same bucket
|
||||
for s in [1, 2, 4, 8] {
|
||||
|
||||
@@ -301,9 +301,8 @@ fn test_scatter_kv_cache_roundtrip() {
|
||||
}
|
||||
|
||||
/// Test scatter with TWO cache buffers and dual outputs (closer to llama K+V pattern).
|
||||
/// Also verifies graph_break interaction.
|
||||
#[test]
|
||||
fn test_scatter_dual_cache_with_graph_break() {
|
||||
fn test_scatter_dual_cache() {
|
||||
let ctx = CudaContext::new(0).unwrap();
|
||||
ctx.bind_to_thread().unwrap();
|
||||
let stream = ctx.default_stream();
|
||||
@@ -348,7 +347,7 @@ fn test_scatter_dual_cache_with_graph_break() {
|
||||
// Use seeded search for deterministic scatter variant selection.
|
||||
// Seed 0 reliably selects Scatter (not ScatterNoCopy) for both caches.
|
||||
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
|
||||
rt = cx.search_rng(rt, 5, &mut rng);
|
||||
rt = cx.search_options(rt, SearchOptions::new(5), &mut rng);
|
||||
|
||||
// Print selected variants
|
||||
for node in rt.llir_graph().node_weights() {
|
||||
|
||||
318
crates/luminal_cuda_lite/src/tests/fusion.rs
Normal file
318
crates/luminal_cuda_lite/src/tests/fusion.rs
Normal file
@@ -0,0 +1,318 @@
|
||||
use as_any::Downcast;
|
||||
use luminal::egglog_utils::{egglog_to_llir, random_initial_choice};
|
||||
use luminal::prelude::*;
|
||||
|
||||
use crate::kernel::KernelOp;
|
||||
use crate::kernel::other_ops::{KernelFusedElementwise, UnaryFn};
|
||||
use crate::runtime::CudaRuntime;
|
||||
use crate::tests::utilities::{random_f32_vec, test_unary_cuda};
|
||||
|
||||
/// Return every distinct kernel_name that appears across many random extractions
|
||||
/// of the search space. Used to check whether fusion produces a reachable
|
||||
/// `KernelFusedElementwise` node (or, negatively, that it never does).
|
||||
fn extract_all_kernel_names(cx: &mut Graph) -> Vec<String> {
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
let egraph = cx.egraph().expect("egraph not built");
|
||||
let ops = cx.egglog_ops().expect("ops not built");
|
||||
let custom_ops = &cx.custom_ops;
|
||||
|
||||
let mut all_names = Vec::new();
|
||||
for _ in 0..50 {
|
||||
let choices = random_initial_choice(egraph, &mut rand::rng());
|
||||
let mut list_cache = Default::default();
|
||||
let mut expr_cache = Default::default();
|
||||
let llir = egglog_to_llir(
|
||||
egraph,
|
||||
choices,
|
||||
ops,
|
||||
custom_ops,
|
||||
&mut list_cache,
|
||||
&mut expr_cache,
|
||||
None,
|
||||
);
|
||||
for op in llir.node_weights() {
|
||||
if let Some(k) = op.to_dialect::<dyn KernelOp>() {
|
||||
let name = k.kernel_name().to_string();
|
||||
if !all_names.contains(&name) {
|
||||
all_names.push(name);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
all_names
|
||||
}
|
||||
|
||||
/// Return every distinct `Vec<UnaryFn>` that appears inside a reachable
|
||||
/// `KernelFusedElementwise` across many random extractions. Used to verify
|
||||
/// that a specific fused configuration (e.g. a 3-op chain) is reachable.
|
||||
fn extract_all_fused_configs(cx: &mut Graph) -> Vec<Vec<UnaryFn>> {
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
let egraph = cx.egraph().expect("egraph not built");
|
||||
let ops = cx.egglog_ops().expect("ops not built");
|
||||
let custom_ops = &cx.custom_ops;
|
||||
|
||||
let mut all_configs: Vec<Vec<UnaryFn>> = Vec::new();
|
||||
for _ in 0..200 {
|
||||
let choices = random_initial_choice(egraph, &mut rand::rng());
|
||||
let mut list_cache = Default::default();
|
||||
let mut expr_cache = Default::default();
|
||||
let llir = egglog_to_llir(
|
||||
egraph,
|
||||
choices,
|
||||
ops,
|
||||
custom_ops,
|
||||
&mut list_cache,
|
||||
&mut expr_cache,
|
||||
None,
|
||||
);
|
||||
for op in llir.node_weights() {
|
||||
if let Some(kop) = op.to_dialect::<dyn KernelOp>()
|
||||
&& let Some(fused) = (***kop).downcast_ref::<KernelFusedElementwise>()
|
||||
{
|
||||
let cfg = fused.ops().to_vec();
|
||||
if !all_configs.contains(&cfg) {
|
||||
all_configs.push(cfg);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
all_configs
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_two_unary_ops_fuse() {
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor(8);
|
||||
let _b = a.sin().sqrt().output();
|
||||
|
||||
let names = extract_all_kernel_names(&mut cx);
|
||||
assert!(
|
||||
names.iter().any(|n| n == "FusedElementwise"),
|
||||
"expected KernelSin→KernelSqrt on contiguous strides to be fusable into \
|
||||
a single FusedElementwise kernel, but reachable kernels were: {names:?}",
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_stride_mismatch_prevents_fusion() {
|
||||
// A permute between sin and sqrt gives sqrt a non-contiguous view of sin's
|
||||
// contiguous output, so sqrt's in_strides != its out_strides and the
|
||||
// non-linear `?strides` match in the fusion rule can't fire.
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor((3, 4));
|
||||
let _b = a.sin().permute((1, 0)).sqrt().output();
|
||||
|
||||
let names = extract_all_kernel_names(&mut cx);
|
||||
assert!(
|
||||
!names.iter().any(|n| n == "FusedElementwise"),
|
||||
"a permute between sin and sqrt must prevent fusion, but \
|
||||
FusedElementwise appeared in reachable kernels: {names:?}",
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reduction_prevents_unary_fusion() {
|
||||
// A reduction between two unaries is not elementwise, so the fusion rule
|
||||
// (which only matches unary+unary pairs) must not fire.
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor((4, 4));
|
||||
let _b = a.sin().sum(1).sqrt().output();
|
||||
|
||||
let names = extract_all_kernel_names(&mut cx);
|
||||
assert!(
|
||||
!names.iter().any(|n| n == "FusedElementwise"),
|
||||
"a reduction between sin and sqrt must prevent fusion, but \
|
||||
FusedElementwise appeared in reachable kernels: {names:?}",
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unary_fusion_preserves_output() {
|
||||
// End-to-end numerical check: sqrt(sin(x)) must produce the same values
|
||||
// whether or not the fusion rule fired. Runs on GPU when available;
|
||||
// silently no-ops otherwise via get_cuda_stream().
|
||||
let seed = 0xC0FFEEu64;
|
||||
let gen_lambda = |n, s| random_f32_vec(n, s, 0.0, 1.0);
|
||||
test_unary_cuda::<f32>(
|
||||
8,
|
||||
|a| a.sin().sqrt(),
|
||||
|a| a.sin().unwrap().sqrt().unwrap(),
|
||||
gen_lambda,
|
||||
seed,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_three_unary_ops_fuse() {
|
||||
// A chain of 3 pure-elementwise unaries with matching strides should be
|
||||
// reachable as a single FusedElementwise containing all three ops.
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor(16);
|
||||
let _b = a.sin().sqrt().exp2().output();
|
||||
|
||||
let configs = extract_all_fused_configs(&mut cx);
|
||||
let expected = vec![UnaryFn::Sin, UnaryFn::Sqrt, UnaryFn::Exp2];
|
||||
assert!(
|
||||
configs.contains(&expected),
|
||||
"expected a Fused[Sin, Sqrt, Exp2] in reachable configs, got: {configs:?}",
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_four_unary_ops_fuse() {
|
||||
// 4-op chain should collapse into a single Fused containing all four ops.
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor(16);
|
||||
let _b = a.sin().sqrt().exp2().log2().output();
|
||||
|
||||
let configs = extract_all_fused_configs(&mut cx);
|
||||
let expected = vec![UnaryFn::Sin, UnaryFn::Sqrt, UnaryFn::Exp2, UnaryFn::Log2];
|
||||
assert!(
|
||||
configs.contains(&expected),
|
||||
"expected a Fused[Sin, Sqrt, Exp2, Log2] in reachable configs, got: {configs:?}",
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_three_unary_chain_preserves_output() {
|
||||
// End-to-end numerical check for a 3-op chain.
|
||||
// Uses sin→sqrt→sin because candle lacks exp2/log2 and this still exercises
|
||||
// a 3-link chain. The structural tests above cover the distinct-ops shape.
|
||||
let seed = 0xBEEFu64;
|
||||
let gen_lambda = |n, s| random_f32_vec(n, s, 0.0, 1.0);
|
||||
test_unary_cuda::<f32>(
|
||||
16,
|
||||
|a| a.sin().sqrt().sin(),
|
||||
|a| a.sin().unwrap().sqrt().unwrap().sin().unwrap(),
|
||||
gen_lambda,
|
||||
seed,
|
||||
);
|
||||
}
|
||||
|
||||
/// Isolated per-kernel microbenchmark: time two unfused kernels
|
||||
/// (`sqrt_k` then `recip_k`) vs one fused kernel (`fused_k` that does
|
||||
/// `1.0f / sqrtf(x)` in a single launch) on a fixed-size input, using
|
||||
/// CUDA events for device-side timing.
|
||||
///
|
||||
/// Ignored by default — run with
|
||||
/// `cargo test -p luminal_cuda_lite -- --ignored bench_fused_vs_unfused_sqrt_recip --nocapture`.
|
||||
#[test]
|
||||
#[ignore]
|
||||
fn bench_fused_vs_unfused_sqrt_recip() {
|
||||
use crate::compile_module_image_for_current_device;
|
||||
use cudarc::driver::{CudaContext, LaunchConfig, PushKernelArg};
|
||||
|
||||
const N: usize = 1 << 20; // 1M elements
|
||||
const WARMUP: usize = 100;
|
||||
const TRIALS: usize = 2000;
|
||||
|
||||
let ctx = match CudaContext::new(0) {
|
||||
Ok(c) => c,
|
||||
Err(_) => return, // no GPU available, skip
|
||||
};
|
||||
ctx.bind_to_thread().unwrap();
|
||||
let stream = ctx.default_stream();
|
||||
|
||||
// Prepare input (values in (0, 1] so sqrt/recip are well-defined).
|
||||
let host_input: Vec<f32> = (0..N).map(|i| (i as f32 + 1.0) / (N as f32)).collect();
|
||||
let d_in = stream.clone_htod(&host_input).unwrap();
|
||||
let mut d_scratch = stream.alloc_zeros::<f32>(N).unwrap();
|
||||
let mut d_out = stream.alloc_zeros::<f32>(N).unwrap();
|
||||
|
||||
let compile = |src: &str, name: &str| {
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), src).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
module.load_function(name).unwrap()
|
||||
};
|
||||
|
||||
let sqrt_k = compile(
|
||||
r#"
|
||||
extern "C" __global__ void sqrt_k(float* out, const float* in, long long n) {
|
||||
long long i = (long long)blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (i >= n) return;
|
||||
out[i] = sqrtf(in[i]);
|
||||
}
|
||||
"#,
|
||||
"sqrt_k",
|
||||
);
|
||||
let recip_k = compile(
|
||||
r#"
|
||||
extern "C" __global__ void recip_k(float* out, const float* in, long long n) {
|
||||
long long i = (long long)blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (i >= n) return;
|
||||
out[i] = 1.0f / in[i];
|
||||
}
|
||||
"#,
|
||||
"recip_k",
|
||||
);
|
||||
let fused_k = compile(
|
||||
r#"
|
||||
extern "C" __global__ void fused_k(float* out, const float* in, long long n) {
|
||||
long long i = (long long)blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (i >= n) return;
|
||||
float v = in[i];
|
||||
v = sqrtf(v);
|
||||
v = 1.0f / v;
|
||||
out[i] = v;
|
||||
}
|
||||
"#,
|
||||
"fused_k",
|
||||
);
|
||||
|
||||
let cfg = LaunchConfig::for_num_elems(N as u32);
|
||||
let n_arg: i64 = N as i64;
|
||||
|
||||
let launch_unfused = |d_out: &mut cudarc::driver::CudaSlice<f32>,
|
||||
d_scratch: &mut cudarc::driver::CudaSlice<f32>| {
|
||||
let mut b = stream.launch_builder(&sqrt_k);
|
||||
b.arg(&mut *d_scratch).arg(&d_in).arg(&n_arg);
|
||||
unsafe { b.launch(cfg) }.unwrap();
|
||||
let mut b = stream.launch_builder(&recip_k);
|
||||
b.arg(d_out).arg(&*d_scratch).arg(&n_arg);
|
||||
unsafe { b.launch(cfg) }.unwrap();
|
||||
};
|
||||
let launch_fused = |d_out: &mut cudarc::driver::CudaSlice<f32>| {
|
||||
let mut b = stream.launch_builder(&fused_k);
|
||||
b.arg(d_out).arg(&d_in).arg(&n_arg);
|
||||
unsafe { b.launch(cfg) }.unwrap();
|
||||
};
|
||||
|
||||
// Warmup
|
||||
for _ in 0..WARMUP {
|
||||
launch_unfused(&mut d_out, &mut d_scratch);
|
||||
launch_fused(&mut d_out);
|
||||
}
|
||||
stream.synchronize().unwrap();
|
||||
|
||||
let start = ctx.new_event(None).unwrap();
|
||||
let end = ctx.new_event(None).unwrap();
|
||||
|
||||
// Time unfused
|
||||
start.record(&stream).unwrap();
|
||||
for _ in 0..TRIALS {
|
||||
launch_unfused(&mut d_out, &mut d_scratch);
|
||||
}
|
||||
end.record(&stream).unwrap();
|
||||
end.synchronize().unwrap();
|
||||
let unfused_total_ms = start.elapsed_ms(&end).unwrap();
|
||||
|
||||
// Time fused
|
||||
start.record(&stream).unwrap();
|
||||
for _ in 0..TRIALS {
|
||||
launch_fused(&mut d_out);
|
||||
}
|
||||
end.record(&stream).unwrap();
|
||||
end.synchronize().unwrap();
|
||||
let fused_total_ms = start.elapsed_ms(&end).unwrap();
|
||||
|
||||
let unfused_us = unfused_total_ms as f64 * 1_000.0 / TRIALS as f64;
|
||||
let fused_us = fused_total_ms as f64 * 1_000.0 / TRIALS as f64;
|
||||
let speedup = unfused_us / fused_us;
|
||||
|
||||
println!(
|
||||
"\n[fusion microbench, N={N}, trials={TRIALS}]\n\
|
||||
unfused (sqrt_k; recip_k): {unfused_us:8.3} us/iter ({unfused_total_ms:.2} ms total)\n\
|
||||
fused (sqrtf; 1.0f/): {fused_us:8.3} us/iter ({fused_total_ms:.2} ms total)\n\
|
||||
speedup: {speedup:.2}x"
|
||||
);
|
||||
}
|
||||
@@ -5,10 +5,14 @@ mod bucket_tests;
|
||||
#[cfg(test)]
|
||||
mod consumed_buffer_tests;
|
||||
#[cfg(test)]
|
||||
mod fusion;
|
||||
#[cfg(test)]
|
||||
mod model_fuzz;
|
||||
#[cfg(test)]
|
||||
mod op_functional_tests;
|
||||
#[cfg(test)]
|
||||
mod performance_tests;
|
||||
#[cfg(test)]
|
||||
mod qwen3_moe_rewrite;
|
||||
#[cfg(test)]
|
||||
mod transformer;
|
||||
|
||||
314
crates/luminal_cuda_lite/src/tests/qwen3_moe_rewrite.rs
Normal file
314
crates/luminal_cuda_lite/src/tests/qwen3_moe_rewrite.rs
Normal file
@@ -0,0 +1,314 @@
|
||||
use half::bf16;
|
||||
use luminal::{dtype::DType, prelude::*, shape::Expression};
|
||||
|
||||
use super::utilities::{assert_close, get_cuda_stream, random_f32_vec};
|
||||
use crate::{
|
||||
host::{
|
||||
HostOp,
|
||||
moe::{GLUMoE, GLUMoEMode},
|
||||
},
|
||||
runtime::CudaRuntime,
|
||||
};
|
||||
|
||||
const SEQ: usize = 2;
|
||||
const HIDDEN: usize = 16;
|
||||
const NUM_EXPERTS: usize = 8;
|
||||
const TOP_K: usize = 2;
|
||||
const MOE_INTERMEDIATE: usize = 6;
|
||||
const RMS_NORM_EPS: f32 = 1e-6;
|
||||
|
||||
struct QwenMoeGraph {
|
||||
graph: Graph,
|
||||
x: GraphTensor,
|
||||
router: GraphTensor,
|
||||
gate_up_weights: GraphTensor,
|
||||
down_weights: GraphTensor,
|
||||
output: GraphTensor,
|
||||
}
|
||||
|
||||
struct GemmaMoeGraph {
|
||||
graph: Graph,
|
||||
router_input: GraphTensor,
|
||||
expert_input: GraphTensor,
|
||||
router_scale: GraphTensor,
|
||||
router_proj: GraphTensor,
|
||||
per_expert_scale: GraphTensor,
|
||||
gate_up_weights: GraphTensor,
|
||||
down_weights: GraphTensor,
|
||||
output: GraphTensor,
|
||||
}
|
||||
|
||||
fn build_qwen_moe_graph() -> QwenMoeGraph {
|
||||
let mut cx = Graph::default();
|
||||
let x = cx.tensor(('s', HIDDEN));
|
||||
let router = cx.tensor((NUM_EXPERTS, HIDDEN));
|
||||
let gate_up_weights = cx
|
||||
.tensor((NUM_EXPERTS, MOE_INTERMEDIATE * 2, HIDDEN))
|
||||
.as_dtype(DType::Bf16);
|
||||
let down_weights = cx
|
||||
.tensor((NUM_EXPERTS, HIDDEN, MOE_INTERMEDIATE))
|
||||
.as_dtype(DType::Bf16);
|
||||
|
||||
let n = x.dims().len();
|
||||
let e_dim = *router.dims().first().unwrap();
|
||||
let k_expr = Expression::from(TOP_K);
|
||||
|
||||
let routing_weights = x.matmul(router.t()).softmax(n - 1);
|
||||
let top_k_indices = routing_weights.topk_indexes(TOP_K, n - 1);
|
||||
|
||||
let row_offsets = x
|
||||
.graph()
|
||||
.iota(Expression::from('z') / k_expr * e_dim, top_k_indices.dims());
|
||||
let routing_flat_idx = row_offsets + top_k_indices;
|
||||
let top_k_values = routing_weights.gather(routing_flat_idx);
|
||||
|
||||
let gate_up_gathered = gather_experts(x, top_k_indices, gate_up_weights).cast(DType::F32);
|
||||
let x_exp = x.expand_dim(n - 1, TOP_K).unsqueeze(n);
|
||||
let gate_up_out = x_exp.matmul(gate_up_gathered.transpose(2, 3)).squeeze(n);
|
||||
let gate = gate_up_out.slice((.., .., ..MOE_INTERMEDIATE));
|
||||
let up = gate_up_out.slice((.., .., MOE_INTERMEDIATE..));
|
||||
let hidden = gate.silu() * up;
|
||||
|
||||
let down_gathered = gather_experts(x, top_k_indices, down_weights).cast(DType::F32);
|
||||
let down_out = hidden
|
||||
.unsqueeze(2)
|
||||
.matmul(down_gathered.transpose(2, 3))
|
||||
.squeeze(2);
|
||||
let output = (down_out * top_k_values.unsqueeze(top_k_values.dims().len()))
|
||||
.sum(n - 1)
|
||||
.output();
|
||||
|
||||
QwenMoeGraph {
|
||||
graph: cx,
|
||||
x,
|
||||
router,
|
||||
gate_up_weights,
|
||||
down_weights,
|
||||
output,
|
||||
}
|
||||
}
|
||||
|
||||
fn build_gemma_moe_graph() -> GemmaMoeGraph {
|
||||
let mut cx = Graph::default();
|
||||
let router_input = cx.tensor(('s', HIDDEN));
|
||||
let expert_input = cx.tensor(('s', HIDDEN));
|
||||
let router_scale = cx.tensor(HIDDEN);
|
||||
let router_proj = cx.tensor((NUM_EXPERTS, HIDDEN));
|
||||
let per_expert_scale = cx.tensor(NUM_EXPERTS);
|
||||
let gate_up_weights = cx
|
||||
.tensor((NUM_EXPERTS, MOE_INTERMEDIATE * 2, HIDDEN))
|
||||
.as_dtype(DType::Bf16);
|
||||
let down_weights = cx
|
||||
.tensor((NUM_EXPERTS, HIDDEN, MOE_INTERMEDIATE))
|
||||
.as_dtype(DType::Bf16);
|
||||
|
||||
let n = router_input.dims().len();
|
||||
let e_dim = *router_proj.dims().first().unwrap();
|
||||
let k_expr = Expression::from(TOP_K);
|
||||
|
||||
let router_hidden = router_input.std_norm(n - 1, RMS_NORM_EPS)
|
||||
* router_scale.expand_lhs(&router_input.dims()[..n - 1])
|
||||
* (HIDDEN as f32).sqrt().recip();
|
||||
let routing_weights = router_hidden.matmul(router_proj.t()).softmax(n - 1);
|
||||
|
||||
let top_k_indices = routing_weights.topk_indexes(TOP_K, n - 1);
|
||||
let row_offsets = router_input
|
||||
.graph()
|
||||
.iota(Expression::from('z') / k_expr * e_dim, top_k_indices.dims());
|
||||
let routing_flat_idx = row_offsets + top_k_indices;
|
||||
let top_k_values = routing_weights.gather(routing_flat_idx);
|
||||
let top_k_norm = top_k_values.sum(n - 1).expand_dim(n - 1, TOP_K);
|
||||
let top_k_weights = (top_k_values / top_k_norm) * per_expert_scale.gather(top_k_indices);
|
||||
|
||||
let gate_up_gathered =
|
||||
gather_experts(expert_input, top_k_indices, gate_up_weights).cast(DType::F32);
|
||||
let x_exp = expert_input.expand_dim(n - 1, TOP_K).unsqueeze(n);
|
||||
let gate_up_out = x_exp.matmul(gate_up_gathered.transpose(2, 3)).squeeze(n);
|
||||
let gate = gate_up_out.slice((.., .., ..MOE_INTERMEDIATE));
|
||||
let up = gate_up_out.slice((.., .., MOE_INTERMEDIATE..));
|
||||
let hidden = gemma_gelu(gate) * up;
|
||||
|
||||
let down_gathered = gather_experts(expert_input, top_k_indices, down_weights).cast(DType::F32);
|
||||
let down_out = hidden
|
||||
.unsqueeze(2)
|
||||
.matmul(down_gathered.transpose(2, 3))
|
||||
.squeeze(2);
|
||||
let output = (down_out * top_k_weights.unsqueeze(top_k_weights.dims().len()))
|
||||
.sum(n - 1)
|
||||
.output();
|
||||
|
||||
GemmaMoeGraph {
|
||||
graph: cx,
|
||||
router_input,
|
||||
expert_input,
|
||||
router_scale,
|
||||
router_proj,
|
||||
per_expert_scale,
|
||||
gate_up_weights,
|
||||
down_weights,
|
||||
output,
|
||||
}
|
||||
}
|
||||
|
||||
fn gather_experts(
|
||||
graph_source: GraphTensor,
|
||||
top_k_indices: GraphTensor,
|
||||
weights: GraphTensor,
|
||||
) -> GraphTensor {
|
||||
let (_, d1, d2) = weights.dims3();
|
||||
let io = d1 * d2;
|
||||
let base = top_k_indices * io;
|
||||
let within = graph_source.graph().iota(Expression::from('z'), (d1, d2));
|
||||
let n_base = base.dims().len();
|
||||
let exp_base = base.expand_dim(n_base, d1).expand_dim(n_base + 1, d2);
|
||||
let mut exp_within = within;
|
||||
for (axis, dim) in base.dims().iter().enumerate() {
|
||||
exp_within = exp_within.expand_dim(axis, *dim);
|
||||
}
|
||||
let expert_flat_idx = exp_base + exp_within;
|
||||
weights.gather(expert_flat_idx)
|
||||
}
|
||||
|
||||
#[allow(clippy::excessive_precision)]
|
||||
fn gemma_gelu(x: GraphTensor) -> GraphTensor {
|
||||
let scaled = 1.5957691216 * x * (1. + 0.044715 * x * x);
|
||||
x * scaled.sigmoid()
|
||||
}
|
||||
|
||||
fn glumoe_modes(rt: &CudaRuntime) -> Vec<GLUMoEMode> {
|
||||
rt.llir_graph()
|
||||
.node_weights()
|
||||
.filter_map(|node| {
|
||||
let op = node.to_dialect::<dyn HostOp>()?;
|
||||
op.as_any()
|
||||
.downcast_ref::<GLUMoE>()
|
||||
.map(|glumoe| glumoe.mode)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn run_qwen_moe(use_glumoe: bool) -> (Vec<f32>, Vec<GLUMoEMode>) {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return (vec![], vec![]);
|
||||
};
|
||||
|
||||
let mut model = build_qwen_moe_graph();
|
||||
model.graph.set_dim('s', SEQ);
|
||||
if use_glumoe {
|
||||
model.graph.build_search_space::<CudaRuntime>();
|
||||
} else {
|
||||
model
|
||||
.graph
|
||||
.build_search_space_exclude_ops::<CudaRuntime, GLUMoE>();
|
||||
}
|
||||
|
||||
let x_data = random_f32_vec(SEQ * HIDDEN, 11, -0.15, 0.15);
|
||||
let router_data = random_f32_vec(NUM_EXPERTS * HIDDEN, 12, -0.2, 0.2);
|
||||
let gate_up_data = random_f32_vec(NUM_EXPERTS * MOE_INTERMEDIATE * 2 * HIDDEN, 13, -0.1, 0.1)
|
||||
.into_iter()
|
||||
.map(bf16::from_f32)
|
||||
.collect::<Vec<_>>();
|
||||
let down_data = random_f32_vec(NUM_EXPERTS * HIDDEN * MOE_INTERMEDIATE, 14, -0.1, 0.1)
|
||||
.into_iter()
|
||||
.map(bf16::from_f32)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.set_data(model.x, x_data);
|
||||
rt.set_data(model.router, router_data);
|
||||
rt.set_data(model.gate_up_weights, gate_up_data);
|
||||
rt.set_data(model.down_weights, down_data);
|
||||
rt = model.graph.search(rt, 10);
|
||||
rt.execute(&model.graph.dyn_map);
|
||||
|
||||
(rt.get_f32(model.output.id), glumoe_modes(&rt))
|
||||
}
|
||||
|
||||
fn run_gemma_moe(use_glumoe: bool) -> (Vec<f32>, Vec<GLUMoEMode>) {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return (vec![], vec![]);
|
||||
};
|
||||
|
||||
let mut model = build_gemma_moe_graph();
|
||||
model.graph.set_dim('s', SEQ);
|
||||
if use_glumoe {
|
||||
model.graph.build_search_space::<CudaRuntime>();
|
||||
} else {
|
||||
model
|
||||
.graph
|
||||
.build_search_space_exclude_ops::<CudaRuntime, GLUMoE>();
|
||||
}
|
||||
|
||||
let router_input_data = random_f32_vec(SEQ * HIDDEN, 21, -0.15, 0.15);
|
||||
let expert_input_data = random_f32_vec(SEQ * HIDDEN, 22, -0.15, 0.15);
|
||||
let router_scale_data = random_f32_vec(HIDDEN, 23, 0.7, 1.3);
|
||||
let router_proj_data = random_f32_vec(NUM_EXPERTS * HIDDEN, 24, -0.2, 0.2);
|
||||
let per_expert_scale_data = random_f32_vec(NUM_EXPERTS, 25, 0.5, 1.5);
|
||||
let gate_up_data = random_f32_vec(NUM_EXPERTS * MOE_INTERMEDIATE * 2 * HIDDEN, 26, -0.1, 0.1)
|
||||
.into_iter()
|
||||
.map(bf16::from_f32)
|
||||
.collect::<Vec<_>>();
|
||||
let down_data = random_f32_vec(NUM_EXPERTS * HIDDEN * MOE_INTERMEDIATE, 27, -0.1, 0.1)
|
||||
.into_iter()
|
||||
.map(bf16::from_f32)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.set_data(model.router_input, router_input_data);
|
||||
rt.set_data(model.expert_input, expert_input_data);
|
||||
rt.set_data(model.router_scale, router_scale_data);
|
||||
rt.set_data(model.router_proj, router_proj_data);
|
||||
rt.set_data(model.per_expert_scale, per_expert_scale_data);
|
||||
rt.set_data(model.gate_up_weights, gate_up_data);
|
||||
rt.set_data(model.down_weights, down_data);
|
||||
rt = model.graph.search(rt, 10);
|
||||
rt.execute(&model.graph.dyn_map);
|
||||
|
||||
(rt.get_f32(model.output.id), glumoe_modes(&rt))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_glumoe_matches_qwen_swiglu_pattern() {
|
||||
let (_result, modes) = run_qwen_moe(true);
|
||||
if modes.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
assert_eq!(modes, vec![GLUMoEMode::SwiGLU]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_glumoe_matches_gemma_gelu_pattern() {
|
||||
let (_result, modes) = run_gemma_moe(true);
|
||||
if modes.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
assert_eq!(modes, vec![GLUMoEMode::GemmaGELU]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_glumoe_swiglu_matches_unfused_output() {
|
||||
let (expected, baseline_modes) = run_qwen_moe(false);
|
||||
if expected.is_empty() {
|
||||
return;
|
||||
}
|
||||
assert!(baseline_modes.is_empty());
|
||||
|
||||
let (actual, fused_modes) = run_qwen_moe(true);
|
||||
assert_eq!(fused_modes, vec![GLUMoEMode::SwiGLU]);
|
||||
assert_close(&actual, &expected, 3e-2, 3e-2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_glumoe_gemma_gelu_matches_unfused_output() {
|
||||
let (expected, baseline_modes) = run_gemma_moe(false);
|
||||
if expected.is_empty() {
|
||||
return;
|
||||
}
|
||||
assert!(baseline_modes.is_empty());
|
||||
|
||||
let (actual, fused_modes) = run_gemma_moe(true);
|
||||
assert_eq!(fused_modes, vec![GLUMoEMode::GemmaGELU]);
|
||||
assert_close(&actual, &expected, 3e-2, 3e-2);
|
||||
}
|
||||
@@ -300,7 +300,7 @@ fn test_mini_transformer_two_layers() {
|
||||
let input = cx.tensor((SEQ, HIDDEN));
|
||||
let layer1 = MiniTransformerLayer::init(&mut cx);
|
||||
let layer2 = MiniTransformerLayer::init(&mut cx);
|
||||
let x = layer1.forward(input).graph_break();
|
||||
let x = layer1.forward(input);
|
||||
let out = layer2.forward(x).output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
@@ -508,3 +508,32 @@ fn test_swiglu_mlp_cuda() {
|
||||
|
||||
assert_close(&result, &expected, 1e-3, 1e-3);
|
||||
}
|
||||
|
||||
/// Body=1, trips=3 chain of scalar Muls plus a residual back to the
|
||||
/// chain's initial value. Auto-rolling sees this as a state-carrying loop
|
||||
/// with state at input position 0; the rolled HLIR must round-trip through
|
||||
/// egglog (rolled body Mul + LoopStart/LoopInput/LoopEnd markers) and
|
||||
/// `unroll_loops_in_llir` must reconstruct the flat 3-mul chain plus
|
||||
/// rewire the residual edge to reference the chain's initial input
|
||||
/// (outside the body) — not a per-iter clone.
|
||||
#[test]
|
||||
fn test_rolled_chained_scalar_muls() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
let mut cx = Graph::default();
|
||||
let x = cx.tensor((1, 4, 32));
|
||||
let chained = ((x * 2.0_f32) * 3.0_f32) * 5.0_f32;
|
||||
let out = (chained + x).output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
let x_data = random_f32_vec(4 * 32, 101, -0.5, 0.5);
|
||||
rt.set_data(x, x_data.clone());
|
||||
rt = cx.search(rt, 3);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
let result = rt.get_f32(out);
|
||||
let expected: Vec<f32> = x_data.iter().map(|v| v * 2.0 * 3.0 * 5.0 + v).collect();
|
||||
assert_close(&result, &expected, 1e-5, 1e-5);
|
||||
}
|
||||
|
||||
@@ -468,7 +468,7 @@ pub fn fuzz_genomes<T: TestDType>(
|
||||
|
||||
let mut list_cache = FxHashMap::default();
|
||||
let mut expr_cache = FxHashMap::default();
|
||||
let llir_graph = egglog_to_llir(
|
||||
let mut llir_graph = egglog_to_llir(
|
||||
egraph,
|
||||
genome.clone(),
|
||||
ops,
|
||||
@@ -477,6 +477,12 @@ pub fn fuzz_genomes<T: TestDType>(
|
||||
&mut expr_cache,
|
||||
None,
|
||||
);
|
||||
// Same finalization as `Graph::search` performs on the chosen
|
||||
// best LLIR: collapse the rolled body's loop markers into a
|
||||
// fully-unrolled LLIR. The runtime cannot execute LoopStart /
|
||||
// LoopEnd / LoopInput / LoopOutput markers — they exist only as
|
||||
// a search-time scaffold the auto-roll prepass introduces.
|
||||
unroll_loops_in_llir(&mut llir_graph);
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream.clone());
|
||||
rt.load_llir(&llir_graph);
|
||||
|
||||
48
crates/luminal_metal/src/dyn_backend.rs
Normal file
48
crates/luminal_metal/src/dyn_backend.rs
Normal file
@@ -0,0 +1,48 @@
|
||||
//! [`DynBackend`] implementation for the Metal runtime.
|
||||
|
||||
use luminal::dtype::DType;
|
||||
use luminal::dyn_backend::{bytes_to_native_data, compile_backend, BackendCompileArgs, DynBackend};
|
||||
use luminal::prelude::*;
|
||||
|
||||
use crate::runtime::MetalRuntime;
|
||||
|
||||
/// [`DynBackend`] wrapper for [`MetalRuntime`].
|
||||
pub struct MetalDynBackend {
|
||||
pub runtime: MetalRuntime,
|
||||
}
|
||||
|
||||
impl DynBackend for MetalDynBackend {
|
||||
fn name(&self) -> &str {
|
||||
"metal"
|
||||
}
|
||||
|
||||
fn set_data_bytes(&mut self, node: NodeIndex, bytes: Vec<u8>, dtype: DType) {
|
||||
self.runtime
|
||||
.set_data(node, bytes_to_native_data(bytes, dtype));
|
||||
}
|
||||
fn set_data_f32(&mut self, node: NodeIndex, data: Vec<f32>) {
|
||||
self.runtime.set_data(node, data);
|
||||
}
|
||||
fn get_output_f32(&self, node: NodeIndex) -> Vec<f32> {
|
||||
self.runtime.get_f32(node)
|
||||
}
|
||||
fn execute(&mut self, dyn_map: &FxHashMap<char, usize>) {
|
||||
self.runtime.execute(dyn_map);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn metal_factory(
|
||||
graph: &mut Graph,
|
||||
args: BackendCompileArgs,
|
||||
) -> Result<Box<dyn DynBackend>, String> {
|
||||
compile_backend::<MetalRuntime>(
|
||||
graph,
|
||||
args,
|
||||
|| Ok(MetalRuntime::initialize(())),
|
||||
|rt, node, bytes, dtype| {
|
||||
rt.set_data(node, bytes_to_native_data(bytes, dtype));
|
||||
},
|
||||
None,
|
||||
|rt| Box::new(MetalDynBackend { runtime: rt }),
|
||||
)
|
||||
}
|
||||
@@ -1,3 +1,4 @@
|
||||
pub mod dyn_backend;
|
||||
pub mod kernel;
|
||||
pub mod runtime;
|
||||
|
||||
|
||||
@@ -234,6 +234,10 @@ impl Runtime for MetalRuntime {
|
||||
}
|
||||
}
|
||||
|
||||
fn aggregate_profile_metrics(metrics: &[Self::ProfileMetric]) -> Self::ProfileMetric {
|
||||
metrics.iter().copied().sum()
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
fn load_llir(&mut self, llir_graph: &LLIRGraph) {
|
||||
self.pipelines.clear();
|
||||
|
||||
@@ -24,7 +24,7 @@ consult before writing new egglog rules, CUDA kernels, or optimizer passes.
|
||||
## Testing Best Practices
|
||||
|
||||
### Overview
|
||||
The luminal_python crate provides a bridge between PyTorch models and the luminal library via ONNX. Tests should verify this integration end-to-end by testing the actual user workflow: PyTorch model → torch.compile → luminal backend.
|
||||
The luminal_python crate provides a bridge between PyTorch models and the luminal library via the PT2 Export pipeline. Tests should verify this integration end-to-end by testing the actual user workflow: PyTorch model → torch.compile → luminal backend.
|
||||
|
||||
### Test Pattern (CORRECT)
|
||||
|
||||
@@ -67,11 +67,11 @@ class AddTestModel(torch.nn.Module):
|
||||
|
||||
### What NOT to Do
|
||||
|
||||
**❌ DO NOT create ONNX files directly in tests:**
|
||||
**❌ DO NOT create pt2 files directly in tests:**
|
||||
```python
|
||||
# WRONG - bypasses the PyTorch integration
|
||||
model_path = create_onnx_model(...)
|
||||
graph_result = luminal.process_onnx(model_path, backend='native')
|
||||
model_path = create_pt2_model(...)
|
||||
graph_result = luminal.process_pt(model_path, backend='native')
|
||||
```
|
||||
|
||||
**✓ DO create PyTorch models and use torch.compile:**
|
||||
@@ -83,16 +83,16 @@ model_compiled = torch.compile(model, backend=luminal_backend)
|
||||
|
||||
### Rationale
|
||||
|
||||
- **End-to-end testing**: Tests verify the complete PyTorch → ONNX → luminal pipeline
|
||||
- **End-to-end testing**: Tests verify the complete PyTorch → Pt2 → luminal pipeline
|
||||
- **User-facing API**: Tests use the same API that users will use (torch.compile)
|
||||
- **Correctness**: Comparing compiled vs original PyTorch output ensures correctness
|
||||
- **Maintainability**: Consistent pattern across all tests makes the codebase easier to understand
|
||||
- **Simplicity**: No manual ONNX file creation, no tempfile cleanup, no numpy comparisons
|
||||
- **Simplicity**: No manual Pt2 file creation, no tempfile cleanup, no numpy comparisons
|
||||
|
||||
### Special Cases
|
||||
|
||||
**Testing constants:**
|
||||
Use inline tensor literals in the forward method - PyTorch exports these as ONNX Constant nodes:
|
||||
Use inline tensor literals in the forward method - these are exported as constant tensors:
|
||||
```python
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
constant = torch.tensor([1.0, 2.0, 3.0])
|
||||
@@ -100,14 +100,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
```
|
||||
|
||||
**Testing type casts:**
|
||||
Use `.to(dtype)` method - PyTorch exports these as ONNX Cast nodes:
|
||||
Use `.to(dtype)` method - these are exported as type cast operations:
|
||||
```python
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x.to(torch.float32)
|
||||
```
|
||||
|
||||
**Testing complex operations:**
|
||||
Chain operations naturally in PyTorch - ONNX export handles the conversion:
|
||||
Chain operations naturally in PyTorch - the export pipeline handles the conversion:
|
||||
```python
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
transposed = x.transpose(0, 1)
|
||||
|
||||
@@ -756,3 +756,29 @@ identical across all attempts (dtype issue) vs varying (actual numerical issue).
|
||||
3. **Why hard**: Per-operation error was ~1e-7 but compounded over 16 layers × ~25 extra materializations. The egglog `Exp` rewrite depends on exact constant format matching.
|
||||
4. **Fix**: Added `KernelExp` (uses `expf()`), `KernelSigmoid` (uses `1/(1+expf(-x))`), and Kahan summation in SumReduce. Each uses both `kernel_rewrite` and a direct egglog pattern match with range checks (e.g., `(> ?val 1.44) (< ?val 1.45)`) to bypass constant format dependency.
|
||||
5. **Principle**: When decomposed CUDA kernel chains cause precision loss, add fused kernels via `kernel_rewrite`. For robustness, add BOTH the logical-op rewrite path AND a direct HLIR pattern match — the constant format in egglog can be fragile.
|
||||
|
||||
## 2026-04-26 — Loop unroll-union rules silently disabled in full egglog stage
|
||||
|
||||
1. **Symptom**: Python `test_llama_transformer_block` (CUDA backend) produced output ~1e-2 off from PyTorch (atol=1e-4) on the `loop_rolling` branch. All component tests (RMSNorm, attention, SwiGLU, RoPE) passed. The diff pattern was suspicious: row 0 of the (1,4,32) output matched exactly, rows 1–3 differed slightly. Disabling rolling fixed it.
|
||||
2. **Root cause**: The auto-roll prepass folds three sequential scalar muls in PyTorch's `pow(2)` decomposition (`exp2(log2(x) * 0.693 * 2.0 * 1.442)` — the last constant is `log2(e)`). The kernel `direct-exp-fusion` egglog rule rewrites `Mul(?x, log2_e_const) → Exp2(...)` into `KernelExp(?x)` (single `expf()` instead of separate exp2f + multiply by truncated log2(e)). Without rolling, this fusion fires and the float chain stays stable; with rolling the fusion can't see through the `LoopStart`/`LoopEnd` markers, so the chain stays as `KernelMul → KernelExp2`, and the truncated `log2(e)` constant accumulates ~1e-7 error per layer that compounds into ~1e-2 over the full block.
|
||||
|
||||
The unroll-union rules I'd added (`Mul`/`Add`/etc. binary-op rules that union a rolled body with its fully-unrolled equivalent) were registered only in `EgglogOp::early_rewrites()`, not `rewrites()`. The egglog driver feeds `early_rewrites` only into the early-stage program and `rewrites` only into the full-stage program. So the unrolled chain materialised in the early egraph, the early→full extract picked the (cheaper) rolled form, the unrolled chain was lost, and `direct-exp-fusion` (which runs in the full stage) had nothing to match against.
|
||||
3. **Why hard**: The post-unroll LLIR for the rolled vs un-rolled paths *looked* nearly identical when scanned visually — both had the Log2 → Mul × 3 → Exp2 chain. The diff was 2 extra Muls vs no-rolling, and the actual semantic gap was visible only in op-name counts: WITH-rolling had 3 `KernelExp2` and 0 `KernelExp`, WITHOUT-rolling had 1 `KernelExp2` and 2 `KernelExp`. Tracking the missing fusion to the early/full ruleset split required reading the egglog driver carefully and noticing that `OpTextParts` builds `early_rewrites` and `full_rewrites` from disjoint method calls.
|
||||
4. **Fix**: Register `binary_op_unroll_rules` in BOTH `early_rewrites()` (so fusion patterns like GLUMoE can match before the early-stage extract, which is what fixed `test_glumoe_gemma_gelu_matches_unfused_output` earlier in the session) AND `rewrites()` (so kernel-level rewrites like `direct-exp-fusion` can match in the full stage on the unrolled chain). One block per binary op (`Add`, `Mul`, `Mod`, `LessThan`).
|
||||
5. **Principle**: When egglog has multiple stages (early/full) with disjoint rule sets, any rewrite that materialises new HLIR/IR enodes (rather than just lowering to LLIR) needs to fire in BOTH stages if downstream rewrites in BOTH stages might want to see the new structure. Putting "preparatory" rewrites only in `early_rewrites` means their effect is lost across the early→full handoff. The narrow rule of thumb: if your rule's outputs are intended to enable matches by other rules, audit which stages those other rules run in and register accordingly.
|
||||
|
||||
## 2026-04-26 — `unroll_loops_in_llir` panicked on iteration-invariant body producers
|
||||
|
||||
1. **Symptom**: Modal CI/CD job for the gemma example panicked at `src/graph.rs:1867` with `no entry found for key`. The line is `clone_map[i - 1][&body_producer]` inside `unroll_loops_in_llir`'s `resolve_src` closure — `body_producer` (the LoopEnd's incoming source for that slot) wasn't a key in the per-iteration clone map. cuda_lite/python tests didn't repro: only triggered by the specific genome and graph shapes that gemma's longer search settles on.
|
||||
2. **Root cause**: `body_nodes` is computed by walking *forward* from each LoopStart/LoopInput/LoopInputStatic outgoing edge, stopping at markers and `Output` ops. Some egglog-extracted LLIRs land a `body_producer` that isn't reachable via that forward walk — i.e., its only ancestors are non-marker (a constant, an external input, or an op whose chain was congruence-merged off the marker chain by rules like `LoopInputStatic inline`). Semantically this is a degenerate "iteration-invariant body": every iter computes the same value, so the loop's state never changes. The per-iter clone path needed a fallback for that case.
|
||||
3. **Why hard**: cuda_lite and python tests don't generate genomes that produce this shape, so local runs always pass. The forward-walk-only definition of `body_nodes` is *almost* always right — only specific extraction shapes from longer searches expose the gap. Test-driven debugging has limited reach when the failure mode depends on a search trajectory the local fuzzers don't explore.
|
||||
4. **Fix**: in `unroll_loops_in_llir::resolve_src`, when the LoopStart-resolved `body_producer` isn't in `body_nodes`, return `body_producer` itself for iter > 0 instead of indexing `clone_map[i - 1]`. The body op didn't depend on the loop variable, so every iter > 0 carries the same value forward — using `body_producer` directly is semantically correct. Mirrored the same `unwrap_or(body_producer)` fallback in the post-loop substitution map (`marker_post_sub` for LoopEnd / LoopOutputSelect). Added a backward-walk-from-end-markers backfill in `collapse_loops_to_first_iter` so its body-node iteration also covers these nodes (it doesn't have a clone_map, but does need to rewire body ops' incoming edges before deleting markers).
|
||||
5. **Principle**: When a graph-walk-derived set is used as a hashmap key requirement, every code path that *could* produce a key outside that set needs a graceful fallback — not just a defensive `expect`. For loop unrolling specifically, the rule is: `body_nodes` is the set of "ops that participate in per-iter computation"; ops on the LoopEnd's path that *don't* participate (iteration-invariant) are still legitimate, and need a "no clone, share across iters" path through `resolve_src` and `marker_post_sub`. Forward-walk-only `body_nodes` is correct only when extraction never produces iteration-invariant body producers — and in an egglog-driven search, that's not a guarantee you can make.
|
||||
|
||||
## 2026-04-26 — Iteration-invariant state slots are a first-class concept, not a defensive fallback
|
||||
|
||||
1. **Symptom + fix recap**: gemma Modal CI panicked at `clone_map[i-1][&body_producer]` because some state slots' `body_producer` (LoopEnd's incoming) isn't in `body_nodes` (forward walk from input markers). The first commit pair (16de9638 / 93fb02c4) caught this with `.unwrap_or(body_producer)` — which works but reads as "defensive, unclear *why* this case exists."
|
||||
2. **What's actually happening**: extracted LLIR from gemma legitimately puts a `KernelConstant` at LoopEnd's incoming for some state slots. e.g. for one slot of gemma's body=104 trips=5 rolling: `initial = KernelConstant 1.442695` (log2 e), `body_producer = same node`. For another: `body_producer = KernelConstant 9.21034` (ln 10000, RoPE's frequency base after `Log2 * ln(2)` simplification). egglog's kernel-level rewrites legitimately union body-slot eclasses with these constants when the body chain provably reduces to them. The state really is iteration-invariant — every iter sees the same value.
|
||||
3. **Why "defensive fallback" framing is misleading**: it implies the LLIR is broken. It isn't. The forward-walk-only `body_nodes` definition just doesn't cover this case, because the case requires no per-iter cloning at all. A *node not reachable from any loop input marker has no input-marker ancestor*, so by construction its value doesn't depend on the loop's per-iter state.
|
||||
4. **Cleaner formulation**: name the concept. Compute an `iteration_invariant_slots: HashSet<LoopStart>` set at the same time `start_meta` is built, with the rule `body_producer ∉ body_nodes ⇒ iteration_invariant`. `resolve_src` and `marker_post_sub` then have explicit branches: if the slot is invariant, use `body_producer` directly; otherwise the standard per-iter clone lookup. The behavior is the same as the `unwrap_or` band-aid, but the code now documents that this is a real, sound case the unroll handles correctly — not a panic suppressor.
|
||||
5. **Principle**: when an `unwrap_or` papers over a case that turns out to be semantically valid, the right cleanup isn't to keep the `unwrap_or` and add a comment — it's to name the case. Hoist the predicate into a set or enum and branch on it explicitly. The compiler then enforces that every consumer of the per-iter cloning machinery has an opinion on iteration-invariant slots, instead of silently relying on a `Map::get` returning `None` at the right moment.
|
||||
|
||||
@@ -186,7 +186,7 @@ class TestRunner:
|
||||
env = os.environ.copy()
|
||||
existing = env.get("PYTHONPATH")
|
||||
env["PYTHONPATH"] = f"{SRC_PATH}:{existing}" if existing else SRC_PATH
|
||||
env["LUMINAL_BACKEND"] = "cuda"
|
||||
env["LUMINAL_TEST_DEVICE"] = "cuda"
|
||||
env["UV_PROJECT_ENVIRONMENT"] = VENV_PATH
|
||||
env["MATURIN_PEP517_ARGS"] = "--features cuda --profile release"
|
||||
env["CUDARC_CUDA_VERSION"] = CUDARC_CUDA_VERSION
|
||||
|
||||
@@ -7,8 +7,6 @@ requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
"numpy>=2.0.2",
|
||||
"torch>=2.10.0",
|
||||
"onnx",
|
||||
"onnxscript",
|
||||
"safetensors",
|
||||
]
|
||||
|
||||
@@ -47,6 +45,5 @@ dev = [
|
||||
"pytest-randomly>=4.0.1",
|
||||
"transformers>=4.40.0",
|
||||
"diffusers>=0.35.0",
|
||||
"onnxsim",
|
||||
"modal>=1.3.5",
|
||||
]
|
||||
|
||||
@@ -16,13 +16,9 @@ rm -rf rust/target/wheels rust/target/debug rust/target/release
|
||||
uv run maturin develop --manifest-path rust/Cargo.toml
|
||||
|
||||
echo ""
|
||||
echo "--- 1a: Native + ONNX ---"
|
||||
echo "--- 1a: Native backend tests ---"
|
||||
uv run pytest $NATIVE_TESTS -v
|
||||
|
||||
echo ""
|
||||
echo "--- 1b: Native + PT2 ---"
|
||||
LUMINAL_EXPORT_MODE=pt2 uv run pytest $NATIVE_TESTS -v
|
||||
|
||||
# ── Phase 2: CUDA Backend ───────────────────────────────────
|
||||
|
||||
echo ""
|
||||
@@ -31,12 +27,8 @@ rm -rf rust/target/wheels rust/target/debug rust/target/release
|
||||
uv run maturin develop --manifest-path rust/Cargo.toml --features cuda -r
|
||||
|
||||
echo ""
|
||||
echo "--- 2a: CUDA + ONNX ---"
|
||||
RUST_BACKTRACE=1 LUMINAL_BACKEND=cuda uv run pytest $CUDA_TESTS -m "not slow" -v
|
||||
|
||||
echo ""
|
||||
echo "--- 2b: CUDA + PT2 ---"
|
||||
RUST_BACKTRACE=1 LUMINAL_BACKEND=cuda LUMINAL_EXPORT_MODE=pt2 uv run pytest $CUDA_TESTS -m "not slow" -v
|
||||
echo "--- 2a: CUDA ---"
|
||||
RUST_BACKTRACE=1 LUMINAL_TEST_DEVICE=cuda uv run pytest $CUDA_TESTS -m "not slow" -v
|
||||
|
||||
echo ""
|
||||
echo "=========================================="
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
echo "=== Luminal Python Test Runner (PT2 Export Mode) ==="
|
||||
echo ""
|
||||
|
||||
# Force clean rebuild of Rust extension
|
||||
echo "Step 1: Cleaning previous builds..."
|
||||
rm -rf rust/target/wheels rust/target/debug rust/target/release
|
||||
|
||||
# Rebuild in development mode (faster compilation)
|
||||
echo "Step 2: Building Rust extension..."
|
||||
uv run maturin develop --manifest-path rust/Cargo.toml
|
||||
|
||||
# Run pytest with PT2 export mode
|
||||
echo "Step 3: Running pytest with PT2 export mode..."
|
||||
LUMINAL_EXPORT_MODE=pt2 uv run pytest tests/test_hlir_ops.py tests/test_unary.py -v
|
||||
|
||||
echo ""
|
||||
echo "=== Tests Complete ==="
|
||||
@@ -14,7 +14,7 @@ uv run maturin develop --manifest-path rust/Cargo.toml --features cuda -r
|
||||
|
||||
# Run pytest with CUDA backend
|
||||
echo "Step 3: Running pytest with CUDA backend..."
|
||||
RUST_BACKTRACE=1 LUMINAL_BACKEND=cuda uv run pytest tests/test_llama3.py tests/test_hlir_ops.py tests/test_unary.py -v
|
||||
RUST_BACKTRACE=1 LUMINAL_TEST_DEVICE=cuda uv run pytest tests/test_llama3.py tests/test_hlir_ops.py tests/test_unary.py -v
|
||||
|
||||
echo ""
|
||||
echo "=== Tests Complete ==="
|
||||
|
||||
@@ -1,19 +0,0 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
echo "=== Luminal Python Test Runner (CUDA + PT2 Export Mode) ==="
|
||||
echo ""
|
||||
|
||||
# Force clean rebuild of Rust extension
|
||||
echo "Step 1: Cleaning previous builds..."
|
||||
rm -rf rust/target/wheels rust/target/debug rust/target/release
|
||||
|
||||
# Rebuild in development mode (faster compilation)
|
||||
echo "Step 2: Building Rust extension..."
|
||||
uv run maturin develop --manifest-path rust/Cargo.toml --features cuda -r
|
||||
|
||||
# Run pytest with CUDA backend and PT2 export mode
|
||||
echo "Step 3: Running pytest with CUDA backend + PT2 export mode..."
|
||||
RUST_BACKTRACE=1 LUMINAL_BACKEND=cuda LUMINAL_EXPORT_MODE=pt2 uv run pytest tests/test_llama3.py tests/test_hlir_ops.py tests/test_unary.py -v
|
||||
echo ""
|
||||
echo "=== Tests Complete ==="
|
||||
@@ -12,8 +12,6 @@ path = "src/lib.rs"
|
||||
cuda = ["dep:luminal_cuda_lite"]
|
||||
|
||||
[dependencies]
|
||||
onnx-protobuf = "0.2"
|
||||
protobuf = "~3.4"
|
||||
rustc-hash = "2.1.1"
|
||||
luminal = {path= "../../.."}
|
||||
luminal_cuda_lite = {path="../../luminal_cuda_lite", optional = true}
|
||||
|
||||
@@ -1,32 +1,51 @@
|
||||
#[cfg(feature = "cuda")]
|
||||
use luminal::prelude::tracing::{trace, warn};
|
||||
use luminal::{prelude::*, shape::Expression, visualization::ToDot};
|
||||
use luminal::{
|
||||
dyn_backend::{BackendCompileArgs, BackendFactory, DynBackend},
|
||||
prelude::*,
|
||||
shape::Expression,
|
||||
visualization::ToDot,
|
||||
};
|
||||
use pyo3::prelude::*;
|
||||
use std::collections::HashMap;
|
||||
#[cfg(feature = "cuda")]
|
||||
use std::collections::HashSet;
|
||||
|
||||
use crate::{runtime::RuntimeBackend, util::DimParamMap};
|
||||
use crate::typed_data::TypedData;
|
||||
|
||||
/// Common intermediate result from translating a model graph (ONNX or FX).
|
||||
/// Maps symbolic dimension parameter names (e.g. "seq_len") to luminal Expression variable chars.
|
||||
pub type DimParamMap = HashMap<String, char>;
|
||||
|
||||
/// Convert luminal DType to PT2 dtype integer code (for python interop)
|
||||
/// Types without a direct Pytorch equivalent map to the closest safe representation
|
||||
fn luminal_dtype_to_pt2_code(dtype: DType) -> u32 {
|
||||
match dtype {
|
||||
DType::U8 => 1,
|
||||
DType::I8 => 2,
|
||||
DType::I16 => 3,
|
||||
DType::Int => 4, // i32
|
||||
DType::U16 => 4, // u16 -> i32 (Pytorch has no u16 in older versions)
|
||||
DType::F16 => 6,
|
||||
DType::F32 | DType::TF32 => 7,
|
||||
DType::F64 => 8,
|
||||
DType::Bool => 12,
|
||||
DType::Bf16 => 13,
|
||||
_ => panic!("luminal_dtype_to_pt2_code: unsupported dtype {:?}", dtype),
|
||||
}
|
||||
}
|
||||
|
||||
/// Common intermediate result from translating a model graph.
|
||||
pub struct GraphTranslation {
|
||||
pub graph: Graph,
|
||||
pub tensor_ids: HashMap<String, NodeIndex>,
|
||||
pub input_names: Vec<String>,
|
||||
pub output_names: Vec<String>,
|
||||
pub output_shape_exprs: Vec<Vec<Expression>>,
|
||||
pub output_dtypes: Vec<DType>,
|
||||
pub input_shape_exprs: Vec<Vec<Expression>>,
|
||||
pub dim_param_map: DimParamMap,
|
||||
}
|
||||
|
||||
/// Pre-loaded weight data from any model format.
|
||||
///
|
||||
/// NOTE: Currently assumes all data is F32. When the type system branch lands
|
||||
/// with proper multi-dtype support, this struct (and all callers) will need
|
||||
/// updating to carry dtype metadata alongside the raw data.
|
||||
/// Pre-loaded weight data from any model format (dtype-aware).
|
||||
pub struct WeightData {
|
||||
/// (Input node label, f32 data) for weights and constants.
|
||||
pub weights: Vec<(String, Vec<f32>)>,
|
||||
/// (Input node label, typed data) for weights and constants.
|
||||
pub weights: Vec<(String, TypedData)>,
|
||||
/// label → element count for ALL Input nodes (for CUDA dummy data sizing).
|
||||
pub tensor_sizes: HashMap<String, usize>,
|
||||
/// label → (device_ptr, n_bytes) for zero-copy CUDA weight sharing.
|
||||
@@ -36,7 +55,7 @@ pub struct WeightData {
|
||||
#[pyclass(unsendable)]
|
||||
pub struct CompiledGraph {
|
||||
pub graph: Graph,
|
||||
pub runtime: RuntimeBackend,
|
||||
pub runtime: Box<dyn DynBackend>,
|
||||
pub tensor_ids: HashMap<String, NodeIndex>,
|
||||
/// Cached label → NodeIndex map for O(1) lookups in set_weight_* methods.
|
||||
label_map: HashMap<String, NodeIndex>,
|
||||
@@ -44,20 +63,21 @@ pub struct CompiledGraph {
|
||||
pub output_names: Vec<String>,
|
||||
pub output_shapes: Vec<Vec<usize>>,
|
||||
pub output_shape_exprs: Vec<Vec<Expression>>,
|
||||
pub output_dtypes: Vec<DType>,
|
||||
pub input_shape_exprs: Vec<Vec<Expression>>,
|
||||
pub dim_param_map: DimParamMap,
|
||||
}
|
||||
|
||||
impl CompiledGraph {
|
||||
/// Shared compilation pipeline for both ONNX and FX/PT2 graphs.
|
||||
/// Compilation pipeline for PT2/FX graphs.
|
||||
///
|
||||
/// Takes a format-neutral `GraphTranslation` (produced by `translate_onnx` or
|
||||
/// `translate_pt2`) and `WeightData`, builds the backend, loads weights, and
|
||||
/// Takes a `GraphTranslation` (produced by `translate_pt2`) and `WeightData`,
|
||||
/// builds the backend via the global registry, loads weights, and
|
||||
/// returns a ready-to-execute `CompiledGraph`.
|
||||
pub fn parse_graph(
|
||||
translation: GraphTranslation,
|
||||
weight_data: WeightData,
|
||||
backend: &str,
|
||||
factory: BackendFactory,
|
||||
search_iters: usize,
|
||||
) -> Result<CompiledGraph, String> {
|
||||
let GraphTranslation {
|
||||
@@ -66,49 +86,34 @@ impl CompiledGraph {
|
||||
input_names,
|
||||
output_names,
|
||||
output_shape_exprs,
|
||||
output_dtypes,
|
||||
input_shape_exprs,
|
||||
dim_param_map,
|
||||
} = translation;
|
||||
|
||||
let rt = match backend {
|
||||
#[cfg(feature = "cuda")]
|
||||
"cuda" | "gpu" => {
|
||||
CompiledGraph::build_cuda_backend(&mut graph, &weight_data, search_iters)?
|
||||
}
|
||||
"native" | "cpu" => {
|
||||
CompiledGraph::build_native_backend(&mut graph, &weight_data, search_iters)?
|
||||
}
|
||||
_ => {
|
||||
#[cfg(feature = "cuda")]
|
||||
{
|
||||
return Err(format!(
|
||||
"Invalid backend '{}'. Must be 'native' or 'cuda'",
|
||||
backend
|
||||
));
|
||||
}
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
{
|
||||
if backend == "cuda" {
|
||||
return Err(
|
||||
"CUDA backend requested, but this luminal extension was built without the `cuda` feature. Rebuild with `maturin develop --features cuda -r` or use backend='native'."
|
||||
.to_string(),
|
||||
);
|
||||
}
|
||||
return Err(format!(
|
||||
"Invalid backend '{}'. This build only supports 'native'. Rebuild with the `cuda` feature to enable 'cuda'.",
|
||||
backend
|
||||
));
|
||||
}
|
||||
}
|
||||
// Build compile args from WeightData (convert TypedData -> raw bytes + dtype)
|
||||
let compile_args = BackendCompileArgs {
|
||||
search_iters,
|
||||
weights: weight_data
|
||||
.weights
|
||||
.iter()
|
||||
.map(|(label, td)| (label.clone(), td.bytes.clone(), td.dtype))
|
||||
.collect(),
|
||||
tensor_sizes: weight_data.tensor_sizes,
|
||||
device_ptrs: weight_data.device_ptrs,
|
||||
};
|
||||
|
||||
// Create backend via the factory directly
|
||||
let rt =
|
||||
luminal::dyn_backend::compile_backend_from_factory(factory, &mut graph, compile_args)?;
|
||||
|
||||
// Resolve concrete output shapes from expressions
|
||||
let output_shapes: Vec<Vec<usize>> = output_shape_exprs
|
||||
.iter()
|
||||
.map(|exprs| exprs.iter().map(|e| e.to_usize().unwrap_or(1)).collect())
|
||||
.collect();
|
||||
|
||||
let label_map = CompiledGraph::build_label_map(&graph);
|
||||
let label_map = luminal::dyn_backend::build_label_map(&graph);
|
||||
|
||||
Ok(CompiledGraph {
|
||||
graph,
|
||||
@@ -119,160 +124,11 @@ impl CompiledGraph {
|
||||
output_names,
|
||||
output_shapes,
|
||||
output_shape_exprs,
|
||||
output_dtypes,
|
||||
input_shape_exprs,
|
||||
dim_param_map,
|
||||
})
|
||||
}
|
||||
|
||||
/// Build a label → NodeIndex map for all Input nodes in the graph.
|
||||
/// Used for efficient weight loading by label matching.
|
||||
fn build_label_map(graph: &Graph) -> HashMap<String, NodeIndex> {
|
||||
graph
|
||||
.graph
|
||||
.node_indices()
|
||||
.filter_map(|node_id| {
|
||||
(*graph.graph[node_id])
|
||||
.as_any()
|
||||
.downcast_ref::<luminal::hlir::Input>()
|
||||
.map(|input| (input.label.clone(), node_id))
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
fn build_cuda_backend(
|
||||
graph: &mut Graph,
|
||||
weight_data: &WeightData,
|
||||
search_iters: usize,
|
||||
) -> Result<RuntimeBackend, String> {
|
||||
let device_ptrs = &weight_data.device_ptrs;
|
||||
use luminal_cuda_lite::cudarc::driver::CudaContext;
|
||||
use luminal_cuda_lite::runtime::CudaRuntime;
|
||||
|
||||
let cuda_ctx = CudaContext::new(0).map_err(|e| format!("CUDA context init failed: {e}"))?;
|
||||
let stream = cuda_ctx.default_stream();
|
||||
|
||||
graph.build_search_space::<CudaRuntime>();
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
|
||||
// Build label → NodeIndex map for device pointer matching.
|
||||
let label_map = CompiledGraph::build_label_map(graph);
|
||||
|
||||
// For weights with device pointers: use them directly (zero-copy).
|
||||
// This avoids allocating ~N GB of dummy data during search.
|
||||
// The pointers survive search because profiling mode skips buffer consumption,
|
||||
// and persist_hlir_node ensures they survive post-search execution too.
|
||||
let mut device_ptr_nodes: HashSet<NodeIndex> = HashSet::new();
|
||||
let mut matched_count = 0usize;
|
||||
let mut missed_labels: Vec<String> = Vec::new();
|
||||
for (label, &(ptr, n_bytes)) in device_ptrs {
|
||||
if let Some(&node_id) = label_map.get(label) {
|
||||
unsafe { rt.set_device_ptr(node_id, ptr, n_bytes) };
|
||||
rt.persist_hlir_node(node_id);
|
||||
device_ptr_nodes.insert(node_id);
|
||||
matched_count += 1;
|
||||
} else {
|
||||
missed_labels.push(label.clone());
|
||||
}
|
||||
}
|
||||
let total_device_bytes: usize = device_ptrs.values().map(|(_, n)| *n).sum();
|
||||
trace!(
|
||||
"[CUDA BUILD] Device pointers: {} matched, {} missed out of {} total ({:.3} GiB)",
|
||||
matched_count,
|
||||
missed_labels.len(),
|
||||
device_ptrs.len(),
|
||||
total_device_bytes as f64 / (1024.0 * 1024.0 * 1024.0),
|
||||
);
|
||||
if !missed_labels.is_empty() {
|
||||
warn!(
|
||||
"[CUDA BUILD] {} device-ptr labels did not match any Input node (first 10): {:?}",
|
||||
missed_labels.len(),
|
||||
&missed_labels[..missed_labels.len().min(10)]
|
||||
);
|
||||
let available: Vec<&String> = label_map.keys().take(10).collect();
|
||||
warn!(
|
||||
"[CUDA BUILD] Available label_map keys (first 10): {:?}",
|
||||
available
|
||||
);
|
||||
}
|
||||
|
||||
// Set dummy 1.0 data for remaining Input nodes (user inputs, constants without
|
||||
// device pointers) for safe search profiling.
|
||||
// IMPORTANT: Must use 1.0, NOT 0.0. Zero inputs cause NaN in many ops:
|
||||
// - fmod(0, 0) = NaN (Mod)
|
||||
// - recip(0) = inf → weight * inf = NaN (Div)
|
||||
// - log(0) = -inf (Pow)
|
||||
// - chain ops with zero produce NaN (Erf)
|
||||
let mut dummy_total_elements = 0usize;
|
||||
let mut dummy_count = 0usize;
|
||||
for node_id in graph.graph.node_indices() {
|
||||
if device_ptr_nodes.contains(&node_id) {
|
||||
continue;
|
||||
}
|
||||
if let Some(input) = (*graph.graph[node_id])
|
||||
.as_any()
|
||||
.downcast_ref::<luminal::hlir::Input>()
|
||||
{
|
||||
if let Some(&n) = weight_data.tensor_sizes.get(&input.label) {
|
||||
if n > 0 {
|
||||
dummy_total_elements += n;
|
||||
dummy_count += 1;
|
||||
rt.set_data(node_id, vec![1.0f32; n]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
trace!(
|
||||
"[CUDA BUILD] Dummy data: {} nodes, {} elements ({:.3} GiB as f32)",
|
||||
dummy_count,
|
||||
dummy_total_elements,
|
||||
(dummy_total_elements * 4) as f64 / (1024.0 * 1024.0 * 1024.0),
|
||||
);
|
||||
|
||||
// Search (device-pointer weights are used directly; dummy data for the rest)
|
||||
let mut rt = graph.search(rt, search_iters);
|
||||
|
||||
// Load real weight data for non-device-ptr weights (constants from PT2 archive, etc.)
|
||||
let mut loaded_weight_elements = 0usize;
|
||||
let mut loaded_weight_count = 0usize;
|
||||
for (label, data) in &weight_data.weights {
|
||||
if !device_ptrs.contains_key(label) {
|
||||
if let Some(&node_id) = label_map.get(label) {
|
||||
loaded_weight_elements += data.len();
|
||||
loaded_weight_count += 1;
|
||||
rt.set_data(node_id, data.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
trace!(
|
||||
"[CUDA BUILD] Post-search weight load: {} weights, {} elements ({:.3} GiB as f32)",
|
||||
loaded_weight_count,
|
||||
loaded_weight_elements,
|
||||
(loaded_weight_elements * 4) as f64 / (1024.0 * 1024.0 * 1024.0),
|
||||
);
|
||||
|
||||
Ok(RuntimeBackend::Cuda(Box::new(rt)))
|
||||
}
|
||||
|
||||
fn build_native_backend(
|
||||
graph: &mut Graph,
|
||||
weight_data: &WeightData,
|
||||
search_iters: usize,
|
||||
) -> Result<RuntimeBackend, String> {
|
||||
graph.build_search_space::<NativeRuntime>();
|
||||
let mut rt = graph.search(NativeRuntime::default(), search_iters);
|
||||
|
||||
// Load weight data after search
|
||||
let label_map = CompiledGraph::build_label_map(graph);
|
||||
for (label, data) in &weight_data.weights {
|
||||
if let Some(&node_id) = label_map.get(label) {
|
||||
rt.set_data(node_id, data.clone());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(RuntimeBackend::Native(rt))
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
@@ -283,6 +139,24 @@ impl CompiledGraph {
|
||||
self.input_names.clone()
|
||||
}
|
||||
|
||||
/// Get the PT2 dtype codes for all inputs (in order of input_names).
|
||||
#[getter]
|
||||
fn input_dtypes(&self) -> Vec<u32> {
|
||||
self.input_names
|
||||
.iter()
|
||||
.map(|name| {
|
||||
if let Some(&node_id) = self.tensor_ids.get(name)
|
||||
&& let Some(input) = (*self.graph.graph[node_id])
|
||||
.as_any()
|
||||
.downcast_ref::<luminal::hlir::Input>()
|
||||
{
|
||||
return luminal_dtype_to_pt2_code(input.dtype);
|
||||
}
|
||||
7 // default to f32
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get the list of output tensor names.
|
||||
#[getter]
|
||||
fn output_names(&self) -> Vec<String> {
|
||||
@@ -301,12 +175,24 @@ impl CompiledGraph {
|
||||
self.tensor_ids.keys().cloned().collect()
|
||||
}
|
||||
|
||||
/// Get the name of the active backend (native or cuda).
|
||||
/// Get the name of the active backend.
|
||||
#[getter]
|
||||
fn backend(&self) -> &'static str {
|
||||
fn backend(&self) -> &str {
|
||||
self.runtime.name()
|
||||
}
|
||||
|
||||
/// The device type this backend operates on (e.g. "cpu", "cuda").
|
||||
#[getter]
|
||||
fn device_type(&self) -> &str {
|
||||
self.runtime.device_type()
|
||||
}
|
||||
|
||||
/// Whether the active backend supports device pointer operations (zero-copy GPU I/O).
|
||||
#[getter]
|
||||
fn supports_device_ptrs(&self) -> bool {
|
||||
self.runtime.supports_device_ptrs()
|
||||
}
|
||||
|
||||
/// Whether this graph has dynamic (symbolic) dimensions.
|
||||
#[getter]
|
||||
fn has_dynamic_dims(&self) -> bool {
|
||||
@@ -371,100 +257,136 @@ impl CompiledGraph {
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Set input tensor data by name.
|
||||
/// Set input tensor data by name (f32, for backward compatibility).
|
||||
fn set_input(&mut self, name: &str, data: Vec<f32>) -> PyResult<()> {
|
||||
let node_id = self.tensor_ids.get(name).ok_or_else(|| {
|
||||
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!("Unknown input tensor: {}", name))
|
||||
})?;
|
||||
self.runtime.set_data(*node_id, data);
|
||||
self.runtime.set_data_f32(*node_id, data);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Set input tensor data from a CPU host memory pointer (avoids Python list conversion).
|
||||
/// The pointer must point to contiguous f32 data (from tensor.data_ptr() on a CPU float32 tensor).
|
||||
fn set_input_from_ptr(&mut self, name: &str, ptr: u64, n_elements: usize) -> PyResult<()> {
|
||||
/// Set input tensor data from a CPU host memory pointer (dtype-aware).
|
||||
/// The pointer must point to contiguous data. `n_bytes` is the total byte count.
|
||||
/// `dtype_code` uses PT2 numbering (7=f32, 6=f16, 13=bf16, etc.).
|
||||
/// Converts source format to luminal's native format (e.g., i64→i32, f64→f32).
|
||||
fn set_input_from_ptr(
|
||||
&mut self,
|
||||
name: &str,
|
||||
ptr: u64,
|
||||
n_bytes: usize,
|
||||
dtype_code: u32,
|
||||
) -> PyResult<()> {
|
||||
debug_assert!(ptr != 0, "set_input_from_ptr called with null pointer");
|
||||
let node_id = self.tensor_ids.get(name).ok_or_else(|| {
|
||||
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!("Unknown input tensor: {}", name))
|
||||
})?;
|
||||
let data: Vec<f32> =
|
||||
unsafe { std::slice::from_raw_parts(ptr as *const f32, n_elements).to_vec() };
|
||||
self.runtime.set_data(*node_id, data);
|
||||
let raw_bytes = unsafe { std::slice::from_raw_parts(ptr as *const u8, n_bytes).to_vec() };
|
||||
let typed = TypedData::from_pytorch_bytes(raw_bytes, dtype_code);
|
||||
self.runtime
|
||||
.set_data_bytes(*node_id, typed.bytes, typed.dtype);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Set input from a CUDA device pointer. Zero-copy on device.
|
||||
/// The pointer must be a valid CUDA device allocation with at least n_bytes bytes.
|
||||
#[cfg(feature = "cuda")]
|
||||
/// Set input from a device pointer. Zero-copy on device.
|
||||
/// The pointer must be a valid device allocation with at least n_bytes bytes.
|
||||
/// Requires a GPU backend (e.g. CUDA).
|
||||
fn set_input_device_ptr(
|
||||
&mut self,
|
||||
name: &str,
|
||||
device_ptr: u64,
|
||||
n_bytes: usize,
|
||||
) -> PyResult<()> {
|
||||
if !self.runtime.supports_device_ptrs() {
|
||||
return Err(pyo3::exceptions::PyValueError::new_err(
|
||||
"set_input_device_ptr requires a GPU backend",
|
||||
));
|
||||
}
|
||||
let node_id = self.tensor_ids.get(name).ok_or_else(|| {
|
||||
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!("Unknown input tensor: {}", name))
|
||||
})?;
|
||||
match &mut self.runtime {
|
||||
RuntimeBackend::Cuda(rt) => unsafe { rt.set_device_ptr(*node_id, device_ptr, n_bytes) },
|
||||
_ => {
|
||||
return Err(pyo3::exceptions::PyValueError::new_err(
|
||||
"set_input_device_ptr requires CUDA backend",
|
||||
));
|
||||
}
|
||||
}
|
||||
unsafe { self.runtime.set_device_ptr(*node_id, device_ptr, n_bytes) };
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Mark an input tensor as persistent (survives execute() calls).
|
||||
/// Call this for weight tensors that should not be consumed after each execution.
|
||||
fn persist_input(&mut self, name: &str) -> PyResult<()> {
|
||||
let _node_id = *self.tensor_ids.get(name).ok_or_else(|| {
|
||||
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!("Unknown input tensor: {}", name))
|
||||
})?;
|
||||
match &mut self.runtime {
|
||||
#[cfg(feature = "cuda")]
|
||||
RuntimeBackend::Cuda(rt) => rt.persist_hlir_node(_node_id),
|
||||
RuntimeBackend::Native(_) => {} // Native: persist is handled at graph level
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Set a weight tensor from a CUDA device pointer, matching by Input node label.
|
||||
/// Also marks the weight as persistent. For PT2 weights (e.g. "fc1.weight").
|
||||
#[cfg(feature = "cuda")]
|
||||
/// Set a weight from a device pointer (e.g. "fc1.weight"). Zero-copy on device.
|
||||
/// Requires a GPU backend.
|
||||
fn set_weight_device_ptr(
|
||||
&mut self,
|
||||
label: &str,
|
||||
device_ptr: u64,
|
||||
n_bytes: usize,
|
||||
) -> PyResult<()> {
|
||||
if !self.runtime.supports_device_ptrs() {
|
||||
return Err(pyo3::exceptions::PyValueError::new_err(
|
||||
"set_weight_device_ptr requires a GPU backend",
|
||||
));
|
||||
}
|
||||
let &node_id = self.label_map.get(label).ok_or_else(|| {
|
||||
pyo3::exceptions::PyKeyError::new_err(format!("No Input node with label: {}", label))
|
||||
})?;
|
||||
match &mut self.runtime {
|
||||
RuntimeBackend::Cuda(rt) => {
|
||||
unsafe { rt.set_device_ptr(node_id, device_ptr, n_bytes) };
|
||||
rt.persist_hlir_node(node_id);
|
||||
}
|
||||
_ => {
|
||||
return Err(pyo3::exceptions::PyValueError::new_err(
|
||||
"set_weight_device_ptr requires CUDA backend",
|
||||
));
|
||||
}
|
||||
}
|
||||
unsafe { self.runtime.set_device_ptr(node_id, device_ptr, n_bytes) };
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Set a weight tensor from a CPU host pointer, matching by Input node label.
|
||||
fn set_weight_from_ptr(&mut self, label: &str, ptr: u64, n_elements: usize) -> PyResult<()> {
|
||||
/// Register an external device pointer for an output tensor (zero-copy output).
|
||||
/// Call before run() — the runtime will write kernel results directly into this buffer.
|
||||
/// For aliased outputs (in-place ops), falls back to DtoD copy; check output_is_zero_copy() after run().
|
||||
/// Requires a GPU backend.
|
||||
fn set_output_device_ptr(
|
||||
&mut self,
|
||||
name: &str,
|
||||
device_ptr: u64,
|
||||
n_bytes: usize,
|
||||
) -> PyResult<()> {
|
||||
if !self.runtime.supports_device_ptrs() {
|
||||
return Err(pyo3::exceptions::PyValueError::new_err(
|
||||
"set_output_device_ptr requires a GPU backend",
|
||||
));
|
||||
}
|
||||
let node_id = self.tensor_ids.get(name).ok_or_else(|| {
|
||||
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!(
|
||||
"Unknown output tensor: {}",
|
||||
name
|
||||
))
|
||||
})?;
|
||||
unsafe {
|
||||
self.runtime
|
||||
.set_output_device_ptr(*node_id, device_ptr, n_bytes)
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check whether an output tensor was zero-copied (written directly to the registered pointer).
|
||||
/// Returns false for aliased outputs that need a fallback DtoD copy, or if no GPU backend.
|
||||
/// Must be called after run().
|
||||
fn output_is_zero_copy(&self, name: &str) -> PyResult<bool> {
|
||||
let node_id = self.tensor_ids.get(name).ok_or_else(|| {
|
||||
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!(
|
||||
"Unknown output tensor: {}",
|
||||
name
|
||||
))
|
||||
})?;
|
||||
Ok(self.runtime.output_is_zero_copy(*node_id))
|
||||
}
|
||||
|
||||
/// Set a weight tensor from a CPU host pointer, matching by Input node label (dtype-aware).
|
||||
/// `n_bytes` is the total byte count. `dtype_code` uses PT2 numbering (7=f32, 6=f16, 13=bf16, etc.).
|
||||
fn set_weight_from_ptr(
|
||||
&mut self,
|
||||
label: &str,
|
||||
ptr: u64,
|
||||
n_bytes: usize,
|
||||
dtype_code: u32,
|
||||
) -> PyResult<()> {
|
||||
debug_assert!(ptr != 0, "set_weight_from_ptr called with null pointer");
|
||||
let &node_id = self.label_map.get(label).ok_or_else(|| {
|
||||
pyo3::exceptions::PyKeyError::new_err(format!("No Input node with label: {}", label))
|
||||
})?;
|
||||
let data: Vec<f32> =
|
||||
unsafe { std::slice::from_raw_parts(ptr as *const f32, n_elements).to_vec() };
|
||||
self.runtime.set_data(node_id, data);
|
||||
let bytes = unsafe { std::slice::from_raw_parts(ptr as *const u8, n_bytes).to_vec() };
|
||||
let typed = TypedData::from_pytorch_bytes(bytes, dtype_code);
|
||||
self.runtime
|
||||
.set_data_bytes(node_id, typed.bytes, typed.dtype);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -480,7 +402,16 @@ impl CompiledGraph {
|
||||
})
|
||||
}
|
||||
|
||||
/// Get output tensor data by name (copies to host).
|
||||
/// Get the PT2 dtype codes for all outputs (in order).
|
||||
#[getter]
|
||||
fn output_dtypes(&self) -> Vec<u32> {
|
||||
self.output_dtypes
|
||||
.iter()
|
||||
.map(|d| luminal_dtype_to_pt2_code(*d))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get output tensor data by name as f32 (copies to host).
|
||||
fn get_output(&self, name: &str) -> PyResult<Vec<f32>> {
|
||||
let node_id = self.tensor_ids.get(name).ok_or_else(|| {
|
||||
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!(
|
||||
@@ -488,27 +419,50 @@ impl CompiledGraph {
|
||||
name
|
||||
))
|
||||
})?;
|
||||
Ok(self.runtime.get_f32(*node_id))
|
||||
Ok(self.runtime.get_output_f32(*node_id))
|
||||
}
|
||||
|
||||
/// Copy output tensor data directly to a CUDA device pointer (DtoD).
|
||||
/// Avoids the DtoH + HtoD round-trip of get_output() + .to(device).
|
||||
#[cfg(feature = "cuda")]
|
||||
fn copy_output_to_device_ptr(&self, name: &str, dest_ptr: u64, n_bytes: usize) -> PyResult<()> {
|
||||
/// Get output tensor data by name as i32 (copies to host).
|
||||
fn get_output_i32(&self, name: &str) -> PyResult<Vec<i32>> {
|
||||
let node_id = self.tensor_ids.get(name).ok_or_else(|| {
|
||||
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!(
|
||||
"Unknown output tensor: {}",
|
||||
name
|
||||
))
|
||||
})?;
|
||||
match &self.runtime {
|
||||
RuntimeBackend::Cuda(rt) => {
|
||||
unsafe { rt.copy_output_to_device_ptr(*node_id, dest_ptr, n_bytes) };
|
||||
Ok(())
|
||||
}
|
||||
_ => Err(pyo3::exceptions::PyValueError::new_err(
|
||||
"copy_output_to_device_ptr requires CUDA backend",
|
||||
)),
|
||||
Ok(self.runtime.get_output_i32(*node_id))
|
||||
}
|
||||
|
||||
/// Get output tensor data by name as bool (copies to host).
|
||||
fn get_output_bool(&self, name: &str) -> PyResult<Vec<bool>> {
|
||||
let node_id = self.tensor_ids.get(name).ok_or_else(|| {
|
||||
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!(
|
||||
"Unknown output tensor: {}",
|
||||
name
|
||||
))
|
||||
})?;
|
||||
Ok(self.runtime.get_output_bool(*node_id))
|
||||
}
|
||||
|
||||
/// Copy output tensor data directly to a device pointer (DtoD).
|
||||
/// Avoids the DtoH + HtoD round-trip of get_output() + .to(device).
|
||||
/// Requires a GPU backend.
|
||||
fn copy_output_to_device_ptr(&self, name: &str, dest_ptr: u64, n_bytes: usize) -> PyResult<()> {
|
||||
if !self.runtime.supports_device_ptrs() {
|
||||
return Err(pyo3::exceptions::PyValueError::new_err(
|
||||
"copy_output_to_device_ptr requires a GPU backend",
|
||||
));
|
||||
}
|
||||
let node_id = self.tensor_ids.get(name).ok_or_else(|| {
|
||||
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!(
|
||||
"Unknown output tensor: {}",
|
||||
name
|
||||
))
|
||||
})?;
|
||||
unsafe {
|
||||
self.runtime
|
||||
.copy_output_to_device_ptr(*node_id, dest_ptr, n_bytes)
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,248 +0,0 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use luminal::{prelude::*, shape::Expression};
|
||||
use onnx_protobuf::NodeProto;
|
||||
|
||||
use crate::ops_parse::*;
|
||||
|
||||
pub fn process_onnx_nodes(
|
||||
nodes: &[NodeProto],
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
cx: &mut Graph,
|
||||
weight_data: &mut Vec<(String, Vec<f32>)>,
|
||||
known_values: &mut HashMap<String, Vec<f32>>,
|
||||
shape_exprs: &mut HashMap<String, Vec<Expression>>,
|
||||
) -> Result<(), String> {
|
||||
for node in nodes {
|
||||
match node.op_type.as_str() {
|
||||
"Add" => parse_binary_broadcast_op(
|
||||
node,
|
||||
tensors,
|
||||
"Add",
|
||||
|a, b| a + b,
|
||||
shape_exprs,
|
||||
known_values,
|
||||
)?,
|
||||
"Mod" => parse_binary_broadcast_op(
|
||||
node,
|
||||
tensors,
|
||||
"Mod",
|
||||
|a, b| a % b,
|
||||
shape_exprs,
|
||||
known_values,
|
||||
)?,
|
||||
"Sub" => parse_binary_broadcast_op(
|
||||
node,
|
||||
tensors,
|
||||
"Sub",
|
||||
|a, b| a - b,
|
||||
shape_exprs,
|
||||
known_values,
|
||||
)?,
|
||||
"Mul" => parse_binary_broadcast_op(
|
||||
node,
|
||||
tensors,
|
||||
"Mul",
|
||||
|a, b| a * b,
|
||||
shape_exprs,
|
||||
known_values,
|
||||
)?,
|
||||
"Div" => parse_binary_broadcast_op(
|
||||
node,
|
||||
tensors,
|
||||
"Div",
|
||||
|a, b| a / b,
|
||||
shape_exprs,
|
||||
known_values,
|
||||
)?,
|
||||
"Sqrt" => parse_unary_op(node, tensors, "Sqrt", |a| a.sqrt())?,
|
||||
"Transpose" => parse_transpose_node(node, tensors)?,
|
||||
"Concat" => parse_concat_node(node, tensors, shape_exprs, known_values)?,
|
||||
"Floor" => parse_floor_node(node, tensors)?,
|
||||
"Ceil" => parse_ceil_node(node, tensors)?,
|
||||
"Sin" => parse_unary_op(node, tensors, "Sin", |a| a.sin())?,
|
||||
"Neg" => parse_unary_op(node, tensors, "Neg", |a| -a)?,
|
||||
"Cos" => parse_unary_op(node, tensors, "Cos", |a| a.cos())?,
|
||||
"Pow" => parse_binary_broadcast_op(
|
||||
node,
|
||||
tensors,
|
||||
"Pow",
|
||||
|a, b| a.pow(b),
|
||||
shape_exprs,
|
||||
known_values,
|
||||
)?,
|
||||
"Sigmoid" => parse_unary_op(node, tensors, "Sigmoid", |a| a.sigmoid())?,
|
||||
"Tanh" => parse_unary_op(node, tensors, "Tanh", |a| a.tanh())?,
|
||||
"Relu" => parse_unary_op(node, tensors, "Relu", |a| a.relu())?,
|
||||
"Softmax" => parse_softmax_node(node, tensors)?,
|
||||
"Abs" => parse_unary_op(node, tensors, "Abs", |a| a.abs())?,
|
||||
"Reciprocal" => parse_unary_op(node, tensors, "Reciprocal", |a| a.reciprocal())?,
|
||||
"Clip" => parse_clip_node(node, tensors, known_values)?,
|
||||
"Equal" => parse_binary_broadcast_op(
|
||||
node,
|
||||
tensors,
|
||||
"Equal",
|
||||
|a, b| a.eq(b),
|
||||
shape_exprs,
|
||||
known_values,
|
||||
)?,
|
||||
"Where" => parse_where_node(node, tensors)?,
|
||||
"Constant" => {
|
||||
parse_constant_node(node, tensors, cx, weight_data, known_values, shape_exprs)?
|
||||
}
|
||||
"ConstantOfShape" => {
|
||||
parse_constant_of_shape(node, tensors, cx, weight_data, known_values, shape_exprs)?
|
||||
}
|
||||
"Cast" => parse_cast_node(node, tensors, weight_data, known_values, shape_exprs)?,
|
||||
"MatMul" => parse_matmul_node(node, tensors)?,
|
||||
"Reshape" => parse_reshape_node(node, tensors, known_values, shape_exprs)?,
|
||||
"Shape" => parse_shape_node(node, tensors, cx, weight_data, known_values, shape_exprs)?,
|
||||
"Gather" => {
|
||||
parse_gather_node(node, tensors, cx, weight_data, known_values, shape_exprs)?
|
||||
}
|
||||
"GatherND" => parse_gathernd_node(node, tensors, cx, weight_data, known_values)?,
|
||||
"Less" => parse_binary_broadcast_op(
|
||||
node,
|
||||
tensors,
|
||||
"Less",
|
||||
|a, b| a.lt(b),
|
||||
shape_exprs,
|
||||
known_values,
|
||||
)?,
|
||||
"Greater" => parse_binary_broadcast_op(
|
||||
node,
|
||||
tensors,
|
||||
"Greater",
|
||||
|a, b| b.lt(a),
|
||||
shape_exprs,
|
||||
known_values,
|
||||
)?,
|
||||
"LessOrEqual" => parse_binary_broadcast_op(
|
||||
node,
|
||||
tensors,
|
||||
"LessOrEqual",
|
||||
|a, b| a.le(b),
|
||||
shape_exprs,
|
||||
known_values,
|
||||
)?,
|
||||
"GreaterOrEqual" => parse_binary_broadcast_op(
|
||||
node,
|
||||
tensors,
|
||||
"GreaterOrEqual",
|
||||
|a, b| a.ge(b),
|
||||
shape_exprs,
|
||||
known_values,
|
||||
)?,
|
||||
"Not" => parse_not_node(node, tensors)?,
|
||||
"And" => parse_binary_broadcast_op(
|
||||
node,
|
||||
tensors,
|
||||
"And",
|
||||
|a, b| a.cast(DType::F32) * b.cast(DType::F32),
|
||||
shape_exprs,
|
||||
known_values,
|
||||
)?,
|
||||
"Or" => parse_binary_broadcast_op(
|
||||
node,
|
||||
tensors,
|
||||
"Or",
|
||||
|a, b| (a.cast(DType::F32) + b.cast(DType::F32)).minimum_f32(1.0),
|
||||
shape_exprs,
|
||||
known_values,
|
||||
)?,
|
||||
"Xor" => parse_binary_broadcast_op(
|
||||
node,
|
||||
tensors,
|
||||
"Xor",
|
||||
|a, b| a.ne(b),
|
||||
shape_exprs,
|
||||
known_values,
|
||||
)?,
|
||||
"Min" => parse_variadic_broadcast_op(
|
||||
node,
|
||||
tensors,
|
||||
"Min",
|
||||
|a, b| a.minimum(b),
|
||||
shape_exprs,
|
||||
known_values,
|
||||
)?,
|
||||
"Max" => parse_variadic_broadcast_op(
|
||||
node,
|
||||
tensors,
|
||||
"Max",
|
||||
|a, b| a.maximum(b),
|
||||
shape_exprs,
|
||||
known_values,
|
||||
)?,
|
||||
"Identity" => parse_identity(node, tensors, known_values, shape_exprs)?,
|
||||
"Unsqueeze" => parse_unsqueeze_node(node, tensors, known_values, shape_exprs)?,
|
||||
"Squeeze" => parse_squeeze_node(node, tensors, known_values, shape_exprs)?,
|
||||
"ReduceSum" => parse_reduce_op(
|
||||
node,
|
||||
tensors,
|
||||
known_values,
|
||||
"ReduceSum",
|
||||
|t, axes| t.sum(axes),
|
||||
|flat, _n| flat.sum(1),
|
||||
)?,
|
||||
"ReduceMax" => parse_reduce_op(
|
||||
node,
|
||||
tensors,
|
||||
known_values,
|
||||
"ReduceMax",
|
||||
|t, axes| t.max(axes),
|
||||
|flat, _n| flat.max(1),
|
||||
)?,
|
||||
"ReduceMin" => parse_reduce_op(
|
||||
node,
|
||||
tensors,
|
||||
known_values,
|
||||
"ReduceMin",
|
||||
|t, axes| t.min(axes),
|
||||
|flat, _n| flat.min(1),
|
||||
)?,
|
||||
"ReduceMean" => parse_reduce_op(
|
||||
node,
|
||||
tensors,
|
||||
known_values,
|
||||
"ReduceMean",
|
||||
|t, axes| t.mean(axes),
|
||||
|flat, n| flat.sum(1) / n as f32,
|
||||
)?,
|
||||
"Trilu" => parse_trilu_node(node, tensors, cx, known_values)?,
|
||||
"GatherElements" => parse_gather_elements_node(node, tensors)?,
|
||||
"ScatterElements" => parse_scatter_elements_node(node, tensors)?,
|
||||
"ScatterND" => parse_scatter_nd_node(node, tensors)?,
|
||||
"Expand" => parse_expand_node(node, tensors, known_values, shape_exprs)?,
|
||||
"IsNaN" => parse_unary_op(node, tensors, "IsNaN", |a| a.ne(a))?,
|
||||
"LayerNormalization" => parse_layernorm_node(node, tensors)?,
|
||||
"Gemm" => parse_gemm_node(node, tensors)?,
|
||||
"Erf" => parse_erf_node(node, tensors)?,
|
||||
"Slice" => parse_slice_node(node, tensors, known_values, shape_exprs)?,
|
||||
"Split" => parse_split_node(node, tensors, known_values)?,
|
||||
"TopK" => parse_topk_node(node, tensors, known_values)?,
|
||||
"OneHot" => parse_onehot_node(node, tensors, known_values)?,
|
||||
"Range" => parse_range_node(node, tensors, cx, weight_data, known_values, shape_exprs)?,
|
||||
"CumSum" => parse_cumsum_node(node, tensors, known_values)?,
|
||||
"Gelu" => parse_unary_op(node, tensors, "Gelu", |a| a.gelu())?,
|
||||
"Conv" => parse_conv_node(node, tensors)?,
|
||||
"Pad" => parse_pad_node(node, tensors, known_values)?,
|
||||
"Resize" => parse_resize_node(node, tensors, known_values)?,
|
||||
"Tile" => parse_tile_node(node, tensors, known_values)?,
|
||||
"ReduceL2" => parse_reduce_op(
|
||||
node,
|
||||
tensors,
|
||||
known_values,
|
||||
"ReduceL2",
|
||||
|t, axes| (t * t).sum(axes).sqrt(),
|
||||
|flat, _n| (flat * flat).sum(1).sqrt(),
|
||||
)?,
|
||||
"GroupNormalization" => parse_group_norm_node(node, tensors)?,
|
||||
_ => {
|
||||
panic!("Missing Node {}", node.op_type)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,9 +1,5 @@
|
||||
mod compiled_graph;
|
||||
mod dispatch;
|
||||
mod onnx_translator;
|
||||
mod ops_parse;
|
||||
mod runtime;
|
||||
mod util;
|
||||
pub mod typed_data;
|
||||
|
||||
// PT2 modules
|
||||
mod pt2_compiled_model;
|
||||
@@ -15,59 +11,40 @@ mod translator;
|
||||
use compiled_graph::CompiledGraph;
|
||||
use pt2_compiled_model::process_pt2;
|
||||
use pyo3::prelude::*;
|
||||
use std::collections::HashMap;
|
||||
|
||||
fn validate_backend(backend: &str) -> PyResult<()> {
|
||||
match backend {
|
||||
"native" => Ok(()),
|
||||
#[cfg(feature = "cuda")]
|
||||
"cuda" => Ok(()),
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
"cuda" => Err(pyo3::exceptions::PyValueError::new_err(
|
||||
"CUDA backend requested, but this luminal extension was built without the `cuda` feature. Rebuild with `maturin develop --features cuda -r` or use backend='native'.",
|
||||
)),
|
||||
_ => {
|
||||
#[cfg(feature = "cuda")]
|
||||
{
|
||||
Err(pyo3::exceptions::PyValueError::new_err(format!(
|
||||
"Invalid backend '{}'. Must be 'native' or 'cuda'",
|
||||
backend
|
||||
)))
|
||||
}
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
{
|
||||
Err(pyo3::exceptions::PyValueError::new_err(format!(
|
||||
"Invalid backend '{}'. This build only supports 'native'. Rebuild with the `cuda` feature to enable 'cuda'.",
|
||||
backend
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
#[pyo3(signature = (path, backend="native", search_iters=10, weight_device_ptrs=None))]
|
||||
fn process_onnx(
|
||||
path: &str,
|
||||
backend: &str,
|
||||
search_iters: usize,
|
||||
weight_device_ptrs: Option<HashMap<String, (u64, usize)>>,
|
||||
) -> PyResult<CompiledGraph> {
|
||||
validate_backend(backend)?;
|
||||
|
||||
onnx_translator::compile_onnx(
|
||||
path,
|
||||
backend,
|
||||
weight_device_ptrs.unwrap_or_default(),
|
||||
search_iters,
|
||||
)
|
||||
.map_err(pyo3::exceptions::PyRuntimeError::new_err)
|
||||
}
|
||||
use pyo3::types::PyCapsule;
|
||||
|
||||
#[pymodule]
|
||||
fn luminal(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_function(wrap_pyfunction!(process_onnx, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(process_pt2, m)?)?;
|
||||
m.add_class::<CompiledGraph>()?;
|
||||
m.add_function(wrap_pyfunction!(_native_factory_capsule, m)?)?;
|
||||
#[cfg(feature = "cuda")]
|
||||
m.add_function(wrap_pyfunction!(_cuda_lite_factory_capsule, m)?)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Factory capsule helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Wrapper to put a function pointer into a PyCapsule.
|
||||
#[allow(dead_code)]
|
||||
struct FnPtrWrapper(pub *const std::ffi::c_void);
|
||||
unsafe impl Send for FnPtrWrapper {}
|
||||
|
||||
/// PyCapsule wrapping the native (CPU) backend factory.
|
||||
#[pyfunction]
|
||||
fn _native_factory_capsule<'py>(py: Python<'py>) -> PyResult<Bound<'py, PyCapsule>> {
|
||||
let fptr = ::luminal::dyn_backend::native_factory as *const std::ffi::c_void;
|
||||
let name = ::luminal::dyn_backend::BACKEND_FACTORY_CAPSULE_NAME.to_owned();
|
||||
PyCapsule::new(py, FnPtrWrapper(fptr), Some(name))
|
||||
}
|
||||
|
||||
/// PyCapsule wrapping the cuda_lite backend factory.
|
||||
#[cfg(feature = "cuda")]
|
||||
#[pyfunction]
|
||||
fn _cuda_lite_factory_capsule<'py>(py: Python<'py>) -> PyResult<Bound<'py, PyCapsule>> {
|
||||
let fptr = luminal_cuda_lite::dyn_backend::cuda_lite_factory as *const std::ffi::c_void;
|
||||
let name = ::luminal::dyn_backend::BACKEND_FACTORY_CAPSULE_NAME.to_owned();
|
||||
PyCapsule::new(py, FnPtrWrapper(fptr), Some(name))
|
||||
}
|
||||
|
||||
@@ -1,283 +0,0 @@
|
||||
use luminal::{
|
||||
prelude::{
|
||||
tracing::{Level, span, trace},
|
||||
*,
|
||||
},
|
||||
shape::Expression,
|
||||
};
|
||||
use onnx_protobuf::ModelProto;
|
||||
use protobuf::Message;
|
||||
use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
fs,
|
||||
path::Path,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
compiled_graph::{CompiledGraph, GraphTranslation, WeightData},
|
||||
dispatch::process_onnx_nodes,
|
||||
util::{
|
||||
DimParamMap, get_shape_for_onnx_value, get_shape_for_onnx_value_expr,
|
||||
load_all_tensor_floats, load_initializer_as_f32,
|
||||
},
|
||||
};
|
||||
|
||||
/// Load, validate, translate, and compile an ONNX model.
|
||||
///
|
||||
/// This is the ONNX counterpart of `pt2_compiled_model::compile_pt2()`.
|
||||
pub fn compile_onnx(
|
||||
path: &str,
|
||||
backend: &str,
|
||||
weight_device_ptrs: HashMap<String, (u64, usize)>,
|
||||
search_iters: usize,
|
||||
) -> Result<CompiledGraph, String> {
|
||||
let data = fs::read(path).map_err(|e| format!("Failed to read file: {}", e))?;
|
||||
let model_directory = Path::new(path).parent().unwrap_or(Path::new("."));
|
||||
let model = ModelProto::parse_from_bytes(&data)
|
||||
.map_err(|e| format!("Failed to parse ONNX model: {}", e))?;
|
||||
|
||||
let opset_version = model
|
||||
.opset_import
|
||||
.iter()
|
||||
.find(|entry| entry.domain.is_empty())
|
||||
.map(|entry| entry.version);
|
||||
|
||||
match opset_version {
|
||||
Some(20) => {}
|
||||
Some(v) => {
|
||||
return Err(format!(
|
||||
"Unsupported ONNX opset version {v}. Only opset 20 is supported."
|
||||
));
|
||||
}
|
||||
None => {
|
||||
return Err(
|
||||
"No ONNX opset version found in model. Only opset 20 is supported.".to_string(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
let (translation, mut weights) = translate_onnx(model, model_directory)?;
|
||||
weights.device_ptrs = weight_device_ptrs;
|
||||
CompiledGraph::parse_graph(translation, weights, backend, search_iters)
|
||||
}
|
||||
|
||||
/// Translate an ONNX model into a format-neutral GraphTranslation + WeightData.
|
||||
pub fn translate_onnx(
|
||||
model: ModelProto,
|
||||
model_directory: &Path,
|
||||
) -> Result<(GraphTranslation, WeightData), String> {
|
||||
let _span = span!(Level::TRACE, "ONNX Graph Translation").entered();
|
||||
let onnx_graph = &model.graph;
|
||||
let mut cx = Graph::new();
|
||||
let mut tensors: HashMap<String, GraphTensor> = HashMap::new();
|
||||
|
||||
// Dynamic dimension tracking
|
||||
let mut dim_param_map: DimParamMap = HashMap::new();
|
||||
let mut next_char = 'a';
|
||||
|
||||
// Separate initializers (weights) from true user inputs
|
||||
let initializer_names: HashSet<&str> = onnx_graph
|
||||
.initializer
|
||||
.iter()
|
||||
.map(|t| t.name.as_str())
|
||||
.collect();
|
||||
|
||||
let input_names: Vec<String> = onnx_graph
|
||||
.input
|
||||
.iter()
|
||||
.filter(|inp| !initializer_names.contains(inp.name.as_str()))
|
||||
.map(|inp| inp.name.clone())
|
||||
.collect();
|
||||
|
||||
// Create input tensors with dynamic dimension support
|
||||
for input in &onnx_graph.input {
|
||||
let shape_exprs = get_shape_for_onnx_value_expr(input, &mut dim_param_map, &mut next_char);
|
||||
if shape_exprs.is_empty() {
|
||||
let shape = get_shape_for_onnx_value(input);
|
||||
if shape.is_empty() {
|
||||
trace!("Input {} skipped because it is empty", input.name.clone());
|
||||
continue;
|
||||
}
|
||||
let tensor = cx.named_tensor(input.name.clone(), shape);
|
||||
trace!("Input {} added to tensors", input.name.clone());
|
||||
tensors.insert(input.name.clone(), tensor);
|
||||
continue;
|
||||
}
|
||||
let tensor = cx.named_tensor(input.name.clone(), shape_exprs);
|
||||
trace!("Input {} added to tensors", input.name.clone());
|
||||
tensors.insert(input.name.clone(), tensor);
|
||||
}
|
||||
|
||||
// Create initializer (weight) tensors
|
||||
for init in &onnx_graph.initializer {
|
||||
if !tensors.contains_key(&init.name) {
|
||||
let mut shape: Vec<usize> = init.dims.iter().map(|&d| d as usize).collect();
|
||||
if shape.is_empty() {
|
||||
shape = vec![1];
|
||||
}
|
||||
let tensor = cx.named_tensor(init.name.clone(), shape);
|
||||
tensors.insert(init.name.clone(), tensor);
|
||||
}
|
||||
}
|
||||
|
||||
// Load small constants for constant folding
|
||||
let mut known_values: HashMap<String, Vec<f32>> = HashMap::new();
|
||||
for init in &onnx_graph.initializer {
|
||||
let n_elements: usize = init
|
||||
.dims
|
||||
.iter()
|
||||
.map(|&d| d as usize)
|
||||
.product::<usize>()
|
||||
.max(1);
|
||||
if n_elements <= 32 {
|
||||
if let Some(floats) = load_initializer_as_f32(init) {
|
||||
known_values.insert(init.name.clone(), floats);
|
||||
} else {
|
||||
panic!("Unable to load initializer values for {:?}", init.name);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Shape expressions for propagating symbolic shapes through ONNX graphs
|
||||
let mut shape_exprs: HashMap<String, Vec<Expression>> = HashMap::new();
|
||||
|
||||
// Accumulates constant node data from process_onnx_nodes
|
||||
let mut constant_data: Vec<(String, Vec<f32>)> = Vec::new();
|
||||
|
||||
// Process computation nodes
|
||||
process_onnx_nodes(
|
||||
&onnx_graph.node,
|
||||
&mut tensors,
|
||||
&mut cx,
|
||||
&mut constant_data,
|
||||
&mut known_values,
|
||||
&mut shape_exprs,
|
||||
)
|
||||
.map_err(|e| format!("process_onnx_nodes failed: {}", e))?;
|
||||
|
||||
// Mark weight/constant tensors as persistent so their buffers survive execute()
|
||||
for (name, gt) in &tensors {
|
||||
if !input_names.contains(name) {
|
||||
gt.persist();
|
||||
}
|
||||
}
|
||||
|
||||
// Mark graph outputs (must happen before build_search_space)
|
||||
let mut output_names = Vec::new();
|
||||
let mut output_shape_exprs = Vec::new();
|
||||
for output_vi in &onnx_graph.output {
|
||||
if let Some(>) = tensors.get(&output_vi.name) {
|
||||
// Force contiguous if the shape tracker is a non-contiguous view
|
||||
let gt = if gt.shape != gt.shape.contiguous() {
|
||||
let contiguous = gt * 1.0;
|
||||
tensors.insert(output_vi.name.clone(), contiguous);
|
||||
contiguous
|
||||
} else {
|
||||
gt
|
||||
};
|
||||
gt.output();
|
||||
let dims = gt.dims();
|
||||
output_shape_exprs.push(dims.clone());
|
||||
|
||||
let shape: Vec<usize> = dims.iter().map(|d| d.to_usize().unwrap_or(1)).collect();
|
||||
if shape.is_empty() {
|
||||
return Err(format!(
|
||||
"Output tensor '{}' has no shape information in the ONNX model",
|
||||
output_vi.name
|
||||
));
|
||||
}
|
||||
output_names.push(output_vi.name.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Set initial dynamic dimension values from example input shapes
|
||||
let has_dynamic = !dim_param_map.is_empty();
|
||||
if has_dynamic {
|
||||
for input in &onnx_graph.input {
|
||||
if initializer_names.contains(input.name.as_str()) {
|
||||
continue;
|
||||
}
|
||||
let concrete_shape = get_shape_for_onnx_value(input);
|
||||
let expr_shape =
|
||||
get_shape_for_onnx_value_expr(input, &mut dim_param_map, &mut next_char);
|
||||
for (expr, concrete) in expr_shape.iter().zip(concrete_shape.iter()) {
|
||||
if expr.to_usize().is_none()
|
||||
&& let Some(ch) = dim_param_map
|
||||
.values()
|
||||
.find(|&&ch| Expression::from(ch) == *expr)
|
||||
{
|
||||
cx.set_dim(*ch, *concrete);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Build weight data: initializers + constants from process_onnx_nodes
|
||||
let mut weights: Vec<(String, Vec<f32>)> = Vec::new();
|
||||
for (name, floats) in load_all_tensor_floats(&onnx_graph.initializer, model_directory) {
|
||||
if let Some(f) = floats {
|
||||
weights.push((name, f));
|
||||
}
|
||||
}
|
||||
weights.extend(constant_data);
|
||||
|
||||
// Build tensor sizes for CUDA dummy data allocation
|
||||
let mut tensor_sizes: HashMap<String, usize> = HashMap::new();
|
||||
for input in &onnx_graph.input {
|
||||
if !initializer_names.contains(input.name.as_str()) {
|
||||
let shape = get_shape_for_onnx_value(input);
|
||||
let n: usize = shape.iter().product::<usize>().max(1);
|
||||
tensor_sizes.insert(input.name.clone(), n);
|
||||
}
|
||||
}
|
||||
for init in &onnx_graph.initializer {
|
||||
let n: usize = init
|
||||
.dims
|
||||
.iter()
|
||||
.map(|&d| d as usize)
|
||||
.product::<usize>()
|
||||
.max(1);
|
||||
tensor_sizes.insert(init.name.clone(), n);
|
||||
}
|
||||
for (name, data) in &weights {
|
||||
if !tensor_sizes.contains_key(name) {
|
||||
tensor_sizes.insert(name.clone(), data.len());
|
||||
}
|
||||
}
|
||||
|
||||
// Collect tensor name → NodeIndex mapping
|
||||
let tensor_ids: HashMap<String, NodeIndex> = tensors
|
||||
.iter()
|
||||
.map(|(name, gt)| (name.clone(), gt.id))
|
||||
.collect();
|
||||
|
||||
// Build input_shape_exprs for user inputs (needed for auto-dim detection)
|
||||
let input_shape_exprs: Vec<Vec<Expression>> = input_names
|
||||
.iter()
|
||||
.map(|name| {
|
||||
if let Some(>) = tensors.get(name) {
|
||||
gt.dims()
|
||||
} else {
|
||||
vec![]
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
let translation = GraphTranslation {
|
||||
graph: cx,
|
||||
tensor_ids,
|
||||
input_names,
|
||||
output_names,
|
||||
output_shape_exprs,
|
||||
input_shape_exprs,
|
||||
dim_param_map,
|
||||
};
|
||||
|
||||
let weight_data = WeightData {
|
||||
weights,
|
||||
tensor_sizes,
|
||||
device_ptrs: HashMap::new(),
|
||||
};
|
||||
|
||||
Ok((translation, weight_data))
|
||||
}
|
||||
@@ -1,187 +0,0 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use luminal::{
|
||||
prelude::{tracing::trace, *},
|
||||
shape::Expression,
|
||||
};
|
||||
use onnx_protobuf::NodeProto;
|
||||
|
||||
use crate::util::{broadcast_to_expr, compute_broadcast_shape_expr};
|
||||
|
||||
/// Handle Where node: conditional select — output[i] = condition[i] ? x[i] : y[i]
|
||||
///
|
||||
/// ONNX Where uses numpy-style broadcasting across all three inputs.
|
||||
pub fn parse_where_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
) -> Result<(), String> {
|
||||
assert!(node.input.len() == 3, "Where should have 3 inputs");
|
||||
let condition = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("Where: missing condition tensor '{}'", node.input[0]))?;
|
||||
let x = *tensors
|
||||
.get(&node.input[1])
|
||||
.ok_or_else(|| format!("Where: missing X tensor '{}'", node.input[1]))?;
|
||||
let y = *tensors
|
||||
.get(&node.input[2])
|
||||
.ok_or_else(|| format!("Where: missing Y tensor '{}'", node.input[2]))?;
|
||||
|
||||
let output_name = &node.output[0];
|
||||
|
||||
// ONNX Where broadcasts all 3 inputs to a common shape
|
||||
let bc_shape = compute_broadcast_shape_expr(
|
||||
&condition.dims(),
|
||||
&compute_broadcast_shape_expr(&x.dims(), &y.dims()),
|
||||
);
|
||||
let condition = broadcast_to_expr(condition, &bc_shape);
|
||||
let x = broadcast_to_expr(x, &bc_shape);
|
||||
let y = broadcast_to_expr(y, &bc_shape);
|
||||
|
||||
let result = x.cond(condition, y);
|
||||
tensors.insert(output_name.clone(), result);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn parse_binary_broadcast_op(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
op_name: &str,
|
||||
op: impl Fn(GraphTensor, GraphTensor) -> GraphTensor,
|
||||
shape_exprs: &mut HashMap<String, Vec<Expression>>,
|
||||
known_values: &HashMap<String, Vec<f32>>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: {} Node", op_name);
|
||||
assert!(
|
||||
node.input.len() == 2,
|
||||
"{} should have 2 inputs, got {}",
|
||||
op_name,
|
||||
node.input.len()
|
||||
);
|
||||
assert!(
|
||||
node.output.len() == 1,
|
||||
"{} should have 1 output, got {}",
|
||||
op_name,
|
||||
node.output.len()
|
||||
);
|
||||
// Shape-only path: if any input is shape-only (not in tensors), do Expression arithmetic
|
||||
let a_missing = !tensors.contains_key(&node.input[0]);
|
||||
let b_missing = !tensors.contains_key(&node.input[1]);
|
||||
if a_missing || b_missing {
|
||||
// At least one input is shape-only. Do shape_exprs arithmetic and return.
|
||||
let se_a = shape_exprs.get(&node.input[0]).cloned().or_else(|| {
|
||||
known_values
|
||||
.get(&node.input[0])
|
||||
.map(|kv| kv.iter().map(|&v| Expression::from(v as usize)).collect())
|
||||
});
|
||||
let se_b = shape_exprs.get(&node.input[1]).cloned().or_else(|| {
|
||||
known_values
|
||||
.get(&node.input[1])
|
||||
.map(|kv| kv.iter().map(|&v| Expression::from(v as usize)).collect())
|
||||
});
|
||||
if let (Some(se_a), Some(se_b)) = (se_a, se_b)
|
||||
&& se_a.len() == 1
|
||||
&& se_b.len() == 1
|
||||
{
|
||||
let result_expr = match op_name {
|
||||
"Add" => Some(se_a[0] + se_b[0]),
|
||||
"Sub" => Some(se_a[0] - se_b[0]),
|
||||
"Mul" => Some(se_a[0] * se_b[0]),
|
||||
"Div" => Some(se_a[0] / se_b[0]),
|
||||
_ => None,
|
||||
};
|
||||
if let Some(expr) = result_expr {
|
||||
shape_exprs.insert(node.output[0].clone(), vec![expr]);
|
||||
}
|
||||
}
|
||||
trace!("Finished parse: {} Node (shape-only)", op_name);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let a = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("{}: missing input '{}'", op_name, node.input[0]))?;
|
||||
let b = *tensors
|
||||
.get(&node.input[1])
|
||||
.ok_or_else(|| format!("{}: missing input '{}'", op_name, node.input[1]))?;
|
||||
let broadcast_shape = compute_broadcast_shape_expr(&a.dims(), &b.dims());
|
||||
let a_bc = broadcast_to_expr(a, &broadcast_shape);
|
||||
let b_bc = broadcast_to_expr(b, &broadcast_shape);
|
||||
let result = op(a_bc, b_bc);
|
||||
tensors.insert(node.output[0].clone(), result);
|
||||
|
||||
// Propagate shape_exprs for scalar shape arithmetic (e.g., Add(1, seq_len))
|
||||
// At least one input must be in shape_exprs; the other can come from known_values.
|
||||
let has_shape_expr =
|
||||
shape_exprs.contains_key(&node.input[0]) || shape_exprs.contains_key(&node.input[1]);
|
||||
if has_shape_expr {
|
||||
let se_a = shape_exprs.get(&node.input[0]).cloned().or_else(|| {
|
||||
known_values
|
||||
.get(&node.input[0])
|
||||
.map(|kv| kv.iter().map(|&v| Expression::from(v as usize)).collect())
|
||||
});
|
||||
let se_b = shape_exprs.get(&node.input[1]).cloned().or_else(|| {
|
||||
known_values
|
||||
.get(&node.input[1])
|
||||
.map(|kv| kv.iter().map(|&v| Expression::from(v as usize)).collect())
|
||||
});
|
||||
if let (Some(se_a), Some(se_b)) = (se_a, se_b)
|
||||
&& se_a.len() == 1
|
||||
&& se_b.len() == 1
|
||||
{
|
||||
let result_expr = match op_name {
|
||||
"Add" => Some(se_a[0] + se_b[0]),
|
||||
"Sub" => Some(se_a[0] - se_b[0]),
|
||||
"Mul" => Some(se_a[0] * se_b[0]),
|
||||
"Div" => Some(se_a[0] / se_b[0]),
|
||||
_ => None,
|
||||
};
|
||||
if let Some(expr) = result_expr {
|
||||
shape_exprs.insert(node.output[0].clone(), vec![expr]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
trace!("Finished parse: {} Node", op_name);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn parse_variadic_broadcast_op(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
op_name: &str,
|
||||
op: impl Fn(GraphTensor, GraphTensor) -> GraphTensor,
|
||||
_shape_exprs: &mut HashMap<String, Vec<Expression>>,
|
||||
_known_values: &HashMap<String, Vec<f32>>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: {} Node", op_name);
|
||||
assert!(
|
||||
node.input.len() >= 2,
|
||||
"{} needs at least two inputs, got {}",
|
||||
op_name,
|
||||
node.input.len()
|
||||
);
|
||||
assert!(
|
||||
node.output.len() == 1,
|
||||
"{} nodes only have one output, got {}",
|
||||
op_name,
|
||||
node.output.len()
|
||||
);
|
||||
|
||||
let mut result = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("{}: missing input tensor '{}'", op_name, node.input[0]))?;
|
||||
|
||||
for input_name in &node.input[1..] {
|
||||
let rhs = *tensors
|
||||
.get(input_name)
|
||||
.ok_or_else(|| format!("{}: missing input tensor '{}'", op_name, input_name))?;
|
||||
let broadcast_shape = compute_broadcast_shape_expr(&result.dims(), &rhs.dims());
|
||||
let lhs_bc = broadcast_to_expr(result, &broadcast_shape);
|
||||
let rhs_bc = broadcast_to_expr(rhs, &broadcast_shape);
|
||||
result = op(lhs_bc, rhs_bc);
|
||||
}
|
||||
|
||||
tensors.insert(node.output[0].clone(), result);
|
||||
trace!("Finished parse: {} Node", op_name);
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,194 +0,0 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use luminal::{
|
||||
prelude::{tracing::trace, *},
|
||||
shape::Expression,
|
||||
};
|
||||
use onnx_protobuf::NodeProto;
|
||||
|
||||
use crate::util::get_int_attr;
|
||||
|
||||
/// Get an integer-list attribute from a node, with a default value applied per element.
|
||||
fn get_ints_attr(node: &NodeProto, name: &str, default_elem: i64, spatial: usize) -> Vec<usize> {
|
||||
for attr in &node.attribute {
|
||||
if attr.name == name {
|
||||
return attr.ints.iter().map(|&v| v as usize).collect();
|
||||
}
|
||||
}
|
||||
vec![default_elem as usize; spatial]
|
||||
}
|
||||
|
||||
/// Parse an ONNX Conv node.
|
||||
///
|
||||
/// Supports N-dimensional convolution (1D, 2D, 3D) with group=1.
|
||||
/// Uses the unfold-based approach from `luminal_nn::ConvND`.
|
||||
///
|
||||
/// Input layout: [batch, C_in, spatial...]
|
||||
/// Weight layout: [C_out, C_in/group, kernel...]
|
||||
/// Optional bias: [C_out]
|
||||
pub fn parse_conv_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: Conv Node");
|
||||
|
||||
assert!(
|
||||
node.input.len() >= 2,
|
||||
"Conv needs at least 2 inputs (X, W), got {}",
|
||||
node.input.len()
|
||||
);
|
||||
|
||||
let x = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("Conv: missing input X '{}'", node.input[0]))?;
|
||||
let w = *tensors
|
||||
.get(&node.input[1])
|
||||
.ok_or_else(|| format!("Conv: missing weight W '{}'", node.input[1]))?;
|
||||
let bias = if node.input.len() > 2 && !node.input[2].is_empty() {
|
||||
Some(
|
||||
*tensors
|
||||
.get(&node.input[2])
|
||||
.ok_or_else(|| format!("Conv: missing bias B '{}'", node.input[2]))?,
|
||||
)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let x_dims = x.dims();
|
||||
let w_dims = w.dims();
|
||||
let rank = x_dims.len();
|
||||
assert!(
|
||||
rank >= 3,
|
||||
"Conv: input must be at least 3D (batch, channels, spatial...), got {rank}D"
|
||||
);
|
||||
|
||||
let spatial = rank - 2; // number of spatial dimensions
|
||||
|
||||
// Parse attributes
|
||||
let kernel_shape = get_ints_attr(node, "kernel_shape", 1, spatial);
|
||||
let strides = get_ints_attr(node, "strides", 1, spatial);
|
||||
let dilations = get_ints_attr(node, "dilations", 1, spatial);
|
||||
let group = get_int_attr(node, "group", 1) as usize;
|
||||
|
||||
// Parse pads: ONNX format is [begin_0, begin_1, ..., end_0, end_1, ...]
|
||||
let pads_flat = get_ints_attr(node, "pads", 0, 2 * spatial);
|
||||
let mut pads_begin = vec![0usize; spatial];
|
||||
let mut pads_end = vec![0usize; spatial];
|
||||
if pads_flat.len() == 2 * spatial {
|
||||
pads_begin[..spatial].copy_from_slice(&pads_flat[..spatial]);
|
||||
pads_end[..spatial].copy_from_slice(&pads_flat[spatial..(spatial + spatial)]);
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
group, 1,
|
||||
"Conv: only group=1 is currently supported, got {group}"
|
||||
);
|
||||
|
||||
// Get channel dimensions
|
||||
let ch_out = w_dims[0]
|
||||
.to_usize()
|
||||
.ok_or("Conv: weight C_out must be concrete")?;
|
||||
let ch_in = x_dims[1]
|
||||
.to_usize()
|
||||
.ok_or("Conv: input C_in must be concrete")?;
|
||||
|
||||
let kernel_product: usize = kernel_shape.iter().product();
|
||||
|
||||
// Reshape weight from ONNX [C_out, C_in, *kernel] to [C_out, C_in * kernel_product]
|
||||
let w_reshaped = {
|
||||
let mut wt = w;
|
||||
wt.shape = ShapeTracker::new(vec![ch_out, ch_in * kernel_product]);
|
||||
wt
|
||||
};
|
||||
|
||||
// Pad spatial dimensions
|
||||
let mut padding: Vec<(Expression, Expression)> = vec![(0.into(), 0.into()); rank];
|
||||
for i in 0..spatial {
|
||||
let axis = 2 + i; // batch=0, channel=1, spatial starts at 2
|
||||
padding[axis] = (
|
||||
Expression::from(pads_begin[i]),
|
||||
Expression::from(pads_end[i]),
|
||||
);
|
||||
}
|
||||
let padded = x.pad(padding, 0.0);
|
||||
|
||||
// Build unfold parameters (ones for batch/channel, actual for spatial)
|
||||
let mut kernel_full = vec![1usize; rank];
|
||||
let mut stride_full = vec![1usize; rank];
|
||||
let mut dilation_full = vec![1usize; rank];
|
||||
for i in 0..spatial {
|
||||
let axis = 2 + i;
|
||||
kernel_full[axis] = kernel_shape[i];
|
||||
stride_full[axis] = strides[i];
|
||||
dilation_full[axis] = dilations[i];
|
||||
}
|
||||
|
||||
let unfolded = padded.unfold(kernel_full, stride_full, dilation_full);
|
||||
// unfolded shape: [win_N, win_C, win_spatial..., k_batch=1, k_chan=1, k_spatial...]
|
||||
// (2*rank dimensions total)
|
||||
|
||||
// Step 1: Permute to [N, win_spatial..., C_in, k_batch, k_chan, k_spatial...]
|
||||
// This groups: batch | output spatial | channel+kernel (for merging)
|
||||
let mut perm: Vec<usize> = Vec::with_capacity(2 * rank);
|
||||
perm.push(0); // win_N (batch)
|
||||
perm.extend(2..2 + spatial); // win_spatial dims
|
||||
perm.push(1); // win_C (= C_in)
|
||||
perm.extend(rank..2 * rank); // all kernel dims: k_batch=1, k_chan=1, k_spatial...
|
||||
let permuted = unfolded.permute(perm);
|
||||
|
||||
// Step 2: Capture output spatial dimensions (win_spatial sizes)
|
||||
let output_spatial_dims: Vec<Expression> = permuted.dims()[1..1 + spatial].to_vec();
|
||||
|
||||
// Step 3: Merge all channel+kernel dims into one (C_in * kernel_product)
|
||||
// From index (1+spatial) to end there are (1 + 2 + spatial) dims to merge
|
||||
let mut patches = permuted;
|
||||
let target_before_spatial_merge = 2 + spatial; // [N, spatial..., merged_patch]
|
||||
while patches.dims().len() > target_before_spatial_merge {
|
||||
let last = patches.dims().len();
|
||||
patches = patches.merge_dims(last - 2, last - 1);
|
||||
}
|
||||
// patches: [N, spatial_0, ..., spatial_{s-1}, C_in * kernel_product]
|
||||
|
||||
// Step 4: Merge spatial dims into one
|
||||
for _ in 1..spatial {
|
||||
patches = patches.merge_dims(1, 2);
|
||||
}
|
||||
// patches: [N, spatial_product, C_in * kernel_product]
|
||||
|
||||
// Step 5: Matmul with weight
|
||||
let mut out = patches.matmul(w_reshaped.permute((1, 0)));
|
||||
// out: [N, spatial_product, C_out]
|
||||
|
||||
// Step 6: Restore spatial dimensions via split_dims
|
||||
// Split from innermost spatial dim first (reverse order, skip outermost)
|
||||
for i in (1..spatial).rev() {
|
||||
out = out.split_dims(1, output_spatial_dims[i]);
|
||||
}
|
||||
// out: [N, spatial_0, spatial_1, ..., spatial_{s-1}, C_out]
|
||||
|
||||
// Step 7: Move C_out from last position to position 1 (after batch)
|
||||
let mut final_order: Vec<usize> = Vec::with_capacity(2 + spatial);
|
||||
final_order.push(0); // batch
|
||||
final_order.push(1 + spatial); // C_out
|
||||
final_order.extend(1..1 + spatial); // spatial dims
|
||||
out = out.permute(final_order);
|
||||
// out: [N, C_out, spatial_0, ..., spatial_{s-1}]
|
||||
|
||||
// Add bias if present: bias shape [C_out], broadcast to [1, C_out, 1, 1, ...]
|
||||
if let Some(b) = bias {
|
||||
let mut bias_expanded = b;
|
||||
// Expand to [1, C_out, 1, 1, ...]
|
||||
bias_expanded = bias_expanded.expand_dim(0, 1); // batch dim
|
||||
for i in 0..spatial {
|
||||
let out_dims = out.dims();
|
||||
let spatial_size = out_dims[2 + i];
|
||||
bias_expanded = bias_expanded.expand_dim(2 + i, spatial_size);
|
||||
}
|
||||
out += bias_expanded;
|
||||
}
|
||||
|
||||
tensors.insert(node.output[0].clone(), out);
|
||||
|
||||
trace!("Finished parse: Conv Node");
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,70 +0,0 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use luminal::prelude::{tracing::trace, *};
|
||||
use onnx_protobuf::NodeProto;
|
||||
|
||||
use crate::util::{broadcast_to_expr, get_float_attr, get_int_attr};
|
||||
|
||||
pub fn parse_matmul_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Started parse: MatMul Node");
|
||||
assert!(node.input.len() == 2, "MatMul should have exactly 2 inputs");
|
||||
let a = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("MatMul: missing input tensor '{}'", node.input[0]))?;
|
||||
let b = *tensors
|
||||
.get(&node.input[1])
|
||||
.ok_or_else(|| format!("MatMul: missing input tensor '{}'", node.input[1]))?;
|
||||
|
||||
//TODO: enforce some kind of check here that they are broadcastable
|
||||
let result = a.matmul(b);
|
||||
let output_name = &node.output[0];
|
||||
tensors.insert(output_name.clone(), result);
|
||||
trace!("Finished parse: MatMul Node");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle Gemm node: Y = alpha * (transA ? A.T : A) @ (transB ? B.T : B) + beta * C
|
||||
///
|
||||
/// Attributes: transA (default 0), transB (default 0), alpha (default 1.0), beta (default 1.0)
|
||||
/// Input C (bias) is optional.
|
||||
pub fn parse_gemm_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Started parse: Gemm Node");
|
||||
let a = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("Gemm: missing input A '{}'", node.input[0]))?;
|
||||
let b = *tensors
|
||||
.get(&node.input[1])
|
||||
.ok_or_else(|| format!("Gemm: missing input B '{}'", node.input[1]))?;
|
||||
|
||||
let trans_a = get_int_attr(node, "transA", 0) != 0;
|
||||
let trans_b = get_int_attr(node, "transB", 0) != 0;
|
||||
let alpha = get_float_attr(node, "alpha", 1.0);
|
||||
let beta = get_float_attr(node, "beta", 1.0);
|
||||
|
||||
let a_mat = if trans_a { a.permute(vec![1, 0]) } else { a };
|
||||
let b_mat = if trans_b { b.permute(vec![1, 0]) } else { b };
|
||||
|
||||
let mut result = a_mat.matmul(b_mat);
|
||||
if alpha != 1.0 {
|
||||
result *= alpha;
|
||||
}
|
||||
|
||||
if node.input.len() > 2 && !node.input[2].is_empty() {
|
||||
let c = *tensors
|
||||
.get(&node.input[2])
|
||||
.ok_or_else(|| format!("Gemm: missing bias C '{}'", node.input[2]))?;
|
||||
let c_scaled = if beta != 1.0 { c * beta } else { c };
|
||||
let result_shape = result.dims();
|
||||
result += broadcast_to_expr(c_scaled, &result_shape);
|
||||
}
|
||||
|
||||
tensors.insert(node.output[0].clone(), result);
|
||||
trace!("Finished parse: Gemm Node");
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,15 +0,0 @@
|
||||
pub mod binary;
|
||||
pub mod convolution;
|
||||
pub mod matmul;
|
||||
pub mod movement;
|
||||
pub mod reduction;
|
||||
pub mod tensor;
|
||||
pub mod unary;
|
||||
|
||||
pub use binary::*;
|
||||
pub use convolution::*;
|
||||
pub use matmul::*;
|
||||
pub use movement::*;
|
||||
pub use reduction::*;
|
||||
pub use tensor::*;
|
||||
pub use unary::*;
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,172 +0,0 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use luminal::prelude::{tracing::trace, *};
|
||||
use onnx_protobuf::NodeProto;
|
||||
|
||||
use crate::util::get_int_attr;
|
||||
|
||||
/// Handle TopK node: return the top-k values and indices along an axis.
|
||||
///
|
||||
/// output[0] = values (F32), output[1] = indices (Int, can be empty/unused).
|
||||
/// For largest=true (default): uses topk_indexes + gather_elements.
|
||||
/// For largest=false: uses argsort(ascending).slice_along(..k) + gather_elements.
|
||||
/// Indices output is stored as-is (Int dtype); downstream Cast handles F32 conversion.
|
||||
/// The "sorted" attribute is ignored — output is always sorted.
|
||||
pub fn parse_topk_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
known_values: &mut HashMap<String, Vec<f32>>,
|
||||
) -> Result<(), String> {
|
||||
let x = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("TopK: missing input '{}'", node.input[0]))?;
|
||||
let k = known_values
|
||||
.get(&node.input[1])
|
||||
.ok_or("TopK: k must be constant")?[0] as usize;
|
||||
|
||||
let rank = x.dims().len() as i64;
|
||||
let raw_axis = get_int_attr(node, "axis", -1);
|
||||
let axis = if raw_axis < 0 {
|
||||
(raw_axis + rank) as usize
|
||||
} else {
|
||||
raw_axis as usize
|
||||
};
|
||||
|
||||
let largest = get_int_attr(node, "largest", 1) != 0;
|
||||
|
||||
// Compute full argsort, then gather all sorted values, then slice both to top-k.
|
||||
// This avoids passing a non-contiguous sliced index tensor into gather_elements,
|
||||
// which triggers a CUDA kernel bug when data and index sizes differ along the axis.
|
||||
let full_argsort = x.argsort(axis, largest);
|
||||
let indices = full_argsort.slice_along(..k, axis);
|
||||
let values = x.gather_elements(full_argsort, axis).slice_along(..k, axis);
|
||||
|
||||
// ONNX output[0] = values, output[1] = indices
|
||||
if !node.output[0].is_empty() {
|
||||
tensors.insert(node.output[0].clone(), values);
|
||||
}
|
||||
if node.output.len() > 1 && !node.output[1].is_empty() {
|
||||
// Force materialization of Int indices; downstream Cast(INT64→FLOAT) handles the
|
||||
// F32 conversion via the *1.0 workaround in parse_cast_node.
|
||||
tensors.insert(node.output[1].clone(), indices * 1.0);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn parse_reduce_op(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
known_values: &mut HashMap<String, Vec<f32>>,
|
||||
op_name: &str,
|
||||
reduce_op: impl Fn(GraphTensor, Vec<usize>) -> GraphTensor,
|
||||
all_axes_op: impl Fn(GraphTensor, usize) -> GraphTensor,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: {} Node", op_name);
|
||||
assert!(
|
||||
!node.input.is_empty(),
|
||||
"{} should have at least 1 input",
|
||||
op_name
|
||||
);
|
||||
assert!(
|
||||
node.output.len() == 1,
|
||||
"{} should have exactly 1 output",
|
||||
op_name
|
||||
);
|
||||
|
||||
let input = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("{}: missing input tensor '{}'", op_name, node.input[0]))?;
|
||||
|
||||
let keepdims = get_int_attr(node, "keepdims", 1) != 0;
|
||||
let noop_with_empty_axes = get_int_attr(node, "noop_with_empty_axes", 0) != 0;
|
||||
|
||||
let ndim = input.dims().len();
|
||||
|
||||
// Resolve axes from second input (opset 13+) or from attribute (opset 11)
|
||||
let raw_axes: Vec<i64> = if node.input.len() > 1 && !node.input[1].is_empty() {
|
||||
let axes_vals = known_values.get(&node.input[1]).ok_or_else(|| {
|
||||
format!(
|
||||
"{}: axes input '{}' must be a known constant",
|
||||
op_name, node.input[1]
|
||||
)
|
||||
})?;
|
||||
axes_vals.iter().map(|&v| v as i64).collect()
|
||||
} else if let Some(attr) = node.attribute.iter().find(|a| a.name == "axes") {
|
||||
attr.ints.clone()
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
|
||||
let output_name = &node.output[0];
|
||||
|
||||
// Handle empty axes: noop or reduce all
|
||||
let raw_axes: Vec<i64> = if raw_axes.is_empty() {
|
||||
if noop_with_empty_axes {
|
||||
tensors.insert(output_name.clone(), input);
|
||||
trace!("Finished parse: {} Node (noop)", op_name);
|
||||
return Ok(());
|
||||
} else {
|
||||
(0..ndim as i64).collect()
|
||||
}
|
||||
} else {
|
||||
raw_axes
|
||||
};
|
||||
|
||||
// Normalize negative axes and convert to usize
|
||||
let mut normalized_axes: Vec<usize> = raw_axes
|
||||
.iter()
|
||||
.map(|&a| {
|
||||
if a < 0 {
|
||||
(ndim as i64 + a) as usize
|
||||
} else {
|
||||
a as usize
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
normalized_axes.sort();
|
||||
normalized_axes.dedup();
|
||||
|
||||
// Save original sorted axes for keepdims unsqueeze bookkeeping
|
||||
let sorted_axes = normalized_axes.clone();
|
||||
|
||||
let input_dims = input.dims();
|
||||
|
||||
if normalized_axes.len() == ndim {
|
||||
// All-axes reduction: flatten to [1, N] and reduce axis 1 → [1].
|
||||
// luminal's Expression::product() returns 0 for empty iterators, so a reduce
|
||||
// producing a 0-dim tensor causes CUDA to launch with grid (0,1,1), which is
|
||||
// invalid. Using [1, N] → reduce(1) → [1] avoids this entirely.
|
||||
let total: usize = input_dims
|
||||
.iter()
|
||||
.map(|d| d.to_usize().expect("reduce: dim must be concrete"))
|
||||
.product();
|
||||
let mut flat = input;
|
||||
flat.shape = ShapeTracker::new(vec![1, total]);
|
||||
let mut result = all_axes_op(flat, total);
|
||||
|
||||
if keepdims {
|
||||
// Insert (ndim-1) additional size-1 dims to produce [1]*ndim
|
||||
for i in 1..ndim {
|
||||
result = result.unsqueeze(i);
|
||||
}
|
||||
}
|
||||
|
||||
tensors.insert(output_name.clone(), result);
|
||||
trace!("Finished parse: {} Node (all-axes)", op_name);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Partial reduction: luminal's ToAxes API handles axis shifting internally
|
||||
let mut result = reduce_op(input, normalized_axes);
|
||||
|
||||
// Re-insert size-1 dims at original positions (ascending order keeps positions correct)
|
||||
if keepdims {
|
||||
for &axis in &sorted_axes {
|
||||
result = result.unsqueeze(axis);
|
||||
}
|
||||
}
|
||||
|
||||
tensors.insert(output_name.clone(), result);
|
||||
trace!("Finished parse: {} Node", op_name);
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,453 +0,0 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use luminal::{
|
||||
prelude::{tracing::trace, *},
|
||||
shape::Expression,
|
||||
};
|
||||
use onnx_protobuf::NodeProto;
|
||||
|
||||
use crate::util::{broadcast_to_expr, get_int_attr};
|
||||
|
||||
/// Handle Constant node: creates a tensor from embedded data in the node attributes.
|
||||
///
|
||||
/// Supports FLOAT, INT64, INT32, and FLOAT64 data types (all converted to f32).
|
||||
/// The resulting tensor is registered as a known constant for downstream folding.
|
||||
pub fn parse_constant_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
cx: &mut Graph,
|
||||
weight_data: &mut Vec<(String, Vec<f32>)>,
|
||||
known_values: &mut HashMap<String, Vec<f32>>,
|
||||
shape_exprs: &mut HashMap<String, Vec<Expression>>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: Constant Node");
|
||||
assert!(
|
||||
node.output.len() == 1,
|
||||
"Constant should have exactly one output"
|
||||
);
|
||||
|
||||
// Find the "value" attribute (type TENSOR)
|
||||
let value_attr = node
|
||||
.attribute
|
||||
.iter()
|
||||
.find(|a| a.name == "value")
|
||||
.ok_or_else(|| "Constant node missing 'value' attribute".to_string())?;
|
||||
|
||||
let tensor_proto = value_attr
|
||||
.t
|
||||
.as_ref()
|
||||
.ok_or_else(|| "Constant 'value' attribute has no TensorProto".to_string())?;
|
||||
|
||||
// Determine shape: empty dims = scalar = [1] for luminal
|
||||
let shape: Vec<usize> = if tensor_proto.dims.is_empty() {
|
||||
vec![1]
|
||||
} else {
|
||||
tensor_proto.dims.iter().map(|&d| d as usize).collect()
|
||||
};
|
||||
|
||||
// Extract float data based on data_type
|
||||
let floats: Vec<f32> = match tensor_proto.data_type {
|
||||
1 => {
|
||||
// FLOAT (f32)
|
||||
if !tensor_proto.float_data.is_empty() {
|
||||
tensor_proto.float_data.clone()
|
||||
} else {
|
||||
tensor_proto
|
||||
.raw_data
|
||||
.chunks_exact(4)
|
||||
.map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
6 => {
|
||||
// INT32
|
||||
if !tensor_proto.int32_data.is_empty() {
|
||||
tensor_proto.int32_data.iter().map(|&v| v as f32).collect()
|
||||
} else {
|
||||
tensor_proto
|
||||
.raw_data
|
||||
.chunks_exact(4)
|
||||
.map(|c| i32::from_le_bytes([c[0], c[1], c[2], c[3]]) as f32)
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
7 => {
|
||||
// INT64
|
||||
if !tensor_proto.int64_data.is_empty() {
|
||||
tensor_proto.int64_data.iter().map(|&v| v as f32).collect()
|
||||
} else {
|
||||
tensor_proto
|
||||
.raw_data
|
||||
.chunks_exact(8)
|
||||
.map(|c| {
|
||||
i64::from_le_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]]) as f32
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
dt => return Err(format!("Constant node: unsupported data_type {}", dt)),
|
||||
};
|
||||
|
||||
let output_name = &node.output[0];
|
||||
let tensor = cx.named_tensor(output_name.clone(), shape);
|
||||
tensors.insert(output_name.clone(), tensor);
|
||||
known_values.insert(output_name.clone(), floats.clone());
|
||||
// Also propagate as concrete shape_exprs for downstream shape computation chains
|
||||
shape_exprs.insert(
|
||||
output_name.clone(),
|
||||
floats
|
||||
.iter()
|
||||
.map(|&v| Expression::from(v as usize))
|
||||
.collect(),
|
||||
);
|
||||
weight_data.push((output_name.clone(), floats));
|
||||
|
||||
trace!("Finished parse: Constant Node");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle Shape node: extract the shape of the input tensor as a 1D constant.
|
||||
///
|
||||
/// For static shapes, stores as known_values. For dynamic shapes (containing
|
||||
/// Expression variables), stores in shape_exprs for downstream shape computation chains.
|
||||
pub fn parse_shape_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
cx: &mut Graph,
|
||||
weight_data: &mut Vec<(String, Vec<f32>)>,
|
||||
known_values: &mut HashMap<String, Vec<f32>>,
|
||||
shape_exprs: &mut HashMap<String, Vec<Expression>>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Started parse: Shape");
|
||||
assert!(node.input.len() == 1, "Shape should have exactly 1 input");
|
||||
let input = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("Shape: missing input tensor '{}'", node.input[0]))?;
|
||||
|
||||
let all_dims = input.dims();
|
||||
|
||||
// Handle start/end attributes (ONNX Shape opset 15+: extract a slice of dims)
|
||||
let start = get_int_attr(node, "start", 0) as usize;
|
||||
let end_attr = get_int_attr(node, "end", all_dims.len() as i64);
|
||||
let end = if end_attr < 0 {
|
||||
(all_dims.len() as i64 + end_attr) as usize
|
||||
} else {
|
||||
(end_attr as usize).min(all_dims.len())
|
||||
};
|
||||
let dims: Vec<Expression> = all_dims[start..end].to_vec();
|
||||
|
||||
let output_name = &node.output[0];
|
||||
|
||||
// Always store in shape_exprs (supports both concrete and symbolic dims)
|
||||
shape_exprs.insert(output_name.clone(), dims.clone());
|
||||
|
||||
// For concrete dims, also store in known_values for backward compat
|
||||
let all_concrete = dims.iter().all(|d| d.to_usize().is_some());
|
||||
let shape_values: Vec<f32> = dims
|
||||
.iter()
|
||||
.map(|d| d.to_usize().unwrap_or(1) as f32)
|
||||
.collect();
|
||||
|
||||
if all_concrete {
|
||||
// Concrete shape: create tensor + known_values + weight_data
|
||||
let tensor = cx.named_tensor(output_name.clone(), vec![shape_values.len()]);
|
||||
tensors.insert(output_name.clone(), tensor);
|
||||
known_values.insert(output_name.clone(), shape_values.clone());
|
||||
weight_data.push((output_name.clone(), shape_values));
|
||||
}
|
||||
// For symbolic shapes, don't create a tensor — it's shape-only
|
||||
|
||||
trace!("Finished parse: Shape");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle ConstantOfShape node: creates a tensor of a given shape filled with a constant value.
|
||||
///
|
||||
/// The shape is taken from the input tensor (which must be a known constant).
|
||||
/// The fill value comes from the "value" attribute (default 0.0).
|
||||
pub fn parse_constant_of_shape(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
cx: &mut Graph,
|
||||
weight_data: &mut Vec<(String, Vec<f32>)>,
|
||||
known_values: &mut HashMap<String, Vec<f32>>,
|
||||
shape_exprs: &mut HashMap<String, Vec<Expression>>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: ConstantOfShape Node");
|
||||
assert!(
|
||||
node.input.len() == 1,
|
||||
"ConstantOfShape should have exactly one input (shape)"
|
||||
);
|
||||
assert!(
|
||||
node.output.len() == 1,
|
||||
"ConstantOfShape should have exactly one output"
|
||||
);
|
||||
|
||||
// Extract fill value from "value" attribute (TensorProto scalar), default 0.0
|
||||
let fill_value: f32 = node
|
||||
.attribute
|
||||
.iter()
|
||||
.find(|a| a.name == "value")
|
||||
.and_then(|attr| attr.t.as_ref())
|
||||
.map(|tp| {
|
||||
if !tp.float_data.is_empty() {
|
||||
tp.float_data[0]
|
||||
} else if !tp.int32_data.is_empty() {
|
||||
tp.int32_data[0] as f32
|
||||
} else if !tp.raw_data.is_empty() {
|
||||
match tp.data_type {
|
||||
1 => f32::from_le_bytes([
|
||||
tp.raw_data[0],
|
||||
tp.raw_data[1],
|
||||
tp.raw_data[2],
|
||||
tp.raw_data[3],
|
||||
]),
|
||||
6 => i32::from_le_bytes([
|
||||
tp.raw_data[0],
|
||||
tp.raw_data[1],
|
||||
tp.raw_data[2],
|
||||
tp.raw_data[3],
|
||||
]) as f32,
|
||||
7 => i64::from_le_bytes([
|
||||
tp.raw_data[0],
|
||||
tp.raw_data[1],
|
||||
tp.raw_data[2],
|
||||
tp.raw_data[3],
|
||||
tp.raw_data[4],
|
||||
tp.raw_data[5],
|
||||
tp.raw_data[6],
|
||||
tp.raw_data[7],
|
||||
]) as f32,
|
||||
_ => 0.0,
|
||||
}
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
})
|
||||
.unwrap_or(0.0);
|
||||
|
||||
let output_name = &node.output[0];
|
||||
|
||||
// Try shape_exprs first (for dynamic shapes), then known_values
|
||||
if let Some(se) = shape_exprs.get(&node.input[0]) {
|
||||
let shape: Vec<Expression> = se.clone();
|
||||
|
||||
// Check if all dims are concrete
|
||||
if let Some(concrete) = shape
|
||||
.iter()
|
||||
.map(|e| e.to_usize())
|
||||
.collect::<Option<Vec<usize>>>()
|
||||
{
|
||||
// Fully concrete: create named tensor with weight data
|
||||
let numel: usize = concrete.iter().product();
|
||||
let floats: Vec<f32> = vec![fill_value; numel];
|
||||
let tensor = cx.named_tensor(output_name.clone(), concrete);
|
||||
tensors.insert(output_name.clone(), tensor);
|
||||
known_values.insert(output_name.clone(), floats.clone());
|
||||
weight_data.push((output_name.clone(), floats));
|
||||
} else {
|
||||
// Dynamic shape: create scalar constant and broadcast to symbolic shape.
|
||||
// The scalar always has concrete data (1 element), and the shape is
|
||||
// resolved at runtime via ShapeTracker/dyn_map. Broadcast uses stride-0
|
||||
// expansion, so only 1 float is needed in the backing buffer.
|
||||
let scalar = cx.constant_float(fill_value);
|
||||
let result = broadcast_to_expr(scalar, se);
|
||||
tensors.insert(output_name.clone(), result);
|
||||
}
|
||||
} else {
|
||||
let shape_values = known_values.get(&node.input[0]).ok_or_else(|| {
|
||||
format!(
|
||||
"ConstantOfShape: shape input '{}' must be a known constant or shape_expr",
|
||||
node.input[0]
|
||||
)
|
||||
})?;
|
||||
let shape: Vec<usize> = shape_values.iter().map(|&v| v as usize).collect();
|
||||
let numel: usize = shape.iter().product();
|
||||
let floats: Vec<f32> = vec![fill_value; numel];
|
||||
|
||||
let tensor = cx.named_tensor(output_name.clone(), shape);
|
||||
tensors.insert(output_name.clone(), tensor);
|
||||
known_values.insert(output_name.clone(), floats.clone());
|
||||
weight_data.push((output_name.clone(), floats));
|
||||
}
|
||||
|
||||
trace!("Finished parse: ConstantOfShape Node");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle Identity node: output is a direct alias of the input tensor.
|
||||
///
|
||||
/// Propagates known constant values for downstream constant folding.
|
||||
pub fn parse_identity(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
known_values: &mut HashMap<String, Vec<f32>>,
|
||||
shape_exprs: &mut HashMap<String, Vec<Expression>>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: Identity Node");
|
||||
assert!(node.input.len() == 1, "Identity should only have one input");
|
||||
let a = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("Identity: missing input tensor '{}'", node.input[0]))?;
|
||||
|
||||
assert!(
|
||||
node.output.len() == 1,
|
||||
"Identity should only have a single output"
|
||||
);
|
||||
|
||||
let output_name = &node.output[0];
|
||||
|
||||
// Force materialization using Expression-aware broadcast
|
||||
let dims = a.dims();
|
||||
let one = a.graph().constant_float(1.0);
|
||||
let one_expanded = broadcast_to_expr(one, &dims);
|
||||
let result = a * one_expanded;
|
||||
tensors.insert(output_name.clone(), result);
|
||||
|
||||
// Propagate known values
|
||||
if let Some(vals) = known_values.get(&node.input[0]).cloned() {
|
||||
known_values.insert(output_name.clone(), vals);
|
||||
}
|
||||
// Propagate shape_exprs
|
||||
if let Some(se) = shape_exprs.get(&node.input[0]).cloned() {
|
||||
shape_exprs.insert(output_name.clone(), se);
|
||||
}
|
||||
|
||||
trace!("Finished parse: Identity Node");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle Range node: creates a 1D tensor [start, start+delta, start+2*delta, ...] up to limit.
|
||||
///
|
||||
/// Used by dynamo ONNX export for generating position indices (arange).
|
||||
/// Supports Expression-based limits for dynamic sequence lengths.
|
||||
pub fn parse_range_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
cx: &mut Graph,
|
||||
weight_data: &mut Vec<(String, Vec<f32>)>,
|
||||
known_values: &mut HashMap<String, Vec<f32>>,
|
||||
shape_exprs: &mut HashMap<String, Vec<Expression>>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: Range Node");
|
||||
assert!(
|
||||
node.input.len() == 3,
|
||||
"Range needs 3 inputs: start, limit, delta"
|
||||
);
|
||||
|
||||
let output_name = &node.output[0];
|
||||
|
||||
// Try to get concrete values from known_values first
|
||||
let start_val = known_values
|
||||
.get(&node.input[0])
|
||||
.and_then(|v| v.first().copied());
|
||||
let limit_val = known_values
|
||||
.get(&node.input[1])
|
||||
.and_then(|v| v.first().copied());
|
||||
let delta_val = known_values
|
||||
.get(&node.input[2])
|
||||
.and_then(|v| v.first().copied());
|
||||
|
||||
// Also check shape_exprs for symbolic limit
|
||||
let limit_expr = shape_exprs
|
||||
.get(&node.input[1])
|
||||
.and_then(|v| v.first().cloned());
|
||||
|
||||
let start = start_val.unwrap_or(0.0);
|
||||
let delta = delta_val.unwrap_or(1.0);
|
||||
|
||||
if start == 0.0 && delta == 1.0 {
|
||||
// Simple arange case — most common for position indices
|
||||
if let Some(expr) = limit_expr {
|
||||
// Dynamic limit: create arange with symbolic length
|
||||
let tensor = cx.arange(expr);
|
||||
// Cast to F32 (luminal arange returns Int dtype)
|
||||
let result = tensor.cast(DType::F32);
|
||||
tensors.insert(output_name.clone(), result);
|
||||
shape_exprs.insert(output_name.clone(), vec![expr]);
|
||||
} else if let Some(limit) = limit_val {
|
||||
let n = limit as usize;
|
||||
let floats: Vec<f32> = (0..n).map(|i| i as f32).collect();
|
||||
let tensor = cx.named_tensor(output_name.clone(), vec![n]);
|
||||
tensors.insert(output_name.clone(), tensor);
|
||||
known_values.insert(output_name.clone(), floats.clone());
|
||||
weight_data.push((output_name.clone(), floats));
|
||||
} else {
|
||||
return Err("Range: limit must be known or symbolic".to_string());
|
||||
}
|
||||
} else if let (Some(s), Some(l), Some(d)) = (start_val, limit_val, delta_val) {
|
||||
// Fully concrete range
|
||||
let mut floats = Vec::new();
|
||||
let mut v = s;
|
||||
while (d > 0.0 && v < l) || (d < 0.0 && v > l) {
|
||||
floats.push(v);
|
||||
v += d;
|
||||
}
|
||||
let tensor = cx.named_tensor(output_name.clone(), vec![floats.len()]);
|
||||
tensors.insert(output_name.clone(), tensor);
|
||||
known_values.insert(output_name.clone(), floats.clone());
|
||||
weight_data.push((output_name.clone(), floats));
|
||||
} else {
|
||||
return Err("Range: cannot handle non-trivial dynamic ranges yet".to_string());
|
||||
}
|
||||
|
||||
trace!("Finished parse: Range Node");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle CumSum node: cumulative sum along an axis.
|
||||
///
|
||||
/// For the simple case of axis=0 on a 1D tensor [0, 1, 2, ...] (position indices),
|
||||
/// the cumsum is equivalent to [0, 1, 3, 6, ...]. For dynamic ONNX graphs,
|
||||
/// this is typically used for position_ids computation.
|
||||
pub fn parse_cumsum_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
known_values: &mut HashMap<String, Vec<f32>>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: CumSum Node");
|
||||
assert!(node.input.len() >= 2, "CumSum needs at least 2 inputs");
|
||||
|
||||
let input = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("CumSum: missing input '{}'", node.input[0]))?;
|
||||
|
||||
let axis_val = known_values
|
||||
.get(&node.input[1])
|
||||
.and_then(|v| v.first().copied())
|
||||
.unwrap_or(0.0) as i64;
|
||||
|
||||
let dims = input.dims();
|
||||
let ndim = dims.len();
|
||||
let _axis = if axis_val < 0 {
|
||||
(ndim as i64 + axis_val) as usize
|
||||
} else {
|
||||
axis_val as usize
|
||||
};
|
||||
|
||||
// For constant folding
|
||||
if let Some(vals) = known_values.get(&node.input[0]).cloned() {
|
||||
let output_name = &node.output[0];
|
||||
let mut cumsum = vals.clone();
|
||||
// Simple 1D cumsum
|
||||
if ndim == 1 {
|
||||
for i in 1..cumsum.len() {
|
||||
cumsum[i] += cumsum[i - 1];
|
||||
}
|
||||
}
|
||||
known_values.insert(output_name.clone(), cumsum);
|
||||
// Just alias the tensor (same shape)
|
||||
tensors.insert(output_name.clone(), input);
|
||||
trace!("Finished parse: CumSum Node (constant folded)");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// For dynamic: cumsum is hard to express in luminal primitives.
|
||||
// For the specific pattern used in Llama position_ids (cumsum of ones = arange),
|
||||
// we just pass through since arange is already handled by Range node.
|
||||
let output_name = &node.output[0];
|
||||
tensors.insert(output_name.clone(), input);
|
||||
|
||||
trace!("Finished parse: CumSum Node");
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,440 +0,0 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use luminal::{
|
||||
prelude::{tracing::trace, *},
|
||||
shape::Expression,
|
||||
};
|
||||
use onnx_protobuf::NodeProto;
|
||||
|
||||
use crate::util::{broadcast_to_expr, get_float_attr, get_int_attr};
|
||||
|
||||
/// Handle Softmax node: output = softmax(input[0], axis)
|
||||
///
|
||||
/// ONNX axis attribute defaults to -1 (last dimension, opset 13+).
|
||||
/// Negative axis is normalized against the input rank.
|
||||
pub fn parse_softmax_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: Softmax Node");
|
||||
assert!(
|
||||
node.input.len() == 1,
|
||||
"Softmax nodes need to have one input, {} where present",
|
||||
node.input.len()
|
||||
);
|
||||
assert!(
|
||||
node.output.len() == 1,
|
||||
"Softmax nodes only have one output, {} where present",
|
||||
node.output.len(),
|
||||
);
|
||||
let output_name = &node.output[0];
|
||||
let a = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("Softmax: missing input tensor '{}'", node.input[0]))?;
|
||||
|
||||
let ndim = a.dims().len();
|
||||
let raw_axis = get_int_attr(node, "axis", -1);
|
||||
let axis = if raw_axis < 0 {
|
||||
(ndim as i64 + raw_axis) as usize
|
||||
} else {
|
||||
raw_axis as usize
|
||||
};
|
||||
|
||||
let result = a.softmax(axis);
|
||||
tensors.insert(output_name.clone(), result);
|
||||
trace!("Finished parse: Softmax Node");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle Not node: logical NOT — output = 1.0 - input[0]
|
||||
pub fn parse_not_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: Not Node");
|
||||
assert!(
|
||||
node.input.len() == 1,
|
||||
"Not nodes need to have one input {} where present",
|
||||
node.input.len()
|
||||
);
|
||||
assert!(
|
||||
node.output.len() == 1,
|
||||
"Not nodes only have one output, {} where present",
|
||||
node.output.len()
|
||||
);
|
||||
let a = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("Not: missing input tensor '{}'", node.input[0]))?;
|
||||
let a_f32 = a.cast(DType::F32);
|
||||
let result = 1.0_f32 - a_f32;
|
||||
tensors.insert(node.output[0].clone(), result);
|
||||
trace!("Finished parse: Not Node");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle Clip node: output = clip(input[0], min, max)
|
||||
///
|
||||
/// Equivalent to torch.clamp. min and max are optional tensor inputs
|
||||
/// (typically constants) residing in known_values.
|
||||
pub fn parse_clip_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
known_values: &HashMap<String, Vec<f32>>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: Clip Node");
|
||||
let output_name = &node.output[0];
|
||||
let a = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("Clip: missing input tensor '{}'", node.input[0]))?;
|
||||
|
||||
// input[1] = min (optional), input[2] = max (optional)
|
||||
let min_name = node.input.get(1).map(String::as_str).unwrap_or("");
|
||||
let max_name = node.input.get(2).map(String::as_str).unwrap_or("");
|
||||
|
||||
let min_val = if min_name.is_empty() {
|
||||
None
|
||||
} else {
|
||||
known_values.get(min_name).map(|v| v[0])
|
||||
};
|
||||
let max_val = if max_name.is_empty() {
|
||||
None
|
||||
} else {
|
||||
known_values.get(max_name).map(|v| v[0])
|
||||
};
|
||||
|
||||
let result = match (min_val, max_val) {
|
||||
(Some(lo), Some(hi)) => a.clip(lo, hi),
|
||||
(Some(lo), None) => a.maximum_f32(lo),
|
||||
(None, Some(hi)) => a.minimum_f32(hi),
|
||||
(None, None) => a,
|
||||
};
|
||||
|
||||
tensors.insert(output_name.clone(), result);
|
||||
trace!("Finished parse: Clip Node");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle Floor node: output = floor(input[0])
|
||||
///
|
||||
/// Implemented as: trunc(x) - (x < trunc(x) ? 1 : 0)
|
||||
/// where trunc is truncation toward zero via cast to Int then back to F32.
|
||||
/// This correctly handles negative non-integer values (e.g. floor(-1.5) = -2).
|
||||
pub fn parse_floor_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: Floor Node");
|
||||
assert!(
|
||||
node.input.len() == 1,
|
||||
"Floor nodes need to have one input {} where present",
|
||||
node.input.len()
|
||||
);
|
||||
assert!(
|
||||
node.output.len() == 1,
|
||||
"Floor nodes only have one output, {} where present",
|
||||
node.output.len(),
|
||||
);
|
||||
let output_name = &node.output[0];
|
||||
let a = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("Floor: missing input tensor '{}'", node.input[0]))?;
|
||||
|
||||
// trunc(x): truncation toward zero
|
||||
let trunc = a.cast(DType::Int).cast(DType::F32);
|
||||
// For negative non-integers, x < trunc(x), so subtract 1
|
||||
// Cast lt result (Bool) to F32 before arithmetic
|
||||
let adjustment = a.lt(trunc).cast(DType::F32);
|
||||
let result = trunc - adjustment;
|
||||
tensors.insert(output_name.clone(), result);
|
||||
trace!("Finished parse: Floor Node");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle Ceil node: output = ceil(input[0])
|
||||
///
|
||||
/// Implemented as: trunc(x) + (x > trunc(x) ? 1 : 0)
|
||||
/// where trunc is truncation toward zero via cast to Int then back to F32.
|
||||
/// This correctly handles positive non-integer values (e.g. ceil(1.5) = 2).
|
||||
pub fn parse_ceil_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: Ceil Node");
|
||||
assert!(
|
||||
node.input.len() == 1,
|
||||
"Ceil nodes need to have one input {} where present",
|
||||
node.input.len()
|
||||
);
|
||||
assert!(
|
||||
node.output.len() == 1,
|
||||
"Ceil nodes only have one output, {} where present",
|
||||
node.output.len(),
|
||||
);
|
||||
let output_name = &node.output[0];
|
||||
let a = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("Ceil: missing input tensor '{}'", node.input[0]))?;
|
||||
|
||||
// trunc(x): truncation toward zero
|
||||
let trunc = a.cast(DType::Int).cast(DType::F32);
|
||||
// For positive non-integers, x > trunc(x), so add 1
|
||||
let adjustment = a.gt(trunc).cast(DType::F32);
|
||||
let result = trunc + adjustment;
|
||||
tensors.insert(output_name.clone(), result);
|
||||
trace!("Finished parse: Ceil Node");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn parse_cast_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
weight_data: &mut Vec<(String, Vec<f32>)>,
|
||||
known_values: &mut HashMap<String, Vec<f32>>,
|
||||
shape_exprs: &mut HashMap<String, Vec<Expression>>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: Cast Node");
|
||||
assert!(node.input.len() == 1, "Cast should have exactly 1 input");
|
||||
let input = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("Cast: missing input tensor '{}'", node.input[0]))?;
|
||||
|
||||
// ONNX data type enum → luminal DType
|
||||
let to = get_int_attr(node, "to", 1);
|
||||
let dtype = match to {
|
||||
1 => DType::F32, // FLOAT
|
||||
10 => DType::F16, // FLOAT16
|
||||
16 => DType::Bf16, // BFLOAT16
|
||||
6 | 7 => DType::Int, // INT32, INT64
|
||||
9 => DType::F32, // BOOL → treat as F32 (0.0/1.0)
|
||||
11 => DType::F32, // DOUBLE → F32 (downcast)
|
||||
_ => DType::F32, // fallback
|
||||
};
|
||||
|
||||
let cast_result = input.cast(dtype);
|
||||
let output_name = &node.output[0];
|
||||
|
||||
let result = if cast_result.id == input.id {
|
||||
input
|
||||
} else {
|
||||
cast_result
|
||||
};
|
||||
|
||||
tensors.insert(output_name.clone(), result);
|
||||
|
||||
// Propagate known values (cast is a no-op for our f32 storage)
|
||||
if let Some(vals) = known_values.get(&node.input[0]).cloned() {
|
||||
let folded = if to == 9 {
|
||||
vals.iter()
|
||||
.map(|&v| if v != 0.0 { 1.0 } else { 0.0 })
|
||||
.collect()
|
||||
} else if to == 6 || to == 7 {
|
||||
vals.iter().map(|&v| (v as i64) as f32).collect()
|
||||
} else {
|
||||
vals
|
||||
};
|
||||
known_values.insert(output_name.clone(), folded.clone());
|
||||
weight_data.push((output_name.clone(), folded));
|
||||
}
|
||||
// Propagate shape_exprs
|
||||
if let Some(se) = shape_exprs.get(&node.input[0]).cloned() {
|
||||
shape_exprs.insert(output_name.clone(), se);
|
||||
}
|
||||
|
||||
trace!("Finished parse: Cast Node");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn parse_unary_op(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
op_name: &str,
|
||||
op: impl Fn(GraphTensor) -> GraphTensor,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: {} Node", op_name);
|
||||
assert!(
|
||||
node.input.len() == 1,
|
||||
"{} should have 1 input, got {}",
|
||||
op_name,
|
||||
node.input.len()
|
||||
);
|
||||
assert!(
|
||||
node.output.len() == 1,
|
||||
"{} should have 1 output, got {}",
|
||||
op_name,
|
||||
node.output.len()
|
||||
);
|
||||
let a = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("{}: missing input tensor '{}'", op_name, node.input[0]))?;
|
||||
let result = op(a);
|
||||
tensors.insert(node.output[0].clone(), result);
|
||||
trace!("Finished parse: {} Node", op_name);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle Erf node: output = erf(input[0])
|
||||
///
|
||||
/// Uses the Abramowitz & Stegun 7.1.26 polynomial approximation (max error < 1.5e-7):
|
||||
/// For x ≥ 0: erf(x) ≈ 1 - (a1·t + a2·t² + a3·t³ + a4·t⁴ + a5·t⁵) · exp(-x²)
|
||||
/// where t = 1 / (1 + 0.3275911·x)
|
||||
/// a1 = 0.254829592
|
||||
/// a2 = -0.284496736
|
||||
/// a3 = 1.421413741
|
||||
/// a4 = -1.453152027
|
||||
/// a5 = 1.061405429
|
||||
/// Extended to all x via odd symmetry: erf(-x) = -erf(x).
|
||||
pub fn parse_erf_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
) -> Result<(), String> {
|
||||
parse_unary_op(node, tensors, "Erf", |x| {
|
||||
let a = x.abs();
|
||||
let t = (1.0_f32 + 0.3275911_f32 * a).reciprocal();
|
||||
// Horner evaluation of a1*t + a2*t² + a3*t³ + a4*t⁴ + a5*t⁵
|
||||
// poly = t*(a1 + t*(a2 + t*(a3 + t*(a4 + a5*t))))
|
||||
let h = t * 1.061_405_4_f32 - 1.453_152_1_f32; // a4 + a5*t
|
||||
let h = t * h + 1.421_413_8_f32;
|
||||
let h = t * h - 0.284_496_72_f32;
|
||||
let h = t * h + 0.254_829_6_f32;
|
||||
let poly = t * h;
|
||||
let erf_abs = 1.0_f32 - poly * (-a * a).exp();
|
||||
x.sign() * erf_abs
|
||||
})
|
||||
}
|
||||
|
||||
/// Handle LayerNormalization node (opset 17).
|
||||
///
|
||||
/// Inputs: X (required), scale (required), bias (optional)
|
||||
/// Attributes: axis (default -1), epsilon (default 1e-5)
|
||||
/// Normalizes over axes [axis, axis+1, ..., rank-1], then applies scale and bias.
|
||||
/// Only output 0 (the normalized result) is wired; outputs 1/2 (mean, inv_std_var)
|
||||
/// are training-only and not supported for inference.
|
||||
pub fn parse_layernorm_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: LayerNormalization Node");
|
||||
let input = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("LayerNorm: missing input '{}'", node.input[0]))?;
|
||||
let scale = *tensors
|
||||
.get(&node.input[1])
|
||||
.ok_or_else(|| format!("LayerNorm: missing scale '{}'", node.input[1]))?;
|
||||
|
||||
let ndim = input.dims().len();
|
||||
let axis_raw = get_int_attr(node, "axis", -1);
|
||||
let axis = if axis_raw < 0 {
|
||||
(ndim as i64 + axis_raw) as usize
|
||||
} else {
|
||||
axis_raw as usize
|
||||
};
|
||||
let epsilon = get_float_attr(node, "epsilon", 1e-5);
|
||||
let axes: Vec<usize> = (axis..ndim).collect();
|
||||
|
||||
let mut result = input.layer_norm(axes, epsilon);
|
||||
|
||||
// Apply scale (broadcast to input shape using Expression-aware broadcast)
|
||||
let input_shape = input.dims();
|
||||
result *= broadcast_to_expr(scale, &input_shape);
|
||||
|
||||
// Apply optional bias
|
||||
if node.input.len() > 2 && !node.input[2].is_empty() {
|
||||
let bias = *tensors
|
||||
.get(&node.input[2])
|
||||
.ok_or_else(|| format!("LayerNorm: missing bias '{}'", node.input[2]))?;
|
||||
result += broadcast_to_expr(bias, &input_shape);
|
||||
}
|
||||
|
||||
tensors.insert(node.output[0].clone(), result);
|
||||
trace!("Finished parse: LayerNormalization Node");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle GroupNormalization node (opset 18).
|
||||
///
|
||||
/// Inputs: X [N, C, spatial...], scale [num_groups], bias [num_groups]
|
||||
/// Attributes: num_groups (required), epsilon (default 1e-5)
|
||||
///
|
||||
/// Normalizes over channels-per-group and spatial dims, then applies per-group scale/bias.
|
||||
/// Decomposed into: reshape [N, G, C/G, spatial...] -> layer_norm over [C/G, spatial...] ->
|
||||
/// reshape back to [N, C, spatial...] -> scale + bias (broadcast).
|
||||
pub fn parse_group_norm_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: GroupNormalization Node");
|
||||
|
||||
assert!(
|
||||
node.input.len() >= 3,
|
||||
"GroupNormalization needs 3 inputs (X, scale, bias), got {}",
|
||||
node.input.len()
|
||||
);
|
||||
|
||||
let x = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("GroupNorm: missing input X '{}'", node.input[0]))?;
|
||||
let scale = *tensors
|
||||
.get(&node.input[1])
|
||||
.ok_or_else(|| format!("GroupNorm: missing scale '{}'", node.input[1]))?;
|
||||
let bias = *tensors
|
||||
.get(&node.input[2])
|
||||
.ok_or_else(|| format!("GroupNorm: missing bias '{}'", node.input[2]))?;
|
||||
|
||||
let x_dims = x.dims();
|
||||
let ndim = x_dims.len();
|
||||
assert!(
|
||||
ndim >= 3,
|
||||
"GroupNorm: input must be at least 3D [N, C, spatial...], got {ndim}D"
|
||||
);
|
||||
|
||||
let num_groups = get_int_attr(node, "num_groups", 1) as usize;
|
||||
let epsilon = get_float_attr(node, "epsilon", 1e-5);
|
||||
|
||||
let n = x_dims[0]
|
||||
.to_usize()
|
||||
.expect("GroupNorm: batch must be concrete");
|
||||
let c = x_dims[1]
|
||||
.to_usize()
|
||||
.expect("GroupNorm: channels must be concrete");
|
||||
assert_eq!(
|
||||
c % num_groups,
|
||||
0,
|
||||
"GroupNorm: channels {c} must be divisible by num_groups {num_groups}"
|
||||
);
|
||||
let cpg = c / num_groups; // channels per group
|
||||
|
||||
// Reshape X from [N, C, spatial...] to [N, G, C/G, spatial...]
|
||||
let spatial_dims: Vec<Expression> = x_dims[2..].to_vec();
|
||||
let mut reshaped = x;
|
||||
let mut new_shape = vec![n, num_groups, cpg];
|
||||
for d in &spatial_dims {
|
||||
new_shape.push(
|
||||
d.to_usize()
|
||||
.expect("GroupNorm: spatial dims must be concrete"),
|
||||
);
|
||||
}
|
||||
reshaped.shape = ShapeTracker::new(new_shape.clone());
|
||||
|
||||
// Normalize over axes [2, 3, ..., ndim] (C/G + spatial dims)
|
||||
let norm_axes: Vec<usize> = (2..new_shape.len()).collect();
|
||||
let mut normed = reshaped.layer_norm(norm_axes, epsilon);
|
||||
|
||||
// Reshape back to [N, C, spatial...]
|
||||
let mut orig_shape = vec![n, c];
|
||||
for d in &spatial_dims {
|
||||
orig_shape.push(d.to_usize().unwrap());
|
||||
}
|
||||
normed *= 1.0;
|
||||
normed.shape = ShapeTracker::new(orig_shape.clone());
|
||||
|
||||
// Apply scale and bias (both shape [C], broadcast to [N, C, spatial...])
|
||||
let target_shape: Vec<Expression> = orig_shape.iter().map(|&d| Expression::from(d)).collect();
|
||||
let result =
|
||||
normed * broadcast_to_expr(scale, &target_shape) + broadcast_to_expr(bias, &target_shape);
|
||||
|
||||
tensors.insert(node.output[0].clone(), result);
|
||||
trace!("Finished parse: GroupNormalization Node");
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,15 +1,70 @@
|
||||
use luminal::dyn_backend::BackendFactory;
|
||||
use luminal::prelude::tracing::warn;
|
||||
use luminal::prelude::*;
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::{PyAny, PyCapsule, PyCapsuleMethods, PyDict};
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::compiled_graph::{CompiledGraph, GraphTranslation, WeightData};
|
||||
use crate::pt2_parser;
|
||||
use crate::compiled_graph::{CompiledGraph, DimParamMap, GraphTranslation, WeightData};
|
||||
use crate::pt2_schema;
|
||||
use crate::translator;
|
||||
use crate::util::DimParamMap;
|
||||
use crate::typed_data::TypedData;
|
||||
use crate::{pt2_parser, pt2_util};
|
||||
|
||||
/// Pre-loaded weight/constant data paired with tensor sizes.
|
||||
type PreloadResult = (Vec<(String, Vec<f32>)>, HashMap<String, usize>);
|
||||
type PreloadResult = (Vec<(String, TypedData)>, HashMap<String, usize>);
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
struct CompileOptions {
|
||||
search_iterations: usize,
|
||||
}
|
||||
|
||||
impl Default for CompileOptions {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
search_iterations: 10,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CompileOptions {
|
||||
fn from_py(options: Option<&Bound<'_, PyAny>>) -> PyResult<Self> {
|
||||
let mut parsed = Self::default();
|
||||
|
||||
let Some(options) = options else {
|
||||
return Ok(parsed);
|
||||
};
|
||||
|
||||
let options = options.cast::<PyDict>().map_err(|_| {
|
||||
pyo3::exceptions::PyTypeError::new_err("luminal backend options must be a dict")
|
||||
})?;
|
||||
|
||||
for (key, value) in options.iter() {
|
||||
let key = key.extract::<String>().map_err(|_| {
|
||||
pyo3::exceptions::PyTypeError::new_err(
|
||||
"luminal backend option keys must be strings",
|
||||
)
|
||||
})?;
|
||||
|
||||
match key.as_str() {
|
||||
"search_iterations" => {
|
||||
parsed.search_iterations = value.extract::<usize>().map_err(|_| {
|
||||
pyo3::exceptions::PyTypeError::new_err(
|
||||
"luminal backend option 'search_iterations' must be an integer",
|
||||
)
|
||||
})?;
|
||||
}
|
||||
other => {
|
||||
return Err(pyo3::exceptions::PyValueError::new_err(format!(
|
||||
"Unsupported luminal backend option '{other}'. Supported options: search_iterations",
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(parsed)
|
||||
}
|
||||
}
|
||||
|
||||
fn resolve_dim_sizes(
|
||||
sizes: &[pt2_schema::DimSize],
|
||||
@@ -35,20 +90,55 @@ fn resolve_dim_sizes(
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
#[pyo3(signature = (pt2_path, weights_path, backend, search_iters, weight_device_ptrs=None))]
|
||||
#[pyo3(signature = (pt2_path, weights_path, factory_capsule, weight_device_ptrs=None, options=None))]
|
||||
pub fn process_pt2(
|
||||
pt2_path: &str,
|
||||
weights_path: &str,
|
||||
backend: &str,
|
||||
search_iters: usize,
|
||||
factory_capsule: &Bound<'_, PyCapsule>,
|
||||
weight_device_ptrs: Option<HashMap<String, (u64, usize)>>,
|
||||
options: Option<&Bound<'_, PyAny>>,
|
||||
) -> PyResult<CompiledGraph> {
|
||||
let options = CompileOptions::from_py(options)?;
|
||||
let factory: BackendFactory = {
|
||||
let expected = ::luminal::dyn_backend::BACKEND_FACTORY_CAPSULE_NAME;
|
||||
match factory_capsule.name()? {
|
||||
Some(name) => {
|
||||
// SAFETY: the &CStr is used immediately (for a byte-wise
|
||||
// comparison) and never stored; the capsule is borrowed for
|
||||
// the duration of this function, so the name pointer stays
|
||||
// valid for as long as we read it here.
|
||||
let actual = unsafe { name.as_cstr() };
|
||||
if actual != expected {
|
||||
return Err(pyo3::exceptions::PyValueError::new_err(format!(
|
||||
"factory_capsule has wrong name: expected {:?}, got {:?}",
|
||||
expected, actual,
|
||||
)));
|
||||
}
|
||||
}
|
||||
None => {
|
||||
return Err(pyo3::exceptions::PyValueError::new_err(
|
||||
"factory_capsule has no name; expected \"luminal.backend_factory\"",
|
||||
));
|
||||
}
|
||||
}
|
||||
let wrapper_ptr = factory_capsule
|
||||
.pointer_checked(Some(expected))
|
||||
.map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("{e}")))?
|
||||
.as_ptr() as *const *const std::ffi::c_void;
|
||||
let fn_ptr = unsafe { *wrapper_ptr };
|
||||
if fn_ptr.is_null() {
|
||||
return Err(pyo3::exceptions::PyValueError::new_err(
|
||||
"factory_capsule inner function pointer is null",
|
||||
));
|
||||
}
|
||||
unsafe { std::mem::transmute(fn_ptr) }
|
||||
};
|
||||
compile_pt2(
|
||||
pt2_path,
|
||||
weights_path,
|
||||
backend,
|
||||
search_iters,
|
||||
&options,
|
||||
weight_device_ptrs.unwrap_or_default(),
|
||||
factory,
|
||||
)
|
||||
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e:#}")))
|
||||
}
|
||||
@@ -56,14 +146,14 @@ pub fn process_pt2(
|
||||
fn compile_pt2(
|
||||
pt2_path: &str,
|
||||
weights_path: &str,
|
||||
backend: &str,
|
||||
search_iters: usize,
|
||||
options: &CompileOptions,
|
||||
weight_device_ptrs: HashMap<String, (u64, usize)>,
|
||||
factory: BackendFactory,
|
||||
) -> anyhow::Result<CompiledGraph> {
|
||||
let (translation, mut weights) = translate_pt2(pt2_path, weights_path)?;
|
||||
weights.device_ptrs = weight_device_ptrs;
|
||||
|
||||
CompiledGraph::parse_graph(translation, weights, backend, search_iters)
|
||||
CompiledGraph::parse_graph(translation, weights, factory, options.search_iterations)
|
||||
.map_err(|e| anyhow::anyhow!(e))
|
||||
}
|
||||
|
||||
@@ -83,7 +173,7 @@ pub fn translate_pt2(
|
||||
}
|
||||
}
|
||||
|
||||
// Compute shape expressions from PT2 tensor metadata
|
||||
// Compute shape expressions and dtypes from PT2 tensor metadata
|
||||
let output_shape_exprs: Vec<Vec<Expression>> = translated
|
||||
.output_ids
|
||||
.iter()
|
||||
@@ -95,6 +185,17 @@ pub fn translate_pt2(
|
||||
})
|
||||
.collect();
|
||||
|
||||
let output_dtypes: Vec<DType> = translated
|
||||
.output_ids
|
||||
.iter()
|
||||
.map(|(name, _id)| {
|
||||
parsed
|
||||
.tensor_meta(name)
|
||||
.map(|meta| pt2_util::torch_dtype_int_to_luminal(meta.dtype))
|
||||
.unwrap_or(DType::F32)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let input_names: Vec<String> = translated
|
||||
.user_input_ids
|
||||
.iter()
|
||||
@@ -127,7 +228,7 @@ pub fn translate_pt2(
|
||||
}
|
||||
|
||||
// Pre-load weights and compute tensor sizes for CUDA dummy data
|
||||
let mut weights: Vec<(String, Vec<f32>)> = Vec::new();
|
||||
let mut weights: Vec<(String, TypedData)> = Vec::new();
|
||||
let mut tensor_sizes: HashMap<String, usize> = HashMap::new();
|
||||
|
||||
// Load safetensors weights
|
||||
@@ -189,6 +290,7 @@ pub fn translate_pt2(
|
||||
tensor_ids,
|
||||
input_names,
|
||||
output_names,
|
||||
output_dtypes,
|
||||
output_shape_exprs,
|
||||
input_shape_exprs,
|
||||
dim_param_map,
|
||||
@@ -235,8 +337,8 @@ fn preload_safetensors(graph: &Graph, file_path: &str) -> anyhow::Result<Preload
|
||||
.downcast_ref::<luminal::hlir::Input>()
|
||||
&& let Ok(tensor) = st.tensor(&input.label)
|
||||
{
|
||||
let f32s = bytes_to_f32(tensor.data(), safetensors_dtype_to_pt2(tensor.dtype()));
|
||||
weights.push((input.label.clone(), f32s));
|
||||
let types = bytes_to_typed(tensor.data(), safetensors_dtype_to_pt2(tensor.dtype()));
|
||||
weights.push((input.label.clone(), types));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -273,15 +375,12 @@ fn preload_constants(
|
||||
) {
|
||||
Ok(b) => b,
|
||||
Err(e) => {
|
||||
eprintln!(
|
||||
"[luminal] Warning: failed to load constant '{}': {:#}",
|
||||
name, e
|
||||
);
|
||||
warn!("failed to load constant '{}': {:#}", name, e);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
let f32_data = bytes_to_f32(&raw_bytes, entry.tensor_meta.dtype);
|
||||
weights.push((name.clone(), f32_data));
|
||||
let typed_data = bytes_to_typed(&raw_bytes, entry.tensor_meta.dtype);
|
||||
weights.push((name.clone(), typed_data));
|
||||
}
|
||||
|
||||
Ok((weights, sizes))
|
||||
@@ -308,49 +407,121 @@ fn safetensors_dtype_to_pt2(dtype: safetensors::Dtype) -> u32 {
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert raw bytes to f32 using PT2 dtype numbering.
|
||||
fn bytes_to_f32(bytes: &[u8], dtype: u32) -> Vec<f32> {
|
||||
/// Convert raw bytes to TypedData using PT2 dtype numbering.
|
||||
/// Preserves native byte format for types luminal supports directly (f32, f16, bf16, i32, bool, u8, i8).
|
||||
/// Converts i64/f64/i16 to the closest luminal-native representation.
|
||||
fn bytes_to_typed(bytes: &[u8], dtype: u32) -> TypedData {
|
||||
match dtype {
|
||||
7 => bytes
|
||||
.chunks_exact(4)
|
||||
.map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
|
||||
.collect(),
|
||||
6 => bytes
|
||||
.chunks_exact(2)
|
||||
.map(|b| half::f16::from_le_bytes([b[0], b[1]]).to_f32())
|
||||
.collect(),
|
||||
13 => bytes
|
||||
.chunks_exact(2)
|
||||
.map(|b| half::bf16::from_le_bytes([b[0], b[1]]).to_f32())
|
||||
.collect(),
|
||||
8 => bytes
|
||||
.chunks_exact(8)
|
||||
.map(|b| f64::from_le_bytes([b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]]) as f32)
|
||||
.collect(),
|
||||
5 => bytes
|
||||
.chunks_exact(8)
|
||||
.map(|b| i64::from_le_bytes([b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]]) as f32)
|
||||
.collect(),
|
||||
4 => bytes
|
||||
.chunks_exact(4)
|
||||
.map(|b| i32::from_le_bytes([b[0], b[1], b[2], b[3]]) as f32)
|
||||
.collect(),
|
||||
3 => bytes
|
||||
.chunks_exact(2)
|
||||
.map(|b| i16::from_le_bytes([b[0], b[1]]) as f32)
|
||||
.collect(),
|
||||
2 => bytes.iter().map(|&b| (b as i8) as f32).collect(),
|
||||
1 => bytes.iter().map(|&b| b as f32).collect(),
|
||||
12 => bytes
|
||||
.iter()
|
||||
.map(|&b| if b != 0 { 1.0 } else { 0.0 })
|
||||
.collect(),
|
||||
// Types that map directly — preserve raw bytes
|
||||
7 => TypedData::from_raw(bytes.to_vec(), DType::F32),
|
||||
6 => TypedData::from_raw(bytes.to_vec(), DType::F16),
|
||||
13 => TypedData::from_raw(bytes.to_vec(), DType::Bf16),
|
||||
4 => TypedData::from_raw(bytes.to_vec(), DType::Int), // i32
|
||||
1 => TypedData::from_raw(bytes.to_vec(), DType::U8),
|
||||
2 => TypedData::from_raw(bytes.to_vec(), DType::I8),
|
||||
12 => TypedData::from_raw(bytes.to_vec(), DType::Bool),
|
||||
|
||||
// i64 → i32 (truncate, matching luminal's Int type)
|
||||
5 => {
|
||||
let i32s: Vec<i32> = bytes
|
||||
.chunks_exact(8)
|
||||
.map(|b| {
|
||||
i64::from_le_bytes([b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]]) as i32
|
||||
})
|
||||
.collect();
|
||||
TypedData::from_i32_vec(i32s)
|
||||
}
|
||||
// f64 → f32 (downcast, luminal has no F64 in practice for most ops)
|
||||
8 => {
|
||||
let f32s: Vec<f32> = bytes
|
||||
.chunks_exact(8)
|
||||
.map(|b| {
|
||||
f64::from_le_bytes([b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]]) as f32
|
||||
})
|
||||
.collect();
|
||||
TypedData::from_f32_vec(f32s)
|
||||
}
|
||||
// i16 → i32 (widen to luminal's Int)
|
||||
3 => {
|
||||
let i32s: Vec<i32> = bytes
|
||||
.chunks_exact(2)
|
||||
.map(|b| i16::from_le_bytes([b[0], b[1]]) as i32)
|
||||
.collect();
|
||||
TypedData::from_i32_vec(i32s)
|
||||
}
|
||||
_ => {
|
||||
eprintln!("[luminal] Warning: unrecognized dtype {dtype}, interpreting as f32");
|
||||
bytes
|
||||
.chunks_exact(4)
|
||||
.map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
|
||||
.collect()
|
||||
let luminal_dtype = pt2_util::torch_dtype_int_to_luminal(dtype);
|
||||
warn!("Unrecognized dtype {dtype}, interpreting as {luminal_dtype:?}");
|
||||
TypedData::from_raw(bytes.to_vec(), luminal_dtype)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::CompileOptions;
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::PyDict;
|
||||
use std::sync::Once;
|
||||
|
||||
fn with_python(f: impl FnOnce(Python<'_>)) {
|
||||
static INIT: Once = Once::new();
|
||||
INIT.call_once(Python::initialize);
|
||||
Python::attach(f);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compile_options_defaults_apply() {
|
||||
let options = CompileOptions::from_py(None).unwrap();
|
||||
assert_eq!(options.search_iterations, 10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compile_options_dict_overlays_defaults() {
|
||||
with_python(|py| {
|
||||
let options = PyDict::new(py);
|
||||
options.set_item("search_iterations", 3).unwrap();
|
||||
|
||||
let parsed = CompileOptions::from_py(Some(options.as_any())).unwrap();
|
||||
assert_eq!(parsed.search_iterations, 3);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compile_options_reject_unknown_keys() {
|
||||
with_python(|py| {
|
||||
let options = PyDict::new(py);
|
||||
options.set_item("unknown", 1).unwrap();
|
||||
|
||||
let err = CompileOptions::from_py(Some(options.as_any())).unwrap_err();
|
||||
assert!(err.is_instance_of::<pyo3::exceptions::PyValueError>(py));
|
||||
assert!(
|
||||
err.to_string()
|
||||
.contains("Unsupported luminal backend option 'unknown'")
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compile_options_reject_non_dict() {
|
||||
with_python(|py| {
|
||||
let options = 123usize.into_pyobject(py).unwrap();
|
||||
|
||||
let err = CompileOptions::from_py(Some(options.as_any())).unwrap_err();
|
||||
assert!(err.is_instance_of::<pyo3::exceptions::PyTypeError>(py));
|
||||
assert!(err.to_string().contains("options must be a dict"));
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compile_options_reject_bad_search_iterations_type() {
|
||||
with_python(|py| {
|
||||
let options = PyDict::new(py);
|
||||
options.set_item("search_iterations", "fast").unwrap();
|
||||
|
||||
let err = CompileOptions::from_py(Some(options.as_any())).unwrap_err();
|
||||
assert!(err.is_instance_of::<pyo3::exceptions::PyTypeError>(py));
|
||||
assert!(err.to_string().contains("search_iterations"));
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -77,6 +77,7 @@ pub enum Argument {
|
||||
SymInts(SymIntsArg),
|
||||
SymInt(SymIntArg),
|
||||
Expr(ExprArg),
|
||||
#[allow(dead_code)]
|
||||
ScalarType(ScalarTypeArg),
|
||||
Tensors(TensorsArg),
|
||||
OptionalTensors(OptionalTensorsArg),
|
||||
@@ -168,6 +169,7 @@ pub struct NoneArg {
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
#[allow(dead_code)]
|
||||
pub struct ScalarTypeArg {
|
||||
pub as_scalar_type: u32,
|
||||
}
|
||||
@@ -224,6 +226,7 @@ impl Argument {
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn as_scalar_type(&self) -> Option<u32> {
|
||||
match self {
|
||||
Argument::ScalarType(s) => Some(s.as_scalar_type),
|
||||
|
||||
@@ -16,6 +16,7 @@ pub enum ReductionOp {
|
||||
Mean,
|
||||
Max,
|
||||
Min,
|
||||
Prod,
|
||||
}
|
||||
|
||||
/// Normalize a potentially negative dimension index.
|
||||
|
||||
@@ -1,81 +0,0 @@
|
||||
use luminal::prelude::*;
|
||||
#[cfg(feature = "cuda")]
|
||||
use luminal_cuda_lite::cudarc::driver::{CudaContext, CudaStream};
|
||||
#[cfg(feature = "cuda")]
|
||||
use luminal_cuda_lite::runtime::CudaRuntime;
|
||||
use rustc_hash::FxHashMap;
|
||||
#[cfg(feature = "cuda")]
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Enum wrapper for runtime backends allowing runtime selection.
|
||||
pub enum RuntimeBackend {
|
||||
Native(NativeRuntime),
|
||||
#[cfg(feature = "cuda")]
|
||||
Cuda(Box<CudaRuntime>),
|
||||
}
|
||||
|
||||
impl RuntimeBackend {
|
||||
/// Set input data for a tensor node.
|
||||
pub fn set_data(&mut self, node: NodeIndex, data: Vec<f32>) {
|
||||
match self {
|
||||
RuntimeBackend::Native(rt) => rt.set_data(node, data),
|
||||
#[cfg(feature = "cuda")]
|
||||
RuntimeBackend::Cuda(rt) => rt.set_data(node, data),
|
||||
}
|
||||
}
|
||||
|
||||
/// Execute the compiled graph.
|
||||
pub fn execute(&mut self, dyn_map: &FxHashMap<char, usize>) {
|
||||
match self {
|
||||
RuntimeBackend::Native(rt) => rt.execute(dyn_map),
|
||||
#[cfg(feature = "cuda")]
|
||||
RuntimeBackend::Cuda(rt) => rt.execute(dyn_map),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get output data from a tensor node.
|
||||
pub fn get_f32(&self, node: NodeIndex) -> Vec<f32> {
|
||||
match self {
|
||||
RuntimeBackend::Native(rt) => rt.get_f32(node).to_vec(),
|
||||
#[cfg(feature = "cuda")]
|
||||
RuntimeBackend::Cuda(rt) => rt.get_f32(node),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the name of the active backend.
|
||||
pub fn name(&self) -> &'static str {
|
||||
match self {
|
||||
RuntimeBackend::Native(_) => "native",
|
||||
#[cfg(feature = "cuda")]
|
||||
RuntimeBackend::Cuda(_) => "cuda",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Two-phase initialization for CUDA (required because profiling executes graph)
|
||||
// ============================================================================
|
||||
|
||||
/// Prepare CUDA runtime: build search space and create runtime, but don't search yet.
|
||||
/// Returns the unoptimized runtime that can have data set on it.
|
||||
///
|
||||
/// Use this with `finalize_cuda` for proper CUDA initialization:
|
||||
/// 1. Call `prepare_cuda` to get the runtime
|
||||
/// 2. Set data on the runtime using `rt.set_data(node_id, data)`
|
||||
/// 3. Call `finalize_cuda` to run profiling with data available
|
||||
#[cfg(feature = "cuda")]
|
||||
pub fn prepare_cuda(context: &mut Graph) -> Result<(CudaRuntime, Arc<CudaStream>), String> {
|
||||
let cuda_ctx =
|
||||
CudaContext::new(0).map_err(|e| format!("Failed to init CUDA context: {}", e))?;
|
||||
let stream = cuda_ctx.default_stream();
|
||||
context.build_search_space::<CudaRuntime>();
|
||||
let rt = CudaRuntime::initialize(stream.clone());
|
||||
Ok((rt, stream))
|
||||
}
|
||||
|
||||
/// Finalize CUDA runtime: run search with data already set.
|
||||
#[cfg(feature = "cuda")]
|
||||
pub fn finalize_cuda(context: &mut Graph, rt: CudaRuntime) -> RuntimeBackend {
|
||||
let optimized_rt = context.search(rt, 10);
|
||||
RuntimeBackend::Cuda(Box::new(optimized_rt))
|
||||
}
|
||||
@@ -12,6 +12,7 @@ impl<'a> Translator<'a> {
|
||||
let arg1 = &node.inputs[1].arg;
|
||||
if let Some(name) = arg1.as_tensor_name() {
|
||||
let b = self.get_tensor(name)?;
|
||||
let (a, b) = ensure_same_dtype(a, b);
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
Ok(match op {
|
||||
BinaryOp::Add => a + b,
|
||||
|
||||
407
crates/luminal_python/rust/src/translator/conv.rs
Normal file
407
crates/luminal_python/rust/src/translator/conv.rs
Normal file
@@ -0,0 +1,407 @@
|
||||
use anyhow::Result;
|
||||
use luminal::prelude::*;
|
||||
|
||||
use crate::pt2_schema::*;
|
||||
|
||||
use super::Translator;
|
||||
|
||||
const CONV_INPUT_ARG: usize = 0;
|
||||
const CONV_WEIGHT_ARG: usize = 1;
|
||||
const CONV_BIAS_ARG: usize = 2;
|
||||
const CONV_STRIDE_ARG: usize = 3;
|
||||
const CONV_PADDING_ARG: usize = 4;
|
||||
const CONV_DILATION_ARG: usize = 5;
|
||||
const CONV_GROUPS_ARG: usize = 6;
|
||||
|
||||
const CONVOLUTION_TRANSPOSED_ARG: usize = 6;
|
||||
const CONVOLUTION_OUTPUT_PADDING_ARG: usize = 7;
|
||||
const CONVOLUTION_GROUPS_ARG: usize = 8;
|
||||
|
||||
impl<'a> Translator<'a> {
|
||||
/// Translate aten.conv{1,2,3}d.default and aten.convolution.default.
|
||||
///
|
||||
/// The PT2 export may omit defaulted trailing arguments entirely. In practice this means
|
||||
/// conv{N}d.default can show up as just `(input, weight)` for the no-bias, stride=1,
|
||||
/// padding=0, dilation=1, groups=1 case.
|
||||
pub(crate) fn translate_conv(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let input = self.get_input_tensor(node, CONV_INPUT_ARG)?;
|
||||
let weight = self.get_input_tensor(node, CONV_WEIGHT_ARG)?;
|
||||
let bias = self.get_input_tensor(node, CONV_BIAS_ARG).ok();
|
||||
|
||||
let x_dims = input.dims();
|
||||
let w_dims = weight.dims();
|
||||
let rank = x_dims.len();
|
||||
let spatial = rank - 2;
|
||||
|
||||
let stride = self
|
||||
.get_ints_arg(node, CONV_STRIDE_ARG)
|
||||
.unwrap_or_else(|_| vec![1; spatial]);
|
||||
let padding = self
|
||||
.get_ints_arg(node, CONV_PADDING_ARG)
|
||||
.unwrap_or_else(|_| vec![0; spatial]);
|
||||
let mut dilation = self
|
||||
.get_ints_arg(node, CONV_DILATION_ARG)
|
||||
.unwrap_or_else(|_| vec![1; spatial]);
|
||||
let groups = if node.target == "torch.ops.aten.convolution.default" {
|
||||
let transposed = self
|
||||
.get_bool_arg(node, CONVOLUTION_TRANSPOSED_ARG)
|
||||
.unwrap_or(false);
|
||||
anyhow::ensure!(
|
||||
!transposed,
|
||||
"conv: ConvTranspose / transposed=true is not supported yet"
|
||||
);
|
||||
let output_padding = self
|
||||
.get_ints_arg(node, CONVOLUTION_OUTPUT_PADDING_ARG)
|
||||
.unwrap_or_else(|_| vec![0; spatial]);
|
||||
anyhow::ensure!(
|
||||
output_padding.iter().all(|&v| v == 0),
|
||||
"conv: output_padding is not supported for non-transposed convolution"
|
||||
);
|
||||
self.get_int_arg(node, CONVOLUTION_GROUPS_ARG).unwrap_or(1) as usize
|
||||
} else {
|
||||
self.get_int_arg(node, CONV_GROUPS_ARG).unwrap_or(1) as usize
|
||||
};
|
||||
if dilation.len() != spatial {
|
||||
dilation = vec![1; spatial];
|
||||
}
|
||||
|
||||
let ch_out = w_dims[0]
|
||||
.to_usize()
|
||||
.ok_or_else(|| anyhow::anyhow!("conv: weight C_out must be concrete"))?;
|
||||
let ch_in = x_dims[1]
|
||||
.to_usize()
|
||||
.ok_or_else(|| anyhow::anyhow!("conv: input C_in must be concrete"))?;
|
||||
anyhow::ensure!(
|
||||
stride.len() == spatial && padding.len() == spatial && dilation.len() == spatial,
|
||||
"conv: stride/padding/dilation rank must match spatial rank {spatial}"
|
||||
);
|
||||
anyhow::ensure!(
|
||||
groups > 0 && ch_in % groups == 0 && ch_out % groups == 0,
|
||||
"conv: invalid group configuration (C_in={ch_in}, C_out={ch_out}, groups={groups})"
|
||||
);
|
||||
let ch_per_group = ch_in / groups;
|
||||
|
||||
let kernel_shape: Vec<usize> = w_dims[2..]
|
||||
.iter()
|
||||
.map(|d| {
|
||||
d.to_usize()
|
||||
.ok_or_else(|| anyhow::anyhow!("conv: kernel dims must be concrete"))
|
||||
})
|
||||
.collect::<Result<_>>()?;
|
||||
let kernel_product: usize = kernel_shape.iter().product();
|
||||
|
||||
// ATen uses symmetric padding (same begin/end)
|
||||
let stride_u: Vec<usize> = stride.iter().map(|&v| v as usize).collect();
|
||||
let padding_u: Vec<usize> = padding.iter().map(|&v| v as usize).collect();
|
||||
let dilation_u: Vec<usize> = dilation.iter().map(|&v| v as usize).collect();
|
||||
|
||||
let mut out = if groups > 1 {
|
||||
let group_out = ch_out / groups;
|
||||
|
||||
if ch_per_group == 1 {
|
||||
// Depthwise (including channel multiplier > 1): avoid per-channel slicing.
|
||||
depthwise_conv(
|
||||
input,
|
||||
weight,
|
||||
&kernel_shape,
|
||||
&stride_u,
|
||||
&dilation_u,
|
||||
&padding_u,
|
||||
&padding_u,
|
||||
ch_in,
|
||||
group_out,
|
||||
kernel_product,
|
||||
spatial,
|
||||
)
|
||||
} else {
|
||||
// General grouped: pre-pad full input then slice per group
|
||||
let padded_input = {
|
||||
let mut pad_spec: Vec<(Expression, Expression)> =
|
||||
vec![(0.into(), 0.into()); 2 + spatial];
|
||||
for i in 0..spatial {
|
||||
pad_spec[2 + i] = (padding_u[i].into(), padding_u[i].into());
|
||||
}
|
||||
input.pad(pad_spec, 0.0)
|
||||
};
|
||||
|
||||
let no_pad = vec![0usize; spatial];
|
||||
let mut group_outputs = Vec::with_capacity(groups);
|
||||
for g in 0..groups {
|
||||
let x_g = slice_channel_group(padded_input, g, ch_per_group, spatial);
|
||||
let w_g =
|
||||
slice_weight_group(weight, g, group_out, ch_per_group * kernel_product);
|
||||
group_outputs.push(conv_unfold(
|
||||
x_g,
|
||||
w_g,
|
||||
&kernel_shape,
|
||||
&stride_u,
|
||||
&dilation_u,
|
||||
&no_pad,
|
||||
&no_pad,
|
||||
ch_per_group,
|
||||
group_out,
|
||||
spatial,
|
||||
));
|
||||
}
|
||||
|
||||
let mut result = group_outputs[0];
|
||||
for g_out in &group_outputs[1..] {
|
||||
result = result.concat_along(*g_out, 1);
|
||||
}
|
||||
result
|
||||
}
|
||||
} else {
|
||||
let mut w_flat = weight;
|
||||
w_flat.shape = ShapeTracker::new_with_element_bits(
|
||||
vec![ch_out, ch_in * kernel_product],
|
||||
weight.dtype.bits(),
|
||||
);
|
||||
|
||||
conv_unfold(
|
||||
input,
|
||||
w_flat,
|
||||
&kernel_shape,
|
||||
&stride_u,
|
||||
&dilation_u,
|
||||
&padding_u,
|
||||
&padding_u,
|
||||
ch_in,
|
||||
ch_out,
|
||||
spatial,
|
||||
)
|
||||
};
|
||||
|
||||
if let Some(b) = bias {
|
||||
let out_dims = out.dims();
|
||||
let mut b_expanded = b.expand_dim(0, 1);
|
||||
for i in 0..spatial {
|
||||
b_expanded = b_expanded.expand_dim(2 + i, out_dims[2 + i]);
|
||||
}
|
||||
out += b_expanded;
|
||||
}
|
||||
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
/// Slice input channels for one group.
|
||||
/// Caller must pre-pad `x` so no additional padding is applied to the slice.
|
||||
fn slice_channel_group(
|
||||
x: GraphTensor,
|
||||
g: usize,
|
||||
ch_per_group: usize,
|
||||
spatial: usize,
|
||||
) -> GraphTensor {
|
||||
let start = g * ch_per_group;
|
||||
let end = start + ch_per_group;
|
||||
let dims = x.dims();
|
||||
let rank = 2 + spatial;
|
||||
let mut slices: Vec<(Expression, Expression)> = Vec::with_capacity(rank);
|
||||
slices.push((0.into(), dims[0]));
|
||||
slices.push((start.into(), end.into()));
|
||||
for dim in dims.iter().take(rank).skip(2) {
|
||||
slices.push((0.into(), *dim));
|
||||
}
|
||||
x.slice(slices)
|
||||
}
|
||||
|
||||
/// Slice and flatten weight for one group.
|
||||
fn slice_weight_group(
|
||||
w: GraphTensor,
|
||||
g: usize,
|
||||
group_out: usize,
|
||||
flat_inner: usize,
|
||||
) -> GraphTensor {
|
||||
let start = g * group_out;
|
||||
let end = start + group_out;
|
||||
let w_dims = w.dims();
|
||||
let mut slices: Vec<(Expression, Expression)> = Vec::with_capacity(w_dims.len());
|
||||
slices.push((start.into(), end.into()));
|
||||
for dim in w_dims.iter().skip(1) {
|
||||
slices.push((0.into(), *dim));
|
||||
}
|
||||
// Materialize through Add: binary op outputs are contiguous in Luminal, which makes the
|
||||
// following flatten safe for the sliced weight buffer.
|
||||
let w_sliced = w.slice(slices) + 0.0;
|
||||
let mut w_flat = w_sliced;
|
||||
w_flat.shape =
|
||||
ShapeTracker::new_with_element_bits(vec![group_out, flat_inner], w_sliced.dtype.bits());
|
||||
w_flat
|
||||
}
|
||||
|
||||
/// Core unfold-based convolution for a single group.
|
||||
///
|
||||
/// `x`: [batch, ch_in, spatial...]
|
||||
/// `w_flat`: [ch_out, ch_in * kernel_product] (already reshaped)
|
||||
/// Returns: [batch, ch_out, out_spatial...]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn conv_unfold(
|
||||
x: GraphTensor,
|
||||
w_flat: GraphTensor,
|
||||
kernel_shape: &[usize],
|
||||
strides: &[usize],
|
||||
dilations: &[usize],
|
||||
pads_begin: &[usize],
|
||||
pads_end: &[usize],
|
||||
_ch_in: usize,
|
||||
_ch_out: usize,
|
||||
spatial: usize,
|
||||
) -> GraphTensor {
|
||||
let rank = 2 + spatial;
|
||||
|
||||
// Pad spatial dimensions (skip if all padding is zero)
|
||||
let needs_pad = pads_begin.iter().any(|&p| p > 0) || pads_end.iter().any(|&p| p > 0);
|
||||
let padded = if needs_pad {
|
||||
let mut padding: Vec<(Expression, Expression)> = vec![(0.into(), 0.into()); rank];
|
||||
for i in 0..spatial {
|
||||
padding[2 + i] = (pads_begin[i].into(), pads_end[i].into());
|
||||
}
|
||||
x.pad(padding, 0.0)
|
||||
} else {
|
||||
x
|
||||
};
|
||||
|
||||
// Build full-rank unfold parameters (1 for batch/channel, actual for spatial)
|
||||
let mut kernel_full = vec![1usize; rank];
|
||||
let mut stride_full = vec![1usize; rank];
|
||||
let mut dilation_full = vec![1usize; rank];
|
||||
kernel_full[2..(spatial + 2)].copy_from_slice(&kernel_shape[..spatial]);
|
||||
stride_full[2..(spatial + 2)].copy_from_slice(&strides[..spatial]);
|
||||
dilation_full[2..(spatial + 2)].copy_from_slice(&dilations[..spatial]);
|
||||
|
||||
let unfolded = padded.unfold(kernel_full, stride_full, dilation_full);
|
||||
// Shape: [win_N, win_C, win_spatial..., k_N=1, k_C=1, k_spatial...]
|
||||
|
||||
// Permute to [N, win_spatial..., C_in, k_N, k_C, k_spatial...]
|
||||
let mut perm: Vec<usize> = Vec::with_capacity(2 * rank);
|
||||
perm.push(0);
|
||||
perm.extend(2..2 + spatial);
|
||||
perm.push(1);
|
||||
perm.extend(rank..2 * rank);
|
||||
let permuted = unfolded.permute(perm);
|
||||
|
||||
let output_spatial_dims: Vec<Expression> = permuted.dims()[1..1 + spatial].to_vec();
|
||||
|
||||
// Merge all channel+kernel dims into [N, spatial..., ch_in * kernel_product]
|
||||
let mut patches = permuted;
|
||||
let target = 2 + spatial;
|
||||
while patches.dims().len() > target {
|
||||
let last = patches.dims().len();
|
||||
patches = patches.merge_dims(last - 2, last - 1);
|
||||
}
|
||||
|
||||
// Merge spatial dims into one
|
||||
for _ in 1..spatial {
|
||||
patches = patches.merge_dims(1, 2);
|
||||
}
|
||||
// patches: [N, spatial_product, ch_in * kernel_product]
|
||||
|
||||
let mut out = patches.matmul(w_flat.permute((1, 0)));
|
||||
// out: [N, spatial_product, ch_out]
|
||||
|
||||
// Restore spatial dimensions
|
||||
for i in (1..spatial).rev() {
|
||||
out = out.split_dims(1, output_spatial_dims[i]);
|
||||
}
|
||||
|
||||
// Move ch_out from last to position 1: [N, ch_out, spatial...]
|
||||
let mut final_order: Vec<usize> = Vec::with_capacity(2 + spatial);
|
||||
final_order.push(0);
|
||||
final_order.push(1 + spatial);
|
||||
final_order.extend(1..1 + spatial);
|
||||
out.permute(final_order)
|
||||
}
|
||||
|
||||
/// Depthwise convolution: groups == in_channels, ch_per_group == 1.
|
||||
///
|
||||
/// Processes all channels simultaneously using element-wise multiply + reduce,
|
||||
/// avoiding per-channel input slicing which can cause index-expression bugs in luminal.
|
||||
///
|
||||
/// out[n, c, oh, ow] = sum_k patches[n, c, oh, ow, k] * weight[c, k]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn depthwise_conv(
|
||||
x: GraphTensor,
|
||||
w: GraphTensor, // [C, 1, *kernel]
|
||||
kernel_shape: &[usize],
|
||||
strides: &[usize],
|
||||
dilations: &[usize],
|
||||
pads_begin: &[usize],
|
||||
pads_end: &[usize],
|
||||
ch: usize,
|
||||
group_out: usize,
|
||||
kernel_product: usize,
|
||||
spatial: usize,
|
||||
) -> GraphTensor {
|
||||
let rank = 2 + spatial;
|
||||
|
||||
let needs_pad = pads_begin.iter().any(|&p| p > 0) || pads_end.iter().any(|&p| p > 0);
|
||||
let padded = if needs_pad {
|
||||
let mut padding: Vec<(Expression, Expression)> = vec![(0.into(), 0.into()); rank];
|
||||
for i in 0..spatial {
|
||||
padding[2 + i] = (pads_begin[i].into(), pads_end[i].into());
|
||||
}
|
||||
x.pad(padding, 0.0)
|
||||
} else {
|
||||
x
|
||||
};
|
||||
|
||||
// Unfold the full [N, C, H+2p, W+2p] with kernel [1, 1, kH, kW]
|
||||
let mut kernel_full = vec![1usize; rank];
|
||||
let mut stride_full = vec![1usize; rank];
|
||||
let mut dilation_full = vec![1usize; rank];
|
||||
kernel_full[2..(spatial + 2)].copy_from_slice(&kernel_shape[..spatial]);
|
||||
stride_full[2..(spatial + 2)].copy_from_slice(&strides[..spatial]);
|
||||
dilation_full[2..(spatial + 2)].copy_from_slice(&dilations[..spatial]);
|
||||
|
||||
let unfolded = padded.unfold(kernel_full, stride_full, dilation_full);
|
||||
// Shape: [N, C, out_H, out_W, 1, 1, kH, kW]
|
||||
|
||||
// Permute to [N, C, out_spatial..., k_all...]
|
||||
let mut perm: Vec<usize> = Vec::with_capacity(2 * rank);
|
||||
perm.push(0); // N
|
||||
perm.push(1); // C
|
||||
perm.extend(2..2 + spatial); // win_spatial
|
||||
perm.extend(rank..2 * rank); // all kernel dims
|
||||
let permuted = unfolded.permute(perm);
|
||||
|
||||
let out_spatial_dims: Vec<Expression> = permuted.dims()[2..2 + spatial].to_vec();
|
||||
|
||||
// Merge all kernel dims (including 1-size k_N, k_C) into kernel_product
|
||||
let target = 3 + spatial; // [N, C, spatial..., K]
|
||||
let mut patches = permuted;
|
||||
while patches.dims().len() > target {
|
||||
let last = patches.dims().len();
|
||||
patches = patches.merge_dims(last - 2, last - 1);
|
||||
}
|
||||
// patches: [N, C, out_H, ..., out_W, kernel_product]
|
||||
|
||||
// Merge spatial into one: [N, C, out_spatial_product, kernel_product]
|
||||
for _ in 1..spatial {
|
||||
patches = patches.merge_dims(2, 3);
|
||||
}
|
||||
|
||||
// Weight [C * group_out, 1, *kernel] -> [C, group_out, kernel_product]
|
||||
let mut w_flat = w;
|
||||
w_flat.shape =
|
||||
ShapeTracker::new_with_element_bits(vec![ch, group_out, kernel_product], w.dtype.bits());
|
||||
|
||||
// patches: [N, C, out_spatial_product, kernel_product]
|
||||
// Expand to [N, C, group_out, out_spatial_product, kernel_product]
|
||||
let patches = patches.expand_dim(2, group_out);
|
||||
|
||||
// Expand weight for broadcast: [1, C, group_out, out_spatial_product, kernel_product]
|
||||
let w_expanded = w_flat.expand_dim(0, 1).expand_dim(3, patches.dims()[3]);
|
||||
|
||||
// Element-wise multiply and sum over kernel dim
|
||||
let product = patches * w_expanded;
|
||||
let mut out = product.sum(vec![4]).merge_dims(1, 2);
|
||||
// out: [N, C * group_out, out_spatial_product]
|
||||
|
||||
// Restore spatial dimensions
|
||||
for i in (1..spatial).rev() {
|
||||
out = out.split_dims(2, out_spatial_dims[i]);
|
||||
}
|
||||
// out: [N, C, out_spatial_0, ..., out_spatial_{s-1}]
|
||||
|
||||
out
|
||||
}
|
||||
@@ -51,6 +51,7 @@ impl<'a> Translator<'a> {
|
||||
"torch.ops.aten.sub.Scalar" => self.translate_binary_scalar_op(node, BinaryOp::Sub)?,
|
||||
"torch.ops.aten.div.Tensor" => self.translate_binary_op(node, BinaryOp::Div)?,
|
||||
"torch.ops.aten.div.Scalar" => self.translate_binary_scalar_op(node, BinaryOp::Div)?,
|
||||
"torch.ops.aten.div.Tensor_mode" => self.translate_div_tensor_mode(node)?,
|
||||
|
||||
// Unary ops
|
||||
"torch.ops.aten.neg.default" => self.translate_unary_op(node, |a| a * (-1.0))?,
|
||||
@@ -66,74 +67,75 @@ impl<'a> Translator<'a> {
|
||||
}
|
||||
"torch.ops.aten.sigmoid.default" => self.translate_unary_op(node, |a| a.sigmoid())?,
|
||||
"torch.ops.aten.relu.default" => self.translate_unary_op(node, |a| a.relu())?,
|
||||
"torch.ops.aten.silu.default" => self.translate_unary_op(node, |a| a.swish())?,
|
||||
"torch.ops.aten.tanh.default" => self.translate_unary_op(node, |a| a.tanh())?,
|
||||
"torch.ops.aten.abs.default" => self.translate_unary_op(node, |a| a.abs())?,
|
||||
"torch.ops.aten.log.default" => self.translate_unary_op(node, |a| a.log())?,
|
||||
"torch.ops.aten.log2.default" => self.translate_unary_op(node, |a| a.log2())?,
|
||||
"torch.ops.aten.exp2.default" => self.translate_unary_op(node, |a| a.exp2())?,
|
||||
"torch.ops.aten.sign.default" => self.translate_sign(node)?,
|
||||
"torch.ops.aten.bitwise_not.default" => self.translate_bitwise_not(node)?,
|
||||
|
||||
// Cast
|
||||
"torch.ops.aten._to_copy.default" => self.translate_to_copy(node)?,
|
||||
"torch.ops.aten.to.dtype" => self.translate_to_dtype(node)?,
|
||||
"torch.ops.aten.to.dtype_layout" => self.translate_to_dtype_layout(node)?,
|
||||
|
||||
// No-op pass-throughs
|
||||
"torch.ops.aten.alias.default"
|
||||
| "torch.ops.aten.detach_.default"
|
||||
| "torch.ops.aten.lift_fresh_copy.default" => self.get_input_tensor(node, 0)?,
|
||||
"torch.ops.aten.dropout.default" => self.get_input_tensor(node, 0)?,
|
||||
// No-op
|
||||
"torch.ops.aten.alias.default" => self.get_input_tensor(node, 0)?,
|
||||
|
||||
// Shape ops
|
||||
"torch.ops.aten.view.default"
|
||||
| "torch.ops.aten.reshape.default"
|
||||
| "torch.ops.aten._unsafe_view.default" => self.translate_reshape(node)?,
|
||||
"torch.ops.aten.view.default" => self.translate_reshape(node)?,
|
||||
"torch.ops.aten.permute.default" => self.translate_permute(node)?,
|
||||
"torch.ops.aten.transpose.int" => self.translate_transpose(node)?,
|
||||
"torch.ops.aten.t.default" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
a.t()
|
||||
}
|
||||
"torch.ops.aten.unsqueeze.default" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let dim = self.get_int_arg(node, 1)?;
|
||||
let dim = normalize_dim(dim, a.shape.len() + 1);
|
||||
a.unsqueeze(dim)
|
||||
}
|
||||
"torch.ops.aten.squeeze.dim" | "torch.ops.aten.squeeze.default" => {
|
||||
"torch.ops.aten.squeeze.dims" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
if node.inputs.len() > 1 {
|
||||
let dim = self.get_int_arg(node, 1)?;
|
||||
let dim = normalize_dim(dim, a.shape.len());
|
||||
a.squeeze(dim)
|
||||
} else {
|
||||
let mut result = a;
|
||||
let dims = a.shape.dims;
|
||||
let mut offset = 0;
|
||||
for (i, d) in dims.iter().enumerate() {
|
||||
if d.to_usize() == Some(1) {
|
||||
result = result.squeeze(i - offset);
|
||||
offset += 1;
|
||||
}
|
||||
let dims = self.get_ints_arg(node, 1)?;
|
||||
let ndim = a.shape.len();
|
||||
let mut sorted_dims: Vec<usize> =
|
||||
dims.iter().map(|&d| normalize_dim(d, ndim)).collect();
|
||||
sorted_dims.sort();
|
||||
let mut result = a;
|
||||
let mut offset = 0;
|
||||
for d in sorted_dims {
|
||||
if result.shape.dims[d - offset].to_usize() == Some(1) {
|
||||
result = result.squeeze(d - offset);
|
||||
offset += 1;
|
||||
}
|
||||
result
|
||||
}
|
||||
result
|
||||
}
|
||||
"torch.ops.aten.expand.default" => self.translate_expand(node)?,
|
||||
"torch.ops.aten.contiguous.default" | "torch.ops.aten.clone.default" => {
|
||||
"torch.ops.aten.clone.default" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
if !a.shape.is_contiguous() { a + 0.0 } else { a }
|
||||
}
|
||||
"torch.ops.aten.argsort.default" => self.translate_argsort(node)?,
|
||||
|
||||
// Matmul
|
||||
"torch.ops.aten.mm.default"
|
||||
| "torch.ops.aten.bmm.default"
|
||||
| "torch.ops.aten.matmul.default" => {
|
||||
"torch.ops.aten.mm.default" | "torch.ops.aten.bmm.default" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let b = self.get_input_tensor(node, 1)?;
|
||||
let (a, b) = ensure_same_dtype(a, b);
|
||||
a.matmul(b)
|
||||
}
|
||||
|
||||
// Linear
|
||||
"torch.ops.aten.linear.default" => self.translate_linear(node)?,
|
||||
// addmm: beta*input + alpha*(mat1 @ mat2)
|
||||
"torch.ops.aten.addmm.default" => {
|
||||
let input = self.get_input_tensor(node, 0)?;
|
||||
let mat1 = self.get_input_tensor(node, 1)?;
|
||||
let mat2 = self.get_input_tensor(node, 2)?;
|
||||
let beta = self.get_float_arg(node, 3).unwrap_or(1.0) as f32;
|
||||
let alpha = self.get_float_arg(node, 4).unwrap_or(1.0) as f32;
|
||||
let mm = mat1.matmul(mat2);
|
||||
let (input, mm) = broadcast_binary(input, mm);
|
||||
input * beta + mm * alpha
|
||||
}
|
||||
|
||||
// Convolution
|
||||
"torch.ops.aten.convolution.default" => self.translate_conv(node)?,
|
||||
|
||||
// Reduction ops
|
||||
"torch.ops.aten.sum.dim_IntList" => self.translate_reduction(node, ReductionOp::Sum)?,
|
||||
@@ -142,16 +144,14 @@ impl<'a> Translator<'a> {
|
||||
|
||||
// Slice/index ops
|
||||
"torch.ops.aten.slice.Tensor" => self.translate_slice(node)?,
|
||||
"torch.ops.aten.select.int" => self.translate_select(node)?,
|
||||
"torch.ops.aten.cat.default" => self.translate_cat(node)?,
|
||||
"torch.ops.aten.index_select.default" => self.translate_index_select(node)?,
|
||||
"torch.ops.aten.index.Tensor" => self.translate_index_tensor(node)?,
|
||||
|
||||
// Embedding
|
||||
"torch.ops.aten.embedding.default" => self.translate_embedding(node)?,
|
||||
|
||||
// Softmax
|
||||
"torch.ops.aten._softmax.default" | "torch.ops.aten.softmax.int" => {
|
||||
"torch.ops.aten._softmax.default" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let dim = self.get_int_arg(node, 1)?;
|
||||
let dim = normalize_dim(dim, a.shape.len());
|
||||
@@ -159,11 +159,12 @@ impl<'a> Translator<'a> {
|
||||
}
|
||||
|
||||
// LayerNorm
|
||||
"torch.ops.aten.layer_norm.default" => self.translate_layer_norm(node)?,
|
||||
"torch.ops.aten.native_layer_norm.default" => self.translate_layer_norm(node)?,
|
||||
|
||||
// Where
|
||||
"torch.ops.aten.where.self" => self.translate_where(node)?,
|
||||
"torch.ops.aten.where.ScalarOther" => self.translate_where_scalar_other(node)?,
|
||||
"torch.ops.aten.masked_fill.Scalar" => self.translate_masked_fill_scalar(node)?,
|
||||
|
||||
// Pow
|
||||
"torch.ops.aten.pow.Tensor_Scalar" => {
|
||||
@@ -179,18 +180,13 @@ impl<'a> Translator<'a> {
|
||||
}
|
||||
|
||||
// Creation ops
|
||||
"torch.ops.aten.arange.default" | "torch.ops.aten.arange.start" => {
|
||||
self.translate_arange(node)?
|
||||
}
|
||||
"torch.ops.aten.arange.start_step" => self.translate_arange(node)?,
|
||||
"torch.ops.aten.full.default" => self.translate_full(node)?,
|
||||
"torch.ops.aten.zeros.default" | "torch.ops.aten.zeros_like.default" => {
|
||||
self.translate_zeros(node)?
|
||||
"torch.ops.aten.full_like.default" => self.translate_full_like(node)?,
|
||||
"torch.ops.aten.scalar_tensor.default" => {
|
||||
let val = self.get_float_arg(node, 0)? as f32;
|
||||
self.graph.constant_float(val)
|
||||
}
|
||||
"torch.ops.aten.ones.default" | "torch.ops.aten.ones_like.default" => {
|
||||
self.translate_ones(node)?
|
||||
}
|
||||
"torch.ops.aten.new_ones.default" => self.translate_new_ones(node)?,
|
||||
|
||||
// Scalar comparisons
|
||||
"torch.ops.aten.gt.Scalar" => self.translate_scalar_comparison(node, |a, s| a.gt(s))?,
|
||||
"torch.ops.aten.lt.Scalar" => self.translate_scalar_comparison(node, |a, s| a.lt(s))?,
|
||||
@@ -222,7 +218,7 @@ impl<'a> Translator<'a> {
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
a.le(b)
|
||||
}
|
||||
"torch.ops.aten.__and__.Tensor" | "torch.ops.aten.logical_and.default" => {
|
||||
"torch.ops.aten.bitwise_and.Tensor" | "torch.ops.aten.logical_and.default" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let b = self.get_input_tensor(node, 1)?;
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
@@ -248,9 +244,7 @@ impl<'a> Translator<'a> {
|
||||
}
|
||||
|
||||
// Clamp
|
||||
"torch.ops.aten.clamp.default" | "torch.ops.aten.clamp_min.default" => {
|
||||
self.translate_clamp(node)?
|
||||
}
|
||||
"torch.ops.aten.clamp.default" => self.translate_clamp(node)?,
|
||||
|
||||
// Cumsum
|
||||
"torch.ops.aten.cumsum.default" => {
|
||||
@@ -265,9 +259,6 @@ impl<'a> Translator<'a> {
|
||||
a.cumsum(dim)
|
||||
}
|
||||
|
||||
// Diff
|
||||
"torch.ops.aten.diff.default" => self.translate_diff(node)?,
|
||||
|
||||
// Floor / Ceil / Erf (approximations)
|
||||
"torch.ops.aten.floor.default" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
@@ -352,45 +343,12 @@ impl<'a> Translator<'a> {
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
a.gt(b)
|
||||
}
|
||||
"torch.ops.aten.ne.Tensor" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let b = self.get_input_tensor(node, 1)?;
|
||||
let (a, b) = ensure_same_dtype(a, b);
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
a.ne(b)
|
||||
}
|
||||
|
||||
// Reductions without dim arg (full reduce)
|
||||
// Flatten to [1, N] and reduce axis 1 to avoid multi-step HLIR
|
||||
// that CUDA can't schedule (grid (0,1,1) invalid launch).
|
||||
"torch.ops.aten.sum.default" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let total = concrete_numel(&a)?;
|
||||
let mut flat = a;
|
||||
flat.shape = ShapeTracker::new(vec![1, total]);
|
||||
flat.sum(vec![1])
|
||||
}
|
||||
"torch.ops.aten.mean.default" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let total = concrete_numel(&a)?;
|
||||
let mut flat = a;
|
||||
flat.shape = ShapeTracker::new(vec![1, total]);
|
||||
flat.sum(vec![1]) / total as f32
|
||||
}
|
||||
"torch.ops.aten.max.default" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let total = concrete_numel(&a)?;
|
||||
let mut flat = a;
|
||||
flat.shape = ShapeTracker::new(vec![1, total]);
|
||||
flat.max(vec![1])
|
||||
}
|
||||
"torch.ops.aten.min.default" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let total = concrete_numel(&a)?;
|
||||
let mut flat = a;
|
||||
flat.shape = ShapeTracker::new(vec![1, total]);
|
||||
flat.min(vec![1])
|
||||
}
|
||||
// Full-reduce variants (no dim arg) — handled by translate_reduction fallback
|
||||
"torch.ops.aten.sum.default" => self.translate_reduction(node, ReductionOp::Sum)?,
|
||||
"torch.ops.aten.mean.default" => self.translate_reduction(node, ReductionOp::Mean)?,
|
||||
"torch.ops.aten.max.default" => self.translate_reduction(node, ReductionOp::Max)?,
|
||||
"torch.ops.aten.min.default" => self.translate_reduction(node, ReductionOp::Min)?,
|
||||
"torch.ops.aten.amin.default" => self.translate_reduction(node, ReductionOp::Min)?,
|
||||
|
||||
// Gather (axis-aware)
|
||||
@@ -398,7 +356,13 @@ impl<'a> Translator<'a> {
|
||||
|
||||
// Scatter ops
|
||||
"torch.ops.aten.scatter.src" => self.translate_scatter_src(node)?,
|
||||
"torch.ops.aten.index_put_.default" => self.translate_index_put(node)?,
|
||||
"torch.ops.aten.scatter.value" => self.translate_scatter_value(node)?,
|
||||
"torch.ops.aten.index_put_.default" | "torch.ops.aten.index_put.default" => {
|
||||
self.translate_index_put(node)?
|
||||
}
|
||||
|
||||
// Integer routing math
|
||||
"torch.ops.aten.floor_divide.default" => self.translate_floor_divide(node)?,
|
||||
|
||||
// Triangular
|
||||
"torch.ops.aten.tril.default" => self.translate_tril(node)?,
|
||||
@@ -410,13 +374,14 @@ impl<'a> Translator<'a> {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Split
|
||||
"torch.ops.aten.split.Tensor" | "torch.ops.aten.split_with_sizes.default" => {
|
||||
self.translate_split(node)?
|
||||
// Sort — handles its own output storage, returns early
|
||||
"torch.ops.aten.sort.default" => {
|
||||
self.translate_sort(node)?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// One-hot
|
||||
"torch.ops.aten.one_hot.default" => self.translate_one_hot(node)?,
|
||||
// Split
|
||||
"torch.ops.aten.split_with_sizes.default" => self.translate_split_with_sizes(node)?,
|
||||
|
||||
// Fmod
|
||||
"torch.ops.aten.fmod.Tensor" => {
|
||||
@@ -425,12 +390,8 @@ impl<'a> Translator<'a> {
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
a % b
|
||||
}
|
||||
"torch.ops.aten.fmod.Scalar" | "torch.ops.aten.remainder.Scalar" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let val = self.get_float_arg(node, 1)? as f32;
|
||||
let b = self.graph.constant_float(val).expand_rhs(a.shape);
|
||||
a % b
|
||||
}
|
||||
// Prod reduction
|
||||
"torch.ops.aten.prod.dim_int" => self.translate_reduction(node, ReductionOp::Prod)?,
|
||||
|
||||
other => {
|
||||
bail!("Unsupported ATen op: {other}");
|
||||
@@ -444,15 +405,6 @@ impl<'a> Translator<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute total element count, returning an error if any dimension is symbolic.
|
||||
fn concrete_numel(a: &GraphTensor) -> Result<usize> {
|
||||
a.dims().iter().try_fold(1usize, |acc, d| {
|
||||
d.to_usize().map(|v| acc * v).ok_or_else(|| {
|
||||
anyhow::anyhow!("Full reduction requires concrete dimensions, got symbolic dim")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
impl<'a> Translator<'a> {
|
||||
fn translate_scalar_comparison(
|
||||
&mut self,
|
||||
|
||||
@@ -1,23 +0,0 @@
|
||||
use anyhow::Result;
|
||||
use luminal::prelude::*;
|
||||
|
||||
use crate::pt2_schema::*;
|
||||
use crate::pt2_util::broadcast_binary;
|
||||
|
||||
use super::Translator;
|
||||
|
||||
impl<'a> Translator<'a> {
|
||||
pub(crate) fn translate_linear(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let input = self.get_input_tensor(node, 0)?;
|
||||
let weight = self.get_input_tensor(node, 1)?;
|
||||
let result = input.matmul(weight.t());
|
||||
|
||||
if node.inputs.len() > 2
|
||||
&& let Ok(bias) = self.get_input_tensor(node, 2)
|
||||
{
|
||||
let (result, bias) = broadcast_binary(result, bias);
|
||||
return Ok(result + bias);
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
@@ -3,8 +3,8 @@
|
||||
//! Walks the parsed PT2 graph and constructs an equivalent Luminal computation graph.
|
||||
|
||||
mod binary;
|
||||
mod conv;
|
||||
mod dispatch;
|
||||
mod matmul;
|
||||
mod movement;
|
||||
mod reduction;
|
||||
mod tensor;
|
||||
@@ -18,6 +18,7 @@ use luminal::prelude::*;
|
||||
|
||||
use crate::pt2_parser::{InputKind, ParsedPT2, SymDimMap};
|
||||
use crate::pt2_schema::*;
|
||||
use crate::pt2_util;
|
||||
|
||||
/// Result of translating a PT2 graph to a Luminal graph.
|
||||
pub struct TranslatedGraph {
|
||||
@@ -76,7 +77,13 @@ impl<'a> Translator<'a> {
|
||||
let output_names = self.parsed.output_names();
|
||||
for name in &output_names {
|
||||
let tensor = self.get_tensor(name)?;
|
||||
let tensor = tensor + 0.0;
|
||||
let tensor = if tensor.dtype == DType::Bool {
|
||||
tensor.cast(DType::Int).cast(DType::Bool)
|
||||
} else if tensor.dtype == DType::Int {
|
||||
tensor
|
||||
} else {
|
||||
tensor + 0.0
|
||||
};
|
||||
tensor.output();
|
||||
self.output_ids.push((name.clone(), tensor.id));
|
||||
}
|
||||
@@ -97,7 +104,12 @@ impl<'a> Translator<'a> {
|
||||
.tensor_meta(graph_name)
|
||||
.with_context(|| format!("Missing tensor meta for param {graph_name}"))?;
|
||||
let shape = self.tensor_meta_to_shape(meta)?;
|
||||
let tensor = self.graph.named_tensor(original_name, shape);
|
||||
let dtype = pt2_util::torch_dtype_int_to_luminal(meta.dtype);
|
||||
let tensor = self
|
||||
.graph
|
||||
.named_tensor(original_name, shape)
|
||||
.as_dtype(dtype);
|
||||
tensor.persist();
|
||||
self.tensors.insert(graph_name.clone(), tensor);
|
||||
}
|
||||
InputKind::Buffer {
|
||||
@@ -109,7 +121,12 @@ impl<'a> Translator<'a> {
|
||||
.tensor_meta(graph_name)
|
||||
.with_context(|| format!("Missing tensor meta for buffer {graph_name}"))?;
|
||||
let shape = self.tensor_meta_to_shape(meta)?;
|
||||
let tensor = self.graph.named_tensor(original_name, shape);
|
||||
let dtype = pt2_util::torch_dtype_int_to_luminal(meta.dtype);
|
||||
let tensor = self
|
||||
.graph
|
||||
.named_tensor(original_name, shape)
|
||||
.as_dtype(dtype);
|
||||
tensor.persist();
|
||||
self.tensors.insert(graph_name.clone(), tensor);
|
||||
}
|
||||
InputKind::UserInput { graph_name } => {
|
||||
@@ -118,7 +135,8 @@ impl<'a> Translator<'a> {
|
||||
.tensor_meta(graph_name)
|
||||
.with_context(|| format!("Missing tensor meta for input {graph_name}"))?;
|
||||
let shape = self.tensor_meta_to_shape(meta)?;
|
||||
let tensor = self.graph.named_tensor(graph_name, shape);
|
||||
let dtype = pt2_util::torch_dtype_int_to_luminal(meta.dtype);
|
||||
let tensor = self.graph.named_tensor(graph_name, shape).as_dtype(dtype);
|
||||
self.user_input_ids.push((graph_name.clone(), tensor.id));
|
||||
self.tensors.insert(graph_name.clone(), tensor);
|
||||
}
|
||||
@@ -138,7 +156,6 @@ impl<'a> Translator<'a> {
|
||||
|
||||
// --- Helper methods ---
|
||||
|
||||
/// Look up tensor metadata by name, checking subgraph extras first.
|
||||
pub(crate) fn tensor_meta(&self, name: &str) -> Option<&TensorMeta> {
|
||||
self.extra_tensor_values
|
||||
.get(name)
|
||||
|
||||
@@ -6,6 +6,11 @@ use crate::pt2_util::*;
|
||||
|
||||
use super::Translator;
|
||||
|
||||
const SCATTER_INPUT_ARG: usize = 0;
|
||||
const SCATTER_DIM_ARG: usize = 1;
|
||||
const SCATTER_INDEX_ARG: usize = 2;
|
||||
const SCATTER_VALUE_ARG: usize = 3;
|
||||
|
||||
impl<'a> Translator<'a> {
|
||||
pub(crate) fn translate_reshape(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
@@ -49,15 +54,6 @@ impl<'a> Translator<'a> {
|
||||
Ok(a.permute(axes))
|
||||
}
|
||||
|
||||
pub(crate) fn translate_transpose(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let dim0 = self.get_int_arg(node, 1)?;
|
||||
let dim1 = self.get_int_arg(node, 2)?;
|
||||
let dim0 = normalize_dim(dim0, a.shape.len());
|
||||
let dim1 = normalize_dim(dim1, a.shape.len());
|
||||
Ok(a.transpose(dim0, dim1))
|
||||
}
|
||||
|
||||
pub(crate) fn translate_expand(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let mut a = self.get_input_tensor(node, 0)?;
|
||||
let neg1_expr = Expression::from(-1i32);
|
||||
@@ -124,20 +120,6 @@ impl<'a> Translator<'a> {
|
||||
Ok(a.slice_along(start..end, dim))
|
||||
}
|
||||
|
||||
pub(crate) fn translate_select(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let dim = self.get_int_arg(node, 1)?;
|
||||
let dim = normalize_dim(dim, a.shape.len());
|
||||
let index = self.get_int_arg(node, 2)?;
|
||||
let index = if index < 0 {
|
||||
bail!("Negative select index not yet supported");
|
||||
} else {
|
||||
index as usize
|
||||
};
|
||||
|
||||
Ok(a.slice_along(index..index + 1, dim).squeeze(dim))
|
||||
}
|
||||
|
||||
pub(crate) fn translate_cat(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let tensors: Vec<GraphTensor> = if let Some(names) = node.inputs[0].arg.as_tensors() {
|
||||
names
|
||||
@@ -184,31 +166,6 @@ impl<'a> Translator<'a> {
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
pub(crate) fn translate_index_select(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let dim = self.get_int_arg(node, 1)?;
|
||||
let dim = normalize_dim(dim, a.shape.len());
|
||||
let indices = self.get_input_tensor(node, 2)?.cast(DType::Int);
|
||||
let src_dims = a.shape.dims;
|
||||
let idx_len = indices.shape.dims[0];
|
||||
|
||||
// Reshape 1D indices [K] → [1,..,K,..,1] with K at position `dim`
|
||||
let mut idx = indices;
|
||||
for _ in 0..dim {
|
||||
idx = idx.unsqueeze(0);
|
||||
}
|
||||
for _ in (dim + 1)..src_dims.len() {
|
||||
idx = idx.expand_dim(idx.shape.len(), Expression::from(1usize));
|
||||
}
|
||||
|
||||
// Expand to output shape: src_dims with dim replaced by idx_len
|
||||
let mut target: Vec<Expression> = src_dims.to_vec();
|
||||
target[dim] = idx_len;
|
||||
idx.shape.expand(target);
|
||||
|
||||
Ok(a.gather_elements(idx, dim))
|
||||
}
|
||||
|
||||
pub(crate) fn translate_embedding(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let weight = self.get_input_tensor(node, 0)?;
|
||||
let indices = self.get_input_tensor(node, 1)?;
|
||||
@@ -407,6 +364,29 @@ impl<'a> Translator<'a> {
|
||||
Ok(a.scatter_elements(indices.cast(DType::Int), src, dim))
|
||||
}
|
||||
|
||||
pub(crate) fn translate_scatter_value(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, SCATTER_INPUT_ARG)?;
|
||||
let dim = self.get_int_arg(node, SCATTER_DIM_ARG)?;
|
||||
let dim = normalize_dim(dim, a.shape.len());
|
||||
let indices = self.get_input_tensor(node, SCATTER_INDEX_ARG)?;
|
||||
let value_arg = &node
|
||||
.inputs
|
||||
.get(SCATTER_VALUE_ARG)
|
||||
.context("scatter.value missing value input")?
|
||||
.arg;
|
||||
let value = if let Some(b) = value_arg.as_bool() {
|
||||
self.graph.constant(if b { 1 } else { 0 }).cast(a.dtype)
|
||||
} else if let Some(i) = value_arg.as_int() {
|
||||
self.graph.constant(i).cast(a.dtype)
|
||||
} else if let Some(f) = value_arg.as_float() {
|
||||
self.graph.constant_float(f as f32).cast(a.dtype)
|
||||
} else {
|
||||
bail!("scatter.value: unsupported scalar argument {:?}", value_arg);
|
||||
}
|
||||
.expand_rhs(indices.shape);
|
||||
Ok(a.scatter_elements(indices.cast(DType::Int), value, dim))
|
||||
}
|
||||
|
||||
pub(crate) fn translate_index_put(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let index_names = node.inputs[1]
|
||||
@@ -430,9 +410,9 @@ impl<'a> Translator<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn translate_split(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
pub(crate) fn translate_split_with_sizes(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let split_size = self.get_int_arg(node, 1)? as usize;
|
||||
let sizes = self.get_ints_arg(node, 1)?;
|
||||
let dim = if node.inputs.len() > 2 {
|
||||
self.get_int_arg(node, 2).unwrap_or(0)
|
||||
} else {
|
||||
@@ -440,35 +420,32 @@ impl<'a> Translator<'a> {
|
||||
};
|
||||
let dim = normalize_dim(dim, a.shape.len());
|
||||
|
||||
let dim_size = a.shape.dims[dim];
|
||||
if let Some(total) = dim_size.to_usize() {
|
||||
// Collect output names from as_tensors (multi-output) or as_tensor (single)
|
||||
let output_names: Vec<String> = node
|
||||
.outputs
|
||||
.first()
|
||||
.and_then(|o| o.as_tensors.as_ref())
|
||||
.map(|ts| ts.iter().map(|t| t.name.clone()).collect())
|
||||
.unwrap_or_else(|| {
|
||||
node.outputs
|
||||
.iter()
|
||||
.filter_map(|o| o.as_tensor.as_ref().map(|t| t.name.clone()))
|
||||
.collect()
|
||||
});
|
||||
let output_names: Vec<String> = node
|
||||
.outputs
|
||||
.first()
|
||||
.and_then(|o| o.as_tensors.as_ref())
|
||||
.map(|ts| ts.iter().map(|t| t.name.clone()).collect())
|
||||
.unwrap_or_else(|| {
|
||||
node.outputs
|
||||
.iter()
|
||||
.filter_map(|o| o.as_tensor.as_ref().map(|t| t.name.clone()))
|
||||
.collect()
|
||||
});
|
||||
|
||||
// Store each chunk under its output name
|
||||
for (i, out_name) in output_names.iter().enumerate() {
|
||||
let start = i * split_size;
|
||||
let end = ((i + 1) * split_size).min(total);
|
||||
if start < total {
|
||||
let chunk = a.slice_along(start..end, dim);
|
||||
self.tensors.insert(out_name.clone(), chunk);
|
||||
}
|
||||
let mut offset = 0usize;
|
||||
let mut first_chunk = None;
|
||||
for (i, &size) in sizes.iter().enumerate() {
|
||||
let size = size as usize;
|
||||
let chunk = a.slice_along(offset..offset + size, dim);
|
||||
if let Some(name) = output_names.get(i) {
|
||||
self.tensors.insert(name.clone(), chunk);
|
||||
}
|
||||
|
||||
// Return the first chunk
|
||||
Ok(a.slice_along(0..split_size.min(total), dim))
|
||||
} else {
|
||||
Ok(a.slice_along(0..split_size, dim))
|
||||
if i == 0 {
|
||||
first_chunk = Some(chunk);
|
||||
}
|
||||
offset += size;
|
||||
}
|
||||
|
||||
first_chunk.ok_or_else(|| anyhow::anyhow!("split_with_sizes: empty sizes list"))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,6 +6,15 @@ use crate::pt2_util::*;
|
||||
|
||||
use super::Translator;
|
||||
|
||||
/// Compute total element count, returning an error if any dimension is symbolic.
|
||||
fn concrete_numel(a: &GraphTensor) -> Result<usize> {
|
||||
a.dims().iter().try_fold(1usize, |acc, d| {
|
||||
d.to_usize().map(|v| acc * v).ok_or_else(|| {
|
||||
anyhow::anyhow!("Full reduction requires concrete dimensions, got symbolic dim")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
impl<'a> Translator<'a> {
|
||||
pub(crate) fn translate_reduction(
|
||||
&mut self,
|
||||
@@ -13,21 +22,42 @@ impl<'a> Translator<'a> {
|
||||
op: ReductionOp,
|
||||
) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let dims = self.get_ints_arg(node, 1)?;
|
||||
let keepdim = if node.inputs.len() > 2 {
|
||||
self.get_bool_arg(node, 2).unwrap_or(false)
|
||||
} else {
|
||||
false
|
||||
};
|
||||
|
||||
let ndim = a.shape.len();
|
||||
let axes: Vec<usize> = dims.iter().map(|&d| normalize_dim(d, ndim)).collect();
|
||||
// Try to get dims arg; if missing or empty, fall back to full reduce
|
||||
let dims_result = self.get_ints_arg(node, 1);
|
||||
let (axes, keepdim) = match dims_result {
|
||||
Ok(ref dims) if !dims.is_empty() => {
|
||||
let ndim = a.shape.len();
|
||||
let axes: Vec<usize> = dims.iter().map(|&d| normalize_dim(d, ndim)).collect();
|
||||
let keepdim = if node.inputs.len() > 2 {
|
||||
self.get_bool_arg(node, 2).unwrap_or(false)
|
||||
} else {
|
||||
false
|
||||
};
|
||||
(axes, keepdim)
|
||||
}
|
||||
_ => {
|
||||
// Full reduce: flatten to [1, N] and reduce axis 1
|
||||
let total = concrete_numel(&a)?;
|
||||
let mut flat = a;
|
||||
flat.shape = ShapeTracker::new(vec![1, total]);
|
||||
let result = match op {
|
||||
ReductionOp::Sum => flat.sum(vec![1]),
|
||||
ReductionOp::Mean => flat.sum(vec![1]) / total as f32,
|
||||
ReductionOp::Max => flat.max(vec![1]),
|
||||
ReductionOp::Min => flat.min(vec![1]),
|
||||
ReductionOp::Prod => flat.prod(vec![1]),
|
||||
};
|
||||
return Ok(result);
|
||||
}
|
||||
};
|
||||
|
||||
let mut result = match op {
|
||||
ReductionOp::Sum => a.sum(axes.clone()),
|
||||
ReductionOp::Mean => a.mean(axes.clone()),
|
||||
ReductionOp::Max => a.max(axes.clone()),
|
||||
ReductionOp::Min => a.min(axes.clone()),
|
||||
ReductionOp::Prod => a.prod(axes.clone()),
|
||||
};
|
||||
|
||||
if keepdim {
|
||||
|
||||
@@ -6,6 +6,27 @@ use crate::pt2_util::*;
|
||||
|
||||
use super::Translator;
|
||||
|
||||
const FULL_SHAPE_ARG: usize = 0;
|
||||
const FULL_VALUE_ARG: usize = 1;
|
||||
|
||||
const FULL_LIKE_INPUT_ARG: usize = 0;
|
||||
const FULL_LIKE_VALUE_ARG: usize = 1;
|
||||
|
||||
const TOPK_INPUT_ARG: usize = 0;
|
||||
const TOPK_K_ARG: usize = 1;
|
||||
const TOPK_DIM_ARG: usize = 2;
|
||||
|
||||
const SORT_INPUT_ARG: usize = 0;
|
||||
const SORT_DIM_ARG: usize = 1;
|
||||
const SORT_DESCENDING_ARG: usize = 2;
|
||||
|
||||
const WHERE_COND_ARG: usize = 0;
|
||||
const WHERE_X_ARG: usize = 1;
|
||||
const WHERE_OTHER_ARG: usize = 2;
|
||||
|
||||
const TRIANGULAR_INPUT_ARG: usize = 0;
|
||||
const TRIANGULAR_DIAGONAL_ARG: usize = 1;
|
||||
|
||||
impl<'a> Translator<'a> {
|
||||
pub(crate) fn translate_arange(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let positional_args: Vec<Expression> = node
|
||||
@@ -18,31 +39,57 @@ impl<'a> Translator<'a> {
|
||||
match positional_args.len() {
|
||||
0 => anyhow::bail!("arange: no positional args found"),
|
||||
1 => Ok(self.graph.arange(positional_args[0])),
|
||||
_ => Ok(self
|
||||
2 => Ok(self
|
||||
.graph
|
||||
.arange_options(positional_args[0], positional_args[1], 1)),
|
||||
_ => Ok(self.graph.arange_options(
|
||||
positional_args[0],
|
||||
positional_args[1],
|
||||
positional_args[2],
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn translate_full(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let shape = self.get_exprs_arg(node, 0)?;
|
||||
let val = self.get_float_arg(node, 1)? as f32;
|
||||
Ok(self.graph.constant_float(val).expand_rhs(shape))
|
||||
let shape = self.get_exprs_arg(node, FULL_SHAPE_ARG)?;
|
||||
// fill_value can be float, int, or bool after decomposition
|
||||
let val = if let Ok(f) = self.get_float_arg(node, FULL_VALUE_ARG) {
|
||||
f as f32
|
||||
} else if let Ok(b) = self.get_bool_arg(node, FULL_VALUE_ARG) {
|
||||
if b { 1.0 } else { 0.0 }
|
||||
} else {
|
||||
anyhow::bail!(
|
||||
"full: unsupported fill value type: {:?}",
|
||||
node.inputs.get(FULL_VALUE_ARG)
|
||||
);
|
||||
};
|
||||
let dtype = self.output_meta_dtype(node)?;
|
||||
let value = self.graph.constant_float(val).cast(dtype);
|
||||
Ok(if shape.is_empty() {
|
||||
value
|
||||
} else {
|
||||
value.expand_rhs(shape)
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn translate_zeros(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
self.translate_constant_fill(node, 0.0)
|
||||
pub(crate) fn translate_full_like(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let reference = self.get_input_tensor(node, FULL_LIKE_INPUT_ARG)?;
|
||||
let val = if let Ok(f) = self.get_float_arg(node, FULL_LIKE_VALUE_ARG) {
|
||||
f as f32
|
||||
} else if let Ok(b) = self.get_bool_arg(node, FULL_LIKE_VALUE_ARG) {
|
||||
if b { 1.0 } else { 0.0 }
|
||||
} else {
|
||||
anyhow::bail!(
|
||||
"full_like: unsupported fill value type: {:?}",
|
||||
node.inputs.get(FULL_LIKE_VALUE_ARG)
|
||||
);
|
||||
};
|
||||
let dtype = self.output_meta_dtype(node)?;
|
||||
let value = self.graph.constant_float(val).cast(dtype);
|
||||
Ok(value.expand_rhs(reference.shape))
|
||||
}
|
||||
|
||||
pub(crate) fn translate_ones(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
self.translate_constant_fill(node, 1.0)
|
||||
}
|
||||
|
||||
pub(crate) fn translate_new_ones(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
self.translate_constant_fill(node, 1.0)
|
||||
}
|
||||
|
||||
fn translate_constant_fill(&mut self, node: &Node, val: f32) -> Result<GraphTensor> {
|
||||
fn output_meta_dtype(&self, node: &Node) -> Result<DType> {
|
||||
let output_name = node
|
||||
.outputs
|
||||
.first()
|
||||
@@ -51,32 +98,31 @@ impl<'a> Translator<'a> {
|
||||
.unwrap_or_default();
|
||||
let meta = self
|
||||
.tensor_meta(&output_name)
|
||||
.context("Missing tensor meta for constant fill output")?;
|
||||
let shape = self.tensor_meta_to_shape(meta)?;
|
||||
if shape.is_empty() {
|
||||
Ok(self.graph.constant_float(val))
|
||||
} else {
|
||||
Ok(self.graph.constant_float(val).expand_rhs(shape))
|
||||
}
|
||||
.context("Missing tensor meta for output dtype")?;
|
||||
Ok(torch_dtype_int_to_luminal(meta.dtype))
|
||||
}
|
||||
|
||||
pub(crate) fn translate_where(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let cond = self.get_input_tensor(node, 0)?;
|
||||
let x = self.get_input_tensor(node, 1)?;
|
||||
let y = self.get_input_tensor(node, 2)?;
|
||||
// Ensure x and y have the same dtype
|
||||
let (x, y) = ensure_same_dtype(x, y);
|
||||
// Broadcast all three tensors to a common shape first
|
||||
let (cond_b, x_b) = broadcast_binary(cond, x);
|
||||
let (cond_bc, y_b) = broadcast_binary(cond_b, y);
|
||||
let (x_bc, y_bc) = broadcast_binary(x_b, y_b);
|
||||
let c = cond_bc.cast(DType::F32);
|
||||
let x_f = x_bc.cast(DType::F32);
|
||||
let y_f = y_bc.cast(DType::F32);
|
||||
let one = self.graph.constant_float(1.0).expand_rhs(c.shape);
|
||||
Ok(c * x_bc + (one - c) * y_bc)
|
||||
Ok(c * x_f + (one - c) * y_f)
|
||||
}
|
||||
|
||||
pub(crate) fn translate_where_scalar_other(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let cond = self.get_input_tensor(node, 0)?;
|
||||
let x = self.get_input_tensor(node, 1)?;
|
||||
let other_val = self.get_float_arg(node, 2)? as f32;
|
||||
let cond = self.get_input_tensor(node, WHERE_COND_ARG)?;
|
||||
let x = self.get_input_tensor(node, WHERE_X_ARG)?;
|
||||
let other_val = self.get_float_arg(node, WHERE_OTHER_ARG)? as f32;
|
||||
// Broadcast cond and x to a common shape
|
||||
let (cond_b, x_b) = broadcast_binary(cond, x);
|
||||
let c = cond_b.cast(DType::F32);
|
||||
@@ -85,33 +131,6 @@ impl<'a> Translator<'a> {
|
||||
Ok(c * x_b + (one - c) * other)
|
||||
}
|
||||
|
||||
pub(crate) fn translate_diff(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let input = self.get_input_tensor(node, 0)?;
|
||||
let dim = if node.inputs.len() > 2 {
|
||||
self.get_int_arg(node, 2).unwrap_or(-1)
|
||||
} else {
|
||||
-1
|
||||
};
|
||||
let dim = normalize_dim(dim, input.shape.len());
|
||||
|
||||
let prepend = if node.inputs.len() > 3 {
|
||||
self.get_input_tensor(node, 3).ok()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let x = if let Some(prep) = prepend {
|
||||
prep.concat_along(input, dim)
|
||||
} else {
|
||||
input
|
||||
};
|
||||
|
||||
let dim_size = x.shape.dims[dim];
|
||||
let front = x.slice_along(Expression::from(1)..dim_size, dim);
|
||||
let back = x.slice_along(Expression::from(0)..dim_size - 1, dim);
|
||||
Ok(front - back)
|
||||
}
|
||||
|
||||
pub(crate) fn translate_tril(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
self.translate_triangular(node, false)
|
||||
}
|
||||
@@ -121,9 +140,9 @@ impl<'a> Translator<'a> {
|
||||
}
|
||||
|
||||
fn translate_triangular(&mut self, node: &Node, upper: bool) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let diagonal = if node.inputs.len() > 1 {
|
||||
self.get_int_arg(node, 1).unwrap_or(0) as i32
|
||||
let a = self.get_input_tensor(node, TRIANGULAR_INPUT_ARG)?;
|
||||
let diagonal = if node.inputs.len() > TRIANGULAR_DIAGONAL_ARG {
|
||||
self.get_int_arg(node, TRIANGULAR_DIAGONAL_ARG).unwrap_or(0) as i32
|
||||
} else {
|
||||
0
|
||||
};
|
||||
@@ -154,10 +173,10 @@ impl<'a> Translator<'a> {
|
||||
}
|
||||
|
||||
pub(crate) fn translate_topk(&mut self, node: &Node) -> Result<()> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let k = self.get_int_arg(node, 1)? as usize;
|
||||
let dim = if node.inputs.len() > 2 {
|
||||
self.get_int_arg(node, 2).unwrap_or(-1)
|
||||
let a = self.get_input_tensor(node, TOPK_INPUT_ARG)?;
|
||||
let k = self.get_int_arg(node, TOPK_K_ARG)? as usize;
|
||||
let dim = if node.inputs.len() > TOPK_DIM_ARG {
|
||||
self.get_int_arg(node, TOPK_DIM_ARG).unwrap_or(-1)
|
||||
} else {
|
||||
-1
|
||||
};
|
||||
@@ -177,13 +196,10 @@ impl<'a> Translator<'a> {
|
||||
None
|
||||
};
|
||||
|
||||
// Use full argsort then slice, rather than topk_indexes/topk_values directly.
|
||||
// This avoids a CUDA gather kernel bug when data and index shapes differ
|
||||
// along the gather axis (topk_indexes returns a sliced tensor).
|
||||
let full_argsort = a.argsort(dim, true);
|
||||
// Build top-k outputs from a full stable argsort, then slice to k.
|
||||
let full_argsort = a.stable_argsort(dim, true);
|
||||
|
||||
// Only build each branch when its output is consumed.
|
||||
// Dead nodes in the graph can confuse the CUDA optimizer.
|
||||
// Only build the outputs that are consumed.
|
||||
if let Some(val_name) = values_name
|
||||
&& !val_name.is_empty()
|
||||
{
|
||||
@@ -191,8 +207,7 @@ impl<'a> Translator<'a> {
|
||||
self.tensors.insert(val_name, values);
|
||||
}
|
||||
if let Some(idx_name) = indices_name {
|
||||
// Materialize Int indices as F32 with `* 1.0` to force a contiguous copy.
|
||||
// Without this, CUDA can't correctly read the sliced Int view.
|
||||
// Materialize the sliced indices through a copy before storing them.
|
||||
let indices = full_argsort.slice_along(..k, dim) * 1.0;
|
||||
self.tensors.insert(idx_name, indices);
|
||||
}
|
||||
@@ -200,19 +215,49 @@ impl<'a> Translator<'a> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn translate_one_hot(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let num_classes = self.get_int_arg(node, 1)? as usize;
|
||||
// one_hot: output[..., i] = 1 if input[...] == i else 0
|
||||
let a_int = a.cast(DType::Int);
|
||||
let classes = self.graph.arange(num_classes);
|
||||
// Expand a to [..., 1] and classes to [..., num_classes]
|
||||
let a_expanded = a_int.expand_dim(a.shape.len(), num_classes);
|
||||
let mut classes_expanded = classes;
|
||||
for d in a.shape.dims.iter().rev() {
|
||||
classes_expanded = classes_expanded.expand_dim(0, *d);
|
||||
pub(crate) fn translate_sort(&mut self, node: &Node) -> Result<()> {
|
||||
let a = self.get_input_tensor(node, SORT_INPUT_ARG)?;
|
||||
let dim = if node.inputs.len() > SORT_DIM_ARG {
|
||||
self.get_int_arg(node, SORT_DIM_ARG).unwrap_or(-1)
|
||||
} else {
|
||||
-1
|
||||
};
|
||||
let descending = if node.inputs.len() > SORT_DESCENDING_ARG {
|
||||
self.get_bool_arg(node, SORT_DESCENDING_ARG)
|
||||
.unwrap_or(false)
|
||||
} else {
|
||||
false
|
||||
};
|
||||
let dim = normalize_dim(dim, a.shape.len());
|
||||
|
||||
// Determine output names (sort returns (values, indices))
|
||||
let values_name = node
|
||||
.outputs
|
||||
.first()
|
||||
.and_then(|o| o.as_tensor.as_ref().map(|t| t.name.clone()));
|
||||
let indices_name =
|
||||
if let Some(ts) = node.outputs.first().and_then(|o| o.as_tensors.as_ref()) {
|
||||
ts.get(1).map(|t| t.name.clone())
|
||||
} else if node.outputs.len() > 1 {
|
||||
node.outputs[1].as_tensor.as_ref().map(|t| t.name.clone())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let full_argsort = a.stable_argsort(dim, descending);
|
||||
|
||||
if let Some(val_name) = values_name
|
||||
&& !val_name.is_empty()
|
||||
{
|
||||
let values = a.gather_elements(full_argsort, dim);
|
||||
self.tensors.insert(val_name, values);
|
||||
}
|
||||
Ok(a_expanded.eq(classes_expanded).cast(DType::Int))
|
||||
if let Some(idx_name) = indices_name {
|
||||
let indices = full_argsort * 1.0;
|
||||
self.tensors.insert(idx_name, indices);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn translate_wrap_set_grad(&mut self, node: &Node) -> Result<()> {
|
||||
|
||||
@@ -6,7 +6,38 @@ use crate::pt2_util::{broadcast_binary, torch_dtype_int_to_luminal};
|
||||
|
||||
use super::Translator;
|
||||
|
||||
const ARGSORT_INPUT_ARG: usize = 0;
|
||||
const ARGSORT_DIM_ARG: usize = 1;
|
||||
const ARGSORT_DESCENDING_ARG: usize = 2;
|
||||
|
||||
const MASKED_FILL_INPUT_ARG: usize = 0;
|
||||
const MASKED_FILL_MASK_ARG: usize = 1;
|
||||
const MASKED_FILL_VALUE_ARG: usize = 2;
|
||||
|
||||
const FLOOR_DIVIDE_INPUT_ARG: usize = 0;
|
||||
const FLOOR_DIVIDE_OTHER_ARG: usize = 1;
|
||||
|
||||
const DIV_MODE_INPUT_ARG: usize = 0;
|
||||
const DIV_MODE_OTHER_ARG: usize = 1;
|
||||
|
||||
impl<'a> Translator<'a> {
|
||||
pub(crate) fn translate_argsort(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, ARGSORT_INPUT_ARG)?;
|
||||
let dim = if node.inputs.len() > ARGSORT_DIM_ARG {
|
||||
self.get_int_arg(node, ARGSORT_DIM_ARG).unwrap_or(-1)
|
||||
} else {
|
||||
-1
|
||||
};
|
||||
let descending = if node.inputs.len() > ARGSORT_DESCENDING_ARG {
|
||||
self.get_bool_arg(node, ARGSORT_DESCENDING_ARG)
|
||||
.unwrap_or(false)
|
||||
} else {
|
||||
false
|
||||
};
|
||||
let dim = crate::pt2_util::normalize_dim(dim, a.shape.len());
|
||||
Ok(a.stable_argsort(dim, descending))
|
||||
}
|
||||
|
||||
pub(crate) fn translate_unary_op(
|
||||
&mut self,
|
||||
node: &Node,
|
||||
@@ -17,43 +48,17 @@ impl<'a> Translator<'a> {
|
||||
}
|
||||
|
||||
pub(crate) fn translate_to_copy(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
for input in &node.inputs {
|
||||
if input.name == "dtype"
|
||||
&& let Some(dtype_int) = input.arg.as_int()
|
||||
{
|
||||
let dtype = torch_dtype_int_to_luminal(dtype_int as u32);
|
||||
return Ok(a.cast(dtype));
|
||||
}
|
||||
}
|
||||
Ok(a)
|
||||
}
|
||||
|
||||
pub(crate) fn translate_to_dtype(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
if let Some(dtype_int) = node.inputs.get(1).and_then(|i| i.arg.as_scalar_type()) {
|
||||
let dtype = torch_dtype_int_to_luminal(dtype_int);
|
||||
Ok(a.cast(dtype))
|
||||
} else if let Some(dtype_int) = node.inputs.get(1).and_then(|i| i.arg.as_int()) {
|
||||
let dtype = torch_dtype_int_to_luminal(dtype_int as u32);
|
||||
Ok(a.cast(dtype))
|
||||
} else {
|
||||
Ok(a)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn translate_to_dtype_layout(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
for input in &node.inputs {
|
||||
if input.name == "dtype" {
|
||||
if let Some(dtype_int) = input.arg.as_scalar_type() {
|
||||
let dtype = torch_dtype_int_to_luminal(dtype_int);
|
||||
return Ok(a.cast(dtype));
|
||||
}
|
||||
if let Some(dtype_int) = input.arg.as_int() {
|
||||
let dtype = torch_dtype_int_to_luminal(dtype_int as u32);
|
||||
return Ok(a.cast(dtype));
|
||||
}
|
||||
if let Some(dtype_int) = input.arg.as_scalar_type() {
|
||||
let dtype = torch_dtype_int_to_luminal(dtype_int);
|
||||
return Ok(a.cast(dtype));
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(a)
|
||||
@@ -90,6 +95,155 @@ impl<'a> Translator<'a> {
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
pub(crate) fn translate_sign(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let zero = self
|
||||
.graph
|
||||
.constant_float(0.0)
|
||||
.cast(a.dtype)
|
||||
.expand_rhs(a.shape);
|
||||
let pos = a.gt(zero).cast(DType::Int);
|
||||
let neg = a.lt(zero).cast(DType::Int);
|
||||
let signed = pos - neg;
|
||||
Ok(if a.dtype == DType::Int {
|
||||
signed
|
||||
} else {
|
||||
signed.cast(a.dtype)
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn translate_bitwise_not(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
Ok(match a.dtype {
|
||||
DType::Bool => {
|
||||
let one = self
|
||||
.graph
|
||||
.constant_float(1.0)
|
||||
.cast(DType::Int)
|
||||
.expand_rhs(a.shape);
|
||||
(one - a.cast(DType::Int)).cast(DType::Bool)
|
||||
}
|
||||
DType::Int => (a + 1) * -1.0,
|
||||
other => {
|
||||
anyhow::bail!("bitwise_not only supports Bool/Int routing tensors, got {other:?}")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn translate_masked_fill_scalar(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let input = self.get_input_tensor(node, MASKED_FILL_INPUT_ARG)?;
|
||||
let mask = self.get_input_tensor(node, MASKED_FILL_MASK_ARG)?;
|
||||
let fill = self.get_float_arg(node, MASKED_FILL_VALUE_ARG)? as f32;
|
||||
let (input, mask) = broadcast_binary(input, mask);
|
||||
let work_dtype = if input.dtype == DType::Bool {
|
||||
DType::Int
|
||||
} else {
|
||||
input.dtype
|
||||
};
|
||||
let input_work = if input.dtype == DType::Bool {
|
||||
input.cast(DType::Int)
|
||||
} else {
|
||||
input
|
||||
};
|
||||
let mask_work = mask.cast(work_dtype);
|
||||
let fill_work = self
|
||||
.graph
|
||||
.constant_float(fill)
|
||||
.cast(work_dtype)
|
||||
.expand_rhs(input_work.shape);
|
||||
let one = self
|
||||
.graph
|
||||
.constant_float(1.0)
|
||||
.cast(work_dtype)
|
||||
.expand_rhs(input_work.shape);
|
||||
let result = mask_work * fill_work + (one - mask_work) * input_work;
|
||||
Ok(if input.dtype == DType::Bool {
|
||||
result.cast(DType::Bool)
|
||||
} else {
|
||||
result
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn translate_floor_divide(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, FLOOR_DIVIDE_INPUT_ARG)?;
|
||||
let b = if let Some(name) = node
|
||||
.inputs
|
||||
.get(FLOOR_DIVIDE_OTHER_ARG)
|
||||
.and_then(|i| i.arg.as_tensor_name())
|
||||
{
|
||||
self.get_tensor(name)?
|
||||
} else {
|
||||
let scalar = self.get_float_arg(node, FLOOR_DIVIDE_OTHER_ARG)? as f32;
|
||||
self.graph
|
||||
.constant_float(scalar)
|
||||
.cast(a.dtype)
|
||||
.expand_rhs(a.shape)
|
||||
};
|
||||
let (a, b) = crate::pt2_util::ensure_same_dtype(a, b);
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
let quotient = a.cast(DType::F32) / b.cast(DType::F32);
|
||||
let trunc = quotient.cast(DType::Int).cast(DType::F32);
|
||||
let adjust = quotient.lt(trunc).cast(DType::F32);
|
||||
let floored = trunc - adjust;
|
||||
Ok(if a.dtype == DType::Int {
|
||||
floored.cast(DType::Int)
|
||||
} else {
|
||||
floored.cast(a.dtype)
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn translate_div_tensor_mode(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, DIV_MODE_INPUT_ARG)?;
|
||||
let b = if let Some(name) = node
|
||||
.inputs
|
||||
.get(DIV_MODE_OTHER_ARG)
|
||||
.and_then(|i| i.arg.as_tensor_name())
|
||||
{
|
||||
self.get_tensor(name)?
|
||||
} else {
|
||||
let scalar = self.get_float_arg(node, DIV_MODE_OTHER_ARG)? as f32;
|
||||
self.graph
|
||||
.constant_float(scalar)
|
||||
.cast(a.dtype)
|
||||
.expand_rhs(a.shape)
|
||||
};
|
||||
let (a, b) = crate::pt2_util::ensure_same_dtype(a, b);
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
|
||||
// Check rounding_mode kwarg
|
||||
let rounding_mode = node.inputs.iter().find_map(|input| {
|
||||
if input.name == "rounding_mode"
|
||||
&& let Argument::Other(val) = &input.arg
|
||||
{
|
||||
return val.as_str().map(|s| s.to_string());
|
||||
}
|
||||
None
|
||||
});
|
||||
|
||||
let quotient = a.cast(DType::F32) / b.cast(DType::F32);
|
||||
match rounding_mode.as_deref() {
|
||||
Some("floor") => {
|
||||
let trunc = quotient.cast(DType::Int).cast(DType::F32);
|
||||
let adjust = quotient.lt(trunc).cast(DType::F32);
|
||||
let floored = trunc - adjust;
|
||||
Ok(if a.dtype == DType::Int {
|
||||
floored.cast(DType::Int)
|
||||
} else {
|
||||
floored.cast(a.dtype)
|
||||
})
|
||||
}
|
||||
Some("trunc") => Ok(if a.dtype == DType::Int {
|
||||
quotient.cast(DType::Int)
|
||||
} else {
|
||||
quotient.cast(DType::Int).cast(a.dtype)
|
||||
}),
|
||||
_ => {
|
||||
// No rounding mode — regular division
|
||||
Ok(quotient.cast(a.dtype))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn translate_clamp(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let min_val = if node.inputs.len() > 1 {
|
||||
|
||||
352
crates/luminal_python/rust/src/typed_data.rs
Normal file
352
crates/luminal_python/rust/src/typed_data.rs
Normal file
@@ -0,0 +1,352 @@
|
||||
//! Dtype-aware buffer type for the luminal_python bridge.
|
||||
//!
|
||||
//! `TypedData` wraps raw bytes with a `DType` tag, enabling multi-dtype data flow
|
||||
//! through the PT2 path without forcing everything to f32.
|
||||
|
||||
use luminal::hlir::NativeData;
|
||||
use luminal::prelude::tracing::warn;
|
||||
use luminal::prelude::*;
|
||||
|
||||
/// A dtype-tagged byte buffer. All weight, constant, and input data flows through this type.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct TypedData {
|
||||
pub bytes: Vec<u8>,
|
||||
pub dtype: DType,
|
||||
}
|
||||
|
||||
impl TypedData {
|
||||
/// Wrap raw bytes with a dtype tag. Caller must ensure bytes are correctly formatted.
|
||||
pub fn from_raw(bytes: Vec<u8>, dtype: DType) -> Self {
|
||||
Self { bytes, dtype }
|
||||
}
|
||||
|
||||
/// Number of bytes in the buffer
|
||||
pub fn n_bytes(&self) -> usize {
|
||||
self.bytes.len()
|
||||
}
|
||||
|
||||
/// Number of logical elements (for byte-aligned dtypes)
|
||||
pub fn n_elements(&self) -> usize {
|
||||
let bits = self.dtype.bits();
|
||||
if bits >= 8 {
|
||||
self.bytes.len() / (bits / 8)
|
||||
} else {
|
||||
// sub-byte types: multiple elements per byte
|
||||
self.bytes.len() * (8 / bits)
|
||||
}
|
||||
}
|
||||
|
||||
/// Read element at `idx` as f64 (used by From<TypedData> for NativeData fallback).
|
||||
fn as_f64(&self, idx: usize) -> f64 {
|
||||
match self.dtype {
|
||||
DType::F32 => {
|
||||
let start = idx * 4;
|
||||
f32::from_le_bytes([
|
||||
self.bytes[start],
|
||||
self.bytes[start + 1],
|
||||
self.bytes[start + 2],
|
||||
self.bytes[start + 3],
|
||||
]) as f64
|
||||
}
|
||||
DType::F64 => {
|
||||
let start = idx * 8;
|
||||
f64::from_le_bytes([
|
||||
self.bytes[start],
|
||||
self.bytes[start + 1],
|
||||
self.bytes[start + 2],
|
||||
self.bytes[start + 3],
|
||||
self.bytes[start + 4],
|
||||
self.bytes[start + 5],
|
||||
self.bytes[start + 6],
|
||||
self.bytes[start + 7],
|
||||
])
|
||||
}
|
||||
DType::F16 => {
|
||||
let start = idx * 2;
|
||||
half::f16::from_le_bytes([self.bytes[start], self.bytes[start + 1]]).to_f64()
|
||||
}
|
||||
DType::Bf16 => {
|
||||
let start = idx * 2;
|
||||
half::bf16::from_le_bytes([self.bytes[start], self.bytes[start + 1]]).to_f64()
|
||||
}
|
||||
DType::Int => {
|
||||
let start = idx * 4;
|
||||
i32::from_le_bytes([
|
||||
self.bytes[start],
|
||||
self.bytes[start + 1],
|
||||
self.bytes[start + 2],
|
||||
self.bytes[start + 3],
|
||||
]) as f64
|
||||
}
|
||||
DType::I8 => self.bytes[idx] as i8 as f64,
|
||||
DType::U8 => self.bytes[idx] as f64,
|
||||
DType::I16 | DType::U16 => {
|
||||
let start = idx * 2;
|
||||
let val = i16::from_le_bytes([self.bytes[start], self.bytes[start + 1]]);
|
||||
if self.dtype == DType::U16 {
|
||||
val as u16 as f64
|
||||
} else {
|
||||
val as f64
|
||||
}
|
||||
}
|
||||
DType::Bool => {
|
||||
if self.bytes[idx] != 0 {
|
||||
1.0
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
}
|
||||
_ => panic!("as_f64 not supported for {:?}", self.dtype),
|
||||
}
|
||||
}
|
||||
// -- Constructors from typed Vecs --
|
||||
|
||||
pub fn from_f32_vec(data: Vec<f32>) -> Self {
|
||||
let bytes = unsafe {
|
||||
std::slice::from_raw_parts(data.as_ptr() as *const u8, data.len() * 4).to_vec()
|
||||
};
|
||||
Self {
|
||||
bytes,
|
||||
dtype: DType::F32,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_f16_vec(data: Vec<half::f16>) -> Self {
|
||||
let bytes = unsafe {
|
||||
std::slice::from_raw_parts(data.as_ptr() as *const u8, data.len() * 2).to_vec()
|
||||
};
|
||||
Self {
|
||||
bytes,
|
||||
dtype: DType::F16,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_bf16_vec(data: Vec<half::bf16>) -> Self {
|
||||
let bytes = unsafe {
|
||||
std::slice::from_raw_parts(data.as_ptr() as *const u8, data.len() * 2).to_vec()
|
||||
};
|
||||
Self {
|
||||
bytes,
|
||||
dtype: DType::Bf16,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_i32_vec(data: Vec<i32>) -> Self {
|
||||
let bytes = unsafe {
|
||||
std::slice::from_raw_parts(data.as_ptr() as *const u8, data.len() * 4).to_vec()
|
||||
};
|
||||
Self {
|
||||
bytes,
|
||||
dtype: DType::Int,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_bool_vec(data: Vec<bool>) -> Self {
|
||||
let bytes: Vec<u8> = data.iter().map(|&b| b as u8).collect();
|
||||
Self {
|
||||
bytes,
|
||||
dtype: DType::Bool,
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert raw bytes from a PyTorch tensor (identified by PT2 dtype code) to TypedData
|
||||
/// in luminal's native format. Handles widening/narrowing conversions for types where
|
||||
/// PyTorch's byte layout differs from luminal's:
|
||||
/// - i64 → i32, f64 → f32 (luminal has no 64-bit types)
|
||||
/// - i16 → i32, u8 → i32, i8 → i32 (luminal maps all integer types to i32 for PT2)
|
||||
pub fn from_pytorch_bytes(bytes: Vec<u8>, dtype_code: u32) -> Self {
|
||||
match dtype_code {
|
||||
// Types that map directly — preserve raw bytes
|
||||
7 => Self::from_raw(bytes, DType::F32),
|
||||
6 => Self::from_raw(bytes, DType::F16),
|
||||
13 => Self::from_raw(bytes, DType::Bf16),
|
||||
4 => Self::from_raw(bytes, DType::Int), // i32
|
||||
12 => Self::from_raw(bytes, DType::Bool),
|
||||
// i64 → i32 (truncate)
|
||||
5 => {
|
||||
let i32s: Vec<i32> = bytes
|
||||
.chunks_exact(8)
|
||||
.map(|b| {
|
||||
i64::from_le_bytes([b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]]) as i32
|
||||
})
|
||||
.collect();
|
||||
Self::from_i32_vec(i32s)
|
||||
}
|
||||
// f64 → f32 (downcast)
|
||||
8 => {
|
||||
let f32s: Vec<f32> = bytes
|
||||
.chunks_exact(8)
|
||||
.map(|b| {
|
||||
f64::from_le_bytes([b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]]) as f32
|
||||
})
|
||||
.collect();
|
||||
Self::from_f32_vec(f32s)
|
||||
}
|
||||
// i16 → i32 (widen)
|
||||
3 => {
|
||||
let i32s: Vec<i32> = bytes
|
||||
.chunks_exact(2)
|
||||
.map(|b| i16::from_le_bytes([b[0], b[1]]) as i32)
|
||||
.collect();
|
||||
Self::from_i32_vec(i32s)
|
||||
}
|
||||
// u8 → i32 (widen)
|
||||
1 => {
|
||||
let i32s: Vec<i32> = bytes.iter().map(|&b| b as i32).collect();
|
||||
Self::from_i32_vec(i32s)
|
||||
}
|
||||
// i8 → i32 (widen, signed)
|
||||
2 => {
|
||||
let i32s: Vec<i32> = bytes.iter().map(|&b| (b as i8) as i32).collect();
|
||||
Self::from_i32_vec(i32s)
|
||||
}
|
||||
// Unknown: best-effort pass-through as f32
|
||||
_ => {
|
||||
warn!("Unrecognized pytorch dtype code {dtype_code}, interpreting as f32");
|
||||
Self::from_raw(bytes, DType::F32)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Create an n-element buffer of "safe" dummy values (1.0 for floats, 1 for ints, true for bool).
|
||||
/// IMPORTANT: Must use 1, NOT 0. Zero inputs cause NaN in many ops (fmod, recip, log, etc.).
|
||||
pub fn ones(n_elements: usize, dtype: DType) -> Self {
|
||||
match dtype {
|
||||
DType::F32 | DType::TF32 => Self::from_f32_vec(vec![1.0f32; n_elements]),
|
||||
DType::F64 => {
|
||||
let data = vec![1.0f64; n_elements];
|
||||
let bytes = unsafe {
|
||||
std::slice::from_raw_parts(data.as_ptr() as *const u8, data.len() * 8).to_vec()
|
||||
};
|
||||
Self {
|
||||
bytes,
|
||||
dtype: DType::F64,
|
||||
}
|
||||
}
|
||||
DType::F16 => Self::from_f16_vec(vec![half::f16::from_f32(1.0); n_elements]),
|
||||
DType::Bf16 => Self::from_bf16_vec(vec![half::bf16::from_f32(1.0); n_elements]),
|
||||
DType::Int => Self::from_i32_vec(vec![1i32; n_elements]),
|
||||
DType::I8 => Self::from_raw(vec![1u8; n_elements], DType::I8),
|
||||
DType::U8 => Self::from_raw(vec![1u8; n_elements], DType::U8),
|
||||
DType::I16 => {
|
||||
let data = vec![1i16; n_elements];
|
||||
let bytes = unsafe {
|
||||
std::slice::from_raw_parts(data.as_ptr() as *const u8, data.len() * 2).to_vec()
|
||||
};
|
||||
Self {
|
||||
bytes,
|
||||
dtype: DType::I16,
|
||||
}
|
||||
}
|
||||
DType::U16 => {
|
||||
let data = vec![1u16; n_elements];
|
||||
let bytes = unsafe {
|
||||
std::slice::from_raw_parts(data.as_ptr() as *const u8, data.len() * 2).to_vec()
|
||||
};
|
||||
Self {
|
||||
bytes,
|
||||
dtype: DType::U16,
|
||||
}
|
||||
}
|
||||
DType::Bool => Self::from_bool_vec(vec![true; n_elements]),
|
||||
_ => panic!("TypedData::ones not supported for {:?}", dtype),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert TypedData to NativeData for the native runtime.
|
||||
impl From<TypedData> for NativeData {
|
||||
fn from(td: TypedData) -> Self {
|
||||
match td.dtype {
|
||||
DType::F32 | DType::TF32 => {
|
||||
let data: Vec<f32> = td
|
||||
.bytes
|
||||
.chunks_exact(4)
|
||||
.map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
|
||||
.collect();
|
||||
NativeData::F32(data)
|
||||
}
|
||||
DType::F64 => {
|
||||
// Downcast f64 -> f32 for native runtime (which only has F32 variant for floats > 32-bit)
|
||||
let data: Vec<f32> = td
|
||||
.bytes
|
||||
.chunks_exact(8)
|
||||
.map(|b| {
|
||||
f64::from_le_bytes([b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]]) as f32
|
||||
})
|
||||
.collect();
|
||||
NativeData::F32(data)
|
||||
}
|
||||
DType::F16 => {
|
||||
let data: Vec<half::f16> = td
|
||||
.bytes
|
||||
.chunks_exact(2)
|
||||
.map(|b| half::f16::from_le_bytes([b[0], b[1]]))
|
||||
.collect();
|
||||
NativeData::F16(data)
|
||||
}
|
||||
DType::Bf16 => {
|
||||
let data: Vec<half::bf16> = td
|
||||
.bytes
|
||||
.chunks_exact(2)
|
||||
.map(|b| half::bf16::from_le_bytes([b[0], b[1]]))
|
||||
.collect();
|
||||
NativeData::Bf16(data)
|
||||
}
|
||||
DType::Int => {
|
||||
let data: Vec<i32> = td
|
||||
.bytes
|
||||
.chunks_exact(4)
|
||||
.map(|b| i32::from_le_bytes([b[0], b[1], b[2], b[3]]))
|
||||
.collect();
|
||||
NativeData::Int(data)
|
||||
}
|
||||
DType::Bool => {
|
||||
let data: Vec<bool> = td.bytes.iter().map(|&b| b != 0).collect();
|
||||
NativeData::Bool(data)
|
||||
}
|
||||
// Integer types that map to NativeData::Int
|
||||
DType::I8 => {
|
||||
let data: Vec<i32> = td.bytes.iter().map(|&b| b as i8 as i32).collect();
|
||||
NativeData::Int(data)
|
||||
}
|
||||
DType::U8 => {
|
||||
let data: Vec<i32> = td.bytes.iter().map(|&b| b as i32).collect();
|
||||
NativeData::Int(data)
|
||||
}
|
||||
DType::I16 => {
|
||||
let data: Vec<i32> = td
|
||||
.bytes
|
||||
.chunks_exact(2)
|
||||
.map(|b| i16::from_le_bytes([b[0], b[1]]) as i32)
|
||||
.collect();
|
||||
NativeData::Int(data)
|
||||
}
|
||||
DType::U16 => {
|
||||
let data: Vec<i32> = td
|
||||
.bytes
|
||||
.chunks_exact(2)
|
||||
.map(|b| u16::from_le_bytes([b[0], b[1]]) as i32)
|
||||
.collect();
|
||||
NativeData::Int(data)
|
||||
}
|
||||
// Sub-byte and F8 types: store as raw f32 for native runtime (best effort)
|
||||
_ => {
|
||||
// For exotic types, the native runtime can't handle them natively.
|
||||
// Store as f32 with element-wise conversion.
|
||||
let data: Vec<f32> = (0..td.n_elements()).map(|i| td.as_f64(i) as f32).collect();
|
||||
NativeData::F32(data)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert &TypedData to NativeData (clone the bytes).
|
||||
impl From<&TypedData> for NativeData {
|
||||
fn from(td: &TypedData) -> Self {
|
||||
td.clone().into()
|
||||
}
|
||||
}
|
||||
|
||||
// CUDA runtime conversion is implemented via ToCudaInput in runtime.rs
|
||||
// (behind the `cuda` feature gate) since it depends on cudarc types.
|
||||
@@ -1,465 +0,0 @@
|
||||
use std::{collections::HashMap, fs, path::Path};
|
||||
|
||||
use luminal::{prelude::GraphTensor, shape::Expression};
|
||||
use onnx_protobuf::NodeProto;
|
||||
|
||||
/// Maps ONNX dim_param names (e.g. "seq_len") to luminal Expression variable chars ('a'..'w').
|
||||
pub type DimParamMap = HashMap<String, char>;
|
||||
|
||||
// Given a Value from the Onnx proto return its tensor Shape, if it exists
|
||||
// Note: some times pytorch will create tensors with a 0 shape
|
||||
// we might want to handle, 0 shape and No shape as seperate ideas
|
||||
pub fn get_shape_for_onnx_value(value: &onnx_protobuf::ValueInfoProto) -> Vec<usize> {
|
||||
if let Some(type_proto) = value.type_.as_ref()
|
||||
&& let Some(onnx_protobuf::type_proto::Value::TensorType(tensor)) = &type_proto.value
|
||||
&& let Some(shape) = tensor.shape.as_ref()
|
||||
{
|
||||
// Scalar (0-dim) tensors have an empty dim list; represent as [1] in luminal
|
||||
if shape.dim.is_empty() {
|
||||
return vec![1];
|
||||
}
|
||||
return shape
|
||||
.dim
|
||||
.iter()
|
||||
.map(|dimension| {
|
||||
if let Some(onnx_protobuf::tensor_shape_proto::dimension::Value::DimValue(v)) =
|
||||
&dimension.value
|
||||
{
|
||||
*v as usize
|
||||
} else {
|
||||
1
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
}
|
||||
|
||||
vec![]
|
||||
}
|
||||
|
||||
/// Like `get_shape_for_onnx_value`, but returns `Vec<Expression>` with symbolic vars for DimParam dims.
|
||||
/// Allocates new variable chars in `dim_param_map` for unseen dim_param names.
|
||||
/// `next_char` is updated to the next available char after allocation.
|
||||
pub fn get_shape_for_onnx_value_expr(
|
||||
value: &onnx_protobuf::ValueInfoProto,
|
||||
dim_param_map: &mut DimParamMap,
|
||||
next_char: &mut char,
|
||||
) -> Vec<Expression> {
|
||||
if let Some(type_proto) = value.type_.as_ref()
|
||||
&& let Some(onnx_protobuf::type_proto::Value::TensorType(tensor)) = &type_proto.value
|
||||
&& let Some(shape) = tensor.shape.as_ref()
|
||||
{
|
||||
if shape.dim.is_empty() {
|
||||
return vec![Expression::from(1usize)];
|
||||
}
|
||||
return shape
|
||||
.dim
|
||||
.iter()
|
||||
.map(|dimension| match &dimension.value {
|
||||
Some(onnx_protobuf::tensor_shape_proto::dimension::Value::DimValue(v)) => {
|
||||
Expression::from(*v as usize)
|
||||
}
|
||||
Some(onnx_protobuf::tensor_shape_proto::dimension::Value::DimParam(name)) => {
|
||||
let ch = *dim_param_map.entry(name.clone()).or_insert_with(|| {
|
||||
let c = *next_char;
|
||||
*next_char = (c as u8 + 1) as char;
|
||||
c
|
||||
});
|
||||
Expression::from(ch)
|
||||
}
|
||||
_ => Expression::from(1usize),
|
||||
})
|
||||
.collect();
|
||||
}
|
||||
|
||||
vec![]
|
||||
}
|
||||
|
||||
/// Compute the broadcast output shape for two tensors using Expressions (numpy rules).
|
||||
pub fn compute_broadcast_shape_expr(a: &[Expression], b: &[Expression]) -> Vec<Expression> {
|
||||
let max_rank = a.len().max(b.len());
|
||||
let mut result = Vec::with_capacity(max_rank);
|
||||
|
||||
for i in 0..max_rank {
|
||||
let a_dim = if i < max_rank - a.len() {
|
||||
Expression::from(1usize)
|
||||
} else {
|
||||
a[i - (max_rank - a.len())]
|
||||
};
|
||||
let b_dim = if i < max_rank - b.len() {
|
||||
Expression::from(1usize)
|
||||
} else {
|
||||
b[i - (max_rank - b.len())]
|
||||
};
|
||||
|
||||
// If both are concrete, use max. If one is 1, use the other.
|
||||
// Otherwise, assume they match (same symbolic dim).
|
||||
let dim = match (a_dim.to_usize(), b_dim.to_usize()) {
|
||||
(Some(a_val), Some(b_val)) => Expression::from(a_val.max(b_val)),
|
||||
(Some(1), _) => b_dim,
|
||||
(_, Some(1)) => a_dim,
|
||||
_ => a_dim, // Both symbolic — assume compatible
|
||||
};
|
||||
result.push(dim);
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
/// Broadcast a tensor's shape to match a target Expression shape (numpy-style broadcasting).
|
||||
/// Left-pads with size-1 dims, then expands dims that are 1 to match target.
|
||||
pub fn broadcast_to_expr(mut tensor: GraphTensor, target_shape: &[Expression]) -> GraphTensor {
|
||||
let src_dims = tensor.dims();
|
||||
let src_len = src_dims.len();
|
||||
let tgt_len = target_shape.len();
|
||||
|
||||
if src_len == tgt_len {
|
||||
tensor.shape.expand(target_shape.to_vec());
|
||||
return tensor;
|
||||
}
|
||||
|
||||
// Left-pad with size-1 dims
|
||||
for _ in 0..(tgt_len - src_len) {
|
||||
tensor = tensor.expand_dim(0, 1);
|
||||
}
|
||||
|
||||
tensor.shape.expand(target_shape.to_vec());
|
||||
tensor
|
||||
}
|
||||
|
||||
/// Convert inline data from a TensorProto to f32, based on data_type.
|
||||
/// Returns None if the tensor has no inline data (e.g. external storage).
|
||||
fn convert_inline_data(init: &onnx_protobuf::TensorProto) -> Option<Vec<f32>> {
|
||||
match init.data_type {
|
||||
1 => {
|
||||
// FLOAT
|
||||
if !init.float_data.is_empty() {
|
||||
return Some(init.float_data.clone());
|
||||
}
|
||||
if !init.raw_data.is_empty() {
|
||||
return Some(parse_raw_bytes_as_f32(&init.raw_data, 1));
|
||||
}
|
||||
}
|
||||
7 => {
|
||||
// INT64
|
||||
if !init.int64_data.is_empty() {
|
||||
return Some(init.int64_data.iter().map(|&v| v as f32).collect());
|
||||
}
|
||||
if !init.raw_data.is_empty() {
|
||||
return Some(parse_raw_bytes_as_f32(&init.raw_data, 7));
|
||||
}
|
||||
}
|
||||
6 => {
|
||||
// INT32
|
||||
if !init.int32_data.is_empty() {
|
||||
return Some(init.int32_data.iter().map(|&v| v as f32).collect());
|
||||
}
|
||||
if !init.raw_data.is_empty() {
|
||||
return Some(parse_raw_bytes_as_f32(&init.raw_data, 6));
|
||||
}
|
||||
}
|
||||
9 => {
|
||||
// BOOL
|
||||
if !init.raw_data.is_empty() {
|
||||
return Some(parse_raw_bytes_as_f32(&init.raw_data, 9));
|
||||
}
|
||||
if !init.int32_data.is_empty() {
|
||||
return Some(
|
||||
init.int32_data
|
||||
.iter()
|
||||
.map(|&v| if v != 0 { 1.0 } else { 0.0 })
|
||||
.collect(),
|
||||
);
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
// Fallback: try float_data or interpret raw_data as F32
|
||||
if !init.float_data.is_empty() {
|
||||
return Some(init.float_data.clone());
|
||||
}
|
||||
if !init.raw_data.is_empty() {
|
||||
return Some(parse_raw_bytes_as_f32(&init.raw_data, 1));
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Parse a raw byte slice as f32 values, respecting the ONNX data_type.
|
||||
fn parse_raw_bytes_as_f32(bytes: &[u8], data_type: i32) -> Vec<f32> {
|
||||
match data_type {
|
||||
1 => bytes
|
||||
.chunks_exact(4)
|
||||
.map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
|
||||
.collect(),
|
||||
7 => bytes
|
||||
.chunks_exact(8)
|
||||
.map(|c| i64::from_le_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]]) as f32)
|
||||
.collect(),
|
||||
6 => bytes
|
||||
.chunks_exact(4)
|
||||
.map(|c| i32::from_le_bytes([c[0], c[1], c[2], c[3]]) as f32)
|
||||
.collect(),
|
||||
9 => bytes
|
||||
.iter()
|
||||
.map(|&b| if b != 0 { 1.0 } else { 0.0 })
|
||||
.collect(),
|
||||
_ => bytes
|
||||
.chunks_exact(4)
|
||||
.map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Load float data from a TensorProto, handling inline (float_data/raw_data) and external storage.
|
||||
/// Prefer `load_all_tensor_floats` for batch loading (avoids redundant file reads).
|
||||
#[allow(dead_code)]
|
||||
pub fn load_tensor_floats(init: &onnx_protobuf::TensorProto, model_dir: &Path) -> Option<Vec<f32>> {
|
||||
// Try inline data first
|
||||
if let Some(floats) = convert_inline_data(init) {
|
||||
return Some(floats);
|
||||
}
|
||||
// Try external data (data_location == EXTERNAL = 1)
|
||||
if !init.external_data.is_empty() {
|
||||
let mut location: Option<&str> = None;
|
||||
let mut offset: u64 = 0;
|
||||
let mut length: Option<u64> = None;
|
||||
for entry in &init.external_data {
|
||||
match entry.key.as_str() {
|
||||
"location" => location = Some(&entry.value),
|
||||
"offset" => offset = entry.value.parse().unwrap_or(0),
|
||||
"length" => length = entry.value.parse().ok(),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
if let Some(loc) = location {
|
||||
let ext_path = model_dir.join(loc);
|
||||
match fs::read(&ext_path) {
|
||||
Ok(file_data) => {
|
||||
let start = offset as usize;
|
||||
let end = match length {
|
||||
Some(len) => start + len as usize,
|
||||
None => file_data.len(),
|
||||
};
|
||||
if end > file_data.len() {
|
||||
return None;
|
||||
}
|
||||
return Some(parse_raw_bytes_as_f32(
|
||||
&file_data[start..end],
|
||||
init.data_type,
|
||||
));
|
||||
}
|
||||
Err(_) => {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Batch-load float data from multiple TensorProtos, reading each external file only once.
|
||||
/// Returns results in the same order as `inits`, with `None` for tensors that couldn't be loaded.
|
||||
pub fn load_all_tensor_floats(
|
||||
inits: &[onnx_protobuf::TensorProto],
|
||||
model_dir: &Path,
|
||||
) -> Vec<(String, Option<Vec<f32>>)> {
|
||||
let mut results: Vec<(String, Option<Vec<f32>>)> = Vec::with_capacity(inits.len());
|
||||
|
||||
// Pending external data entries: (result_index, offset, length, data_type)
|
||||
// grouped by file location
|
||||
type ExternalEntry = (usize, u64, Option<u64>, i32);
|
||||
let mut external_pending: HashMap<String, Vec<ExternalEntry>> = HashMap::new();
|
||||
|
||||
for (i, init) in inits.iter().enumerate() {
|
||||
// Try inline data first
|
||||
if let Some(floats) = convert_inline_data(init) {
|
||||
results.push((init.name.clone(), Some(floats)));
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check for external data
|
||||
if !init.external_data.is_empty() {
|
||||
let mut location: Option<String> = None;
|
||||
let mut offset: u64 = 0;
|
||||
let mut length: Option<u64> = None;
|
||||
for entry in &init.external_data {
|
||||
match entry.key.as_str() {
|
||||
"location" => location = Some(entry.value.clone()),
|
||||
"offset" => offset = entry.value.parse().unwrap_or(0),
|
||||
"length" => length = entry.value.parse().ok(),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
if let Some(loc) = location {
|
||||
// Push placeholder, will fill in later
|
||||
results.push((init.name.clone(), None));
|
||||
external_pending
|
||||
.entry(loc)
|
||||
.or_default()
|
||||
.push((i, offset, length, init.data_type));
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
results.push((init.name.clone(), None));
|
||||
}
|
||||
|
||||
// Read each external file once and extract all tensor slices
|
||||
for (loc, entries) in &external_pending {
|
||||
let ext_path = model_dir.join(loc);
|
||||
let file_data = match fs::read(&ext_path) {
|
||||
Ok(data) => data,
|
||||
Err(_) => continue, // results already have None
|
||||
};
|
||||
for &(idx, offset, length, data_type) in entries {
|
||||
let start = offset as usize;
|
||||
let end = match length {
|
||||
Some(len) => start + len as usize,
|
||||
None => file_data.len(),
|
||||
};
|
||||
if end > file_data.len() {
|
||||
continue;
|
||||
}
|
||||
results[idx].1 = Some(parse_raw_bytes_as_f32(&file_data[start..end], data_type));
|
||||
}
|
||||
}
|
||||
|
||||
results
|
||||
}
|
||||
|
||||
/// Load initializer data as f32 values, handling multiple ONNX data types.
|
||||
/// Used to seed known_values with small constant initializers for constant folding.
|
||||
pub fn load_initializer_as_f32(init: &onnx_protobuf::TensorProto) -> Option<Vec<f32>> {
|
||||
match init.data_type {
|
||||
1 => {
|
||||
// FLOAT
|
||||
if !init.float_data.is_empty() {
|
||||
Some(init.float_data.clone())
|
||||
} else if !init.raw_data.is_empty() {
|
||||
Some(
|
||||
init.raw_data
|
||||
.chunks_exact(4)
|
||||
.map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
|
||||
.collect(),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
7 => {
|
||||
// INT64
|
||||
if !init.int64_data.is_empty() {
|
||||
Some(init.int64_data.iter().map(|&v| v as f32).collect())
|
||||
} else if !init.raw_data.is_empty() {
|
||||
Some(
|
||||
init.raw_data
|
||||
.chunks_exact(8)
|
||||
.map(|c| {
|
||||
i64::from_le_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]])
|
||||
as f32
|
||||
})
|
||||
.collect(),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
6 => {
|
||||
// INT32
|
||||
if !init.int32_data.is_empty() {
|
||||
Some(init.int32_data.iter().map(|&v| v as f32).collect())
|
||||
} else if !init.raw_data.is_empty() {
|
||||
Some(
|
||||
init.raw_data
|
||||
.chunks_exact(4)
|
||||
.map(|c| i32::from_le_bytes([c[0], c[1], c[2], c[3]]) as f32)
|
||||
.collect(),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
16 => {
|
||||
// BFLOAT16 — 2 bytes per element, upper 16 bits of f32
|
||||
if !init.raw_data.is_empty() {
|
||||
Some(
|
||||
init.raw_data
|
||||
.chunks_exact(2)
|
||||
.map(|c| {
|
||||
let bits = u16::from_le_bytes([c[0], c[1]]);
|
||||
f32::from_bits((bits as u32) << 16)
|
||||
})
|
||||
.collect(),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
9 => {
|
||||
// BOOL — 1 byte per element, 0 → 0.0, non-zero → 1.0
|
||||
if !init.raw_data.is_empty() {
|
||||
Some(
|
||||
init.raw_data
|
||||
.iter()
|
||||
.map(|&b| if b != 0 { 1.0 } else { 0.0 })
|
||||
.collect(),
|
||||
)
|
||||
} else if !init.int32_data.is_empty() {
|
||||
Some(
|
||||
init.int32_data
|
||||
.iter()
|
||||
.map(|&v| if v != 0 { 1.0 } else { 0.0 })
|
||||
.collect(),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
11 => {
|
||||
// FLOAT64
|
||||
if !init.raw_data.is_empty() {
|
||||
Some(
|
||||
init.raw_data
|
||||
.chunks_exact(8)
|
||||
.map(|c| {
|
||||
f64::from_le_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]])
|
||||
as f32
|
||||
})
|
||||
.collect(),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get an integer attribute from a node, with a default value
|
||||
pub fn get_int_attr(node: &NodeProto, name: &str, default: i64) -> i64 {
|
||||
for attr in &node.attribute {
|
||||
if attr.name == name {
|
||||
return attr.i;
|
||||
}
|
||||
}
|
||||
default
|
||||
}
|
||||
|
||||
/// Get a string attribute from a node, with a default value
|
||||
pub fn get_str_attr(node: &NodeProto, name: &str, default: &str) -> String {
|
||||
for attr in &node.attribute {
|
||||
if attr.name == name {
|
||||
return String::from_utf8_lossy(&attr.s).into_owned();
|
||||
}
|
||||
}
|
||||
default.to_string()
|
||||
}
|
||||
|
||||
/// Get a float attribute from a node, with a default value
|
||||
pub fn get_float_attr(node: &NodeProto, name: &str, default: f32) -> f32 {
|
||||
for attr in &node.attribute {
|
||||
if attr.name == name {
|
||||
return attr.f;
|
||||
}
|
||||
}
|
||||
default
|
||||
}
|
||||
@@ -1,15 +1,13 @@
|
||||
"""Luminal Python bindings - PyTorch backend using Luminal."""
|
||||
|
||||
# Import Python components
|
||||
# Register DynamicCache pytree serialization once at import time
|
||||
from .cache_utils import _register_cache_serialization
|
||||
from .compiled_model import CompiledModel
|
||||
|
||||
# Import Rust extension components (built by maturin)
|
||||
# These are available directly in the package namespace
|
||||
from .luminal import CompiledGraph, process_onnx, process_pt2
|
||||
from .main import luminal_backend
|
||||
|
||||
# Register DynamicCache pytree serialization once at import time
|
||||
from .cache_utils import _register_cache_serialization
|
||||
from .luminal import CompiledGraph, process_pt2
|
||||
from .main import luminal_backend, register_backend
|
||||
|
||||
_register_cache_serialization()
|
||||
|
||||
@@ -17,7 +15,7 @@ _register_cache_serialization()
|
||||
__all__ = [
|
||||
"CompiledModel",
|
||||
"luminal_backend",
|
||||
"process_onnx",
|
||||
"register_backend",
|
||||
"CompiledGraph",
|
||||
"process_pt2",
|
||||
]
|
||||
|
||||
@@ -4,6 +4,9 @@ from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
from .dtype_util import code_to_torch_dtype
|
||||
from .dtype_util import torch_dtype_code as _torch_dtype_code
|
||||
|
||||
|
||||
class CompiledModel:
|
||||
"""Wrapper around CompiledGraph that handles PyTorch tensor conversion."""
|
||||
@@ -14,7 +17,7 @@ class CompiledModel:
|
||||
"""Initialize with a compiled CompiledGraph from Rust.
|
||||
|
||||
Args:
|
||||
graph_result: The CompiledGraph from luminal_python.process_onnx() or process_pt2()
|
||||
graph_result: The CompiledGraph from luminal_python.process_pt2()
|
||||
weight_refs: List of PyTorch tensors to keep alive (prevents GC of shared weights)
|
||||
input_names: Override for user input names. If None, uses graph_result.input_names.
|
||||
user_indices: When torch.compile lifts model parameters into extra args,
|
||||
@@ -28,7 +31,18 @@ class CompiledModel:
|
||||
self._has_dynamic_dims = getattr(graph_result, "has_dynamic_dims", False)
|
||||
self._weight_refs = weight_refs or []
|
||||
self._user_indices = user_indices
|
||||
self._is_cuda = graph_result.backend == "cuda"
|
||||
self._is_gpu = getattr(graph_result, "device_type", "cpu") != "cpu"
|
||||
self._supports_device_ptrs = getattr(
|
||||
graph_result, "supports_device_ptrs", False
|
||||
)
|
||||
# Expected input dtypes from graph (used to convert user inputs)
|
||||
input_dtype_codes = graph_result.input_dtypes
|
||||
self._input_dtypes = [
|
||||
code_to_torch_dtype(input_dtype_codes[i])
|
||||
if i < len(input_dtype_codes)
|
||||
else torch.float32
|
||||
for i in range(len(self._input_names))
|
||||
]
|
||||
|
||||
def set_dim(self, param_name: str, value: int) -> None:
|
||||
"""Set a dynamic dimension value by its param name."""
|
||||
@@ -70,44 +84,115 @@ class CompiledModel:
|
||||
input_shapes = [list(t.shape) for t in user_inputs]
|
||||
self._graph.auto_set_dims_from_input_shapes(input_shapes)
|
||||
|
||||
# Set user input data via pointer (avoids Python list conversion).
|
||||
# Set user input data via pointer.
|
||||
# Convert to the graph's expected dtype so bytes match the Input node's dtype tag.
|
||||
# For CUDA inputs, keep references alive so the caching allocator doesn't
|
||||
# recycle GPU memory before run() reads the pointers.
|
||||
_input_refs = []
|
||||
for name, tensor in zip(self._input_names, user_inputs):
|
||||
if self._is_cuda and tensor.is_cuda:
|
||||
t = tensor.detach().contiguous().float()
|
||||
self._graph.set_input_device_ptr(name, t.data_ptr(), t.numel() * 4)
|
||||
for name, tensor, expected_dtype in zip(
|
||||
self._input_names, user_inputs, self._input_dtypes
|
||||
):
|
||||
if self._supports_device_ptrs and tensor.is_cuda:
|
||||
t = tensor.detach().contiguous().to(expected_dtype)
|
||||
n_bytes = t.numel() * t.element_size()
|
||||
self._graph.set_input_device_ptr(name, t.data_ptr(), n_bytes)
|
||||
_input_refs.append(t)
|
||||
else:
|
||||
t = tensor.detach().cpu().contiguous().float()
|
||||
self._graph.set_input_from_ptr(name, t.data_ptr(), t.numel())
|
||||
t = tensor.detach().cpu().contiguous().to(expected_dtype)
|
||||
n_bytes = t.numel() * t.element_size()
|
||||
dtype_code = _torch_dtype_code(t.dtype)
|
||||
self._graph.set_input_from_ptr(name, t.data_ptr(), n_bytes, dtype_code)
|
||||
|
||||
# Run the graph
|
||||
self._graph.run()
|
||||
|
||||
# Get output shapes — resolve dynamically if needed
|
||||
# Resolve output shapes before run() (needed for pre-allocation).
|
||||
if self._has_dynamic_dims:
|
||||
output_shapes = self._graph.resolve_output_shapes()
|
||||
else:
|
||||
output_shapes = self._output_shapes
|
||||
|
||||
# Get outputs and convert back to PyTorch tensors on the same device as inputs.
|
||||
# For CUDA: DtoD copy avoids the DtoH + HtoD round-trip.
|
||||
outputs = []
|
||||
for name, shape in zip(self._output_names, output_shapes):
|
||||
if self._is_cuda and hasattr(self._graph, "copy_output_to_device_ptr"):
|
||||
out = torch.empty(shape, dtype=torch.float32, device=input_device)
|
||||
self._graph.copy_output_to_device_ptr(
|
||||
name, out.data_ptr(), out.numel() * 4
|
||||
output_dtype_codes = self._graph.output_dtypes
|
||||
|
||||
# CUDA zero-copy path: pre-allocate output tensors and register their device
|
||||
# pointers so the final kernel writes directly into PyTorch's buffer.
|
||||
_use_zero_copy = self._supports_device_ptrs
|
||||
output_tensors = []
|
||||
if _use_zero_copy:
|
||||
for i, (name, shape) in enumerate(zip(self._output_names, output_shapes)):
|
||||
out_dtype = (
|
||||
code_to_torch_dtype(output_dtype_codes[i])
|
||||
if i < len(output_dtype_codes)
|
||||
else torch.float32
|
||||
)
|
||||
else:
|
||||
data = self._graph.get_output(name)
|
||||
out = (
|
||||
torch.tensor(data, dtype=torch.float32)
|
||||
.reshape(tuple(shape))
|
||||
.to(input_device)
|
||||
out = torch.empty(shape, dtype=out_dtype, device=input_device)
|
||||
if out_dtype.is_floating_point:
|
||||
self._graph.set_output_device_ptr(
|
||||
name, out.data_ptr(), out.numel() * out.element_size()
|
||||
)
|
||||
output_tensors.append(out)
|
||||
|
||||
# Run the graph
|
||||
self._graph.run()
|
||||
|
||||
# Collect outputs
|
||||
if _use_zero_copy:
|
||||
outputs = []
|
||||
for i, (name, shape) in enumerate(zip(self._output_names, output_shapes)):
|
||||
out_dtype = (
|
||||
code_to_torch_dtype(output_dtype_codes[i])
|
||||
if i < len(output_dtype_codes)
|
||||
else torch.float32
|
||||
)
|
||||
outputs.append(out)
|
||||
out = output_tensors[i]
|
||||
if out_dtype.is_floating_point:
|
||||
if not self._graph.output_is_zero_copy(name):
|
||||
self._graph.copy_output_to_device_ptr(
|
||||
name, out.data_ptr(), out.numel() * out.element_size()
|
||||
)
|
||||
elif out_dtype == torch.int32:
|
||||
data = self._graph.get_output_i32(name)
|
||||
out = (
|
||||
torch.tensor(data, dtype=torch.int32)
|
||||
.reshape(tuple(shape))
|
||||
.to(input_device)
|
||||
)
|
||||
elif out_dtype == torch.bool:
|
||||
data = self._graph.get_output_bool(name)
|
||||
out = (
|
||||
torch.tensor(data, dtype=torch.bool)
|
||||
.reshape(tuple(shape))
|
||||
.to(input_device)
|
||||
)
|
||||
else:
|
||||
data = self._graph.get_output(name)
|
||||
out = (
|
||||
torch.tensor(data, dtype=torch.float32)
|
||||
.reshape(tuple(shape))
|
||||
.to(out_dtype)
|
||||
.to(input_device)
|
||||
)
|
||||
outputs.append(out)
|
||||
else:
|
||||
# Native path: retrieve as f32, then convert to target dtype if needed.
|
||||
outputs = []
|
||||
for i, (name, shape) in enumerate(zip(self._output_names, output_shapes)):
|
||||
out_dtype = (
|
||||
code_to_torch_dtype(output_dtype_codes[i])
|
||||
if i < len(output_dtype_codes)
|
||||
else torch.float32
|
||||
)
|
||||
if out_dtype == torch.int32:
|
||||
data = self._graph.get_output_i32(name)
|
||||
out = torch.tensor(data, dtype=torch.int32).reshape(tuple(shape))
|
||||
elif out_dtype == torch.bool:
|
||||
data = self._graph.get_output_bool(name)
|
||||
out = torch.tensor(data, dtype=torch.bool).reshape(tuple(shape))
|
||||
else:
|
||||
data = self._graph.get_output(name)
|
||||
out = (
|
||||
torch.tensor(data, dtype=torch.float32)
|
||||
.reshape(tuple(shape))
|
||||
.to(out_dtype)
|
||||
)
|
||||
out = out.to(input_device)
|
||||
outputs.append(out)
|
||||
|
||||
return tuple(outputs)
|
||||
|
||||
28
crates/luminal_python/src/luminal/dtype_util.py
Normal file
28
crates/luminal_python/src/luminal/dtype_util.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""Shared dtype utility functions for the luminal Python Bridge"""
|
||||
|
||||
import torch
|
||||
|
||||
_TORCH_DTYPE_TO_CODE = {
|
||||
torch.uint8: 1,
|
||||
torch.int8: 2,
|
||||
torch.int16: 3,
|
||||
torch.int32: 4,
|
||||
torch.int64: 5,
|
||||
torch.float16: 6,
|
||||
torch.float32: 7,
|
||||
torch.float64: 8,
|
||||
torch.bool: 12,
|
||||
torch.bfloat16: 13,
|
||||
}
|
||||
|
||||
_CODE_TO_TORCH_DTYPE = {v: k for k, v in _TORCH_DTYPE_TO_CODE.items()}
|
||||
|
||||
|
||||
def torch_dtype_code(dtype):
|
||||
"""Map torch.dtype to PT2 dtype integer code."""
|
||||
return _TORCH_DTYPE_TO_CODE.get(dtype, 7) # default to f32
|
||||
|
||||
|
||||
def code_to_torch_dtype(code):
|
||||
"""Map PT2 dtype integer code to torch.dtype."""
|
||||
return _CODE_TO_TORCH_DTYPE.get(code, torch.float32)
|
||||
@@ -1,150 +1,101 @@
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import torch
|
||||
import torch._dynamo
|
||||
|
||||
import luminal
|
||||
|
||||
from .compiled_model import CompiledModel
|
||||
from .dtype_util import torch_dtype_code as _torch_dtype_code
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared helpers (used by both ONNX and PT2 paths)
|
||||
# Shared helpers (used by PT2 path and compiled_model)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _detect_backend(example_inputs):
|
||||
"""Detect backend from input device. Returns 'cuda' or 'native'."""
|
||||
def _detect_factory_capsule(example_inputs):
|
||||
"""Pick the best built-in factory capsule based on input device."""
|
||||
device = example_inputs[0].device if example_inputs else torch.device("cpu")
|
||||
return "cuda" if device.type == "cuda" else "native"
|
||||
if device.type == "cuda":
|
||||
try:
|
||||
from .luminal import _cuda_lite_factory_capsule
|
||||
|
||||
return _cuda_lite_factory_capsule()
|
||||
except ImportError:
|
||||
pass
|
||||
from .luminal import _native_factory_capsule
|
||||
|
||||
return _native_factory_capsule()
|
||||
|
||||
|
||||
def _collect_weight_pointers(weights, backend):
|
||||
def _collect_weight_pointers(weights):
|
||||
"""Partition weight tensors into CUDA device pointers and CPU host pointers.
|
||||
|
||||
Preserves native dtype — no forced conversion to float32.
|
||||
|
||||
Args:
|
||||
weights: dict of name -> torch.Tensor
|
||||
backend: "cuda", "gpu", "cpu", or "native"
|
||||
|
||||
Returns:
|
||||
(keep_alive, device_ptrs, cpu_ptrs) where:
|
||||
- keep_alive: list[Tensor] to prevent GC of shared weight memory
|
||||
- device_ptrs: {name: (device_ptr, n_bytes)}
|
||||
- cpu_ptrs: {name: (host_ptr, n_elements)}
|
||||
- cpu_ptrs: {name: (host_ptr, n_bytes, dtype_code)}
|
||||
"""
|
||||
keep_alive = []
|
||||
device_ptrs = {}
|
||||
cpu_ptrs = {}
|
||||
for name, tensor in weights.items():
|
||||
t = tensor.detach().contiguous()
|
||||
if t.dtype != torch.float32:
|
||||
t = t.float()
|
||||
if backend in ("cuda", "gpu") and t.is_cuda:
|
||||
n_bytes = t.numel() * t.element_size()
|
||||
if t.is_cuda:
|
||||
keep_alive.append(t)
|
||||
device_ptrs[name] = (t.data_ptr(), t.numel() * 4)
|
||||
device_ptrs[name] = (t.data_ptr(), n_bytes)
|
||||
else:
|
||||
t = t.cpu() if t.is_cuda else t
|
||||
keep_alive.append(t)
|
||||
cpu_ptrs[name] = (t.data_ptr(), t.numel())
|
||||
cpu_ptrs[name] = (t.data_ptr(), n_bytes, _torch_dtype_code(t.dtype))
|
||||
return keep_alive, device_ptrs, cpu_ptrs
|
||||
|
||||
|
||||
def _load_cpu_weights(compiled_graph, cpu_weights):
|
||||
"""Load CPU weight data into a compiled graph after Rust compilation."""
|
||||
for name, (ptr, n_elements) in cpu_weights.items():
|
||||
compiled_graph.set_weight_from_ptr(name, ptr, n_elements)
|
||||
for name, (ptr, n_bytes, dtype_code) in cpu_weights.items():
|
||||
compiled_graph.set_weight_from_ptr(name, ptr, n_bytes, dtype_code)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# torch.compile backend entry point
|
||||
# Backend registration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def register_backend(factory_capsule):
|
||||
"""Wrap a backend factory PyCapsule into a torch.compile-compatible callable.
|
||||
|
||||
Args:
|
||||
factory_capsule: PyCapsule wrapping a BackendFactory fn pointer.
|
||||
|
||||
Returns:
|
||||
A callable(gm, example_inputs, options=None) suitable for torch.compile.
|
||||
"""
|
||||
|
||||
def backend(gm, example_inputs, options=None):
|
||||
return _compile_pt2(gm, example_inputs, factory_capsule, options=options)
|
||||
|
||||
return backend
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# torch.compile backend entry point (auto-detecting)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def luminal_backend(gm, example_inputs, options=None):
|
||||
"""Luminal torch.compile backend.
|
||||
"""Auto-detecting torch.compile backend.
|
||||
|
||||
Usage:
|
||||
torch.compile(model, backend=luminal_backend)
|
||||
torch.compile(model, backend=luminal_backend, options={"export_mode": "pt2"})
|
||||
Picks cuda_lite if inputs are on CUDA (and cuda feature is compiled in),
|
||||
native otherwise.
|
||||
|
||||
Options:
|
||||
export_mode: "onnx" (default) or "pt2"
|
||||
opset: ONNX opset version (default 20)
|
||||
For external backends, use register_backend with the backend's factory capsule.
|
||||
"""
|
||||
options = options or {}
|
||||
|
||||
# Env var override
|
||||
env_mode = os.getenv("LUMINAL_EXPORT_MODE", "").lower()
|
||||
export_mode = (
|
||||
env_mode if env_mode in ("pt2", "onnx") else options.get("export_mode", "onnx")
|
||||
)
|
||||
opset = options.get("opset", 20)
|
||||
|
||||
backend = _detect_backend(example_inputs)
|
||||
|
||||
if export_mode == "pt2":
|
||||
return _compile_pt2(gm, example_inputs, backend)
|
||||
return _compile_onnx(gm, example_inputs, backend, opset=opset)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ONNX compilation path
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _compile_onnx(gm, example_inputs, backend, opset=20):
|
||||
"""ONNX compilation path."""
|
||||
# Identify weight vs user inputs from FX graph placeholders.
|
||||
# torch.compile lifts model parameters into graph inputs — we detect them by name prefix.
|
||||
weight_tensors = {} # onnx_name -> tensor
|
||||
user_indices = []
|
||||
ph_idx = 0
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == "placeholder":
|
||||
onnx_name = f"input_{ph_idx}"
|
||||
if node.name.startswith(("l_self_", "l_model_", "l__self_")):
|
||||
weight_tensors[onnx_name] = example_inputs[ph_idx]
|
||||
else:
|
||||
user_indices.append(ph_idx)
|
||||
ph_idx += 1
|
||||
|
||||
# Collect weight pointers for Rust (avoids duplicate GPU buffer allocation)
|
||||
weight_refs, weight_device_ptrs, cpu_weights = _collect_weight_pointers(
|
||||
weight_tensors, backend
|
||||
)
|
||||
|
||||
tmp = tempfile.NamedTemporaryFile(suffix=".onnx", delete=False)
|
||||
tmp_path = tmp.name
|
||||
tmp.close()
|
||||
_ = gm.eval()
|
||||
try:
|
||||
_ = torch.onnx.export(
|
||||
gm,
|
||||
tuple(example_inputs),
|
||||
tmp_path,
|
||||
opset_version=opset,
|
||||
input_names=[f"input_{i}" for i in range(len(example_inputs))],
|
||||
)
|
||||
|
||||
result = luminal.process_onnx(
|
||||
tmp_path, backend, weight_device_ptrs=weight_device_ptrs
|
||||
)
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
|
||||
# Load CPU weights after compilation
|
||||
_load_cpu_weights(result, cpu_weights)
|
||||
|
||||
# Only expose user input names to CompiledModel (weights are pre-loaded).
|
||||
# user_indices tells __call__ which args from torch.compile are real user inputs.
|
||||
user_input_names = [f"input_{i}" for i in user_indices]
|
||||
return CompiledModel(
|
||||
result,
|
||||
weight_refs=weight_refs,
|
||||
input_names=user_input_names,
|
||||
user_indices=user_indices,
|
||||
)
|
||||
capsule = _detect_factory_capsule(example_inputs)
|
||||
return _compile_pt2(gm, example_inputs, capsule, options=options)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -152,8 +103,8 @@ def _compile_onnx(gm, example_inputs, backend, opset=20):
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _compile_pt2(gm, example_inputs, backend):
|
||||
def _compile_pt2(gm, example_inputs, factory_capsule, options=None):
|
||||
"""PT2/torch.export path — delegates to pt2.pt2_backend."""
|
||||
from .pt2 import pt2_backend
|
||||
|
||||
return pt2_backend(gm, example_inputs, backend=backend)
|
||||
return pt2_backend(gm, example_inputs, factory=factory_capsule, options=options)
|
||||
|
||||
@@ -14,7 +14,7 @@ import torch
|
||||
|
||||
from .compiled_model import CompiledModel
|
||||
from .luminal import process_pt2
|
||||
from .main import _collect_weight_pointers, _detect_backend, _load_cpu_weights
|
||||
from .main import _collect_weight_pointers, _detect_factory_capsule, _load_cpu_weights
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
@@ -32,12 +32,18 @@ def _export_kwargs():
|
||||
return kwargs
|
||||
|
||||
|
||||
def _save_and_compile(ep_or_path, backend, search_iterations, original_weights=None):
|
||||
def _save_and_compile(
|
||||
ep_or_path,
|
||||
factory,
|
||||
original_weights=None,
|
||||
options=None,
|
||||
):
|
||||
"""Compile a PT2 model via Rust, return CompiledModel.
|
||||
|
||||
Args:
|
||||
ep_or_path: Either an ExportedProgram (will be saved to a temp file) or
|
||||
a path to an already-saved .pt2 file.
|
||||
factory: PyCapsule wrapping the BackendFactory to use.
|
||||
original_weights: Optional dict mapping state_dict key -> original PyTorch tensor.
|
||||
When provided, device pointers are taken from these tensors instead of
|
||||
ep.state_dict (which torch.export may have cloned), enabling true zero-copy
|
||||
@@ -58,12 +64,16 @@ def _save_and_compile(ep_or_path, backend, search_iterations, original_weights=N
|
||||
|
||||
# Collect weight pointers for Rust (avoids duplicate GPU buffer allocation)
|
||||
keep_alive, weight_device_ptrs, cpu_weights = _collect_weight_pointers(
|
||||
weight_source, backend
|
||||
weight_source
|
||||
)
|
||||
|
||||
# Compile with device pointers — search uses actual weight memory (zero-copy)
|
||||
compiled = process_pt2(
|
||||
pt2_path, "", backend, search_iterations, weight_device_ptrs
|
||||
pt2_path,
|
||||
"",
|
||||
factory,
|
||||
weight_device_ptrs=weight_device_ptrs,
|
||||
options=options,
|
||||
)
|
||||
|
||||
# Load CPU weights after compilation
|
||||
@@ -136,7 +146,7 @@ def compile(
|
||||
model,
|
||||
example_input,
|
||||
search_iterations=25,
|
||||
backend=None,
|
||||
factory=None,
|
||||
export_kwargs=None,
|
||||
dynamic_dim=None,
|
||||
):
|
||||
@@ -146,7 +156,7 @@ def compile(
|
||||
model: A PyTorch nn.Module.
|
||||
example_input: Example input tensor(s) for tracing.
|
||||
search_iterations: Number of optimization search iterations.
|
||||
backend: "native" or "cuda". Auto-detected if None.
|
||||
factory: PyCapsule wrapping a BackendFactory. Auto-detected if None.
|
||||
export_kwargs: Extra kwargs passed to torch.export.export.
|
||||
dynamic_dim: Which input dimension to make dynamic.
|
||||
|
||||
@@ -156,10 +166,8 @@ def compile(
|
||||
if dynamic_dim is None:
|
||||
dynamic_dim = "auto"
|
||||
|
||||
if backend is None:
|
||||
backend = os.environ.get("LUMINAL_BACKEND", None)
|
||||
if backend is None:
|
||||
backend = "cuda" if torch.cuda.is_available() else "native"
|
||||
if factory is None:
|
||||
factory = _detect_factory_capsule([example_input])
|
||||
|
||||
kwargs = export_kwargs or {}
|
||||
extra = _export_kwargs()
|
||||
@@ -193,6 +201,7 @@ def compile(
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
**extra,
|
||||
)
|
||||
ep = ep.run_decompositions()
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
@@ -205,24 +214,30 @@ def compile(
|
||||
dynamic_shapes=None,
|
||||
**extra,
|
||||
)
|
||||
ep = ep.run_decompositions()
|
||||
|
||||
return _save_and_compile(ep, backend, search_iterations)
|
||||
return _save_and_compile(
|
||||
ep,
|
||||
factory,
|
||||
options={"search_iterations": search_iterations},
|
||||
)
|
||||
|
||||
|
||||
def pt2_backend(gm, example_inputs, backend=None):
|
||||
def pt2_backend(gm, example_inputs, factory=None, options=None):
|
||||
"""torch.compile backend using PT2 pipeline.
|
||||
|
||||
Usage: torch.compile(model, backend=luminal.pt2.pt2_backend)
|
||||
Usage: torch.compile(model, backend=luminal.register_backend(capsule))
|
||||
"""
|
||||
import gc
|
||||
|
||||
if backend is None:
|
||||
backend = _detect_backend(example_inputs)
|
||||
if factory is None:
|
||||
factory = _detect_factory_capsule(example_inputs)
|
||||
|
||||
gm = gm.eval()
|
||||
gm, user_inputs, original_weights = _reinternalize_lifted_params(gm, example_inputs)
|
||||
|
||||
ep = torch.export.export(gm, tuple(user_inputs), **_export_kwargs())
|
||||
ep = ep.run_decompositions()
|
||||
|
||||
# When using shared memory (original_weights), strip large weight buffers from
|
||||
# the EP before saving. The Rust side uses device pointers for these weights,
|
||||
@@ -249,7 +264,10 @@ def pt2_backend(gm, example_inputs, backend=None):
|
||||
|
||||
try:
|
||||
result = _save_and_compile(
|
||||
pt2_path, backend, 10, original_weights=original_weights
|
||||
pt2_path,
|
||||
factory,
|
||||
original_weights=original_weights,
|
||||
options=options,
|
||||
)
|
||||
return result
|
||||
finally:
|
||||
|
||||
@@ -1,194 +0,0 @@
|
||||
"""Kimi-K2.5 / DeepseekV3 model integration tests.
|
||||
|
||||
Tests the DeepseekV3 text backbone (MoE + MLA attention with LoRA-compressed KV,
|
||||
SwiGLU, YaRN RoPE) through the PyTorch -> ONNX -> luminal pipeline.
|
||||
|
||||
The model code requires trust_remote_code=True and uses custom HF modules from
|
||||
moonshotai/Kimi-K2.5. Since torch.compile cannot trace the MoE routing (it uses
|
||||
.numpy() and tensor indexing incompatible with dynamo), tests use manual ONNX
|
||||
export + onnxsim simplification + luminal.process_onnx.
|
||||
"""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import warnings
|
||||
|
||||
import onnx
|
||||
import onnxsim
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
|
||||
def _get_deepseek_v3_classes():
|
||||
"""Import DeepseekV3Config and DeepseekV3ForCausalLM from the Kimi-K2.5 HF repo."""
|
||||
import importlib
|
||||
|
||||
from transformers import AutoConfig
|
||||
|
||||
config = AutoConfig.from_pretrained("moonshotai/Kimi-K2.5", trust_remote_code=True)
|
||||
tc = config.text_config
|
||||
DeepseekV3Config = type(tc)
|
||||
pkg = DeepseekV3Config.__module__.rsplit(".", 1)[0]
|
||||
modeling_mod = importlib.import_module(f"{pkg}.modeling_deepseek")
|
||||
return DeepseekV3Config, modeling_mod.DeepseekV3ForCausalLM
|
||||
|
||||
|
||||
def _make_deepseek_v3_config(
|
||||
DeepseekV3Config,
|
||||
hidden_size: int = 64,
|
||||
num_attention_heads: int = 4,
|
||||
num_key_value_heads: int = 4,
|
||||
num_hidden_layers: int = 1,
|
||||
intermediate_size: int = 128,
|
||||
vocab_size: int = 256,
|
||||
kv_lora_rank: int = 16,
|
||||
q_lora_rank: int = 32,
|
||||
qk_nope_head_dim: int = 8,
|
||||
qk_rope_head_dim: int = 8,
|
||||
v_head_dim: int = 8,
|
||||
n_routed_experts: int = 4,
|
||||
num_experts_per_tok: int = 2,
|
||||
n_shared_experts: int = 1,
|
||||
moe_intermediate_size: int = 32,
|
||||
first_k_dense_replace: int = 1,
|
||||
):
|
||||
"""Create a small DeepseekV3Config for testing."""
|
||||
config = DeepseekV3Config(
|
||||
hidden_size=hidden_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
num_hidden_layers=num_hidden_layers,
|
||||
intermediate_size=intermediate_size,
|
||||
vocab_size=vocab_size,
|
||||
max_position_embeddings=128,
|
||||
kv_lora_rank=kv_lora_rank,
|
||||
q_lora_rank=q_lora_rank,
|
||||
qk_nope_head_dim=qk_nope_head_dim,
|
||||
qk_rope_head_dim=qk_rope_head_dim,
|
||||
v_head_dim=v_head_dim,
|
||||
n_routed_experts=n_routed_experts,
|
||||
num_experts_per_tok=num_experts_per_tok,
|
||||
n_shared_experts=n_shared_experts,
|
||||
moe_intermediate_size=moe_intermediate_size,
|
||||
first_k_dense_replace=first_k_dense_replace,
|
||||
use_cache=False,
|
||||
n_group=1,
|
||||
topk_group=1,
|
||||
topk_method="noaux_tc",
|
||||
scoring_func="sigmoid",
|
||||
rope_scaling={
|
||||
"type": "yarn",
|
||||
"rope_type": "yarn",
|
||||
"factor": 4.0,
|
||||
"original_max_position_embeddings": 32,
|
||||
"beta_fast": 32.0,
|
||||
"beta_slow": 1.0,
|
||||
"mscale": 1.0,
|
||||
"mscale_all_dim": 1.0,
|
||||
"rope_theta": 10000.0,
|
||||
},
|
||||
rope_theta=10000.0,
|
||||
)
|
||||
config._attn_implementation = "eager"
|
||||
return config
|
||||
|
||||
|
||||
def _export_and_simplify(model, input_ids):
|
||||
"""Export model to ONNX and simplify with onnxsim to constant-fold shape chains."""
|
||||
tmp = tempfile.NamedTemporaryFile(suffix=".onnx", delete=False)
|
||||
tmp_path = tmp.name
|
||||
tmp.close()
|
||||
try:
|
||||
torch.onnx.export(
|
||||
model,
|
||||
(input_ids,),
|
||||
tmp_path,
|
||||
opset_version=20,
|
||||
input_names=["input_ids"],
|
||||
output_names=["logits"],
|
||||
dynamo=False,
|
||||
)
|
||||
m = onnx.load(tmp_path)
|
||||
m_sim, check = onnxsim.simplify(m)
|
||||
assert check, "onnxsim simplification failed"
|
||||
onnx.save(m_sim, tmp_path)
|
||||
return tmp_path
|
||||
except Exception:
|
||||
os.unlink(tmp_path)
|
||||
raise
|
||||
|
||||
|
||||
def _run_deepseek_v3_test(config, DeepseekV3ForCausalLM, backend: str, atol: float):
|
||||
"""Export DeepseekV3 to ONNX, simplify, run through luminal, compare."""
|
||||
import luminal
|
||||
|
||||
model = DeepseekV3ForCausalLM(config).eval()
|
||||
input_ids = torch.tensor([[1, 2, 3, 4]])
|
||||
|
||||
onnx_path = _export_and_simplify(model, input_ids)
|
||||
try:
|
||||
graph = luminal.process_onnx(onnx_path, backend)
|
||||
graph.set_input("input_ids", [1.0, 2.0, 3.0, 4.0])
|
||||
graph.run()
|
||||
logits_data = graph.get_output("logits")
|
||||
logits = torch.tensor(logits_data, dtype=torch.float32).reshape(
|
||||
1, 4, config.vocab_size
|
||||
)
|
||||
finally:
|
||||
os.unlink(onnx_path)
|
||||
|
||||
with torch.no_grad():
|
||||
ref = model(input_ids)
|
||||
|
||||
assert torch.allclose(logits, ref.logits, atol=atol), (
|
||||
f"max_diff={torch.max(torch.abs(logits - ref.logits)).item():.2e}"
|
||||
)
|
||||
|
||||
|
||||
# ========== Tests ==========
|
||||
|
||||
|
||||
def test_deepseek_v3_tiny_dense():
|
||||
"""Tiny DeepseekV3 with dense MLP (no MoE): 64 hidden, 1 layer, MLA attention."""
|
||||
DeepseekV3Config, DeepseekV3ForCausalLM = _get_deepseek_v3_classes()
|
||||
config = _make_deepseek_v3_config(
|
||||
DeepseekV3Config,
|
||||
first_k_dense_replace=1, # all layers use dense MLP
|
||||
)
|
||||
backend = os.environ.get("LUMINAL_BACKEND", "native")
|
||||
_run_deepseek_v3_test(config, DeepseekV3ForCausalLM, backend, atol=1e-5)
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="MoE routing uses Int/F32 mixed ops not yet supported")
|
||||
def test_deepseek_v3_tiny_moe():
|
||||
"""Tiny DeepseekV3 with MoE: 64 hidden, 1 layer, 4 routed experts + 1 shared."""
|
||||
DeepseekV3Config, DeepseekV3ForCausalLM = _get_deepseek_v3_classes()
|
||||
config = _make_deepseek_v3_config(
|
||||
DeepseekV3Config,
|
||||
first_k_dense_replace=0, # all layers use MoE
|
||||
)
|
||||
backend = os.environ.get("LUMINAL_BACKEND", "native")
|
||||
_run_deepseek_v3_test(config, DeepseekV3ForCausalLM, backend, atol=1e-5)
|
||||
|
||||
|
||||
def test_deepseek_v3_small_dense():
|
||||
"""Small DeepseekV3 with dense MLP: 256 hidden, 1 layer."""
|
||||
DeepseekV3Config, DeepseekV3ForCausalLM = _get_deepseek_v3_classes()
|
||||
config = _make_deepseek_v3_config(
|
||||
DeepseekV3Config,
|
||||
hidden_size=256,
|
||||
num_attention_heads=8,
|
||||
num_key_value_heads=8,
|
||||
intermediate_size=512,
|
||||
vocab_size=1024,
|
||||
kv_lora_rank=32,
|
||||
q_lora_rank=64,
|
||||
qk_nope_head_dim=16,
|
||||
qk_rope_head_dim=16,
|
||||
v_head_dim=16,
|
||||
first_k_dense_replace=1,
|
||||
)
|
||||
backend = os.environ.get("LUMINAL_BACKEND", "native")
|
||||
_run_deepseek_v3_test(config, DeepseekV3ForCausalLM, backend, atol=1e-4)
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Qwen3-8B HuggingFace model integration tests.
|
||||
|
||||
Tests progressively larger HuggingFace Qwen3ForCausalLM configs through the
|
||||
PyTorch -> ONNX -> luminal pipeline via torch.compile. Qwen3 shares the same
|
||||
PyTorch -> PT2 -> luminal pipeline via torch.compile. Qwen3 shares the same
|
||||
architecture family as Llama (GQA, RoPE, SwiGLU MLP, RMSNorm).
|
||||
"""
|
||||
|
||||
|
||||
@@ -1,426 +0,0 @@
|
||||
"""Qwen-Image diffusion model integration tests.
|
||||
|
||||
Tests the QwenImageTransformer2DModel (MMDiT denoiser) and AutoencoderKLQwenImage (VAE)
|
||||
through the PyTorch -> ONNX -> luminal pipeline.
|
||||
|
||||
The transformer uses complex-valued RoPE (torch.view_as_complex) which isn't ONNX-exportable,
|
||||
so tests use a wrapper that pre-computes RoPE as real-valued cos/sin and replaces the
|
||||
attention processor with a real-valued equivalent.
|
||||
|
||||
The VAE uses Conv3d, which is supported via the N-dimensional unfold-based conv parser.
|
||||
"""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import warnings
|
||||
|
||||
import onnx
|
||||
import onnxsim
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Transformer helpers
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def _apply_rope_real(x, cos, sin):
|
||||
"""Apply RoPE using real-valued cos/sin. x: [B, S, H, D], cos/sin: [S, D/2]."""
|
||||
d = x.shape[-1]
|
||||
x1 = x[..., : d // 2]
|
||||
x2 = x[..., d // 2 :]
|
||||
cos = cos.unsqueeze(0).unsqueeze(2) # [1, S, 1, D/2]
|
||||
sin = sin.unsqueeze(0).unsqueeze(2)
|
||||
rotated_x1 = x1 * cos - x2 * sin
|
||||
rotated_x2 = x2 * cos + x1 * sin
|
||||
return torch.cat([rotated_x1, rotated_x2], dim=-1)
|
||||
|
||||
|
||||
class RealRoPEAttnProcessor:
|
||||
"""Attention processor that uses real-valued RoPE for ONNX compatibility.
|
||||
|
||||
Replaces the default QwenDoubleStreamAttnProcessor2_0 which uses
|
||||
torch.view_as_complex (not ONNX-exportable).
|
||||
"""
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn,
|
||||
hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
encoder_hidden_states_mask=None,
|
||||
attention_mask=None,
|
||||
image_rotary_emb=None,
|
||||
):
|
||||
seq_txt = encoder_hidden_states.shape[1]
|
||||
|
||||
img_query = attn.to_q(hidden_states)
|
||||
img_key = attn.to_k(hidden_states)
|
||||
img_value = attn.to_v(hidden_states)
|
||||
|
||||
txt_query = attn.add_q_proj(encoder_hidden_states)
|
||||
txt_key = attn.add_k_proj(encoder_hidden_states)
|
||||
txt_value = attn.add_v_proj(encoder_hidden_states)
|
||||
|
||||
img_query = img_query.unflatten(-1, (attn.heads, -1))
|
||||
img_key = img_key.unflatten(-1, (attn.heads, -1))
|
||||
img_value = img_value.unflatten(-1, (attn.heads, -1))
|
||||
|
||||
txt_query = txt_query.unflatten(-1, (attn.heads, -1))
|
||||
txt_key = txt_key.unflatten(-1, (attn.heads, -1))
|
||||
txt_value = txt_value.unflatten(-1, (attn.heads, -1))
|
||||
|
||||
if attn.norm_q is not None:
|
||||
img_query = attn.norm_q(img_query)
|
||||
if attn.norm_k is not None:
|
||||
img_key = attn.norm_k(img_key)
|
||||
if attn.norm_added_q is not None:
|
||||
txt_query = attn.norm_added_q(txt_query)
|
||||
if attn.norm_added_k is not None:
|
||||
txt_key = attn.norm_added_k(txt_key)
|
||||
|
||||
if image_rotary_emb is not None:
|
||||
img_cos, img_sin, txt_cos, txt_sin = image_rotary_emb
|
||||
img_query = _apply_rope_real(img_query, img_cos, img_sin)
|
||||
img_key = _apply_rope_real(img_key, img_cos, img_sin)
|
||||
txt_query = _apply_rope_real(txt_query, txt_cos, txt_sin)
|
||||
txt_key = _apply_rope_real(txt_key, txt_cos, txt_sin)
|
||||
|
||||
joint_query = torch.cat([txt_query, img_query], dim=1)
|
||||
joint_key = torch.cat([txt_key, img_key], dim=1)
|
||||
joint_value = torch.cat([txt_value, img_value], dim=1)
|
||||
|
||||
joint_query = joint_query.transpose(1, 2)
|
||||
joint_key = joint_key.transpose(1, 2)
|
||||
joint_value = joint_value.transpose(1, 2)
|
||||
joint_hidden = torch.nn.functional.scaled_dot_product_attention(
|
||||
joint_query, joint_key, joint_value, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
joint_hidden = joint_hidden.transpose(1, 2)
|
||||
joint_hidden = joint_hidden.flatten(2, 3)
|
||||
|
||||
txt_attn = joint_hidden[:, :seq_txt, :]
|
||||
img_attn = joint_hidden[:, seq_txt:, :]
|
||||
|
||||
img_attn = attn.to_out[0](img_attn.contiguous())
|
||||
if len(attn.to_out) > 1:
|
||||
img_attn = attn.to_out[1](img_attn)
|
||||
txt_attn = attn.to_add_out(txt_attn.contiguous())
|
||||
|
||||
return img_attn, txt_attn
|
||||
|
||||
|
||||
class TransformerONNXWrapper(nn.Module):
|
||||
"""Wraps QwenImageTransformer2DModel for ONNX export.
|
||||
|
||||
Pre-computes complex RoPE frequencies as real cos/sin buffers and replaces
|
||||
the attention processors with ONNX-friendly real-valued versions.
|
||||
"""
|
||||
|
||||
def __init__(self, model, img_shapes, txt_seq_len):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
for block in self.model.transformer_blocks:
|
||||
block.attn.set_processor(RealRoPEAttnProcessor())
|
||||
|
||||
with torch.no_grad():
|
||||
img_freqs, txt_freqs = model.pos_embed(
|
||||
img_shapes, max_txt_seq_len=txt_seq_len
|
||||
)
|
||||
self.register_buffer("img_cos", img_freqs.real.float().contiguous())
|
||||
self.register_buffer("img_sin", img_freqs.imag.float().contiguous())
|
||||
self.register_buffer("txt_cos", txt_freqs.real.float().contiguous())
|
||||
self.register_buffer("txt_sin", txt_freqs.imag.float().contiguous())
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states, timestep):
|
||||
hidden_states = self.model.img_in(hidden_states)
|
||||
timestep = timestep.to(hidden_states.dtype)
|
||||
|
||||
encoder_hidden_states = self.model.txt_norm(encoder_hidden_states)
|
||||
encoder_hidden_states = self.model.txt_in(encoder_hidden_states)
|
||||
|
||||
temb = self.model.time_text_embed(timestep, hidden_states)
|
||||
|
||||
rope = (self.img_cos, self.img_sin, self.txt_cos, self.txt_sin)
|
||||
|
||||
for block in self.model.transformer_blocks:
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_hidden_states_mask=None,
|
||||
temb=temb,
|
||||
image_rotary_emb=rope,
|
||||
)
|
||||
|
||||
hidden_states = self.model.norm_out(hidden_states, temb)
|
||||
output = self.model.proj_out(hidden_states)
|
||||
return output
|
||||
|
||||
|
||||
def _make_tiny_transformer_config():
|
||||
"""Tiny transformer config: ~100K params, 1 layer."""
|
||||
return dict(
|
||||
patch_size=2,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
num_layers=1,
|
||||
attention_head_dim=16,
|
||||
num_attention_heads=4,
|
||||
joint_attention_dim=64,
|
||||
axes_dims_rope=(4, 6, 6),
|
||||
)
|
||||
|
||||
|
||||
def _make_small_transformer_config():
|
||||
"""Small transformer config: ~1M params, 2 layers."""
|
||||
return dict(
|
||||
patch_size=2,
|
||||
in_channels=16,
|
||||
out_channels=16,
|
||||
num_layers=2,
|
||||
attention_head_dim=32,
|
||||
num_attention_heads=8,
|
||||
joint_attention_dim=256,
|
||||
axes_dims_rope=(8, 12, 12),
|
||||
)
|
||||
|
||||
|
||||
def _make_medium_transformer_config():
|
||||
"""Medium transformer config: ~39M params, 4 layers."""
|
||||
return dict(
|
||||
patch_size=2,
|
||||
in_channels=32,
|
||||
out_channels=32,
|
||||
num_layers=4,
|
||||
attention_head_dim=64,
|
||||
num_attention_heads=8,
|
||||
joint_attention_dim=512,
|
||||
axes_dims_rope=(8, 28, 28),
|
||||
)
|
||||
|
||||
|
||||
def _run_transformer_test(config, atol):
|
||||
"""Compile transformer with luminal backend, compare to PyTorch reference."""
|
||||
from diffusers.models import QwenImageTransformer2DModel
|
||||
|
||||
from luminal import luminal_backend
|
||||
|
||||
model = QwenImageTransformer2DModel(**config).eval()
|
||||
img_seq_len = 4
|
||||
txt_seq_len = 3
|
||||
|
||||
wrapper = TransformerONNXWrapper(model, [(1, 2, 2)], txt_seq_len).eval()
|
||||
wrapper_compiled = torch.compile(wrapper, backend=luminal_backend)
|
||||
|
||||
hidden = torch.randn(1, img_seq_len, config["in_channels"])
|
||||
encoder_hs = torch.randn(1, txt_seq_len, config["joint_attention_dim"])
|
||||
timestep = torch.tensor([1.0])
|
||||
|
||||
with torch.no_grad():
|
||||
ref = wrapper(hidden, encoder_hs, timestep)
|
||||
out = wrapper_compiled(hidden, encoder_hs, timestep)
|
||||
|
||||
assert torch.allclose(out, ref, atol=atol), (
|
||||
f"max_diff={torch.max(torch.abs(out - ref)).item():.2e}"
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# VAE helpers
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class _OnnxFriendlyUpsample(nn.Module):
|
||||
"""Replaces nn.Upsample with repeat_interleave for ONNX compatibility."""
|
||||
|
||||
def __init__(self, scale_factor):
|
||||
super().__init__()
|
||||
if isinstance(scale_factor, (tuple, list)):
|
||||
self.scale_factors = [int(s) for s in scale_factor]
|
||||
else:
|
||||
sf = int(scale_factor)
|
||||
self.scale_factors = [sf]
|
||||
|
||||
def forward(self, x):
|
||||
for dim_offset, sf in enumerate(self.scale_factors):
|
||||
if sf > 1:
|
||||
x = x.repeat_interleave(sf, dim=2 + dim_offset)
|
||||
return x
|
||||
|
||||
|
||||
def _make_tiny_vae_config():
|
||||
"""Tiny VAE config for testing."""
|
||||
return dict(
|
||||
base_dim=8,
|
||||
z_dim=4,
|
||||
dim_mult=[1, 2],
|
||||
num_res_blocks=1,
|
||||
attn_scales=[],
|
||||
temperal_downsample=[False],
|
||||
dropout=0.0,
|
||||
input_channels=3,
|
||||
)
|
||||
|
||||
|
||||
def _make_medium_vae_config():
|
||||
"""Medium VAE config: base_dim=32, z_dim=8."""
|
||||
return dict(
|
||||
base_dim=32,
|
||||
z_dim=8,
|
||||
dim_mult=[1, 2, 4],
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
temperal_downsample=[False, True],
|
||||
dropout=0.0,
|
||||
input_channels=3,
|
||||
)
|
||||
|
||||
|
||||
def _prepare_vae_for_onnx(vae):
|
||||
"""Replace non-ONNX-exportable modules in the VAE."""
|
||||
import diffusers.models.autoencoders.autoencoder_kl_qwenimage as vae_mod
|
||||
|
||||
def _replace(module):
|
||||
for name, child in module.named_children():
|
||||
if isinstance(child, vae_mod.QwenImageUpsample):
|
||||
setattr(module, name, _OnnxFriendlyUpsample(child.scale_factor))
|
||||
else:
|
||||
_replace(child)
|
||||
|
||||
_replace(vae)
|
||||
return vae
|
||||
|
||||
|
||||
class _VAEDecoderWrapper(nn.Module):
|
||||
def __init__(self, vae):
|
||||
super().__init__()
|
||||
self.vae = vae
|
||||
|
||||
def forward(self, z):
|
||||
return self.vae.decode(z).sample
|
||||
|
||||
|
||||
def _export_and_simplify(wrapper, inputs, input_names, output_names):
|
||||
"""Export model to ONNX and simplify with onnxsim."""
|
||||
tmp = tempfile.NamedTemporaryFile(suffix=".onnx", delete=False)
|
||||
tmp_path = tmp.name
|
||||
tmp.close()
|
||||
try:
|
||||
torch.onnx.export(
|
||||
wrapper,
|
||||
inputs,
|
||||
tmp_path,
|
||||
opset_version=20,
|
||||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
dynamo=False,
|
||||
)
|
||||
m = onnx.load(tmp_path)
|
||||
m_sim, check = onnxsim.simplify(m)
|
||||
assert check, "onnxsim simplification failed"
|
||||
onnx.save(m_sim, tmp_path)
|
||||
return tmp_path
|
||||
except Exception:
|
||||
os.unlink(tmp_path)
|
||||
raise
|
||||
|
||||
|
||||
def _run_vae_test(config, atol):
|
||||
"""Export VAE decoder to ONNX, run through luminal, compare."""
|
||||
from diffusers import AutoencoderKLQwenImage
|
||||
|
||||
import luminal
|
||||
|
||||
backend = os.environ.get("LUMINAL_BACKEND", "native")
|
||||
vae = AutoencoderKLQwenImage(**config).eval()
|
||||
vae = _prepare_vae_for_onnx(vae)
|
||||
|
||||
wrapper = _VAEDecoderWrapper(vae).eval()
|
||||
latents = torch.randn(1, config["z_dim"], 1, 4, 4)
|
||||
|
||||
with torch.no_grad():
|
||||
ref = wrapper(latents)
|
||||
|
||||
onnx_path = _export_and_simplify(wrapper, (latents,), ["latents"], ["output"])
|
||||
try:
|
||||
graph = luminal.process_onnx(onnx_path, backend)
|
||||
graph.set_input("latents", latents.flatten().tolist())
|
||||
graph.run()
|
||||
out_data = graph.get_output("output")
|
||||
out = torch.tensor(out_data, dtype=torch.float32).reshape(ref.shape)
|
||||
finally:
|
||||
os.unlink(onnx_path)
|
||||
|
||||
assert torch.allclose(out, ref, atol=atol), (
|
||||
f"max_diff={torch.max(torch.abs(out - ref)).item():.2e}"
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_qwen_image_transformer_tiny():
|
||||
"""Tiny QwenImage transformer: 1 layer, 4 heads, dim=64."""
|
||||
_run_transformer_test(_make_tiny_transformer_config(), atol=1e-4)
|
||||
|
||||
|
||||
def test_qwen_image_transformer_small():
|
||||
"""Small QwenImage transformer: 2 layers, 8 heads, dim=256."""
|
||||
_run_transformer_test(_make_small_transformer_config(), atol=1e-4)
|
||||
|
||||
|
||||
def test_qwen_image_transformer_medium():
|
||||
"""Medium QwenImage transformer: 4 layers, 8 heads, dim=512."""
|
||||
_run_transformer_test(_make_medium_transformer_config(), atol=1e-4)
|
||||
|
||||
|
||||
def test_qwen_image_transformer_full():
|
||||
"""Full QwenImage transformer (production defaults)."""
|
||||
from diffusers.models import QwenImageTransformer2DModel
|
||||
|
||||
from luminal import luminal_backend
|
||||
|
||||
model = QwenImageTransformer2DModel().eval()
|
||||
config = {k: v for k, v in dict(model.config).items() if not k.startswith("_")}
|
||||
|
||||
wrapper = TransformerONNXWrapper(model, [(1, 2, 2)], txt_seq_len=3).eval()
|
||||
wrapper_compiled = torch.compile(wrapper, backend=luminal_backend)
|
||||
|
||||
hidden = torch.randn(1, 4, config["in_channels"])
|
||||
encoder_hs = torch.randn(1, 3, config["joint_attention_dim"])
|
||||
timestep = torch.tensor([1.0])
|
||||
|
||||
with torch.no_grad():
|
||||
ref = wrapper(hidden, encoder_hs, timestep)
|
||||
out = wrapper_compiled(hidden, encoder_hs, timestep)
|
||||
|
||||
assert torch.allclose(out, ref, atol=1e-4), (
|
||||
f"max_diff={torch.max(torch.abs(out - ref)).item():.2e}"
|
||||
)
|
||||
|
||||
|
||||
def test_qwen_image_vae_decoder_tiny():
|
||||
"""Tiny QwenImage VAE decoder: base_dim=8, z_dim=4."""
|
||||
_run_vae_test(_make_tiny_vae_config(), atol=1e-3)
|
||||
|
||||
|
||||
def test_qwen_image_vae_decoder_medium():
|
||||
"""Medium QwenImage VAE decoder: base_dim=32, z_dim=8."""
|
||||
_run_vae_test(_make_medium_vae_config(), atol=1e-3)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Full production VAE -- expected to be slow/OOM")
|
||||
def test_qwen_image_vae_decoder_full():
|
||||
"""Full QwenImage VAE decoder (production defaults)."""
|
||||
from diffusers import AutoencoderKLQwenImage
|
||||
|
||||
config = dict(AutoencoderKLQwenImage().config)
|
||||
config = {k: v for k, v in config.items() if not k.startswith("_")}
|
||||
_run_vae_test(config, atol=1e-3)
|
||||
@@ -7,8 +7,8 @@ try:
|
||||
import maturin_import_hook
|
||||
from maturin_import_hook.settings import MaturinSettings
|
||||
|
||||
backend = os.getenv("LUMINAL_BACKEND", "native").lower()
|
||||
settings = MaturinSettings(features=["cuda"]) if backend == "cuda" else None
|
||||
use_cuda = os.getenv("LUMINAL_TEST_DEVICE", "cpu").lower() == "cuda"
|
||||
settings = MaturinSettings(features=["cuda"]) if use_cuda else None
|
||||
maturin_import_hook.install(settings=settings)
|
||||
except ImportError:
|
||||
pass # Hook not available, rebuilds will be manual
|
||||
@@ -22,23 +22,17 @@ torch.set_float32_matmul_precision("highest")
|
||||
|
||||
@pytest.fixture
|
||||
def device() -> torch.device:
|
||||
backend = os.getenv("LUMINAL_BACKEND", "native").lower()
|
||||
return torch.device("cuda") if backend == "cuda" else torch.device("cpu")
|
||||
if (
|
||||
os.getenv("LUMINAL_TEST_DEVICE", "cpu").lower() == "cuda"
|
||||
and torch.cuda.is_available()
|
||||
):
|
||||
return torch.device("cuda")
|
||||
return torch.device("cpu")
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True, scope="function")
|
||||
def reset_torch_dynamo():
|
||||
# We need this for two reasons
|
||||
# 1. Some of our casts tests use the same model, but those graph have some state to them
|
||||
# and the cache will return old models
|
||||
# 2. The cache adds a large preformace hit to the test suite
|
||||
torch._dynamo.config.cache_size_limit = 1
|
||||
# Disable silent fallback to eager mode so backend errors surface as test failures
|
||||
torch._dynamo.config.suppress_errors = False
|
||||
"""Reset PyTorch Dynamo state after each test to prevent state leakage.
|
||||
|
||||
This fixture automatically runs after every test function to clear
|
||||
torch._dynamo's compilation cache, ensuring test isolation.
|
||||
"""
|
||||
yield # Test runs here
|
||||
yield
|
||||
torch._dynamo.reset()
|
||||
|
||||
@@ -1,57 +0,0 @@
|
||||
"""Generate pre-computed artifacts for test_hf_llama38b_cached_onnx.
|
||||
|
||||
Run once:
|
||||
uv run python tests/generate_llama38b_artifacts.py
|
||||
|
||||
Produces:
|
||||
tests/llama38b.onnx — ONNX export of Llama 3.1-8B
|
||||
tests/llama38b_ref_logits.pt — reference logits for input_ids=[1,2,3,4]
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from transformers import AutoConfig, LlamaForCausalLM
|
||||
|
||||
SCRIPT_DIR = Path(__file__).resolve().parent
|
||||
ONNX_PATH = SCRIPT_DIR / "llama38b.onnx"
|
||||
LOGITS_PATH = SCRIPT_DIR / "llama38b_ref_logits.pt"
|
||||
|
||||
INPUT_IDS = torch.tensor([[1, 2, 3, 4]])
|
||||
|
||||
|
||||
def main():
|
||||
config = AutoConfig.from_pretrained("NousResearch/Meta-Llama-3.1-8B-Instruct")
|
||||
config.use_cache = False
|
||||
config._attn_implementation = "eager"
|
||||
|
||||
print("Loading model on CPU...")
|
||||
model = LlamaForCausalLM.from_pretrained(
|
||||
"NousResearch/Meta-Llama-3.1-8B-Instruct",
|
||||
config=config,
|
||||
torch_dtype=torch.float32,
|
||||
).eval()
|
||||
|
||||
print("Computing reference logits...")
|
||||
with torch.no_grad():
|
||||
ref_logits = model(INPUT_IDS).logits.clone()
|
||||
print(f"Reference logits shape: {ref_logits.shape}")
|
||||
|
||||
print(f"Saving reference logits to {LOGITS_PATH}")
|
||||
torch.save(ref_logits, LOGITS_PATH)
|
||||
|
||||
print(f"Exporting ONNX to {ONNX_PATH}")
|
||||
torch.onnx.export(
|
||||
model,
|
||||
(INPUT_IDS,),
|
||||
str(ONNX_PATH),
|
||||
opset_version=20,
|
||||
input_names=["input_ids"],
|
||||
output_names=["logits"],
|
||||
)
|
||||
|
||||
print("Done.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -7,7 +7,7 @@ Produces:
|
||||
tests/llama38b.pt2 — torch.export of Llama 3.1-8B
|
||||
tests/llama38b_weights.safetensors — model weights
|
||||
tests/llama38b_ref_logits.pt — reference logits for input_ids=[1,2,3,4]
|
||||
(shared with ONNX artifact script)
|
||||
(shared with PT2 artifact script)
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
@@ -36,7 +36,7 @@ def main():
|
||||
torch_dtype=torch.float32,
|
||||
).eval()
|
||||
|
||||
# Generate reference logits (shared with ONNX artifact script)
|
||||
# Generate reference logits (shared with PT2 artifact script)
|
||||
if not LOGITS_PATH.exists():
|
||||
print("Computing reference logits...")
|
||||
with torch.no_grad():
|
||||
|
||||
34
crates/luminal_python/tests/test_capsule_validation.py
Normal file
34
crates/luminal_python/tests/test_capsule_validation.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""FFI-boundary tests for process_pt2's capsule validation.
|
||||
|
||||
Deviates from the standard `torch.compile(..., backend=luminal_backend)`
|
||||
pattern in CLAUDE.md because the thing under test is the capsule-name
|
||||
check itself, not a feature behavior. Exercising it through torch.compile
|
||||
would only cover the happy path (`_native_factory_capsule` produces a
|
||||
correctly-named capsule, so validation passes trivially).
|
||||
"""
|
||||
|
||||
import ctypes
|
||||
|
||||
import pytest
|
||||
|
||||
from luminal import process_pt2
|
||||
|
||||
|
||||
def _new_capsule(name: bytes):
|
||||
PyCapsule_New = ctypes.pythonapi.PyCapsule_New
|
||||
PyCapsule_New.restype = ctypes.py_object
|
||||
PyCapsule_New.argtypes = [ctypes.c_void_p, ctypes.c_char_p, ctypes.c_void_p]
|
||||
dummy = ctypes.c_void_p(0xDEADBEEF)
|
||||
return PyCapsule_New(ctypes.byref(dummy), name, None)
|
||||
|
||||
|
||||
def test_process_pt2_rejects_capsule_with_wrong_name():
|
||||
bogus = _new_capsule(b"not.luminal.backend_factory")
|
||||
with pytest.raises(ValueError, match="luminal.backend_factory"):
|
||||
process_pt2("/dev/null", "/dev/null", bogus, None)
|
||||
|
||||
|
||||
def test_process_pt2_rejects_capsule_with_no_name():
|
||||
unnamed = _new_capsule(None)
|
||||
with pytest.raises(ValueError, match="luminal.backend_factory"):
|
||||
process_pt2("/dev/null", "/dev/null", unnamed, None)
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch._dynamo
|
||||
from test_models import (
|
||||
@@ -8,6 +9,8 @@ from test_models import (
|
||||
AddTestModel,
|
||||
# And model
|
||||
AndTestModel,
|
||||
# Dtype round-trip model
|
||||
SelfAddModel,
|
||||
CastBoolToFloatModel,
|
||||
# Cast models
|
||||
CastDoubleToFloatModel,
|
||||
@@ -213,11 +216,120 @@ from test_models import (
|
||||
WhereWithConstantModel,
|
||||
# Xor model
|
||||
XorTestModel,
|
||||
ArgsortStableDuplicatesModel,
|
||||
# Conv models
|
||||
Conv1dNoPadModel,
|
||||
Conv1dSamePadModel,
|
||||
Conv1dBiasModel,
|
||||
Conv2dNoPadModel,
|
||||
Conv2dSamePadModel,
|
||||
Conv2dBiasModel,
|
||||
Conv2dStrideModel,
|
||||
Conv2dDilationModel,
|
||||
Conv3dSamePadModel,
|
||||
DepthwiseConv1dModel,
|
||||
DepthwiseConv2dModel,
|
||||
DepthwiseMultiplierConv2dModel,
|
||||
GroupedConv2dModel,
|
||||
GroupedConv2dGroups3Model,
|
||||
MambaConvBlockModel,
|
||||
TinyMoERoutingModel,
|
||||
)
|
||||
|
||||
import luminal.pt2 as luminal_pt2
|
||||
from luminal import luminal_backend
|
||||
|
||||
|
||||
def test_backend_options_forwarded_to_process_pt2(
|
||||
monkeypatch: pytest.MonkeyPatch, device: torch.device
|
||||
):
|
||||
captured = {}
|
||||
|
||||
def fake_process_pt2(
|
||||
pt2_path,
|
||||
weights_path,
|
||||
factory_capsule,
|
||||
weight_device_ptrs=None,
|
||||
options=None,
|
||||
):
|
||||
captured["pt2_path"] = pt2_path
|
||||
captured["weights_path"] = weights_path
|
||||
captured["factory_capsule"] = factory_capsule
|
||||
captured["weight_device_ptrs"] = weight_device_ptrs
|
||||
captured["options"] = options
|
||||
return object()
|
||||
|
||||
monkeypatch.setattr(luminal_pt2, "process_pt2", fake_process_pt2)
|
||||
monkeypatch.setattr(
|
||||
luminal_pt2, "_load_cpu_weights", lambda compiled, weights: None
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
luminal_pt2,
|
||||
"CompiledModel",
|
||||
lambda compiled, weight_refs=None: lambda x: x + x,
|
||||
)
|
||||
|
||||
model: torch.nn.Module = AddTestModel().to(device)
|
||||
compiled: Callable = torch.compile(
|
||||
model,
|
||||
backend=luminal_backend,
|
||||
options={"search_iterations": 3},
|
||||
)
|
||||
|
||||
x: torch.Tensor = torch.rand((5, 5), device=device)
|
||||
compiled(x)
|
||||
|
||||
assert captured["weights_path"] == ""
|
||||
assert type(captured["factory_capsule"]).__name__ == "PyCapsule"
|
||||
assert captured["options"] == {"search_iterations": 3}
|
||||
assert isinstance(captured["weight_device_ptrs"], dict)
|
||||
|
||||
|
||||
def test_backend_options_unknown_key_raises(device: torch.device):
|
||||
model: torch.nn.Module = AddTestModel().to(device)
|
||||
compiled: Callable = torch.compile(
|
||||
model,
|
||||
backend=luminal_backend,
|
||||
options={"unknown_option": 1},
|
||||
)
|
||||
|
||||
x: torch.Tensor = torch.rand((5, 5), device=device)
|
||||
with pytest.raises(torch._dynamo.exc.BackendCompilerFailed) as exc_info:
|
||||
compiled(x)
|
||||
assert isinstance(exc_info.value.inner_exception, ValueError)
|
||||
assert "Unsupported luminal backend option" in str(exc_info.value.inner_exception)
|
||||
|
||||
|
||||
def test_backend_options_non_dict_raises(device: torch.device):
|
||||
model: torch.nn.Module = AddTestModel().to(device)
|
||||
compiled: Callable = torch.compile(
|
||||
model,
|
||||
backend=luminal_backend,
|
||||
options=["pt2"],
|
||||
)
|
||||
|
||||
x: torch.Tensor = torch.rand((5, 5), device=device)
|
||||
with pytest.raises(torch._dynamo.exc.BackendCompilerFailed) as exc_info:
|
||||
compiled(x)
|
||||
assert isinstance(exc_info.value.inner_exception, TypeError)
|
||||
assert "options must be a dict" in str(exc_info.value.inner_exception)
|
||||
|
||||
|
||||
def test_backend_options_bad_search_iterations_type_raises(device: torch.device):
|
||||
model: torch.nn.Module = AddTestModel().to(device)
|
||||
compiled: Callable = torch.compile(
|
||||
model,
|
||||
backend=luminal_backend,
|
||||
options={"search_iterations": "fast"},
|
||||
)
|
||||
|
||||
x: torch.Tensor = torch.rand((5, 5), device=device)
|
||||
with pytest.raises(torch._dynamo.exc.BackendCompilerFailed) as exc_info:
|
||||
compiled(x)
|
||||
assert isinstance(exc_info.value.inner_exception, TypeError)
|
||||
assert "search_iterations" in str(exc_info.value.inner_exception)
|
||||
|
||||
|
||||
def test_add(device: torch.device):
|
||||
add_test_model: torch.nn.Module = AddTestModel().to(device)
|
||||
add_test_mode_compiled: Callable = torch.compile(
|
||||
@@ -416,9 +528,9 @@ def test_transpose_square_matrix(device: torch.device):
|
||||
assert torch.allclose(output, original)
|
||||
|
||||
|
||||
# ========== ONNX Constant Node Tests ==========
|
||||
# ========== PT2 Constant Node Tests ==========
|
||||
# These tests verify the parse_constant_node function in ops_parse.rs
|
||||
# which handles ONNX Constant nodes (nodes with embedded data in attributes)
|
||||
# which handles PT2 Constant nodes (nodes with embedded data in attributes)
|
||||
|
||||
|
||||
def test_constant_scalar_float(device: torch.device):
|
||||
@@ -541,9 +653,9 @@ def test_constant_multiple_in_graph(device: torch.device):
|
||||
assert torch.allclose(output, original)
|
||||
|
||||
|
||||
# ========== ONNX Cast Node Tests ==========
|
||||
# ========== PT2 Cast Node Tests ==========
|
||||
# These tests verify the parse_cast_node function in ops_parse.rs
|
||||
# which handles ONNX Cast nodes (type conversion operations)
|
||||
# which handles PT2 Cast nodes (type conversion operations)
|
||||
|
||||
|
||||
def test_cast_double_to_float(device: torch.device):
|
||||
@@ -630,7 +742,7 @@ def test_cast_scalar_value(device: torch.device):
|
||||
assert torch.allclose(output, original)
|
||||
|
||||
|
||||
# ========== ONNX Mod Node Tests ==========
|
||||
# ========== PT2 Mod Node Tests ==========
|
||||
|
||||
|
||||
def test_mod(device: torch.device):
|
||||
@@ -663,7 +775,7 @@ def test_mod_by_constant(device: torch.device):
|
||||
assert torch.allclose(output, original)
|
||||
|
||||
|
||||
# ========== ONNX Floor Node Tests ==========
|
||||
# ========== PT2 Floor Node Tests ==========
|
||||
|
||||
|
||||
def test_floor(device: torch.device):
|
||||
@@ -696,7 +808,7 @@ def test_floor_in_expression(device: torch.device):
|
||||
assert torch.allclose(output, original)
|
||||
|
||||
|
||||
# ========== ONNX Ceil Node Tests ==========
|
||||
# ========== PT2 Ceil Node Tests ==========
|
||||
|
||||
|
||||
def test_ceil(device: torch.device):
|
||||
@@ -729,7 +841,7 @@ def test_ceil_in_expression(device: torch.device):
|
||||
assert torch.allclose(output, original)
|
||||
|
||||
|
||||
# ========== ONNX Reshape Node Tests ==========
|
||||
# ========== PT2 Reshape Node Tests ==========
|
||||
# These tests verify parse_reshape_node and parse_shape_node in ops_parse.rs
|
||||
|
||||
|
||||
@@ -843,7 +955,7 @@ def test_shape_reshape_view_batch(device: torch.device):
|
||||
assert torch.allclose(output, original)
|
||||
|
||||
|
||||
# ========== ONNX Less Node Tests ==========
|
||||
# ========== PT2 Less Node Tests ==========
|
||||
# These tests verify parse_less_node in ops_parse.rs
|
||||
|
||||
|
||||
@@ -877,7 +989,7 @@ def test_less_with_constant(device: torch.device):
|
||||
assert torch.allclose(output, original)
|
||||
|
||||
|
||||
# ========== ONNX Equal Node Tests ==========
|
||||
# ========== PT2 Equal Node Tests ==========
|
||||
# These tests verify parse_equal_node in ops_parse/binary.rs
|
||||
|
||||
|
||||
@@ -911,7 +1023,7 @@ def test_equal_with_constant(device: torch.device):
|
||||
assert torch.allclose(output, original)
|
||||
|
||||
|
||||
# ========== ONNX Gather Node Tests ==========
|
||||
# ========== PT2 Gather Node Tests ==========
|
||||
# These tests verify parse_gather_node in ops_parse.rs
|
||||
|
||||
|
||||
@@ -975,7 +1087,7 @@ def test_gather_constant_fold(device: torch.device):
|
||||
assert torch.allclose(output, original)
|
||||
|
||||
|
||||
# ========== ONNX Squeeze Node Tests ==========
|
||||
# ========== PT2 Squeeze Node Tests ==========
|
||||
# These tests verify parse_squeeze_node in ops_parse.rs
|
||||
|
||||
|
||||
@@ -1029,7 +1141,7 @@ def test_squeeze_in_expression(device: torch.device):
|
||||
assert torch.allclose(output, original)
|
||||
|
||||
|
||||
# ========== ONNX ReduceSum Node Tests ==========
|
||||
# ========== PT2 ReduceSum Node Tests ==========
|
||||
|
||||
|
||||
def test_reduce_sum_axis0(device: torch.device):
|
||||
@@ -1104,7 +1216,7 @@ def test_reduce_sum_in_expression(device: torch.device):
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
|
||||
|
||||
# ========== ONNX ReduceMax Node Tests ==========
|
||||
# ========== PT2 ReduceMax Node Tests ==========
|
||||
|
||||
|
||||
def test_reduce_max_axis0(device: torch.device):
|
||||
@@ -1179,7 +1291,7 @@ def test_reduce_max_in_expression(device: torch.device):
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
|
||||
|
||||
# ========== ONNX ReduceMin Node Tests ==========
|
||||
# ========== PT2 ReduceMin Node Tests ==========
|
||||
# These tests verify parse_reduce_min_node in ops_parse/reduction.rs
|
||||
|
||||
|
||||
@@ -1255,7 +1367,7 @@ def test_reduce_min_in_expression(device: torch.device):
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
|
||||
|
||||
# ========== ONNX ReduceMean Node Tests ==========
|
||||
# ========== PT2 ReduceMean Node Tests ==========
|
||||
# These tests verify parse_reduce_mean_node in ops_parse/reduction.rs
|
||||
|
||||
|
||||
@@ -1331,7 +1443,7 @@ def test_reduce_mean_in_expression(device: torch.device):
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
|
||||
|
||||
# ========== ONNX Pow Node Tests ==========
|
||||
# ========== PT2 Pow Node Tests ==========
|
||||
# These tests verify parse_pow_node in ops_parse/binary.rs
|
||||
|
||||
|
||||
@@ -1365,7 +1477,7 @@ def test_pow_by_constant(device: torch.device):
|
||||
assert torch.allclose(output, original, rtol=1e-4, atol=1e-4)
|
||||
|
||||
|
||||
# ========== ONNX Where Node Tests ==========
|
||||
# ========== PT2 Where Node Tests ==========
|
||||
# These tests verify parse_where_node in ops_parse/binary.rs
|
||||
|
||||
|
||||
@@ -1403,7 +1515,7 @@ def test_where_with_constant(device: torch.device):
|
||||
assert torch.allclose(output, original)
|
||||
|
||||
|
||||
# ========== ONNX Max Node Tests ==========
|
||||
# ========== PT2 Max Node Tests ==========
|
||||
# These tests verify parse_max_node in ops_parse/binary.rs
|
||||
|
||||
|
||||
@@ -1427,7 +1539,7 @@ def test_max_with_constant(device: torch.device):
|
||||
assert torch.allclose(output, original)
|
||||
|
||||
|
||||
# ========== ONNX Min Node Tests ==========
|
||||
# ========== PT2 Min Node Tests ==========
|
||||
# These tests verify parse_min_node in ops_parse/binary.rs
|
||||
|
||||
|
||||
@@ -1451,7 +1563,7 @@ def test_min_with_constant(device: torch.device):
|
||||
assert torch.allclose(output, original)
|
||||
|
||||
|
||||
# ========== ONNX Concat Node Tests ==========
|
||||
# ========== PT2 Concat Node Tests ==========
|
||||
# These tests verify parse_concat_node in ops_parse/movement.rs
|
||||
|
||||
|
||||
@@ -1495,7 +1607,7 @@ def test_concat_in_expression(device: torch.device):
|
||||
assert torch.allclose(model_compiled(x), model(x))
|
||||
|
||||
|
||||
# ========== ONNX Softmax Node Tests ==========
|
||||
# ========== PT2 Softmax Node Tests ==========
|
||||
# These tests verify parse_softmax_node in ops_parse/unary.rs
|
||||
|
||||
|
||||
@@ -1519,7 +1631,7 @@ def test_softmax_dim0(device: torch.device):
|
||||
assert torch.allclose(output, original, atol=1e-5)
|
||||
|
||||
|
||||
# ========== ONNX LessOrEqual Node Tests ==========
|
||||
# ========== PT2 LessOrEqual Node Tests ==========
|
||||
|
||||
|
||||
def test_less_or_equal(device: torch.device):
|
||||
@@ -1542,7 +1654,7 @@ def test_less_or_equal_with_constant(device: torch.device):
|
||||
assert torch.allclose(output, original)
|
||||
|
||||
|
||||
# ========== ONNX GreaterOrEqual Node Tests ==========
|
||||
# ========== PT2 GreaterOrEqual Node Tests ==========
|
||||
|
||||
|
||||
def test_greater_or_equal(device: torch.device):
|
||||
@@ -1565,7 +1677,7 @@ def test_greater_or_equal_with_constant(device: torch.device):
|
||||
assert torch.allclose(output, original)
|
||||
|
||||
|
||||
# ========== ONNX Not Node Tests ==========
|
||||
# ========== PT2 Not Node Tests ==========
|
||||
|
||||
|
||||
def test_not(device: torch.device):
|
||||
@@ -1578,7 +1690,7 @@ def test_not(device: torch.device):
|
||||
assert torch.allclose(output, original)
|
||||
|
||||
|
||||
# ========== ONNX And Node Tests ==========
|
||||
# ========== PT2 And Node Tests ==========
|
||||
|
||||
|
||||
def test_and(device: torch.device):
|
||||
@@ -1591,7 +1703,7 @@ def test_and(device: torch.device):
|
||||
assert torch.allclose(output, original)
|
||||
|
||||
|
||||
# ========== ONNX Or Node Tests ==========
|
||||
# ========== PT2 Or Node Tests ==========
|
||||
|
||||
|
||||
def test_or(device: torch.device):
|
||||
@@ -1604,7 +1716,7 @@ def test_or(device: torch.device):
|
||||
assert torch.allclose(output, original)
|
||||
|
||||
|
||||
# ========== ONNX Xor Node Tests ==========
|
||||
# ========== PT2 Xor Node Tests ==========
|
||||
|
||||
|
||||
def test_xor(device: torch.device):
|
||||
@@ -1617,7 +1729,7 @@ def test_xor(device: torch.device):
|
||||
assert torch.allclose(output, original)
|
||||
|
||||
|
||||
# ========== ONNX Trilu Node Tests ==========
|
||||
# ========== PT2 Trilu Node Tests ==========
|
||||
|
||||
|
||||
def test_tril(device: torch.device):
|
||||
@@ -1812,11 +1924,11 @@ def test_mlp_block(device: torch.device):
|
||||
assert torch.allclose(output, original, atol=1e-5)
|
||||
|
||||
|
||||
# ========== ONNX GatherElements Node Tests ==========
|
||||
# ========== PT2 GatherElements Node Tests ==========
|
||||
|
||||
|
||||
def test_gather_elements(device: torch.device):
|
||||
"""Tests GatherElements op (torch.gather → ONNX GatherElements)."""
|
||||
"""Tests GatherElements op (torch.gather → PT2 GatherElements)."""
|
||||
model: torch.nn.Module = GatherElementsTestModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x: torch.Tensor = torch.rand((2, 3), device=device)
|
||||
@@ -1831,18 +1943,18 @@ def test_gather_elements_large(device: torch.device):
|
||||
assert torch.allclose(model_compiled(x), model(x))
|
||||
|
||||
|
||||
# ========== ONNX Expand Node Tests ==========
|
||||
# ========== PT2 Expand Node Tests ==========
|
||||
|
||||
|
||||
def test_expand(device: torch.device):
|
||||
"""Tests Expand op (tensor.expand → ONNX Expand)."""
|
||||
"""Tests Expand op (tensor.expand → PT2 Expand)."""
|
||||
model: torch.nn.Module = ExpandTestModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x: torch.Tensor = torch.rand((1, 4), device=device)
|
||||
assert torch.allclose(model_compiled(x), model(x))
|
||||
|
||||
|
||||
# ========== ONNX IsNaN Node Tests ==========
|
||||
# ========== PT2 IsNaN Node Tests ==========
|
||||
|
||||
|
||||
def test_isnan(device: torch.device):
|
||||
@@ -1853,29 +1965,29 @@ def test_isnan(device: torch.device):
|
||||
assert torch.allclose(model_compiled(x), model(x))
|
||||
|
||||
|
||||
# ========== ONNX LayerNormalization Node Tests ==========
|
||||
# ========== PT2 LayerNormalization Node Tests ==========
|
||||
|
||||
|
||||
def test_layernorm(device: torch.device):
|
||||
"""Tests LayerNormalization op (nn.LayerNorm → ONNX LayerNormalization)."""
|
||||
"""Tests LayerNormalization op (nn.LayerNorm → PT2 LayerNormalization)."""
|
||||
model: torch.nn.Module = LayerNormTestModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x: torch.Tensor = torch.rand((2, 4), device=device)
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
|
||||
|
||||
# ========== ONNX Gemm Node Tests ==========
|
||||
# ========== PT2 Gemm Node Tests ==========
|
||||
|
||||
|
||||
def test_gemm(device: torch.device):
|
||||
"""Tests Gemm op (nn.Linear → ONNX Gemm)."""
|
||||
"""Tests Gemm op (nn.Linear → PT2 Gemm)."""
|
||||
model: torch.nn.Module = GemmTestModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x: torch.Tensor = torch.rand((3, 4), device=device)
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
|
||||
|
||||
# ========== ONNX Erf Node Tests ==========
|
||||
# ========== PT2 Erf Node Tests ==========
|
||||
|
||||
|
||||
def test_erf(device: torch.device):
|
||||
@@ -1888,7 +2000,7 @@ def test_erf(device: torch.device):
|
||||
assert torch.allclose(output, original, atol=1e-4)
|
||||
|
||||
|
||||
# ========== ONNX Slice Node Tests ==========
|
||||
# ========== PT2 Slice Node Tests ==========
|
||||
|
||||
|
||||
def test_slice_1d(device: torch.device):
|
||||
@@ -1907,7 +2019,7 @@ def test_slice_2d(device: torch.device):
|
||||
assert torch.allclose(model_compiled(x), model(x))
|
||||
|
||||
|
||||
# ========== ONNX Split Node Tests ==========
|
||||
# ========== PT2 Split Node Tests ==========
|
||||
|
||||
|
||||
def test_split(device: torch.device):
|
||||
@@ -1918,7 +2030,55 @@ def test_split(device: torch.device):
|
||||
assert torch.allclose(model_compiled(x), model(x))
|
||||
|
||||
|
||||
# ========== ONNX TopK Node Tests ==========
|
||||
# ========== Argsort / MoE Routing Tests ==========
|
||||
|
||||
|
||||
def test_argsort_stable_duplicates(device: torch.device):
|
||||
"""Duplicate values should follow stable lower-index-first tie-breaking."""
|
||||
model: torch.nn.Module = ArgsortStableDuplicatesModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x = torch.tensor(
|
||||
[[2.0, 1.0, 1.0, 3.0]],
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
original: torch.Tensor = model(x)
|
||||
output: torch.Tensor = model_compiled(x)
|
||||
assert output.dtype == torch.int32
|
||||
assert torch.equal(output, original.to(torch.int32))
|
||||
|
||||
|
||||
def test_tiny_moe_routing(device: torch.device):
|
||||
"""Focused proof for build MoE routing support."""
|
||||
model: torch.nn.Module = TinyMoERoutingModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
scores = torch.tensor(
|
||||
[[0.1, 0.9, 0.4, 0.7], [0.6, -0.8, 0.95, 0.2]],
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
expected = model(scores)
|
||||
output = model_compiled(scores)
|
||||
|
||||
expected_dtypes = (
|
||||
torch.int32,
|
||||
torch.float32,
|
||||
torch.int32,
|
||||
torch.bool,
|
||||
torch.int32,
|
||||
torch.float32,
|
||||
)
|
||||
for actual, eager, expected_dtype in zip(output, expected, expected_dtypes):
|
||||
assert actual.dtype == expected_dtype
|
||||
eager = eager.to(actual.dtype)
|
||||
if actual.dtype.is_floating_point:
|
||||
assert torch.allclose(actual, eager)
|
||||
else:
|
||||
assert torch.equal(actual, eager)
|
||||
|
||||
|
||||
# ========== PT2 TopK Node Tests ==========
|
||||
|
||||
|
||||
def test_topk_values(device: torch.device):
|
||||
@@ -1937,7 +2097,7 @@ def test_topk_indices(device: torch.device):
|
||||
assert torch.allclose(model_compiled(x), model(x))
|
||||
|
||||
|
||||
# ========== ONNX OneHot Node Tests ==========
|
||||
# ========== PT2 OneHot Node Tests ==========
|
||||
|
||||
|
||||
def test_onehot(device: torch.device):
|
||||
@@ -1984,3 +2144,209 @@ def test_scatter_nd(device: torch.device):
|
||||
original: torch.Tensor = model(x)
|
||||
output: torch.Tensor = model_compiled(x)
|
||||
assert torch.allclose(output, original)
|
||||
|
||||
|
||||
# ========== Dtype Round-Trip Tests ==========
|
||||
|
||||
|
||||
def test_dtype_float16(device: torch.device):
|
||||
"""Verify float16 input produces float16 output with correct values."""
|
||||
model: torch.nn.Module = SelfAddModel()
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x: torch.Tensor = torch.tensor(
|
||||
[1.0, 2.0, 3.0, 4.0], dtype=torch.float16, device=device
|
||||
)
|
||||
original: torch.Tensor = model(x)
|
||||
output: torch.Tensor = model_compiled(x)
|
||||
assert output.dtype == torch.float16, f"Expected float16 output, got {output.dtype}"
|
||||
assert torch.allclose(output.float(), original.float())
|
||||
|
||||
|
||||
def test_dtype_float32(device: torch.device):
|
||||
"""Verify float32 input produces float32 output (baseline)."""
|
||||
model: torch.nn.Module = SelfAddModel()
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x: torch.Tensor = torch.tensor(
|
||||
[1.0, 2.0, 3.0, 4.0], dtype=torch.float32, device=device
|
||||
)
|
||||
original: torch.Tensor = model(x)
|
||||
output: torch.Tensor = model_compiled(x)
|
||||
assert output.dtype == torch.float32, f"Expected float32 output, got {output.dtype}"
|
||||
assert torch.allclose(output, original)
|
||||
|
||||
|
||||
# ========== Convolution Tests ==========
|
||||
|
||||
|
||||
def _run_conv1d_no_pad(device: torch.device):
|
||||
"""Conv1d without padding: output length = input - (kernel-1)."""
|
||||
model: torch.nn.Module = Conv1dNoPadModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x: torch.Tensor = torch.randn(2, 8, 32, device=device)
|
||||
original: torch.Tensor = model(x)
|
||||
output: torch.Tensor = model_compiled(x)
|
||||
assert torch.allclose(output, original, atol=1e-4)
|
||||
|
||||
|
||||
def test_conv1d_no_pad(device: torch.device):
|
||||
_run_conv1d_no_pad(device)
|
||||
|
||||
|
||||
def test_conv1d_same_pad(device: torch.device):
|
||||
"""Conv1d with padding=1: output length == input length."""
|
||||
model: torch.nn.Module = Conv1dSamePadModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x: torch.Tensor = torch.randn(2, 8, 32, device=device)
|
||||
original: torch.Tensor = model(x)
|
||||
output: torch.Tensor = model_compiled(x)
|
||||
assert torch.allclose(output, original, atol=1e-4)
|
||||
|
||||
|
||||
def test_conv1d_bias(device: torch.device):
|
||||
"""Conv1d with bias term."""
|
||||
model: torch.nn.Module = Conv1dBiasModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x: torch.Tensor = torch.randn(2, 8, 32, device=device)
|
||||
original: torch.Tensor = model(x)
|
||||
output: torch.Tensor = model_compiled(x)
|
||||
assert torch.allclose(output, original, atol=1e-4)
|
||||
|
||||
|
||||
def _run_conv2d_no_pad(device: torch.device):
|
||||
"""Conv2d without padding: output spatial = input - (kernel-1)."""
|
||||
model: torch.nn.Module = Conv2dNoPadModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x: torch.Tensor = torch.randn(1, 3, 8, 8, device=device)
|
||||
original: torch.Tensor = model(x)
|
||||
output: torch.Tensor = model_compiled(x)
|
||||
assert torch.allclose(output, original, atol=1e-4)
|
||||
|
||||
|
||||
def test_conv2d_no_pad(device: torch.device):
|
||||
_run_conv2d_no_pad(device)
|
||||
|
||||
|
||||
def test_conv2d_same_pad(device: torch.device):
|
||||
"""Conv2d with padding=1: output spatial == input spatial."""
|
||||
model: torch.nn.Module = Conv2dSamePadModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x: torch.Tensor = torch.randn(1, 3, 8, 8, device=device)
|
||||
original: torch.Tensor = model(x)
|
||||
output: torch.Tensor = model_compiled(x)
|
||||
assert torch.allclose(output, original, atol=1e-4)
|
||||
|
||||
|
||||
def test_conv2d_bias(device: torch.device):
|
||||
"""Conv2d with bias term."""
|
||||
model: torch.nn.Module = Conv2dBiasModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x: torch.Tensor = torch.randn(1, 3, 8, 8, device=device)
|
||||
original: torch.Tensor = model(x)
|
||||
output: torch.Tensor = model_compiled(x)
|
||||
assert torch.allclose(output, original, atol=1e-4)
|
||||
|
||||
|
||||
def test_conv2d_stride(device: torch.device):
|
||||
"""Conv2d with stride=2: output spatial dims halved."""
|
||||
model: torch.nn.Module = Conv2dStrideModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x: torch.Tensor = torch.randn(1, 3, 8, 8, device=device)
|
||||
original: torch.Tensor = model(x)
|
||||
output: torch.Tensor = model_compiled(x)
|
||||
assert torch.allclose(output, original, atol=1e-4)
|
||||
|
||||
|
||||
def _run_conv2d_dilation(device: torch.device):
|
||||
"""Conv2d with dilation=2 preserves the expected spatial shape and values."""
|
||||
model: torch.nn.Module = Conv2dDilationModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x: torch.Tensor = torch.randn(2, 8, 17, 19, device=device)
|
||||
original: torch.Tensor = model(x)
|
||||
output: torch.Tensor = model_compiled(x)
|
||||
assert torch.allclose(output, original, atol=1e-4)
|
||||
|
||||
|
||||
def test_conv2d_dilation(device: torch.device):
|
||||
_run_conv2d_dilation(device)
|
||||
|
||||
|
||||
def _run_conv3d_same_pad(device: torch.device):
|
||||
"""Conv3d exercises the spatial=3 unfold/permute/split path."""
|
||||
model: torch.nn.Module = Conv3dSamePadModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x: torch.Tensor = torch.randn(2, 4, 6, 7, 8, device=device)
|
||||
original: torch.Tensor = model(x)
|
||||
output: torch.Tensor = model_compiled(x)
|
||||
assert torch.allclose(output, original, atol=1e-3)
|
||||
|
||||
|
||||
def test_conv3d_same_pad(device: torch.device):
|
||||
_run_conv3d_same_pad(device)
|
||||
|
||||
|
||||
def test_depthwise_conv1d(device: torch.device):
|
||||
"""Depthwise Conv1d with groups=in_channels, as used in Mamba."""
|
||||
model: torch.nn.Module = DepthwiseConv1dModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x: torch.Tensor = torch.randn(2, 16, 32, device=device)
|
||||
original: torch.Tensor = model(x)
|
||||
output: torch.Tensor = model_compiled(x)
|
||||
assert torch.allclose(output, original, atol=1e-4)
|
||||
|
||||
|
||||
def test_depthwise_conv2d(device: torch.device):
|
||||
"""Depthwise Conv2d with groups=in_channels."""
|
||||
model: torch.nn.Module = DepthwiseConv2dModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x: torch.Tensor = torch.randn(1, 8, 8, 8, device=device)
|
||||
original: torch.Tensor = model(x)
|
||||
output: torch.Tensor = model_compiled(x)
|
||||
assert torch.allclose(output, original, atol=1e-4)
|
||||
|
||||
|
||||
def _run_depthwise_multiplier_conv2d(device: torch.device):
|
||||
"""Depthwise Conv2d with multiplier > 1 should preserve both output channels per input channel."""
|
||||
model: torch.nn.Module = DepthwiseMultiplierConv2dModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x: torch.Tensor = torch.randn(2, 8, 9, 9, device=device)
|
||||
original: torch.Tensor = model(x)
|
||||
output: torch.Tensor = model_compiled(x)
|
||||
assert torch.allclose(output, original, atol=1e-4)
|
||||
|
||||
|
||||
def test_depthwise_multiplier_conv2d(device: torch.device):
|
||||
_run_depthwise_multiplier_conv2d(device)
|
||||
|
||||
|
||||
def test_grouped_conv2d(device: torch.device):
|
||||
"""Conv2d with groups=4 (grouped, not depthwise)."""
|
||||
model: torch.nn.Module = GroupedConv2dModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x: torch.Tensor = torch.randn(1, 16, 8, 8, device=device)
|
||||
original: torch.Tensor = model(x)
|
||||
output: torch.Tensor = model_compiled(x)
|
||||
assert torch.allclose(output, original, atol=1e-4)
|
||||
|
||||
|
||||
def _run_grouped_conv2d_groups3_batch4(device: torch.device):
|
||||
"""Grouped Conv2d with groups=3 and batch>1 exercises the pre-pad + slice path."""
|
||||
model: torch.nn.Module = GroupedConv2dGroups3Model().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x: torch.Tensor = torch.randn(4, 12, 11, 9, device=device)
|
||||
original: torch.Tensor = model(x)
|
||||
output: torch.Tensor = model_compiled(x)
|
||||
assert torch.allclose(output, original, atol=1e-3)
|
||||
|
||||
|
||||
def test_grouped_conv2d_groups3_batch4(device: torch.device):
|
||||
_run_grouped_conv2d_groups3_batch4(device)
|
||||
|
||||
|
||||
def test_mamba_conv_block(device: torch.device):
|
||||
"""Minimal Mamba-style block: depthwise Conv1d with causal gating (end-to-end)."""
|
||||
model: torch.nn.Module = MambaConvBlockModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x: torch.Tensor = torch.randn(2, 64, 16, device=device)
|
||||
original: torch.Tensor = model(x)
|
||||
output: torch.Tensor = model_compiled(x)
|
||||
assert torch.allclose(output, original, atol=1e-4)
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
Tests individual Llama3 building blocks (RMSNorm, RoPE, SwiGLU, causal attention,
|
||||
full transformer block) and progressively larger HuggingFace LlamaForCausalLM configs
|
||||
through the PyTorch -> ONNX -> luminal pipeline via torch.compile.
|
||||
through the PyTorch -> Pt2 -> luminal pipeline via torch.compile.
|
||||
"""
|
||||
|
||||
from typing import Callable
|
||||
@@ -66,12 +66,15 @@ def test_causal_self_attention(device: torch.device):
|
||||
|
||||
def test_llama_transformer_block(device: torch.device):
|
||||
"""Test full Llama transformer block: RMSNorm -> Attn -> Residual -> RMSNorm -> MLP -> Residual."""
|
||||
torch.manual_seed(0)
|
||||
model: torch.nn.Module = LlamaTransformerBlockModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x: torch.Tensor = torch.rand((1, 4, 32), device=device)
|
||||
original: torch.Tensor = model(x)
|
||||
output: torch.Tensor = model_compiled(x)
|
||||
assert torch.allclose(output, original, atol=1e-4)
|
||||
assert torch.allclose(output, original, atol=1e-3), (
|
||||
f"max_diff={torch.max(torch.abs(output - original)).item():.2e}"
|
||||
)
|
||||
|
||||
|
||||
# ========== HuggingFace LlamaForCausalLM Tests ==========
|
||||
@@ -362,6 +365,55 @@ def test_hf_llama3_large_full(device: torch.device):
|
||||
)
|
||||
|
||||
|
||||
# ========== Dynamic Dimension Tests ==========
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(),
|
||||
reason="CUDA graph in-place update test — requires CUDA",
|
||||
)
|
||||
def test_dynamic_dim_reuse_no_recompile(device: torch.device):
|
||||
"""Compile once with dynamic shapes, execute with varying seq lengths.
|
||||
|
||||
Validates that the luminal runtime correctly handles dynamic dimension
|
||||
changes without recompilation. This is the core scenario optimized by
|
||||
removing the unnecessary CUDA graph rebuild on dyn_map changes: a single
|
||||
compiled graph handles multiple sequence lengths via in-place parameter
|
||||
updates rather than rebuilding the entire CUDA graph each step.
|
||||
"""
|
||||
from luminal.pt2 import compile as luminal_compile
|
||||
|
||||
class DynamicSeqModel(torch.nn.Module):
|
||||
"""Embedding + linear projection with variable-length integer input."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.embed = torch.nn.Embedding(256, 64)
|
||||
self.proj = torch.nn.Linear(64, 64)
|
||||
|
||||
def forward(self, x):
|
||||
return self.proj(self.embed(x))
|
||||
|
||||
model = DynamicSeqModel().eval().to(device)
|
||||
|
||||
# Compile once with dynamic seq dim (auto-detected for integer inputs).
|
||||
# Factory capsule is auto-detected from example.device.
|
||||
example = torch.tensor([[1, 2, 3, 4]], device=device)
|
||||
compiled = luminal_compile(model, example, search_iterations=5)
|
||||
|
||||
# Execute with multiple different seq lengths — each call reuses the
|
||||
# same compiled graph, updating dynamic dims in-place.
|
||||
for seq_len in [4, 5, 6, 7, 8]:
|
||||
input_ids = torch.tensor([list(range(1, seq_len + 1))], device=device)
|
||||
with torch.no_grad():
|
||||
ref = model(input_ids)
|
||||
out = compiled(input_ids)
|
||||
assert torch.allclose(out[0], ref, atol=1e-5), (
|
||||
f"seq_len={seq_len}: "
|
||||
f"max_diff={torch.max(torch.abs(out[0] - ref)).item():.2e}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="numerical precision — max_diff exceeds atol")
|
||||
def test_hf_llama38b_full(device: torch.device):
|
||||
"""HuggingFace LlamaForCausalLM — full Llama-3.1-8B-Instruct with real pretrained weights.
|
||||
|
||||
@@ -3,6 +3,13 @@
|
||||
import torch
|
||||
|
||||
|
||||
class SelfAddModel(torch.nn.Module):
|
||||
"""Adds input to itself (x + x). Preserves input dtype."""
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x + x
|
||||
|
||||
|
||||
class AddTestModel(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
@@ -145,7 +152,7 @@ class TransposeInExpressionModel(torch.nn.Module):
|
||||
|
||||
|
||||
# ========== Constant Node Test Models ==========
|
||||
# These models test ONNX Constant node handling via inline tensor literals
|
||||
# These models test PT2 Constant node handling via inline tensor literals
|
||||
|
||||
|
||||
class ConstantScalarFloatModel(torch.nn.Module):
|
||||
@@ -284,7 +291,7 @@ class ConstantMultipleInGraphModel(torch.nn.Module):
|
||||
|
||||
|
||||
# ========== Cast Node Test Models ==========
|
||||
# These models test ONNX Cast node handling via .to(dtype) method
|
||||
# These models test PT2 Cast node handling via .to(dtype) method
|
||||
|
||||
|
||||
class CastDoubleToFloatModel(torch.nn.Module):
|
||||
@@ -387,7 +394,7 @@ class ModTestModel(torch.nn.Module):
|
||||
|
||||
|
||||
class ModByConstantModel(torch.nn.Module):
|
||||
"""Tests modulo with an inline constant tensor (ONNX Constant node)."""
|
||||
"""Tests modulo with an inline constant tensor (PT2 Constant node)."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
@@ -446,7 +453,7 @@ class CeilInExpressionModel(torch.nn.Module):
|
||||
|
||||
|
||||
# ========== Reshape Node Test Models ==========
|
||||
# These models test ONNX Reshape node handling in ops_parse.rs
|
||||
# These models test PT2 Reshape node handling in ops_parse.rs
|
||||
|
||||
|
||||
class ReshapeToFlatModel(torch.nn.Module):
|
||||
@@ -534,7 +541,7 @@ class ShapeReshapeKeepBatchModel(torch.nn.Module):
|
||||
|
||||
|
||||
# ========== Less Node Test Models ==========
|
||||
# These models test ONNX Less node handling in ops_parse.rs
|
||||
# These models test PT2 Less node handling in ops_parse.rs
|
||||
|
||||
|
||||
class LessTestModel(torch.nn.Module):
|
||||
@@ -560,7 +567,7 @@ class LessBroadcastModel(torch.nn.Module):
|
||||
|
||||
|
||||
class LessWithConstantModel(torch.nn.Module):
|
||||
"""Tests less-than against an inline constant (ONNX Constant + Less nodes)."""
|
||||
"""Tests less-than against an inline constant (PT2 Constant + Less nodes)."""
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
constant = torch.tensor([0.25, 0.5, 0.75]).to(x.device)
|
||||
@@ -568,7 +575,7 @@ class LessWithConstantModel(torch.nn.Module):
|
||||
|
||||
|
||||
# ========== Gather Node Test Models ==========
|
||||
# These models test ONNX Gather node handling in ops_parse.rs
|
||||
# These models test PT2 Gather node handling in ops_parse.rs
|
||||
|
||||
|
||||
class Gather1DModel(torch.nn.Module):
|
||||
@@ -621,7 +628,7 @@ class GatherNegativeIndicesModel(torch.nn.Module):
|
||||
|
||||
|
||||
class GatherConstantFoldModel(torch.nn.Module):
|
||||
"""Tests Gather constant folding: both data and indices are ONNX Constant nodes."""
|
||||
"""Tests Gather constant folding: both data and indices are PT2 Constant nodes."""
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
data = torch.tensor([10.0, 20.0, 30.0, 40.0, 50.0]).to(x.device)
|
||||
@@ -630,7 +637,7 @@ class GatherConstantFoldModel(torch.nn.Module):
|
||||
|
||||
|
||||
# ========== Squeeze Node Test Models ==========
|
||||
# These models test ONNX Squeeze node handling in ops_parse.rs
|
||||
# These models test PT2 Squeeze node handling in ops_parse.rs
|
||||
|
||||
|
||||
class SqueezeAxisModel(torch.nn.Module):
|
||||
@@ -1140,7 +1147,7 @@ class MaxTestModel(torch.nn.Module):
|
||||
|
||||
|
||||
class MaxWithConstantModel(torch.nn.Module):
|
||||
"""Tests element-wise maximum against an inline constant (ONNX Max + Constant nodes)."""
|
||||
"""Tests element-wise maximum against an inline constant (PT2 Max + Constant nodes)."""
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
constant = torch.tensor([0.2, 0.4, 0.6, 0.8, 1.0]).to(x.device)
|
||||
@@ -1162,7 +1169,7 @@ class MinTestModel(torch.nn.Module):
|
||||
|
||||
|
||||
class MinWithConstantModel(torch.nn.Module):
|
||||
"""Tests element-wise minimum against an inline constant (ONNX Min + Constant nodes)."""
|
||||
"""Tests element-wise minimum against an inline constant (PT2 Min + Constant nodes)."""
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
constant = torch.tensor([0.2, 0.4, 0.6, 0.8, 1.0]).to(x.device)
|
||||
@@ -1288,7 +1295,7 @@ class LessOrEqualTestModel(torch.nn.Module):
|
||||
|
||||
|
||||
class LessOrEqualWithConstantModel(torch.nn.Module):
|
||||
"""Tests less-than-or-equal against an inline constant (ONNX Constant + LessOrEqual nodes)."""
|
||||
"""Tests less-than-or-equal against an inline constant (PT2 Constant + LessOrEqual nodes)."""
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
constant = torch.tensor([0.25, 0.5, 0.75]).to(x.device)
|
||||
@@ -1310,7 +1317,7 @@ class GreaterOrEqualTestModel(torch.nn.Module):
|
||||
|
||||
|
||||
class GreaterOrEqualWithConstantModel(torch.nn.Module):
|
||||
"""Tests greater-than-or-equal against an inline constant (ONNX Constant + GreaterOrEqual nodes)."""
|
||||
"""Tests greater-than-or-equal against an inline constant (PT2 Constant + GreaterOrEqual nodes)."""
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
constant = torch.tensor([0.25, 0.5, 0.75]).to(x.device)
|
||||
@@ -1432,7 +1439,7 @@ class GreaterTestModel(torch.nn.Module):
|
||||
|
||||
|
||||
class GreaterWithConstantModel(torch.nn.Module):
|
||||
"""Tests greater-than against a scalar constant (ONNX Greater + Constant nodes)."""
|
||||
"""Tests greater-than against a scalar constant (PT2 Greater + Constant nodes)."""
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return (x > 0.5).to(torch.float32)
|
||||
@@ -1509,7 +1516,7 @@ class MLPBlockModel(torch.nn.Module):
|
||||
|
||||
|
||||
class GatherElementsTestModel(torch.nn.Module):
|
||||
"""Tests element-wise gather along axis=1 using torch.gather (→ ONNX GatherElements)."""
|
||||
"""Tests element-wise gather along axis=1 using torch.gather (→ PT2 GatherElements)."""
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
idx = torch.tensor([[0, 1, 1], [1, 0, 0]], device=x.device)
|
||||
@@ -1530,7 +1537,7 @@ class GatherElementsLargeTestModel(torch.nn.Module):
|
||||
|
||||
|
||||
class ExpandTestModel(torch.nn.Module):
|
||||
"""Tests broadcasting a (1, 4) tensor to (3, 4) via .expand() (→ ONNX Expand)."""
|
||||
"""Tests broadcasting a (1, 4) tensor to (3, 4) via .expand() (→ PT2 Expand)."""
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x.expand(3, 4)
|
||||
@@ -1550,7 +1557,7 @@ class IsNaNTestModel(torch.nn.Module):
|
||||
|
||||
|
||||
class LayerNormTestModel(torch.nn.Module):
|
||||
"""Tests nn.LayerNorm which exports as ONNX LayerNormalization."""
|
||||
"""Tests nn.LayerNorm which exports as PT2 LayerNormalization."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
@@ -1564,7 +1571,7 @@ class LayerNormTestModel(torch.nn.Module):
|
||||
|
||||
|
||||
class GemmTestModel(torch.nn.Module):
|
||||
"""Tests Gemm: nn.Linear exports as ONNX Gemm (weight transposed)."""
|
||||
"""Tests Gemm: nn.Linear exports as PT2 Gemm (weight transposed)."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
@@ -1588,14 +1595,14 @@ class ErfTestModel(torch.nn.Module):
|
||||
|
||||
|
||||
class SliceTestModel(torch.nn.Module):
|
||||
"""Tests ONNX Slice: slice axis 0 from index 1 to 3."""
|
||||
"""Tests PT2 Slice: slice axis 0 from index 1 to 3."""
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x[1:3]
|
||||
|
||||
|
||||
class SliceMultiAxisTestModel(torch.nn.Module):
|
||||
"""Tests ONNX Slice along multiple axes: x[1:3, 0:2]."""
|
||||
"""Tests PT2 Slice along multiple axes: x[1:3, 0:2]."""
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x[1:3, 0:2]
|
||||
@@ -1612,6 +1619,73 @@ class SplitTestModel(torch.nn.Module):
|
||||
return a + b
|
||||
|
||||
|
||||
# ========== Argsort / MoE Routing Test Models ==========
|
||||
|
||||
|
||||
class ArgsortStableDuplicatesModel(torch.nn.Module):
|
||||
"""Tests deterministic duplicate ordering for exported argsort."""
|
||||
|
||||
SORT_DIM = 1
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.argsort(x, dim=self.SORT_DIM)
|
||||
|
||||
|
||||
class TinyMoERoutingModel(torch.nn.Module):
|
||||
"""Minimal deterministic MoE-style routing proof for PT2/native and CUDA."""
|
||||
|
||||
TOP_K = 2
|
||||
ROUTING_DIM = -1
|
||||
ZERO_FILL = 0.0
|
||||
DISPATCH_ON = 1
|
||||
GROUP_SIZE = 2
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.register_buffer(
|
||||
"expert_scale",
|
||||
torch.tensor([1.5, -0.5, 2.0, 0.25], dtype=torch.float32),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, scores: torch.Tensor
|
||||
) -> tuple[
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
]:
|
||||
topk_values, topk_indices = torch.topk(scores, self.TOP_K, dim=self.ROUTING_DIM)
|
||||
regroup_order = torch.argsort(topk_indices, dim=self.ROUTING_DIM)
|
||||
routed_indices = torch.gather(topk_indices, self.ROUTING_DIM, regroup_order)
|
||||
routed_values = torch.gather(topk_values, self.ROUTING_DIM, regroup_order)
|
||||
|
||||
expert_scale = self.expert_scale.unsqueeze(0).expand(scores.shape[0], -1)
|
||||
gathered_scale = torch.gather(expert_scale, self.ROUTING_DIM, routed_indices)
|
||||
weighted = routed_values * gathered_scale
|
||||
|
||||
inactive_mask = torch.bitwise_not(weighted > 0)
|
||||
masked_values = weighted.masked_fill(inactive_mask, self.ZERO_FILL)
|
||||
|
||||
slots = torch.zeros_like(routed_indices).scatter(
|
||||
self.ROUTING_DIM, regroup_order, self.DISPATCH_ON
|
||||
)
|
||||
active_slots = torch.bitwise_not(inactive_mask).to(slots.dtype)
|
||||
dispatch = slots * active_slots
|
||||
group_ids = torch.floor_divide(routed_indices, self.GROUP_SIZE)
|
||||
routing_sign = torch.sign(masked_values)
|
||||
return (
|
||||
routed_indices,
|
||||
masked_values,
|
||||
dispatch,
|
||||
inactive_mask,
|
||||
group_ids,
|
||||
routing_sign,
|
||||
)
|
||||
|
||||
|
||||
# ========== TopK Node Test Models ==========
|
||||
|
||||
|
||||
@@ -1684,7 +1758,7 @@ class ScatterNDTestModel(torch.nn.Module):
|
||||
class RMSNormModel(torch.nn.Module):
|
||||
"""Tests RMS normalization: x * rsqrt(mean(x^2) + eps) * weight.
|
||||
|
||||
ONNX ops: Pow, ReduceMean, Add, Sqrt, Reciprocal, Mul.
|
||||
PT2 ops: Pow, ReduceMean, Add, Sqrt, Reciprocal, Mul.
|
||||
Input: (1, 4, 32) -> Output: (1, 4, 32).
|
||||
"""
|
||||
|
||||
@@ -1703,7 +1777,7 @@ class RotaryEmbeddingModel(torch.nn.Module):
|
||||
"""Tests rotary position embeddings (RoPE) using rotate-half approach.
|
||||
|
||||
Precomputes cos/sin caches as buffers; at runtime: slice, split halves, rotate.
|
||||
ONNX ops: Slice, Unsqueeze, Mul, Sub, Add, Concat.
|
||||
PT2 ops: Slice, Unsqueeze, Mul, Sub, Add, Concat.
|
||||
Input: (1, 4, 4, 8) [batch, seq, heads, head_dim] -> Output: same shape.
|
||||
"""
|
||||
|
||||
@@ -1732,7 +1806,7 @@ class RotaryEmbeddingModel(torch.nn.Module):
|
||||
class SwiGLUMLPModel(torch.nn.Module):
|
||||
"""Tests SwiGLU MLP: down_proj(silu(gate_proj(x)) * up_proj(x)).
|
||||
|
||||
silu(x) = x * sigmoid(x), decomposes to Sigmoid+Mul in ONNX.
|
||||
silu(x) = x * sigmoid(x), decomposes to Sigmoid+Mul in PT2.
|
||||
Input: (1, 4, 32) -> Output: (1, 4, 32).
|
||||
"""
|
||||
|
||||
@@ -1823,3 +1897,307 @@ class LlamaTransformerBlockModel(torch.nn.Module):
|
||||
h = x + self.attn(self.input_norm(x))
|
||||
out = h + self.mlp(self.post_attn_norm(h))
|
||||
return out
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Convolution models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class Conv1dNoPadModel(torch.nn.Module):
|
||||
"""Conv1d with no padding: output length shrinks by (kernel-1)."""
|
||||
|
||||
KERNEL_SIZE = 3
|
||||
PADDING = 0
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv1d(
|
||||
8, 16, kernel_size=self.KERNEL_SIZE, padding=self.PADDING, bias=False
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class Conv1dSamePadModel(torch.nn.Module):
|
||||
"""Conv1d with same-size padding (output length == input length)."""
|
||||
|
||||
KERNEL_SIZE = 3
|
||||
PADDING = 1
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv1d(
|
||||
8, 16, kernel_size=self.KERNEL_SIZE, padding=self.PADDING, bias=False
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class Conv1dBiasModel(torch.nn.Module):
|
||||
"""Conv1d with bias."""
|
||||
|
||||
KERNEL_SIZE = 3
|
||||
PADDING = 1
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv1d(
|
||||
8, 16, kernel_size=self.KERNEL_SIZE, padding=self.PADDING, bias=True
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class Conv2dNoPadModel(torch.nn.Module):
|
||||
"""Conv2d with no padding: output spatial dims shrink by (kernel-1)."""
|
||||
|
||||
KERNEL_SIZE = 3
|
||||
PADDING = 0
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(
|
||||
3, 16, kernel_size=self.KERNEL_SIZE, padding=self.PADDING, bias=False
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class Conv2dSamePadModel(torch.nn.Module):
|
||||
"""Conv2d with same-size padding."""
|
||||
|
||||
KERNEL_SIZE = 3
|
||||
PADDING = 1
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(
|
||||
3, 16, kernel_size=self.KERNEL_SIZE, padding=self.PADDING, bias=False
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class Conv2dBiasModel(torch.nn.Module):
|
||||
"""Conv2d with bias."""
|
||||
|
||||
KERNEL_SIZE = 3
|
||||
PADDING = 1
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(
|
||||
3, 16, kernel_size=self.KERNEL_SIZE, padding=self.PADDING, bias=True
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class Conv2dStrideModel(torch.nn.Module):
|
||||
"""Conv2d with stride=2 (output dims halved)."""
|
||||
|
||||
KERNEL_SIZE = 3
|
||||
STRIDE = 2
|
||||
PADDING = 1
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(
|
||||
3,
|
||||
16,
|
||||
kernel_size=self.KERNEL_SIZE,
|
||||
stride=self.STRIDE,
|
||||
padding=self.PADDING,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class Conv2dDilationModel(torch.nn.Module):
|
||||
"""Conv2d with dilation=2 and padding chosen to preserve spatial size."""
|
||||
|
||||
KERNEL_SIZE = 3
|
||||
DILATION = 2
|
||||
PADDING = 2
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(
|
||||
8,
|
||||
16,
|
||||
kernel_size=self.KERNEL_SIZE,
|
||||
dilation=self.DILATION,
|
||||
padding=self.PADDING,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class Conv3dSamePadModel(torch.nn.Module):
|
||||
"""Conv3d with padding=1 to preserve spatial dimensions."""
|
||||
|
||||
KERNEL_SIZE = 3
|
||||
PADDING = 1
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv3d(
|
||||
4, 8, kernel_size=self.KERNEL_SIZE, padding=self.PADDING, bias=False
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class DepthwiseConv1dModel(torch.nn.Module):
|
||||
"""Depthwise Conv1d as used in Mamba (groups == in_channels)."""
|
||||
|
||||
KERNEL_SIZE = 4
|
||||
GROUPS = 16
|
||||
PADDING = 3
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv1d(
|
||||
16,
|
||||
16,
|
||||
kernel_size=self.KERNEL_SIZE,
|
||||
groups=self.GROUPS,
|
||||
padding=self.PADDING,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# Causal truncation: keep only the first L positions
|
||||
return self.conv(x)[:, :, : x.shape[2]]
|
||||
|
||||
|
||||
class DepthwiseConv2dModel(torch.nn.Module):
|
||||
"""Depthwise Conv2d (groups == in_channels)."""
|
||||
|
||||
KERNEL_SIZE = 3
|
||||
GROUPS = 8
|
||||
PADDING = 1
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(
|
||||
8,
|
||||
8,
|
||||
kernel_size=self.KERNEL_SIZE,
|
||||
groups=self.GROUPS,
|
||||
padding=self.PADDING,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class DepthwiseMultiplierConv2dModel(torch.nn.Module):
|
||||
"""Depthwise Conv2d with channel multiplier 2 (out_channels = 2 * in_channels)."""
|
||||
|
||||
KERNEL_SIZE = 3
|
||||
GROUPS = 8
|
||||
PADDING = 1
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(
|
||||
8,
|
||||
16,
|
||||
kernel_size=self.KERNEL_SIZE,
|
||||
groups=self.GROUPS,
|
||||
padding=self.PADDING,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class GroupedConv2dModel(torch.nn.Module):
|
||||
"""Conv2d with groups=4 (not depthwise, but grouped)."""
|
||||
|
||||
KERNEL_SIZE = 3
|
||||
GROUPS = 4
|
||||
PADDING = 1
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(
|
||||
16,
|
||||
32,
|
||||
kernel_size=self.KERNEL_SIZE,
|
||||
groups=self.GROUPS,
|
||||
padding=self.PADDING,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class GroupedConv2dGroups3Model(torch.nn.Module):
|
||||
"""Conv2d with groups=3 and ch_per_group=4."""
|
||||
|
||||
KERNEL_SIZE = 3
|
||||
GROUPS = 3
|
||||
PADDING = 1
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(
|
||||
12,
|
||||
12,
|
||||
kernel_size=self.KERNEL_SIZE,
|
||||
groups=self.GROUPS,
|
||||
padding=self.PADDING,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class MambaConvBlockModel(torch.nn.Module):
|
||||
"""Minimal Mamba-style SSM block: Linear -> split -> depthwise Conv1d -> SiLU gate -> Linear.
|
||||
|
||||
This is the core conv pattern used in Mamba / Mamba-2 models.
|
||||
"""
|
||||
|
||||
def __init__(self, d_model: int = 16, d_conv: int = 4, expand: int = 2) -> None:
|
||||
super().__init__()
|
||||
d_inner = d_model * expand
|
||||
groups = d_inner
|
||||
padding = d_conv - 1
|
||||
self.in_proj = torch.nn.Linear(d_model, d_inner * 2, bias=False)
|
||||
self.conv1d = torch.nn.Conv1d(
|
||||
d_inner,
|
||||
d_inner,
|
||||
d_conv,
|
||||
groups=groups,
|
||||
padding=padding,
|
||||
bias=True,
|
||||
)
|
||||
self.out_proj = torch.nn.Linear(d_inner, d_model, bias=False)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
b, seq_len, _ = x.shape
|
||||
xz = self.in_proj(x)
|
||||
x_part, z = xz.chunk(2, dim=-1)
|
||||
x_part = self.conv1d(x_part.transpose(1, 2))[:, :, :seq_len].transpose(1, 2)
|
||||
return self.out_proj(
|
||||
torch.nn.functional.silu(x_part) * torch.nn.functional.silu(z)
|
||||
)
|
||||
|
||||
BIN
docs/logo/inference_at_the_speed_of_light.png
Normal file
BIN
docs/logo/inference_at_the_speed_of_light.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 380 KiB |
@@ -199,7 +199,7 @@ impl Gemma {
|
||||
kv_cache.v_caches[i],
|
||||
kv_cache.max_seq,
|
||||
);
|
||||
x = x_new.graph_break();
|
||||
x = x_new;
|
||||
cache_outputs.push((k_out, v_out));
|
||||
}
|
||||
let logits = self.lm_norm.forward(x).matmul(self.lm_head.t());
|
||||
|
||||
22
examples/gemma4_moe/Cargo.toml
Normal file
22
examples/gemma4_moe/Cargo.toml
Normal file
@@ -0,0 +1,22 @@
|
||||
[package]
|
||||
name = "gemma4_moe"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[features]
|
||||
|
||||
[dependencies]
|
||||
luminal = { path = "../.." }
|
||||
luminal_nn = { path = "../../crates/luminal_nn" }
|
||||
luminal_cuda_lite = { path = "../../crates/luminal_cuda_lite" }
|
||||
tokenizers = "0.22.2"
|
||||
rustc-hash = "2"
|
||||
|
||||
# HuggingFace model download
|
||||
hf-hub = { version = "0.4", default-features = false, features = ["rustls-tls", "ureq"] }
|
||||
safetensors = "0.7.0"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
half = { version = "2.7.1", features = ["bytemuck"] }
|
||||
bytemuck = "1.24.0"
|
||||
memmap2 = "0.9.9"
|
||||
227
examples/gemma4_moe/src/hf.rs
Normal file
227
examples/gemma4_moe/src/hf.rs
Normal file
@@ -0,0 +1,227 @@
|
||||
use half::{bf16, f16};
|
||||
use hf_hub::api::sync::Api;
|
||||
use memmap2::MmapOptions;
|
||||
use safetensors::{tensor::TensorView, Dtype, SafeTensors};
|
||||
use serde::Deserialize;
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
fs::File,
|
||||
io::Write,
|
||||
path::{Path, PathBuf},
|
||||
};
|
||||
|
||||
use crate::model::HIDDEN;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct SafetensorsIndex {
|
||||
weight_map: HashMap<String, String>,
|
||||
}
|
||||
|
||||
enum TensorData {
|
||||
F32(Vec<f32>),
|
||||
BF16(Vec<u8>),
|
||||
}
|
||||
|
||||
struct StoredTensor {
|
||||
shape: Vec<usize>,
|
||||
data: TensorData,
|
||||
}
|
||||
|
||||
pub fn download_hf_model(repo_id: &str) -> Result<PathBuf, Box<dyn std::error::Error>> {
|
||||
let api = Api::new()?;
|
||||
let repo = api.model(repo_id.to_string());
|
||||
|
||||
let tokenizer_path = repo.get("tokenizer.json")?;
|
||||
let model_dir = tokenizer_path.parent().unwrap().to_path_buf();
|
||||
|
||||
if repo.get("model.safetensors").is_ok() {
|
||||
return Ok(model_dir);
|
||||
}
|
||||
|
||||
let index_path = repo.get("model.safetensors.index.json")?;
|
||||
let index_content = std::fs::read_to_string(&index_path)?;
|
||||
let index: SafetensorsIndex = serde_json::from_str(&index_content)?;
|
||||
|
||||
let mut shard_files: Vec<String> = index.weight_map.values().cloned().collect();
|
||||
shard_files.sort();
|
||||
shard_files.dedup();
|
||||
|
||||
for shard_file in &shard_files {
|
||||
repo.get(shard_file)?;
|
||||
}
|
||||
|
||||
Ok(model_dir)
|
||||
}
|
||||
|
||||
fn tensor_to_f32(tensor: &safetensors::tensor::TensorView) -> Vec<f32> {
|
||||
match tensor.dtype() {
|
||||
Dtype::F32 => bytemuck::cast_slice::<u8, f32>(tensor.data()).to_vec(),
|
||||
Dtype::F16 => {
|
||||
let f16_slice: &[f16] = bytemuck::cast_slice(tensor.data());
|
||||
f16_slice.iter().map(|x| x.to_f32()).collect()
|
||||
}
|
||||
Dtype::BF16 => {
|
||||
let bf16_slice: &[bf16] = bytemuck::cast_slice(tensor.data());
|
||||
bf16_slice.iter().map(|x| x.to_f32()).collect()
|
||||
}
|
||||
other => panic!("Unsupported dtype for conversion: {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
fn tensor_to_bf16_bytes(tensor: &safetensors::tensor::TensorView) -> Vec<u8> {
|
||||
match tensor.dtype() {
|
||||
Dtype::BF16 => tensor.data().to_vec(),
|
||||
Dtype::F16 => {
|
||||
let f16_slice: &[f16] = bytemuck::cast_slice(tensor.data());
|
||||
let bf16_data: Vec<bf16> = f16_slice
|
||||
.iter()
|
||||
.map(|x| bf16::from_f32(x.to_f32()))
|
||||
.collect();
|
||||
bytemuck::cast_slice(&bf16_data).to_vec()
|
||||
}
|
||||
Dtype::F32 => {
|
||||
let f32_slice: &[f32] = bytemuck::cast_slice(tensor.data());
|
||||
let bf16_data: Vec<bf16> = f32_slice.iter().map(|x| bf16::from_f32(*x)).collect();
|
||||
bytemuck::cast_slice(&bf16_data).to_vec()
|
||||
}
|
||||
other => panic!("Unsupported dtype for conversion: {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
fn is_text_weight(name: &str) -> bool {
|
||||
name.starts_with("model.language_model.")
|
||||
}
|
||||
|
||||
fn is_expert_weight(name: &str) -> bool {
|
||||
name.contains(".experts.")
|
||||
}
|
||||
|
||||
pub fn combine_safetensors(model_dir: &Path) -> Result<PathBuf, Box<dyn std::error::Error>> {
|
||||
let output_path = model_dir.join("model_combined.safetensors");
|
||||
if output_path.exists() {
|
||||
return Ok(output_path);
|
||||
}
|
||||
|
||||
let index_path = model_dir.join("model.safetensors.index.json");
|
||||
let single_shard_path = model_dir.join("model.safetensors");
|
||||
|
||||
let shard_files: Vec<PathBuf> = if single_shard_path.exists() && !index_path.exists() {
|
||||
println!("Single shard model detected...");
|
||||
vec![single_shard_path]
|
||||
} else if index_path.exists() {
|
||||
let index_content = std::fs::read_to_string(&index_path)?;
|
||||
let index: SafetensorsIndex = serde_json::from_str(&index_content)?;
|
||||
|
||||
let mut files: Vec<String> = index.weight_map.values().cloned().collect();
|
||||
files.sort();
|
||||
files.dedup();
|
||||
|
||||
println!("Loading {} shard files...", files.len());
|
||||
files.into_iter().map(|f| model_dir.join(f)).collect()
|
||||
} else {
|
||||
return Err("No model.safetensors or model.safetensors.index.json found".into());
|
||||
};
|
||||
|
||||
let mut all_tensors: HashMap<String, StoredTensor> = HashMap::new();
|
||||
|
||||
for shard_path in &shard_files {
|
||||
println!(
|
||||
" Loading {}...",
|
||||
shard_path.file_name().unwrap().to_string_lossy()
|
||||
);
|
||||
let file = File::open(shard_path)?;
|
||||
let mmap = unsafe { MmapOptions::new().map(&file)? };
|
||||
let st = SafeTensors::deserialize(&mmap)?;
|
||||
|
||||
for name in st.names() {
|
||||
if !is_text_weight(name) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let new_name = name.replacen("model.language_model.", "model.", 1);
|
||||
let tensor = st.tensor(name)?;
|
||||
|
||||
if new_name.ends_with(".layer_scalar") {
|
||||
let scalar = tensor_to_f32(&tensor);
|
||||
let scalar = *scalar.first().expect("layer_scalar tensor is empty");
|
||||
all_tensors.insert(
|
||||
new_name,
|
||||
StoredTensor {
|
||||
shape: vec![HIDDEN],
|
||||
data: TensorData::F32(vec![scalar; HIDDEN]),
|
||||
},
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
let shape = tensor.shape().to_vec();
|
||||
let data = if is_expert_weight(&new_name) {
|
||||
TensorData::BF16(tensor_to_bf16_bytes(&tensor))
|
||||
} else {
|
||||
TensorData::F32(tensor_to_f32(&tensor))
|
||||
};
|
||||
|
||||
all_tensors.insert(new_name, StoredTensor { shape, data });
|
||||
}
|
||||
}
|
||||
|
||||
println!("Extracted {} text tensors", all_tensors.len());
|
||||
|
||||
let embed_key = "model.embed_tokens.weight";
|
||||
if let Some(embed_tensor) = all_tensors.get(embed_key) {
|
||||
let (shape, embed_data) = match &embed_tensor.data {
|
||||
TensorData::F32(data) => (embed_tensor.shape.clone(), data.clone()),
|
||||
TensorData::BF16(_) => unreachable!("Embedding weights should stay in F32"),
|
||||
};
|
||||
|
||||
all_tensors.insert(
|
||||
"lm_head.weight".to_string(),
|
||||
StoredTensor {
|
||||
shape,
|
||||
data: TensorData::F32(embed_data.clone()),
|
||||
},
|
||||
);
|
||||
|
||||
let embed_scale = (HIDDEN as f32).sqrt();
|
||||
if let Some(stored) = all_tensors.get_mut(embed_key) {
|
||||
match &mut stored.data {
|
||||
TensorData::F32(data) => {
|
||||
for value in data {
|
||||
*value *= embed_scale;
|
||||
}
|
||||
}
|
||||
TensorData::BF16(_) => unreachable!("Embedding weights should stay in F32"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
println!("Saving combined model (BF16 experts + F32 rest)...");
|
||||
let tensor_views: HashMap<String, TensorView<'_>> = all_tensors
|
||||
.iter()
|
||||
.map(|(name, stored)| {
|
||||
let view = match &stored.data {
|
||||
TensorData::F32(data) => {
|
||||
let bytes: &[u8] = bytemuck::cast_slice(data);
|
||||
TensorView::new(Dtype::F32, stored.shape.clone(), bytes).unwrap()
|
||||
}
|
||||
TensorData::BF16(bytes) => {
|
||||
TensorView::new(Dtype::BF16, stored.shape.clone(), bytes).unwrap()
|
||||
}
|
||||
};
|
||||
(name.clone(), view)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let serialized = safetensors::serialize(&tensor_views, None)?;
|
||||
let mut file = File::create(&output_path)?;
|
||||
file.write_all(&serialized)?;
|
||||
|
||||
println!("Combined model saved successfully!");
|
||||
Ok(output_path)
|
||||
}
|
||||
|
||||
pub fn prepare_hf_model(repo_id: &str) -> Result<PathBuf, Box<dyn std::error::Error>> {
|
||||
let model_dir = download_hf_model(repo_id)?;
|
||||
combine_safetensors(&model_dir)?;
|
||||
Ok(model_dir)
|
||||
}
|
||||
190
examples/gemma4_moe/src/main.rs
Normal file
190
examples/gemma4_moe/src/main.rs
Normal file
@@ -0,0 +1,190 @@
|
||||
mod hf;
|
||||
mod model;
|
||||
|
||||
use hf::prepare_hf_model;
|
||||
use luminal::prelude::*;
|
||||
use luminal_cuda_lite::{cudarc::driver::CudaContext, runtime::CudaRuntime};
|
||||
use model::*;
|
||||
use rustc_hash::FxHashSet;
|
||||
use std::{io::Write, time::Duration};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
const REPO_ID: &str = "google/gemma-4-26B-A4B";
|
||||
|
||||
fn env_usize(name: &str, default: usize) -> usize {
|
||||
std::env::var(name)
|
||||
.ok()
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(default)
|
||||
}
|
||||
|
||||
fn env_bool(name: &str) -> bool {
|
||||
std::env::var(name)
|
||||
.ok()
|
||||
.is_some_and(|s| matches!(s.as_str(), "1" | "true" | "TRUE" | "yes" | "YES"))
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let max_seq_len = env_usize("MAX_SEQ_LEN", 4096);
|
||||
let gen_tokens = env_usize("GEN_TOKENS", 30);
|
||||
let search_graphs = env_usize("SEARCH_GRAPHS", 50);
|
||||
let prompt = std::env::var("PROMPT").unwrap_or_else(|_| "The capital of France is".to_string());
|
||||
let print_token_ids = env_bool("PRINT_TOKEN_IDS");
|
||||
|
||||
let ctx = CudaContext::new(0).unwrap();
|
||||
let stream = ctx.default_stream();
|
||||
|
||||
let model_dir = prepare_hf_model(REPO_ID).expect("Failed to prepare model");
|
||||
println!("Using model directory: {}", model_dir.display());
|
||||
|
||||
let tokenizer = Tokenizer::from_file(model_dir.join("tokenizer.json")).unwrap();
|
||||
let prompt_tokens = tokenizer
|
||||
.encode(prompt.as_str(), true)
|
||||
.unwrap()
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
|
||||
let mut cx = Graph::default();
|
||||
let input = cx.named_tensor("input", 's').as_dtype(DType::Int);
|
||||
let pos_ids = cx.named_tensor("pos_ids", 's').as_dtype(DType::Int);
|
||||
let kv_cache = KVCache::new(&mut cx, max_seq_len);
|
||||
let (logits, cache_outputs) = Gemma4MoE::init(&mut cx).forward(input, pos_ids, &kv_cache);
|
||||
let logits = logits.output();
|
||||
for (k_out, v_out) in &cache_outputs {
|
||||
k_out.output();
|
||||
v_out.output();
|
||||
}
|
||||
|
||||
println!("Building E-Graph...");
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
|
||||
println!("Loading weights...");
|
||||
let mut runtime = CudaRuntime::initialize(stream);
|
||||
let weights_path = model_dir.join("model_combined.safetensors");
|
||||
runtime.load_safetensors(&cx, weights_path.to_str().unwrap());
|
||||
|
||||
for layer in 0..LAYERS {
|
||||
let cache_bytes = cache_bytes_for_layer(layer, max_seq_len);
|
||||
runtime.set_zeros(kv_cache.k_caches[layer], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[layer], cache_bytes);
|
||||
}
|
||||
|
||||
println!("Compiling...");
|
||||
cx.set_dim('s', 1);
|
||||
cx.set_dim('p', 1);
|
||||
runtime.set_data(input, vec![1]);
|
||||
runtime.set_data(pos_ids, vec![1]);
|
||||
runtime = cx.search(runtime, search_graphs);
|
||||
|
||||
for layer in 0..LAYERS {
|
||||
let cache_bytes = cache_bytes_for_layer(layer, max_seq_len);
|
||||
runtime.set_zeros(kv_cache.k_caches[layer], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[layer], cache_bytes);
|
||||
}
|
||||
|
||||
print!("{prompt}");
|
||||
std::io::stdout().flush().unwrap();
|
||||
|
||||
let mut prev_seq = 0usize;
|
||||
let mut fwd_durations = vec![];
|
||||
let mut seen_tokens = FxHashSet::default();
|
||||
let mut generated_token_ids = vec![];
|
||||
let repetition_penalty: f32 = 1.05;
|
||||
|
||||
const EOS_TOKEN: u32 = 1;
|
||||
|
||||
let prefill_start = std::time::Instant::now();
|
||||
for &token in &prompt_tokens {
|
||||
cx.set_dim('s', 1);
|
||||
cx.set_dim('p', prev_seq);
|
||||
runtime.set_data(input, vec![token as i32]);
|
||||
runtime.set_data(pos_ids, vec![prev_seq as i32]);
|
||||
runtime.execute(&cx.dyn_map);
|
||||
|
||||
for (layer_idx, (k_out, v_out)) in cache_outputs.iter().enumerate() {
|
||||
let k_buf = runtime.remove_buffer(*k_out);
|
||||
let v_buf = runtime.remove_buffer(*v_out);
|
||||
runtime.set_buffer(kv_cache.k_caches[layer_idx], k_buf);
|
||||
runtime.set_buffer(kv_cache.v_caches[layer_idx], v_buf);
|
||||
}
|
||||
|
||||
prev_seq += 1;
|
||||
}
|
||||
let prefill_duration = prefill_start.elapsed();
|
||||
|
||||
let logits_data = runtime.get_f32(logits);
|
||||
let last_row = &logits_data[..VOCAB_SIZE];
|
||||
let mut next_token = last_row
|
||||
.iter()
|
||||
.enumerate()
|
||||
.max_by(|(_, a), (_, b)| a.total_cmp(b))
|
||||
.unwrap()
|
||||
.0 as u32;
|
||||
generated_token_ids.push(next_token);
|
||||
print!("{}", tokenizer.decode(&[next_token], true).unwrap());
|
||||
std::io::stdout().flush().unwrap();
|
||||
seen_tokens.insert(next_token);
|
||||
|
||||
for _ in 1..gen_tokens {
|
||||
let start = std::time::Instant::now();
|
||||
cx.set_dim('s', 1);
|
||||
cx.set_dim('p', prev_seq);
|
||||
runtime.set_data(input, vec![next_token as i32]);
|
||||
runtime.set_data(pos_ids, vec![prev_seq as i32]);
|
||||
runtime.execute(&cx.dyn_map);
|
||||
|
||||
for (layer_idx, (k_out, v_out)) in cache_outputs.iter().enumerate() {
|
||||
let k_buf = runtime.remove_buffer(*k_out);
|
||||
let v_buf = runtime.remove_buffer(*v_out);
|
||||
runtime.set_buffer(kv_cache.k_caches[layer_idx], k_buf);
|
||||
runtime.set_buffer(kv_cache.v_caches[layer_idx], v_buf);
|
||||
}
|
||||
|
||||
prev_seq += 1;
|
||||
|
||||
let logits_data = runtime.get_f32(logits);
|
||||
let mut last_row = logits_data[..VOCAB_SIZE].to_vec();
|
||||
for &tok in &seen_tokens {
|
||||
let logit = &mut last_row[tok as usize];
|
||||
if *logit > 0.0 {
|
||||
*logit /= repetition_penalty;
|
||||
} else {
|
||||
*logit *= repetition_penalty;
|
||||
}
|
||||
}
|
||||
next_token = last_row
|
||||
.iter()
|
||||
.enumerate()
|
||||
.max_by(|(_, a), (_, b)| a.total_cmp(b))
|
||||
.unwrap()
|
||||
.0 as u32;
|
||||
generated_token_ids.push(next_token);
|
||||
seen_tokens.insert(next_token);
|
||||
|
||||
if next_token == EOS_TOKEN {
|
||||
break;
|
||||
}
|
||||
|
||||
print!("{}", tokenizer.decode(&[next_token], true).unwrap());
|
||||
std::io::stdout().flush().unwrap();
|
||||
fwd_durations.push(start.elapsed());
|
||||
}
|
||||
println!();
|
||||
if print_token_ids {
|
||||
println!("Generated token ids: {generated_token_ids:?}");
|
||||
}
|
||||
|
||||
println!(
|
||||
" TTFT: {:.2} ms ({} prompt tokens)",
|
||||
prefill_duration.as_secs_f64() * 1e3,
|
||||
prompt_tokens.len()
|
||||
);
|
||||
if fwd_durations.len() > 1 {
|
||||
println!(
|
||||
" TPOT: {:.2} ms",
|
||||
(fwd_durations.iter().skip(1).sum::<Duration>() / (fwd_durations.len() - 1) as u32)
|
||||
.as_secs_f64()
|
||||
* 1_000.
|
||||
);
|
||||
}
|
||||
}
|
||||
621
examples/gemma4_moe/src/model.rs
Normal file
621
examples/gemma4_moe/src/model.rs
Normal file
@@ -0,0 +1,621 @@
|
||||
use luminal::{
|
||||
dtype::DType,
|
||||
graph::Graph,
|
||||
prelude::{F32Pow, GraphTensor},
|
||||
shape::Expression,
|
||||
};
|
||||
use luminal_nn::LayerNorm;
|
||||
|
||||
pub const LAYERS: usize = 30;
|
||||
pub const HIDDEN: usize = 2816;
|
||||
pub const INTERMEDIATE: usize = 2112;
|
||||
pub const MOE_INTERMEDIATE: usize = 704;
|
||||
pub const NUM_EXPERTS: usize = 128;
|
||||
pub const TOP_K: usize = 8;
|
||||
pub const N_HEADS: usize = 16;
|
||||
pub const SLIDING_HEAD_DIM: usize = 256;
|
||||
pub const FULL_HEAD_DIM: usize = 512;
|
||||
pub const SLIDING_KV_HEADS: usize = 8;
|
||||
pub const FULL_KV_HEADS: usize = 2;
|
||||
pub const VOCAB_SIZE: usize = 262144;
|
||||
pub const RMS_NORM_EPS: f32 = 1e-6;
|
||||
pub const SLIDING_WINDOW_SIZE: usize = 1024;
|
||||
pub const SLIDING_ROPE_THETA: f32 = 10_000.0;
|
||||
pub const FULL_ROPE_THETA: f32 = 1_000_000.0;
|
||||
pub const FULL_PARTIAL_ROTARY_FACTOR: f32 = 0.25;
|
||||
pub const FINAL_LOGIT_SOFTCAP: f32 = 30.0;
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
struct LayerSpec {
|
||||
is_sliding: bool,
|
||||
head_dim: usize,
|
||||
q_dim: usize,
|
||||
num_kv_heads: usize,
|
||||
kv_dim: usize,
|
||||
kv_groups: usize,
|
||||
rope_theta: f32,
|
||||
partial_rotary_factor: f32,
|
||||
has_v_proj: bool,
|
||||
}
|
||||
|
||||
fn layer_spec(layer: usize) -> LayerSpec {
|
||||
if !(layer + 1).is_multiple_of(6) {
|
||||
LayerSpec {
|
||||
is_sliding: true,
|
||||
head_dim: SLIDING_HEAD_DIM,
|
||||
q_dim: N_HEADS * SLIDING_HEAD_DIM,
|
||||
num_kv_heads: SLIDING_KV_HEADS,
|
||||
kv_dim: SLIDING_KV_HEADS * SLIDING_HEAD_DIM,
|
||||
kv_groups: N_HEADS / SLIDING_KV_HEADS,
|
||||
rope_theta: SLIDING_ROPE_THETA,
|
||||
partial_rotary_factor: 1.0,
|
||||
has_v_proj: true,
|
||||
}
|
||||
} else {
|
||||
LayerSpec {
|
||||
is_sliding: false,
|
||||
head_dim: FULL_HEAD_DIM,
|
||||
q_dim: N_HEADS * FULL_HEAD_DIM,
|
||||
num_kv_heads: FULL_KV_HEADS,
|
||||
kv_dim: FULL_KV_HEADS * FULL_HEAD_DIM,
|
||||
kv_groups: N_HEADS / FULL_KV_HEADS,
|
||||
rope_theta: FULL_ROPE_THETA,
|
||||
partial_rotary_factor: FULL_PARTIAL_ROTARY_FACTOR,
|
||||
has_v_proj: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn cache_bytes_for_layer(layer: usize, max_seq: usize) -> usize {
|
||||
let spec = layer_spec(layer);
|
||||
spec.num_kv_heads * max_seq * spec.head_dim * std::mem::size_of::<f32>()
|
||||
}
|
||||
|
||||
pub struct KVCache {
|
||||
pub k_caches: Vec<GraphTensor>,
|
||||
pub v_caches: Vec<GraphTensor>,
|
||||
pub max_seq: usize,
|
||||
}
|
||||
|
||||
impl KVCache {
|
||||
pub fn new(cx: &mut Graph, max_seq: usize) -> Self {
|
||||
let mut k_caches = Vec::with_capacity(LAYERS);
|
||||
let mut v_caches = Vec::with_capacity(LAYERS);
|
||||
for layer in 0..LAYERS {
|
||||
let spec = layer_spec(layer);
|
||||
let k = cx
|
||||
.named_tensor(
|
||||
format!("kv_cache.{layer}.k"),
|
||||
(spec.num_kv_heads, max_seq, spec.head_dim),
|
||||
)
|
||||
.persist();
|
||||
let v = cx
|
||||
.named_tensor(
|
||||
format!("kv_cache.{layer}.v"),
|
||||
(spec.num_kv_heads, max_seq, spec.head_dim),
|
||||
)
|
||||
.persist();
|
||||
k_caches.push(k);
|
||||
v_caches.push(v);
|
||||
}
|
||||
Self {
|
||||
k_caches,
|
||||
v_caches,
|
||||
max_seq,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Gemma4MoE {
|
||||
embedding: GraphTensor,
|
||||
lm_head: GraphTensor,
|
||||
layers: Vec<Gemma4Layer>,
|
||||
lm_norm: LayerNorm,
|
||||
}
|
||||
|
||||
impl Gemma4MoE {
|
||||
pub fn init(cx: &mut Graph) -> Self {
|
||||
let mut layers = Vec::with_capacity(LAYERS);
|
||||
for layer in 0..LAYERS {
|
||||
let spec = layer_spec(layer);
|
||||
let gate = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.mlp.gate_proj.weight"),
|
||||
(INTERMEDIATE, HIDDEN),
|
||||
)
|
||||
.persist();
|
||||
let up = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.mlp.up_proj.weight"),
|
||||
(INTERMEDIATE, HIDDEN),
|
||||
)
|
||||
.persist();
|
||||
let down = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.mlp.down_proj.weight"),
|
||||
(HIDDEN, INTERMEDIATE),
|
||||
)
|
||||
.persist();
|
||||
|
||||
let q_proj = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.self_attn.q_proj.weight"),
|
||||
(spec.q_dim, HIDDEN),
|
||||
)
|
||||
.persist();
|
||||
let k_proj = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.self_attn.k_proj.weight"),
|
||||
(spec.kv_dim, HIDDEN),
|
||||
)
|
||||
.persist();
|
||||
let v_proj = spec.has_v_proj.then(|| {
|
||||
cx.named_tensor(
|
||||
format!("model.layers.{layer}.self_attn.v_proj.weight"),
|
||||
(spec.kv_dim, HIDDEN),
|
||||
)
|
||||
.persist()
|
||||
});
|
||||
let o_proj = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.self_attn.o_proj.weight"),
|
||||
(HIDDEN, spec.q_dim),
|
||||
)
|
||||
.persist();
|
||||
let q_norm = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.self_attn.q_norm.weight"),
|
||||
spec.head_dim,
|
||||
)
|
||||
.persist();
|
||||
let k_norm = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.self_attn.k_norm.weight"),
|
||||
spec.head_dim,
|
||||
)
|
||||
.persist();
|
||||
let layer_scalar = cx
|
||||
.named_tensor(format!("model.layers.{layer}.layer_scalar"), HIDDEN)
|
||||
.persist();
|
||||
|
||||
let router_scale = cx
|
||||
.named_tensor(format!("model.layers.{layer}.router.scale"), HIDDEN)
|
||||
.persist();
|
||||
let router_proj = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.router.proj.weight"),
|
||||
(NUM_EXPERTS, HIDDEN),
|
||||
)
|
||||
.persist();
|
||||
let per_expert_scale = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.router.per_expert_scale"),
|
||||
NUM_EXPERTS,
|
||||
)
|
||||
.persist();
|
||||
let gate_up_weights = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.experts.gate_up_proj"),
|
||||
(NUM_EXPERTS, MOE_INTERMEDIATE * 2, HIDDEN),
|
||||
)
|
||||
.persist()
|
||||
.as_dtype(DType::Bf16);
|
||||
let down_weights = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.experts.down_proj"),
|
||||
(NUM_EXPERTS, HIDDEN, MOE_INTERMEDIATE),
|
||||
)
|
||||
.persist()
|
||||
.as_dtype(DType::Bf16);
|
||||
|
||||
layers.push(Gemma4Layer {
|
||||
spec,
|
||||
gate,
|
||||
up,
|
||||
down,
|
||||
q_proj,
|
||||
k_proj,
|
||||
v_proj,
|
||||
o_proj,
|
||||
q_norm,
|
||||
k_norm,
|
||||
layer_scalar,
|
||||
input_layernorm: gemma4_norm(
|
||||
HIDDEN,
|
||||
&format!("model.layers.{layer}.input_layernorm.weight"),
|
||||
cx,
|
||||
),
|
||||
post_attention_layernorm: gemma4_norm(
|
||||
HIDDEN,
|
||||
&format!("model.layers.{layer}.post_attention_layernorm.weight"),
|
||||
cx,
|
||||
),
|
||||
pre_feedforward_layernorm: gemma4_norm(
|
||||
HIDDEN,
|
||||
&format!("model.layers.{layer}.pre_feedforward_layernorm.weight"),
|
||||
cx,
|
||||
),
|
||||
post_feedforward_layernorm: gemma4_norm(
|
||||
HIDDEN,
|
||||
&format!("model.layers.{layer}.post_feedforward_layernorm.weight"),
|
||||
cx,
|
||||
),
|
||||
post_feedforward_layernorm_1: gemma4_norm(
|
||||
HIDDEN,
|
||||
&format!("model.layers.{layer}.post_feedforward_layernorm_1.weight"),
|
||||
cx,
|
||||
),
|
||||
post_feedforward_layernorm_2: gemma4_norm(
|
||||
HIDDEN,
|
||||
&format!("model.layers.{layer}.post_feedforward_layernorm_2.weight"),
|
||||
cx,
|
||||
),
|
||||
pre_feedforward_layernorm_2: gemma4_norm(
|
||||
HIDDEN,
|
||||
&format!("model.layers.{layer}.pre_feedforward_layernorm_2.weight"),
|
||||
cx,
|
||||
),
|
||||
moe: Gemma4SparseMoE {
|
||||
router_scale,
|
||||
router_proj,
|
||||
per_expert_scale,
|
||||
gate_up_weights,
|
||||
down_weights,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
let embedding = cx
|
||||
.named_tensor("model.embed_tokens.weight", (VOCAB_SIZE, HIDDEN))
|
||||
.persist();
|
||||
let lm_head = cx
|
||||
.named_tensor("lm_head.weight", (VOCAB_SIZE, HIDDEN))
|
||||
.persist();
|
||||
let lm_norm = gemma4_norm(HIDDEN, "model.norm.weight", cx);
|
||||
|
||||
Self {
|
||||
embedding,
|
||||
lm_head,
|
||||
layers,
|
||||
lm_norm,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward(
|
||||
&self,
|
||||
token_ids: GraphTensor,
|
||||
pos_ids: GraphTensor,
|
||||
kv_cache: &KVCache,
|
||||
) -> (GraphTensor, Vec<(GraphTensor, GraphTensor)>) {
|
||||
let seq = token_ids.dims1();
|
||||
let mut x = self.embedding.gather(
|
||||
(token_ids * HIDDEN).expand_dim(1, HIDDEN)
|
||||
+ token_ids.graph().arange(HIDDEN).expand_dim(0, seq),
|
||||
);
|
||||
|
||||
let mut cache_outputs = Vec::with_capacity(LAYERS);
|
||||
for (layer_idx, layer) in self.layers.iter().enumerate() {
|
||||
let (x_new, k_out, v_out) = layer.forward(
|
||||
x,
|
||||
pos_ids,
|
||||
kv_cache.k_caches[layer_idx],
|
||||
kv_cache.v_caches[layer_idx],
|
||||
kv_cache.max_seq,
|
||||
);
|
||||
x = x_new;
|
||||
cache_outputs.push((k_out, v_out));
|
||||
}
|
||||
|
||||
let logits = self.lm_norm.forward(x).matmul(self.lm_head.t());
|
||||
let logits = (logits / FINAL_LOGIT_SOFTCAP).tanh() * FINAL_LOGIT_SOFTCAP;
|
||||
(logits, cache_outputs)
|
||||
}
|
||||
}
|
||||
|
||||
struct Gemma4Layer {
|
||||
spec: LayerSpec,
|
||||
gate: GraphTensor,
|
||||
up: GraphTensor,
|
||||
down: GraphTensor,
|
||||
q_proj: GraphTensor,
|
||||
k_proj: GraphTensor,
|
||||
v_proj: Option<GraphTensor>,
|
||||
o_proj: GraphTensor,
|
||||
q_norm: GraphTensor,
|
||||
k_norm: GraphTensor,
|
||||
layer_scalar: GraphTensor,
|
||||
input_layernorm: LayerNorm,
|
||||
post_attention_layernorm: LayerNorm,
|
||||
pre_feedforward_layernorm: LayerNorm,
|
||||
post_feedforward_layernorm: LayerNorm,
|
||||
post_feedforward_layernorm_1: LayerNorm,
|
||||
post_feedforward_layernorm_2: LayerNorm,
|
||||
pre_feedforward_layernorm_2: LayerNorm,
|
||||
moe: Gemma4SparseMoE,
|
||||
}
|
||||
|
||||
struct Gemma4SparseMoE {
|
||||
router_scale: GraphTensor,
|
||||
router_proj: GraphTensor,
|
||||
per_expert_scale: GraphTensor,
|
||||
gate_up_weights: GraphTensor,
|
||||
down_weights: GraphTensor,
|
||||
}
|
||||
|
||||
fn gemma4_norm(dim: usize, weight_name: &str, cx: &mut Graph) -> LayerNorm {
|
||||
LayerNorm::new(dim, Some(weight_name), None, false, RMS_NORM_EPS, cx)
|
||||
}
|
||||
|
||||
#[allow(clippy::excessive_precision)]
|
||||
fn gemma_gelu(x: GraphTensor) -> GraphTensor {
|
||||
let scaled = 1.5957691216 * x * (1. + 0.044715 * x * x);
|
||||
x * scaled.sigmoid()
|
||||
}
|
||||
|
||||
fn qk_norm(x: GraphTensor, weight: GraphTensor, n_heads: usize, head_dim: usize) -> GraphTensor {
|
||||
let seq = x.dims()[0];
|
||||
let reshaped = x.split_dims(1, head_dim);
|
||||
let normed = reshaped.std_norm(2, RMS_NORM_EPS);
|
||||
let w = weight.expand_dim(0, n_heads).expand_dim(0, seq);
|
||||
(normed * w).merge_dims(1, 2)
|
||||
}
|
||||
|
||||
fn value_norm(x: GraphTensor, head_dim: usize) -> GraphTensor {
|
||||
x.split_dims(1, head_dim)
|
||||
.std_norm(2, RMS_NORM_EPS)
|
||||
.merge_dims(1, 2)
|
||||
}
|
||||
|
||||
fn gemma4_rotary_embeddings(
|
||||
input: GraphTensor,
|
||||
pos_ids: GraphTensor,
|
||||
n_heads: usize,
|
||||
head_dim: usize,
|
||||
rope_theta: f32,
|
||||
partial_rotary_factor: f32,
|
||||
) -> GraphTensor {
|
||||
let input = input.split_dims(1, head_dim).transpose(0, 1);
|
||||
let half_dim = head_dim / 2;
|
||||
let rope_angles = ((partial_rotary_factor * head_dim as f32) / 2.0).floor() as usize;
|
||||
|
||||
let rotated = input
|
||||
.graph()
|
||||
.arange_options(0, rope_angles * 2, 2)
|
||||
.cast(DType::F32)
|
||||
/ head_dim as f32;
|
||||
let rotated = rope_theta.pow(rotated).reciprocal();
|
||||
let inv_freqs = if rope_angles < half_dim {
|
||||
let zeros = input
|
||||
.graph()
|
||||
.arange(half_dim - rope_angles)
|
||||
.cast(DType::F32)
|
||||
* 0.0;
|
||||
rotated.concat_along(zeros, 0)
|
||||
} else {
|
||||
rotated
|
||||
};
|
||||
|
||||
let emb = pos_ids
|
||||
.cast(DType::F32)
|
||||
.expand_dim(1, 1)
|
||||
.matmul(inv_freqs.expand_dim(0, 1));
|
||||
|
||||
let x0 = input.slice((.., .., ..half_dim));
|
||||
let x1 = input.slice((.., .., half_dim..));
|
||||
|
||||
let cos = emb.cos().expand_dim(0, n_heads);
|
||||
let sin = emb.sin().expand_dim(0, n_heads);
|
||||
let x0_out = x0 * cos - x1 * sin;
|
||||
let x1_out = x1 * cos + x0 * sin;
|
||||
|
||||
x0_out
|
||||
.concat_along(x1_out, 2)
|
||||
.transpose(0, 1)
|
||||
.merge_dims(1, 2)
|
||||
}
|
||||
|
||||
fn gather_experts(
|
||||
graph_source: GraphTensor,
|
||||
top_k_indices: GraphTensor,
|
||||
weights: GraphTensor,
|
||||
) -> GraphTensor {
|
||||
let (_, d1, d2) = weights.dims3();
|
||||
let io = d1 * d2;
|
||||
let base = top_k_indices * io;
|
||||
let within = graph_source.graph().iota(Expression::from('z'), (d1, d2));
|
||||
let n_base = base.dims().len();
|
||||
let exp_base = base.expand_dim(n_base, d1).expand_dim(n_base + 1, d2);
|
||||
let mut exp_within = within;
|
||||
for (axis, dim) in base.dims().iter().enumerate() {
|
||||
exp_within = exp_within.expand_dim(axis, *dim);
|
||||
}
|
||||
let expert_flat_idx = exp_base + exp_within;
|
||||
weights.gather(expert_flat_idx)
|
||||
}
|
||||
|
||||
fn hlir_attention(
|
||||
q_rope: GraphTensor,
|
||||
k_rope: GraphTensor,
|
||||
v: GraphTensor,
|
||||
k_cache_in: GraphTensor,
|
||||
v_cache_in: GraphTensor,
|
||||
max_seq: usize,
|
||||
spec: LayerSpec,
|
||||
) -> (GraphTensor, GraphTensor, GraphTensor) {
|
||||
let cx = q_rope.graph();
|
||||
let seq = q_rope.dims()[0];
|
||||
let prev = Expression::from('p');
|
||||
let total_seq = prev + seq;
|
||||
|
||||
let k_new = k_rope.split_dims(1, spec.head_dim).transpose(0, 1);
|
||||
let v_new = v.split_dims(1, spec.head_dim).transpose(0, 1);
|
||||
|
||||
let h_offset = cx.arange(spec.num_kv_heads) * (max_seq * spec.head_dim);
|
||||
let p_offset = (cx.arange(seq) + prev) * spec.head_dim;
|
||||
let d_offset = cx.arange(spec.head_dim);
|
||||
let scatter_idx = h_offset.expand_dim(1, seq).expand_dim(2, spec.head_dim)
|
||||
+ p_offset
|
||||
.expand_dim(0, spec.num_kv_heads)
|
||||
.expand_dim(2, spec.head_dim)
|
||||
+ d_offset.expand_dim(0, spec.num_kv_heads).expand_dim(1, seq);
|
||||
|
||||
let k_cache_out = k_new.scatter(scatter_idx, k_cache_in);
|
||||
let v_cache_out = v_new.scatter(scatter_idx, v_cache_in);
|
||||
|
||||
let k_full = k_cache_out.slice((.., ..total_seq, ..));
|
||||
let v_full = v_cache_out.slice((.., ..total_seq, ..));
|
||||
|
||||
let k_3d = k_full.expand_dim(1, spec.kv_groups).merge_dims(0, 1);
|
||||
let v_3d = v_full.expand_dim(1, spec.kv_groups).merge_dims(0, 1);
|
||||
let q = q_rope.split_dims(1, spec.head_dim).transpose(0, 1);
|
||||
|
||||
// Gemma 4's text attention uses Q/K normalization and then leaves the
|
||||
// attention scaling at 1.0 in the reference implementation.
|
||||
let scores = q.matmul(k_3d.transpose(1, 2));
|
||||
|
||||
let q_abs = cx.arange(seq).cast(DType::F32) + prev;
|
||||
let k_pos = cx.arange(total_seq).cast(DType::F32);
|
||||
let future_mask = k_pos
|
||||
.expand_dim(0, seq)
|
||||
.gt(q_abs.expand_dim(1, total_seq))
|
||||
.cast(DType::F32);
|
||||
|
||||
let mask_2d = if spec.is_sliding {
|
||||
let window_start = q_abs - (SLIDING_WINDOW_SIZE - 1) as f32;
|
||||
let past_mask = window_start
|
||||
.expand_dim(1, total_seq)
|
||||
.gt(k_pos.expand_dim(0, seq))
|
||||
.cast(DType::F32);
|
||||
future_mask + past_mask
|
||||
} else {
|
||||
future_mask
|
||||
};
|
||||
let mask_3d = mask_2d.expand_dim(0, N_HEADS);
|
||||
let masked_scores = scores + mask_3d * (-1e10f32);
|
||||
|
||||
let attn_weights = masked_scores.softmax(2);
|
||||
let attn_out = attn_weights.matmul(v_3d);
|
||||
let out = attn_out.transpose(0, 1).merge_dims(1, 2);
|
||||
|
||||
(out, k_cache_out, v_cache_out)
|
||||
}
|
||||
|
||||
impl Gemma4Layer {
|
||||
pub fn forward(
|
||||
&self,
|
||||
x: GraphTensor,
|
||||
pos_ids: GraphTensor,
|
||||
k_cache_in: GraphTensor,
|
||||
v_cache_in: GraphTensor,
|
||||
max_seq: usize,
|
||||
) -> (GraphTensor, GraphTensor, GraphTensor) {
|
||||
let residual = x;
|
||||
let x_attn = self.input_layernorm.forward(x);
|
||||
let q = x_attn.matmul(self.q_proj.t());
|
||||
let k_base = x_attn.matmul(self.k_proj.t());
|
||||
let v_base = if let Some(v_proj) = self.v_proj {
|
||||
x_attn.matmul(v_proj.t())
|
||||
} else {
|
||||
k_base
|
||||
};
|
||||
|
||||
let q_normed = qk_norm(q, self.q_norm, N_HEADS, self.spec.head_dim);
|
||||
let k_normed = qk_norm(
|
||||
k_base,
|
||||
self.k_norm,
|
||||
self.spec.num_kv_heads,
|
||||
self.spec.head_dim,
|
||||
);
|
||||
let v_normed = value_norm(v_base, self.spec.head_dim);
|
||||
|
||||
let q_rope = gemma4_rotary_embeddings(
|
||||
q_normed,
|
||||
pos_ids,
|
||||
N_HEADS,
|
||||
self.spec.head_dim,
|
||||
self.spec.rope_theta,
|
||||
self.spec.partial_rotary_factor,
|
||||
);
|
||||
let k_rope = gemma4_rotary_embeddings(
|
||||
k_normed,
|
||||
pos_ids,
|
||||
self.spec.num_kv_heads,
|
||||
self.spec.head_dim,
|
||||
self.spec.rope_theta,
|
||||
self.spec.partial_rotary_factor,
|
||||
);
|
||||
|
||||
let (attn_out, k_cache_out, v_cache_out) = hlir_attention(
|
||||
q_rope, k_rope, v_normed, k_cache_in, v_cache_in, max_seq, self.spec,
|
||||
);
|
||||
|
||||
let attn_proj = attn_out.matmul(self.o_proj.t());
|
||||
let x = residual + self.post_attention_layernorm.forward(attn_proj);
|
||||
|
||||
let dense_ff = dense_ffn(
|
||||
self.pre_feedforward_layernorm.forward(x),
|
||||
self.gate,
|
||||
self.up,
|
||||
self.down,
|
||||
);
|
||||
let dense_ff = self.post_feedforward_layernorm_1.forward(dense_ff);
|
||||
|
||||
let moe_out = self
|
||||
.moe
|
||||
.forward(x, self.pre_feedforward_layernorm_2.forward(x));
|
||||
let moe_out = self.post_feedforward_layernorm_2.forward(moe_out);
|
||||
|
||||
let ff_out = self.post_feedforward_layernorm.forward(dense_ff + moe_out);
|
||||
let x = x + ff_out;
|
||||
let x = x * self
|
||||
.layer_scalar
|
||||
.expand_lhs(&x.dims()[..x.dims().len() - 1]);
|
||||
|
||||
(x, k_cache_out, v_cache_out)
|
||||
}
|
||||
}
|
||||
|
||||
fn dense_ffn(x: GraphTensor, gate: GraphTensor, up: GraphTensor, down: GraphTensor) -> GraphTensor {
|
||||
(gemma_gelu(x.matmul(gate.t())) * x.matmul(up.t())).matmul(down.t())
|
||||
}
|
||||
|
||||
impl Gemma4SparseMoE {
|
||||
fn forward(&self, router_input: GraphTensor, expert_input: GraphTensor) -> GraphTensor {
|
||||
let n = router_input.dims().len();
|
||||
let e_dim = *self.router_proj.dims().first().unwrap();
|
||||
let k_expr = Expression::from(TOP_K);
|
||||
|
||||
let router_hidden = router_input.std_norm(router_input.dims().len() - 1, RMS_NORM_EPS)
|
||||
* self
|
||||
.router_scale
|
||||
.expand_lhs(&router_input.dims()[..router_input.dims().len() - 1])
|
||||
* (HIDDEN as f32).sqrt().recip();
|
||||
let routing_weights = router_hidden.matmul(self.router_proj.t()).softmax(n - 1);
|
||||
|
||||
let top_k_indices = routing_weights.topk_indexes(TOP_K, n - 1);
|
||||
let row_offsets = router_input
|
||||
.graph()
|
||||
.iota(Expression::from('z') / k_expr * e_dim, top_k_indices.dims());
|
||||
let routing_flat_idx = row_offsets + top_k_indices;
|
||||
let top_k_values = routing_weights.gather(routing_flat_idx);
|
||||
let top_k_norm = top_k_values.sum(n - 1).expand_dim(n - 1, TOP_K);
|
||||
let top_k_weights =
|
||||
(top_k_values / top_k_norm) * self.per_expert_scale.gather(top_k_indices);
|
||||
|
||||
let gate_up_gathered =
|
||||
gather_experts(expert_input, top_k_indices, self.gate_up_weights).cast(DType::F32);
|
||||
let x_exp = expert_input.expand_dim(n - 1, TOP_K).unsqueeze(n);
|
||||
let gate_up_out = x_exp.matmul(gate_up_gathered.transpose(2, 3)).squeeze(n);
|
||||
|
||||
let gate = gate_up_out.slice((.., .., ..MOE_INTERMEDIATE));
|
||||
let up = gate_up_out.slice((.., .., MOE_INTERMEDIATE..));
|
||||
let hidden = gemma_gelu(gate) * up;
|
||||
|
||||
let down_gathered =
|
||||
gather_experts(expert_input, top_k_indices, self.down_weights).cast(DType::F32);
|
||||
let hidden_exp = hidden.unsqueeze(2);
|
||||
let down_out = hidden_exp.matmul(down_gathered.transpose(2, 3)).squeeze(2);
|
||||
|
||||
(down_out * top_k_weights.unsqueeze(top_k_weights.dims().len())).sum(n - 1)
|
||||
}
|
||||
}
|
||||
@@ -159,7 +159,8 @@ impl Llama {
|
||||
kv_cache.v_caches[i],
|
||||
kv_cache.max_seq,
|
||||
);
|
||||
x = x_new.graph_break();
|
||||
x = x_new;
|
||||
//x = x_new.graph_break();
|
||||
cache_outputs.push((k_out, v_out));
|
||||
}
|
||||
let logits = self.lm_norm.forward(x).matmul(self.lm_head.t());
|
||||
|
||||
@@ -157,7 +157,7 @@ impl Llama {
|
||||
kv_cache.k_caches[i],
|
||||
kv_cache.v_caches[i],
|
||||
);
|
||||
x = x_new.graph_break();
|
||||
x = x_new;
|
||||
cache_outputs.push((k_out, v_out));
|
||||
}
|
||||
let logits = self.lm_norm.forward(x).matmul(self.lm_head.t());
|
||||
|
||||
@@ -178,7 +178,7 @@ impl Qwen {
|
||||
kv_cache.v_caches[i],
|
||||
kv_cache.max_seq,
|
||||
);
|
||||
x = x_new.graph_break();
|
||||
x = x_new;
|
||||
cache_outputs.push((k_out, v_out));
|
||||
}
|
||||
// Tied embeddings: lm_head = embedding.t()
|
||||
|
||||
@@ -186,7 +186,7 @@ impl Qwen3MoE {
|
||||
kv_cache.v_caches[i],
|
||||
kv_cache.max_seq,
|
||||
);
|
||||
x = x_new.graph_break();
|
||||
x = x_new;
|
||||
cache_outputs.push((k_out, v_out));
|
||||
}
|
||||
let logits = self.lm_norm.forward(x).matmul(self.lm_head.t());
|
||||
@@ -239,7 +239,6 @@ impl Qwen3MoELayer {
|
||||
let (attn_out, k_cache_out, v_cache_out) =
|
||||
attention(q_rope, k_rope, v, k_cache_in, v_cache_in, max_seq);
|
||||
x += attn_out.matmul(self.o_proj.t());
|
||||
x = x.graph_break();
|
||||
|
||||
// MoE FFN
|
||||
let x_mlp = self.mlp_rms.forward(x);
|
||||
@@ -264,8 +263,7 @@ impl QwenMoE {
|
||||
let row_offsets = x
|
||||
.graph()
|
||||
.iota(Expression::from('z') / k_expr * e_dim, top_k_indices.dims());
|
||||
let routing_flat_idx =
|
||||
(row_offsets.cast(DType::F32) + top_k_indices.cast(DType::F32)).cast(DType::Int);
|
||||
let routing_flat_idx = row_offsets + top_k_indices;
|
||||
let top_k_values = routing_weights.gather(routing_flat_idx);
|
||||
|
||||
// 4. Gather gate_up expert weights → [s, k, intermediate*2, H]
|
||||
@@ -303,18 +301,18 @@ fn gather_experts(
|
||||
) -> GraphTensor {
|
||||
let (_, d1, d2) = weights.dims3();
|
||||
let io = d1 * d2;
|
||||
let base = (top_k_indices * io).cast(DType::F32);
|
||||
let within = graph_source
|
||||
.graph()
|
||||
.iota(Expression::from('z'), (d1, d2))
|
||||
.cast(DType::F32);
|
||||
// Keep expert gather indices in Int all the way through. Routing them through
|
||||
// F32 loses exactness once the flat offsets exceed 2^24, which Qwen's expert
|
||||
// tensors do at realistic hidden sizes.
|
||||
let base = top_k_indices * io;
|
||||
let within = graph_source.graph().iota(Expression::from('z'), (d1, d2));
|
||||
let n_base = base.dims().len();
|
||||
let exp_base = base.expand_dim(n_base, d1).expand_dim(n_base + 1, d2);
|
||||
let mut exp_within = within;
|
||||
for (i, dim) in base.dims().iter().enumerate() {
|
||||
exp_within = exp_within.expand_dim(i, *dim);
|
||||
}
|
||||
let expert_flat_idx = (exp_base + exp_within).cast(DType::Int);
|
||||
let expert_flat_idx = exp_base + exp_within;
|
||||
weights.gather(expert_flat_idx)
|
||||
}
|
||||
|
||||
|
||||
339
src/dyn_backend.rs
Normal file
339
src/dyn_backend.rs
Normal file
@@ -0,0 +1,339 @@
|
||||
//! Dynamic backend trait and factory-based compilation.
|
||||
//!
|
||||
//! This module provides:
|
||||
//! - [`DynBackend`]: an object-safe trait for dynamic backend dispatch
|
||||
//! - [`compile_backend`]: generic helper that handles the full compilation pipeline
|
||||
//! - [`BackendFactory`]: function pointer type for backend factories
|
||||
//! - [`NativeDynBackend`]: the reference implementation for CPU
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use half::{bf16, f16};
|
||||
use petgraph::stable_graph::NodeIndex;
|
||||
use rustc_hash::FxHashMap;
|
||||
|
||||
use crate::dtype::DType;
|
||||
use crate::graph::Graph;
|
||||
use crate::hlir::{NativeData, NativeRuntime, Output};
|
||||
use crate::op::Runtime;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// DynBackend trait
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Object-safe backend trait for dynamic dispatch.
|
||||
///
|
||||
/// Wraps a concrete [`Runtime`] implementor, providing a uniform interface
|
||||
/// for `luminal_python` (and other dynamic consumers) without requiring
|
||||
/// generic type parameters.
|
||||
pub trait DynBackend {
|
||||
fn name(&self) -> &str;
|
||||
|
||||
/// The device type this backend operates on (e.g. "cpu", "cuda").
|
||||
/// Used by the Python frontend to decide input tensor placement.
|
||||
fn device_type(&self) -> &str {
|
||||
"cpu"
|
||||
}
|
||||
|
||||
fn set_data_bytes(&mut self, node: NodeIndex, bytes: Vec<u8>, dtype: DType);
|
||||
fn set_data_f32(&mut self, node: NodeIndex, data: Vec<f32>);
|
||||
fn get_output_f32(&self, node: NodeIndex) -> Vec<f32>;
|
||||
fn get_output_i32(&self, _node: NodeIndex) -> Vec<i32> {
|
||||
panic!("get_output_i32 not supported by '{}'", self.name());
|
||||
}
|
||||
fn get_output_bool(&self, _node: NodeIndex) -> Vec<bool> {
|
||||
panic!("get_output_bool not supported by '{}'", self.name());
|
||||
}
|
||||
fn execute(&mut self, dyn_map: &FxHashMap<char, usize>);
|
||||
|
||||
// --- Optional device pointer support (GPU backends) --------------------
|
||||
|
||||
fn supports_device_ptrs(&self) -> bool {
|
||||
false
|
||||
}
|
||||
/// # Safety
|
||||
/// Device pointer must be valid and point to at least `n_bytes` bytes.
|
||||
unsafe fn set_device_ptr(&mut self, _node: NodeIndex, _ptr: u64, _n_bytes: usize) {
|
||||
panic!("set_device_ptr not supported by '{}'", self.name());
|
||||
}
|
||||
/// # Safety
|
||||
/// Device pointer must remain valid through the next `execute()` call.
|
||||
unsafe fn set_output_device_ptr(&mut self, _node: NodeIndex, _ptr: u64, _n_bytes: usize) {
|
||||
panic!("set_output_device_ptr not supported by '{}'", self.name());
|
||||
}
|
||||
fn output_is_zero_copy(&self, _node: NodeIndex) -> bool {
|
||||
false
|
||||
}
|
||||
/// # Safety
|
||||
/// `dest_ptr` must be a valid device allocation with at least `n_bytes`.
|
||||
unsafe fn copy_output_to_device_ptr(&self, _node: NodeIndex, _dest_ptr: u64, _n_bytes: usize) {
|
||||
panic!(
|
||||
"copy_output_to_device_ptr not supported by '{}'",
|
||||
self.name()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// BackendCompileArgs + BackendFactory + Registry
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Arguments passed to a backend factory during compilation.
|
||||
pub struct BackendCompileArgs {
|
||||
pub search_iters: usize,
|
||||
pub weights: Vec<(String, Vec<u8>, DType)>,
|
||||
pub tensor_sizes: HashMap<String, usize>,
|
||||
pub device_ptrs: HashMap<String, (u64, usize)>,
|
||||
}
|
||||
|
||||
/// Canonical PyCapsule name for [`BackendFactory`] function-pointer capsules.
|
||||
///
|
||||
/// Value MUST remain `"luminal.backend_factory"` for compatibility with
|
||||
/// external plugin producers built against older versions of this crate.
|
||||
pub const BACKEND_FACTORY_CAPSULE_NAME: &std::ffi::CStr = c"luminal.backend_factory";
|
||||
|
||||
/// A factory function that compiles a [`Graph`] into a ready-to-execute [`DynBackend`].
|
||||
pub type BackendFactory = fn(&mut Graph, BackendCompileArgs) -> Result<Box<dyn DynBackend>, String>;
|
||||
|
||||
/// Compile a graph using a factory function directly.
|
||||
pub fn compile_backend_from_factory(
|
||||
factory: BackendFactory,
|
||||
graph: &mut Graph,
|
||||
args: BackendCompileArgs,
|
||||
) -> Result<Box<dyn DynBackend>, String> {
|
||||
factory(graph, args)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// compile_backend — generic compilation helper
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Optional callback for uploading a device pointer + byte count to a node.
|
||||
pub type SetDevicePtrFn<'a, Rt> = &'a dyn Fn(&mut Rt, NodeIndex, u64, usize);
|
||||
|
||||
/// Generic compilation pipeline shared by all backends.
|
||||
///
|
||||
/// Handles: build search space → init runtime → set device ptrs → set dummy
|
||||
/// data → search → load weights → wrap as `Box<dyn DynBackend>`.
|
||||
///
|
||||
/// Backend-specific behavior is injected via callbacks:
|
||||
/// - `init`: create the concrete runtime
|
||||
/// - `set_raw`: upload raw bytes + dtype to a node
|
||||
/// - `set_device_ptr`: optional zero-copy device pointer setter
|
||||
/// - `wrap`: wrap the final runtime in a `Box<dyn DynBackend>`
|
||||
pub fn compile_backend<Rt: Runtime + 'static>(
|
||||
graph: &mut Graph,
|
||||
args: BackendCompileArgs,
|
||||
init: impl FnOnce() -> Result<Rt, String>,
|
||||
set_raw: impl Fn(&mut Rt, NodeIndex, Vec<u8>, DType),
|
||||
set_device_ptr: Option<SetDevicePtrFn<'_, Rt>>,
|
||||
wrap: impl FnOnce(Rt) -> Box<dyn DynBackend>,
|
||||
) -> Result<Box<dyn DynBackend>, String> {
|
||||
// Build label map from input_meta (plain data — no downcast needed,
|
||||
// survives cross-binary type identity mismatches with external plugins).
|
||||
let label_map = build_label_map(graph);
|
||||
|
||||
graph.build_search_space::<Rt>();
|
||||
|
||||
let mut rt = init()?;
|
||||
|
||||
// Set device pointers for zero-copy weights (GPU backends)
|
||||
let mut device_ptr_nodes = rustc_hash::FxHashSet::default();
|
||||
if let Some(set_ptr) = set_device_ptr {
|
||||
for (label, &(ptr, n_bytes)) in &args.device_ptrs {
|
||||
if let Some(&node_id) = label_map.get(label) {
|
||||
set_ptr(&mut rt, node_id, ptr, n_bytes);
|
||||
device_ptr_nodes.insert(node_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Set dummy ones for Input nodes (required for search profiling).
|
||||
// Must use 1, NOT 0 — zero inputs cause NaN in many ops.
|
||||
for (&node_id, (label, dtype)) in &graph.input_meta {
|
||||
if device_ptr_nodes.contains(&node_id) {
|
||||
continue;
|
||||
}
|
||||
if let Some(&n) = args.tensor_sizes.get(label) {
|
||||
if n > 0 {
|
||||
set_raw(&mut rt, node_id, make_ones_bytes(n, *dtype), *dtype);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Search
|
||||
let mut rt = graph.search(rt, args.search_iters);
|
||||
|
||||
// Rebuild label map after search (graph may have changed)
|
||||
let label_map = build_label_map(graph);
|
||||
|
||||
// Load real weights post-search (skip device-ptr weights)
|
||||
for (label, bytes, dtype) in &args.weights {
|
||||
if !args.device_ptrs.contains_key(label) {
|
||||
if let Some(&node_id) = label_map.get(label) {
|
||||
set_raw(&mut rt, node_id, bytes.clone(), *dtype);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(wrap(rt))
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Shared utilities
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Build a `label → NodeIndex` map for all Input nodes in the graph.
|
||||
///
|
||||
/// Uses `graph.input_meta` (plain data) rather than downcasting, so it works
|
||||
/// correctly when the graph was built by a different compilation unit (e.g.
|
||||
/// an external backend plugin compiled as a separate wheel).
|
||||
pub fn build_label_map(graph: &Graph) -> HashMap<String, NodeIndex> {
|
||||
graph
|
||||
.input_meta
|
||||
.iter()
|
||||
.map(|(&node_id, (label, _))| (label.clone(), node_id))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Create a byte buffer of `n_elements` ones for the given dtype.
|
||||
///
|
||||
/// IMPORTANT: Must use 1, NOT 0 — zero inputs cause NaN in many ops
|
||||
/// (fmod, recip, log, etc.) during search profiling.
|
||||
pub fn make_ones_bytes(n_elements: usize, dtype: DType) -> Vec<u8> {
|
||||
// Safety: all source types have defined bit representations; we just
|
||||
// reinterpret the backing Vec<u8> without changing the allocation.
|
||||
unsafe fn as_bytes<T>(v: Vec<T>) -> Vec<u8> {
|
||||
let mut v = std::mem::ManuallyDrop::new(v);
|
||||
let ptr = v.as_mut_ptr() as *mut u8;
|
||||
let len = v.len() * std::mem::size_of::<T>();
|
||||
unsafe { Vec::from_raw_parts(ptr, len, len) }
|
||||
}
|
||||
match dtype {
|
||||
DType::F32 | DType::TF32 => unsafe { as_bytes(vec![1.0f32; n_elements]) },
|
||||
DType::F64 => unsafe { as_bytes(vec![1.0f64; n_elements]) },
|
||||
DType::F16 => unsafe { as_bytes(vec![f16::from_f32(1.0); n_elements]) },
|
||||
DType::Bf16 => unsafe { as_bytes(vec![bf16::from_f32(1.0); n_elements]) },
|
||||
DType::Int => unsafe { as_bytes(vec![1i32; n_elements]) },
|
||||
DType::I16 => unsafe { as_bytes(vec![1i16; n_elements]) },
|
||||
DType::U16 => unsafe { as_bytes(vec![1u16; n_elements]) },
|
||||
_ => vec![1u8; n_elements], // I8, U8, Bool, sub-byte types
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert raw bytes + [`DType`] to [`NativeData`].
|
||||
pub fn bytes_to_native_data(bytes: Vec<u8>, dtype: DType) -> NativeData {
|
||||
// Safety: source bytes are from a valid typed buffer; we reinterpret.
|
||||
unsafe fn from_bytes<T: Copy>(bytes: Vec<u8>) -> Vec<T> {
|
||||
let n = bytes.len() / std::mem::size_of::<T>();
|
||||
let mut bytes = std::mem::ManuallyDrop::new(bytes);
|
||||
unsafe { Vec::from_raw_parts(bytes.as_mut_ptr() as *mut T, n, n) }
|
||||
}
|
||||
match dtype {
|
||||
DType::F32 | DType::TF32 => NativeData::F32(unsafe { from_bytes(bytes) }),
|
||||
DType::F64 => {
|
||||
let f64s: Vec<f64> = unsafe { from_bytes(bytes) };
|
||||
NativeData::F32(f64s.into_iter().map(|v| v as f32).collect())
|
||||
}
|
||||
DType::F16 => NativeData::F16(unsafe { from_bytes(bytes) }),
|
||||
DType::Bf16 => NativeData::Bf16(unsafe { from_bytes(bytes) }),
|
||||
DType::Int => NativeData::Int(unsafe { from_bytes(bytes) }),
|
||||
DType::Bool => NativeData::Bool(bytes.into_iter().map(|b| b != 0).collect()),
|
||||
DType::I8 => NativeData::Int(bytes.iter().map(|&b| b as i8 as i32).collect()),
|
||||
DType::U8 => NativeData::Int(bytes.iter().map(|&b| b as i32).collect()),
|
||||
DType::I16 => {
|
||||
let i16s: Vec<i16> = unsafe { from_bytes(bytes) };
|
||||
NativeData::Int(i16s.into_iter().map(|v| v as i32).collect())
|
||||
}
|
||||
DType::U16 => {
|
||||
let u16s: Vec<u16> = unsafe { from_bytes(bytes) };
|
||||
NativeData::Int(u16s.into_iter().map(|v| v as i32).collect())
|
||||
}
|
||||
_ => NativeData::F32(vec![]),
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// NativeDynBackend
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// [`DynBackend`] wrapper for the native (CPU) runtime.
|
||||
pub struct NativeDynBackend {
|
||||
pub runtime: NativeRuntime,
|
||||
}
|
||||
|
||||
impl DynBackend for NativeDynBackend {
|
||||
fn name(&self) -> &str {
|
||||
"native"
|
||||
}
|
||||
|
||||
fn set_data_bytes(&mut self, node: NodeIndex, bytes: Vec<u8>, dtype: DType) {
|
||||
self.runtime
|
||||
.set_data(node, bytes_to_native_data(bytes, dtype));
|
||||
}
|
||||
|
||||
fn set_data_f32(&mut self, node: NodeIndex, data: Vec<f32>) {
|
||||
self.runtime.set_data(node, data);
|
||||
}
|
||||
|
||||
fn get_output_f32(&self, node: NodeIndex) -> Vec<f32> {
|
||||
let data = self.output_buffer(node);
|
||||
(0..data.len()).map(|i| data.f32(i)).collect()
|
||||
}
|
||||
|
||||
fn get_output_i32(&self, node: NodeIndex) -> Vec<i32> {
|
||||
let data = self.output_buffer(node);
|
||||
(0..data.len()).map(|i| data.i32(i)).collect()
|
||||
}
|
||||
|
||||
fn get_output_bool(&self, node: NodeIndex) -> Vec<bool> {
|
||||
let data = self.output_buffer(node);
|
||||
(0..data.len()).map(|i| data.bool(i)).collect()
|
||||
}
|
||||
|
||||
fn execute(&mut self, dyn_map: &FxHashMap<char, usize>) {
|
||||
self.runtime.execute(dyn_map);
|
||||
}
|
||||
}
|
||||
|
||||
impl NativeDynBackend {
|
||||
fn output_buffer(&self, node: NodeIndex) -> &NativeData {
|
||||
let output_id = self
|
||||
.runtime
|
||||
.graph
|
||||
.node_indices()
|
||||
.find(|n| {
|
||||
(**self.runtime.graph[*n])
|
||||
.as_any()
|
||||
.downcast_ref::<Output>()
|
||||
.is_some_and(|out| out.node == node.index())
|
||||
})
|
||||
.unwrap_or_else(|| panic!("No output node found for {:?}", node));
|
||||
self.runtime
|
||||
.buffers
|
||||
.get(&output_id)
|
||||
.unwrap_or_else(|| panic!("No buffer data for output {:?}", node))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn native_factory(
|
||||
graph: &mut Graph,
|
||||
args: BackendCompileArgs,
|
||||
) -> Result<Box<dyn DynBackend>, String> {
|
||||
compile_backend::<NativeRuntime>(
|
||||
graph,
|
||||
args,
|
||||
|| Ok(NativeRuntime::default()),
|
||||
// NativeRuntime::set_data requires the LLIR graph to be loaded (it searches
|
||||
// for Input nodes in the LLIR). Before search, the LLIR is empty. We guard
|
||||
// against that: if rt.graph is empty, skip (dummy data isn't needed for
|
||||
// native since its profile is a no-op).
|
||||
|rt, node, bytes, dtype| {
|
||||
if rt.graph.node_count() > 0 {
|
||||
rt.set_data(node, bytes_to_native_data(bytes, dtype));
|
||||
}
|
||||
},
|
||||
None,
|
||||
|rt| Box::new(NativeDynBackend { runtime: rt }),
|
||||
)
|
||||
}
|
||||
@@ -232,6 +232,8 @@ pub struct BaseSorts {
|
||||
pub bf16_dt: SortDef,
|
||||
pub int_dt: SortDef,
|
||||
pub bool_dt: SortDef,
|
||||
pub i4_dt: SortDef,
|
||||
pub tf32_dt: SortDef,
|
||||
// Egglog builtin primitives (for term construction only)
|
||||
pub p_add: SortDef,
|
||||
pub p_sub: SortDef,
|
||||
@@ -310,6 +312,8 @@ impl BaseSorts {
|
||||
bf16_dt: sort(DTYPE, "Bf16", &[]),
|
||||
int_dt: sort(DTYPE, "Int", &[]),
|
||||
bool_dt: sort(DTYPE, "Bool", &[]),
|
||||
i4_dt: sort(DTYPE, "I4", &[]),
|
||||
tf32_dt: sort(DTYPE, "TF32", &[]),
|
||||
p_add: func("+", &["a", "b"]),
|
||||
p_sub: func("-", &["a", "b"]),
|
||||
p_mul: func("*", &["a", "b"]),
|
||||
@@ -363,6 +367,8 @@ impl BaseSorts {
|
||||
&self.bf16_dt,
|
||||
&self.int_dt,
|
||||
&self.bool_dt,
|
||||
&self.i4_dt,
|
||||
&self.tf32_dt,
|
||||
] {
|
||||
p.add_sort(s);
|
||||
}
|
||||
@@ -436,6 +442,7 @@ pub fn base_expression_egglog() -> String {
|
||||
|
||||
// Rulesets
|
||||
p.add_ruleset("expr");
|
||||
p.add_ruleset("dtype_prop");
|
||||
p.add_ruleset("cleanup");
|
||||
p.add_ruleset("early");
|
||||
|
||||
|
||||
@@ -6,15 +6,16 @@ use rand::Rng;
|
||||
use rustc_hash::FxHashSet;
|
||||
use std::collections::hash_map::DefaultHasher;
|
||||
use std::hash::{Hash, Hasher};
|
||||
use std::{str, sync::Arc};
|
||||
use std::{str, sync::Arc, time::Duration};
|
||||
use tracing::trace;
|
||||
|
||||
pub mod api;
|
||||
pub mod base;
|
||||
|
||||
pub const RUN_SCHEDULE: &str = "(run-schedule
|
||||
(repeat 100
|
||||
(repeat 10
|
||||
(saturate expr)
|
||||
(saturate dtype_prop)
|
||||
(run)
|
||||
)
|
||||
(saturate expr)
|
||||
@@ -111,24 +112,64 @@ pub fn early_egglog(
|
||||
ops: &[Arc<Box<dyn EgglogOp>>],
|
||||
cleanup: bool,
|
||||
) -> String {
|
||||
let parts = OpTextParts::new(ops, cleanup);
|
||||
early_egglog_with(program, root, &parts)
|
||||
}
|
||||
|
||||
pub fn full_egglog(program: &str, ops: &[Arc<Box<dyn EgglogOp>>], cleanup: bool) -> String {
|
||||
let parts = OpTextParts::new(ops, cleanup);
|
||||
full_egglog_with(program, &parts)
|
||||
}
|
||||
|
||||
/// Pre-computed per-op text fragments. `run_egglog` calls early + full back
|
||||
/// to back with identical `ops`; materialising all op-derived strings once
|
||||
/// up front means callers that want to drive multiple egglog runs in parallel
|
||||
/// only need to share `&str` references and never touch the non-Send trait
|
||||
/// objects in `ops`.
|
||||
pub struct OpTextParts {
|
||||
op_defs: String,
|
||||
cleanups: String,
|
||||
early_rewrites: String,
|
||||
full_rewrites: String,
|
||||
}
|
||||
|
||||
impl OpTextParts {
|
||||
pub fn new(ops: &[Arc<Box<dyn EgglogOp>>], cleanup: bool) -> Self {
|
||||
Self {
|
||||
op_defs: op_defs_string(ops),
|
||||
cleanups: if cleanup {
|
||||
op_cleanups_string(ops)
|
||||
} else {
|
||||
String::new()
|
||||
},
|
||||
early_rewrites: ops
|
||||
.iter()
|
||||
.flat_map(|o| o.early_rewrites())
|
||||
.map(|r| r.to_egglog_string())
|
||||
.join("\n"),
|
||||
full_rewrites: ops
|
||||
.iter()
|
||||
.flat_map(|o| o.rewrites())
|
||||
.map(|r| r.to_egglog_string())
|
||||
.join("\n"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn early_egglog_with(program: &str, root: &str, parts: &OpTextParts) -> String {
|
||||
[
|
||||
base::base_expression_egglog(),
|
||||
op_defs_string(ops),
|
||||
ops.iter()
|
||||
.flat_map(|o| o.early_rewrites())
|
||||
.map(|r| r.to_egglog_string())
|
||||
.join("\n"),
|
||||
if cleanup {
|
||||
op_cleanups_string(ops)
|
||||
} else {
|
||||
"".to_string()
|
||||
},
|
||||
parts.op_defs.clone(),
|
||||
parts.early_rewrites.clone(),
|
||||
parts.cleanups.clone(),
|
||||
base::base_cleanup_egglog(),
|
||||
program.to_string(),
|
||||
format!(
|
||||
"(run-schedule
|
||||
(saturate expr)
|
||||
(run)
|
||||
(repeat 6
|
||||
(saturate expr)
|
||||
(run)
|
||||
)
|
||||
(saturate base_cleanup)
|
||||
)
|
||||
(extract {root})"
|
||||
@@ -137,20 +178,13 @@ pub fn early_egglog(
|
||||
.join("\n")
|
||||
}
|
||||
|
||||
pub fn full_egglog(program: &str, ops: &[Arc<Box<dyn EgglogOp>>], cleanup: bool) -> String {
|
||||
fn full_egglog_with(program: &str, parts: &OpTextParts) -> String {
|
||||
[
|
||||
base::base_expression_egglog(),
|
||||
op_defs_string(ops),
|
||||
if cleanup {
|
||||
op_cleanups_string(ops)
|
||||
} else {
|
||||
"".to_string()
|
||||
},
|
||||
parts.op_defs.clone(),
|
||||
parts.cleanups.clone(),
|
||||
base::base_cleanup_egglog(),
|
||||
ops.iter()
|
||||
.flat_map(|o| o.rewrites())
|
||||
.map(|r| r.to_egglog_string())
|
||||
.join("\n"),
|
||||
parts.full_rewrites.clone(),
|
||||
program.to_string(),
|
||||
RUN_SCHEDULE.to_string(),
|
||||
]
|
||||
@@ -159,8 +193,7 @@ pub fn full_egglog(program: &str, ops: &[Arc<Box<dyn EgglogOp>>], cleanup: bool)
|
||||
|
||||
use crate::{
|
||||
dtype::DType,
|
||||
graph::{Graph, LLIRGraph, SubgraphDescriptor},
|
||||
hlir::{Input, Output},
|
||||
graph::{Graph, LLIRGraph},
|
||||
op::{CustomOp, EgglogOp},
|
||||
prelude::FxHashMap,
|
||||
shape::Expression,
|
||||
@@ -178,6 +211,20 @@ pub struct SerializedEGraph {
|
||||
pub roots: Vec<ClassId>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct EgglogStageReport {
|
||||
pub num_matches_per_rule: FxHashMap<String, usize>,
|
||||
pub search_and_apply_time_per_rule: FxHashMap<String, Duration>,
|
||||
pub total_time: Duration,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct EgglogRunReport {
|
||||
pub early: EgglogStageReport,
|
||||
pub full: EgglogStageReport,
|
||||
pub total_time: Duration,
|
||||
}
|
||||
|
||||
impl SerializedEGraph {
|
||||
/// This is an opinionated function which does more than strictly take the state of the egglog object.
|
||||
/// It also filters out "[...]" nodes and then changes the structure from the e-termDAG that egraph-serialize
|
||||
@@ -319,11 +366,17 @@ pub fn hash_egglog_normalized(text: &str) -> u64 {
|
||||
for line in text.lines() {
|
||||
if line.contains("(Input ") {
|
||||
// Format: (let tN (Input NODE "LABEL" (DTYPE)))
|
||||
// Strip the node index and label, keep only the dtype.
|
||||
// Strip the node index and label identity, but preserve whether this
|
||||
// is a synthetic boundary input or a real graph input.
|
||||
// The dtype is the last parenthesized token, e.g. "(F32)".
|
||||
if let Some(dtype_start) = line.rfind(" (") {
|
||||
let dtype = &line[dtype_start + 1..];
|
||||
("INPUT", dtype).hash(&mut hasher);
|
||||
let kind = if line.contains("\"boundary\"") {
|
||||
"BOUNDARY_INPUT"
|
||||
} else {
|
||||
"REAL_INPUT"
|
||||
};
|
||||
(kind, dtype).hash(&mut hasher);
|
||||
} else {
|
||||
line.hash(&mut hasher);
|
||||
}
|
||||
@@ -390,8 +443,10 @@ pub fn hlir_to_egglog(graph: &Graph) -> (String, String) {
|
||||
|
||||
// 2. Map <node-id> → <egglog var name>
|
||||
let mut names: HashMap<NodeIndex, String> = HashMap::new();
|
||||
let mut out = String::new();
|
||||
// Pre-size output to avoid growth reallocations; ops emit ~100-200 chars each.
|
||||
let mut out = String::with_capacity(topo_order.len() * 160);
|
||||
|
||||
use std::fmt::Write;
|
||||
let mut curr_id = 0;
|
||||
for n in topo_order {
|
||||
let sources: Vec<(NodeIndex, String)> = graph
|
||||
@@ -400,7 +455,9 @@ pub fn hlir_to_egglog(graph: &Graph) -> (String, String) {
|
||||
.map(|src| (src, names[&src].clone()))
|
||||
.collect_vec();
|
||||
let code = graph[n].to_egglog(&sources);
|
||||
out.push_str(&format!("(let t{curr_id} {code})\n"));
|
||||
// write!() into the existing buffer skips the intermediate String
|
||||
// that format! would otherwise allocate for each node.
|
||||
let _ = writeln!(out, "(let t{curr_id} {code})");
|
||||
names.insert(n, format!("t{curr_id}"));
|
||||
curr_id += 1;
|
||||
}
|
||||
@@ -413,145 +470,12 @@ pub fn hlir_to_egglog(graph: &Graph) -> (String, String) {
|
||||
let mut root = names[0].clone();
|
||||
for node in names.into_iter().skip(1) {
|
||||
curr_id += 1;
|
||||
out.push_str(&format!("(let t{curr_id} (OutputJoin {root} {node}))\n"));
|
||||
let _ = writeln!(out, "(let t{curr_id} (OutputJoin {root} {node}))");
|
||||
root = format!("t{curr_id}");
|
||||
}
|
||||
(out.replace("(MVar \"z\")", "(MIter)"), root)
|
||||
}
|
||||
|
||||
/// Convert a subgraph of the HLIR to egglog, injecting synthetic Input/Output
|
||||
/// nodes at graph break boundaries.
|
||||
pub fn hlir_subgraph_to_egglog(graph: &Graph, subgraph: &SubgraphDescriptor) -> (String, String) {
|
||||
use std::cmp::Reverse;
|
||||
use std::collections::{BinaryHeap, HashMap};
|
||||
|
||||
let mut names: HashMap<NodeIndex, String> = HashMap::new();
|
||||
let mut out = String::new();
|
||||
let mut curr_id = 0;
|
||||
|
||||
// Emit synthetic Input nodes for boundary inputs
|
||||
for boundary in &subgraph.boundary_inputs {
|
||||
let var_name = format!("t{curr_id}");
|
||||
let code = format!(
|
||||
"(Input {} \"boundary\" ({:?}))",
|
||||
boundary.break_node.index(),
|
||||
boundary.dtype
|
||||
);
|
||||
out.push_str(&format!("(let {var_name} {code})\n"));
|
||||
// Map the GraphBreak node to this synthetic Input variable.
|
||||
// When downstream nodes reference the GraphBreak as a source, they'll use this.
|
||||
names.insert(boundary.break_node, var_name);
|
||||
curr_id += 1;
|
||||
}
|
||||
|
||||
// Topo-order only the nodes in this subgraph
|
||||
// Build sub-indeg map restricted to subgraph nodes
|
||||
let mut indeg: HashMap<NodeIndex, usize> = HashMap::new();
|
||||
for &n in &subgraph.nodes {
|
||||
let count = graph
|
||||
.graph
|
||||
.neighbors_directed(n, Direction::Incoming)
|
||||
.filter(|pred| subgraph.nodes.contains(pred))
|
||||
.count();
|
||||
indeg.insert(n, count);
|
||||
}
|
||||
|
||||
let mut ready: BinaryHeap<(Reverse<usize>, NodeIndex)> = BinaryHeap::new();
|
||||
for (&n, &d) in &indeg {
|
||||
if d == 0 {
|
||||
ready.push((Reverse(n.index()), n));
|
||||
}
|
||||
}
|
||||
|
||||
let mut topo_order: Vec<NodeIndex> = Vec::with_capacity(indeg.len());
|
||||
while let Some((_, n)) = ready.pop() {
|
||||
topo_order.push(n);
|
||||
for succ in graph.graph.neighbors_directed(n, Direction::Outgoing) {
|
||||
if let Some(e) = indeg.get_mut(&succ) {
|
||||
*e -= 1;
|
||||
if *e == 0 {
|
||||
ready.push((Reverse(succ.index()), succ));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Convert each node in topological order to egglog
|
||||
for n in topo_order {
|
||||
let sources: Vec<(NodeIndex, String)> = graph
|
||||
.get_sources(n)
|
||||
.into_iter()
|
||||
.map(|src| {
|
||||
let name = names
|
||||
.get(&src)
|
||||
.cloned()
|
||||
.unwrap_or_else(|| panic!("Missing egglog name for node {:?}", src));
|
||||
(src, name)
|
||||
})
|
||||
.collect_vec();
|
||||
let code = graph.graph[n].to_egglog(&sources);
|
||||
out.push_str(&format!("(let t{curr_id} {code})\n"));
|
||||
names.insert(n, format!("t{curr_id}"));
|
||||
curr_id += 1;
|
||||
}
|
||||
|
||||
// Emit synthetic Output nodes for boundary outputs
|
||||
for &brk in &subgraph.boundary_outputs {
|
||||
// The predecessor of the GraphBreak is the actual producer
|
||||
let pred = graph
|
||||
.graph
|
||||
.neighbors_directed(brk, Direction::Incoming)
|
||||
.next()
|
||||
.expect("GraphBreak must have exactly one input");
|
||||
let pred_name = names.get(&pred).cloned().unwrap_or_else(|| {
|
||||
panic!(
|
||||
"Missing egglog name for boundary output predecessor {:?}",
|
||||
pred
|
||||
)
|
||||
});
|
||||
let code = format!("(Output {} {})", pred_name, brk.index());
|
||||
out.push_str(&format!("(let t{curr_id} {code})\n"));
|
||||
names.insert(brk, format!("t{curr_id}"));
|
||||
curr_id += 1;
|
||||
}
|
||||
|
||||
// Join outputs: real outputs (nodes with no outgoing edges within the subgraph)
|
||||
// plus boundary outputs
|
||||
let mut output_names: Vec<String> = vec![];
|
||||
|
||||
// Boundary outputs
|
||||
for &brk in &subgraph.boundary_outputs {
|
||||
if let Some(name) = names.get(&brk) {
|
||||
output_names.push(name.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Real outputs: only actual Output HLIR ops that exist in this subgraph
|
||||
// (not arbitrary nodes that happen to have no subgraph successors)
|
||||
for &n in &subgraph.nodes {
|
||||
if graph.try_get_op::<Output>(n).is_some() {
|
||||
if let Some(name) = names.get(&n) {
|
||||
output_names.push(name.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if output_names.is_empty() {
|
||||
// Fallback: use the last node added
|
||||
output_names.push(format!("t{}", curr_id - 1));
|
||||
}
|
||||
|
||||
// Join with OutputJoin
|
||||
let mut root = output_names[0].clone();
|
||||
for node in output_names.into_iter().skip(1) {
|
||||
curr_id += 1;
|
||||
out.push_str(&format!("(let t{curr_id} (OutputJoin {root} {node}))\n"));
|
||||
root = format!("t{curr_id}");
|
||||
}
|
||||
|
||||
(out.replace("(MVar \"z\")", "(MIter)"), root)
|
||||
}
|
||||
|
||||
pub fn elist_to_egglog(shape: &[Expression]) -> String {
|
||||
list_to_egglog(
|
||||
&shape.iter().map(|e| e.to_egglog()).collect_vec(),
|
||||
@@ -588,41 +512,34 @@ fn termdag_to_egglog(td: &egglog::TermDag, root: egglog::TermId) -> (String, Str
|
||||
(out.replace("(MVar \"z\")", "(MIter)"), format!("t{root}"))
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub fn run_egglog(
|
||||
program: &str,
|
||||
root: &str,
|
||||
ops: &[Arc<Box<dyn EgglogOp>>],
|
||||
cleanup: bool,
|
||||
) -> Result<SerializedEGraph, egglog::Error> {
|
||||
let start = std::time::Instant::now();
|
||||
let code = early_egglog(program, root, ops, cleanup);
|
||||
let mut egraph = egglog::EGraph::default();
|
||||
let commands = egraph.parser.get_program_from_string(None, &code)?;
|
||||
let outputs = egraph.run_program(commands)?;
|
||||
let CommandOutput::ExtractBest(termdag, _cost, term) = outputs.last().unwrap() else {
|
||||
panic!();
|
||||
};
|
||||
let (program, root) = termdag_to_egglog(termdag, termdag.lookup(term));
|
||||
let code = full_egglog(&program, ops, cleanup);
|
||||
let mut egraph = egglog::EGraph::default();
|
||||
let commands = egraph.parser.get_program_from_string(None, &code)?;
|
||||
trace!("{}", "Egglog running...".green());
|
||||
let _outputs = egraph.run_program(commands)?;
|
||||
trace!("{}", "---- Egglog Rule Matches ----".green());
|
||||
fn stage_report(egraph: &egglog::EGraph, total_time: Duration) -> EgglogStageReport {
|
||||
let run_report = egraph.get_overall_run_report();
|
||||
EgglogStageReport {
|
||||
num_matches_per_rule: run_report
|
||||
.num_matches_per_rule
|
||||
.iter()
|
||||
.map(|(name, matches)| (name.to_string(), *matches))
|
||||
.collect(),
|
||||
search_and_apply_time_per_rule: run_report
|
||||
.search_and_apply_time_per_rule
|
||||
.iter()
|
||||
.map(|(name, elapsed)| (name.to_string(), *elapsed))
|
||||
.collect(),
|
||||
total_time,
|
||||
}
|
||||
}
|
||||
|
||||
fn trace_stage_report(header: &str, report: &EgglogStageReport) {
|
||||
trace!("{}", header.green());
|
||||
trace!(
|
||||
"{}",
|
||||
run_report
|
||||
report
|
||||
.num_matches_per_rule
|
||||
.iter()
|
||||
.filter(|(k, _)| !k.contains("("))
|
||||
.map(|(k, v)| format!(
|
||||
"{k}: {v} ({})",
|
||||
pretty_duration::pretty_duration(
|
||||
&run_report.search_and_apply_time_per_rule[k],
|
||||
None
|
||||
)
|
||||
pretty_duration::pretty_duration(&report.search_and_apply_time_per_rule[k], None)
|
||||
))
|
||||
.join("\n")
|
||||
.green()
|
||||
@@ -630,11 +547,73 @@ pub fn run_egglog(
|
||||
trace!(
|
||||
"{}",
|
||||
format!(
|
||||
"---- Egglog Took {} ----",
|
||||
pretty_duration::pretty_duration(&start.elapsed(), None).bold()
|
||||
"---- {} Took {} ----",
|
||||
header,
|
||||
pretty_duration::pretty_duration(&report.total_time, None).bold()
|
||||
)
|
||||
.green()
|
||||
);
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub fn run_egglog_with_report(
|
||||
program: &str,
|
||||
root: &str,
|
||||
ops: &[Arc<Box<dyn EgglogOp>>],
|
||||
cleanup: bool,
|
||||
) -> Result<(SerializedEGraph, EgglogRunReport), egglog::Error> {
|
||||
let op_parts = OpTextParts::new(ops, cleanup);
|
||||
run_egglog_with_report_parts(program, root, &op_parts)
|
||||
}
|
||||
|
||||
/// Same as [`run_egglog_with_report`], but takes pre-computed [`OpTextParts`].
|
||||
/// Useful when a caller runs many egglog invocations with the same op set
|
||||
/// and wants to factor the op-derived text work out of a parallel loop.
|
||||
/// Takes only `&str` / `&OpTextParts` inputs so the whole function is `Send`.
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub fn run_egglog_with_report_parts(
|
||||
program: &str,
|
||||
root: &str,
|
||||
op_parts: &OpTextParts,
|
||||
) -> Result<(SerializedEGraph, EgglogRunReport), egglog::Error> {
|
||||
let total_start = std::time::Instant::now();
|
||||
|
||||
let early_start = std::time::Instant::now();
|
||||
let code = early_egglog_with(program, root, op_parts);
|
||||
let mut egraph = egglog::EGraph::default();
|
||||
let commands = egraph.parser.get_program_from_string(None, &code)?;
|
||||
let outputs = egraph.run_program(commands)?;
|
||||
let early_report = stage_report(&egraph, early_start.elapsed());
|
||||
|
||||
let CommandOutput::ExtractBest(termdag, _cost, term) = outputs.last().unwrap() else {
|
||||
panic!();
|
||||
};
|
||||
let (program, root) = termdag_to_egglog(termdag, termdag.lookup(term));
|
||||
|
||||
let full_start = std::time::Instant::now();
|
||||
let code = full_egglog_with(&program, op_parts);
|
||||
let mut egraph = egglog::EGraph::default();
|
||||
let commands = egraph.parser.get_program_from_string(None, &code)?;
|
||||
trace!("{}", "Egglog running...".green());
|
||||
let _outputs = egraph.run_program(commands)?;
|
||||
let full_report = stage_report(&egraph, full_start.elapsed());
|
||||
trace_stage_report("---- Egglog Early Rule Matches ----", &early_report);
|
||||
trace_stage_report("---- Egglog Full Rule Matches ----", &full_report);
|
||||
|
||||
let run_report = EgglogRunReport {
|
||||
early: early_report,
|
||||
full: full_report,
|
||||
total_time: total_start.elapsed(),
|
||||
};
|
||||
trace!(
|
||||
"{}",
|
||||
format!(
|
||||
"---- Egglog Total Took {} ----",
|
||||
pretty_duration::pretty_duration(&run_report.total_time, None).bold()
|
||||
)
|
||||
.green()
|
||||
);
|
||||
|
||||
let (sort, value) = egraph.eval_expr(&var!(root))?;
|
||||
let s = egraph.serialize(egglog::SerializeConfig {
|
||||
root_eclasses: vec![(sort, value)],
|
||||
@@ -719,7 +698,28 @@ pub fn run_egglog(
|
||||
"No valid graphs present in the e-graph!"
|
||||
);
|
||||
|
||||
Ok(egraph)
|
||||
Ok((egraph, run_report))
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub fn run_egglog(
|
||||
program: &str,
|
||||
root: &str,
|
||||
ops: &[Arc<Box<dyn EgglogOp>>],
|
||||
cleanup: bool,
|
||||
) -> Result<SerializedEGraph, egglog::Error> {
|
||||
run_egglog_with_report(program, root, ops, cleanup).map(|(egraph, _)| egraph)
|
||||
}
|
||||
|
||||
/// Same as [`run_egglog`] but takes pre-computed [`OpTextParts`], so the
|
||||
/// whole function is `Send`. Used by the parallel grouped-egraphs build.
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub fn run_egglog_with(
|
||||
program: &str,
|
||||
root: &str,
|
||||
op_parts: &OpTextParts,
|
||||
) -> Result<SerializedEGraph, egglog::Error> {
|
||||
run_egglog_with_report_parts(program, root, op_parts).map(|(egraph, _)| egraph)
|
||||
}
|
||||
|
||||
pub fn extract_expr_list<'a>(
|
||||
@@ -766,6 +766,8 @@ pub fn extract_dtype<'a>(egraph: &'a SerializedEGraph, node: &'a NodeId) -> DTyp
|
||||
"F4E2M1" => DType::F4E2M1,
|
||||
"F8E4M3" => DType::F8E4M3,
|
||||
"F8UE8M0" => DType::F8UE8M0,
|
||||
"I4" => DType::I4,
|
||||
"TF32" => DType::TF32,
|
||||
other => panic!("unknown dtype {other}"),
|
||||
}
|
||||
}
|
||||
@@ -1101,11 +1103,34 @@ pub fn egglog_to_llir<'a>(
|
||||
list_cache: &mut FxHashMap<&'a NodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a NodeId, Expression>,
|
||||
custom_op_id_remap: Option<&FxHashMap<usize, usize>>,
|
||||
) -> LLIRGraph {
|
||||
egglog_to_llir_from_root(
|
||||
egraph,
|
||||
choices,
|
||||
ops,
|
||||
custom_ops,
|
||||
list_cache,
|
||||
expr_cache,
|
||||
custom_op_id_remap,
|
||||
&egraph.roots[0],
|
||||
)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn egglog_to_llir_from_root<'a>(
|
||||
egraph: &'a SerializedEGraph,
|
||||
choices: EGraphChoiceSet<'a>,
|
||||
ops: &'a Vec<Arc<Box<dyn EgglogOp>>>,
|
||||
custom_ops: &[Box<dyn CustomOp>],
|
||||
list_cache: &mut FxHashMap<&'a NodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a NodeId, Expression>,
|
||||
custom_op_id_remap: Option<&FxHashMap<usize, usize>>,
|
||||
root_class: &ClassId,
|
||||
) -> LLIRGraph {
|
||||
// Make reachability set from root
|
||||
let mut reachable = FxHashSet::default();
|
||||
reachable.insert(choices[&egraph.roots[0]]);
|
||||
let mut reachability_stack = vec![choices[&egraph.roots[0]]];
|
||||
reachable.insert(choices[root_class]);
|
||||
let mut reachability_stack = vec![choices[root_class]];
|
||||
while let Some(r) = reachability_stack.pop() {
|
||||
for ch in &egraph.enodes[r].1 {
|
||||
if egraph.eclasses[ch].0.contains("IR") || egraph.eclasses[ch].0.contains("IList") {
|
||||
@@ -1226,135 +1251,10 @@ pub fn egglog_to_llir<'a>(
|
||||
// )
|
||||
// .unwrap();
|
||||
// }
|
||||
// Loop markers (LoopStart/End/Input/InputStatic/Output) are intentionally
|
||||
// preserved here — `crate::graph::collapse_loops_to_first_iter` produces
|
||||
// a single-iteration LLIR for fast per-candidate profiling, and the full
|
||||
// `crate::graph::unroll_loops_in_llir` runs once on the chosen best LLIR
|
||||
// before it is loaded into the runtime.
|
||||
graph
|
||||
}
|
||||
|
||||
/// Merge multiple per-chunk LLIR graphs into a single LLIR graph,
|
||||
/// resolving boundary Input/Output nodes at graph break boundaries.
|
||||
pub fn stitch_llir_graphs(
|
||||
chunk_llirs: &[LLIRGraph],
|
||||
descriptors: &[SubgraphDescriptor],
|
||||
) -> LLIRGraph {
|
||||
use petgraph::stable_graph::NodeIndex;
|
||||
|
||||
let mut merged = LLIRGraph::default();
|
||||
|
||||
// Collect the set of boundary break_node indices for matching
|
||||
let mut boundary_output_set: FxHashSet<usize> = FxHashSet::default();
|
||||
let mut boundary_input_set: FxHashSet<usize> = FxHashSet::default();
|
||||
for desc in descriptors {
|
||||
for brk in &desc.boundary_outputs {
|
||||
boundary_output_set.insert(brk.index());
|
||||
}
|
||||
for bi in &desc.boundary_inputs {
|
||||
boundary_input_set.insert(bi.break_node.index());
|
||||
}
|
||||
}
|
||||
|
||||
// Per-chunk node mapping: old NodeIndex -> new NodeIndex in merged graph
|
||||
let mut node_maps: Vec<FxHashMap<NodeIndex, NodeIndex>> = Vec::with_capacity(chunk_llirs.len());
|
||||
|
||||
// Track boundary producers: break_node_index -> new NodeIndex of the actual producer
|
||||
let mut boundary_producers: FxHashMap<usize, NodeIndex> = FxHashMap::default();
|
||||
|
||||
// Track real Input node deduplication: Input.node -> new NodeIndex
|
||||
let mut real_inputs: FxHashMap<usize, NodeIndex> = FxHashMap::default();
|
||||
|
||||
for (_chunk_idx, chunk_graph) in chunk_llirs.iter().enumerate() {
|
||||
let mut this_map: FxHashMap<NodeIndex, NodeIndex> = FxHashMap::default();
|
||||
|
||||
// Pass 1: Add all non-boundary nodes
|
||||
for old_node in chunk_graph.node_indices() {
|
||||
let op = &chunk_graph[old_node];
|
||||
|
||||
// Check if this is a boundary Output
|
||||
if let Some(output_op) = op.to_op::<Output>() {
|
||||
if boundary_output_set.contains(&output_op.node) {
|
||||
// Skip — will resolve in pass 2
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// Check if this is a boundary Input
|
||||
if let Some(input_op) = op.to_op::<Input>() {
|
||||
if boundary_input_set.contains(&input_op.node) {
|
||||
// Skip — will resolve in pass 2
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check if this is a real Input that was already added (dedup)
|
||||
if let Some(&existing) = real_inputs.get(&input_op.node) {
|
||||
this_map.insert(old_node, existing);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
let new_node = merged.add_node(op.clone());
|
||||
this_map.insert(old_node, new_node);
|
||||
|
||||
// Track real inputs for deduplication
|
||||
if let Some(input_op) = op.to_op::<Input>() {
|
||||
real_inputs.insert(input_op.node, new_node);
|
||||
}
|
||||
}
|
||||
|
||||
// Pass 2: Resolve boundary Output nodes (record the producer)
|
||||
for old_node in chunk_graph.node_indices() {
|
||||
let op = &chunk_graph[old_node];
|
||||
if let Some(output_op) = op.to_op::<Output>() {
|
||||
if boundary_output_set.contains(&output_op.node) {
|
||||
// Find the predecessor (the actual producer)
|
||||
let pred = chunk_graph
|
||||
.neighbors_directed(old_node, petgraph::Direction::Incoming)
|
||||
.next()
|
||||
.expect("Boundary Output must have exactly one input");
|
||||
if let Some(&producer_new) = this_map.get(&pred) {
|
||||
boundary_producers.insert(output_op.node, producer_new);
|
||||
} else {
|
||||
eprintln!(
|
||||
"[stitch] WARNING: chunk {}: boundary Output node={} predecessor {:?} not in this_map!",
|
||||
_chunk_idx,
|
||||
output_op.node,
|
||||
pred.index()
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Pass 2b: Resolve boundary Input nodes (map to producer from prior chunk)
|
||||
for old_node in chunk_graph.node_indices() {
|
||||
let op = &chunk_graph[old_node];
|
||||
if let Some(input_op) = op.to_op::<Input>() {
|
||||
if boundary_input_set.contains(&input_op.node) {
|
||||
if let Some(&producer) = boundary_producers.get(&input_op.node) {
|
||||
this_map.insert(old_node, producer);
|
||||
} else {
|
||||
eprintln!(
|
||||
"[stitch] WARNING: chunk {}: boundary Input node={} has no producer in boundary_producers!",
|
||||
_chunk_idx, input_op.node
|
||||
);
|
||||
eprintln!(
|
||||
"[stitch] available producers: {:?}",
|
||||
boundary_producers.keys().collect::<Vec<_>>()
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Pass 3: Add edges (preserving duplicate edges for ops like x*x)
|
||||
for edge in chunk_graph.edge_indices() {
|
||||
let (src, dst) = chunk_graph.edge_endpoints(edge).unwrap();
|
||||
if let (Some(&new_src), Some(&new_dst)) = (this_map.get(&src), this_map.get(&dst)) {
|
||||
if new_src != new_dst {
|
||||
merged.add_edge(new_src, new_dst, ());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
node_maps.push(this_map);
|
||||
}
|
||||
|
||||
merged
|
||||
}
|
||||
|
||||
@@ -105,6 +105,9 @@ impl GraphTensor {
|
||||
if let Some(gmem) = self.graph().try_get_op_mut::<Input>(self.id) {
|
||||
gmem.dtype = dtype;
|
||||
}
|
||||
if let Some((_, d)) = self.graph().input_meta.get_mut(&self.id) {
|
||||
*d = dtype;
|
||||
}
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
@@ -57,15 +57,35 @@ impl GraphTensor {
|
||||
self.graph().get_op_mut::<Input>(self.id).label = name.to_string();
|
||||
}
|
||||
|
||||
/// Mark this tensor as an output
|
||||
/// Mark this tensor as an output.
|
||||
/// If the tensor has non-contiguous strides (e.g. from transpose + merge_dims),
|
||||
/// inserts a gather to materialize contiguous data before the output node.
|
||||
pub fn output(&self) -> GraphTensor {
|
||||
let source = if self.shape.is_contiguous() {
|
||||
*self
|
||||
} else {
|
||||
// Insert gather to make physically contiguous
|
||||
let dims = self.dims();
|
||||
let total = dims.iter().copied().reduce(|a, b| a * b).unwrap();
|
||||
let idx_expr = self.shape.index_expression();
|
||||
let idx = self.graph().iota(idx_expr, total);
|
||||
let mut gathered = self.gather(idx);
|
||||
gathered.shape = ShapeTracker::new(dims);
|
||||
gathered
|
||||
};
|
||||
self.output_raw(source)
|
||||
}
|
||||
|
||||
/// Mark a tensor as an output without any contiguous materialization.
|
||||
/// Used internally by graph_break and persist.
|
||||
fn output_raw(&self, source: GraphTensor) -> GraphTensor {
|
||||
self.graph().add_op(
|
||||
Output {
|
||||
node: self.id.index(),
|
||||
node: source.id.index(),
|
||||
},
|
||||
&[self.id],
|
||||
&[source.id],
|
||||
);
|
||||
*self
|
||||
source
|
||||
}
|
||||
|
||||
/// Required bytes to store this tensor's physical elements. Rounds up to nearest byte.
|
||||
@@ -77,7 +97,7 @@ impl GraphTensor {
|
||||
/// so the buffer is not consumed after execute(), but returns the original
|
||||
/// Input node's GraphTensor (not the Output node).
|
||||
pub fn persist(&self) -> GraphTensor {
|
||||
self.output();
|
||||
self.output_raw(*self);
|
||||
*self
|
||||
}
|
||||
|
||||
|
||||
@@ -152,16 +152,6 @@ impl GraphTensor {
|
||||
GraphTensor::from_id(new_id, self.shape.contiguous(), self.graph_ref, self.dtype)
|
||||
}
|
||||
|
||||
pub fn graph_break(self) -> GraphTensor {
|
||||
let new_id = self.graph().add_op(
|
||||
crate::hlir::GraphBreak {
|
||||
input_shape: self.shape,
|
||||
},
|
||||
&[self.id],
|
||||
);
|
||||
GraphTensor::from_id(new_id, self.shape.contiguous(), self.graph_ref, self.dtype)
|
||||
}
|
||||
|
||||
/// Scale so std is 1.0
|
||||
pub fn std_norm<T>(self, axes: impl ToAxes, epsilon: T) -> GraphTensor
|
||||
where
|
||||
@@ -663,7 +653,7 @@ pub(super) mod tests {
|
||||
let mut out: Vec<(NotNan<f32>, usize)> =
|
||||
heap.into_iter().map(|std::cmp::Reverse(t)| t).collect();
|
||||
|
||||
out.sort_unstable_by(|a, b| b.0.cmp(&a.0));
|
||||
out.sort_unstable_by_key(|entry| std::cmp::Reverse(entry.0));
|
||||
out.into_iter().map(|(_, i)| i).collect()
|
||||
}
|
||||
test_unary(
|
||||
|
||||
2759
src/graph.rs
2759
src/graph.rs
File diff suppressed because it is too large
Load Diff
752
src/hlir.rs
752
src/hlir.rs
@@ -25,6 +25,7 @@ fn dtype_propagation_rule(sort: &SortDef, dtype_source: &str) -> Rule {
|
||||
.fact(eq(e.clone(), op_match))
|
||||
.fact(eq(dty.clone(), dtype(args[dtype_source].clone())))
|
||||
.action(Action::Set(dtype(e), dty))
|
||||
.ruleset("dtype_prop")
|
||||
}
|
||||
|
||||
/// Helper: build a dtype-from-field rule for a direct IR op.
|
||||
@@ -34,6 +35,7 @@ fn dtype_from_field_rule(sort: &SortDef, dtype_field: &str) -> Rule {
|
||||
Rule::new()
|
||||
.fact(eq(e.clone(), op_match))
|
||||
.action(Action::Set(dtype(e), args[dtype_field].clone()))
|
||||
.ruleset("dtype_prop")
|
||||
}
|
||||
|
||||
// --- Dtype helpers for normalized ops (Op OpKind IList) ---
|
||||
@@ -58,6 +60,7 @@ fn dtype_propagation_op(kind_sort: &SortDef) -> Rule {
|
||||
))
|
||||
.fact(eq(dty.clone(), dtype(first_inp)))
|
||||
.action(Action::Set(dtype(e), dty))
|
||||
.ruleset("dtype_prop")
|
||||
}
|
||||
|
||||
/// Dtype from a field on the OpKind (e.g., Cast's dtype field).
|
||||
@@ -68,6 +71,7 @@ fn dtype_from_kind_field(kind_sort: &SortDef, field_name: &str) -> Rule {
|
||||
Rule::new()
|
||||
.fact(eq(e.clone(), op_term(kind_term, inputs)))
|
||||
.action(Action::Set(dtype(e), args[field_name].clone()))
|
||||
.ruleset("dtype_prop")
|
||||
}
|
||||
|
||||
/// Fixed dtype for a normalized op (e.g., Iota always Int).
|
||||
@@ -78,6 +82,7 @@ fn dtype_fixed_op(kind_sort: &SortDef, dtype_sort: &SortDef) -> Rule {
|
||||
Rule::new()
|
||||
.fact(eq(e.clone(), op_term(kind_term, inputs)))
|
||||
.action(Action::Set(dtype(e), dtype_sort.call(())))
|
||||
.ruleset("dtype_prop")
|
||||
}
|
||||
|
||||
/// Build an IList egglog string from input variable names.
|
||||
@@ -114,6 +119,90 @@ pub fn binary_sort(name: &str) -> SortDef {
|
||||
)
|
||||
}
|
||||
|
||||
/// Generate egglog rewrite rules that union a small rolled `body=1, trips=N`
|
||||
/// single-binary-op loop with its fully-unrolled equivalent in the same
|
||||
/// eclass. Both representations coexist; the cost-based extractor picks
|
||||
/// whichever one downstream patterns prefer — the unrolled form when fusions
|
||||
/// (e.g. GLUMoE GemmaGELU, KernelExp's `direct-exp-fusion`) match through
|
||||
/// the flat chain, the rolled form otherwise. Without these unions, rolling
|
||||
/// a tiny chain blocks the fusion entirely and the extracted graph is
|
||||
/// strictly worse than not rolling.
|
||||
///
|
||||
/// **Register in both `EgglogOp::early_rewrites()` AND `rewrites()`.** The
|
||||
/// driver feeds `early_rewrites` into the early-stage program only and
|
||||
/// `rewrites` into the full-stage program only; we need the unrolled chain
|
||||
/// visible in both stages so early-stage fusion patterns (GLUMoE) AND
|
||||
/// full-stage kernel rewrites (`direct-exp-fusion`) can both match it.
|
||||
///
|
||||
/// Generates 2 rules per iter count (state at body input position 0 vs 1)
|
||||
/// for every `n_iters` in `2..=max_trips`. Larger trips stay rolled-only —
|
||||
/// real transformer-block rolls are body ≫ 1 anyway, and carrying both
|
||||
/// forms beyond a small N adds search-time cost without an upside.
|
||||
///
|
||||
/// Each rule matches the rolled shape `LoopEnd(body)` where `body` is the
|
||||
/// binary op consuming `LoopStart(initial)` and `LoopInput(s0..s_{N-1})`,
|
||||
/// and unions `LoopEnd` with the chain
|
||||
/// `u0 = <kind>(initial, s0); u1 = <kind>(u0, s1); … u_{N-1}`.
|
||||
/// (or symmetric for state at position 1.)
|
||||
pub fn binary_op_unroll_rules(op_kind: &str, max_trips: usize) -> Vec<Rule> {
|
||||
let mut rules = Vec::with_capacity((max_trips.saturating_sub(1)) * 2);
|
||||
for n_iters in 2..=max_trips {
|
||||
for state_pos in 0..2 {
|
||||
rules.push(binary_op_unroll_rule(op_kind, n_iters, state_pos));
|
||||
}
|
||||
}
|
||||
rules
|
||||
}
|
||||
|
||||
fn binary_op_unroll_rule(op_kind: &str, n_iters: usize, state_pos: usize) -> Rule {
|
||||
// Swap (state, per_iter) → (input0, input1) by `state_pos`. Both the
|
||||
// body match pattern and the unrolled chain bodies follow this mapping
|
||||
// so a/b stride positions stay aligned.
|
||||
debug_assert!(state_pos < 2);
|
||||
let order = |state: &str, per_iter: &str| -> String {
|
||||
if state_pos == 0 {
|
||||
format!("(ICons {state} (ICons {per_iter} (INil)))")
|
||||
} else {
|
||||
format!("(ICons {per_iter} (ICons {state} (INil)))")
|
||||
}
|
||||
};
|
||||
let li_sources = (0..n_iters).rev().fold(String::from("(INil)"), |acc, i| {
|
||||
format!("(ICons ?s{i} {acc})")
|
||||
});
|
||||
let chain = (0..n_iters)
|
||||
.map(|i| {
|
||||
let prev = if i == 0 {
|
||||
"?initial".to_string()
|
||||
} else {
|
||||
format!("?u{}", i - 1)
|
||||
};
|
||||
format!(
|
||||
" (let ?u{i} (Op ({op_kind} ?sh ?as ?bs ?os) {}))",
|
||||
order(&prev, &format!("?s{i}"))
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
Rule::raw(format!(
|
||||
"(rule
|
||||
(
|
||||
(= ?ls (LoopStart ?initial ?loop_id ?slot_idx (MNum {n_iters}) ?dt))
|
||||
(= ?li (Op (LoopInput ?loop_id ?stream ?dt) {li_sources}))
|
||||
(= ?body (Op ({op_kind} ?sh ?as ?bs ?os) {body_pat}))
|
||||
(= ?le (LoopEnd ?body ?loop_id ?slot_idx ?dt))
|
||||
)
|
||||
(
|
||||
{chain}
|
||||
(union ?le ?u{last})
|
||||
)
|
||||
:ruleset expr
|
||||
:name \"unroll {op_kind} body trips={n_iters} state={state_pos}\"
|
||||
)",
|
||||
body_pat = order("?ls", "?li"),
|
||||
last = n_iters - 1,
|
||||
))
|
||||
}
|
||||
|
||||
/// Reduce op kind: (shape: EList, iters: Expression, strides: EList, iter_stride: Expression, out_strides: EList), IList: [inp]
|
||||
pub fn reduce_sort(name: &str) -> SortDef {
|
||||
sort(
|
||||
@@ -133,6 +222,12 @@ pub type HLIROps = (
|
||||
Input,
|
||||
Output,
|
||||
CustomOpKind,
|
||||
LoopStart,
|
||||
LoopEnd,
|
||||
LoopInput,
|
||||
LoopInputStatic,
|
||||
LoopOutput,
|
||||
LoopOutputSelect,
|
||||
Constant,
|
||||
Cast,
|
||||
Iota,
|
||||
@@ -331,6 +426,607 @@ impl NativeOp for CustomOpKind {
|
||||
}
|
||||
}
|
||||
|
||||
// --- Loop ops ---------------------------------------------------------------
|
||||
//
|
||||
// Automatic loop-rolling replaces N unrolled copies of a repeating body with
|
||||
// a single body plus structural marker ops. All four ops in one loop share a
|
||||
// `loop_id`. `iters` lives on `LoopStart` only; every other op references the
|
||||
// same loop via `loop_id`.
|
||||
//
|
||||
// LoopStart — one per loop-carried slot; takes the initial value, yields
|
||||
// the current iteration's value into the body.
|
||||
// LoopEnd — mirror of LoopStart; takes the body's final value for the
|
||||
// slot, yields the post-loop value.
|
||||
// LoopInput — OpKind (variable-arity). Takes N input tensors (one per
|
||||
// iteration) and yields the current iteration's tensor.
|
||||
// LoopOutput — OpKind (variable-arity, sink). Takes the body's value + N
|
||||
// target tensors; writes body[i] -> target[i] each iteration.
|
||||
//
|
||||
// Execution semantics and iteration driving live in the runtime compilation
|
||||
// step; these ops just carry the structure through HLIR/egglog/LLIR.
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct LoopStart {
|
||||
pub loop_id: usize,
|
||||
pub slot_idx: usize,
|
||||
pub iters: Expression,
|
||||
pub dtype: DType,
|
||||
}
|
||||
|
||||
impl Display for LoopStart {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"LoopStart(id={}, slot={}, iters={:?}, {})",
|
||||
self.loop_id, self.slot_idx, self.iters, self.dtype
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl EgglogOp for LoopStart {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
IR,
|
||||
"LoopStart",
|
||||
&[
|
||||
("inp", IR),
|
||||
("loop_id", I64),
|
||||
("slot_idx", I64),
|
||||
("iters", EXPRESSION),
|
||||
("dtype", DTYPE),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![dtype_from_field_rule(&self.sort(), "dtype")]
|
||||
}
|
||||
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
_input_enodes: Vec<&'a ENodeId>,
|
||||
_: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
let loop_id = egraph.enodes[kind_children[1]]
|
||||
.0
|
||||
.replace("\"", "")
|
||||
.parse::<usize>()
|
||||
.unwrap();
|
||||
let slot_idx = egraph.enodes[kind_children[2]]
|
||||
.0
|
||||
.replace("\"", "")
|
||||
.parse::<usize>()
|
||||
.unwrap();
|
||||
let iters = extract_expr(egraph, kind_children[3], expr_cache).unwrap();
|
||||
let dtype = extract_dtype(egraph, kind_children[4]);
|
||||
(
|
||||
LLIROp::new::<LoopStart>(Box::new(Self {
|
||||
loop_id,
|
||||
slot_idx,
|
||||
iters,
|
||||
dtype,
|
||||
})),
|
||||
vec![kind_children[0]],
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl HLIROp for LoopStart {
|
||||
fn to_egglog(&self, inp: &[(NodeIndex, String)]) -> String {
|
||||
format!(
|
||||
"(LoopStart {} {} {} {} ({:?}))",
|
||||
inp[0].1,
|
||||
self.loop_id,
|
||||
self.slot_idx,
|
||||
self.iters.to_egglog(),
|
||||
self.dtype,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl NativeOp for LoopStart {
|
||||
fn execute(&self, _: Vec<&NativeData>, _: &FxHashMap<char, usize>) -> NativeData {
|
||||
unimplemented!("LoopStart is driven by the runtime loop compiler")
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct LoopEnd {
|
||||
pub loop_id: usize,
|
||||
pub slot_idx: usize,
|
||||
pub dtype: DType,
|
||||
}
|
||||
|
||||
impl Display for LoopEnd {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"LoopEnd(id={}, slot={}, {})",
|
||||
self.loop_id, self.slot_idx, self.dtype
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl EgglogOp for LoopEnd {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
IR,
|
||||
"LoopEnd",
|
||||
&[
|
||||
("inp", IR),
|
||||
("loop_id", I64),
|
||||
("slot_idx", I64),
|
||||
("dtype", DTYPE),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![dtype_from_field_rule(&self.sort(), "dtype")]
|
||||
}
|
||||
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
_input_enodes: Vec<&'a ENodeId>,
|
||||
_: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
_: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
let loop_id = egraph.enodes[kind_children[1]]
|
||||
.0
|
||||
.replace("\"", "")
|
||||
.parse::<usize>()
|
||||
.unwrap();
|
||||
let slot_idx = egraph.enodes[kind_children[2]]
|
||||
.0
|
||||
.replace("\"", "")
|
||||
.parse::<usize>()
|
||||
.unwrap();
|
||||
let dtype = extract_dtype(egraph, kind_children[3]);
|
||||
(
|
||||
LLIROp::new::<LoopEnd>(Box::new(Self {
|
||||
loop_id,
|
||||
slot_idx,
|
||||
dtype,
|
||||
})),
|
||||
vec![kind_children[0]],
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl HLIROp for LoopEnd {
|
||||
fn to_egglog(&self, inp: &[(NodeIndex, String)]) -> String {
|
||||
format!(
|
||||
"(LoopEnd {} {} {} ({:?}))",
|
||||
inp[0].1, self.loop_id, self.slot_idx, self.dtype,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl NativeOp for LoopEnd {
|
||||
fn execute(&self, _: Vec<&NativeData>, _: &FxHashMap<char, usize>) -> NativeData {
|
||||
unimplemented!("LoopEnd is driven by the runtime loop compiler")
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct LoopInput {
|
||||
pub loop_id: usize,
|
||||
pub stream_id: usize,
|
||||
pub dtype: DType,
|
||||
}
|
||||
|
||||
impl Display for LoopInput {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"LoopInput(id={}, stream={}, {})",
|
||||
self.loop_id, self.stream_id, self.dtype
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl EgglogOp for LoopInput {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
"LoopInput",
|
||||
&[("loop_id", I64), ("stream_id", I64), ("dtype", DTYPE)],
|
||||
)
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![dtype_from_kind_field(&self.sort(), "dtype")]
|
||||
}
|
||||
|
||||
fn early_rewrites(&self) -> Vec<Rule> {
|
||||
// Declare the `identical_inputs` relation and the three-way unification
|
||||
// chain between `LoopInput`, `LoopInputStatic`, and an inlined source.
|
||||
// Running in Stage 1 alongside fusion rules (e.g. GLUMoE) so that
|
||||
// fusion patterns that expect raw op kinds at boundary positions can
|
||||
// match via the unioned eclass.
|
||||
vec![Rule::raw(
|
||||
r#"
|
||||
(relation identical_inputs (IList))
|
||||
|
||||
; All four rules live in the `expr` ruleset, which the early/full
|
||||
; schedules saturate each iteration. Default-ruleset scheduling
|
||||
; only runs each rule once per outer step, which is not enough to
|
||||
; propagate `identical_inputs` through an N-element IList.
|
||||
|
||||
; Base: single-element list is trivially identical.
|
||||
(rule ((= ?l (ICons ?x (INil))))
|
||||
((identical_inputs ?l))
|
||||
:ruleset expr
|
||||
:name "identical_inputs base")
|
||||
|
||||
; Inductive: head equals next-head, and the tail starting at next-head is identical.
|
||||
(rule ((= ?l (ICons ?x (ICons ?x ?tail)))
|
||||
(identical_inputs (ICons ?x ?tail)))
|
||||
((identical_inputs ?l))
|
||||
:ruleset expr
|
||||
:name "identical_inputs ind")
|
||||
|
||||
; LoopInput with an identical IList is equivalent to LoopInputStatic over a single copy.
|
||||
(rule ((= ?e (Op (LoopInput ?id ?stream ?dt) (ICons ?x ?cont)))
|
||||
(identical_inputs (ICons ?x ?cont)))
|
||||
((let ?static (Op (LoopInputStatic ?id ?stream ?dt) (ICons ?x (INil))))
|
||||
(union ?e ?static))
|
||||
:ruleset expr
|
||||
:name "LoopInput to LoopInputStatic")
|
||||
|
||||
; LoopInputStatic is equivalent to its single inner value — collapses the boundary
|
||||
; wrapper for pattern-matching and extraction purposes.
|
||||
(rule ((= ?e (Op (LoopInputStatic ?id ?stream ?dt) (ICons ?x (INil)))))
|
||||
((union ?e ?x))
|
||||
:ruleset expr
|
||||
:name "LoopInputStatic inline")
|
||||
"#,
|
||||
)]
|
||||
}
|
||||
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
_: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
_: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
let loop_id = egraph.enodes[kind_children[0]]
|
||||
.0
|
||||
.replace("\"", "")
|
||||
.parse::<usize>()
|
||||
.unwrap();
|
||||
let stream_id = egraph.enodes[kind_children[1]]
|
||||
.0
|
||||
.replace("\"", "")
|
||||
.parse::<usize>()
|
||||
.unwrap();
|
||||
let dtype = extract_dtype(egraph, kind_children[2]);
|
||||
(
|
||||
LLIROp::new::<LoopInput>(Box::new(Self {
|
||||
loop_id,
|
||||
stream_id,
|
||||
dtype,
|
||||
})),
|
||||
input_enodes,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl HLIROp for LoopInput {
|
||||
fn to_egglog(&self, inp: &[(NodeIndex, String)]) -> String {
|
||||
format!(
|
||||
"(Op (LoopInput {} {} ({:?})) {})",
|
||||
self.loop_id,
|
||||
self.stream_id,
|
||||
self.dtype,
|
||||
list_to_egglog(&inp.iter().map(|i| &i.1).collect_vec(), "ICons", "INil"),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl NativeOp for LoopInput {
|
||||
fn execute(&self, _: Vec<&NativeData>, _: &FxHashMap<char, usize>) -> NativeData {
|
||||
unimplemented!("LoopInput is driven by the runtime loop compiler")
|
||||
}
|
||||
}
|
||||
|
||||
/// Iteration-independent boundary input: the same value flows into every
|
||||
/// iteration of a loop. Structurally a `LoopInput` whose per-iteration
|
||||
/// sources have all been proven equal (via the `identical_inputs` egglog
|
||||
/// relation) collapses into `LoopInputStatic` with a single-element IList,
|
||||
/// and that in turn collapses via a further rewrite into just its inner
|
||||
/// value — so egglog search can explore any of the three representations.
|
||||
/// At unroll time `LoopInputStatic` lowers to a plain edge: every cloned
|
||||
/// body node in every iteration references the single shared source.
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct LoopInputStatic {
|
||||
pub loop_id: usize,
|
||||
pub stream_id: usize,
|
||||
pub dtype: DType,
|
||||
}
|
||||
|
||||
impl Display for LoopInputStatic {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"LoopInputStatic(id={}, stream={}, {})",
|
||||
self.loop_id, self.stream_id, self.dtype
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl EgglogOp for LoopInputStatic {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
"LoopInputStatic",
|
||||
&[("loop_id", I64), ("stream_id", I64), ("dtype", DTYPE)],
|
||||
)
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![dtype_from_kind_field(&self.sort(), "dtype")]
|
||||
}
|
||||
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
_: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
_: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
let loop_id = egraph.enodes[kind_children[0]]
|
||||
.0
|
||||
.replace("\"", "")
|
||||
.parse::<usize>()
|
||||
.unwrap();
|
||||
let stream_id = egraph.enodes[kind_children[1]]
|
||||
.0
|
||||
.replace("\"", "")
|
||||
.parse::<usize>()
|
||||
.unwrap();
|
||||
let dtype = extract_dtype(egraph, kind_children[2]);
|
||||
(
|
||||
LLIROp::new::<LoopInputStatic>(Box::new(Self {
|
||||
loop_id,
|
||||
stream_id,
|
||||
dtype,
|
||||
})),
|
||||
input_enodes,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl HLIROp for LoopInputStatic {
|
||||
fn to_egglog(&self, inp: &[(NodeIndex, String)]) -> String {
|
||||
format!(
|
||||
"(Op (LoopInputStatic {} {} ({:?})) {})",
|
||||
self.loop_id,
|
||||
self.stream_id,
|
||||
self.dtype,
|
||||
list_to_egglog(&inp.iter().map(|i| &i.1).collect_vec(), "ICons", "INil"),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl NativeOp for LoopInputStatic {
|
||||
fn execute(&self, _: Vec<&NativeData>, _: &FxHashMap<char, usize>) -> NativeData {
|
||||
unimplemented!("LoopInputStatic is driven by the runtime loop compiler")
|
||||
}
|
||||
}
|
||||
|
||||
/// Marker for the per-iter output stream of a rolled loop. Mirrors `LoopInput`
|
||||
/// in reverse: a single body producer (one incoming edge) feeds the marker, and
|
||||
/// `LoopOutputSelect(i)` nodes hang off it to pluck iteration `i`'s value for
|
||||
/// downstream consumers (any post-region op — `Output` HLIR, downstream
|
||||
/// computation, etc.).
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct LoopOutput {
|
||||
pub loop_id: usize,
|
||||
pub stream_id: usize,
|
||||
pub dtype: DType,
|
||||
}
|
||||
|
||||
impl Display for LoopOutput {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"LoopOutput(id={}, stream={}, {})",
|
||||
self.loop_id, self.stream_id, self.dtype
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl EgglogOp for LoopOutput {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
"LoopOutput",
|
||||
&[("loop_id", I64), ("stream_id", I64), ("dtype", DTYPE)],
|
||||
)
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![dtype_from_kind_field(&self.sort(), "dtype")]
|
||||
}
|
||||
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
_: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
_: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
let loop_id = egraph.enodes[kind_children[0]]
|
||||
.0
|
||||
.replace("\"", "")
|
||||
.parse::<usize>()
|
||||
.unwrap();
|
||||
let stream_id = egraph.enodes[kind_children[1]]
|
||||
.0
|
||||
.replace("\"", "")
|
||||
.parse::<usize>()
|
||||
.unwrap();
|
||||
let dtype = extract_dtype(egraph, kind_children[2]);
|
||||
(
|
||||
LLIROp::new::<LoopOutput>(Box::new(Self {
|
||||
loop_id,
|
||||
stream_id,
|
||||
dtype,
|
||||
})),
|
||||
input_enodes,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl HLIROp for LoopOutput {
|
||||
fn to_egglog(&self, inp: &[(NodeIndex, String)]) -> String {
|
||||
format!(
|
||||
"(Op (LoopOutput {} {} ({:?})) {})",
|
||||
self.loop_id,
|
||||
self.stream_id,
|
||||
self.dtype,
|
||||
list_to_egglog(&inp.iter().map(|i| &i.1).collect_vec(), "ICons", "INil"),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl NativeOp for LoopOutput {
|
||||
fn execute(&self, _: Vec<&NativeData>, _: &FxHashMap<char, usize>) -> NativeData {
|
||||
unimplemented!("LoopOutput is driven by the runtime loop compiler")
|
||||
}
|
||||
}
|
||||
|
||||
/// Per-iteration extractor for a `LoopOutput` stream. Mirrors a per-iter
|
||||
/// `LoopInput` source slot in reverse: every cross-region edge that originally
|
||||
/// went from iteration `i`'s body producer to a post-region consumer is
|
||||
/// rewired through `LoopOutputSelect { iter: i, ... }`. At unroll time
|
||||
/// `Select(i)` lowers to the iter-`i` body clone's producer; at collapse time
|
||||
/// every Select lowers to iter-0's producer.
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct LoopOutputSelect {
|
||||
pub loop_id: usize,
|
||||
pub stream_id: usize,
|
||||
pub iter: usize,
|
||||
pub dtype: DType,
|
||||
}
|
||||
|
||||
impl Display for LoopOutputSelect {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"LoopOutputSelect(id={}, stream={}, iter={}, {})",
|
||||
self.loop_id, self.stream_id, self.iter, self.dtype
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl EgglogOp for LoopOutputSelect {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
"LoopOutputSelect",
|
||||
&[
|
||||
("loop_id", I64),
|
||||
("stream_id", I64),
|
||||
("iter", I64),
|
||||
("dtype", DTYPE),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![dtype_from_kind_field(&self.sort(), "dtype")]
|
||||
}
|
||||
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
_: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
_: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
let loop_id = egraph.enodes[kind_children[0]]
|
||||
.0
|
||||
.replace("\"", "")
|
||||
.parse::<usize>()
|
||||
.unwrap();
|
||||
let stream_id = egraph.enodes[kind_children[1]]
|
||||
.0
|
||||
.replace("\"", "")
|
||||
.parse::<usize>()
|
||||
.unwrap();
|
||||
let iter = egraph.enodes[kind_children[2]]
|
||||
.0
|
||||
.replace("\"", "")
|
||||
.parse::<usize>()
|
||||
.unwrap();
|
||||
let dtype = extract_dtype(egraph, kind_children[3]);
|
||||
(
|
||||
LLIROp::new::<LoopOutputSelect>(Box::new(Self {
|
||||
loop_id,
|
||||
stream_id,
|
||||
iter,
|
||||
dtype,
|
||||
})),
|
||||
input_enodes,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl HLIROp for LoopOutputSelect {
|
||||
fn to_egglog(&self, inp: &[(NodeIndex, String)]) -> String {
|
||||
format!(
|
||||
"(Op (LoopOutputSelect {} {} {} ({:?})) {})",
|
||||
self.loop_id,
|
||||
self.stream_id,
|
||||
self.iter,
|
||||
self.dtype,
|
||||
list_to_egglog(&inp.iter().map(|i| &i.1).collect_vec(), "ICons", "INil"),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl NativeOp for LoopOutputSelect {
|
||||
fn execute(&self, _: Vec<&NativeData>, _: &FxHashMap<char, usize>) -> NativeData {
|
||||
unimplemented!("LoopOutputSelect is driven by the runtime loop compiler")
|
||||
}
|
||||
}
|
||||
|
||||
/// Produces a single number constant from an expression or a float
|
||||
#[derive(Clone, PartialEq, Default)]
|
||||
pub struct Constant(pub f32);
|
||||
@@ -550,28 +1246,6 @@ impl NativeOp for Cast {
|
||||
}
|
||||
}
|
||||
|
||||
/// Graph break for chunking search graphs
|
||||
#[derive(Clone, PartialEq, Default)]
|
||||
pub struct GraphBreak {
|
||||
pub input_shape: ShapeTracker,
|
||||
}
|
||||
impl Debug for GraphBreak {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "GraphBreak")
|
||||
}
|
||||
}
|
||||
impl Display for GraphBreak {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "GraphBreak")
|
||||
}
|
||||
}
|
||||
|
||||
impl HLIROp for GraphBreak {
|
||||
fn to_egglog(&self, _: &[(NodeIndex, String)]) -> String {
|
||||
panic!("Cannot turn GraphBreak into egglog op!");
|
||||
}
|
||||
}
|
||||
|
||||
// Unary Op (A -> A)
|
||||
|
||||
fn unary_impl(
|
||||
@@ -1004,7 +1678,12 @@ impl EgglogOp for Add {
|
||||
2
|
||||
}
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![dtype_propagation_op(&self.sort())]
|
||||
let mut r = vec![dtype_propagation_op(&self.sort())];
|
||||
r.extend(self.early_rewrites());
|
||||
r
|
||||
}
|
||||
fn early_rewrites(&self) -> Vec<Rule> {
|
||||
binary_op_unroll_rules("Add", 4)
|
||||
}
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
@@ -1089,7 +1768,12 @@ impl EgglogOp for Mul {
|
||||
2
|
||||
}
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![dtype_propagation_op(&self.sort())]
|
||||
let mut r = vec![dtype_propagation_op(&self.sort())];
|
||||
r.extend(self.early_rewrites());
|
||||
r
|
||||
}
|
||||
fn early_rewrites(&self) -> Vec<Rule> {
|
||||
binary_op_unroll_rules("Mul", 4)
|
||||
}
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
@@ -1174,7 +1858,12 @@ impl EgglogOp for Mod {
|
||||
2
|
||||
}
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![dtype_propagation_op(&self.sort())]
|
||||
let mut r = vec![dtype_propagation_op(&self.sort())];
|
||||
r.extend(self.early_rewrites());
|
||||
r
|
||||
}
|
||||
fn early_rewrites(&self) -> Vec<Rule> {
|
||||
binary_op_unroll_rules("Mod", 4)
|
||||
}
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
@@ -1259,8 +1948,13 @@ impl EgglogOp for LessThan {
|
||||
2
|
||||
}
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
// Comparison operations always output Bool
|
||||
vec![dtype_fixed_op(&self.sort(), &SORTS.bool_dt)]
|
||||
// Comparisons output Bool, not the input dtype.
|
||||
let mut r = vec![dtype_fixed_op(&self.sort(), &SORTS.bool_dt)];
|
||||
r.extend(self.early_rewrites());
|
||||
r
|
||||
}
|
||||
fn early_rewrites(&self) -> Vec<Rule> {
|
||||
binary_op_unroll_rules("LessThan", 4)
|
||||
}
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
@@ -2195,6 +2889,10 @@ impl Runtime for NativeRuntime {
|
||||
(0, "0 ms".to_string())
|
||||
}
|
||||
|
||||
fn aggregate_profile_metrics(metrics: &[Self::ProfileMetric]) -> Self::ProfileMetric {
|
||||
metrics.iter().copied().sum()
|
||||
}
|
||||
|
||||
fn load_llir(&mut self, llir_graph: &LLIRGraph) {
|
||||
// Extract nativeop graph
|
||||
let mut graph = StableGraph::new();
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
pub mod dtype;
|
||||
pub mod dyn_backend;
|
||||
pub mod egglog_utils;
|
||||
pub mod frontend;
|
||||
pub mod graph;
|
||||
|
||||
16
src/op.rs
16
src/op.rs
@@ -21,6 +21,16 @@ pub trait Runtime {
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
trials: usize,
|
||||
) -> (Self::ProfileMetric, String);
|
||||
/// Aggregate multiple profile metrics into one comparable metric.
|
||||
/// Used for regionalized profiling where one candidate maps to multiple LLIR regions.
|
||||
fn aggregate_profile_metrics(metrics: &[Self::ProfileMetric]) -> Self::ProfileMetric {
|
||||
metrics
|
||||
.first()
|
||||
.unwrap_or_else(|| panic!("aggregate_profile_metrics called with empty metrics"))
|
||||
.clone()
|
||||
}
|
||||
/// Optional per-candidate profiling timeout used by search.
|
||||
fn set_profile_timeout(&mut self, _timeout: Option<std::time::Duration>) {}
|
||||
/// Allocate a dummy input buffer for a boundary node during per-chunk profiling.
|
||||
/// `node_index` is the HLIR node index used in the Input op's `node` field.
|
||||
/// `num_bytes` is the number of bytes to allocate.
|
||||
@@ -226,7 +236,11 @@ impl LLIROp {
|
||||
assert!(
|
||||
op.type_name().contains("dyn")
|
||||
|| op.type_name().contains("Input")
|
||||
|| op.type_name().contains("Output"),
|
||||
|| op.type_name().contains("Output")
|
||||
|| op.type_name().contains("LoopStart")
|
||||
|| op.type_name().contains("LoopEnd")
|
||||
|| op.type_name().contains("LoopInput")
|
||||
|| op.type_name().contains("LoopOutput"),
|
||||
"op types must be erased into dialect traits for dialect casting to work!"
|
||||
);
|
||||
Self(Arc::new(Box::new(DialectOp::new(op))))
|
||||
|
||||
@@ -485,3 +485,56 @@ fn test_only_outputs_remain() {
|
||||
.count();
|
||||
assert_eq!(rt.buffers.len(), output_count);
|
||||
}
|
||||
|
||||
fn build_repeated_block_graph(
|
||||
layers: usize,
|
||||
width: usize,
|
||||
) -> (Graph, NodeIndex, Vec<NodeIndex>, NodeIndex) {
|
||||
let mut cx = Graph::new();
|
||||
let x = cx.tensor(width);
|
||||
let mut state = x;
|
||||
let mut weight_nodes = Vec::with_capacity(layers * 2);
|
||||
for i in 0..layers {
|
||||
let w = cx.named_tensor(format!("w_{i}"), width);
|
||||
let b = cx.named_tensor(format!("b_{i}"), width);
|
||||
weight_nodes.push(w.id);
|
||||
weight_nodes.push(b.id);
|
||||
state = ((state * w) + b).sin();
|
||||
}
|
||||
let y = state.output();
|
||||
(cx, x.id, weight_nodes, y.id)
|
||||
}
|
||||
|
||||
fn repeated_block_reference(layers: usize, input: &[f32], weights: &[Vec<f32>]) -> Vec<f32> {
|
||||
let mut state = input.to_vec();
|
||||
for i in 0..layers {
|
||||
let w = &weights[i * 2];
|
||||
let b = &weights[i * 2 + 1];
|
||||
for ((s, wi), bi) in state.iter_mut().zip(w.iter()).zip(b.iter()) {
|
||||
*s = (*s * *wi + *bi).sin();
|
||||
}
|
||||
}
|
||||
state
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn integration_auto_loop_rolling_matches_reference_native_runtime() {
|
||||
let layers = 12;
|
||||
let width = 16;
|
||||
let input = random_vec(width);
|
||||
let weights: Vec<Vec<f32>> = (0..layers * 2).map(|_| random_vec(width)).collect();
|
||||
|
||||
let reference = repeated_block_reference(layers, &input, &weights);
|
||||
|
||||
let (mut graph, input_id, weight_ids, output_id) = build_repeated_block_graph(layers, width);
|
||||
graph.build_search_space::<NativeRuntime>();
|
||||
let mut rt = graph.search(NativeRuntime::default(), 1);
|
||||
rt.set_data(input_id, input);
|
||||
for (node, data) in weight_ids.iter().zip(weights.iter()) {
|
||||
rt.set_data(*node, data.clone());
|
||||
}
|
||||
rt.execute(&graph.dyn_map);
|
||||
let out = rt.get_f32(output_id);
|
||||
|
||||
assert_close(&reference, out);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user