mirror of
https://github.com/jafioti/luminal.git
synced 2026-06-01 21:49:47 +09:00
Compare commits
17 Commits
readme-ref
...
feat/lumin
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4d32024f0f | ||
|
|
d66b3f2643 | ||
|
|
66b0807462 | ||
|
|
c24ea4a7a5 | ||
|
|
c309d9b4ed | ||
|
|
745c071ee5 | ||
|
|
56ffe8bbb3 | ||
|
|
13dbdcb53b | ||
|
|
c8ad5f8b75 | ||
|
|
51c6596f6a | ||
|
|
aef4c68537 | ||
|
|
0af1c186fd | ||
|
|
86b2784b51 | ||
|
|
53c58576fc | ||
|
|
64e4eedcc6 | ||
|
|
63afb602b0 | ||
|
|
985e7752aa |
22
README.md
22
README.md
@@ -45,18 +45,6 @@ cd ./examples/llama
|
||||
cargo run --release
|
||||
```
|
||||
|
||||
**PyTorch models via `torch.compile`**
|
||||
|
||||
Any PyTorch model can be run through Luminal by swapping the backend:
|
||||
```python
|
||||
import torch
|
||||
from luminal import luminal_backend
|
||||
|
||||
model_compiled = torch.compile(model, backend=luminal_backend)
|
||||
output = model_compiled(x)
|
||||
```
|
||||
See `crates/luminal_python/` for the PT2-based bridge.
|
||||
|
||||
## Features
|
||||
|
||||
### Speed
|
||||
@@ -87,7 +75,7 @@ The current ML ecosystem is too fragmented, and the solution isn't another layer
|
||||
|
||||
### Validated against Pytorch
|
||||
|
||||
Correctness matters. We write as much tests as possible to cover all ops and verify they work the same as an equivalent Pytorch implementation.
|
||||
Correctness matters. We write as much tests as possible to cover all ops and verify they work the same as an equivalent Pytorch implementation. ([Improvements needed!](https://github.com/jafioti/luminal/issues/20))
|
||||
|
||||
## Ideology
|
||||
|
||||
@@ -114,12 +102,12 @@ Now we can do:
|
||||
|
||||
## Where are we?
|
||||
|
||||
- Search is the default execution path — compile via `build_search_space` and `search` (see the Usage example above).
|
||||
- Search is partially merged. We are between 1.0 and 2.0 (search), which will be completed within the next month or so.
|
||||
- Metal and Cuda are supported for running models on Macs and Nvidia GPUs respectively, in both full and half precision.
|
||||
- Llama 3, Gemma, Qwen (incl. MoE variants), and a paged-attention Llama are implemented in `examples/`. See instructions above for running.
|
||||
- Full training support with graph-based autograd.
|
||||
- Llama 3, Phi 3, Whisper and Yolo v8 are implemented in `examples/`. See instructions above for running.
|
||||
- We have a small library of NN modules in `luminal_nn`, including transformers.
|
||||
- A large surface of high-level ops lives in `src/frontend/` — aiming to match the most used ~80% of the PyTorch api.
|
||||
- PyTorch models can be run through luminal via `torch.compile` — see `crates/luminal_python/`.
|
||||
- A significant amount of high-level ops are implemented in `hl_ops`. We are aiming to match the most used ~80% of the pytorch api.
|
||||
|
||||
Some things on the roadmap:
|
||||
|
||||
|
||||
@@ -32,6 +32,7 @@ use crate::{
|
||||
driver::{CudaSlice, CudaStream, DevicePtr},
|
||||
},
|
||||
host::{HostOp, cublas::parse_cublas_op},
|
||||
try_create_cublaslt,
|
||||
};
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -248,6 +249,19 @@ fn dtype_to_cuda_types(dtype: DType) -> (cudaDataType, cublasComputeType_t, cuda
|
||||
}
|
||||
}
|
||||
|
||||
impl CuBlasLt {
|
||||
fn get_cublaslt(&self, stream: &Arc<CudaStream>) -> anyhow::Result<Arc<CudaBlasLT>> {
|
||||
if let Some(cublaslt) = self.cublaslt.get() {
|
||||
return Ok(cublaslt.clone());
|
||||
}
|
||||
let created = try_create_cublaslt(stream.clone()).map_err(|message| {
|
||||
anyhow::anyhow!("cuBLASLt unavailable on this machine: {message}")
|
||||
})?;
|
||||
let _ = self.cublaslt.set(created.clone());
|
||||
Ok(created)
|
||||
}
|
||||
}
|
||||
|
||||
impl HostOp for CuBlasLt {
|
||||
fn execute(
|
||||
&self,
|
||||
@@ -324,9 +338,7 @@ impl HostOp for CuBlasLt {
|
||||
)
|
||||
.entered();
|
||||
|
||||
let cublaslt = self
|
||||
.cublaslt
|
||||
.get_or_init(|| Arc::new(CudaBlasLT::new(stream.clone()).unwrap()));
|
||||
let cublaslt = self.get_cublaslt(stream)?;
|
||||
|
||||
let mut matmul_desc: cublasLtMatmulDesc_t = std::ptr::null_mut();
|
||||
let mut a_desc: cublasLtMatrixLayout_t = std::ptr::null_mut();
|
||||
|
||||
@@ -1,128 +1,213 @@
|
||||
; GLUMoE: Match the expert computation subgraph of a Gated MoE (SwiGLU variant).
|
||||
; GLUMoE: Match the expert computation subgraph of a gated MoE.
|
||||
;
|
||||
; This matches the pattern produced by QwenMoE::forward() starting from the
|
||||
; expert gathers through to the final weighted sum, and replaces it with a
|
||||
; fused GLUMoE HostOp.
|
||||
; One fused op supports two activation modes:
|
||||
; mode=0: Qwen-style SwiGLU (silu(gate) * up)
|
||||
; mode=1: Gemma-style GELU (gate * sigmoid(1.595769 * gate * (1 + 0.044715 * gate^2)))
|
||||
;
|
||||
; Inputs extracted:
|
||||
; ?x - input activations [s, H] F32
|
||||
; ?topk_idx - top-k expert indices [s, k] Int (from argsort+slice)
|
||||
; ?topk_vals - top-k routing values [s, k] F32 (from gather on softmax)
|
||||
; ?gate_up_w - stacked gate+up expert weights [E, intermediate*2, H] BF16
|
||||
; ?down_w - stacked down expert weights [E, H, intermediate] BF16
|
||||
;
|
||||
; The pattern captures:
|
||||
; 1. Gate-up expert gather (Iota, Mul, Cast, Iota, Cast, Add, Cast, Gather)
|
||||
; 2. Cast BF16→F32 of gathered gate-up weights
|
||||
; 3. Gate-up batched matmul (Mul + SumReduce)
|
||||
; 4. Gate/Up split via Iota+Gather (slice semantics)
|
||||
; 5. SwiGLU: silu(gate) * up
|
||||
; 6. Down expert gather (same pattern as gate-up)
|
||||
; 7. Cast BF16→F32 of gathered down weights
|
||||
; 8. Down batched matmul (Mul + SumReduce)
|
||||
; 9. Weighted sum: (down_out * topk_values) summed over k
|
||||
;
|
||||
; Variables with ? prefix are egglog pattern variables.
|
||||
; We use wildcards (?_xxx) for shapes/strides we don't extract.
|
||||
; To keep matching fast, we stage through marker states:
|
||||
; 1) Shared gate-up matmul marker
|
||||
; 2) Activation marker (separate swiglu / gemma_gelu paths)
|
||||
; 3) Down matmul marker (separate swiglu / gemma_gelu paths)
|
||||
; 4) Final GLUMoE fusion (separate swiglu / gemma_gelu rules)
|
||||
|
||||
(datatype*
|
||||
(GLUMoEGateUpState
|
||||
(MkGLUMoEGateUpState Expression Expression Expression IR IR IR)
|
||||
)
|
||||
(GLUMoESwiGLUState
|
||||
(MkGLUMoESwiGLUState GLUMoEGateUpState)
|
||||
)
|
||||
(GLUMoEGemmaGELUState
|
||||
(MkGLUMoEGemmaGELUState GLUMoEGateUpState)
|
||||
)
|
||||
(GLUMoESwiGLUDownState
|
||||
(MkGLUMoESwiGLUDownState Expression Expression Expression GLUMoESwiGLUState IR IR)
|
||||
)
|
||||
(GLUMoEGemmaDownState
|
||||
(MkGLUMoEGemmaDownState Expression Expression Expression GLUMoEGemmaGELUState IR IR)
|
||||
)
|
||||
)
|
||||
|
||||
(function glumoe_gate_up (IR) GLUMoEGateUpState :merge new)
|
||||
(function glumoe_swiglu (IR) GLUMoESwiGLUState :merge new)
|
||||
(function glumoe_gemma_gelu (IR) GLUMoEGemmaGELUState :merge new)
|
||||
(function glumoe_swiglu_down (IR) GLUMoESwiGLUDownState :merge new)
|
||||
(function glumoe_gemma_down (IR) GLUMoEGemmaDownState :merge new)
|
||||
|
||||
(rule
|
||||
(
|
||||
; ===== Gate-up expert gather =====
|
||||
; t51: Iota for base index (expert_idx * io_gu)
|
||||
(= ?gu_iota_base (Op (Iota ?gu_io ?gu_iota_base_range) (INil)))
|
||||
; t52: Mul topk_indices * io → base offsets [s, k]
|
||||
(= ?gu_mul_base (Op (Mul ?gu_mul_base_shape ?gu_mul_base_a_stride ?gu_mul_base_b_stride ?gu_mul_base_out_stride) (ICons ?topk_idx (ICons ?gu_iota_base (INil)))))
|
||||
; t53: Cast to F32
|
||||
(= ?gu_cast_base (Op (Cast ?gu_cast_base_size (F32)) (ICons ?gu_mul_base (INil))))
|
||||
; t54: Iota for within-expert index
|
||||
(= ?gu_iota_within (Op (Iota (MIter) ?gu_iota_within_range) (INil)))
|
||||
; t55: Cast within to F32
|
||||
(= ?gu_cast_within (Op (Cast ?gu_cast_within_size (F32)) (ICons ?gu_iota_within (INil))))
|
||||
; t56: Add base + within → flat gather indices
|
||||
(= ?gu_add_idx (Op (Add ?gu_add_shape ?gu_add_a_stride ?gu_add_b_stride ?gu_add_out_stride) (ICons ?gu_cast_base (ICons ?gu_cast_within (INil)))))
|
||||
; t57: Cast to Int
|
||||
(= ?gu_cast_idx (Op (Cast ?gu_cast_idx_size (Int)) (ICons ?gu_add_idx (INil))))
|
||||
; t58: Gather gate_up weights
|
||||
(= ?gu_gathered (Op (Gather ?gu_gather_idx_shape ?gu_gather_idx_stride ?gu_gather_data_shape ?gu_gather_data_stride) (ICons ?gu_cast_idx (ICons ?gate_up_w (INil)))))
|
||||
(= ?gu_add_idx (Op (Add ?gu_add_shape ?gu_add_a_stride ?gu_add_b_stride ?gu_add_out_stride) (ICons ?gu_mul_base (ICons ?gu_iota_within (INil)))))
|
||||
(= ?gu_gathered (Op (Gather ?gu_gather_idx_shape ?gu_gather_idx_stride ?gu_gather_data_shape ?gu_gather_data_stride) (ICons ?gu_add_idx (ICons ?gate_up_w (INil)))))
|
||||
|
||||
; ===== Cast BF16→F32 =====
|
||||
; t59: Cast gathered gate_up to F32
|
||||
(= ?gu_f32 (Op (Cast ?gu_f32_size (F32)) (ICons ?gu_gathered (INil))))
|
||||
|
||||
; ===== Gate-up batched matmul =====
|
||||
; t60: Mul x * gathered_gu (broadcast multiply)
|
||||
(= ?gu_matmul_mul (Op (Mul ?gu_matmul_mul_shape ?gu_matmul_a_stride ?gu_matmul_b_stride ?gu_matmul_mul_out_stride) (ICons ?x (ICons ?gu_f32 (INil)))))
|
||||
; t61: SumReduce over K dimension
|
||||
(= ?gu_matmul (Op (Sum ?gu_matmul_out_shape ?gu_matmul_k ?gu_matmul_in_stride ?gu_matmul_k_stride ?gu_matmul_out_stride) (ICons ?gu_matmul_mul (INil))))
|
||||
)
|
||||
(
|
||||
(set (glumoe_gate_up ?gu_matmul)
|
||||
(MkGLUMoEGateUpState ?gu_io ?gu_matmul_k ?gu_iota_within_range ?x ?topk_idx ?gate_up_w))
|
||||
)
|
||||
:name "GLUMoE gate-up matmul marker"
|
||||
)
|
||||
|
||||
; ===== SwiGLU activation marker =====
|
||||
(rule
|
||||
(
|
||||
(= ?gate_up_state (glumoe_gate_up ?gu_matmul))
|
||||
(= ?gate_up_state (MkGLUMoEGateUpState ?gu_io ?gu_matmul_k ?gu_within_range ?x ?topk_idx ?gate_up_w))
|
||||
|
||||
; ===== Up slice via Iota+Gather =====
|
||||
; t62: Iota with complex expression (slicing the "up" half)
|
||||
(= ?up_iota (Op (Iota ?up_iota_expr ?up_iota_range) (INil)))
|
||||
; t63: Gather to select up portion from matmul result
|
||||
(= ?up_slice (Op (Gather ?up_gather_idx_shape ?up_gather_idx_stride ?up_gather_data_shape ?up_gather_data_stride) (ICons ?up_iota (ICons ?gu_matmul (INil)))))
|
||||
|
||||
; ===== SwiGLU: silu(gate) * up =====
|
||||
; t64: Constant(-1)
|
||||
(= ?neg1 (Op (Constant -1.000000) (INil)))
|
||||
; t65: gate * -1
|
||||
(= ?neg_gate (Op (Mul ?silu_shape1 ?silu_a_stride1 ?silu_b_stride1 ?silu_out_stride1) (ICons ?gu_matmul (ICons ?neg1 (INil)))))
|
||||
; t66: Constant(log2e)
|
||||
(= ?log2e (Op (Constant 1.442695) (INil)))
|
||||
; t67: neg_gate * log2e
|
||||
(= ?scaled (Op (Mul ?silu_shape2 ?silu_a_stride2 ?silu_b_stride2 ?silu_out_stride2) (ICons ?neg_gate (ICons ?log2e (INil)))))
|
||||
; t68: exp2
|
||||
(= ?exp2_val (Op (Exp2 ?silu_shape3 ?silu_in_stride3 ?silu_out_stride3) (ICons ?scaled (INil))))
|
||||
; t69: Constant(1)
|
||||
(= ?one (Op (Constant 1.000000) (INil)))
|
||||
; t70: exp2 + 1
|
||||
(= ?plus1 (Op (Add ?silu_shape4 ?silu_a_stride4 ?silu_b_stride4 ?silu_out_stride4) (ICons ?exp2_val (ICons ?one (INil)))))
|
||||
; t71: recip
|
||||
(= ?sigmoid (Op (Recip ?silu_shape5 ?silu_in_stride5 ?silu_out_stride5) (ICons ?plus1 (INil))))
|
||||
; t72: gate * sigmoid(gate) = silu(gate)
|
||||
(= ?silu_out (Op (Mul ?silu_shape6 ?silu_a_stride6 ?silu_b_stride6 ?silu_out_stride6) (ICons ?gu_matmul (ICons ?sigmoid (INil)))))
|
||||
; t73: silu(gate) * up
|
||||
(= ?swiglu_out (Op (Mul ?swiglu_shape ?swiglu_a_stride ?swiglu_b_stride ?swiglu_out_stride) (ICons ?silu_out (ICons ?up_slice (INil)))))
|
||||
)
|
||||
(
|
||||
(set (glumoe_swiglu ?swiglu_out) (MkGLUMoESwiGLUState ?gate_up_state))
|
||||
)
|
||||
:name "GLUMoE swiglu marker"
|
||||
)
|
||||
|
||||
; ===== Gemma GELU activation marker =====
|
||||
(rule
|
||||
(
|
||||
(= ?gate_up_state (glumoe_gate_up ?gu_matmul))
|
||||
(= ?gate_up_state (MkGLUMoEGateUpState ?gu_io ?gu_matmul_k ?gu_within_range ?x ?topk_idx ?gate_up_w))
|
||||
|
||||
(= ?up_iota (Op (Iota ?up_iota_expr ?up_iota_range) (INil)))
|
||||
(= ?up_slice (Op (Gather ?up_gather_idx_shape ?up_gather_idx_stride ?up_gather_data_shape ?up_gather_data_stride) (ICons ?up_iota (ICons ?gu_matmul (INil)))))
|
||||
|
||||
(= ?gelu_coeff_inner (Op (Constant 0.044715) (INil)))
|
||||
(= ?gelu_inner_scaled (Op (Mul ?gelu_inner_scaled_shape ?gelu_inner_scaled_a_stride ?gelu_inner_scaled_b_stride ?gelu_inner_scaled_out_stride) (ICons ?gu_matmul (ICons ?gelu_coeff_inner (INil)))))
|
||||
(= ?gelu_inner_quad (Op (Mul ?gelu_inner_quad_shape ?gelu_inner_quad_a_stride ?gelu_inner_quad_b_stride ?gelu_inner_quad_out_stride) (ICons ?gelu_inner_scaled (ICons ?gu_matmul (INil)))))
|
||||
(= ?gelu_one (Op (Constant 1.000000) (INil)))
|
||||
(= ?gelu_poly (Op (Add ?gelu_poly_shape ?gelu_poly_a_stride ?gelu_poly_b_stride ?gelu_poly_out_stride) (ICons ?gelu_inner_quad (ICons ?gelu_one (INil)))))
|
||||
(= ?gelu_coeff_outer (Op (Constant 1.595769) (INil)))
|
||||
(= ?gelu_outer_scaled (Op (Mul ?gelu_outer_scaled_shape ?gelu_outer_scaled_a_stride ?gelu_outer_scaled_b_stride ?gelu_outer_scaled_out_stride) (ICons ?gu_matmul (ICons ?gelu_coeff_outer (INil)))))
|
||||
(= ?gelu_scaled (Op (Mul ?gelu_scaled_shape ?gelu_scaled_a_stride ?gelu_scaled_b_stride ?gelu_scaled_out_stride) (ICons ?gelu_outer_scaled (ICons ?gelu_poly (INil)))))
|
||||
(= ?neg1 (Op (Constant -1.000000) (INil)))
|
||||
(= ?gelu_neg (Op (Mul ?gelu_neg_shape ?gelu_neg_a_stride ?gelu_neg_b_stride ?gelu_neg_out_stride) (ICons ?gelu_scaled (ICons ?neg1 (INil)))))
|
||||
(= ?log2e (Op (Constant 1.442695) (INil)))
|
||||
(= ?gelu_exp_scaled (Op (Mul ?gelu_exp_scaled_shape ?gelu_exp_scaled_a_stride ?gelu_exp_scaled_b_stride ?gelu_exp_scaled_out_stride) (ICons ?gelu_neg (ICons ?log2e (INil)))))
|
||||
(= ?gelu_exp2_val (Op (Exp2 ?gelu_exp_shape ?gelu_exp_in_stride ?gelu_exp_out_stride) (ICons ?gelu_exp_scaled (INil))))
|
||||
(= ?gelu_plus1 (Op (Add ?gelu_plus1_shape ?gelu_plus1_a_stride ?gelu_plus1_b_stride ?gelu_plus1_out_stride) (ICons ?gelu_exp2_val (ICons ?gelu_one (INil)))))
|
||||
(= ?gelu_sigmoid (Op (Recip ?gelu_sigmoid_shape ?gelu_sigmoid_in_stride ?gelu_sigmoid_out_stride) (ICons ?gelu_plus1 (INil))))
|
||||
(= ?gelu_out (Op (Mul ?gelu_out_shape ?gelu_out_a_stride ?gelu_out_b_stride ?gelu_out_out_stride) (ICons ?gu_matmul (ICons ?gelu_sigmoid (INil)))))
|
||||
(= ?gemma_out (Op (Mul ?geglu_shape ?geglu_a_stride ?geglu_b_stride ?geglu_out_stride) (ICons ?gelu_out (ICons ?up_slice (INil)))))
|
||||
)
|
||||
(
|
||||
(set (glumoe_gemma_gelu ?gemma_out) (MkGLUMoEGemmaGELUState ?gate_up_state))
|
||||
)
|
||||
:name "GLUMoE gemma gelu marker"
|
||||
)
|
||||
|
||||
; ===== SwiGLU down marker =====
|
||||
(rule
|
||||
(
|
||||
(= ?swiglu_state (glumoe_swiglu ?swiglu_out))
|
||||
(= ?swiglu_state (MkGLUMoESwiGLUState ?gate_up_state))
|
||||
|
||||
; ===== Down expert gather =====
|
||||
; t74: Iota for base index (expert_idx * io_down)
|
||||
(= ?dn_iota_base (Op (Iota ?dn_io ?dn_iota_base_range) (INil)))
|
||||
; t75: Mul topk_indices * io_down
|
||||
(= ?dn_mul_base (Op (Mul ?dn_mul_base_shape ?dn_mul_base_a_stride ?dn_mul_base_b_stride ?dn_mul_base_out_stride) (ICons ?topk_idx (ICons ?dn_iota_base (INil)))))
|
||||
; t76: Cast to F32
|
||||
(= ?dn_cast_base (Op (Cast ?dn_cast_base_size (F32)) (ICons ?dn_mul_base (INil))))
|
||||
; t77: Iota for within-expert index
|
||||
(= ?dn_iota_within (Op (Iota (MIter) ?dn_iota_within_range) (INil)))
|
||||
; t78: Cast within to F32
|
||||
(= ?dn_cast_within (Op (Cast ?dn_cast_within_size (F32)) (ICons ?dn_iota_within (INil))))
|
||||
; t79: Add base + within
|
||||
(= ?dn_add_idx (Op (Add ?dn_add_shape ?dn_add_a_stride ?dn_add_b_stride ?dn_add_out_stride) (ICons ?dn_cast_base (ICons ?dn_cast_within (INil)))))
|
||||
; t80: Cast to Int
|
||||
(= ?dn_cast_idx (Op (Cast ?dn_cast_idx_size (Int)) (ICons ?dn_add_idx (INil))))
|
||||
; t81: Gather down weights
|
||||
(= ?dn_gathered (Op (Gather ?dn_gather_idx_shape ?dn_gather_idx_stride ?dn_gather_data_shape ?dn_gather_data_stride) (ICons ?dn_cast_idx (ICons ?down_w (INil)))))
|
||||
|
||||
; ===== Cast BF16→F32 =====
|
||||
; t82: Cast gathered down to F32
|
||||
(= ?dn_add_idx (Op (Add ?dn_add_shape ?dn_add_a_stride ?dn_add_b_stride ?dn_add_out_stride) (ICons ?dn_mul_base (ICons ?dn_iota_within (INil)))))
|
||||
(= ?dn_gathered (Op (Gather ?dn_gather_idx_shape ?dn_gather_idx_stride ?dn_gather_data_shape ?dn_gather_data_stride) (ICons ?dn_add_idx (ICons ?down_w (INil)))))
|
||||
(= ?dn_f32 (Op (Cast ?dn_f32_size (F32)) (ICons ?dn_gathered (INil))))
|
||||
|
||||
; ===== Down batched matmul =====
|
||||
; t83: Mul swiglu_out * gathered_down (broadcast multiply)
|
||||
(= ?dn_matmul_mul (Op (Mul ?dn_matmul_mul_shape ?dn_matmul_a_stride ?dn_matmul_b_stride ?dn_matmul_mul_out_stride) (ICons ?swiglu_out (ICons ?dn_f32 (INil)))))
|
||||
; t84: SumReduce
|
||||
(= ?dn_matmul (Op (Sum ?dn_matmul_out_shape ?dn_matmul_k ?dn_matmul_in_stride ?dn_matmul_k_stride ?dn_matmul_out_stride) (ICons ?dn_matmul_mul (INil))))
|
||||
)
|
||||
(
|
||||
(set (glumoe_swiglu_down ?dn_matmul)
|
||||
(MkGLUMoESwiGLUDownState ?dn_io ?dn_matmul_k ?dn_iota_within_range ?swiglu_state ?topk_idx ?down_w))
|
||||
)
|
||||
:name "GLUMoE swiglu down marker"
|
||||
)
|
||||
|
||||
; ===== Gemma GELU down marker =====
|
||||
(rule
|
||||
(
|
||||
(= ?gemma_state (glumoe_gemma_gelu ?gemma_out))
|
||||
(= ?gemma_state (MkGLUMoEGemmaGELUState ?gate_up_state))
|
||||
|
||||
(= ?dn_iota_base (Op (Iota ?dn_io ?dn_iota_base_range) (INil)))
|
||||
(= ?dn_mul_base (Op (Mul ?dn_mul_base_shape ?dn_mul_base_a_stride ?dn_mul_base_b_stride ?dn_mul_base_out_stride) (ICons ?topk_idx (ICons ?dn_iota_base (INil)))))
|
||||
(= ?dn_iota_within (Op (Iota (MIter) ?dn_iota_within_range) (INil)))
|
||||
(= ?dn_add_idx (Op (Add ?dn_add_shape ?dn_add_a_stride ?dn_add_b_stride ?dn_add_out_stride) (ICons ?dn_mul_base (ICons ?dn_iota_within (INil)))))
|
||||
(= ?dn_gathered (Op (Gather ?dn_gather_idx_shape ?dn_gather_idx_stride ?dn_gather_data_shape ?dn_gather_data_stride) (ICons ?dn_add_idx (ICons ?down_w (INil)))))
|
||||
(= ?dn_f32 (Op (Cast ?dn_f32_size (F32)) (ICons ?dn_gathered (INil))))
|
||||
(= ?dn_matmul_mul (Op (Mul ?dn_matmul_mul_shape ?dn_matmul_a_stride ?dn_matmul_b_stride ?dn_matmul_mul_out_stride) (ICons ?gemma_out (ICons ?dn_f32 (INil)))))
|
||||
(= ?dn_matmul (Op (Sum ?dn_matmul_out_shape ?dn_matmul_k ?dn_matmul_in_stride ?dn_matmul_k_stride ?dn_matmul_out_stride) (ICons ?dn_matmul_mul (INil))))
|
||||
)
|
||||
(
|
||||
(set (glumoe_gemma_down ?dn_matmul)
|
||||
(MkGLUMoEGemmaDownState ?dn_io ?dn_matmul_k ?dn_iota_within_range ?gemma_state ?topk_idx ?down_w))
|
||||
)
|
||||
:name "GLUMoE gemma down marker"
|
||||
)
|
||||
|
||||
; ===== Final fusion: mode 0 (SwiGLU) =====
|
||||
(rule
|
||||
(
|
||||
(= ?down_state (glumoe_swiglu_down ?dn_matmul))
|
||||
(= ?down_state (MkGLUMoESwiGLUDownState ?dn_io ?dn_matmul_k ?dn_within_range ?swiglu_state ?topk_idx ?down_w))
|
||||
(= ?swiglu_state (MkGLUMoESwiGLUState ?gate_up_state))
|
||||
(= ?gate_up_state (MkGLUMoEGateUpState ?gu_io ?gu_matmul_k ?gu_within_range ?x ?topk_idx ?gate_up_w))
|
||||
|
||||
; ===== Weighted sum over k experts =====
|
||||
; t85: Mul down_out * topk_values
|
||||
(= ?weighted (Op (Mul ?weighted_shape ?weighted_a_stride ?weighted_b_stride ?weighted_out_stride) (ICons ?dn_matmul (ICons ?topk_vals (INil)))))
|
||||
; t86: SumReduce over k dimension → [s, H]
|
||||
(= ?output (Op (Sum ?output_shape ?output_k ?output_in_stride ?output_k_stride ?output_out_stride) (ICons ?weighted (INil))))
|
||||
)
|
||||
(
|
||||
(let ?glumoe (Op (GLUMoE
|
||||
?gu_io ?dn_io ?gu_matmul_k ?dn_matmul_k ?output_k
|
||||
?gu_iota_within_range ?dn_iota_within_range)
|
||||
(ICons ?x (ICons ?topk_idx (ICons ?topk_vals (ICons ?gate_up_w (ICons ?down_w (INil))))))))
|
||||
?gu_within_range ?dn_within_range (MNum 0))
|
||||
(ICons ?x (ICons ?topk_idx (ICons ?topk_vals (ICons ?gate_up_w (ICons ?down_w (ICons ?topk_vals (INil)))))))))
|
||||
(union ?output ?glumoe)
|
||||
)
|
||||
:name "GLUMoE fused expert computation"
|
||||
:name "GLUMoE fused expert computation (swiglu)"
|
||||
)
|
||||
|
||||
; ===== Final fusion: mode 1 (Gemma GELU) =====
|
||||
(rule
|
||||
(
|
||||
(= ?down_state (glumoe_gemma_down ?dn_matmul))
|
||||
(= ?down_state (MkGLUMoEGemmaDownState ?dn_io ?dn_matmul_k ?dn_within_range ?gemma_state ?topk_idx ?down_w))
|
||||
(= ?gemma_state (MkGLUMoEGemmaGELUState ?gate_up_state))
|
||||
(= ?gate_up_state (MkGLUMoEGateUpState ?gu_io ?gu_matmul_k ?gu_within_range ?x ?topk_idx ?gate_up_w))
|
||||
|
||||
; Gemma expert weights: topk_weights = normed_topk * per_expert_scale.gather(topk_idx)
|
||||
(= ?per_expert_vals (Op (Gather ?scale_gather_idx_shape ?scale_gather_idx_stride ?scale_gather_data_shape ?scale_gather_data_stride) (ICons ?topk_idx (ICons ?per_expert_scale (INil)))))
|
||||
(= ?topk_row_offsets (Op (Iota ?topk_row_offsets_expr ?topk_row_offsets_range) (INil)))
|
||||
(= ?topk_flat_idx (Op (Add ?topk_flat_idx_shape ?topk_flat_idx_a_stride ?topk_flat_idx_b_stride ?topk_flat_idx_out_stride) (ICons ?topk_row_offsets (ICons ?topk_idx (INil)))))
|
||||
(= ?topk_vals (Op (Gather ?topk_vals_gather_idx_shape ?topk_vals_gather_idx_stride ?topk_vals_gather_data_shape ?topk_vals_gather_data_stride) (ICons ?topk_flat_idx (ICons ?routing_weights (INil)))))
|
||||
(= ?topk_norm (Op (Sum ?topk_norm_shape ?output_k ?topk_norm_in_stride ?topk_norm_k_stride ?topk_norm_out_stride) (ICons ?topk_vals (INil))))
|
||||
(= ?topk_norm_factor (Op (Recip ?topk_norm_recip_shape ?topk_norm_recip_in_stride ?topk_norm_recip_out_stride) (ICons ?topk_norm (INil))))
|
||||
(= ?normed_topk (Op (Mul ?normed_topk_shape ?normed_topk_a_stride ?normed_topk_b_stride ?normed_topk_out_stride) (ICons ?topk_vals (ICons ?topk_norm_factor (INil)))))
|
||||
(= ?expert_weights (Op (Mul ?expert_weights_shape ?expert_weights_a_stride ?expert_weights_b_stride ?expert_weights_out_stride) (ICons ?normed_topk (ICons ?per_expert_vals (INil)))))
|
||||
|
||||
(= ?weighted (Op (Mul ?weighted_shape ?weighted_a_stride ?weighted_b_stride ?weighted_out_stride) (ICons ?dn_matmul (ICons ?expert_weights (INil)))))
|
||||
(= ?output (Op (Sum ?output_shape ?output_k ?output_in_stride ?output_k_stride ?output_out_stride) (ICons ?weighted (INil))))
|
||||
)
|
||||
(
|
||||
(let ?glumoe (Op (GLUMoE
|
||||
?gu_io ?dn_io ?gu_matmul_k ?dn_matmul_k ?output_k
|
||||
?gu_within_range ?dn_within_range (MNum 1))
|
||||
(ICons ?x (ICons ?topk_idx (ICons ?topk_vals (ICons ?gate_up_w (ICons ?down_w (ICons ?per_expert_scale (INil)))))))))
|
||||
(union ?output ?glumoe)
|
||||
)
|
||||
:name "GLUMoE fused expert computation (gemma_gelu)"
|
||||
)
|
||||
|
||||
@@ -33,14 +33,15 @@ use crate::{
|
||||
},
|
||||
},
|
||||
host::HostOp,
|
||||
try_create_cublaslt,
|
||||
};
|
||||
|
||||
const WORKSPACE_SIZE: usize = 32 * 1024 * 1024; // 32 MiB
|
||||
|
||||
/// Fused GLU-MoE HostOp matched via egglog pattern.
|
||||
///
|
||||
/// Replaces the expert computation subgraph (expert gathers + matmuls + SwiGLU
|
||||
/// + weighted sum) with an efficient cuBLASLt implementation.
|
||||
/// Replaces the expert computation subgraph (expert gathers + matmuls + gated
|
||||
/// activation + weighted sum) with an efficient cuBLASLt implementation.
|
||||
///
|
||||
/// Inputs (graph edges, in order):
|
||||
/// 0: x [seq, hidden] F32
|
||||
@@ -48,9 +49,13 @@ const WORKSPACE_SIZE: usize = 32 * 1024 * 1024; // 32 MiB
|
||||
/// 2: topk_values [seq, k] F32
|
||||
/// 3: gate_up_w [E, gate_up_dim, hidden] BF16
|
||||
/// 4: down_w [E, hidden, intermediate] BF16
|
||||
/// 5: mode_aux
|
||||
/// - SwiGLU: ignored (rewriter wires `topk_values` again)
|
||||
/// - GemmaGELU: per_expert_scale [E] F32
|
||||
///
|
||||
/// Output: [seq, hidden] F32
|
||||
pub struct GLUMoE {
|
||||
pub(crate) mode: GLUMoEMode,
|
||||
/// Product of gate_up weight dimensions per expert (gate_up_dim * hidden) used for gather stride
|
||||
gu_io: Expression,
|
||||
/// Product of down weight dimensions per expert (hidden * intermediate) used for gather stride
|
||||
@@ -69,9 +74,35 @@ pub struct GLUMoE {
|
||||
module: OnceLock<(Arc<CudaModule>, CudaFunction, CudaFunction)>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub(crate) enum GLUMoEMode {
|
||||
SwiGLU,
|
||||
GemmaGELU,
|
||||
}
|
||||
|
||||
impl GLUMoEMode {
|
||||
fn from_mode_id(mode_id: usize) -> Self {
|
||||
match mode_id {
|
||||
0 => Self::SwiGLU,
|
||||
1 => Self::GemmaGELU,
|
||||
other => {
|
||||
panic!("Unknown GLUMoE mode id: {other}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn activation_kernel_mode(self) -> i32 {
|
||||
match self {
|
||||
Self::SwiGLU => 0,
|
||||
Self::GemmaGELU => 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for GLUMoE {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
mode: GLUMoEMode::SwiGLU,
|
||||
gu_io: Expression::default(),
|
||||
dn_io: Expression::default(),
|
||||
gu_matmul_k: Expression::default(),
|
||||
@@ -88,6 +119,7 @@ impl Default for GLUMoE {
|
||||
impl std::fmt::Debug for GLUMoE {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("GLUMoE")
|
||||
.field("mode", &self.mode)
|
||||
.field("gu_io", &self.gu_io)
|
||||
.field("dn_io", &self.dn_io)
|
||||
.field("gu_matmul_k", &self.gu_matmul_k)
|
||||
@@ -100,6 +132,7 @@ impl std::fmt::Debug for GLUMoE {
|
||||
impl Clone for GLUMoE {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
mode: self.mode,
|
||||
gu_io: self.gu_io,
|
||||
dn_io: self.dn_io,
|
||||
gu_matmul_k: self.gu_matmul_k,
|
||||
@@ -114,9 +147,15 @@ impl Clone for GLUMoE {
|
||||
}
|
||||
|
||||
impl GLUMoE {
|
||||
fn get_cublaslt(&self, stream: &Arc<CudaStream>) -> &Arc<CudaBlasLT> {
|
||||
self.cublaslt
|
||||
.get_or_init(|| Arc::new(CudaBlasLT::new(stream.clone()).unwrap()))
|
||||
fn get_cublaslt(&self, stream: &Arc<CudaStream>) -> anyhow::Result<Arc<CudaBlasLT>> {
|
||||
if let Some(cublaslt) = self.cublaslt.get() {
|
||||
return Ok(cublaslt.clone());
|
||||
}
|
||||
let created = try_create_cublaslt(stream.clone()).map_err(|message| {
|
||||
anyhow::anyhow!("cuBLASLt unavailable on this machine: {message}")
|
||||
})?;
|
||||
let _ = self.cublaslt.set(created.clone());
|
||||
Ok(created)
|
||||
}
|
||||
|
||||
fn get_kernels(
|
||||
@@ -134,23 +173,34 @@ extern "C" __global__ void f32_to_bf16(unsigned long long in_ptr, unsigned long
|
||||
if (i < n) out[i] = __float2bfloat16(in_[i]);
|
||||
}
|
||||
|
||||
extern "C" __global__ void swiglu_bf16(unsigned long long gate_up_ptr, unsigned long long out_ptr, int intermediate) {
|
||||
extern "C" __global__ void glu_activation_bf16(
|
||||
unsigned long long gate_up_ptr,
|
||||
unsigned long long out_ptr,
|
||||
int intermediate,
|
||||
int mode
|
||||
) {
|
||||
const __nv_bfloat16* gate_up = (const __nv_bfloat16*)gate_up_ptr;
|
||||
__nv_bfloat16* out = (__nv_bfloat16*)out_ptr;
|
||||
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (i < intermediate) {
|
||||
float gate = __bfloat162float(gate_up[i]);
|
||||
float up = __bfloat162float(gate_up[i + intermediate]);
|
||||
float silu = gate / (1.0f + expf(-gate));
|
||||
out[i] = __float2bfloat16(silu * up);
|
||||
float activated;
|
||||
if (mode == 0) {
|
||||
activated = gate / (1.0f + expf(-gate));
|
||||
} else {
|
||||
float scaled = 1.5957691216f * gate * (1.0f + 0.044715f * gate * gate);
|
||||
activated = gate / (1.0f + expf(-scaled));
|
||||
}
|
||||
out[i] = __float2bfloat16(activated * up);
|
||||
}
|
||||
}
|
||||
"#;
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), src).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let f32_to_bf16 = module.load_function("f32_to_bf16").unwrap();
|
||||
let swiglu = module.load_function("swiglu_bf16").unwrap();
|
||||
(module, f32_to_bf16, swiglu)
|
||||
let activation = module.load_function("glu_activation_bf16").unwrap();
|
||||
(module, f32_to_bf16, activation)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -168,12 +218,27 @@ impl EgglogOp for GLUMoE {
|
||||
("output_k", EXPRESSION),
|
||||
("gu_within_range", EXPRESSION),
|
||||
("dn_within_range", EXPRESSION),
|
||||
("mode", EXPRESSION),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![Rule::raw(
|
||||
"(rule
|
||||
(
|
||||
(= ?e (Op (GLUMoE ?gu_io ?dn_io ?gu_matmul_k ?dn_matmul_k ?output_k ?gu_within_range ?dn_within_range ?mode) ?inputs))
|
||||
)
|
||||
(
|
||||
(set (dtype ?e) (F32))
|
||||
)
|
||||
:ruleset dtype_prop
|
||||
)",
|
||||
)]
|
||||
}
|
||||
|
||||
fn n_inputs(&self) -> usize {
|
||||
5
|
||||
6
|
||||
}
|
||||
|
||||
fn early_rewrites(&self) -> Vec<Rule> {
|
||||
@@ -195,8 +260,14 @@ impl EgglogOp for GLUMoE {
|
||||
let output_k = extract_expr(egraph, kind_children[4], expr_cache).unwrap();
|
||||
let gu_within_range = extract_expr(egraph, kind_children[5], expr_cache).unwrap();
|
||||
let dn_within_range = extract_expr(egraph, kind_children[6], expr_cache).unwrap();
|
||||
let mode_expr = extract_expr(egraph, kind_children[7], expr_cache).unwrap();
|
||||
let mode_id = mode_expr
|
||||
.to_usize()
|
||||
.unwrap_or_else(|| panic!("GLUMoE mode must be static, got expression: {mode_expr}"));
|
||||
let mode = GLUMoEMode::from_mode_id(mode_id);
|
||||
|
||||
let extracted = GLUMoE {
|
||||
mode,
|
||||
gu_io,
|
||||
dn_io,
|
||||
gu_matmul_k,
|
||||
@@ -209,7 +280,7 @@ impl EgglogOp for GLUMoE {
|
||||
};
|
||||
|
||||
let op = LLIROp::new::<dyn HostOp>(Box::new(extracted) as Box<dyn HostOp>);
|
||||
// Return the 5 IR inputs: x, topk_idx, topk_vals, gate_up_w, down_w
|
||||
// Return the 6 IR inputs: x, topk_idx, topk_values, gate_up_w, down_w, mode_aux
|
||||
(op, input_enodes)
|
||||
}
|
||||
|
||||
@@ -230,9 +301,9 @@ impl HostOp for GLUMoE {
|
||||
// Resolve dimensions
|
||||
let hidden = self.gu_matmul_k.exec(dyn_map).unwrap();
|
||||
let intermediate = self.dn_matmul_k.exec(dyn_map).unwrap();
|
||||
let top_k = self.output_k.exec(dyn_map).unwrap();
|
||||
let top_k_expected = self.output_k.exec(dyn_map).unwrap();
|
||||
let gate_up_dim = self.gu_io.exec(dyn_map).unwrap() / hidden; // gate_up_dim = gu_io / hidden
|
||||
let _num_experts = self.gu_within_range.exec(dyn_map).unwrap() / (gate_up_dim * hidden);
|
||||
let num_experts = self.gu_within_range.exec(dyn_map).unwrap() / (gate_up_dim * hidden);
|
||||
|
||||
// Derive seq from x buffer size: x is [seq, hidden] F32 → seq = len / (hidden * 4)
|
||||
let x_buf = buffers[&inputs[0]];
|
||||
@@ -243,6 +314,7 @@ impl HostOp for GLUMoE {
|
||||
let topk_vals_buf = buffers[&inputs[2]]; // [seq, k] F32
|
||||
let gate_up_buf = buffers[&inputs[3]]; // [E, gate_up_dim, hidden] BF16
|
||||
let down_buf = buffers[&inputs[4]]; // [E, hidden, intermediate] BF16
|
||||
let mode_aux_buf = buffers[&inputs[5]];
|
||||
let output_buf = buffers[&self_node]; // [seq, hidden] F32
|
||||
|
||||
// Get raw device pointer addresses
|
||||
@@ -251,14 +323,59 @@ impl HostOp for GLUMoE {
|
||||
let down_ptr = buf_ptr(down_buf, stream);
|
||||
let output_ptr = buf_ptr(output_buf, stream);
|
||||
|
||||
let cublaslt = self.get_cublaslt(stream);
|
||||
let (_, f32_to_bf16_fn, swiglu_fn) = self.get_kernels(stream);
|
||||
let cublaslt = self.get_cublaslt(stream)?;
|
||||
let (_, f32_to_bf16_fn, activation_fn) = self.get_kernels(stream);
|
||||
|
||||
// Read topk indices and values from GPU
|
||||
// Read top-k routing values from GPU
|
||||
let topk_idx_host: Vec<u8> = stream.clone_dtoh(topk_idx_buf)?;
|
||||
let topk_idx_i32: &[i32] = bytemuck::cast_slice(&topk_idx_host);
|
||||
let topk_vals_host: Vec<u8> = stream.clone_dtoh(topk_vals_buf)?;
|
||||
let topk_vals_f32: &[f32] = bytemuck::cast_slice(&topk_vals_host);
|
||||
let idx_k = topk_idx_i32
|
||||
.len()
|
||||
.checked_div(seq)
|
||||
.unwrap_or(top_k_expected);
|
||||
let val_k = topk_vals_f32
|
||||
.len()
|
||||
.checked_div(seq)
|
||||
.unwrap_or(top_k_expected);
|
||||
let top_k = idx_k.min(val_k);
|
||||
if seq > 0 && top_k == 0 {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Mode-dependent expert weights used for the final reduction:
|
||||
// - SwiGLU: direct topk values
|
||||
// - GemmaGELU: normalize topk values and scale by per-expert factors
|
||||
let mut expert_weights_storage: Vec<f32> = Vec::new();
|
||||
let expert_weights_f32: &[f32] = match self.mode {
|
||||
GLUMoEMode::SwiGLU => topk_vals_f32,
|
||||
GLUMoEMode::GemmaGELU => {
|
||||
let per_expert_scale_host: Vec<u8> = stream.clone_dtoh(mode_aux_buf)?;
|
||||
let per_expert_scale_f32: &[f32] = bytemuck::cast_slice(&per_expert_scale_host);
|
||||
debug_assert!(per_expert_scale_f32.len() >= num_experts);
|
||||
expert_weights_storage.resize(seq * top_k, 0.0);
|
||||
for t in 0..seq {
|
||||
let base = t * top_k;
|
||||
let vals = &topk_vals_f32[base..base + top_k];
|
||||
let norm = vals.iter().copied().sum::<f32>();
|
||||
let inv_norm = if norm != 0.0 { norm.recip() } else { 0.0 };
|
||||
for i in 0..top_k {
|
||||
let expert_idx = topk_idx_i32[base + i] as usize;
|
||||
if expert_idx >= per_expert_scale_f32.len() {
|
||||
anyhow::bail!(
|
||||
"GLUMoE Gemma mode expert index {} out of bounds {}",
|
||||
expert_idx,
|
||||
per_expert_scale_f32.len()
|
||||
);
|
||||
}
|
||||
let scale = per_expert_scale_f32[expert_idx];
|
||||
expert_weights_storage[base + i] = vals[i] * inv_norm * scale;
|
||||
}
|
||||
}
|
||||
&expert_weights_storage
|
||||
}
|
||||
};
|
||||
|
||||
// Allocate temp buffers
|
||||
let x_bf16_buf = unsafe { stream.alloc::<u8>(seq * hidden * 2)? }; // BF16
|
||||
@@ -291,22 +408,10 @@ impl HostOp for GLUMoE {
|
||||
let gu_stride = (gate_up_dim * hidden * 2) as u64; // bytes per expert gate_up (BF16)
|
||||
let down_stride = (hidden * intermediate * 2) as u64; // bytes per expert down (BF16)
|
||||
|
||||
// Normalize top-k values per token (norm_topk_prob=true)
|
||||
let mut normalized_vals = topk_vals_f32.to_vec();
|
||||
for t in 0..seq {
|
||||
let row = &mut normalized_vals[t * top_k..(t + 1) * top_k];
|
||||
let sum: f32 = row.iter().sum();
|
||||
if sum > 0.0 {
|
||||
for v in row.iter_mut() {
|
||||
*v /= sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for t in 0..seq {
|
||||
let x_t_ptr = xbf16_ptr + (t * hidden * 2) as u64; // BF16
|
||||
let expert_indices = &topk_idx_i32[t * top_k..(t + 1) * top_k];
|
||||
let weights = &normalized_vals[t * top_k..(t + 1) * top_k];
|
||||
let weights = &expert_weights_f32[t * top_k..(t + 1) * top_k];
|
||||
|
||||
for (i, (&expert_idx, &weight)) in expert_indices.iter().zip(weights.iter()).enumerate()
|
||||
{
|
||||
@@ -316,7 +421,7 @@ impl HostOp for GLUMoE {
|
||||
let expert_gu_ptr = gate_up_ptr + expert_idx as u64 * gu_stride;
|
||||
cublas_matmul(
|
||||
stream,
|
||||
cublaslt,
|
||||
&cublaslt,
|
||||
ws_ptr,
|
||||
gate_up_dim as u64,
|
||||
1,
|
||||
@@ -335,17 +440,19 @@ impl HostOp for GLUMoE {
|
||||
0.0f32,
|
||||
)?;
|
||||
|
||||
// b. SwiGLU kernel (BF16 → BF16)
|
||||
// b. Mode-specific gated activation (BF16 → BF16)
|
||||
let moe_int = intermediate as i32;
|
||||
let swiglu_blocks = (moe_int as u32).div_ceil(256);
|
||||
let activation_mode = self.mode.activation_kernel_mode();
|
||||
let activation_blocks = (moe_int as u32).div_ceil(256);
|
||||
unsafe {
|
||||
stream
|
||||
.launch_builder(swiglu_fn)
|
||||
.launch_builder(activation_fn)
|
||||
.arg(&gu_out_ptr)
|
||||
.arg(&hid_ptr)
|
||||
.arg(&moe_int)
|
||||
.arg(&activation_mode)
|
||||
.launch(LaunchConfig {
|
||||
grid_dim: (swiglu_blocks, 1, 1),
|
||||
grid_dim: (activation_blocks, 1, 1),
|
||||
block_dim: (256, 1, 1),
|
||||
shared_mem_bytes: 0,
|
||||
})?;
|
||||
@@ -358,7 +465,7 @@ impl HostOp for GLUMoE {
|
||||
let beta = if i == 0 { 0.0f32 } else { 1.0f32 };
|
||||
cublas_matmul_mixed(
|
||||
stream,
|
||||
cublaslt,
|
||||
&cublaslt,
|
||||
ws_ptr,
|
||||
hidden as u64,
|
||||
1,
|
||||
|
||||
@@ -9,6 +9,8 @@ use std::{
|
||||
|
||||
pub use cudarc;
|
||||
|
||||
use cudarc::{cublaslt::CudaBlasLT, driver::CudaStream};
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
@@ -137,6 +139,25 @@ fn cuda_driver_diagnostics() -> (Option<i32>, Option<i32>) {
|
||||
(driver_version, None)
|
||||
}
|
||||
|
||||
pub(crate) fn try_create_cublaslt(
|
||||
stream: Arc<CudaStream>,
|
||||
) -> std::result::Result<Arc<CudaBlasLT>, String> {
|
||||
match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| CudaBlasLT::new(stream))) {
|
||||
Ok(Ok(handle)) => Ok(Arc::new(handle)),
|
||||
Ok(Err(err)) => Err(err.to_string()),
|
||||
Err(payload) => {
|
||||
let message = if let Some(message) = payload.downcast_ref::<String>() {
|
||||
message.clone()
|
||||
} else if let Some(message) = payload.downcast_ref::<&str>() {
|
||||
message.to_string()
|
||||
} else {
|
||||
"cuBLASLt initialization panicked".to_string()
|
||||
};
|
||||
Err(message)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn cuda_nvrtc_compile_options(target_arch: &str) -> Vec<String> {
|
||||
let mut options = cuda_nvrtc_include_paths()
|
||||
.into_iter()
|
||||
@@ -186,9 +207,9 @@ fn get_cubin(program: nvrtc_sys::nvrtcProgram) -> Result<Vec<u8>, NvrtcError> {
|
||||
}
|
||||
|
||||
let mut cubin = Vec::with_capacity(cubin_size);
|
||||
cubin.resize(cubin_size, 0);
|
||||
unsafe { nvrtc_sys::nvrtcGetCUBIN(program, cubin.as_mut_ptr()) }.result()?;
|
||||
Ok(cubin.into_iter().map(|byte| byte as u8).collect())
|
||||
cubin.resize(cubin_size, 0u8);
|
||||
unsafe { nvrtc_sys::nvrtcGetCUBIN(program, cubin.as_mut_ptr() as *mut _) }.result()?;
|
||||
Ok(cubin)
|
||||
}
|
||||
|
||||
pub(crate) fn compile_module_image_for_current_device<S: AsRef<str>>(
|
||||
|
||||
@@ -11,4 +11,6 @@ mod op_functional_tests;
|
||||
#[cfg(test)]
|
||||
mod performance_tests;
|
||||
#[cfg(test)]
|
||||
mod qwen3_moe_rewrite;
|
||||
#[cfg(test)]
|
||||
mod transformer;
|
||||
|
||||
314
crates/luminal_cuda_lite/src/tests/qwen3_moe_rewrite.rs
Normal file
314
crates/luminal_cuda_lite/src/tests/qwen3_moe_rewrite.rs
Normal file
@@ -0,0 +1,314 @@
|
||||
use half::bf16;
|
||||
use luminal::{dtype::DType, prelude::*, shape::Expression};
|
||||
|
||||
use super::utilities::{assert_close, get_cuda_stream, random_f32_vec};
|
||||
use crate::{
|
||||
host::{
|
||||
HostOp,
|
||||
moe::{GLUMoE, GLUMoEMode},
|
||||
},
|
||||
runtime::CudaRuntime,
|
||||
};
|
||||
|
||||
const SEQ: usize = 2;
|
||||
const HIDDEN: usize = 16;
|
||||
const NUM_EXPERTS: usize = 8;
|
||||
const TOP_K: usize = 2;
|
||||
const MOE_INTERMEDIATE: usize = 6;
|
||||
const RMS_NORM_EPS: f32 = 1e-6;
|
||||
|
||||
struct QwenMoeGraph {
|
||||
graph: Graph,
|
||||
x: GraphTensor,
|
||||
router: GraphTensor,
|
||||
gate_up_weights: GraphTensor,
|
||||
down_weights: GraphTensor,
|
||||
output: GraphTensor,
|
||||
}
|
||||
|
||||
struct GemmaMoeGraph {
|
||||
graph: Graph,
|
||||
router_input: GraphTensor,
|
||||
expert_input: GraphTensor,
|
||||
router_scale: GraphTensor,
|
||||
router_proj: GraphTensor,
|
||||
per_expert_scale: GraphTensor,
|
||||
gate_up_weights: GraphTensor,
|
||||
down_weights: GraphTensor,
|
||||
output: GraphTensor,
|
||||
}
|
||||
|
||||
fn build_qwen_moe_graph() -> QwenMoeGraph {
|
||||
let mut cx = Graph::default();
|
||||
let x = cx.tensor(('s', HIDDEN));
|
||||
let router = cx.tensor((NUM_EXPERTS, HIDDEN));
|
||||
let gate_up_weights = cx
|
||||
.tensor((NUM_EXPERTS, MOE_INTERMEDIATE * 2, HIDDEN))
|
||||
.as_dtype(DType::Bf16);
|
||||
let down_weights = cx
|
||||
.tensor((NUM_EXPERTS, HIDDEN, MOE_INTERMEDIATE))
|
||||
.as_dtype(DType::Bf16);
|
||||
|
||||
let n = x.dims().len();
|
||||
let e_dim = *router.dims().first().unwrap();
|
||||
let k_expr = Expression::from(TOP_K);
|
||||
|
||||
let routing_weights = x.matmul(router.t()).softmax(n - 1);
|
||||
let top_k_indices = routing_weights.topk_indexes(TOP_K, n - 1);
|
||||
|
||||
let row_offsets = x
|
||||
.graph()
|
||||
.iota(Expression::from('z') / k_expr * e_dim, top_k_indices.dims());
|
||||
let routing_flat_idx = row_offsets + top_k_indices;
|
||||
let top_k_values = routing_weights.gather(routing_flat_idx);
|
||||
|
||||
let gate_up_gathered = gather_experts(x, top_k_indices, gate_up_weights).cast(DType::F32);
|
||||
let x_exp = x.expand_dim(n - 1, TOP_K).unsqueeze(n);
|
||||
let gate_up_out = x_exp.matmul(gate_up_gathered.transpose(2, 3)).squeeze(n);
|
||||
let gate = gate_up_out.slice((.., .., ..MOE_INTERMEDIATE));
|
||||
let up = gate_up_out.slice((.., .., MOE_INTERMEDIATE..));
|
||||
let hidden = gate.silu() * up;
|
||||
|
||||
let down_gathered = gather_experts(x, top_k_indices, down_weights).cast(DType::F32);
|
||||
let down_out = hidden
|
||||
.unsqueeze(2)
|
||||
.matmul(down_gathered.transpose(2, 3))
|
||||
.squeeze(2);
|
||||
let output = (down_out * top_k_values.unsqueeze(top_k_values.dims().len()))
|
||||
.sum(n - 1)
|
||||
.output();
|
||||
|
||||
QwenMoeGraph {
|
||||
graph: cx,
|
||||
x,
|
||||
router,
|
||||
gate_up_weights,
|
||||
down_weights,
|
||||
output,
|
||||
}
|
||||
}
|
||||
|
||||
fn build_gemma_moe_graph() -> GemmaMoeGraph {
|
||||
let mut cx = Graph::default();
|
||||
let router_input = cx.tensor(('s', HIDDEN));
|
||||
let expert_input = cx.tensor(('s', HIDDEN));
|
||||
let router_scale = cx.tensor(HIDDEN);
|
||||
let router_proj = cx.tensor((NUM_EXPERTS, HIDDEN));
|
||||
let per_expert_scale = cx.tensor(NUM_EXPERTS);
|
||||
let gate_up_weights = cx
|
||||
.tensor((NUM_EXPERTS, MOE_INTERMEDIATE * 2, HIDDEN))
|
||||
.as_dtype(DType::Bf16);
|
||||
let down_weights = cx
|
||||
.tensor((NUM_EXPERTS, HIDDEN, MOE_INTERMEDIATE))
|
||||
.as_dtype(DType::Bf16);
|
||||
|
||||
let n = router_input.dims().len();
|
||||
let e_dim = *router_proj.dims().first().unwrap();
|
||||
let k_expr = Expression::from(TOP_K);
|
||||
|
||||
let router_hidden = router_input.std_norm(n - 1, RMS_NORM_EPS)
|
||||
* router_scale.expand_lhs(&router_input.dims()[..n - 1])
|
||||
* (HIDDEN as f32).sqrt().recip();
|
||||
let routing_weights = router_hidden.matmul(router_proj.t()).softmax(n - 1);
|
||||
|
||||
let top_k_indices = routing_weights.topk_indexes(TOP_K, n - 1);
|
||||
let row_offsets = router_input
|
||||
.graph()
|
||||
.iota(Expression::from('z') / k_expr * e_dim, top_k_indices.dims());
|
||||
let routing_flat_idx = row_offsets + top_k_indices;
|
||||
let top_k_values = routing_weights.gather(routing_flat_idx);
|
||||
let top_k_norm = top_k_values.sum(n - 1).expand_dim(n - 1, TOP_K);
|
||||
let top_k_weights = (top_k_values / top_k_norm) * per_expert_scale.gather(top_k_indices);
|
||||
|
||||
let gate_up_gathered =
|
||||
gather_experts(expert_input, top_k_indices, gate_up_weights).cast(DType::F32);
|
||||
let x_exp = expert_input.expand_dim(n - 1, TOP_K).unsqueeze(n);
|
||||
let gate_up_out = x_exp.matmul(gate_up_gathered.transpose(2, 3)).squeeze(n);
|
||||
let gate = gate_up_out.slice((.., .., ..MOE_INTERMEDIATE));
|
||||
let up = gate_up_out.slice((.., .., MOE_INTERMEDIATE..));
|
||||
let hidden = gemma_gelu(gate) * up;
|
||||
|
||||
let down_gathered = gather_experts(expert_input, top_k_indices, down_weights).cast(DType::F32);
|
||||
let down_out = hidden
|
||||
.unsqueeze(2)
|
||||
.matmul(down_gathered.transpose(2, 3))
|
||||
.squeeze(2);
|
||||
let output = (down_out * top_k_weights.unsqueeze(top_k_weights.dims().len()))
|
||||
.sum(n - 1)
|
||||
.output();
|
||||
|
||||
GemmaMoeGraph {
|
||||
graph: cx,
|
||||
router_input,
|
||||
expert_input,
|
||||
router_scale,
|
||||
router_proj,
|
||||
per_expert_scale,
|
||||
gate_up_weights,
|
||||
down_weights,
|
||||
output,
|
||||
}
|
||||
}
|
||||
|
||||
fn gather_experts(
|
||||
graph_source: GraphTensor,
|
||||
top_k_indices: GraphTensor,
|
||||
weights: GraphTensor,
|
||||
) -> GraphTensor {
|
||||
let (_, d1, d2) = weights.dims3();
|
||||
let io = d1 * d2;
|
||||
let base = top_k_indices * io;
|
||||
let within = graph_source.graph().iota(Expression::from('z'), (d1, d2));
|
||||
let n_base = base.dims().len();
|
||||
let exp_base = base.expand_dim(n_base, d1).expand_dim(n_base + 1, d2);
|
||||
let mut exp_within = within;
|
||||
for (axis, dim) in base.dims().iter().enumerate() {
|
||||
exp_within = exp_within.expand_dim(axis, *dim);
|
||||
}
|
||||
let expert_flat_idx = exp_base + exp_within;
|
||||
weights.gather(expert_flat_idx)
|
||||
}
|
||||
|
||||
#[allow(clippy::excessive_precision)]
|
||||
fn gemma_gelu(x: GraphTensor) -> GraphTensor {
|
||||
let scaled = 1.5957691216 * x * (1. + 0.044715 * x * x);
|
||||
x * scaled.sigmoid()
|
||||
}
|
||||
|
||||
fn glumoe_modes(rt: &CudaRuntime) -> Vec<GLUMoEMode> {
|
||||
rt.llir_graph()
|
||||
.node_weights()
|
||||
.filter_map(|node| {
|
||||
let op = node.to_dialect::<dyn HostOp>()?;
|
||||
op.as_any()
|
||||
.downcast_ref::<GLUMoE>()
|
||||
.map(|glumoe| glumoe.mode)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn run_qwen_moe(use_glumoe: bool) -> (Vec<f32>, Vec<GLUMoEMode>) {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return (vec![], vec![]);
|
||||
};
|
||||
|
||||
let mut model = build_qwen_moe_graph();
|
||||
model.graph.set_dim('s', SEQ);
|
||||
if use_glumoe {
|
||||
model.graph.build_search_space::<CudaRuntime>();
|
||||
} else {
|
||||
model
|
||||
.graph
|
||||
.build_search_space_exclude_ops::<CudaRuntime, GLUMoE>();
|
||||
}
|
||||
|
||||
let x_data = random_f32_vec(SEQ * HIDDEN, 11, -0.15, 0.15);
|
||||
let router_data = random_f32_vec(NUM_EXPERTS * HIDDEN, 12, -0.2, 0.2);
|
||||
let gate_up_data = random_f32_vec(NUM_EXPERTS * MOE_INTERMEDIATE * 2 * HIDDEN, 13, -0.1, 0.1)
|
||||
.into_iter()
|
||||
.map(bf16::from_f32)
|
||||
.collect::<Vec<_>>();
|
||||
let down_data = random_f32_vec(NUM_EXPERTS * HIDDEN * MOE_INTERMEDIATE, 14, -0.1, 0.1)
|
||||
.into_iter()
|
||||
.map(bf16::from_f32)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.set_data(model.x, x_data);
|
||||
rt.set_data(model.router, router_data);
|
||||
rt.set_data(model.gate_up_weights, gate_up_data);
|
||||
rt.set_data(model.down_weights, down_data);
|
||||
rt = model.graph.search(rt, 10);
|
||||
rt.execute(&model.graph.dyn_map);
|
||||
|
||||
(rt.get_f32(model.output.id), glumoe_modes(&rt))
|
||||
}
|
||||
|
||||
fn run_gemma_moe(use_glumoe: bool) -> (Vec<f32>, Vec<GLUMoEMode>) {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return (vec![], vec![]);
|
||||
};
|
||||
|
||||
let mut model = build_gemma_moe_graph();
|
||||
model.graph.set_dim('s', SEQ);
|
||||
if use_glumoe {
|
||||
model.graph.build_search_space::<CudaRuntime>();
|
||||
} else {
|
||||
model
|
||||
.graph
|
||||
.build_search_space_exclude_ops::<CudaRuntime, GLUMoE>();
|
||||
}
|
||||
|
||||
let router_input_data = random_f32_vec(SEQ * HIDDEN, 21, -0.15, 0.15);
|
||||
let expert_input_data = random_f32_vec(SEQ * HIDDEN, 22, -0.15, 0.15);
|
||||
let router_scale_data = random_f32_vec(HIDDEN, 23, 0.7, 1.3);
|
||||
let router_proj_data = random_f32_vec(NUM_EXPERTS * HIDDEN, 24, -0.2, 0.2);
|
||||
let per_expert_scale_data = random_f32_vec(NUM_EXPERTS, 25, 0.5, 1.5);
|
||||
let gate_up_data = random_f32_vec(NUM_EXPERTS * MOE_INTERMEDIATE * 2 * HIDDEN, 26, -0.1, 0.1)
|
||||
.into_iter()
|
||||
.map(bf16::from_f32)
|
||||
.collect::<Vec<_>>();
|
||||
let down_data = random_f32_vec(NUM_EXPERTS * HIDDEN * MOE_INTERMEDIATE, 27, -0.1, 0.1)
|
||||
.into_iter()
|
||||
.map(bf16::from_f32)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.set_data(model.router_input, router_input_data);
|
||||
rt.set_data(model.expert_input, expert_input_data);
|
||||
rt.set_data(model.router_scale, router_scale_data);
|
||||
rt.set_data(model.router_proj, router_proj_data);
|
||||
rt.set_data(model.per_expert_scale, per_expert_scale_data);
|
||||
rt.set_data(model.gate_up_weights, gate_up_data);
|
||||
rt.set_data(model.down_weights, down_data);
|
||||
rt = model.graph.search(rt, 10);
|
||||
rt.execute(&model.graph.dyn_map);
|
||||
|
||||
(rt.get_f32(model.output.id), glumoe_modes(&rt))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_glumoe_matches_qwen_swiglu_pattern() {
|
||||
let (_result, modes) = run_qwen_moe(true);
|
||||
if modes.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
assert_eq!(modes, vec![GLUMoEMode::SwiGLU]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_glumoe_matches_gemma_gelu_pattern() {
|
||||
let (_result, modes) = run_gemma_moe(true);
|
||||
if modes.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
assert_eq!(modes, vec![GLUMoEMode::GemmaGELU]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_glumoe_swiglu_matches_unfused_output() {
|
||||
let (expected, baseline_modes) = run_qwen_moe(false);
|
||||
if expected.is_empty() {
|
||||
return;
|
||||
}
|
||||
assert!(baseline_modes.is_empty());
|
||||
|
||||
let (actual, fused_modes) = run_qwen_moe(true);
|
||||
assert_eq!(fused_modes, vec![GLUMoEMode::SwiGLU]);
|
||||
assert_close(&actual, &expected, 3e-2, 3e-2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_glumoe_gemma_gelu_matches_unfused_output() {
|
||||
let (expected, baseline_modes) = run_gemma_moe(false);
|
||||
if expected.is_empty() {
|
||||
return;
|
||||
}
|
||||
assert!(baseline_modes.is_empty());
|
||||
|
||||
let (actual, fused_modes) = run_gemma_moe(true);
|
||||
assert_eq!(fused_modes, vec![GLUMoEMode::GemmaGELU]);
|
||||
assert_close(&actual, &expected, 3e-2, 3e-2);
|
||||
}
|
||||
@@ -628,6 +628,84 @@ impl CompiledGraph {
|
||||
}
|
||||
}
|
||||
|
||||
fn get_output_i32(&self, name: &str) -> PyResult<Vec<i32>> {
|
||||
let node_id = self.tensor_ids.get(name).ok_or_else(|| {
|
||||
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!(
|
||||
"Unknown output tensor: {}",
|
||||
name
|
||||
))
|
||||
})?;
|
||||
match &self.runtime {
|
||||
RuntimeBackend::Native(rt) => {
|
||||
let id = *node_id;
|
||||
let output_id = rt
|
||||
.graph
|
||||
.node_indices()
|
||||
.find(|n| {
|
||||
if let Some(out) = (**rt.graph[*n]).as_any().downcast_ref::<Output>() {
|
||||
out.node == id.index()
|
||||
} else {
|
||||
false
|
||||
}
|
||||
})
|
||||
.ok_or_else(|| {
|
||||
pyo3::exceptions::PyRuntimeError::new_err(format!(
|
||||
"No output node found for tensor: {}",
|
||||
name
|
||||
))
|
||||
})?;
|
||||
let data = rt.buffers.get(&output_id).ok_or_else(|| {
|
||||
pyo3::exceptions::PyRuntimeError::new_err(format!(
|
||||
"No buffer data for output tensor: {}",
|
||||
name
|
||||
))
|
||||
})?;
|
||||
Ok((0..data.len()).map(|i| data.i32(i)).collect())
|
||||
}
|
||||
#[cfg(feature = "cuda")]
|
||||
RuntimeBackend::Cuda(rt) => Ok(rt.get_i32(*node_id)),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_output_bool(&self, name: &str) -> PyResult<Vec<bool>> {
|
||||
let node_id = self.tensor_ids.get(name).ok_or_else(|| {
|
||||
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!(
|
||||
"Unknown output tensor: {}",
|
||||
name
|
||||
))
|
||||
})?;
|
||||
match &self.runtime {
|
||||
RuntimeBackend::Native(rt) => {
|
||||
let id = *node_id;
|
||||
let output_id = rt
|
||||
.graph
|
||||
.node_indices()
|
||||
.find(|n| {
|
||||
if let Some(out) = (**rt.graph[*n]).as_any().downcast_ref::<Output>() {
|
||||
out.node == id.index()
|
||||
} else {
|
||||
false
|
||||
}
|
||||
})
|
||||
.ok_or_else(|| {
|
||||
pyo3::exceptions::PyRuntimeError::new_err(format!(
|
||||
"No output node found for tensor: {}",
|
||||
name
|
||||
))
|
||||
})?;
|
||||
let data = rt.buffers.get(&output_id).ok_or_else(|| {
|
||||
pyo3::exceptions::PyRuntimeError::new_err(format!(
|
||||
"No buffer data for output tensor: {}",
|
||||
name
|
||||
))
|
||||
})?;
|
||||
Ok((0..data.len()).map(|i| data.bool(i)).collect())
|
||||
}
|
||||
#[cfg(feature = "cuda")]
|
||||
RuntimeBackend::Cuda(rt) => Ok(rt.get_bool(*node_id)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Copy output tensor data directly to a CUDA device pointer (DtoD).
|
||||
/// Avoids the DtoH + HtoD round-trip of get_output() + .to(device).
|
||||
#[cfg(feature = "cuda")]
|
||||
|
||||
@@ -51,6 +51,7 @@ impl<'a> Translator<'a> {
|
||||
"torch.ops.aten.sub.Scalar" => self.translate_binary_scalar_op(node, BinaryOp::Sub)?,
|
||||
"torch.ops.aten.div.Tensor" => self.translate_binary_op(node, BinaryOp::Div)?,
|
||||
"torch.ops.aten.div.Scalar" => self.translate_binary_scalar_op(node, BinaryOp::Div)?,
|
||||
"torch.ops.aten.div.Tensor_mode" => self.translate_div_tensor_mode(node)?,
|
||||
|
||||
// Unary ops
|
||||
"torch.ops.aten.neg.default" => self.translate_unary_op(node, |a| a * (-1.0))?,
|
||||
@@ -71,6 +72,8 @@ impl<'a> Translator<'a> {
|
||||
"torch.ops.aten.log.default" => self.translate_unary_op(node, |a| a.log())?,
|
||||
"torch.ops.aten.log2.default" => self.translate_unary_op(node, |a| a.log2())?,
|
||||
"torch.ops.aten.exp2.default" => self.translate_unary_op(node, |a| a.exp2())?,
|
||||
"torch.ops.aten.sign.default" => self.translate_sign(node)?,
|
||||
"torch.ops.aten.bitwise_not.default" => self.translate_bitwise_not(node)?,
|
||||
|
||||
// Cast
|
||||
"torch.ops.aten._to_copy.default" => self.translate_to_copy(node)?,
|
||||
@@ -109,6 +112,7 @@ impl<'a> Translator<'a> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
if !a.shape.is_contiguous() { a + 0.0 } else { a }
|
||||
}
|
||||
"torch.ops.aten.argsort.default" => self.translate_argsort(node)?,
|
||||
|
||||
// Matmul
|
||||
"torch.ops.aten.mm.default" | "torch.ops.aten.bmm.default" => {
|
||||
@@ -159,6 +163,8 @@ impl<'a> Translator<'a> {
|
||||
|
||||
// Where
|
||||
"torch.ops.aten.where.self" => self.translate_where(node)?,
|
||||
"torch.ops.aten.where.ScalarOther" => self.translate_where_scalar_other(node)?,
|
||||
"torch.ops.aten.masked_fill.Scalar" => self.translate_masked_fill_scalar(node)?,
|
||||
|
||||
// Pow
|
||||
"torch.ops.aten.pow.Tensor_Scalar" => {
|
||||
@@ -176,6 +182,7 @@ impl<'a> Translator<'a> {
|
||||
// Creation ops
|
||||
"torch.ops.aten.arange.start_step" => self.translate_arange(node)?,
|
||||
"torch.ops.aten.full.default" => self.translate_full(node)?,
|
||||
"torch.ops.aten.full_like.default" => self.translate_full_like(node)?,
|
||||
"torch.ops.aten.scalar_tensor.default" => {
|
||||
let val = self.get_float_arg(node, 0)? as f32;
|
||||
self.graph.constant_float(val)
|
||||
@@ -349,7 +356,17 @@ impl<'a> Translator<'a> {
|
||||
|
||||
// Scatter ops
|
||||
"torch.ops.aten.scatter.src" => self.translate_scatter_src(node)?,
|
||||
"torch.ops.aten.index_put.default" => self.translate_index_put(node)?,
|
||||
"torch.ops.aten.scatter.value" => self.translate_scatter_value(node)?,
|
||||
"torch.ops.aten.index_put_.default" | "torch.ops.aten.index_put.default" => {
|
||||
self.translate_index_put(node)?
|
||||
}
|
||||
|
||||
// Integer routing math
|
||||
"torch.ops.aten.floor_divide.default" => self.translate_floor_divide(node)?,
|
||||
|
||||
// Triangular
|
||||
"torch.ops.aten.tril.default" => self.translate_tril(node)?,
|
||||
"torch.ops.aten.triu.default" => self.translate_triu(node)?,
|
||||
|
||||
// TopK — handles its own output storage, returns early
|
||||
"torch.ops.aten.topk.default" => {
|
||||
@@ -357,6 +374,12 @@ impl<'a> Translator<'a> {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Sort — handles its own output storage, returns early
|
||||
"torch.ops.aten.sort.default" => {
|
||||
self.translate_sort(node)?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Split
|
||||
"torch.ops.aten.split_with_sizes.default" => self.translate_split_with_sizes(node)?,
|
||||
|
||||
|
||||
@@ -77,11 +77,12 @@ impl<'a> Translator<'a> {
|
||||
let output_names = self.parsed.output_names();
|
||||
for name in &output_names {
|
||||
let tensor = self.get_tensor(name)?;
|
||||
// Cast non-float outputs (Bool, Int) to F32 for the runtime.
|
||||
// Preserve F16/BF16/F32 as-is to avoid corrupting half-precision models.
|
||||
let tensor = match tensor.dtype {
|
||||
DType::Bool | DType::Int => tensor.cast(DType::F32) + 0.0,
|
||||
_ => tensor + 0.0,
|
||||
let tensor = if tensor.dtype == DType::Bool {
|
||||
tensor.cast(DType::Int).cast(DType::Bool)
|
||||
} else if tensor.dtype == DType::Int {
|
||||
tensor
|
||||
} else {
|
||||
tensor + 0.0
|
||||
};
|
||||
tensor.output();
|
||||
self.output_ids.push((name.clone(), tensor.id));
|
||||
@@ -155,6 +156,12 @@ impl<'a> Translator<'a> {
|
||||
|
||||
// --- Helper methods ---
|
||||
|
||||
pub(crate) fn tensor_meta(&self, name: &str) -> Option<&TensorMeta> {
|
||||
self.extra_tensor_values
|
||||
.get(name)
|
||||
.or_else(|| self.parsed.tensor_meta(name))
|
||||
}
|
||||
|
||||
pub(crate) fn get_tensor(&self, name: &str) -> Result<GraphTensor> {
|
||||
self.tensors
|
||||
.get(name)
|
||||
|
||||
@@ -6,6 +6,11 @@ use crate::pt2_util::*;
|
||||
|
||||
use super::Translator;
|
||||
|
||||
const SCATTER_INPUT_ARG: usize = 0;
|
||||
const SCATTER_DIM_ARG: usize = 1;
|
||||
const SCATTER_INDEX_ARG: usize = 2;
|
||||
const SCATTER_VALUE_ARG: usize = 3;
|
||||
|
||||
impl<'a> Translator<'a> {
|
||||
pub(crate) fn translate_reshape(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
@@ -359,6 +364,29 @@ impl<'a> Translator<'a> {
|
||||
Ok(a.scatter_elements(indices.cast(DType::Int), src, dim))
|
||||
}
|
||||
|
||||
pub(crate) fn translate_scatter_value(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, SCATTER_INPUT_ARG)?;
|
||||
let dim = self.get_int_arg(node, SCATTER_DIM_ARG)?;
|
||||
let dim = normalize_dim(dim, a.shape.len());
|
||||
let indices = self.get_input_tensor(node, SCATTER_INDEX_ARG)?;
|
||||
let value_arg = &node
|
||||
.inputs
|
||||
.get(SCATTER_VALUE_ARG)
|
||||
.context("scatter.value missing value input")?
|
||||
.arg;
|
||||
let value = if let Some(b) = value_arg.as_bool() {
|
||||
self.graph.constant(if b { 1 } else { 0 }).cast(a.dtype)
|
||||
} else if let Some(i) = value_arg.as_int() {
|
||||
self.graph.constant(i).cast(a.dtype)
|
||||
} else if let Some(f) = value_arg.as_float() {
|
||||
self.graph.constant_float(f as f32).cast(a.dtype)
|
||||
} else {
|
||||
bail!("scatter.value: unsupported scalar argument {:?}", value_arg);
|
||||
}
|
||||
.expand_rhs(indices.shape);
|
||||
Ok(a.scatter_elements(indices.cast(DType::Int), value, dim))
|
||||
}
|
||||
|
||||
pub(crate) fn translate_index_put(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let index_names = node.inputs[1]
|
||||
|
||||
@@ -6,6 +6,27 @@ use crate::pt2_util::*;
|
||||
|
||||
use super::Translator;
|
||||
|
||||
const FULL_SHAPE_ARG: usize = 0;
|
||||
const FULL_VALUE_ARG: usize = 1;
|
||||
|
||||
const FULL_LIKE_INPUT_ARG: usize = 0;
|
||||
const FULL_LIKE_VALUE_ARG: usize = 1;
|
||||
|
||||
const TOPK_INPUT_ARG: usize = 0;
|
||||
const TOPK_K_ARG: usize = 1;
|
||||
const TOPK_DIM_ARG: usize = 2;
|
||||
|
||||
const SORT_INPUT_ARG: usize = 0;
|
||||
const SORT_DIM_ARG: usize = 1;
|
||||
const SORT_DESCENDING_ARG: usize = 2;
|
||||
|
||||
const WHERE_COND_ARG: usize = 0;
|
||||
const WHERE_X_ARG: usize = 1;
|
||||
const WHERE_OTHER_ARG: usize = 2;
|
||||
|
||||
const TRIANGULAR_INPUT_ARG: usize = 0;
|
||||
const TRIANGULAR_DIAGONAL_ARG: usize = 1;
|
||||
|
||||
impl<'a> Translator<'a> {
|
||||
pub(crate) fn translate_arange(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let positional_args: Vec<Expression> = node
|
||||
@@ -30,19 +51,55 @@ impl<'a> Translator<'a> {
|
||||
}
|
||||
|
||||
pub(crate) fn translate_full(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let shape = self.get_exprs_arg(node, 0)?;
|
||||
let shape = self.get_exprs_arg(node, FULL_SHAPE_ARG)?;
|
||||
// fill_value can be float, int, or bool after decomposition
|
||||
let val = if let Ok(f) = self.get_float_arg(node, 1) {
|
||||
let val = if let Ok(f) = self.get_float_arg(node, FULL_VALUE_ARG) {
|
||||
f as f32
|
||||
} else if let Ok(b) = self.get_bool_arg(node, 1) {
|
||||
} else if let Ok(b) = self.get_bool_arg(node, FULL_VALUE_ARG) {
|
||||
if b { 1.0 } else { 0.0 }
|
||||
} else {
|
||||
anyhow::bail!(
|
||||
"full: unsupported fill value type: {:?}",
|
||||
node.inputs.get(1)
|
||||
node.inputs.get(FULL_VALUE_ARG)
|
||||
);
|
||||
};
|
||||
Ok(self.graph.constant_float(val).expand_rhs(shape))
|
||||
let dtype = self.output_meta_dtype(node)?;
|
||||
let value = self.graph.constant_float(val).cast(dtype);
|
||||
Ok(if shape.is_empty() {
|
||||
value
|
||||
} else {
|
||||
value.expand_rhs(shape)
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn translate_full_like(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let reference = self.get_input_tensor(node, FULL_LIKE_INPUT_ARG)?;
|
||||
let val = if let Ok(f) = self.get_float_arg(node, FULL_LIKE_VALUE_ARG) {
|
||||
f as f32
|
||||
} else if let Ok(b) = self.get_bool_arg(node, FULL_LIKE_VALUE_ARG) {
|
||||
if b { 1.0 } else { 0.0 }
|
||||
} else {
|
||||
anyhow::bail!(
|
||||
"full_like: unsupported fill value type: {:?}",
|
||||
node.inputs.get(FULL_LIKE_VALUE_ARG)
|
||||
);
|
||||
};
|
||||
let dtype = self.output_meta_dtype(node)?;
|
||||
let value = self.graph.constant_float(val).cast(dtype);
|
||||
Ok(value.expand_rhs(reference.shape))
|
||||
}
|
||||
|
||||
fn output_meta_dtype(&self, node: &Node) -> Result<DType> {
|
||||
let output_name = node
|
||||
.outputs
|
||||
.first()
|
||||
.and_then(|o| o.as_tensor.as_ref())
|
||||
.map(|t| t.name.clone())
|
||||
.unwrap_or_default();
|
||||
let meta = self
|
||||
.tensor_meta(&output_name)
|
||||
.context("Missing tensor meta for output dtype")?;
|
||||
Ok(torch_dtype_int_to_luminal(meta.dtype))
|
||||
}
|
||||
|
||||
pub(crate) fn translate_where(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
@@ -62,11 +119,64 @@ impl<'a> Translator<'a> {
|
||||
Ok(c * x_f + (one - c) * y_f)
|
||||
}
|
||||
|
||||
pub(crate) fn translate_where_scalar_other(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let cond = self.get_input_tensor(node, WHERE_COND_ARG)?;
|
||||
let x = self.get_input_tensor(node, WHERE_X_ARG)?;
|
||||
let other_val = self.get_float_arg(node, WHERE_OTHER_ARG)? as f32;
|
||||
// Broadcast cond and x to a common shape
|
||||
let (cond_b, x_b) = broadcast_binary(cond, x);
|
||||
let c = cond_b.cast(DType::F32);
|
||||
let one = self.graph.constant_float(1.0).expand_rhs(c.shape);
|
||||
let other = self.graph.constant_float(other_val).expand_rhs(c.shape);
|
||||
Ok(c * x_b + (one - c) * other)
|
||||
}
|
||||
|
||||
pub(crate) fn translate_tril(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
self.translate_triangular(node, false)
|
||||
}
|
||||
|
||||
pub(crate) fn translate_triu(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
self.translate_triangular(node, true)
|
||||
}
|
||||
|
||||
fn translate_triangular(&mut self, node: &Node, upper: bool) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, TRIANGULAR_INPUT_ARG)?;
|
||||
let diagonal = if node.inputs.len() > TRIANGULAR_DIAGONAL_ARG {
|
||||
self.get_int_arg(node, TRIANGULAR_DIAGONAL_ARG).unwrap_or(0) as i32
|
||||
} else {
|
||||
0
|
||||
};
|
||||
let dims = a.shape.dims;
|
||||
let rows = dims[dims.len() - 2];
|
||||
let cols = dims[dims.len() - 1];
|
||||
let (r_val, c_val) = match (rows.to_usize(), cols.to_usize()) {
|
||||
(Some(r), Some(c)) => (r, c),
|
||||
_ => anyhow::bail!("tril/triu requires concrete matrix dimensions"),
|
||||
};
|
||||
let size = r_val.max(c_val);
|
||||
let mask = if upper {
|
||||
self.graph.triu(size, diagonal)
|
||||
} else {
|
||||
self.graph.tril(size, diagonal)
|
||||
}
|
||||
.cast(DType::F32);
|
||||
let mask = if rows != cols {
|
||||
mask.slice_along(0..r_val, 0).slice_along(0..c_val, 1)
|
||||
} else {
|
||||
mask
|
||||
};
|
||||
let mut mask_expanded = mask;
|
||||
for i in (0..dims.len() - 2).rev() {
|
||||
mask_expanded = mask_expanded.expand_dim(0, dims[i]);
|
||||
}
|
||||
Ok(a * mask_expanded)
|
||||
}
|
||||
|
||||
pub(crate) fn translate_topk(&mut self, node: &Node) -> Result<()> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let k = self.get_int_arg(node, 1)? as usize;
|
||||
let dim = if node.inputs.len() > 2 {
|
||||
self.get_int_arg(node, 2).unwrap_or(-1)
|
||||
let a = self.get_input_tensor(node, TOPK_INPUT_ARG)?;
|
||||
let k = self.get_int_arg(node, TOPK_K_ARG)? as usize;
|
||||
let dim = if node.inputs.len() > TOPK_DIM_ARG {
|
||||
self.get_int_arg(node, TOPK_DIM_ARG).unwrap_or(-1)
|
||||
} else {
|
||||
-1
|
||||
};
|
||||
@@ -86,13 +196,10 @@ impl<'a> Translator<'a> {
|
||||
None
|
||||
};
|
||||
|
||||
// Use full argsort then slice, rather than topk_indexes/topk_values directly.
|
||||
// This avoids a CUDA gather kernel bug when data and index shapes differ
|
||||
// along the gather axis (topk_indexes returns a sliced tensor).
|
||||
let full_argsort = a.argsort(dim, true);
|
||||
// Build top-k outputs from a full stable argsort, then slice to k.
|
||||
let full_argsort = a.stable_argsort(dim, true);
|
||||
|
||||
// Only build each branch when its output is consumed.
|
||||
// Dead nodes in the graph can confuse the CUDA optimizer.
|
||||
// Only build the outputs that are consumed.
|
||||
if let Some(val_name) = values_name
|
||||
&& !val_name.is_empty()
|
||||
{
|
||||
@@ -100,8 +207,7 @@ impl<'a> Translator<'a> {
|
||||
self.tensors.insert(val_name, values);
|
||||
}
|
||||
if let Some(idx_name) = indices_name {
|
||||
// Materialize Int indices as F32 with `* 1.0` to force a contiguous copy.
|
||||
// Without this, CUDA can't correctly read the sliced Int view.
|
||||
// Materialize the sliced indices through a copy before storing them.
|
||||
let indices = full_argsort.slice_along(..k, dim) * 1.0;
|
||||
self.tensors.insert(idx_name, indices);
|
||||
}
|
||||
@@ -109,6 +215,51 @@ impl<'a> Translator<'a> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn translate_sort(&mut self, node: &Node) -> Result<()> {
|
||||
let a = self.get_input_tensor(node, SORT_INPUT_ARG)?;
|
||||
let dim = if node.inputs.len() > SORT_DIM_ARG {
|
||||
self.get_int_arg(node, SORT_DIM_ARG).unwrap_or(-1)
|
||||
} else {
|
||||
-1
|
||||
};
|
||||
let descending = if node.inputs.len() > SORT_DESCENDING_ARG {
|
||||
self.get_bool_arg(node, SORT_DESCENDING_ARG)
|
||||
.unwrap_or(false)
|
||||
} else {
|
||||
false
|
||||
};
|
||||
let dim = normalize_dim(dim, a.shape.len());
|
||||
|
||||
// Determine output names (sort returns (values, indices))
|
||||
let values_name = node
|
||||
.outputs
|
||||
.first()
|
||||
.and_then(|o| o.as_tensor.as_ref().map(|t| t.name.clone()));
|
||||
let indices_name =
|
||||
if let Some(ts) = node.outputs.first().and_then(|o| o.as_tensors.as_ref()) {
|
||||
ts.get(1).map(|t| t.name.clone())
|
||||
} else if node.outputs.len() > 1 {
|
||||
node.outputs[1].as_tensor.as_ref().map(|t| t.name.clone())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let full_argsort = a.stable_argsort(dim, descending);
|
||||
|
||||
if let Some(val_name) = values_name
|
||||
&& !val_name.is_empty()
|
||||
{
|
||||
let values = a.gather_elements(full_argsort, dim);
|
||||
self.tensors.insert(val_name, values);
|
||||
}
|
||||
if let Some(idx_name) = indices_name {
|
||||
let indices = full_argsort * 1.0;
|
||||
self.tensors.insert(idx_name, indices);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn translate_wrap_set_grad(&mut self, node: &Node) -> Result<()> {
|
||||
let subgraph = node.inputs[1]
|
||||
.arg
|
||||
|
||||
@@ -6,7 +6,38 @@ use crate::pt2_util::{broadcast_binary, torch_dtype_int_to_luminal};
|
||||
|
||||
use super::Translator;
|
||||
|
||||
const ARGSORT_INPUT_ARG: usize = 0;
|
||||
const ARGSORT_DIM_ARG: usize = 1;
|
||||
const ARGSORT_DESCENDING_ARG: usize = 2;
|
||||
|
||||
const MASKED_FILL_INPUT_ARG: usize = 0;
|
||||
const MASKED_FILL_MASK_ARG: usize = 1;
|
||||
const MASKED_FILL_VALUE_ARG: usize = 2;
|
||||
|
||||
const FLOOR_DIVIDE_INPUT_ARG: usize = 0;
|
||||
const FLOOR_DIVIDE_OTHER_ARG: usize = 1;
|
||||
|
||||
const DIV_MODE_INPUT_ARG: usize = 0;
|
||||
const DIV_MODE_OTHER_ARG: usize = 1;
|
||||
|
||||
impl<'a> Translator<'a> {
|
||||
pub(crate) fn translate_argsort(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, ARGSORT_INPUT_ARG)?;
|
||||
let dim = if node.inputs.len() > ARGSORT_DIM_ARG {
|
||||
self.get_int_arg(node, ARGSORT_DIM_ARG).unwrap_or(-1)
|
||||
} else {
|
||||
-1
|
||||
};
|
||||
let descending = if node.inputs.len() > ARGSORT_DESCENDING_ARG {
|
||||
self.get_bool_arg(node, ARGSORT_DESCENDING_ARG)
|
||||
.unwrap_or(false)
|
||||
} else {
|
||||
false
|
||||
};
|
||||
let dim = crate::pt2_util::normalize_dim(dim, a.shape.len());
|
||||
Ok(a.stable_argsort(dim, descending))
|
||||
}
|
||||
|
||||
pub(crate) fn translate_unary_op(
|
||||
&mut self,
|
||||
node: &Node,
|
||||
@@ -19,11 +50,15 @@ impl<'a> Translator<'a> {
|
||||
pub(crate) fn translate_to_copy(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
for input in &node.inputs {
|
||||
if input.name == "dtype"
|
||||
&& let Some(dtype_int) = input.arg.as_int()
|
||||
{
|
||||
let dtype = torch_dtype_int_to_luminal(dtype_int as u32);
|
||||
return Ok(a.cast(dtype));
|
||||
if input.name == "dtype" {
|
||||
if let Some(dtype_int) = input.arg.as_int() {
|
||||
let dtype = torch_dtype_int_to_luminal(dtype_int as u32);
|
||||
return Ok(a.cast(dtype));
|
||||
}
|
||||
if let Some(dtype_int) = input.arg.as_scalar_type() {
|
||||
let dtype = torch_dtype_int_to_luminal(dtype_int);
|
||||
return Ok(a.cast(dtype));
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(a)
|
||||
@@ -60,6 +95,155 @@ impl<'a> Translator<'a> {
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
pub(crate) fn translate_sign(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let zero = self
|
||||
.graph
|
||||
.constant_float(0.0)
|
||||
.cast(a.dtype)
|
||||
.expand_rhs(a.shape);
|
||||
let pos = a.gt(zero).cast(DType::Int);
|
||||
let neg = a.lt(zero).cast(DType::Int);
|
||||
let signed = pos - neg;
|
||||
Ok(if a.dtype == DType::Int {
|
||||
signed
|
||||
} else {
|
||||
signed.cast(a.dtype)
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn translate_bitwise_not(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
Ok(match a.dtype {
|
||||
DType::Bool => {
|
||||
let one = self
|
||||
.graph
|
||||
.constant_float(1.0)
|
||||
.cast(DType::Int)
|
||||
.expand_rhs(a.shape);
|
||||
(one - a.cast(DType::Int)).cast(DType::Bool)
|
||||
}
|
||||
DType::Int => (a + 1) * -1.0,
|
||||
other => {
|
||||
anyhow::bail!("bitwise_not only supports Bool/Int routing tensors, got {other:?}")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn translate_masked_fill_scalar(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let input = self.get_input_tensor(node, MASKED_FILL_INPUT_ARG)?;
|
||||
let mask = self.get_input_tensor(node, MASKED_FILL_MASK_ARG)?;
|
||||
let fill = self.get_float_arg(node, MASKED_FILL_VALUE_ARG)? as f32;
|
||||
let (input, mask) = broadcast_binary(input, mask);
|
||||
let work_dtype = if input.dtype == DType::Bool {
|
||||
DType::Int
|
||||
} else {
|
||||
input.dtype
|
||||
};
|
||||
let input_work = if input.dtype == DType::Bool {
|
||||
input.cast(DType::Int)
|
||||
} else {
|
||||
input
|
||||
};
|
||||
let mask_work = mask.cast(work_dtype);
|
||||
let fill_work = self
|
||||
.graph
|
||||
.constant_float(fill)
|
||||
.cast(work_dtype)
|
||||
.expand_rhs(input_work.shape);
|
||||
let one = self
|
||||
.graph
|
||||
.constant_float(1.0)
|
||||
.cast(work_dtype)
|
||||
.expand_rhs(input_work.shape);
|
||||
let result = mask_work * fill_work + (one - mask_work) * input_work;
|
||||
Ok(if input.dtype == DType::Bool {
|
||||
result.cast(DType::Bool)
|
||||
} else {
|
||||
result
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn translate_floor_divide(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, FLOOR_DIVIDE_INPUT_ARG)?;
|
||||
let b = if let Some(name) = node
|
||||
.inputs
|
||||
.get(FLOOR_DIVIDE_OTHER_ARG)
|
||||
.and_then(|i| i.arg.as_tensor_name())
|
||||
{
|
||||
self.get_tensor(name)?
|
||||
} else {
|
||||
let scalar = self.get_float_arg(node, FLOOR_DIVIDE_OTHER_ARG)? as f32;
|
||||
self.graph
|
||||
.constant_float(scalar)
|
||||
.cast(a.dtype)
|
||||
.expand_rhs(a.shape)
|
||||
};
|
||||
let (a, b) = crate::pt2_util::ensure_same_dtype(a, b);
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
let quotient = a.cast(DType::F32) / b.cast(DType::F32);
|
||||
let trunc = quotient.cast(DType::Int).cast(DType::F32);
|
||||
let adjust = quotient.lt(trunc).cast(DType::F32);
|
||||
let floored = trunc - adjust;
|
||||
Ok(if a.dtype == DType::Int {
|
||||
floored.cast(DType::Int)
|
||||
} else {
|
||||
floored.cast(a.dtype)
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn translate_div_tensor_mode(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, DIV_MODE_INPUT_ARG)?;
|
||||
let b = if let Some(name) = node
|
||||
.inputs
|
||||
.get(DIV_MODE_OTHER_ARG)
|
||||
.and_then(|i| i.arg.as_tensor_name())
|
||||
{
|
||||
self.get_tensor(name)?
|
||||
} else {
|
||||
let scalar = self.get_float_arg(node, DIV_MODE_OTHER_ARG)? as f32;
|
||||
self.graph
|
||||
.constant_float(scalar)
|
||||
.cast(a.dtype)
|
||||
.expand_rhs(a.shape)
|
||||
};
|
||||
let (a, b) = crate::pt2_util::ensure_same_dtype(a, b);
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
|
||||
// Check rounding_mode kwarg
|
||||
let rounding_mode = node.inputs.iter().find_map(|input| {
|
||||
if input.name == "rounding_mode"
|
||||
&& let Argument::Other(val) = &input.arg
|
||||
{
|
||||
return val.as_str().map(|s| s.to_string());
|
||||
}
|
||||
None
|
||||
});
|
||||
|
||||
let quotient = a.cast(DType::F32) / b.cast(DType::F32);
|
||||
match rounding_mode.as_deref() {
|
||||
Some("floor") => {
|
||||
let trunc = quotient.cast(DType::Int).cast(DType::F32);
|
||||
let adjust = quotient.lt(trunc).cast(DType::F32);
|
||||
let floored = trunc - adjust;
|
||||
Ok(if a.dtype == DType::Int {
|
||||
floored.cast(DType::Int)
|
||||
} else {
|
||||
floored.cast(a.dtype)
|
||||
})
|
||||
}
|
||||
Some("trunc") => Ok(if a.dtype == DType::Int {
|
||||
quotient.cast(DType::Int)
|
||||
} else {
|
||||
quotient.cast(DType::Int).cast(a.dtype)
|
||||
}),
|
||||
_ => {
|
||||
// No rounding mode — regular division
|
||||
Ok(quotient.cast(a.dtype))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn translate_clamp(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let min_val = if node.inputs.len() > 1 {
|
||||
|
||||
@@ -95,7 +95,7 @@ class CompiledModel:
|
||||
self._graph.set_input_device_ptr(name, t.data_ptr(), n_bytes)
|
||||
_input_refs.append(t)
|
||||
else:
|
||||
t = tensor.detach().cpu().contiguous()
|
||||
t = tensor.detach().cpu().contiguous().to(expected_dtype)
|
||||
n_bytes = t.numel() * t.element_size()
|
||||
dtype_code = _torch_dtype_code(t.dtype)
|
||||
self._graph.set_input_from_ptr(name, t.data_ptr(), n_bytes, dtype_code)
|
||||
@@ -120,9 +120,10 @@ class CompiledModel:
|
||||
else torch.float32
|
||||
)
|
||||
out = torch.empty(shape, dtype=out_dtype, device=input_device)
|
||||
self._graph.set_output_device_ptr(
|
||||
name, out.data_ptr(), out.numel() * out.element_size()
|
||||
)
|
||||
if out_dtype.is_floating_point:
|
||||
self._graph.set_output_device_ptr(
|
||||
name, out.data_ptr(), out.numel() * out.element_size()
|
||||
)
|
||||
output_tensors.append(out)
|
||||
|
||||
# Run the graph
|
||||
@@ -130,13 +131,42 @@ class CompiledModel:
|
||||
|
||||
# Collect outputs
|
||||
if _use_zero_copy:
|
||||
# For aliased outputs that couldn't be zero-copied, fall back to DtoD copy.
|
||||
for name, out in zip(self._output_names, output_tensors):
|
||||
if not self._graph.output_is_zero_copy(name):
|
||||
self._graph.copy_output_to_device_ptr(
|
||||
name, out.data_ptr(), out.numel() * out.element_size()
|
||||
outputs = []
|
||||
for i, (name, shape) in enumerate(zip(self._output_names, output_shapes)):
|
||||
out_dtype = (
|
||||
code_to_torch_dtype(output_dtype_codes[i])
|
||||
if i < len(output_dtype_codes)
|
||||
else torch.float32
|
||||
)
|
||||
out = output_tensors[i]
|
||||
if out_dtype.is_floating_point:
|
||||
if not self._graph.output_is_zero_copy(name):
|
||||
self._graph.copy_output_to_device_ptr(
|
||||
name, out.data_ptr(), out.numel() * out.element_size()
|
||||
)
|
||||
elif out_dtype == torch.int32:
|
||||
data = self._graph.get_output_i32(name)
|
||||
out = (
|
||||
torch.tensor(data, dtype=torch.int32)
|
||||
.reshape(tuple(shape))
|
||||
.to(input_device)
|
||||
)
|
||||
outputs = output_tensors
|
||||
elif out_dtype == torch.bool:
|
||||
data = self._graph.get_output_bool(name)
|
||||
out = (
|
||||
torch.tensor(data, dtype=torch.bool)
|
||||
.reshape(tuple(shape))
|
||||
.to(input_device)
|
||||
)
|
||||
else:
|
||||
data = self._graph.get_output(name)
|
||||
out = (
|
||||
torch.tensor(data, dtype=torch.float32)
|
||||
.reshape(tuple(shape))
|
||||
.to(out_dtype)
|
||||
.to(input_device)
|
||||
)
|
||||
outputs.append(out)
|
||||
else:
|
||||
# Native path: retrieve as f32, then convert to target dtype if needed.
|
||||
outputs = []
|
||||
@@ -146,13 +176,20 @@ class CompiledModel:
|
||||
if i < len(output_dtype_codes)
|
||||
else torch.float32
|
||||
)
|
||||
data = self._graph.get_output(name)
|
||||
out = (
|
||||
torch.tensor(data, dtype=torch.float32)
|
||||
.reshape(tuple(shape))
|
||||
.to(out_dtype)
|
||||
.to(input_device)
|
||||
)
|
||||
if out_dtype == torch.int32:
|
||||
data = self._graph.get_output_i32(name)
|
||||
out = torch.tensor(data, dtype=torch.int32).reshape(tuple(shape))
|
||||
elif out_dtype == torch.bool:
|
||||
data = self._graph.get_output_bool(name)
|
||||
out = torch.tensor(data, dtype=torch.bool).reshape(tuple(shape))
|
||||
else:
|
||||
data = self._graph.get_output(name)
|
||||
out = (
|
||||
torch.tensor(data, dtype=torch.float32)
|
||||
.reshape(tuple(shape))
|
||||
.to(out_dtype)
|
||||
)
|
||||
out = out.to(input_device)
|
||||
outputs.append(out)
|
||||
|
||||
return tuple(outputs)
|
||||
|
||||
@@ -215,6 +215,7 @@ from test_models import (
|
||||
WhereWithConstantModel,
|
||||
# Xor model
|
||||
XorTestModel,
|
||||
ArgsortStableDuplicatesModel,
|
||||
# Conv models
|
||||
Conv1dNoPadModel,
|
||||
Conv1dSamePadModel,
|
||||
@@ -231,6 +232,7 @@ from test_models import (
|
||||
GroupedConv2dModel,
|
||||
GroupedConv2dGroups3Model,
|
||||
MambaConvBlockModel,
|
||||
TinyMoERoutingModel,
|
||||
)
|
||||
|
||||
from luminal import luminal_backend
|
||||
@@ -1948,6 +1950,54 @@ def test_split(device: torch.device):
|
||||
assert torch.allclose(model_compiled(x), model(x))
|
||||
|
||||
|
||||
# ========== Argsort / MoE Routing Tests ==========
|
||||
|
||||
|
||||
def test_argsort_stable_duplicates(device: torch.device):
|
||||
"""Duplicate values should follow stable lower-index-first tie-breaking."""
|
||||
model: torch.nn.Module = ArgsortStableDuplicatesModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x = torch.tensor(
|
||||
[[2.0, 1.0, 1.0, 3.0]],
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
original: torch.Tensor = model(x)
|
||||
output: torch.Tensor = model_compiled(x)
|
||||
assert output.dtype == torch.int32
|
||||
assert torch.equal(output, original.to(torch.int32))
|
||||
|
||||
|
||||
def test_tiny_moe_routing(device: torch.device):
|
||||
"""Focused proof for build MoE routing support."""
|
||||
model: torch.nn.Module = TinyMoERoutingModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
scores = torch.tensor(
|
||||
[[0.1, 0.9, 0.4, 0.7], [0.6, -0.8, 0.95, 0.2]],
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
expected = model(scores)
|
||||
output = model_compiled(scores)
|
||||
|
||||
expected_dtypes = (
|
||||
torch.int32,
|
||||
torch.float32,
|
||||
torch.int32,
|
||||
torch.bool,
|
||||
torch.int32,
|
||||
torch.float32,
|
||||
)
|
||||
for actual, eager, expected_dtype in zip(output, expected, expected_dtypes):
|
||||
assert actual.dtype == expected_dtype
|
||||
eager = eager.to(actual.dtype)
|
||||
if actual.dtype.is_floating_point:
|
||||
assert torch.allclose(actual, eager)
|
||||
else:
|
||||
assert torch.equal(actual, eager)
|
||||
|
||||
|
||||
# ========== PT2 TopK Node Tests ==========
|
||||
|
||||
|
||||
|
||||
@@ -1619,6 +1619,73 @@ class SplitTestModel(torch.nn.Module):
|
||||
return a + b
|
||||
|
||||
|
||||
# ========== Argsort / MoE Routing Test Models ==========
|
||||
|
||||
|
||||
class ArgsortStableDuplicatesModel(torch.nn.Module):
|
||||
"""Tests deterministic duplicate ordering for exported argsort."""
|
||||
|
||||
SORT_DIM = 1
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.argsort(x, dim=self.SORT_DIM)
|
||||
|
||||
|
||||
class TinyMoERoutingModel(torch.nn.Module):
|
||||
"""Minimal deterministic MoE-style routing proof for PT2/native and CUDA."""
|
||||
|
||||
TOP_K = 2
|
||||
ROUTING_DIM = -1
|
||||
ZERO_FILL = 0.0
|
||||
DISPATCH_ON = 1
|
||||
GROUP_SIZE = 2
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.register_buffer(
|
||||
"expert_scale",
|
||||
torch.tensor([1.5, -0.5, 2.0, 0.25], dtype=torch.float32),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, scores: torch.Tensor
|
||||
) -> tuple[
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
]:
|
||||
topk_values, topk_indices = torch.topk(scores, self.TOP_K, dim=self.ROUTING_DIM)
|
||||
regroup_order = torch.argsort(topk_indices, dim=self.ROUTING_DIM)
|
||||
routed_indices = torch.gather(topk_indices, self.ROUTING_DIM, regroup_order)
|
||||
routed_values = torch.gather(topk_values, self.ROUTING_DIM, regroup_order)
|
||||
|
||||
expert_scale = self.expert_scale.unsqueeze(0).expand(scores.shape[0], -1)
|
||||
gathered_scale = torch.gather(expert_scale, self.ROUTING_DIM, routed_indices)
|
||||
weighted = routed_values * gathered_scale
|
||||
|
||||
inactive_mask = torch.bitwise_not(weighted > 0)
|
||||
masked_values = weighted.masked_fill(inactive_mask, self.ZERO_FILL)
|
||||
|
||||
slots = torch.zeros_like(routed_indices).scatter(
|
||||
self.ROUTING_DIM, regroup_order, self.DISPATCH_ON
|
||||
)
|
||||
active_slots = torch.bitwise_not(inactive_mask).to(slots.dtype)
|
||||
dispatch = slots * active_slots
|
||||
group_ids = torch.floor_divide(routed_indices, self.GROUP_SIZE)
|
||||
routing_sign = torch.sign(masked_values)
|
||||
return (
|
||||
routed_indices,
|
||||
masked_values,
|
||||
dispatch,
|
||||
inactive_mask,
|
||||
group_ids,
|
||||
routing_sign,
|
||||
)
|
||||
|
||||
|
||||
# ========== TopK Node Test Models ==========
|
||||
|
||||
|
||||
@@ -1840,9 +1907,14 @@ class LlamaTransformerBlockModel(torch.nn.Module):
|
||||
class Conv1dNoPadModel(torch.nn.Module):
|
||||
"""Conv1d with no padding: output length shrinks by (kernel-1)."""
|
||||
|
||||
KERNEL_SIZE = 3
|
||||
PADDING = 0
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv1d(8, 16, kernel_size=3, padding=0, bias=False)
|
||||
self.conv = torch.nn.Conv1d(
|
||||
8, 16, kernel_size=self.KERNEL_SIZE, padding=self.PADDING, bias=False
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.conv(x)
|
||||
@@ -1851,9 +1923,14 @@ class Conv1dNoPadModel(torch.nn.Module):
|
||||
class Conv1dSamePadModel(torch.nn.Module):
|
||||
"""Conv1d with same-size padding (output length == input length)."""
|
||||
|
||||
KERNEL_SIZE = 3
|
||||
PADDING = 1
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv1d(8, 16, kernel_size=3, padding=1, bias=False)
|
||||
self.conv = torch.nn.Conv1d(
|
||||
8, 16, kernel_size=self.KERNEL_SIZE, padding=self.PADDING, bias=False
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.conv(x)
|
||||
@@ -1862,9 +1939,14 @@ class Conv1dSamePadModel(torch.nn.Module):
|
||||
class Conv1dBiasModel(torch.nn.Module):
|
||||
"""Conv1d with bias."""
|
||||
|
||||
KERNEL_SIZE = 3
|
||||
PADDING = 1
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv1d(8, 16, kernel_size=3, padding=1, bias=True)
|
||||
self.conv = torch.nn.Conv1d(
|
||||
8, 16, kernel_size=self.KERNEL_SIZE, padding=self.PADDING, bias=True
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.conv(x)
|
||||
@@ -1873,9 +1955,14 @@ class Conv1dBiasModel(torch.nn.Module):
|
||||
class Conv2dNoPadModel(torch.nn.Module):
|
||||
"""Conv2d with no padding: output spatial dims shrink by (kernel-1)."""
|
||||
|
||||
KERNEL_SIZE = 3
|
||||
PADDING = 0
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(3, 16, kernel_size=3, padding=0, bias=False)
|
||||
self.conv = torch.nn.Conv2d(
|
||||
3, 16, kernel_size=self.KERNEL_SIZE, padding=self.PADDING, bias=False
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.conv(x)
|
||||
@@ -1884,9 +1971,14 @@ class Conv2dNoPadModel(torch.nn.Module):
|
||||
class Conv2dSamePadModel(torch.nn.Module):
|
||||
"""Conv2d with same-size padding."""
|
||||
|
||||
KERNEL_SIZE = 3
|
||||
PADDING = 1
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(3, 16, kernel_size=3, padding=1, bias=False)
|
||||
self.conv = torch.nn.Conv2d(
|
||||
3, 16, kernel_size=self.KERNEL_SIZE, padding=self.PADDING, bias=False
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.conv(x)
|
||||
@@ -1895,9 +1987,14 @@ class Conv2dSamePadModel(torch.nn.Module):
|
||||
class Conv2dBiasModel(torch.nn.Module):
|
||||
"""Conv2d with bias."""
|
||||
|
||||
KERNEL_SIZE = 3
|
||||
PADDING = 1
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(3, 16, kernel_size=3, padding=1, bias=True)
|
||||
self.conv = torch.nn.Conv2d(
|
||||
3, 16, kernel_size=self.KERNEL_SIZE, padding=self.PADDING, bias=True
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.conv(x)
|
||||
@@ -1906,10 +2003,19 @@ class Conv2dBiasModel(torch.nn.Module):
|
||||
class Conv2dStrideModel(torch.nn.Module):
|
||||
"""Conv2d with stride=2 (output dims halved)."""
|
||||
|
||||
KERNEL_SIZE = 3
|
||||
STRIDE = 2
|
||||
PADDING = 1
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(
|
||||
3, 16, kernel_size=3, stride=2, padding=1, bias=False
|
||||
3,
|
||||
16,
|
||||
kernel_size=self.KERNEL_SIZE,
|
||||
stride=self.STRIDE,
|
||||
padding=self.PADDING,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@@ -1919,10 +2025,19 @@ class Conv2dStrideModel(torch.nn.Module):
|
||||
class Conv2dDilationModel(torch.nn.Module):
|
||||
"""Conv2d with dilation=2 and padding chosen to preserve spatial size."""
|
||||
|
||||
KERNEL_SIZE = 3
|
||||
DILATION = 2
|
||||
PADDING = 2
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(
|
||||
8, 16, kernel_size=3, dilation=2, padding=2, bias=False
|
||||
8,
|
||||
16,
|
||||
kernel_size=self.KERNEL_SIZE,
|
||||
dilation=self.DILATION,
|
||||
padding=self.PADDING,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@@ -1932,9 +2047,14 @@ class Conv2dDilationModel(torch.nn.Module):
|
||||
class Conv3dSamePadModel(torch.nn.Module):
|
||||
"""Conv3d with padding=1 to preserve spatial dimensions."""
|
||||
|
||||
KERNEL_SIZE = 3
|
||||
PADDING = 1
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv3d(4, 8, kernel_size=3, padding=1, bias=False)
|
||||
self.conv = torch.nn.Conv3d(
|
||||
4, 8, kernel_size=self.KERNEL_SIZE, padding=self.PADDING, bias=False
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.conv(x)
|
||||
@@ -1943,10 +2063,19 @@ class Conv3dSamePadModel(torch.nn.Module):
|
||||
class DepthwiseConv1dModel(torch.nn.Module):
|
||||
"""Depthwise Conv1d as used in Mamba (groups == in_channels)."""
|
||||
|
||||
KERNEL_SIZE = 4
|
||||
GROUPS = 16
|
||||
PADDING = 3
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv1d(
|
||||
16, 16, kernel_size=4, groups=16, padding=3, bias=True
|
||||
16,
|
||||
16,
|
||||
kernel_size=self.KERNEL_SIZE,
|
||||
groups=self.GROUPS,
|
||||
padding=self.PADDING,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@@ -1957,10 +2086,19 @@ class DepthwiseConv1dModel(torch.nn.Module):
|
||||
class DepthwiseConv2dModel(torch.nn.Module):
|
||||
"""Depthwise Conv2d (groups == in_channels)."""
|
||||
|
||||
KERNEL_SIZE = 3
|
||||
GROUPS = 8
|
||||
PADDING = 1
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(
|
||||
8, 8, kernel_size=3, groups=8, padding=1, bias=False
|
||||
8,
|
||||
8,
|
||||
kernel_size=self.KERNEL_SIZE,
|
||||
groups=self.GROUPS,
|
||||
padding=self.PADDING,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@@ -1970,10 +2108,19 @@ class DepthwiseConv2dModel(torch.nn.Module):
|
||||
class DepthwiseMultiplierConv2dModel(torch.nn.Module):
|
||||
"""Depthwise Conv2d with channel multiplier 2 (out_channels = 2 * in_channels)."""
|
||||
|
||||
KERNEL_SIZE = 3
|
||||
GROUPS = 8
|
||||
PADDING = 1
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(
|
||||
8, 16, kernel_size=3, groups=8, padding=1, bias=False
|
||||
8,
|
||||
16,
|
||||
kernel_size=self.KERNEL_SIZE,
|
||||
groups=self.GROUPS,
|
||||
padding=self.PADDING,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@@ -1983,10 +2130,19 @@ class DepthwiseMultiplierConv2dModel(torch.nn.Module):
|
||||
class GroupedConv2dModel(torch.nn.Module):
|
||||
"""Conv2d with groups=4 (not depthwise, but grouped)."""
|
||||
|
||||
KERNEL_SIZE = 3
|
||||
GROUPS = 4
|
||||
PADDING = 1
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(
|
||||
16, 32, kernel_size=3, groups=4, padding=1, bias=False
|
||||
16,
|
||||
32,
|
||||
kernel_size=self.KERNEL_SIZE,
|
||||
groups=self.GROUPS,
|
||||
padding=self.PADDING,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@@ -1996,10 +2152,19 @@ class GroupedConv2dModel(torch.nn.Module):
|
||||
class GroupedConv2dGroups3Model(torch.nn.Module):
|
||||
"""Conv2d with groups=3 and ch_per_group=4."""
|
||||
|
||||
KERNEL_SIZE = 3
|
||||
GROUPS = 3
|
||||
PADDING = 1
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(
|
||||
12, 12, kernel_size=3, groups=3, padding=1, bias=False
|
||||
12,
|
||||
12,
|
||||
kernel_size=self.KERNEL_SIZE,
|
||||
groups=self.GROUPS,
|
||||
padding=self.PADDING,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@@ -2015,9 +2180,16 @@ class MambaConvBlockModel(torch.nn.Module):
|
||||
def __init__(self, d_model: int = 16, d_conv: int = 4, expand: int = 2) -> None:
|
||||
super().__init__()
|
||||
d_inner = d_model * expand
|
||||
groups = d_inner
|
||||
padding = d_conv - 1
|
||||
self.in_proj = torch.nn.Linear(d_model, d_inner * 2, bias=False)
|
||||
self.conv1d = torch.nn.Conv1d(
|
||||
d_inner, d_inner, d_conv, groups=d_inner, padding=d_conv - 1, bias=True
|
||||
d_inner,
|
||||
d_inner,
|
||||
d_conv,
|
||||
groups=groups,
|
||||
padding=padding,
|
||||
bias=True,
|
||||
)
|
||||
self.out_proj = torch.nn.Linear(d_inner, d_model, bias=False)
|
||||
|
||||
|
||||
22
examples/gemma4_moe/Cargo.toml
Normal file
22
examples/gemma4_moe/Cargo.toml
Normal file
@@ -0,0 +1,22 @@
|
||||
[package]
|
||||
name = "gemma4_moe"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[features]
|
||||
|
||||
[dependencies]
|
||||
luminal = { path = "../.." }
|
||||
luminal_nn = { path = "../../crates/luminal_nn" }
|
||||
luminal_cuda_lite = { path = "../../crates/luminal_cuda_lite" }
|
||||
tokenizers = "0.22.2"
|
||||
rustc-hash = "2"
|
||||
|
||||
# HuggingFace model download
|
||||
hf-hub = { version = "0.4", default-features = false, features = ["rustls-tls", "ureq"] }
|
||||
safetensors = "0.7.0"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
half = { version = "2.7.1", features = ["bytemuck"] }
|
||||
bytemuck = "1.24.0"
|
||||
memmap2 = "0.9.9"
|
||||
227
examples/gemma4_moe/src/hf.rs
Normal file
227
examples/gemma4_moe/src/hf.rs
Normal file
@@ -0,0 +1,227 @@
|
||||
use half::{bf16, f16};
|
||||
use hf_hub::api::sync::Api;
|
||||
use memmap2::MmapOptions;
|
||||
use safetensors::{tensor::TensorView, Dtype, SafeTensors};
|
||||
use serde::Deserialize;
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
fs::File,
|
||||
io::Write,
|
||||
path::{Path, PathBuf},
|
||||
};
|
||||
|
||||
use crate::model::HIDDEN;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct SafetensorsIndex {
|
||||
weight_map: HashMap<String, String>,
|
||||
}
|
||||
|
||||
enum TensorData {
|
||||
F32(Vec<f32>),
|
||||
BF16(Vec<u8>),
|
||||
}
|
||||
|
||||
struct StoredTensor {
|
||||
shape: Vec<usize>,
|
||||
data: TensorData,
|
||||
}
|
||||
|
||||
pub fn download_hf_model(repo_id: &str) -> Result<PathBuf, Box<dyn std::error::Error>> {
|
||||
let api = Api::new()?;
|
||||
let repo = api.model(repo_id.to_string());
|
||||
|
||||
let tokenizer_path = repo.get("tokenizer.json")?;
|
||||
let model_dir = tokenizer_path.parent().unwrap().to_path_buf();
|
||||
|
||||
if repo.get("model.safetensors").is_ok() {
|
||||
return Ok(model_dir);
|
||||
}
|
||||
|
||||
let index_path = repo.get("model.safetensors.index.json")?;
|
||||
let index_content = std::fs::read_to_string(&index_path)?;
|
||||
let index: SafetensorsIndex = serde_json::from_str(&index_content)?;
|
||||
|
||||
let mut shard_files: Vec<String> = index.weight_map.values().cloned().collect();
|
||||
shard_files.sort();
|
||||
shard_files.dedup();
|
||||
|
||||
for shard_file in &shard_files {
|
||||
repo.get(shard_file)?;
|
||||
}
|
||||
|
||||
Ok(model_dir)
|
||||
}
|
||||
|
||||
fn tensor_to_f32(tensor: &safetensors::tensor::TensorView) -> Vec<f32> {
|
||||
match tensor.dtype() {
|
||||
Dtype::F32 => bytemuck::cast_slice::<u8, f32>(tensor.data()).to_vec(),
|
||||
Dtype::F16 => {
|
||||
let f16_slice: &[f16] = bytemuck::cast_slice(tensor.data());
|
||||
f16_slice.iter().map(|x| x.to_f32()).collect()
|
||||
}
|
||||
Dtype::BF16 => {
|
||||
let bf16_slice: &[bf16] = bytemuck::cast_slice(tensor.data());
|
||||
bf16_slice.iter().map(|x| x.to_f32()).collect()
|
||||
}
|
||||
other => panic!("Unsupported dtype for conversion: {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
fn tensor_to_bf16_bytes(tensor: &safetensors::tensor::TensorView) -> Vec<u8> {
|
||||
match tensor.dtype() {
|
||||
Dtype::BF16 => tensor.data().to_vec(),
|
||||
Dtype::F16 => {
|
||||
let f16_slice: &[f16] = bytemuck::cast_slice(tensor.data());
|
||||
let bf16_data: Vec<bf16> = f16_slice
|
||||
.iter()
|
||||
.map(|x| bf16::from_f32(x.to_f32()))
|
||||
.collect();
|
||||
bytemuck::cast_slice(&bf16_data).to_vec()
|
||||
}
|
||||
Dtype::F32 => {
|
||||
let f32_slice: &[f32] = bytemuck::cast_slice(tensor.data());
|
||||
let bf16_data: Vec<bf16> = f32_slice.iter().map(|x| bf16::from_f32(*x)).collect();
|
||||
bytemuck::cast_slice(&bf16_data).to_vec()
|
||||
}
|
||||
other => panic!("Unsupported dtype for conversion: {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
fn is_text_weight(name: &str) -> bool {
|
||||
name.starts_with("model.language_model.")
|
||||
}
|
||||
|
||||
fn is_expert_weight(name: &str) -> bool {
|
||||
name.contains(".experts.")
|
||||
}
|
||||
|
||||
pub fn combine_safetensors(model_dir: &Path) -> Result<PathBuf, Box<dyn std::error::Error>> {
|
||||
let output_path = model_dir.join("model_combined.safetensors");
|
||||
if output_path.exists() {
|
||||
return Ok(output_path);
|
||||
}
|
||||
|
||||
let index_path = model_dir.join("model.safetensors.index.json");
|
||||
let single_shard_path = model_dir.join("model.safetensors");
|
||||
|
||||
let shard_files: Vec<PathBuf> = if single_shard_path.exists() && !index_path.exists() {
|
||||
println!("Single shard model detected...");
|
||||
vec![single_shard_path]
|
||||
} else if index_path.exists() {
|
||||
let index_content = std::fs::read_to_string(&index_path)?;
|
||||
let index: SafetensorsIndex = serde_json::from_str(&index_content)?;
|
||||
|
||||
let mut files: Vec<String> = index.weight_map.values().cloned().collect();
|
||||
files.sort();
|
||||
files.dedup();
|
||||
|
||||
println!("Loading {} shard files...", files.len());
|
||||
files.into_iter().map(|f| model_dir.join(f)).collect()
|
||||
} else {
|
||||
return Err("No model.safetensors or model.safetensors.index.json found".into());
|
||||
};
|
||||
|
||||
let mut all_tensors: HashMap<String, StoredTensor> = HashMap::new();
|
||||
|
||||
for shard_path in &shard_files {
|
||||
println!(
|
||||
" Loading {}...",
|
||||
shard_path.file_name().unwrap().to_string_lossy()
|
||||
);
|
||||
let file = File::open(shard_path)?;
|
||||
let mmap = unsafe { MmapOptions::new().map(&file)? };
|
||||
let st = SafeTensors::deserialize(&mmap)?;
|
||||
|
||||
for name in st.names() {
|
||||
if !is_text_weight(name) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let new_name = name.replacen("model.language_model.", "model.", 1);
|
||||
let tensor = st.tensor(name)?;
|
||||
|
||||
if new_name.ends_with(".layer_scalar") {
|
||||
let scalar = tensor_to_f32(&tensor);
|
||||
let scalar = *scalar.first().expect("layer_scalar tensor is empty");
|
||||
all_tensors.insert(
|
||||
new_name,
|
||||
StoredTensor {
|
||||
shape: vec![HIDDEN],
|
||||
data: TensorData::F32(vec![scalar; HIDDEN]),
|
||||
},
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
let shape = tensor.shape().to_vec();
|
||||
let data = if is_expert_weight(&new_name) {
|
||||
TensorData::BF16(tensor_to_bf16_bytes(&tensor))
|
||||
} else {
|
||||
TensorData::F32(tensor_to_f32(&tensor))
|
||||
};
|
||||
|
||||
all_tensors.insert(new_name, StoredTensor { shape, data });
|
||||
}
|
||||
}
|
||||
|
||||
println!("Extracted {} text tensors", all_tensors.len());
|
||||
|
||||
let embed_key = "model.embed_tokens.weight";
|
||||
if let Some(embed_tensor) = all_tensors.get(embed_key) {
|
||||
let (shape, embed_data) = match &embed_tensor.data {
|
||||
TensorData::F32(data) => (embed_tensor.shape.clone(), data.clone()),
|
||||
TensorData::BF16(_) => unreachable!("Embedding weights should stay in F32"),
|
||||
};
|
||||
|
||||
all_tensors.insert(
|
||||
"lm_head.weight".to_string(),
|
||||
StoredTensor {
|
||||
shape,
|
||||
data: TensorData::F32(embed_data.clone()),
|
||||
},
|
||||
);
|
||||
|
||||
let embed_scale = (HIDDEN as f32).sqrt();
|
||||
if let Some(stored) = all_tensors.get_mut(embed_key) {
|
||||
match &mut stored.data {
|
||||
TensorData::F32(data) => {
|
||||
for value in data {
|
||||
*value *= embed_scale;
|
||||
}
|
||||
}
|
||||
TensorData::BF16(_) => unreachable!("Embedding weights should stay in F32"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
println!("Saving combined model (BF16 experts + F32 rest)...");
|
||||
let tensor_views: HashMap<String, TensorView<'_>> = all_tensors
|
||||
.iter()
|
||||
.map(|(name, stored)| {
|
||||
let view = match &stored.data {
|
||||
TensorData::F32(data) => {
|
||||
let bytes: &[u8] = bytemuck::cast_slice(data);
|
||||
TensorView::new(Dtype::F32, stored.shape.clone(), bytes).unwrap()
|
||||
}
|
||||
TensorData::BF16(bytes) => {
|
||||
TensorView::new(Dtype::BF16, stored.shape.clone(), bytes).unwrap()
|
||||
}
|
||||
};
|
||||
(name.clone(), view)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let serialized = safetensors::serialize(&tensor_views, None)?;
|
||||
let mut file = File::create(&output_path)?;
|
||||
file.write_all(&serialized)?;
|
||||
|
||||
println!("Combined model saved successfully!");
|
||||
Ok(output_path)
|
||||
}
|
||||
|
||||
pub fn prepare_hf_model(repo_id: &str) -> Result<PathBuf, Box<dyn std::error::Error>> {
|
||||
let model_dir = download_hf_model(repo_id)?;
|
||||
combine_safetensors(&model_dir)?;
|
||||
Ok(model_dir)
|
||||
}
|
||||
190
examples/gemma4_moe/src/main.rs
Normal file
190
examples/gemma4_moe/src/main.rs
Normal file
@@ -0,0 +1,190 @@
|
||||
mod hf;
|
||||
mod model;
|
||||
|
||||
use hf::prepare_hf_model;
|
||||
use luminal::prelude::*;
|
||||
use luminal_cuda_lite::{cudarc::driver::CudaContext, runtime::CudaRuntime};
|
||||
use model::*;
|
||||
use rustc_hash::FxHashSet;
|
||||
use std::{io::Write, time::Duration};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
const REPO_ID: &str = "google/gemma-4-26B-A4B";
|
||||
|
||||
fn env_usize(name: &str, default: usize) -> usize {
|
||||
std::env::var(name)
|
||||
.ok()
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(default)
|
||||
}
|
||||
|
||||
fn env_bool(name: &str) -> bool {
|
||||
std::env::var(name)
|
||||
.ok()
|
||||
.is_some_and(|s| matches!(s.as_str(), "1" | "true" | "TRUE" | "yes" | "YES"))
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let max_seq_len = env_usize("MAX_SEQ_LEN", 4096);
|
||||
let gen_tokens = env_usize("GEN_TOKENS", 30);
|
||||
let search_graphs = env_usize("SEARCH_GRAPHS", 50);
|
||||
let prompt = std::env::var("PROMPT").unwrap_or_else(|_| "The capital of France is".to_string());
|
||||
let print_token_ids = env_bool("PRINT_TOKEN_IDS");
|
||||
|
||||
let ctx = CudaContext::new(0).unwrap();
|
||||
let stream = ctx.default_stream();
|
||||
|
||||
let model_dir = prepare_hf_model(REPO_ID).expect("Failed to prepare model");
|
||||
println!("Using model directory: {}", model_dir.display());
|
||||
|
||||
let tokenizer = Tokenizer::from_file(model_dir.join("tokenizer.json")).unwrap();
|
||||
let prompt_tokens = tokenizer
|
||||
.encode(prompt.as_str(), true)
|
||||
.unwrap()
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
|
||||
let mut cx = Graph::default();
|
||||
let input = cx.named_tensor("input", 's').as_dtype(DType::Int);
|
||||
let pos_ids = cx.named_tensor("pos_ids", 's').as_dtype(DType::Int);
|
||||
let kv_cache = KVCache::new(&mut cx, max_seq_len);
|
||||
let (logits, cache_outputs) = Gemma4MoE::init(&mut cx).forward(input, pos_ids, &kv_cache);
|
||||
let logits = logits.output();
|
||||
for (k_out, v_out) in &cache_outputs {
|
||||
k_out.output();
|
||||
v_out.output();
|
||||
}
|
||||
|
||||
println!("Building E-Graph...");
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
|
||||
println!("Loading weights...");
|
||||
let mut runtime = CudaRuntime::initialize(stream);
|
||||
let weights_path = model_dir.join("model_combined.safetensors");
|
||||
runtime.load_safetensors(&cx, weights_path.to_str().unwrap());
|
||||
|
||||
for layer in 0..LAYERS {
|
||||
let cache_bytes = cache_bytes_for_layer(layer, max_seq_len);
|
||||
runtime.set_zeros(kv_cache.k_caches[layer], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[layer], cache_bytes);
|
||||
}
|
||||
|
||||
println!("Compiling...");
|
||||
cx.set_dim('s', 1);
|
||||
cx.set_dim('p', 1);
|
||||
runtime.set_data(input, vec![1]);
|
||||
runtime.set_data(pos_ids, vec![1]);
|
||||
runtime = cx.search(runtime, search_graphs);
|
||||
|
||||
for layer in 0..LAYERS {
|
||||
let cache_bytes = cache_bytes_for_layer(layer, max_seq_len);
|
||||
runtime.set_zeros(kv_cache.k_caches[layer], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[layer], cache_bytes);
|
||||
}
|
||||
|
||||
print!("{prompt}");
|
||||
std::io::stdout().flush().unwrap();
|
||||
|
||||
let mut prev_seq = 0usize;
|
||||
let mut fwd_durations = vec![];
|
||||
let mut seen_tokens = FxHashSet::default();
|
||||
let mut generated_token_ids = vec![];
|
||||
let repetition_penalty: f32 = 1.05;
|
||||
|
||||
const EOS_TOKEN: u32 = 1;
|
||||
|
||||
let prefill_start = std::time::Instant::now();
|
||||
for &token in &prompt_tokens {
|
||||
cx.set_dim('s', 1);
|
||||
cx.set_dim('p', prev_seq);
|
||||
runtime.set_data(input, vec![token as i32]);
|
||||
runtime.set_data(pos_ids, vec![prev_seq as i32]);
|
||||
runtime.execute(&cx.dyn_map);
|
||||
|
||||
for (layer_idx, (k_out, v_out)) in cache_outputs.iter().enumerate() {
|
||||
let k_buf = runtime.remove_buffer(*k_out);
|
||||
let v_buf = runtime.remove_buffer(*v_out);
|
||||
runtime.set_buffer(kv_cache.k_caches[layer_idx], k_buf);
|
||||
runtime.set_buffer(kv_cache.v_caches[layer_idx], v_buf);
|
||||
}
|
||||
|
||||
prev_seq += 1;
|
||||
}
|
||||
let prefill_duration = prefill_start.elapsed();
|
||||
|
||||
let logits_data = runtime.get_f32(logits);
|
||||
let last_row = &logits_data[..VOCAB_SIZE];
|
||||
let mut next_token = last_row
|
||||
.iter()
|
||||
.enumerate()
|
||||
.max_by(|(_, a), (_, b)| a.total_cmp(b))
|
||||
.unwrap()
|
||||
.0 as u32;
|
||||
generated_token_ids.push(next_token);
|
||||
print!("{}", tokenizer.decode(&[next_token], true).unwrap());
|
||||
std::io::stdout().flush().unwrap();
|
||||
seen_tokens.insert(next_token);
|
||||
|
||||
for _ in 1..gen_tokens {
|
||||
let start = std::time::Instant::now();
|
||||
cx.set_dim('s', 1);
|
||||
cx.set_dim('p', prev_seq);
|
||||
runtime.set_data(input, vec![next_token as i32]);
|
||||
runtime.set_data(pos_ids, vec![prev_seq as i32]);
|
||||
runtime.execute(&cx.dyn_map);
|
||||
|
||||
for (layer_idx, (k_out, v_out)) in cache_outputs.iter().enumerate() {
|
||||
let k_buf = runtime.remove_buffer(*k_out);
|
||||
let v_buf = runtime.remove_buffer(*v_out);
|
||||
runtime.set_buffer(kv_cache.k_caches[layer_idx], k_buf);
|
||||
runtime.set_buffer(kv_cache.v_caches[layer_idx], v_buf);
|
||||
}
|
||||
|
||||
prev_seq += 1;
|
||||
|
||||
let logits_data = runtime.get_f32(logits);
|
||||
let mut last_row = logits_data[..VOCAB_SIZE].to_vec();
|
||||
for &tok in &seen_tokens {
|
||||
let logit = &mut last_row[tok as usize];
|
||||
if *logit > 0.0 {
|
||||
*logit /= repetition_penalty;
|
||||
} else {
|
||||
*logit *= repetition_penalty;
|
||||
}
|
||||
}
|
||||
next_token = last_row
|
||||
.iter()
|
||||
.enumerate()
|
||||
.max_by(|(_, a), (_, b)| a.total_cmp(b))
|
||||
.unwrap()
|
||||
.0 as u32;
|
||||
generated_token_ids.push(next_token);
|
||||
seen_tokens.insert(next_token);
|
||||
|
||||
if next_token == EOS_TOKEN {
|
||||
break;
|
||||
}
|
||||
|
||||
print!("{}", tokenizer.decode(&[next_token], true).unwrap());
|
||||
std::io::stdout().flush().unwrap();
|
||||
fwd_durations.push(start.elapsed());
|
||||
}
|
||||
println!();
|
||||
if print_token_ids {
|
||||
println!("Generated token ids: {generated_token_ids:?}");
|
||||
}
|
||||
|
||||
println!(
|
||||
" TTFT: {:.2} ms ({} prompt tokens)",
|
||||
prefill_duration.as_secs_f64() * 1e3,
|
||||
prompt_tokens.len()
|
||||
);
|
||||
if fwd_durations.len() > 1 {
|
||||
println!(
|
||||
" TPOT: {:.2} ms",
|
||||
(fwd_durations.iter().skip(1).sum::<Duration>() / (fwd_durations.len() - 1) as u32)
|
||||
.as_secs_f64()
|
||||
* 1_000.
|
||||
);
|
||||
}
|
||||
}
|
||||
621
examples/gemma4_moe/src/model.rs
Normal file
621
examples/gemma4_moe/src/model.rs
Normal file
@@ -0,0 +1,621 @@
|
||||
use luminal::{
|
||||
dtype::DType,
|
||||
graph::Graph,
|
||||
prelude::{F32Pow, GraphTensor},
|
||||
shape::Expression,
|
||||
};
|
||||
use luminal_nn::LayerNorm;
|
||||
|
||||
pub const LAYERS: usize = 30;
|
||||
pub const HIDDEN: usize = 2816;
|
||||
pub const INTERMEDIATE: usize = 2112;
|
||||
pub const MOE_INTERMEDIATE: usize = 704;
|
||||
pub const NUM_EXPERTS: usize = 128;
|
||||
pub const TOP_K: usize = 8;
|
||||
pub const N_HEADS: usize = 16;
|
||||
pub const SLIDING_HEAD_DIM: usize = 256;
|
||||
pub const FULL_HEAD_DIM: usize = 512;
|
||||
pub const SLIDING_KV_HEADS: usize = 8;
|
||||
pub const FULL_KV_HEADS: usize = 2;
|
||||
pub const VOCAB_SIZE: usize = 262144;
|
||||
pub const RMS_NORM_EPS: f32 = 1e-6;
|
||||
pub const SLIDING_WINDOW_SIZE: usize = 1024;
|
||||
pub const SLIDING_ROPE_THETA: f32 = 10_000.0;
|
||||
pub const FULL_ROPE_THETA: f32 = 1_000_000.0;
|
||||
pub const FULL_PARTIAL_ROTARY_FACTOR: f32 = 0.25;
|
||||
pub const FINAL_LOGIT_SOFTCAP: f32 = 30.0;
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
struct LayerSpec {
|
||||
is_sliding: bool,
|
||||
head_dim: usize,
|
||||
q_dim: usize,
|
||||
num_kv_heads: usize,
|
||||
kv_dim: usize,
|
||||
kv_groups: usize,
|
||||
rope_theta: f32,
|
||||
partial_rotary_factor: f32,
|
||||
has_v_proj: bool,
|
||||
}
|
||||
|
||||
fn layer_spec(layer: usize) -> LayerSpec {
|
||||
if !(layer + 1).is_multiple_of(6) {
|
||||
LayerSpec {
|
||||
is_sliding: true,
|
||||
head_dim: SLIDING_HEAD_DIM,
|
||||
q_dim: N_HEADS * SLIDING_HEAD_DIM,
|
||||
num_kv_heads: SLIDING_KV_HEADS,
|
||||
kv_dim: SLIDING_KV_HEADS * SLIDING_HEAD_DIM,
|
||||
kv_groups: N_HEADS / SLIDING_KV_HEADS,
|
||||
rope_theta: SLIDING_ROPE_THETA,
|
||||
partial_rotary_factor: 1.0,
|
||||
has_v_proj: true,
|
||||
}
|
||||
} else {
|
||||
LayerSpec {
|
||||
is_sliding: false,
|
||||
head_dim: FULL_HEAD_DIM,
|
||||
q_dim: N_HEADS * FULL_HEAD_DIM,
|
||||
num_kv_heads: FULL_KV_HEADS,
|
||||
kv_dim: FULL_KV_HEADS * FULL_HEAD_DIM,
|
||||
kv_groups: N_HEADS / FULL_KV_HEADS,
|
||||
rope_theta: FULL_ROPE_THETA,
|
||||
partial_rotary_factor: FULL_PARTIAL_ROTARY_FACTOR,
|
||||
has_v_proj: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn cache_bytes_for_layer(layer: usize, max_seq: usize) -> usize {
|
||||
let spec = layer_spec(layer);
|
||||
spec.num_kv_heads * max_seq * spec.head_dim * std::mem::size_of::<f32>()
|
||||
}
|
||||
|
||||
pub struct KVCache {
|
||||
pub k_caches: Vec<GraphTensor>,
|
||||
pub v_caches: Vec<GraphTensor>,
|
||||
pub max_seq: usize,
|
||||
}
|
||||
|
||||
impl KVCache {
|
||||
pub fn new(cx: &mut Graph, max_seq: usize) -> Self {
|
||||
let mut k_caches = Vec::with_capacity(LAYERS);
|
||||
let mut v_caches = Vec::with_capacity(LAYERS);
|
||||
for layer in 0..LAYERS {
|
||||
let spec = layer_spec(layer);
|
||||
let k = cx
|
||||
.named_tensor(
|
||||
format!("kv_cache.{layer}.k"),
|
||||
(spec.num_kv_heads, max_seq, spec.head_dim),
|
||||
)
|
||||
.persist();
|
||||
let v = cx
|
||||
.named_tensor(
|
||||
format!("kv_cache.{layer}.v"),
|
||||
(spec.num_kv_heads, max_seq, spec.head_dim),
|
||||
)
|
||||
.persist();
|
||||
k_caches.push(k);
|
||||
v_caches.push(v);
|
||||
}
|
||||
Self {
|
||||
k_caches,
|
||||
v_caches,
|
||||
max_seq,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Gemma4MoE {
|
||||
embedding: GraphTensor,
|
||||
lm_head: GraphTensor,
|
||||
layers: Vec<Gemma4Layer>,
|
||||
lm_norm: LayerNorm,
|
||||
}
|
||||
|
||||
impl Gemma4MoE {
|
||||
pub fn init(cx: &mut Graph) -> Self {
|
||||
let mut layers = Vec::with_capacity(LAYERS);
|
||||
for layer in 0..LAYERS {
|
||||
let spec = layer_spec(layer);
|
||||
let gate = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.mlp.gate_proj.weight"),
|
||||
(INTERMEDIATE, HIDDEN),
|
||||
)
|
||||
.persist();
|
||||
let up = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.mlp.up_proj.weight"),
|
||||
(INTERMEDIATE, HIDDEN),
|
||||
)
|
||||
.persist();
|
||||
let down = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.mlp.down_proj.weight"),
|
||||
(HIDDEN, INTERMEDIATE),
|
||||
)
|
||||
.persist();
|
||||
|
||||
let q_proj = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.self_attn.q_proj.weight"),
|
||||
(spec.q_dim, HIDDEN),
|
||||
)
|
||||
.persist();
|
||||
let k_proj = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.self_attn.k_proj.weight"),
|
||||
(spec.kv_dim, HIDDEN),
|
||||
)
|
||||
.persist();
|
||||
let v_proj = spec.has_v_proj.then(|| {
|
||||
cx.named_tensor(
|
||||
format!("model.layers.{layer}.self_attn.v_proj.weight"),
|
||||
(spec.kv_dim, HIDDEN),
|
||||
)
|
||||
.persist()
|
||||
});
|
||||
let o_proj = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.self_attn.o_proj.weight"),
|
||||
(HIDDEN, spec.q_dim),
|
||||
)
|
||||
.persist();
|
||||
let q_norm = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.self_attn.q_norm.weight"),
|
||||
spec.head_dim,
|
||||
)
|
||||
.persist();
|
||||
let k_norm = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.self_attn.k_norm.weight"),
|
||||
spec.head_dim,
|
||||
)
|
||||
.persist();
|
||||
let layer_scalar = cx
|
||||
.named_tensor(format!("model.layers.{layer}.layer_scalar"), HIDDEN)
|
||||
.persist();
|
||||
|
||||
let router_scale = cx
|
||||
.named_tensor(format!("model.layers.{layer}.router.scale"), HIDDEN)
|
||||
.persist();
|
||||
let router_proj = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.router.proj.weight"),
|
||||
(NUM_EXPERTS, HIDDEN),
|
||||
)
|
||||
.persist();
|
||||
let per_expert_scale = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.router.per_expert_scale"),
|
||||
NUM_EXPERTS,
|
||||
)
|
||||
.persist();
|
||||
let gate_up_weights = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.experts.gate_up_proj"),
|
||||
(NUM_EXPERTS, MOE_INTERMEDIATE * 2, HIDDEN),
|
||||
)
|
||||
.persist()
|
||||
.as_dtype(DType::Bf16);
|
||||
let down_weights = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.experts.down_proj"),
|
||||
(NUM_EXPERTS, HIDDEN, MOE_INTERMEDIATE),
|
||||
)
|
||||
.persist()
|
||||
.as_dtype(DType::Bf16);
|
||||
|
||||
layers.push(Gemma4Layer {
|
||||
spec,
|
||||
gate,
|
||||
up,
|
||||
down,
|
||||
q_proj,
|
||||
k_proj,
|
||||
v_proj,
|
||||
o_proj,
|
||||
q_norm,
|
||||
k_norm,
|
||||
layer_scalar,
|
||||
input_layernorm: gemma4_norm(
|
||||
HIDDEN,
|
||||
&format!("model.layers.{layer}.input_layernorm.weight"),
|
||||
cx,
|
||||
),
|
||||
post_attention_layernorm: gemma4_norm(
|
||||
HIDDEN,
|
||||
&format!("model.layers.{layer}.post_attention_layernorm.weight"),
|
||||
cx,
|
||||
),
|
||||
pre_feedforward_layernorm: gemma4_norm(
|
||||
HIDDEN,
|
||||
&format!("model.layers.{layer}.pre_feedforward_layernorm.weight"),
|
||||
cx,
|
||||
),
|
||||
post_feedforward_layernorm: gemma4_norm(
|
||||
HIDDEN,
|
||||
&format!("model.layers.{layer}.post_feedforward_layernorm.weight"),
|
||||
cx,
|
||||
),
|
||||
post_feedforward_layernorm_1: gemma4_norm(
|
||||
HIDDEN,
|
||||
&format!("model.layers.{layer}.post_feedforward_layernorm_1.weight"),
|
||||
cx,
|
||||
),
|
||||
post_feedforward_layernorm_2: gemma4_norm(
|
||||
HIDDEN,
|
||||
&format!("model.layers.{layer}.post_feedforward_layernorm_2.weight"),
|
||||
cx,
|
||||
),
|
||||
pre_feedforward_layernorm_2: gemma4_norm(
|
||||
HIDDEN,
|
||||
&format!("model.layers.{layer}.pre_feedforward_layernorm_2.weight"),
|
||||
cx,
|
||||
),
|
||||
moe: Gemma4SparseMoE {
|
||||
router_scale,
|
||||
router_proj,
|
||||
per_expert_scale,
|
||||
gate_up_weights,
|
||||
down_weights,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
let embedding = cx
|
||||
.named_tensor("model.embed_tokens.weight", (VOCAB_SIZE, HIDDEN))
|
||||
.persist();
|
||||
let lm_head = cx
|
||||
.named_tensor("lm_head.weight", (VOCAB_SIZE, HIDDEN))
|
||||
.persist();
|
||||
let lm_norm = gemma4_norm(HIDDEN, "model.norm.weight", cx);
|
||||
|
||||
Self {
|
||||
embedding,
|
||||
lm_head,
|
||||
layers,
|
||||
lm_norm,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward(
|
||||
&self,
|
||||
token_ids: GraphTensor,
|
||||
pos_ids: GraphTensor,
|
||||
kv_cache: &KVCache,
|
||||
) -> (GraphTensor, Vec<(GraphTensor, GraphTensor)>) {
|
||||
let seq = token_ids.dims1();
|
||||
let mut x = self.embedding.gather(
|
||||
(token_ids * HIDDEN).expand_dim(1, HIDDEN)
|
||||
+ token_ids.graph().arange(HIDDEN).expand_dim(0, seq),
|
||||
);
|
||||
|
||||
let mut cache_outputs = Vec::with_capacity(LAYERS);
|
||||
for (layer_idx, layer) in self.layers.iter().enumerate() {
|
||||
let (x_new, k_out, v_out) = layer.forward(
|
||||
x,
|
||||
pos_ids,
|
||||
kv_cache.k_caches[layer_idx],
|
||||
kv_cache.v_caches[layer_idx],
|
||||
kv_cache.max_seq,
|
||||
);
|
||||
x = x_new.graph_break();
|
||||
cache_outputs.push((k_out, v_out));
|
||||
}
|
||||
|
||||
let logits = self.lm_norm.forward(x).matmul(self.lm_head.t());
|
||||
let logits = (logits / FINAL_LOGIT_SOFTCAP).tanh() * FINAL_LOGIT_SOFTCAP;
|
||||
(logits, cache_outputs)
|
||||
}
|
||||
}
|
||||
|
||||
struct Gemma4Layer {
|
||||
spec: LayerSpec,
|
||||
gate: GraphTensor,
|
||||
up: GraphTensor,
|
||||
down: GraphTensor,
|
||||
q_proj: GraphTensor,
|
||||
k_proj: GraphTensor,
|
||||
v_proj: Option<GraphTensor>,
|
||||
o_proj: GraphTensor,
|
||||
q_norm: GraphTensor,
|
||||
k_norm: GraphTensor,
|
||||
layer_scalar: GraphTensor,
|
||||
input_layernorm: LayerNorm,
|
||||
post_attention_layernorm: LayerNorm,
|
||||
pre_feedforward_layernorm: LayerNorm,
|
||||
post_feedforward_layernorm: LayerNorm,
|
||||
post_feedforward_layernorm_1: LayerNorm,
|
||||
post_feedforward_layernorm_2: LayerNorm,
|
||||
pre_feedforward_layernorm_2: LayerNorm,
|
||||
moe: Gemma4SparseMoE,
|
||||
}
|
||||
|
||||
struct Gemma4SparseMoE {
|
||||
router_scale: GraphTensor,
|
||||
router_proj: GraphTensor,
|
||||
per_expert_scale: GraphTensor,
|
||||
gate_up_weights: GraphTensor,
|
||||
down_weights: GraphTensor,
|
||||
}
|
||||
|
||||
fn gemma4_norm(dim: usize, weight_name: &str, cx: &mut Graph) -> LayerNorm {
|
||||
LayerNorm::new(dim, Some(weight_name), None, false, RMS_NORM_EPS, cx)
|
||||
}
|
||||
|
||||
#[allow(clippy::excessive_precision)]
|
||||
fn gemma_gelu(x: GraphTensor) -> GraphTensor {
|
||||
let scaled = 1.5957691216 * x * (1. + 0.044715 * x * x);
|
||||
x * scaled.sigmoid()
|
||||
}
|
||||
|
||||
fn qk_norm(x: GraphTensor, weight: GraphTensor, n_heads: usize, head_dim: usize) -> GraphTensor {
|
||||
let seq = x.dims()[0];
|
||||
let reshaped = x.split_dims(1, head_dim);
|
||||
let normed = reshaped.std_norm(2, RMS_NORM_EPS);
|
||||
let w = weight.expand_dim(0, n_heads).expand_dim(0, seq);
|
||||
(normed * w).merge_dims(1, 2)
|
||||
}
|
||||
|
||||
fn value_norm(x: GraphTensor, head_dim: usize) -> GraphTensor {
|
||||
x.split_dims(1, head_dim)
|
||||
.std_norm(2, RMS_NORM_EPS)
|
||||
.merge_dims(1, 2)
|
||||
}
|
||||
|
||||
fn gemma4_rotary_embeddings(
|
||||
input: GraphTensor,
|
||||
pos_ids: GraphTensor,
|
||||
n_heads: usize,
|
||||
head_dim: usize,
|
||||
rope_theta: f32,
|
||||
partial_rotary_factor: f32,
|
||||
) -> GraphTensor {
|
||||
let input = input.split_dims(1, head_dim).transpose(0, 1);
|
||||
let half_dim = head_dim / 2;
|
||||
let rope_angles = ((partial_rotary_factor * head_dim as f32) / 2.0).floor() as usize;
|
||||
|
||||
let rotated = input
|
||||
.graph()
|
||||
.arange_options(0, rope_angles * 2, 2)
|
||||
.cast(DType::F32)
|
||||
/ head_dim as f32;
|
||||
let rotated = rope_theta.pow(rotated).reciprocal();
|
||||
let inv_freqs = if rope_angles < half_dim {
|
||||
let zeros = input
|
||||
.graph()
|
||||
.arange(half_dim - rope_angles)
|
||||
.cast(DType::F32)
|
||||
* 0.0;
|
||||
rotated.concat_along(zeros, 0)
|
||||
} else {
|
||||
rotated
|
||||
};
|
||||
|
||||
let emb = pos_ids
|
||||
.cast(DType::F32)
|
||||
.expand_dim(1, 1)
|
||||
.matmul(inv_freqs.expand_dim(0, 1));
|
||||
|
||||
let x0 = input.slice((.., .., ..half_dim));
|
||||
let x1 = input.slice((.., .., half_dim..));
|
||||
|
||||
let cos = emb.cos().expand_dim(0, n_heads);
|
||||
let sin = emb.sin().expand_dim(0, n_heads);
|
||||
let x0_out = x0 * cos - x1 * sin;
|
||||
let x1_out = x1 * cos + x0 * sin;
|
||||
|
||||
x0_out
|
||||
.concat_along(x1_out, 2)
|
||||
.transpose(0, 1)
|
||||
.merge_dims(1, 2)
|
||||
}
|
||||
|
||||
fn gather_experts(
|
||||
graph_source: GraphTensor,
|
||||
top_k_indices: GraphTensor,
|
||||
weights: GraphTensor,
|
||||
) -> GraphTensor {
|
||||
let (_, d1, d2) = weights.dims3();
|
||||
let io = d1 * d2;
|
||||
let base = top_k_indices * io;
|
||||
let within = graph_source.graph().iota(Expression::from('z'), (d1, d2));
|
||||
let n_base = base.dims().len();
|
||||
let exp_base = base.expand_dim(n_base, d1).expand_dim(n_base + 1, d2);
|
||||
let mut exp_within = within;
|
||||
for (axis, dim) in base.dims().iter().enumerate() {
|
||||
exp_within = exp_within.expand_dim(axis, *dim);
|
||||
}
|
||||
let expert_flat_idx = exp_base + exp_within;
|
||||
weights.gather(expert_flat_idx)
|
||||
}
|
||||
|
||||
fn hlir_attention(
|
||||
q_rope: GraphTensor,
|
||||
k_rope: GraphTensor,
|
||||
v: GraphTensor,
|
||||
k_cache_in: GraphTensor,
|
||||
v_cache_in: GraphTensor,
|
||||
max_seq: usize,
|
||||
spec: LayerSpec,
|
||||
) -> (GraphTensor, GraphTensor, GraphTensor) {
|
||||
let cx = q_rope.graph();
|
||||
let seq = q_rope.dims()[0];
|
||||
let prev = Expression::from('p');
|
||||
let total_seq = prev + seq;
|
||||
|
||||
let k_new = k_rope.split_dims(1, spec.head_dim).transpose(0, 1);
|
||||
let v_new = v.split_dims(1, spec.head_dim).transpose(0, 1);
|
||||
|
||||
let h_offset = cx.arange(spec.num_kv_heads) * (max_seq * spec.head_dim);
|
||||
let p_offset = (cx.arange(seq) + prev) * spec.head_dim;
|
||||
let d_offset = cx.arange(spec.head_dim);
|
||||
let scatter_idx = h_offset.expand_dim(1, seq).expand_dim(2, spec.head_dim)
|
||||
+ p_offset
|
||||
.expand_dim(0, spec.num_kv_heads)
|
||||
.expand_dim(2, spec.head_dim)
|
||||
+ d_offset.expand_dim(0, spec.num_kv_heads).expand_dim(1, seq);
|
||||
|
||||
let k_cache_out = k_new.scatter(scatter_idx, k_cache_in);
|
||||
let v_cache_out = v_new.scatter(scatter_idx, v_cache_in);
|
||||
|
||||
let k_full = k_cache_out.slice((.., ..total_seq, ..));
|
||||
let v_full = v_cache_out.slice((.., ..total_seq, ..));
|
||||
|
||||
let k_3d = k_full.expand_dim(1, spec.kv_groups).merge_dims(0, 1);
|
||||
let v_3d = v_full.expand_dim(1, spec.kv_groups).merge_dims(0, 1);
|
||||
let q = q_rope.split_dims(1, spec.head_dim).transpose(0, 1);
|
||||
|
||||
// Gemma 4's text attention uses Q/K normalization and then leaves the
|
||||
// attention scaling at 1.0 in the reference implementation.
|
||||
let scores = q.matmul(k_3d.transpose(1, 2));
|
||||
|
||||
let q_abs = cx.arange(seq).cast(DType::F32) + prev;
|
||||
let k_pos = cx.arange(total_seq).cast(DType::F32);
|
||||
let future_mask = k_pos
|
||||
.expand_dim(0, seq)
|
||||
.gt(q_abs.expand_dim(1, total_seq))
|
||||
.cast(DType::F32);
|
||||
|
||||
let mask_2d = if spec.is_sliding {
|
||||
let window_start = q_abs - (SLIDING_WINDOW_SIZE - 1) as f32;
|
||||
let past_mask = window_start
|
||||
.expand_dim(1, total_seq)
|
||||
.gt(k_pos.expand_dim(0, seq))
|
||||
.cast(DType::F32);
|
||||
future_mask + past_mask
|
||||
} else {
|
||||
future_mask
|
||||
};
|
||||
let mask_3d = mask_2d.expand_dim(0, N_HEADS);
|
||||
let masked_scores = scores + mask_3d * (-1e10f32);
|
||||
|
||||
let attn_weights = masked_scores.softmax(2);
|
||||
let attn_out = attn_weights.matmul(v_3d);
|
||||
let out = attn_out.transpose(0, 1).merge_dims(1, 2);
|
||||
|
||||
(out, k_cache_out, v_cache_out)
|
||||
}
|
||||
|
||||
impl Gemma4Layer {
|
||||
pub fn forward(
|
||||
&self,
|
||||
x: GraphTensor,
|
||||
pos_ids: GraphTensor,
|
||||
k_cache_in: GraphTensor,
|
||||
v_cache_in: GraphTensor,
|
||||
max_seq: usize,
|
||||
) -> (GraphTensor, GraphTensor, GraphTensor) {
|
||||
let residual = x;
|
||||
let x_attn = self.input_layernorm.forward(x);
|
||||
let q = x_attn.matmul(self.q_proj.t());
|
||||
let k_base = x_attn.matmul(self.k_proj.t());
|
||||
let v_base = if let Some(v_proj) = self.v_proj {
|
||||
x_attn.matmul(v_proj.t())
|
||||
} else {
|
||||
k_base
|
||||
};
|
||||
|
||||
let q_normed = qk_norm(q, self.q_norm, N_HEADS, self.spec.head_dim);
|
||||
let k_normed = qk_norm(
|
||||
k_base,
|
||||
self.k_norm,
|
||||
self.spec.num_kv_heads,
|
||||
self.spec.head_dim,
|
||||
);
|
||||
let v_normed = value_norm(v_base, self.spec.head_dim);
|
||||
|
||||
let q_rope = gemma4_rotary_embeddings(
|
||||
q_normed,
|
||||
pos_ids,
|
||||
N_HEADS,
|
||||
self.spec.head_dim,
|
||||
self.spec.rope_theta,
|
||||
self.spec.partial_rotary_factor,
|
||||
);
|
||||
let k_rope = gemma4_rotary_embeddings(
|
||||
k_normed,
|
||||
pos_ids,
|
||||
self.spec.num_kv_heads,
|
||||
self.spec.head_dim,
|
||||
self.spec.rope_theta,
|
||||
self.spec.partial_rotary_factor,
|
||||
);
|
||||
|
||||
let (attn_out, k_cache_out, v_cache_out) = hlir_attention(
|
||||
q_rope, k_rope, v_normed, k_cache_in, v_cache_in, max_seq, self.spec,
|
||||
);
|
||||
|
||||
let attn_proj = attn_out.matmul(self.o_proj.t());
|
||||
let x = residual + self.post_attention_layernorm.forward(attn_proj);
|
||||
|
||||
let dense_ff = dense_ffn(
|
||||
self.pre_feedforward_layernorm.forward(x),
|
||||
self.gate,
|
||||
self.up,
|
||||
self.down,
|
||||
);
|
||||
let dense_ff = self.post_feedforward_layernorm_1.forward(dense_ff);
|
||||
|
||||
let moe_out = self
|
||||
.moe
|
||||
.forward(x, self.pre_feedforward_layernorm_2.forward(x));
|
||||
let moe_out = self.post_feedforward_layernorm_2.forward(moe_out);
|
||||
|
||||
let ff_out = self.post_feedforward_layernorm.forward(dense_ff + moe_out);
|
||||
let x = x + ff_out;
|
||||
let x = x * self
|
||||
.layer_scalar
|
||||
.expand_lhs(&x.dims()[..x.dims().len() - 1]);
|
||||
|
||||
(x, k_cache_out, v_cache_out)
|
||||
}
|
||||
}
|
||||
|
||||
fn dense_ffn(x: GraphTensor, gate: GraphTensor, up: GraphTensor, down: GraphTensor) -> GraphTensor {
|
||||
(gemma_gelu(x.matmul(gate.t())) * x.matmul(up.t())).matmul(down.t())
|
||||
}
|
||||
|
||||
impl Gemma4SparseMoE {
|
||||
fn forward(&self, router_input: GraphTensor, expert_input: GraphTensor) -> GraphTensor {
|
||||
let n = router_input.dims().len();
|
||||
let e_dim = *self.router_proj.dims().first().unwrap();
|
||||
let k_expr = Expression::from(TOP_K);
|
||||
|
||||
let router_hidden = router_input.std_norm(router_input.dims().len() - 1, RMS_NORM_EPS)
|
||||
* self
|
||||
.router_scale
|
||||
.expand_lhs(&router_input.dims()[..router_input.dims().len() - 1])
|
||||
* (HIDDEN as f32).sqrt().recip();
|
||||
let routing_weights = router_hidden.matmul(self.router_proj.t()).softmax(n - 1);
|
||||
|
||||
let top_k_indices = routing_weights.topk_indexes(TOP_K, n - 1);
|
||||
let row_offsets = router_input
|
||||
.graph()
|
||||
.iota(Expression::from('z') / k_expr * e_dim, top_k_indices.dims());
|
||||
let routing_flat_idx = row_offsets + top_k_indices;
|
||||
let top_k_values = routing_weights.gather(routing_flat_idx);
|
||||
let top_k_norm = top_k_values.sum(n - 1).expand_dim(n - 1, TOP_K);
|
||||
let top_k_weights =
|
||||
(top_k_values / top_k_norm) * self.per_expert_scale.gather(top_k_indices);
|
||||
|
||||
let gate_up_gathered =
|
||||
gather_experts(expert_input, top_k_indices, self.gate_up_weights).cast(DType::F32);
|
||||
let x_exp = expert_input.expand_dim(n - 1, TOP_K).unsqueeze(n);
|
||||
let gate_up_out = x_exp.matmul(gate_up_gathered.transpose(2, 3)).squeeze(n);
|
||||
|
||||
let gate = gate_up_out.slice((.., .., ..MOE_INTERMEDIATE));
|
||||
let up = gate_up_out.slice((.., .., MOE_INTERMEDIATE..));
|
||||
let hidden = gemma_gelu(gate) * up;
|
||||
|
||||
let down_gathered =
|
||||
gather_experts(expert_input, top_k_indices, self.down_weights).cast(DType::F32);
|
||||
let hidden_exp = hidden.unsqueeze(2);
|
||||
let down_out = hidden_exp.matmul(down_gathered.transpose(2, 3)).squeeze(2);
|
||||
|
||||
(down_out * top_k_weights.unsqueeze(top_k_weights.dims().len())).sum(n - 1)
|
||||
}
|
||||
}
|
||||
@@ -264,8 +264,7 @@ impl QwenMoE {
|
||||
let row_offsets = x
|
||||
.graph()
|
||||
.iota(Expression::from('z') / k_expr * e_dim, top_k_indices.dims());
|
||||
let routing_flat_idx =
|
||||
(row_offsets.cast(DType::F32) + top_k_indices.cast(DType::F32)).cast(DType::Int);
|
||||
let routing_flat_idx = row_offsets + top_k_indices;
|
||||
let top_k_values = routing_weights.gather(routing_flat_idx);
|
||||
|
||||
// 4. Gather gate_up expert weights → [s, k, intermediate*2, H]
|
||||
@@ -303,18 +302,18 @@ fn gather_experts(
|
||||
) -> GraphTensor {
|
||||
let (_, d1, d2) = weights.dims3();
|
||||
let io = d1 * d2;
|
||||
let base = (top_k_indices * io).cast(DType::F32);
|
||||
let within = graph_source
|
||||
.graph()
|
||||
.iota(Expression::from('z'), (d1, d2))
|
||||
.cast(DType::F32);
|
||||
// Keep expert gather indices in Int all the way through. Routing them through
|
||||
// F32 loses exactness once the flat offsets exceed 2^24, which Qwen's expert
|
||||
// tensors do at realistic hidden sizes.
|
||||
let base = top_k_indices * io;
|
||||
let within = graph_source.graph().iota(Expression::from('z'), (d1, d2));
|
||||
let n_base = base.dims().len();
|
||||
let exp_base = base.expand_dim(n_base, d1).expand_dim(n_base + 1, d2);
|
||||
let mut exp_within = within;
|
||||
for (i, dim) in base.dims().iter().enumerate() {
|
||||
exp_within = exp_within.expand_dim(i, *dim);
|
||||
}
|
||||
let expert_flat_idx = (exp_base + exp_within).cast(DType::Int);
|
||||
let expert_flat_idx = exp_base + exp_within;
|
||||
weights.gather(expert_flat_idx)
|
||||
}
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ use rand::Rng;
|
||||
use rustc_hash::FxHashSet;
|
||||
use std::collections::hash_map::DefaultHasher;
|
||||
use std::hash::{Hash, Hasher};
|
||||
use std::{str, sync::Arc};
|
||||
use std::{str, sync::Arc, time::Duration};
|
||||
use tracing::trace;
|
||||
|
||||
pub mod api;
|
||||
@@ -128,8 +128,10 @@ pub fn early_egglog(
|
||||
program.to_string(),
|
||||
format!(
|
||||
"(run-schedule
|
||||
(saturate expr)
|
||||
(run)
|
||||
(repeat 6
|
||||
(saturate expr)
|
||||
(run)
|
||||
)
|
||||
(saturate base_cleanup)
|
||||
)
|
||||
(extract {root})"
|
||||
@@ -179,6 +181,20 @@ pub struct SerializedEGraph {
|
||||
pub roots: Vec<ClassId>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct EgglogStageReport {
|
||||
pub num_matches_per_rule: FxHashMap<String, usize>,
|
||||
pub search_and_apply_time_per_rule: FxHashMap<String, Duration>,
|
||||
pub total_time: Duration,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct EgglogRunReport {
|
||||
pub early: EgglogStageReport,
|
||||
pub full: EgglogStageReport,
|
||||
pub total_time: Duration,
|
||||
}
|
||||
|
||||
impl SerializedEGraph {
|
||||
/// This is an opinionated function which does more than strictly take the state of the egglog object.
|
||||
/// It also filters out "[...]" nodes and then changes the structure from the e-termDAG that egraph-serialize
|
||||
@@ -589,41 +605,34 @@ fn termdag_to_egglog(td: &egglog::TermDag, root: egglog::TermId) -> (String, Str
|
||||
(out.replace("(MVar \"z\")", "(MIter)"), format!("t{root}"))
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub fn run_egglog(
|
||||
program: &str,
|
||||
root: &str,
|
||||
ops: &[Arc<Box<dyn EgglogOp>>],
|
||||
cleanup: bool,
|
||||
) -> Result<SerializedEGraph, egglog::Error> {
|
||||
let start = std::time::Instant::now();
|
||||
let code = early_egglog(program, root, ops, cleanup);
|
||||
let mut egraph = egglog::EGraph::default();
|
||||
let commands = egraph.parser.get_program_from_string(None, &code)?;
|
||||
let outputs = egraph.run_program(commands)?;
|
||||
let CommandOutput::ExtractBest(termdag, _cost, term) = outputs.last().unwrap() else {
|
||||
panic!();
|
||||
};
|
||||
let (program, root) = termdag_to_egglog(termdag, termdag.lookup(term));
|
||||
let code = full_egglog(&program, ops, cleanup);
|
||||
let mut egraph = egglog::EGraph::default();
|
||||
let commands = egraph.parser.get_program_from_string(None, &code)?;
|
||||
trace!("{}", "Egglog running...".green());
|
||||
let _outputs = egraph.run_program(commands)?;
|
||||
trace!("{}", "---- Egglog Rule Matches ----".green());
|
||||
fn stage_report(egraph: &egglog::EGraph, total_time: Duration) -> EgglogStageReport {
|
||||
let run_report = egraph.get_overall_run_report();
|
||||
EgglogStageReport {
|
||||
num_matches_per_rule: run_report
|
||||
.num_matches_per_rule
|
||||
.iter()
|
||||
.map(|(name, matches)| (name.to_string(), *matches))
|
||||
.collect(),
|
||||
search_and_apply_time_per_rule: run_report
|
||||
.search_and_apply_time_per_rule
|
||||
.iter()
|
||||
.map(|(name, elapsed)| (name.to_string(), *elapsed))
|
||||
.collect(),
|
||||
total_time,
|
||||
}
|
||||
}
|
||||
|
||||
fn trace_stage_report(header: &str, report: &EgglogStageReport) {
|
||||
trace!("{}", header.green());
|
||||
trace!(
|
||||
"{}",
|
||||
run_report
|
||||
report
|
||||
.num_matches_per_rule
|
||||
.iter()
|
||||
.filter(|(k, _)| !k.contains("("))
|
||||
.map(|(k, v)| format!(
|
||||
"{k}: {v} ({})",
|
||||
pretty_duration::pretty_duration(
|
||||
&run_report.search_and_apply_time_per_rule[k],
|
||||
None
|
||||
)
|
||||
pretty_duration::pretty_duration(&report.search_and_apply_time_per_rule[k], None)
|
||||
))
|
||||
.join("\n")
|
||||
.green()
|
||||
@@ -631,11 +640,59 @@ pub fn run_egglog(
|
||||
trace!(
|
||||
"{}",
|
||||
format!(
|
||||
"---- Egglog Took {} ----",
|
||||
pretty_duration::pretty_duration(&start.elapsed(), None).bold()
|
||||
"---- {} Took {} ----",
|
||||
header,
|
||||
pretty_duration::pretty_duration(&report.total_time, None).bold()
|
||||
)
|
||||
.green()
|
||||
);
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub fn run_egglog_with_report(
|
||||
program: &str,
|
||||
root: &str,
|
||||
ops: &[Arc<Box<dyn EgglogOp>>],
|
||||
cleanup: bool,
|
||||
) -> Result<(SerializedEGraph, EgglogRunReport), egglog::Error> {
|
||||
let total_start = std::time::Instant::now();
|
||||
|
||||
let early_start = std::time::Instant::now();
|
||||
let code = early_egglog(program, root, ops, cleanup);
|
||||
let mut egraph = egglog::EGraph::default();
|
||||
let commands = egraph.parser.get_program_from_string(None, &code)?;
|
||||
let outputs = egraph.run_program(commands)?;
|
||||
let early_report = stage_report(&egraph, early_start.elapsed());
|
||||
|
||||
let CommandOutput::ExtractBest(termdag, _cost, term) = outputs.last().unwrap() else {
|
||||
panic!();
|
||||
};
|
||||
let (program, root) = termdag_to_egglog(termdag, termdag.lookup(term));
|
||||
|
||||
let full_start = std::time::Instant::now();
|
||||
let code = full_egglog(&program, ops, cleanup);
|
||||
let mut egraph = egglog::EGraph::default();
|
||||
let commands = egraph.parser.get_program_from_string(None, &code)?;
|
||||
trace!("{}", "Egglog running...".green());
|
||||
let _outputs = egraph.run_program(commands)?;
|
||||
let full_report = stage_report(&egraph, full_start.elapsed());
|
||||
trace_stage_report("---- Egglog Early Rule Matches ----", &early_report);
|
||||
trace_stage_report("---- Egglog Full Rule Matches ----", &full_report);
|
||||
|
||||
let run_report = EgglogRunReport {
|
||||
early: early_report,
|
||||
full: full_report,
|
||||
total_time: total_start.elapsed(),
|
||||
};
|
||||
trace!(
|
||||
"{}",
|
||||
format!(
|
||||
"---- Egglog Total Took {} ----",
|
||||
pretty_duration::pretty_duration(&run_report.total_time, None).bold()
|
||||
)
|
||||
.green()
|
||||
);
|
||||
|
||||
let (sort, value) = egraph.eval_expr(&var!(root))?;
|
||||
let s = egraph.serialize(egglog::SerializeConfig {
|
||||
root_eclasses: vec![(sort, value)],
|
||||
@@ -720,7 +777,17 @@ pub fn run_egglog(
|
||||
"No valid graphs present in the e-graph!"
|
||||
);
|
||||
|
||||
Ok(egraph)
|
||||
Ok((egraph, run_report))
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub fn run_egglog(
|
||||
program: &str,
|
||||
root: &str,
|
||||
ops: &[Arc<Box<dyn EgglogOp>>],
|
||||
cleanup: bool,
|
||||
) -> Result<SerializedEGraph, egglog::Error> {
|
||||
run_egglog_with_report(program, root, ops, cleanup).map(|(egraph, _)| egraph)
|
||||
}
|
||||
|
||||
pub fn extract_expr_list<'a>(
|
||||
|
||||
@@ -663,7 +663,7 @@ pub(super) mod tests {
|
||||
let mut out: Vec<(NotNan<f32>, usize)> =
|
||||
heap.into_iter().map(|std::cmp::Reverse(t)| t).collect();
|
||||
|
||||
out.sort_unstable_by(|a, b| b.0.cmp(&a.0));
|
||||
out.sort_unstable_by_key(|b| std::cmp::Reverse(b.0));
|
||||
out.into_iter().map(|(_, i)| i).collect()
|
||||
}
|
||||
test_unary(
|
||||
|
||||
@@ -724,10 +724,8 @@ impl Graph {
|
||||
}
|
||||
|
||||
// Track top-N parents for offspring generation
|
||||
let mut parents: Vec<(
|
||||
R::ProfileMetric,
|
||||
crate::egglog_utils::EGraphChoiceSet<'_>,
|
||||
)> = vec![(best_metric.clone(), best_genome.clone())];
|
||||
let mut parents: Vec<(R::ProfileMetric, crate::egglog_utils::EGraphChoiceSet<'_>)> =
|
||||
vec![(best_metric.clone(), best_genome.clone())];
|
||||
|
||||
while n_graphs < limit {
|
||||
// Generate offspring from all parents, dividing budget evenly
|
||||
@@ -769,8 +767,7 @@ impl Graph {
|
||||
None,
|
||||
);
|
||||
runtime.clear_intermediate_buffers();
|
||||
let result =
|
||||
runtime.profile(&llir_graph, dyn_map, options.trials);
|
||||
let result = runtime.profile(&llir_graph, dyn_map, options.trials);
|
||||
let has_nan = runtime.has_nan_outputs(&llir_graph, dyn_map);
|
||||
(result, llir_graph, has_nan)
|
||||
}));
|
||||
|
||||
Reference in New Issue
Block a user