Compare commits

..

17 Commits

Author SHA1 Message Date
Tucker Morgan
4d32024f0f fix: rename swiglu_fn to activation_fn to match usage
The destructured kernel variable at line 327 was named swiglu_fn but
referenced as activation_fn at line 449, causing E0425. The kernel now
supports multiple activation modes (SiLU + approximate GELU), so
activation_fn is the more accurate name.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-20 20:25:56 +00:00
tucker-luminal
d66b3f2643 Merge branch 'main' into feat/luminal-python-moe-routing-support 2026-04-20 13:16:43 -07:00
Joe Fioti
66b0807462 Merge pull request #272 from luminal-ai/gemma
Gemma
2026-04-19 09:02:30 -07:00
Joe Fioti
c24ea4a7a5 fmt 2026-04-19 15:38:38 +00:00
Joe Fioti
c309d9b4ed clippy 2026-04-19 15:37:44 +00:00
Joe Fioti
745c071ee5 factored out the moe rules 2026-04-19 04:59:38 +00:00
Joe Fioti
56ffe8bbb3 Remove example tests and generated graph artifacts 2026-04-18 17:42:43 +00:00
Joe Fioti
13dbdcb53b gemma fix 2026-04-17 18:47:18 +00:00
Joe Fioti
c8ad5f8b75 fix 2026-04-17 18:01:56 +00:00
Joe Fioti
51c6596f6a cicd fix 2026-04-17 15:35:23 +00:00
Joe Fioti
aef4c68537 fixed qwen3_moe precision and rewrites 2026-04-17 05:16:03 +00:00
tucker-luminal
0af1c186fd Update unary.rs
Fixing a bug here, this should get the cuda tests passing again
2026-04-15 11:18:03 -07:00
Ubuntu
86b2784b51 Merge main into MoE routing branch, fix PyTorch 2.11 compat 2026-04-15 16:38:25 +00:00
Ubuntu
53c58576fc Fix qwen3 MoE cuBLASLt rewrite gating 2026-04-12 02:29:19 +00:00
Ubuntu
64e4eedcc6 Fix qwen3 MoE cuBLASLt rewrite gating 2026-04-12 02:29:05 +00:00
Ubuntu
63afb602b0 Format MoE routing test model 2026-04-10 11:07:42 +00:00
Ubuntu
985e7752aa build MoE routing support in luminal_python 2026-04-10 10:45:07 +00:00
24 changed files with 2630 additions and 248 deletions

View File

@@ -45,18 +45,6 @@ cd ./examples/llama
cargo run --release
```
**PyTorch models via `torch.compile`**
Any PyTorch model can be run through Luminal by swapping the backend:
```python
import torch
from luminal import luminal_backend
model_compiled = torch.compile(model, backend=luminal_backend)
output = model_compiled(x)
```
See `crates/luminal_python/` for the PT2-based bridge.
## Features
### Speed
@@ -87,7 +75,7 @@ The current ML ecosystem is too fragmented, and the solution isn't another layer
### Validated against Pytorch
Correctness matters. We write as much tests as possible to cover all ops and verify they work the same as an equivalent Pytorch implementation.
Correctness matters. We write as much tests as possible to cover all ops and verify they work the same as an equivalent Pytorch implementation. ([Improvements needed!](https://github.com/jafioti/luminal/issues/20))
## Ideology
@@ -114,12 +102,12 @@ Now we can do:
## Where are we?
- Search is the default execution path — compile via `build_search_space` and `search` (see the Usage example above).
- Search is partially merged. We are between 1.0 and 2.0 (search), which will be completed within the next month or so.
- Metal and Cuda are supported for running models on Macs and Nvidia GPUs respectively, in both full and half precision.
- Llama 3, Gemma, Qwen (incl. MoE variants), and a paged-attention Llama are implemented in `examples/`. See instructions above for running.
- Full training support with graph-based autograd.
- Llama 3, Phi 3, Whisper and Yolo v8 are implemented in `examples/`. See instructions above for running.
- We have a small library of NN modules in `luminal_nn`, including transformers.
- A large surface of high-level ops lives in `src/frontend/` — aiming to match the most used ~80% of the PyTorch api.
- PyTorch models can be run through luminal via `torch.compile` — see `crates/luminal_python/`.
- A significant amount of high-level ops are implemented in `hl_ops`. We are aiming to match the most used ~80% of the pytorch api.
Some things on the roadmap:

View File

@@ -32,6 +32,7 @@ use crate::{
driver::{CudaSlice, CudaStream, DevicePtr},
},
host::{HostOp, cublas::parse_cublas_op},
try_create_cublaslt,
};
#[derive(Debug)]
@@ -248,6 +249,19 @@ fn dtype_to_cuda_types(dtype: DType) -> (cudaDataType, cublasComputeType_t, cuda
}
}
impl CuBlasLt {
fn get_cublaslt(&self, stream: &Arc<CudaStream>) -> anyhow::Result<Arc<CudaBlasLT>> {
if let Some(cublaslt) = self.cublaslt.get() {
return Ok(cublaslt.clone());
}
let created = try_create_cublaslt(stream.clone()).map_err(|message| {
anyhow::anyhow!("cuBLASLt unavailable on this machine: {message}")
})?;
let _ = self.cublaslt.set(created.clone());
Ok(created)
}
}
impl HostOp for CuBlasLt {
fn execute(
&self,
@@ -324,9 +338,7 @@ impl HostOp for CuBlasLt {
)
.entered();
let cublaslt = self
.cublaslt
.get_or_init(|| Arc::new(CudaBlasLT::new(stream.clone()).unwrap()));
let cublaslt = self.get_cublaslt(stream)?;
let mut matmul_desc: cublasLtMatmulDesc_t = std::ptr::null_mut();
let mut a_desc: cublasLtMatrixLayout_t = std::ptr::null_mut();

View File

@@ -1,128 +1,213 @@
; GLUMoE: Match the expert computation subgraph of a Gated MoE (SwiGLU variant).
; GLUMoE: Match the expert computation subgraph of a gated MoE.
;
; This matches the pattern produced by QwenMoE::forward() starting from the
; expert gathers through to the final weighted sum, and replaces it with a
; fused GLUMoE HostOp.
; One fused op supports two activation modes:
; mode=0: Qwen-style SwiGLU (silu(gate) * up)
; mode=1: Gemma-style GELU (gate * sigmoid(1.595769 * gate * (1 + 0.044715 * gate^2)))
;
; Inputs extracted:
; ?x - input activations [s, H] F32
; ?topk_idx - top-k expert indices [s, k] Int (from argsort+slice)
; ?topk_vals - top-k routing values [s, k] F32 (from gather on softmax)
; ?gate_up_w - stacked gate+up expert weights [E, intermediate*2, H] BF16
; ?down_w - stacked down expert weights [E, H, intermediate] BF16
;
; The pattern captures:
; 1. Gate-up expert gather (Iota, Mul, Cast, Iota, Cast, Add, Cast, Gather)
; 2. Cast BF16→F32 of gathered gate-up weights
; 3. Gate-up batched matmul (Mul + SumReduce)
; 4. Gate/Up split via Iota+Gather (slice semantics)
; 5. SwiGLU: silu(gate) * up
; 6. Down expert gather (same pattern as gate-up)
; 7. Cast BF16→F32 of gathered down weights
; 8. Down batched matmul (Mul + SumReduce)
; 9. Weighted sum: (down_out * topk_values) summed over k
;
; Variables with ? prefix are egglog pattern variables.
; We use wildcards (?_xxx) for shapes/strides we don't extract.
; To keep matching fast, we stage through marker states:
; 1) Shared gate-up matmul marker
; 2) Activation marker (separate swiglu / gemma_gelu paths)
; 3) Down matmul marker (separate swiglu / gemma_gelu paths)
; 4) Final GLUMoE fusion (separate swiglu / gemma_gelu rules)
(datatype*
(GLUMoEGateUpState
(MkGLUMoEGateUpState Expression Expression Expression IR IR IR)
)
(GLUMoESwiGLUState
(MkGLUMoESwiGLUState GLUMoEGateUpState)
)
(GLUMoEGemmaGELUState
(MkGLUMoEGemmaGELUState GLUMoEGateUpState)
)
(GLUMoESwiGLUDownState
(MkGLUMoESwiGLUDownState Expression Expression Expression GLUMoESwiGLUState IR IR)
)
(GLUMoEGemmaDownState
(MkGLUMoEGemmaDownState Expression Expression Expression GLUMoEGemmaGELUState IR IR)
)
)
(function glumoe_gate_up (IR) GLUMoEGateUpState :merge new)
(function glumoe_swiglu (IR) GLUMoESwiGLUState :merge new)
(function glumoe_gemma_gelu (IR) GLUMoEGemmaGELUState :merge new)
(function glumoe_swiglu_down (IR) GLUMoESwiGLUDownState :merge new)
(function glumoe_gemma_down (IR) GLUMoEGemmaDownState :merge new)
(rule
(
; ===== Gate-up expert gather =====
; t51: Iota for base index (expert_idx * io_gu)
(= ?gu_iota_base (Op (Iota ?gu_io ?gu_iota_base_range) (INil)))
; t52: Mul topk_indices * io → base offsets [s, k]
(= ?gu_mul_base (Op (Mul ?gu_mul_base_shape ?gu_mul_base_a_stride ?gu_mul_base_b_stride ?gu_mul_base_out_stride) (ICons ?topk_idx (ICons ?gu_iota_base (INil)))))
; t53: Cast to F32
(= ?gu_cast_base (Op (Cast ?gu_cast_base_size (F32)) (ICons ?gu_mul_base (INil))))
; t54: Iota for within-expert index
(= ?gu_iota_within (Op (Iota (MIter) ?gu_iota_within_range) (INil)))
; t55: Cast within to F32
(= ?gu_cast_within (Op (Cast ?gu_cast_within_size (F32)) (ICons ?gu_iota_within (INil))))
; t56: Add base + within → flat gather indices
(= ?gu_add_idx (Op (Add ?gu_add_shape ?gu_add_a_stride ?gu_add_b_stride ?gu_add_out_stride) (ICons ?gu_cast_base (ICons ?gu_cast_within (INil)))))
; t57: Cast to Int
(= ?gu_cast_idx (Op (Cast ?gu_cast_idx_size (Int)) (ICons ?gu_add_idx (INil))))
; t58: Gather gate_up weights
(= ?gu_gathered (Op (Gather ?gu_gather_idx_shape ?gu_gather_idx_stride ?gu_gather_data_shape ?gu_gather_data_stride) (ICons ?gu_cast_idx (ICons ?gate_up_w (INil)))))
(= ?gu_add_idx (Op (Add ?gu_add_shape ?gu_add_a_stride ?gu_add_b_stride ?gu_add_out_stride) (ICons ?gu_mul_base (ICons ?gu_iota_within (INil)))))
(= ?gu_gathered (Op (Gather ?gu_gather_idx_shape ?gu_gather_idx_stride ?gu_gather_data_shape ?gu_gather_data_stride) (ICons ?gu_add_idx (ICons ?gate_up_w (INil)))))
; ===== Cast BF16→F32 =====
; t59: Cast gathered gate_up to F32
(= ?gu_f32 (Op (Cast ?gu_f32_size (F32)) (ICons ?gu_gathered (INil))))
; ===== Gate-up batched matmul =====
; t60: Mul x * gathered_gu (broadcast multiply)
(= ?gu_matmul_mul (Op (Mul ?gu_matmul_mul_shape ?gu_matmul_a_stride ?gu_matmul_b_stride ?gu_matmul_mul_out_stride) (ICons ?x (ICons ?gu_f32 (INil)))))
; t61: SumReduce over K dimension
(= ?gu_matmul (Op (Sum ?gu_matmul_out_shape ?gu_matmul_k ?gu_matmul_in_stride ?gu_matmul_k_stride ?gu_matmul_out_stride) (ICons ?gu_matmul_mul (INil))))
)
(
(set (glumoe_gate_up ?gu_matmul)
(MkGLUMoEGateUpState ?gu_io ?gu_matmul_k ?gu_iota_within_range ?x ?topk_idx ?gate_up_w))
)
:name "GLUMoE gate-up matmul marker"
)
; ===== SwiGLU activation marker =====
(rule
(
(= ?gate_up_state (glumoe_gate_up ?gu_matmul))
(= ?gate_up_state (MkGLUMoEGateUpState ?gu_io ?gu_matmul_k ?gu_within_range ?x ?topk_idx ?gate_up_w))
; ===== Up slice via Iota+Gather =====
; t62: Iota with complex expression (slicing the "up" half)
(= ?up_iota (Op (Iota ?up_iota_expr ?up_iota_range) (INil)))
; t63: Gather to select up portion from matmul result
(= ?up_slice (Op (Gather ?up_gather_idx_shape ?up_gather_idx_stride ?up_gather_data_shape ?up_gather_data_stride) (ICons ?up_iota (ICons ?gu_matmul (INil)))))
; ===== SwiGLU: silu(gate) * up =====
; t64: Constant(-1)
(= ?neg1 (Op (Constant -1.000000) (INil)))
; t65: gate * -1
(= ?neg_gate (Op (Mul ?silu_shape1 ?silu_a_stride1 ?silu_b_stride1 ?silu_out_stride1) (ICons ?gu_matmul (ICons ?neg1 (INil)))))
; t66: Constant(log2e)
(= ?log2e (Op (Constant 1.442695) (INil)))
; t67: neg_gate * log2e
(= ?scaled (Op (Mul ?silu_shape2 ?silu_a_stride2 ?silu_b_stride2 ?silu_out_stride2) (ICons ?neg_gate (ICons ?log2e (INil)))))
; t68: exp2
(= ?exp2_val (Op (Exp2 ?silu_shape3 ?silu_in_stride3 ?silu_out_stride3) (ICons ?scaled (INil))))
; t69: Constant(1)
(= ?one (Op (Constant 1.000000) (INil)))
; t70: exp2 + 1
(= ?plus1 (Op (Add ?silu_shape4 ?silu_a_stride4 ?silu_b_stride4 ?silu_out_stride4) (ICons ?exp2_val (ICons ?one (INil)))))
; t71: recip
(= ?sigmoid (Op (Recip ?silu_shape5 ?silu_in_stride5 ?silu_out_stride5) (ICons ?plus1 (INil))))
; t72: gate * sigmoid(gate) = silu(gate)
(= ?silu_out (Op (Mul ?silu_shape6 ?silu_a_stride6 ?silu_b_stride6 ?silu_out_stride6) (ICons ?gu_matmul (ICons ?sigmoid (INil)))))
; t73: silu(gate) * up
(= ?swiglu_out (Op (Mul ?swiglu_shape ?swiglu_a_stride ?swiglu_b_stride ?swiglu_out_stride) (ICons ?silu_out (ICons ?up_slice (INil)))))
)
(
(set (glumoe_swiglu ?swiglu_out) (MkGLUMoESwiGLUState ?gate_up_state))
)
:name "GLUMoE swiglu marker"
)
; ===== Gemma GELU activation marker =====
(rule
(
(= ?gate_up_state (glumoe_gate_up ?gu_matmul))
(= ?gate_up_state (MkGLUMoEGateUpState ?gu_io ?gu_matmul_k ?gu_within_range ?x ?topk_idx ?gate_up_w))
(= ?up_iota (Op (Iota ?up_iota_expr ?up_iota_range) (INil)))
(= ?up_slice (Op (Gather ?up_gather_idx_shape ?up_gather_idx_stride ?up_gather_data_shape ?up_gather_data_stride) (ICons ?up_iota (ICons ?gu_matmul (INil)))))
(= ?gelu_coeff_inner (Op (Constant 0.044715) (INil)))
(= ?gelu_inner_scaled (Op (Mul ?gelu_inner_scaled_shape ?gelu_inner_scaled_a_stride ?gelu_inner_scaled_b_stride ?gelu_inner_scaled_out_stride) (ICons ?gu_matmul (ICons ?gelu_coeff_inner (INil)))))
(= ?gelu_inner_quad (Op (Mul ?gelu_inner_quad_shape ?gelu_inner_quad_a_stride ?gelu_inner_quad_b_stride ?gelu_inner_quad_out_stride) (ICons ?gelu_inner_scaled (ICons ?gu_matmul (INil)))))
(= ?gelu_one (Op (Constant 1.000000) (INil)))
(= ?gelu_poly (Op (Add ?gelu_poly_shape ?gelu_poly_a_stride ?gelu_poly_b_stride ?gelu_poly_out_stride) (ICons ?gelu_inner_quad (ICons ?gelu_one (INil)))))
(= ?gelu_coeff_outer (Op (Constant 1.595769) (INil)))
(= ?gelu_outer_scaled (Op (Mul ?gelu_outer_scaled_shape ?gelu_outer_scaled_a_stride ?gelu_outer_scaled_b_stride ?gelu_outer_scaled_out_stride) (ICons ?gu_matmul (ICons ?gelu_coeff_outer (INil)))))
(= ?gelu_scaled (Op (Mul ?gelu_scaled_shape ?gelu_scaled_a_stride ?gelu_scaled_b_stride ?gelu_scaled_out_stride) (ICons ?gelu_outer_scaled (ICons ?gelu_poly (INil)))))
(= ?neg1 (Op (Constant -1.000000) (INil)))
(= ?gelu_neg (Op (Mul ?gelu_neg_shape ?gelu_neg_a_stride ?gelu_neg_b_stride ?gelu_neg_out_stride) (ICons ?gelu_scaled (ICons ?neg1 (INil)))))
(= ?log2e (Op (Constant 1.442695) (INil)))
(= ?gelu_exp_scaled (Op (Mul ?gelu_exp_scaled_shape ?gelu_exp_scaled_a_stride ?gelu_exp_scaled_b_stride ?gelu_exp_scaled_out_stride) (ICons ?gelu_neg (ICons ?log2e (INil)))))
(= ?gelu_exp2_val (Op (Exp2 ?gelu_exp_shape ?gelu_exp_in_stride ?gelu_exp_out_stride) (ICons ?gelu_exp_scaled (INil))))
(= ?gelu_plus1 (Op (Add ?gelu_plus1_shape ?gelu_plus1_a_stride ?gelu_plus1_b_stride ?gelu_plus1_out_stride) (ICons ?gelu_exp2_val (ICons ?gelu_one (INil)))))
(= ?gelu_sigmoid (Op (Recip ?gelu_sigmoid_shape ?gelu_sigmoid_in_stride ?gelu_sigmoid_out_stride) (ICons ?gelu_plus1 (INil))))
(= ?gelu_out (Op (Mul ?gelu_out_shape ?gelu_out_a_stride ?gelu_out_b_stride ?gelu_out_out_stride) (ICons ?gu_matmul (ICons ?gelu_sigmoid (INil)))))
(= ?gemma_out (Op (Mul ?geglu_shape ?geglu_a_stride ?geglu_b_stride ?geglu_out_stride) (ICons ?gelu_out (ICons ?up_slice (INil)))))
)
(
(set (glumoe_gemma_gelu ?gemma_out) (MkGLUMoEGemmaGELUState ?gate_up_state))
)
:name "GLUMoE gemma gelu marker"
)
; ===== SwiGLU down marker =====
(rule
(
(= ?swiglu_state (glumoe_swiglu ?swiglu_out))
(= ?swiglu_state (MkGLUMoESwiGLUState ?gate_up_state))
; ===== Down expert gather =====
; t74: Iota for base index (expert_idx * io_down)
(= ?dn_iota_base (Op (Iota ?dn_io ?dn_iota_base_range) (INil)))
; t75: Mul topk_indices * io_down
(= ?dn_mul_base (Op (Mul ?dn_mul_base_shape ?dn_mul_base_a_stride ?dn_mul_base_b_stride ?dn_mul_base_out_stride) (ICons ?topk_idx (ICons ?dn_iota_base (INil)))))
; t76: Cast to F32
(= ?dn_cast_base (Op (Cast ?dn_cast_base_size (F32)) (ICons ?dn_mul_base (INil))))
; t77: Iota for within-expert index
(= ?dn_iota_within (Op (Iota (MIter) ?dn_iota_within_range) (INil)))
; t78: Cast within to F32
(= ?dn_cast_within (Op (Cast ?dn_cast_within_size (F32)) (ICons ?dn_iota_within (INil))))
; t79: Add base + within
(= ?dn_add_idx (Op (Add ?dn_add_shape ?dn_add_a_stride ?dn_add_b_stride ?dn_add_out_stride) (ICons ?dn_cast_base (ICons ?dn_cast_within (INil)))))
; t80: Cast to Int
(= ?dn_cast_idx (Op (Cast ?dn_cast_idx_size (Int)) (ICons ?dn_add_idx (INil))))
; t81: Gather down weights
(= ?dn_gathered (Op (Gather ?dn_gather_idx_shape ?dn_gather_idx_stride ?dn_gather_data_shape ?dn_gather_data_stride) (ICons ?dn_cast_idx (ICons ?down_w (INil)))))
; ===== Cast BF16→F32 =====
; t82: Cast gathered down to F32
(= ?dn_add_idx (Op (Add ?dn_add_shape ?dn_add_a_stride ?dn_add_b_stride ?dn_add_out_stride) (ICons ?dn_mul_base (ICons ?dn_iota_within (INil)))))
(= ?dn_gathered (Op (Gather ?dn_gather_idx_shape ?dn_gather_idx_stride ?dn_gather_data_shape ?dn_gather_data_stride) (ICons ?dn_add_idx (ICons ?down_w (INil)))))
(= ?dn_f32 (Op (Cast ?dn_f32_size (F32)) (ICons ?dn_gathered (INil))))
; ===== Down batched matmul =====
; t83: Mul swiglu_out * gathered_down (broadcast multiply)
(= ?dn_matmul_mul (Op (Mul ?dn_matmul_mul_shape ?dn_matmul_a_stride ?dn_matmul_b_stride ?dn_matmul_mul_out_stride) (ICons ?swiglu_out (ICons ?dn_f32 (INil)))))
; t84: SumReduce
(= ?dn_matmul (Op (Sum ?dn_matmul_out_shape ?dn_matmul_k ?dn_matmul_in_stride ?dn_matmul_k_stride ?dn_matmul_out_stride) (ICons ?dn_matmul_mul (INil))))
)
(
(set (glumoe_swiglu_down ?dn_matmul)
(MkGLUMoESwiGLUDownState ?dn_io ?dn_matmul_k ?dn_iota_within_range ?swiglu_state ?topk_idx ?down_w))
)
:name "GLUMoE swiglu down marker"
)
; ===== Gemma GELU down marker =====
(rule
(
(= ?gemma_state (glumoe_gemma_gelu ?gemma_out))
(= ?gemma_state (MkGLUMoEGemmaGELUState ?gate_up_state))
(= ?dn_iota_base (Op (Iota ?dn_io ?dn_iota_base_range) (INil)))
(= ?dn_mul_base (Op (Mul ?dn_mul_base_shape ?dn_mul_base_a_stride ?dn_mul_base_b_stride ?dn_mul_base_out_stride) (ICons ?topk_idx (ICons ?dn_iota_base (INil)))))
(= ?dn_iota_within (Op (Iota (MIter) ?dn_iota_within_range) (INil)))
(= ?dn_add_idx (Op (Add ?dn_add_shape ?dn_add_a_stride ?dn_add_b_stride ?dn_add_out_stride) (ICons ?dn_mul_base (ICons ?dn_iota_within (INil)))))
(= ?dn_gathered (Op (Gather ?dn_gather_idx_shape ?dn_gather_idx_stride ?dn_gather_data_shape ?dn_gather_data_stride) (ICons ?dn_add_idx (ICons ?down_w (INil)))))
(= ?dn_f32 (Op (Cast ?dn_f32_size (F32)) (ICons ?dn_gathered (INil))))
(= ?dn_matmul_mul (Op (Mul ?dn_matmul_mul_shape ?dn_matmul_a_stride ?dn_matmul_b_stride ?dn_matmul_mul_out_stride) (ICons ?gemma_out (ICons ?dn_f32 (INil)))))
(= ?dn_matmul (Op (Sum ?dn_matmul_out_shape ?dn_matmul_k ?dn_matmul_in_stride ?dn_matmul_k_stride ?dn_matmul_out_stride) (ICons ?dn_matmul_mul (INil))))
)
(
(set (glumoe_gemma_down ?dn_matmul)
(MkGLUMoEGemmaDownState ?dn_io ?dn_matmul_k ?dn_iota_within_range ?gemma_state ?topk_idx ?down_w))
)
:name "GLUMoE gemma down marker"
)
; ===== Final fusion: mode 0 (SwiGLU) =====
(rule
(
(= ?down_state (glumoe_swiglu_down ?dn_matmul))
(= ?down_state (MkGLUMoESwiGLUDownState ?dn_io ?dn_matmul_k ?dn_within_range ?swiglu_state ?topk_idx ?down_w))
(= ?swiglu_state (MkGLUMoESwiGLUState ?gate_up_state))
(= ?gate_up_state (MkGLUMoEGateUpState ?gu_io ?gu_matmul_k ?gu_within_range ?x ?topk_idx ?gate_up_w))
; ===== Weighted sum over k experts =====
; t85: Mul down_out * topk_values
(= ?weighted (Op (Mul ?weighted_shape ?weighted_a_stride ?weighted_b_stride ?weighted_out_stride) (ICons ?dn_matmul (ICons ?topk_vals (INil)))))
; t86: SumReduce over k dimension → [s, H]
(= ?output (Op (Sum ?output_shape ?output_k ?output_in_stride ?output_k_stride ?output_out_stride) (ICons ?weighted (INil))))
)
(
(let ?glumoe (Op (GLUMoE
?gu_io ?dn_io ?gu_matmul_k ?dn_matmul_k ?output_k
?gu_iota_within_range ?dn_iota_within_range)
(ICons ?x (ICons ?topk_idx (ICons ?topk_vals (ICons ?gate_up_w (ICons ?down_w (INil))))))))
?gu_within_range ?dn_within_range (MNum 0))
(ICons ?x (ICons ?topk_idx (ICons ?topk_vals (ICons ?gate_up_w (ICons ?down_w (ICons ?topk_vals (INil)))))))))
(union ?output ?glumoe)
)
:name "GLUMoE fused expert computation"
:name "GLUMoE fused expert computation (swiglu)"
)
; ===== Final fusion: mode 1 (Gemma GELU) =====
(rule
(
(= ?down_state (glumoe_gemma_down ?dn_matmul))
(= ?down_state (MkGLUMoEGemmaDownState ?dn_io ?dn_matmul_k ?dn_within_range ?gemma_state ?topk_idx ?down_w))
(= ?gemma_state (MkGLUMoEGemmaGELUState ?gate_up_state))
(= ?gate_up_state (MkGLUMoEGateUpState ?gu_io ?gu_matmul_k ?gu_within_range ?x ?topk_idx ?gate_up_w))
; Gemma expert weights: topk_weights = normed_topk * per_expert_scale.gather(topk_idx)
(= ?per_expert_vals (Op (Gather ?scale_gather_idx_shape ?scale_gather_idx_stride ?scale_gather_data_shape ?scale_gather_data_stride) (ICons ?topk_idx (ICons ?per_expert_scale (INil)))))
(= ?topk_row_offsets (Op (Iota ?topk_row_offsets_expr ?topk_row_offsets_range) (INil)))
(= ?topk_flat_idx (Op (Add ?topk_flat_idx_shape ?topk_flat_idx_a_stride ?topk_flat_idx_b_stride ?topk_flat_idx_out_stride) (ICons ?topk_row_offsets (ICons ?topk_idx (INil)))))
(= ?topk_vals (Op (Gather ?topk_vals_gather_idx_shape ?topk_vals_gather_idx_stride ?topk_vals_gather_data_shape ?topk_vals_gather_data_stride) (ICons ?topk_flat_idx (ICons ?routing_weights (INil)))))
(= ?topk_norm (Op (Sum ?topk_norm_shape ?output_k ?topk_norm_in_stride ?topk_norm_k_stride ?topk_norm_out_stride) (ICons ?topk_vals (INil))))
(= ?topk_norm_factor (Op (Recip ?topk_norm_recip_shape ?topk_norm_recip_in_stride ?topk_norm_recip_out_stride) (ICons ?topk_norm (INil))))
(= ?normed_topk (Op (Mul ?normed_topk_shape ?normed_topk_a_stride ?normed_topk_b_stride ?normed_topk_out_stride) (ICons ?topk_vals (ICons ?topk_norm_factor (INil)))))
(= ?expert_weights (Op (Mul ?expert_weights_shape ?expert_weights_a_stride ?expert_weights_b_stride ?expert_weights_out_stride) (ICons ?normed_topk (ICons ?per_expert_vals (INil)))))
(= ?weighted (Op (Mul ?weighted_shape ?weighted_a_stride ?weighted_b_stride ?weighted_out_stride) (ICons ?dn_matmul (ICons ?expert_weights (INil)))))
(= ?output (Op (Sum ?output_shape ?output_k ?output_in_stride ?output_k_stride ?output_out_stride) (ICons ?weighted (INil))))
)
(
(let ?glumoe (Op (GLUMoE
?gu_io ?dn_io ?gu_matmul_k ?dn_matmul_k ?output_k
?gu_within_range ?dn_within_range (MNum 1))
(ICons ?x (ICons ?topk_idx (ICons ?topk_vals (ICons ?gate_up_w (ICons ?down_w (ICons ?per_expert_scale (INil)))))))))
(union ?output ?glumoe)
)
:name "GLUMoE fused expert computation (gemma_gelu)"
)

View File

@@ -33,14 +33,15 @@ use crate::{
},
},
host::HostOp,
try_create_cublaslt,
};
const WORKSPACE_SIZE: usize = 32 * 1024 * 1024; // 32 MiB
/// Fused GLU-MoE HostOp matched via egglog pattern.
///
/// Replaces the expert computation subgraph (expert gathers + matmuls + SwiGLU
/// + weighted sum) with an efficient cuBLASLt implementation.
/// Replaces the expert computation subgraph (expert gathers + matmuls + gated
/// activation + weighted sum) with an efficient cuBLASLt implementation.
///
/// Inputs (graph edges, in order):
/// 0: x [seq, hidden] F32
@@ -48,9 +49,13 @@ const WORKSPACE_SIZE: usize = 32 * 1024 * 1024; // 32 MiB
/// 2: topk_values [seq, k] F32
/// 3: gate_up_w [E, gate_up_dim, hidden] BF16
/// 4: down_w [E, hidden, intermediate] BF16
/// 5: mode_aux
/// - SwiGLU: ignored (rewriter wires `topk_values` again)
/// - GemmaGELU: per_expert_scale [E] F32
///
/// Output: [seq, hidden] F32
pub struct GLUMoE {
pub(crate) mode: GLUMoEMode,
/// Product of gate_up weight dimensions per expert (gate_up_dim * hidden) used for gather stride
gu_io: Expression,
/// Product of down weight dimensions per expert (hidden * intermediate) used for gather stride
@@ -69,9 +74,35 @@ pub struct GLUMoE {
module: OnceLock<(Arc<CudaModule>, CudaFunction, CudaFunction)>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum GLUMoEMode {
SwiGLU,
GemmaGELU,
}
impl GLUMoEMode {
fn from_mode_id(mode_id: usize) -> Self {
match mode_id {
0 => Self::SwiGLU,
1 => Self::GemmaGELU,
other => {
panic!("Unknown GLUMoE mode id: {other}");
}
}
}
fn activation_kernel_mode(self) -> i32 {
match self {
Self::SwiGLU => 0,
Self::GemmaGELU => 1,
}
}
}
impl Default for GLUMoE {
fn default() -> Self {
Self {
mode: GLUMoEMode::SwiGLU,
gu_io: Expression::default(),
dn_io: Expression::default(),
gu_matmul_k: Expression::default(),
@@ -88,6 +119,7 @@ impl Default for GLUMoE {
impl std::fmt::Debug for GLUMoE {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("GLUMoE")
.field("mode", &self.mode)
.field("gu_io", &self.gu_io)
.field("dn_io", &self.dn_io)
.field("gu_matmul_k", &self.gu_matmul_k)
@@ -100,6 +132,7 @@ impl std::fmt::Debug for GLUMoE {
impl Clone for GLUMoE {
fn clone(&self) -> Self {
Self {
mode: self.mode,
gu_io: self.gu_io,
dn_io: self.dn_io,
gu_matmul_k: self.gu_matmul_k,
@@ -114,9 +147,15 @@ impl Clone for GLUMoE {
}
impl GLUMoE {
fn get_cublaslt(&self, stream: &Arc<CudaStream>) -> &Arc<CudaBlasLT> {
self.cublaslt
.get_or_init(|| Arc::new(CudaBlasLT::new(stream.clone()).unwrap()))
fn get_cublaslt(&self, stream: &Arc<CudaStream>) -> anyhow::Result<Arc<CudaBlasLT>> {
if let Some(cublaslt) = self.cublaslt.get() {
return Ok(cublaslt.clone());
}
let created = try_create_cublaslt(stream.clone()).map_err(|message| {
anyhow::anyhow!("cuBLASLt unavailable on this machine: {message}")
})?;
let _ = self.cublaslt.set(created.clone());
Ok(created)
}
fn get_kernels(
@@ -134,23 +173,34 @@ extern "C" __global__ void f32_to_bf16(unsigned long long in_ptr, unsigned long
if (i < n) out[i] = __float2bfloat16(in_[i]);
}
extern "C" __global__ void swiglu_bf16(unsigned long long gate_up_ptr, unsigned long long out_ptr, int intermediate) {
extern "C" __global__ void glu_activation_bf16(
unsigned long long gate_up_ptr,
unsigned long long out_ptr,
int intermediate,
int mode
) {
const __nv_bfloat16* gate_up = (const __nv_bfloat16*)gate_up_ptr;
__nv_bfloat16* out = (__nv_bfloat16*)out_ptr;
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < intermediate) {
float gate = __bfloat162float(gate_up[i]);
float up = __bfloat162float(gate_up[i + intermediate]);
float silu = gate / (1.0f + expf(-gate));
out[i] = __float2bfloat16(silu * up);
float activated;
if (mode == 0) {
activated = gate / (1.0f + expf(-gate));
} else {
float scaled = 1.5957691216f * gate * (1.0f + 0.044715f * gate * gate);
activated = gate / (1.0f + expf(-scaled));
}
out[i] = __float2bfloat16(activated * up);
}
}
"#;
let ptx = compile_module_image_for_current_device(stream.context(), src).unwrap();
let module = stream.context().load_module(ptx).unwrap();
let f32_to_bf16 = module.load_function("f32_to_bf16").unwrap();
let swiglu = module.load_function("swiglu_bf16").unwrap();
(module, f32_to_bf16, swiglu)
let activation = module.load_function("glu_activation_bf16").unwrap();
(module, f32_to_bf16, activation)
})
}
}
@@ -168,12 +218,27 @@ impl EgglogOp for GLUMoE {
("output_k", EXPRESSION),
("gu_within_range", EXPRESSION),
("dn_within_range", EXPRESSION),
("mode", EXPRESSION),
],
)
}
fn rewrites(&self) -> Vec<Rule> {
vec![Rule::raw(
"(rule
(
(= ?e (Op (GLUMoE ?gu_io ?dn_io ?gu_matmul_k ?dn_matmul_k ?output_k ?gu_within_range ?dn_within_range ?mode) ?inputs))
)
(
(set (dtype ?e) (F32))
)
:ruleset dtype_prop
)",
)]
}
fn n_inputs(&self) -> usize {
5
6
}
fn early_rewrites(&self) -> Vec<Rule> {
@@ -195,8 +260,14 @@ impl EgglogOp for GLUMoE {
let output_k = extract_expr(egraph, kind_children[4], expr_cache).unwrap();
let gu_within_range = extract_expr(egraph, kind_children[5], expr_cache).unwrap();
let dn_within_range = extract_expr(egraph, kind_children[6], expr_cache).unwrap();
let mode_expr = extract_expr(egraph, kind_children[7], expr_cache).unwrap();
let mode_id = mode_expr
.to_usize()
.unwrap_or_else(|| panic!("GLUMoE mode must be static, got expression: {mode_expr}"));
let mode = GLUMoEMode::from_mode_id(mode_id);
let extracted = GLUMoE {
mode,
gu_io,
dn_io,
gu_matmul_k,
@@ -209,7 +280,7 @@ impl EgglogOp for GLUMoE {
};
let op = LLIROp::new::<dyn HostOp>(Box::new(extracted) as Box<dyn HostOp>);
// Return the 5 IR inputs: x, topk_idx, topk_vals, gate_up_w, down_w
// Return the 6 IR inputs: x, topk_idx, topk_values, gate_up_w, down_w, mode_aux
(op, input_enodes)
}
@@ -230,9 +301,9 @@ impl HostOp for GLUMoE {
// Resolve dimensions
let hidden = self.gu_matmul_k.exec(dyn_map).unwrap();
let intermediate = self.dn_matmul_k.exec(dyn_map).unwrap();
let top_k = self.output_k.exec(dyn_map).unwrap();
let top_k_expected = self.output_k.exec(dyn_map).unwrap();
let gate_up_dim = self.gu_io.exec(dyn_map).unwrap() / hidden; // gate_up_dim = gu_io / hidden
let _num_experts = self.gu_within_range.exec(dyn_map).unwrap() / (gate_up_dim * hidden);
let num_experts = self.gu_within_range.exec(dyn_map).unwrap() / (gate_up_dim * hidden);
// Derive seq from x buffer size: x is [seq, hidden] F32 → seq = len / (hidden * 4)
let x_buf = buffers[&inputs[0]];
@@ -243,6 +314,7 @@ impl HostOp for GLUMoE {
let topk_vals_buf = buffers[&inputs[2]]; // [seq, k] F32
let gate_up_buf = buffers[&inputs[3]]; // [E, gate_up_dim, hidden] BF16
let down_buf = buffers[&inputs[4]]; // [E, hidden, intermediate] BF16
let mode_aux_buf = buffers[&inputs[5]];
let output_buf = buffers[&self_node]; // [seq, hidden] F32
// Get raw device pointer addresses
@@ -251,14 +323,59 @@ impl HostOp for GLUMoE {
let down_ptr = buf_ptr(down_buf, stream);
let output_ptr = buf_ptr(output_buf, stream);
let cublaslt = self.get_cublaslt(stream);
let (_, f32_to_bf16_fn, swiglu_fn) = self.get_kernels(stream);
let cublaslt = self.get_cublaslt(stream)?;
let (_, f32_to_bf16_fn, activation_fn) = self.get_kernels(stream);
// Read topk indices and values from GPU
// Read top-k routing values from GPU
let topk_idx_host: Vec<u8> = stream.clone_dtoh(topk_idx_buf)?;
let topk_idx_i32: &[i32] = bytemuck::cast_slice(&topk_idx_host);
let topk_vals_host: Vec<u8> = stream.clone_dtoh(topk_vals_buf)?;
let topk_vals_f32: &[f32] = bytemuck::cast_slice(&topk_vals_host);
let idx_k = topk_idx_i32
.len()
.checked_div(seq)
.unwrap_or(top_k_expected);
let val_k = topk_vals_f32
.len()
.checked_div(seq)
.unwrap_or(top_k_expected);
let top_k = idx_k.min(val_k);
if seq > 0 && top_k == 0 {
return Ok(());
}
// Mode-dependent expert weights used for the final reduction:
// - SwiGLU: direct topk values
// - GemmaGELU: normalize topk values and scale by per-expert factors
let mut expert_weights_storage: Vec<f32> = Vec::new();
let expert_weights_f32: &[f32] = match self.mode {
GLUMoEMode::SwiGLU => topk_vals_f32,
GLUMoEMode::GemmaGELU => {
let per_expert_scale_host: Vec<u8> = stream.clone_dtoh(mode_aux_buf)?;
let per_expert_scale_f32: &[f32] = bytemuck::cast_slice(&per_expert_scale_host);
debug_assert!(per_expert_scale_f32.len() >= num_experts);
expert_weights_storage.resize(seq * top_k, 0.0);
for t in 0..seq {
let base = t * top_k;
let vals = &topk_vals_f32[base..base + top_k];
let norm = vals.iter().copied().sum::<f32>();
let inv_norm = if norm != 0.0 { norm.recip() } else { 0.0 };
for i in 0..top_k {
let expert_idx = topk_idx_i32[base + i] as usize;
if expert_idx >= per_expert_scale_f32.len() {
anyhow::bail!(
"GLUMoE Gemma mode expert index {} out of bounds {}",
expert_idx,
per_expert_scale_f32.len()
);
}
let scale = per_expert_scale_f32[expert_idx];
expert_weights_storage[base + i] = vals[i] * inv_norm * scale;
}
}
&expert_weights_storage
}
};
// Allocate temp buffers
let x_bf16_buf = unsafe { stream.alloc::<u8>(seq * hidden * 2)? }; // BF16
@@ -291,22 +408,10 @@ impl HostOp for GLUMoE {
let gu_stride = (gate_up_dim * hidden * 2) as u64; // bytes per expert gate_up (BF16)
let down_stride = (hidden * intermediate * 2) as u64; // bytes per expert down (BF16)
// Normalize top-k values per token (norm_topk_prob=true)
let mut normalized_vals = topk_vals_f32.to_vec();
for t in 0..seq {
let row = &mut normalized_vals[t * top_k..(t + 1) * top_k];
let sum: f32 = row.iter().sum();
if sum > 0.0 {
for v in row.iter_mut() {
*v /= sum;
}
}
}
for t in 0..seq {
let x_t_ptr = xbf16_ptr + (t * hidden * 2) as u64; // BF16
let expert_indices = &topk_idx_i32[t * top_k..(t + 1) * top_k];
let weights = &normalized_vals[t * top_k..(t + 1) * top_k];
let weights = &expert_weights_f32[t * top_k..(t + 1) * top_k];
for (i, (&expert_idx, &weight)) in expert_indices.iter().zip(weights.iter()).enumerate()
{
@@ -316,7 +421,7 @@ impl HostOp for GLUMoE {
let expert_gu_ptr = gate_up_ptr + expert_idx as u64 * gu_stride;
cublas_matmul(
stream,
cublaslt,
&cublaslt,
ws_ptr,
gate_up_dim as u64,
1,
@@ -335,17 +440,19 @@ impl HostOp for GLUMoE {
0.0f32,
)?;
// b. SwiGLU kernel (BF16 → BF16)
// b. Mode-specific gated activation (BF16 → BF16)
let moe_int = intermediate as i32;
let swiglu_blocks = (moe_int as u32).div_ceil(256);
let activation_mode = self.mode.activation_kernel_mode();
let activation_blocks = (moe_int as u32).div_ceil(256);
unsafe {
stream
.launch_builder(swiglu_fn)
.launch_builder(activation_fn)
.arg(&gu_out_ptr)
.arg(&hid_ptr)
.arg(&moe_int)
.arg(&activation_mode)
.launch(LaunchConfig {
grid_dim: (swiglu_blocks, 1, 1),
grid_dim: (activation_blocks, 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: 0,
})?;
@@ -358,7 +465,7 @@ impl HostOp for GLUMoE {
let beta = if i == 0 { 0.0f32 } else { 1.0f32 };
cublas_matmul_mixed(
stream,
cublaslt,
&cublaslt,
ws_ptr,
hidden as u64,
1,

View File

@@ -9,6 +9,8 @@ use std::{
pub use cudarc;
use cudarc::{cublaslt::CudaBlasLT, driver::CudaStream};
#[cfg(test)]
mod tests;
@@ -137,6 +139,25 @@ fn cuda_driver_diagnostics() -> (Option<i32>, Option<i32>) {
(driver_version, None)
}
pub(crate) fn try_create_cublaslt(
stream: Arc<CudaStream>,
) -> std::result::Result<Arc<CudaBlasLT>, String> {
match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| CudaBlasLT::new(stream))) {
Ok(Ok(handle)) => Ok(Arc::new(handle)),
Ok(Err(err)) => Err(err.to_string()),
Err(payload) => {
let message = if let Some(message) = payload.downcast_ref::<String>() {
message.clone()
} else if let Some(message) = payload.downcast_ref::<&str>() {
message.to_string()
} else {
"cuBLASLt initialization panicked".to_string()
};
Err(message)
}
}
}
fn cuda_nvrtc_compile_options(target_arch: &str) -> Vec<String> {
let mut options = cuda_nvrtc_include_paths()
.into_iter()
@@ -186,9 +207,9 @@ fn get_cubin(program: nvrtc_sys::nvrtcProgram) -> Result<Vec<u8>, NvrtcError> {
}
let mut cubin = Vec::with_capacity(cubin_size);
cubin.resize(cubin_size, 0);
unsafe { nvrtc_sys::nvrtcGetCUBIN(program, cubin.as_mut_ptr()) }.result()?;
Ok(cubin.into_iter().map(|byte| byte as u8).collect())
cubin.resize(cubin_size, 0u8);
unsafe { nvrtc_sys::nvrtcGetCUBIN(program, cubin.as_mut_ptr() as *mut _) }.result()?;
Ok(cubin)
}
pub(crate) fn compile_module_image_for_current_device<S: AsRef<str>>(

View File

@@ -11,4 +11,6 @@ mod op_functional_tests;
#[cfg(test)]
mod performance_tests;
#[cfg(test)]
mod qwen3_moe_rewrite;
#[cfg(test)]
mod transformer;

View File

@@ -0,0 +1,314 @@
use half::bf16;
use luminal::{dtype::DType, prelude::*, shape::Expression};
use super::utilities::{assert_close, get_cuda_stream, random_f32_vec};
use crate::{
host::{
HostOp,
moe::{GLUMoE, GLUMoEMode},
},
runtime::CudaRuntime,
};
const SEQ: usize = 2;
const HIDDEN: usize = 16;
const NUM_EXPERTS: usize = 8;
const TOP_K: usize = 2;
const MOE_INTERMEDIATE: usize = 6;
const RMS_NORM_EPS: f32 = 1e-6;
struct QwenMoeGraph {
graph: Graph,
x: GraphTensor,
router: GraphTensor,
gate_up_weights: GraphTensor,
down_weights: GraphTensor,
output: GraphTensor,
}
struct GemmaMoeGraph {
graph: Graph,
router_input: GraphTensor,
expert_input: GraphTensor,
router_scale: GraphTensor,
router_proj: GraphTensor,
per_expert_scale: GraphTensor,
gate_up_weights: GraphTensor,
down_weights: GraphTensor,
output: GraphTensor,
}
fn build_qwen_moe_graph() -> QwenMoeGraph {
let mut cx = Graph::default();
let x = cx.tensor(('s', HIDDEN));
let router = cx.tensor((NUM_EXPERTS, HIDDEN));
let gate_up_weights = cx
.tensor((NUM_EXPERTS, MOE_INTERMEDIATE * 2, HIDDEN))
.as_dtype(DType::Bf16);
let down_weights = cx
.tensor((NUM_EXPERTS, HIDDEN, MOE_INTERMEDIATE))
.as_dtype(DType::Bf16);
let n = x.dims().len();
let e_dim = *router.dims().first().unwrap();
let k_expr = Expression::from(TOP_K);
let routing_weights = x.matmul(router.t()).softmax(n - 1);
let top_k_indices = routing_weights.topk_indexes(TOP_K, n - 1);
let row_offsets = x
.graph()
.iota(Expression::from('z') / k_expr * e_dim, top_k_indices.dims());
let routing_flat_idx = row_offsets + top_k_indices;
let top_k_values = routing_weights.gather(routing_flat_idx);
let gate_up_gathered = gather_experts(x, top_k_indices, gate_up_weights).cast(DType::F32);
let x_exp = x.expand_dim(n - 1, TOP_K).unsqueeze(n);
let gate_up_out = x_exp.matmul(gate_up_gathered.transpose(2, 3)).squeeze(n);
let gate = gate_up_out.slice((.., .., ..MOE_INTERMEDIATE));
let up = gate_up_out.slice((.., .., MOE_INTERMEDIATE..));
let hidden = gate.silu() * up;
let down_gathered = gather_experts(x, top_k_indices, down_weights).cast(DType::F32);
let down_out = hidden
.unsqueeze(2)
.matmul(down_gathered.transpose(2, 3))
.squeeze(2);
let output = (down_out * top_k_values.unsqueeze(top_k_values.dims().len()))
.sum(n - 1)
.output();
QwenMoeGraph {
graph: cx,
x,
router,
gate_up_weights,
down_weights,
output,
}
}
fn build_gemma_moe_graph() -> GemmaMoeGraph {
let mut cx = Graph::default();
let router_input = cx.tensor(('s', HIDDEN));
let expert_input = cx.tensor(('s', HIDDEN));
let router_scale = cx.tensor(HIDDEN);
let router_proj = cx.tensor((NUM_EXPERTS, HIDDEN));
let per_expert_scale = cx.tensor(NUM_EXPERTS);
let gate_up_weights = cx
.tensor((NUM_EXPERTS, MOE_INTERMEDIATE * 2, HIDDEN))
.as_dtype(DType::Bf16);
let down_weights = cx
.tensor((NUM_EXPERTS, HIDDEN, MOE_INTERMEDIATE))
.as_dtype(DType::Bf16);
let n = router_input.dims().len();
let e_dim = *router_proj.dims().first().unwrap();
let k_expr = Expression::from(TOP_K);
let router_hidden = router_input.std_norm(n - 1, RMS_NORM_EPS)
* router_scale.expand_lhs(&router_input.dims()[..n - 1])
* (HIDDEN as f32).sqrt().recip();
let routing_weights = router_hidden.matmul(router_proj.t()).softmax(n - 1);
let top_k_indices = routing_weights.topk_indexes(TOP_K, n - 1);
let row_offsets = router_input
.graph()
.iota(Expression::from('z') / k_expr * e_dim, top_k_indices.dims());
let routing_flat_idx = row_offsets + top_k_indices;
let top_k_values = routing_weights.gather(routing_flat_idx);
let top_k_norm = top_k_values.sum(n - 1).expand_dim(n - 1, TOP_K);
let top_k_weights = (top_k_values / top_k_norm) * per_expert_scale.gather(top_k_indices);
let gate_up_gathered =
gather_experts(expert_input, top_k_indices, gate_up_weights).cast(DType::F32);
let x_exp = expert_input.expand_dim(n - 1, TOP_K).unsqueeze(n);
let gate_up_out = x_exp.matmul(gate_up_gathered.transpose(2, 3)).squeeze(n);
let gate = gate_up_out.slice((.., .., ..MOE_INTERMEDIATE));
let up = gate_up_out.slice((.., .., MOE_INTERMEDIATE..));
let hidden = gemma_gelu(gate) * up;
let down_gathered = gather_experts(expert_input, top_k_indices, down_weights).cast(DType::F32);
let down_out = hidden
.unsqueeze(2)
.matmul(down_gathered.transpose(2, 3))
.squeeze(2);
let output = (down_out * top_k_weights.unsqueeze(top_k_weights.dims().len()))
.sum(n - 1)
.output();
GemmaMoeGraph {
graph: cx,
router_input,
expert_input,
router_scale,
router_proj,
per_expert_scale,
gate_up_weights,
down_weights,
output,
}
}
fn gather_experts(
graph_source: GraphTensor,
top_k_indices: GraphTensor,
weights: GraphTensor,
) -> GraphTensor {
let (_, d1, d2) = weights.dims3();
let io = d1 * d2;
let base = top_k_indices * io;
let within = graph_source.graph().iota(Expression::from('z'), (d1, d2));
let n_base = base.dims().len();
let exp_base = base.expand_dim(n_base, d1).expand_dim(n_base + 1, d2);
let mut exp_within = within;
for (axis, dim) in base.dims().iter().enumerate() {
exp_within = exp_within.expand_dim(axis, *dim);
}
let expert_flat_idx = exp_base + exp_within;
weights.gather(expert_flat_idx)
}
#[allow(clippy::excessive_precision)]
fn gemma_gelu(x: GraphTensor) -> GraphTensor {
let scaled = 1.5957691216 * x * (1. + 0.044715 * x * x);
x * scaled.sigmoid()
}
fn glumoe_modes(rt: &CudaRuntime) -> Vec<GLUMoEMode> {
rt.llir_graph()
.node_weights()
.filter_map(|node| {
let op = node.to_dialect::<dyn HostOp>()?;
op.as_any()
.downcast_ref::<GLUMoE>()
.map(|glumoe| glumoe.mode)
})
.collect()
}
fn run_qwen_moe(use_glumoe: bool) -> (Vec<f32>, Vec<GLUMoEMode>) {
let Some(stream) = get_cuda_stream() else {
return (vec![], vec![]);
};
let mut model = build_qwen_moe_graph();
model.graph.set_dim('s', SEQ);
if use_glumoe {
model.graph.build_search_space::<CudaRuntime>();
} else {
model
.graph
.build_search_space_exclude_ops::<CudaRuntime, GLUMoE>();
}
let x_data = random_f32_vec(SEQ * HIDDEN, 11, -0.15, 0.15);
let router_data = random_f32_vec(NUM_EXPERTS * HIDDEN, 12, -0.2, 0.2);
let gate_up_data = random_f32_vec(NUM_EXPERTS * MOE_INTERMEDIATE * 2 * HIDDEN, 13, -0.1, 0.1)
.into_iter()
.map(bf16::from_f32)
.collect::<Vec<_>>();
let down_data = random_f32_vec(NUM_EXPERTS * HIDDEN * MOE_INTERMEDIATE, 14, -0.1, 0.1)
.into_iter()
.map(bf16::from_f32)
.collect::<Vec<_>>();
let mut rt = CudaRuntime::initialize(stream);
rt.set_data(model.x, x_data);
rt.set_data(model.router, router_data);
rt.set_data(model.gate_up_weights, gate_up_data);
rt.set_data(model.down_weights, down_data);
rt = model.graph.search(rt, 10);
rt.execute(&model.graph.dyn_map);
(rt.get_f32(model.output.id), glumoe_modes(&rt))
}
fn run_gemma_moe(use_glumoe: bool) -> (Vec<f32>, Vec<GLUMoEMode>) {
let Some(stream) = get_cuda_stream() else {
return (vec![], vec![]);
};
let mut model = build_gemma_moe_graph();
model.graph.set_dim('s', SEQ);
if use_glumoe {
model.graph.build_search_space::<CudaRuntime>();
} else {
model
.graph
.build_search_space_exclude_ops::<CudaRuntime, GLUMoE>();
}
let router_input_data = random_f32_vec(SEQ * HIDDEN, 21, -0.15, 0.15);
let expert_input_data = random_f32_vec(SEQ * HIDDEN, 22, -0.15, 0.15);
let router_scale_data = random_f32_vec(HIDDEN, 23, 0.7, 1.3);
let router_proj_data = random_f32_vec(NUM_EXPERTS * HIDDEN, 24, -0.2, 0.2);
let per_expert_scale_data = random_f32_vec(NUM_EXPERTS, 25, 0.5, 1.5);
let gate_up_data = random_f32_vec(NUM_EXPERTS * MOE_INTERMEDIATE * 2 * HIDDEN, 26, -0.1, 0.1)
.into_iter()
.map(bf16::from_f32)
.collect::<Vec<_>>();
let down_data = random_f32_vec(NUM_EXPERTS * HIDDEN * MOE_INTERMEDIATE, 27, -0.1, 0.1)
.into_iter()
.map(bf16::from_f32)
.collect::<Vec<_>>();
let mut rt = CudaRuntime::initialize(stream);
rt.set_data(model.router_input, router_input_data);
rt.set_data(model.expert_input, expert_input_data);
rt.set_data(model.router_scale, router_scale_data);
rt.set_data(model.router_proj, router_proj_data);
rt.set_data(model.per_expert_scale, per_expert_scale_data);
rt.set_data(model.gate_up_weights, gate_up_data);
rt.set_data(model.down_weights, down_data);
rt = model.graph.search(rt, 10);
rt.execute(&model.graph.dyn_map);
(rt.get_f32(model.output.id), glumoe_modes(&rt))
}
#[test]
fn test_glumoe_matches_qwen_swiglu_pattern() {
let (_result, modes) = run_qwen_moe(true);
if modes.is_empty() {
return;
}
assert_eq!(modes, vec![GLUMoEMode::SwiGLU]);
}
#[test]
fn test_glumoe_matches_gemma_gelu_pattern() {
let (_result, modes) = run_gemma_moe(true);
if modes.is_empty() {
return;
}
assert_eq!(modes, vec![GLUMoEMode::GemmaGELU]);
}
#[test]
fn test_glumoe_swiglu_matches_unfused_output() {
let (expected, baseline_modes) = run_qwen_moe(false);
if expected.is_empty() {
return;
}
assert!(baseline_modes.is_empty());
let (actual, fused_modes) = run_qwen_moe(true);
assert_eq!(fused_modes, vec![GLUMoEMode::SwiGLU]);
assert_close(&actual, &expected, 3e-2, 3e-2);
}
#[test]
fn test_glumoe_gemma_gelu_matches_unfused_output() {
let (expected, baseline_modes) = run_gemma_moe(false);
if expected.is_empty() {
return;
}
assert!(baseline_modes.is_empty());
let (actual, fused_modes) = run_gemma_moe(true);
assert_eq!(fused_modes, vec![GLUMoEMode::GemmaGELU]);
assert_close(&actual, &expected, 3e-2, 3e-2);
}

View File

@@ -628,6 +628,84 @@ impl CompiledGraph {
}
}
fn get_output_i32(&self, name: &str) -> PyResult<Vec<i32>> {
let node_id = self.tensor_ids.get(name).ok_or_else(|| {
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!(
"Unknown output tensor: {}",
name
))
})?;
match &self.runtime {
RuntimeBackend::Native(rt) => {
let id = *node_id;
let output_id = rt
.graph
.node_indices()
.find(|n| {
if let Some(out) = (**rt.graph[*n]).as_any().downcast_ref::<Output>() {
out.node == id.index()
} else {
false
}
})
.ok_or_else(|| {
pyo3::exceptions::PyRuntimeError::new_err(format!(
"No output node found for tensor: {}",
name
))
})?;
let data = rt.buffers.get(&output_id).ok_or_else(|| {
pyo3::exceptions::PyRuntimeError::new_err(format!(
"No buffer data for output tensor: {}",
name
))
})?;
Ok((0..data.len()).map(|i| data.i32(i)).collect())
}
#[cfg(feature = "cuda")]
RuntimeBackend::Cuda(rt) => Ok(rt.get_i32(*node_id)),
}
}
fn get_output_bool(&self, name: &str) -> PyResult<Vec<bool>> {
let node_id = self.tensor_ids.get(name).ok_or_else(|| {
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!(
"Unknown output tensor: {}",
name
))
})?;
match &self.runtime {
RuntimeBackend::Native(rt) => {
let id = *node_id;
let output_id = rt
.graph
.node_indices()
.find(|n| {
if let Some(out) = (**rt.graph[*n]).as_any().downcast_ref::<Output>() {
out.node == id.index()
} else {
false
}
})
.ok_or_else(|| {
pyo3::exceptions::PyRuntimeError::new_err(format!(
"No output node found for tensor: {}",
name
))
})?;
let data = rt.buffers.get(&output_id).ok_or_else(|| {
pyo3::exceptions::PyRuntimeError::new_err(format!(
"No buffer data for output tensor: {}",
name
))
})?;
Ok((0..data.len()).map(|i| data.bool(i)).collect())
}
#[cfg(feature = "cuda")]
RuntimeBackend::Cuda(rt) => Ok(rt.get_bool(*node_id)),
}
}
/// Copy output tensor data directly to a CUDA device pointer (DtoD).
/// Avoids the DtoH + HtoD round-trip of get_output() + .to(device).
#[cfg(feature = "cuda")]

View File

@@ -51,6 +51,7 @@ impl<'a> Translator<'a> {
"torch.ops.aten.sub.Scalar" => self.translate_binary_scalar_op(node, BinaryOp::Sub)?,
"torch.ops.aten.div.Tensor" => self.translate_binary_op(node, BinaryOp::Div)?,
"torch.ops.aten.div.Scalar" => self.translate_binary_scalar_op(node, BinaryOp::Div)?,
"torch.ops.aten.div.Tensor_mode" => self.translate_div_tensor_mode(node)?,
// Unary ops
"torch.ops.aten.neg.default" => self.translate_unary_op(node, |a| a * (-1.0))?,
@@ -71,6 +72,8 @@ impl<'a> Translator<'a> {
"torch.ops.aten.log.default" => self.translate_unary_op(node, |a| a.log())?,
"torch.ops.aten.log2.default" => self.translate_unary_op(node, |a| a.log2())?,
"torch.ops.aten.exp2.default" => self.translate_unary_op(node, |a| a.exp2())?,
"torch.ops.aten.sign.default" => self.translate_sign(node)?,
"torch.ops.aten.bitwise_not.default" => self.translate_bitwise_not(node)?,
// Cast
"torch.ops.aten._to_copy.default" => self.translate_to_copy(node)?,
@@ -109,6 +112,7 @@ impl<'a> Translator<'a> {
let a = self.get_input_tensor(node, 0)?;
if !a.shape.is_contiguous() { a + 0.0 } else { a }
}
"torch.ops.aten.argsort.default" => self.translate_argsort(node)?,
// Matmul
"torch.ops.aten.mm.default" | "torch.ops.aten.bmm.default" => {
@@ -159,6 +163,8 @@ impl<'a> Translator<'a> {
// Where
"torch.ops.aten.where.self" => self.translate_where(node)?,
"torch.ops.aten.where.ScalarOther" => self.translate_where_scalar_other(node)?,
"torch.ops.aten.masked_fill.Scalar" => self.translate_masked_fill_scalar(node)?,
// Pow
"torch.ops.aten.pow.Tensor_Scalar" => {
@@ -176,6 +182,7 @@ impl<'a> Translator<'a> {
// Creation ops
"torch.ops.aten.arange.start_step" => self.translate_arange(node)?,
"torch.ops.aten.full.default" => self.translate_full(node)?,
"torch.ops.aten.full_like.default" => self.translate_full_like(node)?,
"torch.ops.aten.scalar_tensor.default" => {
let val = self.get_float_arg(node, 0)? as f32;
self.graph.constant_float(val)
@@ -349,7 +356,17 @@ impl<'a> Translator<'a> {
// Scatter ops
"torch.ops.aten.scatter.src" => self.translate_scatter_src(node)?,
"torch.ops.aten.index_put.default" => self.translate_index_put(node)?,
"torch.ops.aten.scatter.value" => self.translate_scatter_value(node)?,
"torch.ops.aten.index_put_.default" | "torch.ops.aten.index_put.default" => {
self.translate_index_put(node)?
}
// Integer routing math
"torch.ops.aten.floor_divide.default" => self.translate_floor_divide(node)?,
// Triangular
"torch.ops.aten.tril.default" => self.translate_tril(node)?,
"torch.ops.aten.triu.default" => self.translate_triu(node)?,
// TopK — handles its own output storage, returns early
"torch.ops.aten.topk.default" => {
@@ -357,6 +374,12 @@ impl<'a> Translator<'a> {
return Ok(());
}
// Sort — handles its own output storage, returns early
"torch.ops.aten.sort.default" => {
self.translate_sort(node)?;
return Ok(());
}
// Split
"torch.ops.aten.split_with_sizes.default" => self.translate_split_with_sizes(node)?,

View File

@@ -77,11 +77,12 @@ impl<'a> Translator<'a> {
let output_names = self.parsed.output_names();
for name in &output_names {
let tensor = self.get_tensor(name)?;
// Cast non-float outputs (Bool, Int) to F32 for the runtime.
// Preserve F16/BF16/F32 as-is to avoid corrupting half-precision models.
let tensor = match tensor.dtype {
DType::Bool | DType::Int => tensor.cast(DType::F32) + 0.0,
_ => tensor + 0.0,
let tensor = if tensor.dtype == DType::Bool {
tensor.cast(DType::Int).cast(DType::Bool)
} else if tensor.dtype == DType::Int {
tensor
} else {
tensor + 0.0
};
tensor.output();
self.output_ids.push((name.clone(), tensor.id));
@@ -155,6 +156,12 @@ impl<'a> Translator<'a> {
// --- Helper methods ---
pub(crate) fn tensor_meta(&self, name: &str) -> Option<&TensorMeta> {
self.extra_tensor_values
.get(name)
.or_else(|| self.parsed.tensor_meta(name))
}
pub(crate) fn get_tensor(&self, name: &str) -> Result<GraphTensor> {
self.tensors
.get(name)

View File

@@ -6,6 +6,11 @@ use crate::pt2_util::*;
use super::Translator;
const SCATTER_INPUT_ARG: usize = 0;
const SCATTER_DIM_ARG: usize = 1;
const SCATTER_INDEX_ARG: usize = 2;
const SCATTER_VALUE_ARG: usize = 3;
impl<'a> Translator<'a> {
pub(crate) fn translate_reshape(&mut self, node: &Node) -> Result<GraphTensor> {
let a = self.get_input_tensor(node, 0)?;
@@ -359,6 +364,29 @@ impl<'a> Translator<'a> {
Ok(a.scatter_elements(indices.cast(DType::Int), src, dim))
}
pub(crate) fn translate_scatter_value(&mut self, node: &Node) -> Result<GraphTensor> {
let a = self.get_input_tensor(node, SCATTER_INPUT_ARG)?;
let dim = self.get_int_arg(node, SCATTER_DIM_ARG)?;
let dim = normalize_dim(dim, a.shape.len());
let indices = self.get_input_tensor(node, SCATTER_INDEX_ARG)?;
let value_arg = &node
.inputs
.get(SCATTER_VALUE_ARG)
.context("scatter.value missing value input")?
.arg;
let value = if let Some(b) = value_arg.as_bool() {
self.graph.constant(if b { 1 } else { 0 }).cast(a.dtype)
} else if let Some(i) = value_arg.as_int() {
self.graph.constant(i).cast(a.dtype)
} else if let Some(f) = value_arg.as_float() {
self.graph.constant_float(f as f32).cast(a.dtype)
} else {
bail!("scatter.value: unsupported scalar argument {:?}", value_arg);
}
.expand_rhs(indices.shape);
Ok(a.scatter_elements(indices.cast(DType::Int), value, dim))
}
pub(crate) fn translate_index_put(&mut self, node: &Node) -> Result<GraphTensor> {
let a = self.get_input_tensor(node, 0)?;
let index_names = node.inputs[1]

View File

@@ -6,6 +6,27 @@ use crate::pt2_util::*;
use super::Translator;
const FULL_SHAPE_ARG: usize = 0;
const FULL_VALUE_ARG: usize = 1;
const FULL_LIKE_INPUT_ARG: usize = 0;
const FULL_LIKE_VALUE_ARG: usize = 1;
const TOPK_INPUT_ARG: usize = 0;
const TOPK_K_ARG: usize = 1;
const TOPK_DIM_ARG: usize = 2;
const SORT_INPUT_ARG: usize = 0;
const SORT_DIM_ARG: usize = 1;
const SORT_DESCENDING_ARG: usize = 2;
const WHERE_COND_ARG: usize = 0;
const WHERE_X_ARG: usize = 1;
const WHERE_OTHER_ARG: usize = 2;
const TRIANGULAR_INPUT_ARG: usize = 0;
const TRIANGULAR_DIAGONAL_ARG: usize = 1;
impl<'a> Translator<'a> {
pub(crate) fn translate_arange(&mut self, node: &Node) -> Result<GraphTensor> {
let positional_args: Vec<Expression> = node
@@ -30,19 +51,55 @@ impl<'a> Translator<'a> {
}
pub(crate) fn translate_full(&mut self, node: &Node) -> Result<GraphTensor> {
let shape = self.get_exprs_arg(node, 0)?;
let shape = self.get_exprs_arg(node, FULL_SHAPE_ARG)?;
// fill_value can be float, int, or bool after decomposition
let val = if let Ok(f) = self.get_float_arg(node, 1) {
let val = if let Ok(f) = self.get_float_arg(node, FULL_VALUE_ARG) {
f as f32
} else if let Ok(b) = self.get_bool_arg(node, 1) {
} else if let Ok(b) = self.get_bool_arg(node, FULL_VALUE_ARG) {
if b { 1.0 } else { 0.0 }
} else {
anyhow::bail!(
"full: unsupported fill value type: {:?}",
node.inputs.get(1)
node.inputs.get(FULL_VALUE_ARG)
);
};
Ok(self.graph.constant_float(val).expand_rhs(shape))
let dtype = self.output_meta_dtype(node)?;
let value = self.graph.constant_float(val).cast(dtype);
Ok(if shape.is_empty() {
value
} else {
value.expand_rhs(shape)
})
}
pub(crate) fn translate_full_like(&mut self, node: &Node) -> Result<GraphTensor> {
let reference = self.get_input_tensor(node, FULL_LIKE_INPUT_ARG)?;
let val = if let Ok(f) = self.get_float_arg(node, FULL_LIKE_VALUE_ARG) {
f as f32
} else if let Ok(b) = self.get_bool_arg(node, FULL_LIKE_VALUE_ARG) {
if b { 1.0 } else { 0.0 }
} else {
anyhow::bail!(
"full_like: unsupported fill value type: {:?}",
node.inputs.get(FULL_LIKE_VALUE_ARG)
);
};
let dtype = self.output_meta_dtype(node)?;
let value = self.graph.constant_float(val).cast(dtype);
Ok(value.expand_rhs(reference.shape))
}
fn output_meta_dtype(&self, node: &Node) -> Result<DType> {
let output_name = node
.outputs
.first()
.and_then(|o| o.as_tensor.as_ref())
.map(|t| t.name.clone())
.unwrap_or_default();
let meta = self
.tensor_meta(&output_name)
.context("Missing tensor meta for output dtype")?;
Ok(torch_dtype_int_to_luminal(meta.dtype))
}
pub(crate) fn translate_where(&mut self, node: &Node) -> Result<GraphTensor> {
@@ -62,11 +119,64 @@ impl<'a> Translator<'a> {
Ok(c * x_f + (one - c) * y_f)
}
pub(crate) fn translate_where_scalar_other(&mut self, node: &Node) -> Result<GraphTensor> {
let cond = self.get_input_tensor(node, WHERE_COND_ARG)?;
let x = self.get_input_tensor(node, WHERE_X_ARG)?;
let other_val = self.get_float_arg(node, WHERE_OTHER_ARG)? as f32;
// Broadcast cond and x to a common shape
let (cond_b, x_b) = broadcast_binary(cond, x);
let c = cond_b.cast(DType::F32);
let one = self.graph.constant_float(1.0).expand_rhs(c.shape);
let other = self.graph.constant_float(other_val).expand_rhs(c.shape);
Ok(c * x_b + (one - c) * other)
}
pub(crate) fn translate_tril(&mut self, node: &Node) -> Result<GraphTensor> {
self.translate_triangular(node, false)
}
pub(crate) fn translate_triu(&mut self, node: &Node) -> Result<GraphTensor> {
self.translate_triangular(node, true)
}
fn translate_triangular(&mut self, node: &Node, upper: bool) -> Result<GraphTensor> {
let a = self.get_input_tensor(node, TRIANGULAR_INPUT_ARG)?;
let diagonal = if node.inputs.len() > TRIANGULAR_DIAGONAL_ARG {
self.get_int_arg(node, TRIANGULAR_DIAGONAL_ARG).unwrap_or(0) as i32
} else {
0
};
let dims = a.shape.dims;
let rows = dims[dims.len() - 2];
let cols = dims[dims.len() - 1];
let (r_val, c_val) = match (rows.to_usize(), cols.to_usize()) {
(Some(r), Some(c)) => (r, c),
_ => anyhow::bail!("tril/triu requires concrete matrix dimensions"),
};
let size = r_val.max(c_val);
let mask = if upper {
self.graph.triu(size, diagonal)
} else {
self.graph.tril(size, diagonal)
}
.cast(DType::F32);
let mask = if rows != cols {
mask.slice_along(0..r_val, 0).slice_along(0..c_val, 1)
} else {
mask
};
let mut mask_expanded = mask;
for i in (0..dims.len() - 2).rev() {
mask_expanded = mask_expanded.expand_dim(0, dims[i]);
}
Ok(a * mask_expanded)
}
pub(crate) fn translate_topk(&mut self, node: &Node) -> Result<()> {
let a = self.get_input_tensor(node, 0)?;
let k = self.get_int_arg(node, 1)? as usize;
let dim = if node.inputs.len() > 2 {
self.get_int_arg(node, 2).unwrap_or(-1)
let a = self.get_input_tensor(node, TOPK_INPUT_ARG)?;
let k = self.get_int_arg(node, TOPK_K_ARG)? as usize;
let dim = if node.inputs.len() > TOPK_DIM_ARG {
self.get_int_arg(node, TOPK_DIM_ARG).unwrap_or(-1)
} else {
-1
};
@@ -86,13 +196,10 @@ impl<'a> Translator<'a> {
None
};
// Use full argsort then slice, rather than topk_indexes/topk_values directly.
// This avoids a CUDA gather kernel bug when data and index shapes differ
// along the gather axis (topk_indexes returns a sliced tensor).
let full_argsort = a.argsort(dim, true);
// Build top-k outputs from a full stable argsort, then slice to k.
let full_argsort = a.stable_argsort(dim, true);
// Only build each branch when its output is consumed.
// Dead nodes in the graph can confuse the CUDA optimizer.
// Only build the outputs that are consumed.
if let Some(val_name) = values_name
&& !val_name.is_empty()
{
@@ -100,8 +207,7 @@ impl<'a> Translator<'a> {
self.tensors.insert(val_name, values);
}
if let Some(idx_name) = indices_name {
// Materialize Int indices as F32 with `* 1.0` to force a contiguous copy.
// Without this, CUDA can't correctly read the sliced Int view.
// Materialize the sliced indices through a copy before storing them.
let indices = full_argsort.slice_along(..k, dim) * 1.0;
self.tensors.insert(idx_name, indices);
}
@@ -109,6 +215,51 @@ impl<'a> Translator<'a> {
Ok(())
}
pub(crate) fn translate_sort(&mut self, node: &Node) -> Result<()> {
let a = self.get_input_tensor(node, SORT_INPUT_ARG)?;
let dim = if node.inputs.len() > SORT_DIM_ARG {
self.get_int_arg(node, SORT_DIM_ARG).unwrap_or(-1)
} else {
-1
};
let descending = if node.inputs.len() > SORT_DESCENDING_ARG {
self.get_bool_arg(node, SORT_DESCENDING_ARG)
.unwrap_or(false)
} else {
false
};
let dim = normalize_dim(dim, a.shape.len());
// Determine output names (sort returns (values, indices))
let values_name = node
.outputs
.first()
.and_then(|o| o.as_tensor.as_ref().map(|t| t.name.clone()));
let indices_name =
if let Some(ts) = node.outputs.first().and_then(|o| o.as_tensors.as_ref()) {
ts.get(1).map(|t| t.name.clone())
} else if node.outputs.len() > 1 {
node.outputs[1].as_tensor.as_ref().map(|t| t.name.clone())
} else {
None
};
let full_argsort = a.stable_argsort(dim, descending);
if let Some(val_name) = values_name
&& !val_name.is_empty()
{
let values = a.gather_elements(full_argsort, dim);
self.tensors.insert(val_name, values);
}
if let Some(idx_name) = indices_name {
let indices = full_argsort * 1.0;
self.tensors.insert(idx_name, indices);
}
Ok(())
}
pub(crate) fn translate_wrap_set_grad(&mut self, node: &Node) -> Result<()> {
let subgraph = node.inputs[1]
.arg

View File

@@ -6,7 +6,38 @@ use crate::pt2_util::{broadcast_binary, torch_dtype_int_to_luminal};
use super::Translator;
const ARGSORT_INPUT_ARG: usize = 0;
const ARGSORT_DIM_ARG: usize = 1;
const ARGSORT_DESCENDING_ARG: usize = 2;
const MASKED_FILL_INPUT_ARG: usize = 0;
const MASKED_FILL_MASK_ARG: usize = 1;
const MASKED_FILL_VALUE_ARG: usize = 2;
const FLOOR_DIVIDE_INPUT_ARG: usize = 0;
const FLOOR_DIVIDE_OTHER_ARG: usize = 1;
const DIV_MODE_INPUT_ARG: usize = 0;
const DIV_MODE_OTHER_ARG: usize = 1;
impl<'a> Translator<'a> {
pub(crate) fn translate_argsort(&mut self, node: &Node) -> Result<GraphTensor> {
let a = self.get_input_tensor(node, ARGSORT_INPUT_ARG)?;
let dim = if node.inputs.len() > ARGSORT_DIM_ARG {
self.get_int_arg(node, ARGSORT_DIM_ARG).unwrap_or(-1)
} else {
-1
};
let descending = if node.inputs.len() > ARGSORT_DESCENDING_ARG {
self.get_bool_arg(node, ARGSORT_DESCENDING_ARG)
.unwrap_or(false)
} else {
false
};
let dim = crate::pt2_util::normalize_dim(dim, a.shape.len());
Ok(a.stable_argsort(dim, descending))
}
pub(crate) fn translate_unary_op(
&mut self,
node: &Node,
@@ -19,11 +50,15 @@ impl<'a> Translator<'a> {
pub(crate) fn translate_to_copy(&mut self, node: &Node) -> Result<GraphTensor> {
let a = self.get_input_tensor(node, 0)?;
for input in &node.inputs {
if input.name == "dtype"
&& let Some(dtype_int) = input.arg.as_int()
{
let dtype = torch_dtype_int_to_luminal(dtype_int as u32);
return Ok(a.cast(dtype));
if input.name == "dtype" {
if let Some(dtype_int) = input.arg.as_int() {
let dtype = torch_dtype_int_to_luminal(dtype_int as u32);
return Ok(a.cast(dtype));
}
if let Some(dtype_int) = input.arg.as_scalar_type() {
let dtype = torch_dtype_int_to_luminal(dtype_int);
return Ok(a.cast(dtype));
}
}
}
Ok(a)
@@ -60,6 +95,155 @@ impl<'a> Translator<'a> {
Ok(result)
}
pub(crate) fn translate_sign(&mut self, node: &Node) -> Result<GraphTensor> {
let a = self.get_input_tensor(node, 0)?;
let zero = self
.graph
.constant_float(0.0)
.cast(a.dtype)
.expand_rhs(a.shape);
let pos = a.gt(zero).cast(DType::Int);
let neg = a.lt(zero).cast(DType::Int);
let signed = pos - neg;
Ok(if a.dtype == DType::Int {
signed
} else {
signed.cast(a.dtype)
})
}
pub(crate) fn translate_bitwise_not(&mut self, node: &Node) -> Result<GraphTensor> {
let a = self.get_input_tensor(node, 0)?;
Ok(match a.dtype {
DType::Bool => {
let one = self
.graph
.constant_float(1.0)
.cast(DType::Int)
.expand_rhs(a.shape);
(one - a.cast(DType::Int)).cast(DType::Bool)
}
DType::Int => (a + 1) * -1.0,
other => {
anyhow::bail!("bitwise_not only supports Bool/Int routing tensors, got {other:?}")
}
})
}
pub(crate) fn translate_masked_fill_scalar(&mut self, node: &Node) -> Result<GraphTensor> {
let input = self.get_input_tensor(node, MASKED_FILL_INPUT_ARG)?;
let mask = self.get_input_tensor(node, MASKED_FILL_MASK_ARG)?;
let fill = self.get_float_arg(node, MASKED_FILL_VALUE_ARG)? as f32;
let (input, mask) = broadcast_binary(input, mask);
let work_dtype = if input.dtype == DType::Bool {
DType::Int
} else {
input.dtype
};
let input_work = if input.dtype == DType::Bool {
input.cast(DType::Int)
} else {
input
};
let mask_work = mask.cast(work_dtype);
let fill_work = self
.graph
.constant_float(fill)
.cast(work_dtype)
.expand_rhs(input_work.shape);
let one = self
.graph
.constant_float(1.0)
.cast(work_dtype)
.expand_rhs(input_work.shape);
let result = mask_work * fill_work + (one - mask_work) * input_work;
Ok(if input.dtype == DType::Bool {
result.cast(DType::Bool)
} else {
result
})
}
pub(crate) fn translate_floor_divide(&mut self, node: &Node) -> Result<GraphTensor> {
let a = self.get_input_tensor(node, FLOOR_DIVIDE_INPUT_ARG)?;
let b = if let Some(name) = node
.inputs
.get(FLOOR_DIVIDE_OTHER_ARG)
.and_then(|i| i.arg.as_tensor_name())
{
self.get_tensor(name)?
} else {
let scalar = self.get_float_arg(node, FLOOR_DIVIDE_OTHER_ARG)? as f32;
self.graph
.constant_float(scalar)
.cast(a.dtype)
.expand_rhs(a.shape)
};
let (a, b) = crate::pt2_util::ensure_same_dtype(a, b);
let (a, b) = broadcast_binary(a, b);
let quotient = a.cast(DType::F32) / b.cast(DType::F32);
let trunc = quotient.cast(DType::Int).cast(DType::F32);
let adjust = quotient.lt(trunc).cast(DType::F32);
let floored = trunc - adjust;
Ok(if a.dtype == DType::Int {
floored.cast(DType::Int)
} else {
floored.cast(a.dtype)
})
}
pub(crate) fn translate_div_tensor_mode(&mut self, node: &Node) -> Result<GraphTensor> {
let a = self.get_input_tensor(node, DIV_MODE_INPUT_ARG)?;
let b = if let Some(name) = node
.inputs
.get(DIV_MODE_OTHER_ARG)
.and_then(|i| i.arg.as_tensor_name())
{
self.get_tensor(name)?
} else {
let scalar = self.get_float_arg(node, DIV_MODE_OTHER_ARG)? as f32;
self.graph
.constant_float(scalar)
.cast(a.dtype)
.expand_rhs(a.shape)
};
let (a, b) = crate::pt2_util::ensure_same_dtype(a, b);
let (a, b) = broadcast_binary(a, b);
// Check rounding_mode kwarg
let rounding_mode = node.inputs.iter().find_map(|input| {
if input.name == "rounding_mode"
&& let Argument::Other(val) = &input.arg
{
return val.as_str().map(|s| s.to_string());
}
None
});
let quotient = a.cast(DType::F32) / b.cast(DType::F32);
match rounding_mode.as_deref() {
Some("floor") => {
let trunc = quotient.cast(DType::Int).cast(DType::F32);
let adjust = quotient.lt(trunc).cast(DType::F32);
let floored = trunc - adjust;
Ok(if a.dtype == DType::Int {
floored.cast(DType::Int)
} else {
floored.cast(a.dtype)
})
}
Some("trunc") => Ok(if a.dtype == DType::Int {
quotient.cast(DType::Int)
} else {
quotient.cast(DType::Int).cast(a.dtype)
}),
_ => {
// No rounding mode — regular division
Ok(quotient.cast(a.dtype))
}
}
}
pub(crate) fn translate_clamp(&mut self, node: &Node) -> Result<GraphTensor> {
let a = self.get_input_tensor(node, 0)?;
let min_val = if node.inputs.len() > 1 {

View File

@@ -95,7 +95,7 @@ class CompiledModel:
self._graph.set_input_device_ptr(name, t.data_ptr(), n_bytes)
_input_refs.append(t)
else:
t = tensor.detach().cpu().contiguous()
t = tensor.detach().cpu().contiguous().to(expected_dtype)
n_bytes = t.numel() * t.element_size()
dtype_code = _torch_dtype_code(t.dtype)
self._graph.set_input_from_ptr(name, t.data_ptr(), n_bytes, dtype_code)
@@ -120,9 +120,10 @@ class CompiledModel:
else torch.float32
)
out = torch.empty(shape, dtype=out_dtype, device=input_device)
self._graph.set_output_device_ptr(
name, out.data_ptr(), out.numel() * out.element_size()
)
if out_dtype.is_floating_point:
self._graph.set_output_device_ptr(
name, out.data_ptr(), out.numel() * out.element_size()
)
output_tensors.append(out)
# Run the graph
@@ -130,13 +131,42 @@ class CompiledModel:
# Collect outputs
if _use_zero_copy:
# For aliased outputs that couldn't be zero-copied, fall back to DtoD copy.
for name, out in zip(self._output_names, output_tensors):
if not self._graph.output_is_zero_copy(name):
self._graph.copy_output_to_device_ptr(
name, out.data_ptr(), out.numel() * out.element_size()
outputs = []
for i, (name, shape) in enumerate(zip(self._output_names, output_shapes)):
out_dtype = (
code_to_torch_dtype(output_dtype_codes[i])
if i < len(output_dtype_codes)
else torch.float32
)
out = output_tensors[i]
if out_dtype.is_floating_point:
if not self._graph.output_is_zero_copy(name):
self._graph.copy_output_to_device_ptr(
name, out.data_ptr(), out.numel() * out.element_size()
)
elif out_dtype == torch.int32:
data = self._graph.get_output_i32(name)
out = (
torch.tensor(data, dtype=torch.int32)
.reshape(tuple(shape))
.to(input_device)
)
outputs = output_tensors
elif out_dtype == torch.bool:
data = self._graph.get_output_bool(name)
out = (
torch.tensor(data, dtype=torch.bool)
.reshape(tuple(shape))
.to(input_device)
)
else:
data = self._graph.get_output(name)
out = (
torch.tensor(data, dtype=torch.float32)
.reshape(tuple(shape))
.to(out_dtype)
.to(input_device)
)
outputs.append(out)
else:
# Native path: retrieve as f32, then convert to target dtype if needed.
outputs = []
@@ -146,13 +176,20 @@ class CompiledModel:
if i < len(output_dtype_codes)
else torch.float32
)
data = self._graph.get_output(name)
out = (
torch.tensor(data, dtype=torch.float32)
.reshape(tuple(shape))
.to(out_dtype)
.to(input_device)
)
if out_dtype == torch.int32:
data = self._graph.get_output_i32(name)
out = torch.tensor(data, dtype=torch.int32).reshape(tuple(shape))
elif out_dtype == torch.bool:
data = self._graph.get_output_bool(name)
out = torch.tensor(data, dtype=torch.bool).reshape(tuple(shape))
else:
data = self._graph.get_output(name)
out = (
torch.tensor(data, dtype=torch.float32)
.reshape(tuple(shape))
.to(out_dtype)
)
out = out.to(input_device)
outputs.append(out)
return tuple(outputs)

View File

@@ -215,6 +215,7 @@ from test_models import (
WhereWithConstantModel,
# Xor model
XorTestModel,
ArgsortStableDuplicatesModel,
# Conv models
Conv1dNoPadModel,
Conv1dSamePadModel,
@@ -231,6 +232,7 @@ from test_models import (
GroupedConv2dModel,
GroupedConv2dGroups3Model,
MambaConvBlockModel,
TinyMoERoutingModel,
)
from luminal import luminal_backend
@@ -1948,6 +1950,54 @@ def test_split(device: torch.device):
assert torch.allclose(model_compiled(x), model(x))
# ========== Argsort / MoE Routing Tests ==========
def test_argsort_stable_duplicates(device: torch.device):
"""Duplicate values should follow stable lower-index-first tie-breaking."""
model: torch.nn.Module = ArgsortStableDuplicatesModel().to(device)
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
x = torch.tensor(
[[2.0, 1.0, 1.0, 3.0]],
dtype=torch.float32,
device=device,
)
original: torch.Tensor = model(x)
output: torch.Tensor = model_compiled(x)
assert output.dtype == torch.int32
assert torch.equal(output, original.to(torch.int32))
def test_tiny_moe_routing(device: torch.device):
"""Focused proof for build MoE routing support."""
model: torch.nn.Module = TinyMoERoutingModel().to(device)
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
scores = torch.tensor(
[[0.1, 0.9, 0.4, 0.7], [0.6, -0.8, 0.95, 0.2]],
dtype=torch.float32,
device=device,
)
expected = model(scores)
output = model_compiled(scores)
expected_dtypes = (
torch.int32,
torch.float32,
torch.int32,
torch.bool,
torch.int32,
torch.float32,
)
for actual, eager, expected_dtype in zip(output, expected, expected_dtypes):
assert actual.dtype == expected_dtype
eager = eager.to(actual.dtype)
if actual.dtype.is_floating_point:
assert torch.allclose(actual, eager)
else:
assert torch.equal(actual, eager)
# ========== PT2 TopK Node Tests ==========

View File

@@ -1619,6 +1619,73 @@ class SplitTestModel(torch.nn.Module):
return a + b
# ========== Argsort / MoE Routing Test Models ==========
class ArgsortStableDuplicatesModel(torch.nn.Module):
"""Tests deterministic duplicate ordering for exported argsort."""
SORT_DIM = 1
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.argsort(x, dim=self.SORT_DIM)
class TinyMoERoutingModel(torch.nn.Module):
"""Minimal deterministic MoE-style routing proof for PT2/native and CUDA."""
TOP_K = 2
ROUTING_DIM = -1
ZERO_FILL = 0.0
DISPATCH_ON = 1
GROUP_SIZE = 2
def __init__(self) -> None:
super().__init__()
self.register_buffer(
"expert_scale",
torch.tensor([1.5, -0.5, 2.0, 0.25], dtype=torch.float32),
)
def forward(
self, scores: torch.Tensor
) -> tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
]:
topk_values, topk_indices = torch.topk(scores, self.TOP_K, dim=self.ROUTING_DIM)
regroup_order = torch.argsort(topk_indices, dim=self.ROUTING_DIM)
routed_indices = torch.gather(topk_indices, self.ROUTING_DIM, regroup_order)
routed_values = torch.gather(topk_values, self.ROUTING_DIM, regroup_order)
expert_scale = self.expert_scale.unsqueeze(0).expand(scores.shape[0], -1)
gathered_scale = torch.gather(expert_scale, self.ROUTING_DIM, routed_indices)
weighted = routed_values * gathered_scale
inactive_mask = torch.bitwise_not(weighted > 0)
masked_values = weighted.masked_fill(inactive_mask, self.ZERO_FILL)
slots = torch.zeros_like(routed_indices).scatter(
self.ROUTING_DIM, regroup_order, self.DISPATCH_ON
)
active_slots = torch.bitwise_not(inactive_mask).to(slots.dtype)
dispatch = slots * active_slots
group_ids = torch.floor_divide(routed_indices, self.GROUP_SIZE)
routing_sign = torch.sign(masked_values)
return (
routed_indices,
masked_values,
dispatch,
inactive_mask,
group_ids,
routing_sign,
)
# ========== TopK Node Test Models ==========
@@ -1840,9 +1907,14 @@ class LlamaTransformerBlockModel(torch.nn.Module):
class Conv1dNoPadModel(torch.nn.Module):
"""Conv1d with no padding: output length shrinks by (kernel-1)."""
KERNEL_SIZE = 3
PADDING = 0
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv1d(8, 16, kernel_size=3, padding=0, bias=False)
self.conv = torch.nn.Conv1d(
8, 16, kernel_size=self.KERNEL_SIZE, padding=self.PADDING, bias=False
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.conv(x)
@@ -1851,9 +1923,14 @@ class Conv1dNoPadModel(torch.nn.Module):
class Conv1dSamePadModel(torch.nn.Module):
"""Conv1d with same-size padding (output length == input length)."""
KERNEL_SIZE = 3
PADDING = 1
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv1d(8, 16, kernel_size=3, padding=1, bias=False)
self.conv = torch.nn.Conv1d(
8, 16, kernel_size=self.KERNEL_SIZE, padding=self.PADDING, bias=False
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.conv(x)
@@ -1862,9 +1939,14 @@ class Conv1dSamePadModel(torch.nn.Module):
class Conv1dBiasModel(torch.nn.Module):
"""Conv1d with bias."""
KERNEL_SIZE = 3
PADDING = 1
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv1d(8, 16, kernel_size=3, padding=1, bias=True)
self.conv = torch.nn.Conv1d(
8, 16, kernel_size=self.KERNEL_SIZE, padding=self.PADDING, bias=True
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.conv(x)
@@ -1873,9 +1955,14 @@ class Conv1dBiasModel(torch.nn.Module):
class Conv2dNoPadModel(torch.nn.Module):
"""Conv2d with no padding: output spatial dims shrink by (kernel-1)."""
KERNEL_SIZE = 3
PADDING = 0
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(3, 16, kernel_size=3, padding=0, bias=False)
self.conv = torch.nn.Conv2d(
3, 16, kernel_size=self.KERNEL_SIZE, padding=self.PADDING, bias=False
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.conv(x)
@@ -1884,9 +1971,14 @@ class Conv2dNoPadModel(torch.nn.Module):
class Conv2dSamePadModel(torch.nn.Module):
"""Conv2d with same-size padding."""
KERNEL_SIZE = 3
PADDING = 1
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(3, 16, kernel_size=3, padding=1, bias=False)
self.conv = torch.nn.Conv2d(
3, 16, kernel_size=self.KERNEL_SIZE, padding=self.PADDING, bias=False
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.conv(x)
@@ -1895,9 +1987,14 @@ class Conv2dSamePadModel(torch.nn.Module):
class Conv2dBiasModel(torch.nn.Module):
"""Conv2d with bias."""
KERNEL_SIZE = 3
PADDING = 1
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(3, 16, kernel_size=3, padding=1, bias=True)
self.conv = torch.nn.Conv2d(
3, 16, kernel_size=self.KERNEL_SIZE, padding=self.PADDING, bias=True
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.conv(x)
@@ -1906,10 +2003,19 @@ class Conv2dBiasModel(torch.nn.Module):
class Conv2dStrideModel(torch.nn.Module):
"""Conv2d with stride=2 (output dims halved)."""
KERNEL_SIZE = 3
STRIDE = 2
PADDING = 1
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(
3, 16, kernel_size=3, stride=2, padding=1, bias=False
3,
16,
kernel_size=self.KERNEL_SIZE,
stride=self.STRIDE,
padding=self.PADDING,
bias=False,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -1919,10 +2025,19 @@ class Conv2dStrideModel(torch.nn.Module):
class Conv2dDilationModel(torch.nn.Module):
"""Conv2d with dilation=2 and padding chosen to preserve spatial size."""
KERNEL_SIZE = 3
DILATION = 2
PADDING = 2
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(
8, 16, kernel_size=3, dilation=2, padding=2, bias=False
8,
16,
kernel_size=self.KERNEL_SIZE,
dilation=self.DILATION,
padding=self.PADDING,
bias=False,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -1932,9 +2047,14 @@ class Conv2dDilationModel(torch.nn.Module):
class Conv3dSamePadModel(torch.nn.Module):
"""Conv3d with padding=1 to preserve spatial dimensions."""
KERNEL_SIZE = 3
PADDING = 1
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv3d(4, 8, kernel_size=3, padding=1, bias=False)
self.conv = torch.nn.Conv3d(
4, 8, kernel_size=self.KERNEL_SIZE, padding=self.PADDING, bias=False
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.conv(x)
@@ -1943,10 +2063,19 @@ class Conv3dSamePadModel(torch.nn.Module):
class DepthwiseConv1dModel(torch.nn.Module):
"""Depthwise Conv1d as used in Mamba (groups == in_channels)."""
KERNEL_SIZE = 4
GROUPS = 16
PADDING = 3
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv1d(
16, 16, kernel_size=4, groups=16, padding=3, bias=True
16,
16,
kernel_size=self.KERNEL_SIZE,
groups=self.GROUPS,
padding=self.PADDING,
bias=True,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -1957,10 +2086,19 @@ class DepthwiseConv1dModel(torch.nn.Module):
class DepthwiseConv2dModel(torch.nn.Module):
"""Depthwise Conv2d (groups == in_channels)."""
KERNEL_SIZE = 3
GROUPS = 8
PADDING = 1
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(
8, 8, kernel_size=3, groups=8, padding=1, bias=False
8,
8,
kernel_size=self.KERNEL_SIZE,
groups=self.GROUPS,
padding=self.PADDING,
bias=False,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -1970,10 +2108,19 @@ class DepthwiseConv2dModel(torch.nn.Module):
class DepthwiseMultiplierConv2dModel(torch.nn.Module):
"""Depthwise Conv2d with channel multiplier 2 (out_channels = 2 * in_channels)."""
KERNEL_SIZE = 3
GROUPS = 8
PADDING = 1
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(
8, 16, kernel_size=3, groups=8, padding=1, bias=False
8,
16,
kernel_size=self.KERNEL_SIZE,
groups=self.GROUPS,
padding=self.PADDING,
bias=False,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -1983,10 +2130,19 @@ class DepthwiseMultiplierConv2dModel(torch.nn.Module):
class GroupedConv2dModel(torch.nn.Module):
"""Conv2d with groups=4 (not depthwise, but grouped)."""
KERNEL_SIZE = 3
GROUPS = 4
PADDING = 1
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(
16, 32, kernel_size=3, groups=4, padding=1, bias=False
16,
32,
kernel_size=self.KERNEL_SIZE,
groups=self.GROUPS,
padding=self.PADDING,
bias=False,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -1996,10 +2152,19 @@ class GroupedConv2dModel(torch.nn.Module):
class GroupedConv2dGroups3Model(torch.nn.Module):
"""Conv2d with groups=3 and ch_per_group=4."""
KERNEL_SIZE = 3
GROUPS = 3
PADDING = 1
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(
12, 12, kernel_size=3, groups=3, padding=1, bias=False
12,
12,
kernel_size=self.KERNEL_SIZE,
groups=self.GROUPS,
padding=self.PADDING,
bias=False,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -2015,9 +2180,16 @@ class MambaConvBlockModel(torch.nn.Module):
def __init__(self, d_model: int = 16, d_conv: int = 4, expand: int = 2) -> None:
super().__init__()
d_inner = d_model * expand
groups = d_inner
padding = d_conv - 1
self.in_proj = torch.nn.Linear(d_model, d_inner * 2, bias=False)
self.conv1d = torch.nn.Conv1d(
d_inner, d_inner, d_conv, groups=d_inner, padding=d_conv - 1, bias=True
d_inner,
d_inner,
d_conv,
groups=groups,
padding=padding,
bias=True,
)
self.out_proj = torch.nn.Linear(d_inner, d_model, bias=False)

View File

@@ -0,0 +1,22 @@
[package]
name = "gemma4_moe"
version = "0.1.0"
edition = "2021"
[features]
[dependencies]
luminal = { path = "../.." }
luminal_nn = { path = "../../crates/luminal_nn" }
luminal_cuda_lite = { path = "../../crates/luminal_cuda_lite" }
tokenizers = "0.22.2"
rustc-hash = "2"
# HuggingFace model download
hf-hub = { version = "0.4", default-features = false, features = ["rustls-tls", "ureq"] }
safetensors = "0.7.0"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
half = { version = "2.7.1", features = ["bytemuck"] }
bytemuck = "1.24.0"
memmap2 = "0.9.9"

View File

@@ -0,0 +1,227 @@
use half::{bf16, f16};
use hf_hub::api::sync::Api;
use memmap2::MmapOptions;
use safetensors::{tensor::TensorView, Dtype, SafeTensors};
use serde::Deserialize;
use std::{
collections::HashMap,
fs::File,
io::Write,
path::{Path, PathBuf},
};
use crate::model::HIDDEN;
#[derive(Deserialize)]
struct SafetensorsIndex {
weight_map: HashMap<String, String>,
}
enum TensorData {
F32(Vec<f32>),
BF16(Vec<u8>),
}
struct StoredTensor {
shape: Vec<usize>,
data: TensorData,
}
pub fn download_hf_model(repo_id: &str) -> Result<PathBuf, Box<dyn std::error::Error>> {
let api = Api::new()?;
let repo = api.model(repo_id.to_string());
let tokenizer_path = repo.get("tokenizer.json")?;
let model_dir = tokenizer_path.parent().unwrap().to_path_buf();
if repo.get("model.safetensors").is_ok() {
return Ok(model_dir);
}
let index_path = repo.get("model.safetensors.index.json")?;
let index_content = std::fs::read_to_string(&index_path)?;
let index: SafetensorsIndex = serde_json::from_str(&index_content)?;
let mut shard_files: Vec<String> = index.weight_map.values().cloned().collect();
shard_files.sort();
shard_files.dedup();
for shard_file in &shard_files {
repo.get(shard_file)?;
}
Ok(model_dir)
}
fn tensor_to_f32(tensor: &safetensors::tensor::TensorView) -> Vec<f32> {
match tensor.dtype() {
Dtype::F32 => bytemuck::cast_slice::<u8, f32>(tensor.data()).to_vec(),
Dtype::F16 => {
let f16_slice: &[f16] = bytemuck::cast_slice(tensor.data());
f16_slice.iter().map(|x| x.to_f32()).collect()
}
Dtype::BF16 => {
let bf16_slice: &[bf16] = bytemuck::cast_slice(tensor.data());
bf16_slice.iter().map(|x| x.to_f32()).collect()
}
other => panic!("Unsupported dtype for conversion: {other:?}"),
}
}
fn tensor_to_bf16_bytes(tensor: &safetensors::tensor::TensorView) -> Vec<u8> {
match tensor.dtype() {
Dtype::BF16 => tensor.data().to_vec(),
Dtype::F16 => {
let f16_slice: &[f16] = bytemuck::cast_slice(tensor.data());
let bf16_data: Vec<bf16> = f16_slice
.iter()
.map(|x| bf16::from_f32(x.to_f32()))
.collect();
bytemuck::cast_slice(&bf16_data).to_vec()
}
Dtype::F32 => {
let f32_slice: &[f32] = bytemuck::cast_slice(tensor.data());
let bf16_data: Vec<bf16> = f32_slice.iter().map(|x| bf16::from_f32(*x)).collect();
bytemuck::cast_slice(&bf16_data).to_vec()
}
other => panic!("Unsupported dtype for conversion: {other:?}"),
}
}
fn is_text_weight(name: &str) -> bool {
name.starts_with("model.language_model.")
}
fn is_expert_weight(name: &str) -> bool {
name.contains(".experts.")
}
pub fn combine_safetensors(model_dir: &Path) -> Result<PathBuf, Box<dyn std::error::Error>> {
let output_path = model_dir.join("model_combined.safetensors");
if output_path.exists() {
return Ok(output_path);
}
let index_path = model_dir.join("model.safetensors.index.json");
let single_shard_path = model_dir.join("model.safetensors");
let shard_files: Vec<PathBuf> = if single_shard_path.exists() && !index_path.exists() {
println!("Single shard model detected...");
vec![single_shard_path]
} else if index_path.exists() {
let index_content = std::fs::read_to_string(&index_path)?;
let index: SafetensorsIndex = serde_json::from_str(&index_content)?;
let mut files: Vec<String> = index.weight_map.values().cloned().collect();
files.sort();
files.dedup();
println!("Loading {} shard files...", files.len());
files.into_iter().map(|f| model_dir.join(f)).collect()
} else {
return Err("No model.safetensors or model.safetensors.index.json found".into());
};
let mut all_tensors: HashMap<String, StoredTensor> = HashMap::new();
for shard_path in &shard_files {
println!(
" Loading {}...",
shard_path.file_name().unwrap().to_string_lossy()
);
let file = File::open(shard_path)?;
let mmap = unsafe { MmapOptions::new().map(&file)? };
let st = SafeTensors::deserialize(&mmap)?;
for name in st.names() {
if !is_text_weight(name) {
continue;
}
let new_name = name.replacen("model.language_model.", "model.", 1);
let tensor = st.tensor(name)?;
if new_name.ends_with(".layer_scalar") {
let scalar = tensor_to_f32(&tensor);
let scalar = *scalar.first().expect("layer_scalar tensor is empty");
all_tensors.insert(
new_name,
StoredTensor {
shape: vec![HIDDEN],
data: TensorData::F32(vec![scalar; HIDDEN]),
},
);
continue;
}
let shape = tensor.shape().to_vec();
let data = if is_expert_weight(&new_name) {
TensorData::BF16(tensor_to_bf16_bytes(&tensor))
} else {
TensorData::F32(tensor_to_f32(&tensor))
};
all_tensors.insert(new_name, StoredTensor { shape, data });
}
}
println!("Extracted {} text tensors", all_tensors.len());
let embed_key = "model.embed_tokens.weight";
if let Some(embed_tensor) = all_tensors.get(embed_key) {
let (shape, embed_data) = match &embed_tensor.data {
TensorData::F32(data) => (embed_tensor.shape.clone(), data.clone()),
TensorData::BF16(_) => unreachable!("Embedding weights should stay in F32"),
};
all_tensors.insert(
"lm_head.weight".to_string(),
StoredTensor {
shape,
data: TensorData::F32(embed_data.clone()),
},
);
let embed_scale = (HIDDEN as f32).sqrt();
if let Some(stored) = all_tensors.get_mut(embed_key) {
match &mut stored.data {
TensorData::F32(data) => {
for value in data {
*value *= embed_scale;
}
}
TensorData::BF16(_) => unreachable!("Embedding weights should stay in F32"),
}
}
}
println!("Saving combined model (BF16 experts + F32 rest)...");
let tensor_views: HashMap<String, TensorView<'_>> = all_tensors
.iter()
.map(|(name, stored)| {
let view = match &stored.data {
TensorData::F32(data) => {
let bytes: &[u8] = bytemuck::cast_slice(data);
TensorView::new(Dtype::F32, stored.shape.clone(), bytes).unwrap()
}
TensorData::BF16(bytes) => {
TensorView::new(Dtype::BF16, stored.shape.clone(), bytes).unwrap()
}
};
(name.clone(), view)
})
.collect();
let serialized = safetensors::serialize(&tensor_views, None)?;
let mut file = File::create(&output_path)?;
file.write_all(&serialized)?;
println!("Combined model saved successfully!");
Ok(output_path)
}
pub fn prepare_hf_model(repo_id: &str) -> Result<PathBuf, Box<dyn std::error::Error>> {
let model_dir = download_hf_model(repo_id)?;
combine_safetensors(&model_dir)?;
Ok(model_dir)
}

View File

@@ -0,0 +1,190 @@
mod hf;
mod model;
use hf::prepare_hf_model;
use luminal::prelude::*;
use luminal_cuda_lite::{cudarc::driver::CudaContext, runtime::CudaRuntime};
use model::*;
use rustc_hash::FxHashSet;
use std::{io::Write, time::Duration};
use tokenizers::Tokenizer;
const REPO_ID: &str = "google/gemma-4-26B-A4B";
fn env_usize(name: &str, default: usize) -> usize {
std::env::var(name)
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(default)
}
fn env_bool(name: &str) -> bool {
std::env::var(name)
.ok()
.is_some_and(|s| matches!(s.as_str(), "1" | "true" | "TRUE" | "yes" | "YES"))
}
fn main() {
let max_seq_len = env_usize("MAX_SEQ_LEN", 4096);
let gen_tokens = env_usize("GEN_TOKENS", 30);
let search_graphs = env_usize("SEARCH_GRAPHS", 50);
let prompt = std::env::var("PROMPT").unwrap_or_else(|_| "The capital of France is".to_string());
let print_token_ids = env_bool("PRINT_TOKEN_IDS");
let ctx = CudaContext::new(0).unwrap();
let stream = ctx.default_stream();
let model_dir = prepare_hf_model(REPO_ID).expect("Failed to prepare model");
println!("Using model directory: {}", model_dir.display());
let tokenizer = Tokenizer::from_file(model_dir.join("tokenizer.json")).unwrap();
let prompt_tokens = tokenizer
.encode(prompt.as_str(), true)
.unwrap()
.get_ids()
.to_vec();
let mut cx = Graph::default();
let input = cx.named_tensor("input", 's').as_dtype(DType::Int);
let pos_ids = cx.named_tensor("pos_ids", 's').as_dtype(DType::Int);
let kv_cache = KVCache::new(&mut cx, max_seq_len);
let (logits, cache_outputs) = Gemma4MoE::init(&mut cx).forward(input, pos_ids, &kv_cache);
let logits = logits.output();
for (k_out, v_out) in &cache_outputs {
k_out.output();
v_out.output();
}
println!("Building E-Graph...");
cx.build_search_space::<CudaRuntime>();
println!("Loading weights...");
let mut runtime = CudaRuntime::initialize(stream);
let weights_path = model_dir.join("model_combined.safetensors");
runtime.load_safetensors(&cx, weights_path.to_str().unwrap());
for layer in 0..LAYERS {
let cache_bytes = cache_bytes_for_layer(layer, max_seq_len);
runtime.set_zeros(kv_cache.k_caches[layer], cache_bytes);
runtime.set_zeros(kv_cache.v_caches[layer], cache_bytes);
}
println!("Compiling...");
cx.set_dim('s', 1);
cx.set_dim('p', 1);
runtime.set_data(input, vec![1]);
runtime.set_data(pos_ids, vec![1]);
runtime = cx.search(runtime, search_graphs);
for layer in 0..LAYERS {
let cache_bytes = cache_bytes_for_layer(layer, max_seq_len);
runtime.set_zeros(kv_cache.k_caches[layer], cache_bytes);
runtime.set_zeros(kv_cache.v_caches[layer], cache_bytes);
}
print!("{prompt}");
std::io::stdout().flush().unwrap();
let mut prev_seq = 0usize;
let mut fwd_durations = vec![];
let mut seen_tokens = FxHashSet::default();
let mut generated_token_ids = vec![];
let repetition_penalty: f32 = 1.05;
const EOS_TOKEN: u32 = 1;
let prefill_start = std::time::Instant::now();
for &token in &prompt_tokens {
cx.set_dim('s', 1);
cx.set_dim('p', prev_seq);
runtime.set_data(input, vec![token as i32]);
runtime.set_data(pos_ids, vec![prev_seq as i32]);
runtime.execute(&cx.dyn_map);
for (layer_idx, (k_out, v_out)) in cache_outputs.iter().enumerate() {
let k_buf = runtime.remove_buffer(*k_out);
let v_buf = runtime.remove_buffer(*v_out);
runtime.set_buffer(kv_cache.k_caches[layer_idx], k_buf);
runtime.set_buffer(kv_cache.v_caches[layer_idx], v_buf);
}
prev_seq += 1;
}
let prefill_duration = prefill_start.elapsed();
let logits_data = runtime.get_f32(logits);
let last_row = &logits_data[..VOCAB_SIZE];
let mut next_token = last_row
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.total_cmp(b))
.unwrap()
.0 as u32;
generated_token_ids.push(next_token);
print!("{}", tokenizer.decode(&[next_token], true).unwrap());
std::io::stdout().flush().unwrap();
seen_tokens.insert(next_token);
for _ in 1..gen_tokens {
let start = std::time::Instant::now();
cx.set_dim('s', 1);
cx.set_dim('p', prev_seq);
runtime.set_data(input, vec![next_token as i32]);
runtime.set_data(pos_ids, vec![prev_seq as i32]);
runtime.execute(&cx.dyn_map);
for (layer_idx, (k_out, v_out)) in cache_outputs.iter().enumerate() {
let k_buf = runtime.remove_buffer(*k_out);
let v_buf = runtime.remove_buffer(*v_out);
runtime.set_buffer(kv_cache.k_caches[layer_idx], k_buf);
runtime.set_buffer(kv_cache.v_caches[layer_idx], v_buf);
}
prev_seq += 1;
let logits_data = runtime.get_f32(logits);
let mut last_row = logits_data[..VOCAB_SIZE].to_vec();
for &tok in &seen_tokens {
let logit = &mut last_row[tok as usize];
if *logit > 0.0 {
*logit /= repetition_penalty;
} else {
*logit *= repetition_penalty;
}
}
next_token = last_row
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.total_cmp(b))
.unwrap()
.0 as u32;
generated_token_ids.push(next_token);
seen_tokens.insert(next_token);
if next_token == EOS_TOKEN {
break;
}
print!("{}", tokenizer.decode(&[next_token], true).unwrap());
std::io::stdout().flush().unwrap();
fwd_durations.push(start.elapsed());
}
println!();
if print_token_ids {
println!("Generated token ids: {generated_token_ids:?}");
}
println!(
" TTFT: {:.2} ms ({} prompt tokens)",
prefill_duration.as_secs_f64() * 1e3,
prompt_tokens.len()
);
if fwd_durations.len() > 1 {
println!(
" TPOT: {:.2} ms",
(fwd_durations.iter().skip(1).sum::<Duration>() / (fwd_durations.len() - 1) as u32)
.as_secs_f64()
* 1_000.
);
}
}

View File

@@ -0,0 +1,621 @@
use luminal::{
dtype::DType,
graph::Graph,
prelude::{F32Pow, GraphTensor},
shape::Expression,
};
use luminal_nn::LayerNorm;
pub const LAYERS: usize = 30;
pub const HIDDEN: usize = 2816;
pub const INTERMEDIATE: usize = 2112;
pub const MOE_INTERMEDIATE: usize = 704;
pub const NUM_EXPERTS: usize = 128;
pub const TOP_K: usize = 8;
pub const N_HEADS: usize = 16;
pub const SLIDING_HEAD_DIM: usize = 256;
pub const FULL_HEAD_DIM: usize = 512;
pub const SLIDING_KV_HEADS: usize = 8;
pub const FULL_KV_HEADS: usize = 2;
pub const VOCAB_SIZE: usize = 262144;
pub const RMS_NORM_EPS: f32 = 1e-6;
pub const SLIDING_WINDOW_SIZE: usize = 1024;
pub const SLIDING_ROPE_THETA: f32 = 10_000.0;
pub const FULL_ROPE_THETA: f32 = 1_000_000.0;
pub const FULL_PARTIAL_ROTARY_FACTOR: f32 = 0.25;
pub const FINAL_LOGIT_SOFTCAP: f32 = 30.0;
#[derive(Clone, Copy)]
struct LayerSpec {
is_sliding: bool,
head_dim: usize,
q_dim: usize,
num_kv_heads: usize,
kv_dim: usize,
kv_groups: usize,
rope_theta: f32,
partial_rotary_factor: f32,
has_v_proj: bool,
}
fn layer_spec(layer: usize) -> LayerSpec {
if !(layer + 1).is_multiple_of(6) {
LayerSpec {
is_sliding: true,
head_dim: SLIDING_HEAD_DIM,
q_dim: N_HEADS * SLIDING_HEAD_DIM,
num_kv_heads: SLIDING_KV_HEADS,
kv_dim: SLIDING_KV_HEADS * SLIDING_HEAD_DIM,
kv_groups: N_HEADS / SLIDING_KV_HEADS,
rope_theta: SLIDING_ROPE_THETA,
partial_rotary_factor: 1.0,
has_v_proj: true,
}
} else {
LayerSpec {
is_sliding: false,
head_dim: FULL_HEAD_DIM,
q_dim: N_HEADS * FULL_HEAD_DIM,
num_kv_heads: FULL_KV_HEADS,
kv_dim: FULL_KV_HEADS * FULL_HEAD_DIM,
kv_groups: N_HEADS / FULL_KV_HEADS,
rope_theta: FULL_ROPE_THETA,
partial_rotary_factor: FULL_PARTIAL_ROTARY_FACTOR,
has_v_proj: false,
}
}
}
pub fn cache_bytes_for_layer(layer: usize, max_seq: usize) -> usize {
let spec = layer_spec(layer);
spec.num_kv_heads * max_seq * spec.head_dim * std::mem::size_of::<f32>()
}
pub struct KVCache {
pub k_caches: Vec<GraphTensor>,
pub v_caches: Vec<GraphTensor>,
pub max_seq: usize,
}
impl KVCache {
pub fn new(cx: &mut Graph, max_seq: usize) -> Self {
let mut k_caches = Vec::with_capacity(LAYERS);
let mut v_caches = Vec::with_capacity(LAYERS);
for layer in 0..LAYERS {
let spec = layer_spec(layer);
let k = cx
.named_tensor(
format!("kv_cache.{layer}.k"),
(spec.num_kv_heads, max_seq, spec.head_dim),
)
.persist();
let v = cx
.named_tensor(
format!("kv_cache.{layer}.v"),
(spec.num_kv_heads, max_seq, spec.head_dim),
)
.persist();
k_caches.push(k);
v_caches.push(v);
}
Self {
k_caches,
v_caches,
max_seq,
}
}
}
pub struct Gemma4MoE {
embedding: GraphTensor,
lm_head: GraphTensor,
layers: Vec<Gemma4Layer>,
lm_norm: LayerNorm,
}
impl Gemma4MoE {
pub fn init(cx: &mut Graph) -> Self {
let mut layers = Vec::with_capacity(LAYERS);
for layer in 0..LAYERS {
let spec = layer_spec(layer);
let gate = cx
.named_tensor(
format!("model.layers.{layer}.mlp.gate_proj.weight"),
(INTERMEDIATE, HIDDEN),
)
.persist();
let up = cx
.named_tensor(
format!("model.layers.{layer}.mlp.up_proj.weight"),
(INTERMEDIATE, HIDDEN),
)
.persist();
let down = cx
.named_tensor(
format!("model.layers.{layer}.mlp.down_proj.weight"),
(HIDDEN, INTERMEDIATE),
)
.persist();
let q_proj = cx
.named_tensor(
format!("model.layers.{layer}.self_attn.q_proj.weight"),
(spec.q_dim, HIDDEN),
)
.persist();
let k_proj = cx
.named_tensor(
format!("model.layers.{layer}.self_attn.k_proj.weight"),
(spec.kv_dim, HIDDEN),
)
.persist();
let v_proj = spec.has_v_proj.then(|| {
cx.named_tensor(
format!("model.layers.{layer}.self_attn.v_proj.weight"),
(spec.kv_dim, HIDDEN),
)
.persist()
});
let o_proj = cx
.named_tensor(
format!("model.layers.{layer}.self_attn.o_proj.weight"),
(HIDDEN, spec.q_dim),
)
.persist();
let q_norm = cx
.named_tensor(
format!("model.layers.{layer}.self_attn.q_norm.weight"),
spec.head_dim,
)
.persist();
let k_norm = cx
.named_tensor(
format!("model.layers.{layer}.self_attn.k_norm.weight"),
spec.head_dim,
)
.persist();
let layer_scalar = cx
.named_tensor(format!("model.layers.{layer}.layer_scalar"), HIDDEN)
.persist();
let router_scale = cx
.named_tensor(format!("model.layers.{layer}.router.scale"), HIDDEN)
.persist();
let router_proj = cx
.named_tensor(
format!("model.layers.{layer}.router.proj.weight"),
(NUM_EXPERTS, HIDDEN),
)
.persist();
let per_expert_scale = cx
.named_tensor(
format!("model.layers.{layer}.router.per_expert_scale"),
NUM_EXPERTS,
)
.persist();
let gate_up_weights = cx
.named_tensor(
format!("model.layers.{layer}.experts.gate_up_proj"),
(NUM_EXPERTS, MOE_INTERMEDIATE * 2, HIDDEN),
)
.persist()
.as_dtype(DType::Bf16);
let down_weights = cx
.named_tensor(
format!("model.layers.{layer}.experts.down_proj"),
(NUM_EXPERTS, HIDDEN, MOE_INTERMEDIATE),
)
.persist()
.as_dtype(DType::Bf16);
layers.push(Gemma4Layer {
spec,
gate,
up,
down,
q_proj,
k_proj,
v_proj,
o_proj,
q_norm,
k_norm,
layer_scalar,
input_layernorm: gemma4_norm(
HIDDEN,
&format!("model.layers.{layer}.input_layernorm.weight"),
cx,
),
post_attention_layernorm: gemma4_norm(
HIDDEN,
&format!("model.layers.{layer}.post_attention_layernorm.weight"),
cx,
),
pre_feedforward_layernorm: gemma4_norm(
HIDDEN,
&format!("model.layers.{layer}.pre_feedforward_layernorm.weight"),
cx,
),
post_feedforward_layernorm: gemma4_norm(
HIDDEN,
&format!("model.layers.{layer}.post_feedforward_layernorm.weight"),
cx,
),
post_feedforward_layernorm_1: gemma4_norm(
HIDDEN,
&format!("model.layers.{layer}.post_feedforward_layernorm_1.weight"),
cx,
),
post_feedforward_layernorm_2: gemma4_norm(
HIDDEN,
&format!("model.layers.{layer}.post_feedforward_layernorm_2.weight"),
cx,
),
pre_feedforward_layernorm_2: gemma4_norm(
HIDDEN,
&format!("model.layers.{layer}.pre_feedforward_layernorm_2.weight"),
cx,
),
moe: Gemma4SparseMoE {
router_scale,
router_proj,
per_expert_scale,
gate_up_weights,
down_weights,
},
});
}
let embedding = cx
.named_tensor("model.embed_tokens.weight", (VOCAB_SIZE, HIDDEN))
.persist();
let lm_head = cx
.named_tensor("lm_head.weight", (VOCAB_SIZE, HIDDEN))
.persist();
let lm_norm = gemma4_norm(HIDDEN, "model.norm.weight", cx);
Self {
embedding,
lm_head,
layers,
lm_norm,
}
}
pub fn forward(
&self,
token_ids: GraphTensor,
pos_ids: GraphTensor,
kv_cache: &KVCache,
) -> (GraphTensor, Vec<(GraphTensor, GraphTensor)>) {
let seq = token_ids.dims1();
let mut x = self.embedding.gather(
(token_ids * HIDDEN).expand_dim(1, HIDDEN)
+ token_ids.graph().arange(HIDDEN).expand_dim(0, seq),
);
let mut cache_outputs = Vec::with_capacity(LAYERS);
for (layer_idx, layer) in self.layers.iter().enumerate() {
let (x_new, k_out, v_out) = layer.forward(
x,
pos_ids,
kv_cache.k_caches[layer_idx],
kv_cache.v_caches[layer_idx],
kv_cache.max_seq,
);
x = x_new.graph_break();
cache_outputs.push((k_out, v_out));
}
let logits = self.lm_norm.forward(x).matmul(self.lm_head.t());
let logits = (logits / FINAL_LOGIT_SOFTCAP).tanh() * FINAL_LOGIT_SOFTCAP;
(logits, cache_outputs)
}
}
struct Gemma4Layer {
spec: LayerSpec,
gate: GraphTensor,
up: GraphTensor,
down: GraphTensor,
q_proj: GraphTensor,
k_proj: GraphTensor,
v_proj: Option<GraphTensor>,
o_proj: GraphTensor,
q_norm: GraphTensor,
k_norm: GraphTensor,
layer_scalar: GraphTensor,
input_layernorm: LayerNorm,
post_attention_layernorm: LayerNorm,
pre_feedforward_layernorm: LayerNorm,
post_feedforward_layernorm: LayerNorm,
post_feedforward_layernorm_1: LayerNorm,
post_feedforward_layernorm_2: LayerNorm,
pre_feedforward_layernorm_2: LayerNorm,
moe: Gemma4SparseMoE,
}
struct Gemma4SparseMoE {
router_scale: GraphTensor,
router_proj: GraphTensor,
per_expert_scale: GraphTensor,
gate_up_weights: GraphTensor,
down_weights: GraphTensor,
}
fn gemma4_norm(dim: usize, weight_name: &str, cx: &mut Graph) -> LayerNorm {
LayerNorm::new(dim, Some(weight_name), None, false, RMS_NORM_EPS, cx)
}
#[allow(clippy::excessive_precision)]
fn gemma_gelu(x: GraphTensor) -> GraphTensor {
let scaled = 1.5957691216 * x * (1. + 0.044715 * x * x);
x * scaled.sigmoid()
}
fn qk_norm(x: GraphTensor, weight: GraphTensor, n_heads: usize, head_dim: usize) -> GraphTensor {
let seq = x.dims()[0];
let reshaped = x.split_dims(1, head_dim);
let normed = reshaped.std_norm(2, RMS_NORM_EPS);
let w = weight.expand_dim(0, n_heads).expand_dim(0, seq);
(normed * w).merge_dims(1, 2)
}
fn value_norm(x: GraphTensor, head_dim: usize) -> GraphTensor {
x.split_dims(1, head_dim)
.std_norm(2, RMS_NORM_EPS)
.merge_dims(1, 2)
}
fn gemma4_rotary_embeddings(
input: GraphTensor,
pos_ids: GraphTensor,
n_heads: usize,
head_dim: usize,
rope_theta: f32,
partial_rotary_factor: f32,
) -> GraphTensor {
let input = input.split_dims(1, head_dim).transpose(0, 1);
let half_dim = head_dim / 2;
let rope_angles = ((partial_rotary_factor * head_dim as f32) / 2.0).floor() as usize;
let rotated = input
.graph()
.arange_options(0, rope_angles * 2, 2)
.cast(DType::F32)
/ head_dim as f32;
let rotated = rope_theta.pow(rotated).reciprocal();
let inv_freqs = if rope_angles < half_dim {
let zeros = input
.graph()
.arange(half_dim - rope_angles)
.cast(DType::F32)
* 0.0;
rotated.concat_along(zeros, 0)
} else {
rotated
};
let emb = pos_ids
.cast(DType::F32)
.expand_dim(1, 1)
.matmul(inv_freqs.expand_dim(0, 1));
let x0 = input.slice((.., .., ..half_dim));
let x1 = input.slice((.., .., half_dim..));
let cos = emb.cos().expand_dim(0, n_heads);
let sin = emb.sin().expand_dim(0, n_heads);
let x0_out = x0 * cos - x1 * sin;
let x1_out = x1 * cos + x0 * sin;
x0_out
.concat_along(x1_out, 2)
.transpose(0, 1)
.merge_dims(1, 2)
}
fn gather_experts(
graph_source: GraphTensor,
top_k_indices: GraphTensor,
weights: GraphTensor,
) -> GraphTensor {
let (_, d1, d2) = weights.dims3();
let io = d1 * d2;
let base = top_k_indices * io;
let within = graph_source.graph().iota(Expression::from('z'), (d1, d2));
let n_base = base.dims().len();
let exp_base = base.expand_dim(n_base, d1).expand_dim(n_base + 1, d2);
let mut exp_within = within;
for (axis, dim) in base.dims().iter().enumerate() {
exp_within = exp_within.expand_dim(axis, *dim);
}
let expert_flat_idx = exp_base + exp_within;
weights.gather(expert_flat_idx)
}
fn hlir_attention(
q_rope: GraphTensor,
k_rope: GraphTensor,
v: GraphTensor,
k_cache_in: GraphTensor,
v_cache_in: GraphTensor,
max_seq: usize,
spec: LayerSpec,
) -> (GraphTensor, GraphTensor, GraphTensor) {
let cx = q_rope.graph();
let seq = q_rope.dims()[0];
let prev = Expression::from('p');
let total_seq = prev + seq;
let k_new = k_rope.split_dims(1, spec.head_dim).transpose(0, 1);
let v_new = v.split_dims(1, spec.head_dim).transpose(0, 1);
let h_offset = cx.arange(spec.num_kv_heads) * (max_seq * spec.head_dim);
let p_offset = (cx.arange(seq) + prev) * spec.head_dim;
let d_offset = cx.arange(spec.head_dim);
let scatter_idx = h_offset.expand_dim(1, seq).expand_dim(2, spec.head_dim)
+ p_offset
.expand_dim(0, spec.num_kv_heads)
.expand_dim(2, spec.head_dim)
+ d_offset.expand_dim(0, spec.num_kv_heads).expand_dim(1, seq);
let k_cache_out = k_new.scatter(scatter_idx, k_cache_in);
let v_cache_out = v_new.scatter(scatter_idx, v_cache_in);
let k_full = k_cache_out.slice((.., ..total_seq, ..));
let v_full = v_cache_out.slice((.., ..total_seq, ..));
let k_3d = k_full.expand_dim(1, spec.kv_groups).merge_dims(0, 1);
let v_3d = v_full.expand_dim(1, spec.kv_groups).merge_dims(0, 1);
let q = q_rope.split_dims(1, spec.head_dim).transpose(0, 1);
// Gemma 4's text attention uses Q/K normalization and then leaves the
// attention scaling at 1.0 in the reference implementation.
let scores = q.matmul(k_3d.transpose(1, 2));
let q_abs = cx.arange(seq).cast(DType::F32) + prev;
let k_pos = cx.arange(total_seq).cast(DType::F32);
let future_mask = k_pos
.expand_dim(0, seq)
.gt(q_abs.expand_dim(1, total_seq))
.cast(DType::F32);
let mask_2d = if spec.is_sliding {
let window_start = q_abs - (SLIDING_WINDOW_SIZE - 1) as f32;
let past_mask = window_start
.expand_dim(1, total_seq)
.gt(k_pos.expand_dim(0, seq))
.cast(DType::F32);
future_mask + past_mask
} else {
future_mask
};
let mask_3d = mask_2d.expand_dim(0, N_HEADS);
let masked_scores = scores + mask_3d * (-1e10f32);
let attn_weights = masked_scores.softmax(2);
let attn_out = attn_weights.matmul(v_3d);
let out = attn_out.transpose(0, 1).merge_dims(1, 2);
(out, k_cache_out, v_cache_out)
}
impl Gemma4Layer {
pub fn forward(
&self,
x: GraphTensor,
pos_ids: GraphTensor,
k_cache_in: GraphTensor,
v_cache_in: GraphTensor,
max_seq: usize,
) -> (GraphTensor, GraphTensor, GraphTensor) {
let residual = x;
let x_attn = self.input_layernorm.forward(x);
let q = x_attn.matmul(self.q_proj.t());
let k_base = x_attn.matmul(self.k_proj.t());
let v_base = if let Some(v_proj) = self.v_proj {
x_attn.matmul(v_proj.t())
} else {
k_base
};
let q_normed = qk_norm(q, self.q_norm, N_HEADS, self.spec.head_dim);
let k_normed = qk_norm(
k_base,
self.k_norm,
self.spec.num_kv_heads,
self.spec.head_dim,
);
let v_normed = value_norm(v_base, self.spec.head_dim);
let q_rope = gemma4_rotary_embeddings(
q_normed,
pos_ids,
N_HEADS,
self.spec.head_dim,
self.spec.rope_theta,
self.spec.partial_rotary_factor,
);
let k_rope = gemma4_rotary_embeddings(
k_normed,
pos_ids,
self.spec.num_kv_heads,
self.spec.head_dim,
self.spec.rope_theta,
self.spec.partial_rotary_factor,
);
let (attn_out, k_cache_out, v_cache_out) = hlir_attention(
q_rope, k_rope, v_normed, k_cache_in, v_cache_in, max_seq, self.spec,
);
let attn_proj = attn_out.matmul(self.o_proj.t());
let x = residual + self.post_attention_layernorm.forward(attn_proj);
let dense_ff = dense_ffn(
self.pre_feedforward_layernorm.forward(x),
self.gate,
self.up,
self.down,
);
let dense_ff = self.post_feedforward_layernorm_1.forward(dense_ff);
let moe_out = self
.moe
.forward(x, self.pre_feedforward_layernorm_2.forward(x));
let moe_out = self.post_feedforward_layernorm_2.forward(moe_out);
let ff_out = self.post_feedforward_layernorm.forward(dense_ff + moe_out);
let x = x + ff_out;
let x = x * self
.layer_scalar
.expand_lhs(&x.dims()[..x.dims().len() - 1]);
(x, k_cache_out, v_cache_out)
}
}
fn dense_ffn(x: GraphTensor, gate: GraphTensor, up: GraphTensor, down: GraphTensor) -> GraphTensor {
(gemma_gelu(x.matmul(gate.t())) * x.matmul(up.t())).matmul(down.t())
}
impl Gemma4SparseMoE {
fn forward(&self, router_input: GraphTensor, expert_input: GraphTensor) -> GraphTensor {
let n = router_input.dims().len();
let e_dim = *self.router_proj.dims().first().unwrap();
let k_expr = Expression::from(TOP_K);
let router_hidden = router_input.std_norm(router_input.dims().len() - 1, RMS_NORM_EPS)
* self
.router_scale
.expand_lhs(&router_input.dims()[..router_input.dims().len() - 1])
* (HIDDEN as f32).sqrt().recip();
let routing_weights = router_hidden.matmul(self.router_proj.t()).softmax(n - 1);
let top_k_indices = routing_weights.topk_indexes(TOP_K, n - 1);
let row_offsets = router_input
.graph()
.iota(Expression::from('z') / k_expr * e_dim, top_k_indices.dims());
let routing_flat_idx = row_offsets + top_k_indices;
let top_k_values = routing_weights.gather(routing_flat_idx);
let top_k_norm = top_k_values.sum(n - 1).expand_dim(n - 1, TOP_K);
let top_k_weights =
(top_k_values / top_k_norm) * self.per_expert_scale.gather(top_k_indices);
let gate_up_gathered =
gather_experts(expert_input, top_k_indices, self.gate_up_weights).cast(DType::F32);
let x_exp = expert_input.expand_dim(n - 1, TOP_K).unsqueeze(n);
let gate_up_out = x_exp.matmul(gate_up_gathered.transpose(2, 3)).squeeze(n);
let gate = gate_up_out.slice((.., .., ..MOE_INTERMEDIATE));
let up = gate_up_out.slice((.., .., MOE_INTERMEDIATE..));
let hidden = gemma_gelu(gate) * up;
let down_gathered =
gather_experts(expert_input, top_k_indices, self.down_weights).cast(DType::F32);
let hidden_exp = hidden.unsqueeze(2);
let down_out = hidden_exp.matmul(down_gathered.transpose(2, 3)).squeeze(2);
(down_out * top_k_weights.unsqueeze(top_k_weights.dims().len())).sum(n - 1)
}
}

View File

@@ -264,8 +264,7 @@ impl QwenMoE {
let row_offsets = x
.graph()
.iota(Expression::from('z') / k_expr * e_dim, top_k_indices.dims());
let routing_flat_idx =
(row_offsets.cast(DType::F32) + top_k_indices.cast(DType::F32)).cast(DType::Int);
let routing_flat_idx = row_offsets + top_k_indices;
let top_k_values = routing_weights.gather(routing_flat_idx);
// 4. Gather gate_up expert weights → [s, k, intermediate*2, H]
@@ -303,18 +302,18 @@ fn gather_experts(
) -> GraphTensor {
let (_, d1, d2) = weights.dims3();
let io = d1 * d2;
let base = (top_k_indices * io).cast(DType::F32);
let within = graph_source
.graph()
.iota(Expression::from('z'), (d1, d2))
.cast(DType::F32);
// Keep expert gather indices in Int all the way through. Routing them through
// F32 loses exactness once the flat offsets exceed 2^24, which Qwen's expert
// tensors do at realistic hidden sizes.
let base = top_k_indices * io;
let within = graph_source.graph().iota(Expression::from('z'), (d1, d2));
let n_base = base.dims().len();
let exp_base = base.expand_dim(n_base, d1).expand_dim(n_base + 1, d2);
let mut exp_within = within;
for (i, dim) in base.dims().iter().enumerate() {
exp_within = exp_within.expand_dim(i, *dim);
}
let expert_flat_idx = (exp_base + exp_within).cast(DType::Int);
let expert_flat_idx = exp_base + exp_within;
weights.gather(expert_flat_idx)
}

View File

@@ -6,7 +6,7 @@ use rand::Rng;
use rustc_hash::FxHashSet;
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use std::{str, sync::Arc};
use std::{str, sync::Arc, time::Duration};
use tracing::trace;
pub mod api;
@@ -128,8 +128,10 @@ pub fn early_egglog(
program.to_string(),
format!(
"(run-schedule
(saturate expr)
(run)
(repeat 6
(saturate expr)
(run)
)
(saturate base_cleanup)
)
(extract {root})"
@@ -179,6 +181,20 @@ pub struct SerializedEGraph {
pub roots: Vec<ClassId>,
}
#[derive(Debug, Clone, Default)]
pub struct EgglogStageReport {
pub num_matches_per_rule: FxHashMap<String, usize>,
pub search_and_apply_time_per_rule: FxHashMap<String, Duration>,
pub total_time: Duration,
}
#[derive(Debug, Clone, Default)]
pub struct EgglogRunReport {
pub early: EgglogStageReport,
pub full: EgglogStageReport,
pub total_time: Duration,
}
impl SerializedEGraph {
/// This is an opinionated function which does more than strictly take the state of the egglog object.
/// It also filters out "[...]" nodes and then changes the structure from the e-termDAG that egraph-serialize
@@ -589,41 +605,34 @@ fn termdag_to_egglog(td: &egglog::TermDag, root: egglog::TermId) -> (String, Str
(out.replace("(MVar \"z\")", "(MIter)"), format!("t{root}"))
}
#[tracing::instrument(skip_all)]
pub fn run_egglog(
program: &str,
root: &str,
ops: &[Arc<Box<dyn EgglogOp>>],
cleanup: bool,
) -> Result<SerializedEGraph, egglog::Error> {
let start = std::time::Instant::now();
let code = early_egglog(program, root, ops, cleanup);
let mut egraph = egglog::EGraph::default();
let commands = egraph.parser.get_program_from_string(None, &code)?;
let outputs = egraph.run_program(commands)?;
let CommandOutput::ExtractBest(termdag, _cost, term) = outputs.last().unwrap() else {
panic!();
};
let (program, root) = termdag_to_egglog(termdag, termdag.lookup(term));
let code = full_egglog(&program, ops, cleanup);
let mut egraph = egglog::EGraph::default();
let commands = egraph.parser.get_program_from_string(None, &code)?;
trace!("{}", "Egglog running...".green());
let _outputs = egraph.run_program(commands)?;
trace!("{}", "---- Egglog Rule Matches ----".green());
fn stage_report(egraph: &egglog::EGraph, total_time: Duration) -> EgglogStageReport {
let run_report = egraph.get_overall_run_report();
EgglogStageReport {
num_matches_per_rule: run_report
.num_matches_per_rule
.iter()
.map(|(name, matches)| (name.to_string(), *matches))
.collect(),
search_and_apply_time_per_rule: run_report
.search_and_apply_time_per_rule
.iter()
.map(|(name, elapsed)| (name.to_string(), *elapsed))
.collect(),
total_time,
}
}
fn trace_stage_report(header: &str, report: &EgglogStageReport) {
trace!("{}", header.green());
trace!(
"{}",
run_report
report
.num_matches_per_rule
.iter()
.filter(|(k, _)| !k.contains("("))
.map(|(k, v)| format!(
"{k}: {v} ({})",
pretty_duration::pretty_duration(
&run_report.search_and_apply_time_per_rule[k],
None
)
pretty_duration::pretty_duration(&report.search_and_apply_time_per_rule[k], None)
))
.join("\n")
.green()
@@ -631,11 +640,59 @@ pub fn run_egglog(
trace!(
"{}",
format!(
"---- Egglog Took {} ----",
pretty_duration::pretty_duration(&start.elapsed(), None).bold()
"---- {} Took {} ----",
header,
pretty_duration::pretty_duration(&report.total_time, None).bold()
)
.green()
);
}
#[tracing::instrument(skip_all)]
pub fn run_egglog_with_report(
program: &str,
root: &str,
ops: &[Arc<Box<dyn EgglogOp>>],
cleanup: bool,
) -> Result<(SerializedEGraph, EgglogRunReport), egglog::Error> {
let total_start = std::time::Instant::now();
let early_start = std::time::Instant::now();
let code = early_egglog(program, root, ops, cleanup);
let mut egraph = egglog::EGraph::default();
let commands = egraph.parser.get_program_from_string(None, &code)?;
let outputs = egraph.run_program(commands)?;
let early_report = stage_report(&egraph, early_start.elapsed());
let CommandOutput::ExtractBest(termdag, _cost, term) = outputs.last().unwrap() else {
panic!();
};
let (program, root) = termdag_to_egglog(termdag, termdag.lookup(term));
let full_start = std::time::Instant::now();
let code = full_egglog(&program, ops, cleanup);
let mut egraph = egglog::EGraph::default();
let commands = egraph.parser.get_program_from_string(None, &code)?;
trace!("{}", "Egglog running...".green());
let _outputs = egraph.run_program(commands)?;
let full_report = stage_report(&egraph, full_start.elapsed());
trace_stage_report("---- Egglog Early Rule Matches ----", &early_report);
trace_stage_report("---- Egglog Full Rule Matches ----", &full_report);
let run_report = EgglogRunReport {
early: early_report,
full: full_report,
total_time: total_start.elapsed(),
};
trace!(
"{}",
format!(
"---- Egglog Total Took {} ----",
pretty_duration::pretty_duration(&run_report.total_time, None).bold()
)
.green()
);
let (sort, value) = egraph.eval_expr(&var!(root))?;
let s = egraph.serialize(egglog::SerializeConfig {
root_eclasses: vec![(sort, value)],
@@ -720,7 +777,17 @@ pub fn run_egglog(
"No valid graphs present in the e-graph!"
);
Ok(egraph)
Ok((egraph, run_report))
}
#[tracing::instrument(skip_all)]
pub fn run_egglog(
program: &str,
root: &str,
ops: &[Arc<Box<dyn EgglogOp>>],
cleanup: bool,
) -> Result<SerializedEGraph, egglog::Error> {
run_egglog_with_report(program, root, ops, cleanup).map(|(egraph, _)| egraph)
}
pub fn extract_expr_list<'a>(

View File

@@ -663,7 +663,7 @@ pub(super) mod tests {
let mut out: Vec<(NotNan<f32>, usize)> =
heap.into_iter().map(|std::cmp::Reverse(t)| t).collect();
out.sort_unstable_by(|a, b| b.0.cmp(&a.0));
out.sort_unstable_by_key(|b| std::cmp::Reverse(b.0));
out.into_iter().map(|(_, i)| i).collect()
}
test_unary(

View File

@@ -724,10 +724,8 @@ impl Graph {
}
// Track top-N parents for offspring generation
let mut parents: Vec<(
R::ProfileMetric,
crate::egglog_utils::EGraphChoiceSet<'_>,
)> = vec![(best_metric.clone(), best_genome.clone())];
let mut parents: Vec<(R::ProfileMetric, crate::egglog_utils::EGraphChoiceSet<'_>)> =
vec![(best_metric.clone(), best_genome.clone())];
while n_graphs < limit {
// Generate offspring from all parents, dividing budget evenly
@@ -769,8 +767,7 @@ impl Graph {
None,
);
runtime.clear_intermediate_buffers();
let result =
runtime.profile(&llir_graph, dyn_map, options.trials);
let result = runtime.profile(&llir_graph, dyn_map, options.trials);
let has_nan = runtime.has_nan_outputs(&llir_graph, dyn_map);
(result, llir_graph, has_nan)
}));