removed special rope op for gemma

This commit is contained in:
Joe Fioti
2026-02-01 07:17:07 +00:00
parent 8f0790736d
commit 56bbd29e21
2 changed files with 114 additions and 146 deletions

View File

@@ -14,7 +14,7 @@ const REPO_ID: &str = "unsloth/gemma-3-4b-it";
fn main() {
let max_seq_len = 4096;
let gen_tokens = 30;
let gen_tokens = 100;
let search_graphs = 5; // the number of graphs we want to search during compilation
let prompt = "Explain what a neural network is in simple terms:";

View File

@@ -1,8 +1,8 @@
use luminal::{
graph::Graph,
op::{CustomOp, LLIROp},
prelude::GraphTensor,
shape::{flatten_mul_strides, Expression, ToShape},
op::{CustomOp, DType, LLIROp},
prelude::{F32Pow, GraphTensor},
shape::{flatten_mul_strides, Expression, ShapeTracker, ToShape},
};
use luminal_cuda::{
block::{cstruct::CStruct, BlockOp},
@@ -23,6 +23,70 @@ pub const KV_DIM: usize = N_KV_HEADS * HEAD_DIM; // = 1024
pub const VOCAB_SIZE: usize = 262208;
pub const RMS_NORM_EPS: f32 = 1e-6;
// Attention pattern constants
pub const SLIDING_WINDOW_PATTERN: usize = 6; // Every 6th layer is global attention
pub const SLIDING_WINDOW_SIZE: usize = 1024; // Local attention window size
pub const ROPE_THETA_GLOBAL: f32 = 1_000_000.0; // RoPE base for global attention layers
pub const ROPE_THETA_LOCAL: f32 = 10_000.0; // RoPE base for local attention layers
/// Apply QK-Norm + RoPE using frontend HLIR operations
fn gemma_qk_norm_rope(
mut input: GraphTensor,
norm_weight: GraphTensor,
pos_ids: GraphTensor,
rope_theta: f32,
) -> GraphTensor {
let orig_shape = input.shape;
// Reshape: (seq, dim) -> (n_heads, seq, head_dim)
input = input.split_dims(1, HEAD_DIM).transpose(0, 1);
// Apply QK-Norm: RMS norm along head_dim with learnable weights
// Note: weights are pre-transformed to (1 + weight) in setup.py
input = input.std_norm(2, RMS_NORM_EPS);
input = input * norm_weight.expand_lhs(&input.dims()[..input.dims().len() - 1]);
// Get freqs: theta_i = base^(-2i/d) for i in [0, d/2)
let freqs = input
.graph()
.arange_options(0, HEAD_DIM, 2)
.cast(DType::F32)
/ HEAD_DIM as f32;
let inv_freqs = rope_theta.pow(freqs).reciprocal();
// emb = pos * inv_freqs, shape (seq, head_dim/2)
let emb = pos_ids
.cast(DType::F32)
.expand_dim(1, 1)
.matmul(inv_freqs.expand_dim(0, 1));
// Split input into first half (x0) and second half (x1) - SPLIT-HALF style
// input: (n_heads, seq, head_dim) -> split into (n_heads, seq, head_dim/2) each
let half_dim = HEAD_DIM / 2;
let x0 = input.slice((.., .., ..half_dim)); // First half: indices 0-127
let x1 = input.slice((.., .., half_dim..)); // Second half: indices 128-255
// Apply rotary embeddings:
// x0_out = x0 * cos(emb) - x1 * sin(emb)
// x1_out = x1 * cos(emb) + x0 * sin(emb)
let cos_emb = emb.cos().expand_dim(0, x0.dims()[0]);
let sin_emb = emb.sin().expand_dim(0, x0.dims()[0]);
let x0_out = x0 * cos_emb - x1 * sin_emb;
let x1_out = x1 * cos_emb + x0 * sin_emb;
// Concatenate back: [first_half | second_half]
let mut s = x0_out.concat_along(x1_out, 2);
// Set proper strides and reshape back
let n_heads = input.dims()[0];
let seq_dim = input.dims()[1];
s.shape = ShapeTracker::new((n_heads, seq_dim, HEAD_DIM));
s = s.transpose(0, 1) * 1.0;
s.shape = orig_shape;
s
}
/// Gemma-specific RMSNorm
/// Note: weights are pre-transformed to (1 + weight) in setup.py
pub struct GemmaRMSNorm {
@@ -59,6 +123,7 @@ impl Gemma {
pub fn init(cx: &mut Graph) -> Self {
let mut w = vec![];
for l in 0..LAYERS {
let is_local = (l + 1) % SLIDING_WINDOW_PATTERN != 0;
w.push(GemmaLayer {
up: cx.named_tensor(
format!("model.layers.{l}.mlp.up_proj.weight"),
@@ -120,6 +185,12 @@ impl Gemma {
format!("model.layers.{l}.self_attn.k_norm.weight"),
HEAD_DIM,
),
is_local,
rope_theta: if is_local {
ROPE_THETA_LOCAL
} else {
ROPE_THETA_GLOBAL
},
});
}
let lm_norm = GemmaRMSNorm::new(HIDDEN, "model.norm.weight", RMS_NORM_EPS, cx);
@@ -170,131 +241,8 @@ struct GemmaLayer {
post_feedforward_layernorm: GemmaRMSNorm,
q_norm: GraphTensor,
k_norm: GraphTensor,
}
/// Fused QK-Norm + RoPE custom operation (interleaved format for Gemma)
#[derive(Debug, Clone)]
pub struct GemmaQKNormRoPE {
range: Vec<Expression>,
inp_stride: Vec<Expression>,
row_width: Expression,
}
impl GemmaQKNormRoPE {
fn new(seq: Expression, row_width: Expression) -> Self {
Self {
range: vec![seq],
inp_stride: vec![row_width],
row_width,
}
}
}
impl CustomOp for GemmaQKNormRoPE {
fn to_llir_op(&self) -> LLIROp {
LLIROp::new::<dyn BlockOp>(Box::new(self.clone()))
}
}
impl BlockOp for GemmaQKNormRoPE {
fn op_name(&self) -> &'static str {
"GemmaQKNormRoPE"
}
fn launch_range(&self) -> Vec<Expression> {
self.range.clone()
}
fn output_size(&self) -> Expression {
self.range.iter().copied().product::<Expression>() * self.row_width
}
fn producer_barriers_seperate(&self) -> Vec<bool> {
vec![true; self.range.len()]
}
fn consumer_barriers_seperate(&self) -> Vec<Vec<bool>> {
vec![
vec![true; self.range.len()],
vec![true; self.range.len()],
vec![true; self.range.len()],
]
}
fn cuda_function(&self) -> String {
format!(
r#"
__shared__ float rms_scale_shared;
const float* inp = source_ptrs[0] + eval_expression(payload.inp, current);
float* out = out_ptr + eval_expression(payload.out, current);
const float* weight = source_ptrs[1];
const int* token_ids = (const int*)source_ptrs[2] + eval_expression(payload.token_ids, current);
const int D_total = eval_expression(payload.row_width, 0);
const int d_head = {HEAD_DIM};
const int n_heads = D_total / d_head;
const int pos = token_ids[0];
const float base = 1000000.0f;
const float eps = {RMS_NORM_EPS}f;
const int half = d_head / 2;
// Process each head
for (int h = 0; h < n_heads; ++h) {{
const float* head_in = inp + h * d_head;
float* head_out = out + h * d_head;
// Step 1: Compute sum of squares for RMS norm
if (t == 0) {{
float sum_sq = 0.0f;
for (int k = 0; k < d_head; ++k) {{
float val = head_in[k];
sum_sq += val * val;
}}
rms_scale_shared = rsqrtf(sum_sq / (float)d_head + eps);
}}
__syncthreads();
float rms_scale = rms_scale_shared;
// Step 2: Apply RMS norm + weight and RoPE (interleaved format)
for (int k = t; k < half; k += blockDim.x) {{
// Interleaved indexing: pairs are (0,1), (2,3), (4,5), ...
const int j_even = 2 * k;
const int j_odd = 2 * k + 1;
// Apply RMS norm and weight (weights have 1+w pre-applied in setup.py)
float x_even = head_in[j_even] * rms_scale * weight[j_even];
float x_odd = head_in[j_odd] * rms_scale * weight[j_odd];
// Compute RoPE rotation
const float exponent = -(2.0f * (float)k) / (float)d_head;
const float theta = (float)pos * __powf(base, exponent);
float s, c;
__sincosf(theta, &s, &c);
// Rotation
head_out[j_even] = x_even * c - x_odd * s;
head_out[j_odd] = x_even * s + x_odd * c;
}}
__syncthreads();
}}
"#,
HEAD_DIM = HEAD_DIM,
RMS_NORM_EPS = RMS_NORM_EPS
)
}
fn build_payload<'a>(&self, _: &Arc<CudaStream>, payload: CStruct<'a>) -> CStruct<'a> {
payload
.expr("inp", flatten_mul_strides(&self.range, &self.inp_stride))
.expr("out", flatten_mul_strides(&self.range, &self.inp_stride))
.expr("row_width", self.row_width)
.expr("weight", 0)
.expr("token_ids", 'z')
}
is_local: bool, // true for sliding window attention, false for global
rope_theta: f32, // RoPE base frequency (different for local vs global attention)
}
impl GemmaLayer {
@@ -313,23 +261,21 @@ impl GemmaLayer {
let k = x_attn.matmul(self.k_proj.t());
let v = x_attn.matmul(self.v_proj.t());
// 3. Apply QK-Norm + RoPE using fused custom kernel
let q_rope = x.graph().custom_op(
GemmaQKNormRoPE::new(q.dims()[0], q.dims()[1]),
(q, self.q_norm, pos_ids),
q.shape,
q.dtype,
);
let k_rope = x.graph().custom_op(
GemmaQKNormRoPE::new(k.dims()[0], k.dims()[1]),
(k, self.k_norm, pos_ids),
k.shape,
k.dtype,
);
// 3. Apply QK-Norm + RoPE using HLIR operations
// Use different RoPE base frequency for local vs global attention
let q_rope = gemma_qk_norm_rope(q, self.q_norm, pos_ids, self.rope_theta);
let k_rope = gemma_qk_norm_rope(k, self.k_norm, pos_ids, self.rope_theta);
// 4. Attention
// Use sliding window for local attention layers, global for others
let attn_out = x.graph().custom_op(
GemmaAttention::new(k_cache, v_cache, q_rope.dims()[0], 'p'.into()),
GemmaAttention::new(
k_cache,
v_cache,
q_rope.dims()[0],
'p'.into(),
self.is_local,
),
(q_rope, k_rope, v),
q_rope.shape,
q_rope.dtype,
@@ -398,10 +344,18 @@ pub struct GemmaAttention {
prev_seq: Expression,
k_cache: u64,
v_cache: u64,
sliding_window: usize, // 0 for global attention, >0 for local sliding window
}
impl GemmaAttention {
fn new(k_cache: u64, v_cache: u64, seq: Expression, prev_seq: Expression) -> Self {
fn new(
k_cache: u64,
v_cache: u64,
seq: Expression,
prev_seq: Expression,
is_local: bool,
) -> Self {
let sliding_window = if is_local { SLIDING_WINDOW_SIZE } else { 0 };
Self {
range: (N_KV_HEADS, KV_GROUPS, seq).to_shape(),
head_dim: HEAD_DIM.into(),
@@ -414,6 +368,7 @@ impl GemmaAttention {
prev_seq,
k_cache,
v_cache,
sliding_window,
}
}
}
@@ -481,6 +436,7 @@ impl BlockOp for GemmaAttention {
"head_pos_stride",
flatten_mul_strides(&self.range, &head_pos_stride),
)
.int("sliding_window", self.sliding_window as i32)
}
fn cuda_function(&self) -> String {
@@ -526,6 +482,9 @@ impl BlockOp for GemmaAttention {
const int kv_row_stride = eval_expression(payload.kv_row_stride, 0);
const int prev = eval_expression(payload.prev_seq, 0);
// Sliding window configuration (0 = global attention, >0 = local sliding window)
const int sliding_window = payload.sliding_window;
const float* __restrict__ K_cur = k;
const float* __restrict__ V_cur = v;
float* __restrict__ O = out;
@@ -535,6 +494,13 @@ impl BlockOp for GemmaAttention {
const int q_pos_total = prev + q_pos_local;
// For sliding window attention, compute the start position
// We only attend to positions within [q_pos_total - sliding_window + 1, q_pos_total]
int attn_start = 0;
if (sliding_window > 0 && q_pos_total >= sliding_window) {
attn_start = q_pos_total - sliding_window + 1;
}
const float scale = rsqrtf((float)d);
__shared__ float max_l_shared;
@@ -559,7 +525,8 @@ impl BlockOp for GemmaAttention {
if (t == 0) max_l_shared = -__int_as_float(0x7f800000);
__syncthreads();
for (int r = 0; r <= q_pos_total; ++r) {
// First pass: find max for numerical stability
for (int r = attn_start; r <= q_pos_total; ++r) {
const float* __restrict__ k_row;
if (r < prev) {
k_row = K_cache + r * kv_row_stride;
@@ -591,7 +558,8 @@ impl BlockOp for GemmaAttention {
float s_local = 0.0f;
__syncthreads();
for (int r = 0; r <= q_pos_total; ++r) {
// Second pass: compute softmax and weighted sum
for (int r = attn_start; r <= q_pos_total; ++r) {
const float* __restrict__ k_row;
const float* __restrict__ v_row;