mirror of
https://github.com/jafioti/luminal.git
synced 2026-06-01 21:49:47 +09:00
removed special rope op for gemma
This commit is contained in:
@@ -14,7 +14,7 @@ const REPO_ID: &str = "unsloth/gemma-3-4b-it";
|
||||
|
||||
fn main() {
|
||||
let max_seq_len = 4096;
|
||||
let gen_tokens = 30;
|
||||
let gen_tokens = 100;
|
||||
let search_graphs = 5; // the number of graphs we want to search during compilation
|
||||
let prompt = "Explain what a neural network is in simple terms:";
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
use luminal::{
|
||||
graph::Graph,
|
||||
op::{CustomOp, LLIROp},
|
||||
prelude::GraphTensor,
|
||||
shape::{flatten_mul_strides, Expression, ToShape},
|
||||
op::{CustomOp, DType, LLIROp},
|
||||
prelude::{F32Pow, GraphTensor},
|
||||
shape::{flatten_mul_strides, Expression, ShapeTracker, ToShape},
|
||||
};
|
||||
use luminal_cuda::{
|
||||
block::{cstruct::CStruct, BlockOp},
|
||||
@@ -23,6 +23,70 @@ pub const KV_DIM: usize = N_KV_HEADS * HEAD_DIM; // = 1024
|
||||
pub const VOCAB_SIZE: usize = 262208;
|
||||
pub const RMS_NORM_EPS: f32 = 1e-6;
|
||||
|
||||
// Attention pattern constants
|
||||
pub const SLIDING_WINDOW_PATTERN: usize = 6; // Every 6th layer is global attention
|
||||
pub const SLIDING_WINDOW_SIZE: usize = 1024; // Local attention window size
|
||||
pub const ROPE_THETA_GLOBAL: f32 = 1_000_000.0; // RoPE base for global attention layers
|
||||
pub const ROPE_THETA_LOCAL: f32 = 10_000.0; // RoPE base for local attention layers
|
||||
|
||||
/// Apply QK-Norm + RoPE using frontend HLIR operations
|
||||
fn gemma_qk_norm_rope(
|
||||
mut input: GraphTensor,
|
||||
norm_weight: GraphTensor,
|
||||
pos_ids: GraphTensor,
|
||||
rope_theta: f32,
|
||||
) -> GraphTensor {
|
||||
let orig_shape = input.shape;
|
||||
|
||||
// Reshape: (seq, dim) -> (n_heads, seq, head_dim)
|
||||
input = input.split_dims(1, HEAD_DIM).transpose(0, 1);
|
||||
|
||||
// Apply QK-Norm: RMS norm along head_dim with learnable weights
|
||||
// Note: weights are pre-transformed to (1 + weight) in setup.py
|
||||
input = input.std_norm(2, RMS_NORM_EPS);
|
||||
input = input * norm_weight.expand_lhs(&input.dims()[..input.dims().len() - 1]);
|
||||
|
||||
// Get freqs: theta_i = base^(-2i/d) for i in [0, d/2)
|
||||
let freqs = input
|
||||
.graph()
|
||||
.arange_options(0, HEAD_DIM, 2)
|
||||
.cast(DType::F32)
|
||||
/ HEAD_DIM as f32;
|
||||
let inv_freqs = rope_theta.pow(freqs).reciprocal();
|
||||
|
||||
// emb = pos * inv_freqs, shape (seq, head_dim/2)
|
||||
let emb = pos_ids
|
||||
.cast(DType::F32)
|
||||
.expand_dim(1, 1)
|
||||
.matmul(inv_freqs.expand_dim(0, 1));
|
||||
|
||||
// Split input into first half (x0) and second half (x1) - SPLIT-HALF style
|
||||
// input: (n_heads, seq, head_dim) -> split into (n_heads, seq, head_dim/2) each
|
||||
let half_dim = HEAD_DIM / 2;
|
||||
let x0 = input.slice((.., .., ..half_dim)); // First half: indices 0-127
|
||||
let x1 = input.slice((.., .., half_dim..)); // Second half: indices 128-255
|
||||
|
||||
// Apply rotary embeddings:
|
||||
// x0_out = x0 * cos(emb) - x1 * sin(emb)
|
||||
// x1_out = x1 * cos(emb) + x0 * sin(emb)
|
||||
let cos_emb = emb.cos().expand_dim(0, x0.dims()[0]);
|
||||
let sin_emb = emb.sin().expand_dim(0, x0.dims()[0]);
|
||||
|
||||
let x0_out = x0 * cos_emb - x1 * sin_emb;
|
||||
let x1_out = x1 * cos_emb + x0 * sin_emb;
|
||||
|
||||
// Concatenate back: [first_half | second_half]
|
||||
let mut s = x0_out.concat_along(x1_out, 2);
|
||||
|
||||
// Set proper strides and reshape back
|
||||
let n_heads = input.dims()[0];
|
||||
let seq_dim = input.dims()[1];
|
||||
s.shape = ShapeTracker::new((n_heads, seq_dim, HEAD_DIM));
|
||||
s = s.transpose(0, 1) * 1.0;
|
||||
s.shape = orig_shape;
|
||||
s
|
||||
}
|
||||
|
||||
/// Gemma-specific RMSNorm
|
||||
/// Note: weights are pre-transformed to (1 + weight) in setup.py
|
||||
pub struct GemmaRMSNorm {
|
||||
@@ -59,6 +123,7 @@ impl Gemma {
|
||||
pub fn init(cx: &mut Graph) -> Self {
|
||||
let mut w = vec![];
|
||||
for l in 0..LAYERS {
|
||||
let is_local = (l + 1) % SLIDING_WINDOW_PATTERN != 0;
|
||||
w.push(GemmaLayer {
|
||||
up: cx.named_tensor(
|
||||
format!("model.layers.{l}.mlp.up_proj.weight"),
|
||||
@@ -120,6 +185,12 @@ impl Gemma {
|
||||
format!("model.layers.{l}.self_attn.k_norm.weight"),
|
||||
HEAD_DIM,
|
||||
),
|
||||
is_local,
|
||||
rope_theta: if is_local {
|
||||
ROPE_THETA_LOCAL
|
||||
} else {
|
||||
ROPE_THETA_GLOBAL
|
||||
},
|
||||
});
|
||||
}
|
||||
let lm_norm = GemmaRMSNorm::new(HIDDEN, "model.norm.weight", RMS_NORM_EPS, cx);
|
||||
@@ -170,131 +241,8 @@ struct GemmaLayer {
|
||||
post_feedforward_layernorm: GemmaRMSNorm,
|
||||
q_norm: GraphTensor,
|
||||
k_norm: GraphTensor,
|
||||
}
|
||||
|
||||
/// Fused QK-Norm + RoPE custom operation (interleaved format for Gemma)
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct GemmaQKNormRoPE {
|
||||
range: Vec<Expression>,
|
||||
inp_stride: Vec<Expression>,
|
||||
row_width: Expression,
|
||||
}
|
||||
|
||||
impl GemmaQKNormRoPE {
|
||||
fn new(seq: Expression, row_width: Expression) -> Self {
|
||||
Self {
|
||||
range: vec![seq],
|
||||
inp_stride: vec![row_width],
|
||||
row_width,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CustomOp for GemmaQKNormRoPE {
|
||||
fn to_llir_op(&self) -> LLIROp {
|
||||
LLIROp::new::<dyn BlockOp>(Box::new(self.clone()))
|
||||
}
|
||||
}
|
||||
|
||||
impl BlockOp for GemmaQKNormRoPE {
|
||||
fn op_name(&self) -> &'static str {
|
||||
"GemmaQKNormRoPE"
|
||||
}
|
||||
|
||||
fn launch_range(&self) -> Vec<Expression> {
|
||||
self.range.clone()
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
self.range.iter().copied().product::<Expression>() * self.row_width
|
||||
}
|
||||
|
||||
fn producer_barriers_seperate(&self) -> Vec<bool> {
|
||||
vec![true; self.range.len()]
|
||||
}
|
||||
|
||||
fn consumer_barriers_seperate(&self) -> Vec<Vec<bool>> {
|
||||
vec![
|
||||
vec![true; self.range.len()],
|
||||
vec![true; self.range.len()],
|
||||
vec![true; self.range.len()],
|
||||
]
|
||||
}
|
||||
|
||||
fn cuda_function(&self) -> String {
|
||||
format!(
|
||||
r#"
|
||||
__shared__ float rms_scale_shared;
|
||||
|
||||
const float* inp = source_ptrs[0] + eval_expression(payload.inp, current);
|
||||
float* out = out_ptr + eval_expression(payload.out, current);
|
||||
const float* weight = source_ptrs[1];
|
||||
const int* token_ids = (const int*)source_ptrs[2] + eval_expression(payload.token_ids, current);
|
||||
|
||||
const int D_total = eval_expression(payload.row_width, 0);
|
||||
const int d_head = {HEAD_DIM};
|
||||
const int n_heads = D_total / d_head;
|
||||
|
||||
const int pos = token_ids[0];
|
||||
const float base = 1000000.0f;
|
||||
const float eps = {RMS_NORM_EPS}f;
|
||||
|
||||
const int half = d_head / 2;
|
||||
|
||||
// Process each head
|
||||
for (int h = 0; h < n_heads; ++h) {{
|
||||
const float* head_in = inp + h * d_head;
|
||||
float* head_out = out + h * d_head;
|
||||
|
||||
// Step 1: Compute sum of squares for RMS norm
|
||||
if (t == 0) {{
|
||||
float sum_sq = 0.0f;
|
||||
for (int k = 0; k < d_head; ++k) {{
|
||||
float val = head_in[k];
|
||||
sum_sq += val * val;
|
||||
}}
|
||||
rms_scale_shared = rsqrtf(sum_sq / (float)d_head + eps);
|
||||
}}
|
||||
__syncthreads();
|
||||
float rms_scale = rms_scale_shared;
|
||||
|
||||
// Step 2: Apply RMS norm + weight and RoPE (interleaved format)
|
||||
for (int k = t; k < half; k += blockDim.x) {{
|
||||
// Interleaved indexing: pairs are (0,1), (2,3), (4,5), ...
|
||||
const int j_even = 2 * k;
|
||||
const int j_odd = 2 * k + 1;
|
||||
|
||||
// Apply RMS norm and weight (weights have 1+w pre-applied in setup.py)
|
||||
float x_even = head_in[j_even] * rms_scale * weight[j_even];
|
||||
float x_odd = head_in[j_odd] * rms_scale * weight[j_odd];
|
||||
|
||||
// Compute RoPE rotation
|
||||
const float exponent = -(2.0f * (float)k) / (float)d_head;
|
||||
const float theta = (float)pos * __powf(base, exponent);
|
||||
|
||||
float s, c;
|
||||
__sincosf(theta, &s, &c);
|
||||
|
||||
// Rotation
|
||||
head_out[j_even] = x_even * c - x_odd * s;
|
||||
head_out[j_odd] = x_even * s + x_odd * c;
|
||||
}}
|
||||
__syncthreads();
|
||||
}}
|
||||
"#,
|
||||
HEAD_DIM = HEAD_DIM,
|
||||
RMS_NORM_EPS = RMS_NORM_EPS
|
||||
)
|
||||
}
|
||||
|
||||
fn build_payload<'a>(&self, _: &Arc<CudaStream>, payload: CStruct<'a>) -> CStruct<'a> {
|
||||
payload
|
||||
.expr("inp", flatten_mul_strides(&self.range, &self.inp_stride))
|
||||
.expr("out", flatten_mul_strides(&self.range, &self.inp_stride))
|
||||
.expr("row_width", self.row_width)
|
||||
.expr("weight", 0)
|
||||
.expr("token_ids", 'z')
|
||||
}
|
||||
is_local: bool, // true for sliding window attention, false for global
|
||||
rope_theta: f32, // RoPE base frequency (different for local vs global attention)
|
||||
}
|
||||
|
||||
impl GemmaLayer {
|
||||
@@ -313,23 +261,21 @@ impl GemmaLayer {
|
||||
let k = x_attn.matmul(self.k_proj.t());
|
||||
let v = x_attn.matmul(self.v_proj.t());
|
||||
|
||||
// 3. Apply QK-Norm + RoPE using fused custom kernel
|
||||
let q_rope = x.graph().custom_op(
|
||||
GemmaQKNormRoPE::new(q.dims()[0], q.dims()[1]),
|
||||
(q, self.q_norm, pos_ids),
|
||||
q.shape,
|
||||
q.dtype,
|
||||
);
|
||||
let k_rope = x.graph().custom_op(
|
||||
GemmaQKNormRoPE::new(k.dims()[0], k.dims()[1]),
|
||||
(k, self.k_norm, pos_ids),
|
||||
k.shape,
|
||||
k.dtype,
|
||||
);
|
||||
// 3. Apply QK-Norm + RoPE using HLIR operations
|
||||
// Use different RoPE base frequency for local vs global attention
|
||||
let q_rope = gemma_qk_norm_rope(q, self.q_norm, pos_ids, self.rope_theta);
|
||||
let k_rope = gemma_qk_norm_rope(k, self.k_norm, pos_ids, self.rope_theta);
|
||||
|
||||
// 4. Attention
|
||||
// Use sliding window for local attention layers, global for others
|
||||
let attn_out = x.graph().custom_op(
|
||||
GemmaAttention::new(k_cache, v_cache, q_rope.dims()[0], 'p'.into()),
|
||||
GemmaAttention::new(
|
||||
k_cache,
|
||||
v_cache,
|
||||
q_rope.dims()[0],
|
||||
'p'.into(),
|
||||
self.is_local,
|
||||
),
|
||||
(q_rope, k_rope, v),
|
||||
q_rope.shape,
|
||||
q_rope.dtype,
|
||||
@@ -398,10 +344,18 @@ pub struct GemmaAttention {
|
||||
prev_seq: Expression,
|
||||
k_cache: u64,
|
||||
v_cache: u64,
|
||||
sliding_window: usize, // 0 for global attention, >0 for local sliding window
|
||||
}
|
||||
|
||||
impl GemmaAttention {
|
||||
fn new(k_cache: u64, v_cache: u64, seq: Expression, prev_seq: Expression) -> Self {
|
||||
fn new(
|
||||
k_cache: u64,
|
||||
v_cache: u64,
|
||||
seq: Expression,
|
||||
prev_seq: Expression,
|
||||
is_local: bool,
|
||||
) -> Self {
|
||||
let sliding_window = if is_local { SLIDING_WINDOW_SIZE } else { 0 };
|
||||
Self {
|
||||
range: (N_KV_HEADS, KV_GROUPS, seq).to_shape(),
|
||||
head_dim: HEAD_DIM.into(),
|
||||
@@ -414,6 +368,7 @@ impl GemmaAttention {
|
||||
prev_seq,
|
||||
k_cache,
|
||||
v_cache,
|
||||
sliding_window,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -481,6 +436,7 @@ impl BlockOp for GemmaAttention {
|
||||
"head_pos_stride",
|
||||
flatten_mul_strides(&self.range, &head_pos_stride),
|
||||
)
|
||||
.int("sliding_window", self.sliding_window as i32)
|
||||
}
|
||||
|
||||
fn cuda_function(&self) -> String {
|
||||
@@ -526,6 +482,9 @@ impl BlockOp for GemmaAttention {
|
||||
const int kv_row_stride = eval_expression(payload.kv_row_stride, 0);
|
||||
const int prev = eval_expression(payload.prev_seq, 0);
|
||||
|
||||
// Sliding window configuration (0 = global attention, >0 = local sliding window)
|
||||
const int sliding_window = payload.sliding_window;
|
||||
|
||||
const float* __restrict__ K_cur = k;
|
||||
const float* __restrict__ V_cur = v;
|
||||
float* __restrict__ O = out;
|
||||
@@ -535,6 +494,13 @@ impl BlockOp for GemmaAttention {
|
||||
|
||||
const int q_pos_total = prev + q_pos_local;
|
||||
|
||||
// For sliding window attention, compute the start position
|
||||
// We only attend to positions within [q_pos_total - sliding_window + 1, q_pos_total]
|
||||
int attn_start = 0;
|
||||
if (sliding_window > 0 && q_pos_total >= sliding_window) {
|
||||
attn_start = q_pos_total - sliding_window + 1;
|
||||
}
|
||||
|
||||
const float scale = rsqrtf((float)d);
|
||||
|
||||
__shared__ float max_l_shared;
|
||||
@@ -559,7 +525,8 @@ impl BlockOp for GemmaAttention {
|
||||
if (t == 0) max_l_shared = -__int_as_float(0x7f800000);
|
||||
__syncthreads();
|
||||
|
||||
for (int r = 0; r <= q_pos_total; ++r) {
|
||||
// First pass: find max for numerical stability
|
||||
for (int r = attn_start; r <= q_pos_total; ++r) {
|
||||
const float* __restrict__ k_row;
|
||||
if (r < prev) {
|
||||
k_row = K_cache + r * kv_row_stride;
|
||||
@@ -591,7 +558,8 @@ impl BlockOp for GemmaAttention {
|
||||
float s_local = 0.0f;
|
||||
__syncthreads();
|
||||
|
||||
for (int r = 0; r <= q_pos_total; ++r) {
|
||||
// Second pass: compute softmax and weighted sum
|
||||
for (int r = attn_start; r <= q_pos_total; ++r) {
|
||||
const float* __restrict__ k_row;
|
||||
const float* __restrict__ v_row;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user