mirror of
https://github.com/jafioti/luminal.git
synced 2026-06-01 21:49:47 +09:00
rowembed op
This commit is contained in:
@@ -17,6 +17,7 @@ pub type Ops = (
|
||||
RowRMSNorm,
|
||||
RowRope,
|
||||
TileMatmulFullSplit,
|
||||
RowEmbed, // Working but with slight numerical differences - needs investigation
|
||||
);
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
@@ -1967,3 +1968,203 @@ impl CStruct {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct RowEmbed {
|
||||
range: Vec<Expression>, // batch dimensions (e.g., [s] for sequence length)
|
||||
token_stride: Vec<Expression>, // stride for token_ids input
|
||||
out_stride: Vec<Expression>, // stride for output
|
||||
embed_dim: Expression, // embedding dimension (e.g., HIDDEN)
|
||||
}
|
||||
|
||||
impl EgglogOp for RowEmbed {
|
||||
fn term(&self) -> (String, Vec<OpParam>) {
|
||||
(
|
||||
"RowEmbed".to_string(),
|
||||
vec![EList, Input, EList, Input, EList, Expr],
|
||||
)
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<String> {
|
||||
vec![
|
||||
// Match Gather with Add(Mul(Cast(token_ids), const), Iota) indices
|
||||
"(rule
|
||||
(
|
||||
(= ?gather (Gather ?indices ?idx_shape ?idx_stride ?embed_table ?embed_shape ?embed_stride))
|
||||
(= ?indices (Add ?add_shape ?mul_result ?mul_stride ?iota_result ?iota_stride ?add_out_stride))
|
||||
(= ?mul_result (Mul ?mul_shape ?token_ids_cast ?token_cast_stride ?mul_const ?mul_const_stride ?mul_out_stride))
|
||||
(= ?token_ids_cast (Cast ?token_ids ?cast_dtype))
|
||||
(= ?embed_dim (nth_from_end ?embed_shape 0))
|
||||
(= ?batch_shape (RemoveNthFromEnd ?idx_shape 0))
|
||||
(= ?out_stride_batch (RemoveNthFromEnd ?add_out_stride 0))
|
||||
)
|
||||
(
|
||||
(let ?re (RowEmbed ?batch_shape ?token_ids ?token_cast_stride ?embed_table ?out_stride_batch ?embed_dim))
|
||||
(union ?gather ?re)
|
||||
(set (dtype ?re) (F32))
|
||||
)
|
||||
:name \"row embed with cast mul\"
|
||||
)".to_string(),
|
||||
// Match Gather with Add(Iota, Mul(Cast(token_ids), const)) indices (reversed order)
|
||||
"(rule
|
||||
(
|
||||
(= ?gather (Gather ?indices ?idx_shape ?idx_stride ?embed_table ?embed_shape ?embed_stride))
|
||||
(= ?indices (Add ?add_shape ?iota_result ?iota_stride ?mul_result ?mul_stride ?add_out_stride))
|
||||
(= ?mul_result (Mul ?mul_shape ?token_ids_cast ?token_cast_stride ?mul_const ?mul_const_stride ?mul_out_stride))
|
||||
(= ?token_ids_cast (Cast ?token_ids ?cast_dtype))
|
||||
(= ?embed_dim (nth_from_end ?embed_shape 0))
|
||||
(= ?batch_shape (RemoveNthFromEnd ?idx_shape 0))
|
||||
(= ?out_stride_batch (RemoveNthFromEnd ?add_out_stride 0))
|
||||
)
|
||||
(
|
||||
(let ?re (RowEmbed ?batch_shape ?token_ids ?token_cast_stride ?embed_table ?out_stride_batch ?embed_dim))
|
||||
(union ?gather ?re)
|
||||
(set (dtype ?re) (F32))
|
||||
)
|
||||
:name \"row embed with cast mul reversed\"
|
||||
)".to_string(),
|
||||
// Match Gather with Add(Mul(token_ids, const), Iota) indices (no Cast)
|
||||
"(rule
|
||||
(
|
||||
(= ?gather (Gather ?indices ?idx_shape ?idx_stride ?embed_table ?embed_shape ?embed_stride))
|
||||
(= ?indices (Add ?add_shape ?mul_result ?mul_stride ?iota_result ?iota_stride ?add_out_stride))
|
||||
(= ?mul_result (Mul ?mul_shape ?token_ids ?token_stride ?mul_const ?mul_const_stride ?mul_out_stride))
|
||||
(= ?embed_dim (nth_from_end ?embed_shape 0))
|
||||
(= ?batch_shape (RemoveNthFromEnd ?idx_shape 0))
|
||||
(= ?out_stride_batch (RemoveNthFromEnd ?add_out_stride 0))
|
||||
)
|
||||
(
|
||||
(let ?re (RowEmbed ?batch_shape ?token_ids ?token_stride ?embed_table ?out_stride_batch ?embed_dim))
|
||||
(union ?gather ?re)
|
||||
(set (dtype ?re) (F32))
|
||||
)
|
||||
:name \"row embed with mul\"
|
||||
)".to_string(),
|
||||
// Match Gather with Add(Iota, Mul(token_ids, const)) indices (reversed order, no Cast)
|
||||
"(rule
|
||||
(
|
||||
(= ?gather (Gather ?indices ?idx_shape ?idx_stride ?embed_table ?embed_shape ?embed_stride))
|
||||
(= ?indices (Add ?add_shape ?iota_result ?iota_stride ?mul_result ?mul_stride ?add_out_stride))
|
||||
(= ?mul_result (Mul ?mul_shape ?token_ids ?token_stride ?mul_const ?mul_const_stride ?mul_out_stride))
|
||||
(= ?embed_dim (nth_from_end ?embed_shape 0))
|
||||
(= ?batch_shape (RemoveNthFromEnd ?idx_shape 0))
|
||||
(= ?out_stride_batch (RemoveNthFromEnd ?add_out_stride 0))
|
||||
)
|
||||
(
|
||||
(let ?re (RowEmbed ?batch_shape ?token_ids ?token_stride ?embed_table ?out_stride_batch ?embed_dim))
|
||||
(union ?gather ?re)
|
||||
(set (dtype ?re) (F32))
|
||||
)
|
||||
:name \"row embed with mul reversed\"
|
||||
)".to_string(),
|
||||
]
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn extract<'a>(
|
||||
&self,
|
||||
egraph: &'a SerializedEGraph,
|
||||
children: &[&'a ENodeId],
|
||||
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
(
|
||||
LLIROp::new::<dyn BlockOp>(Box::new(Self {
|
||||
range: extract_expr_list(egraph, children[0], list_cache, expr_cache).unwrap(),
|
||||
token_stride: extract_expr_list(egraph, children[2], list_cache, expr_cache).unwrap(),
|
||||
out_stride: extract_expr_list(egraph, children[4], list_cache, expr_cache).unwrap(),
|
||||
embed_dim: extract_expr(egraph, children[5], expr_cache).unwrap(),
|
||||
})),
|
||||
vec![children[1], children[3]], // token_ids, embedding_table
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl BlockOp for RowEmbed {
|
||||
fn op_name(&self) -> &'static str {
|
||||
"RowEmbed"
|
||||
}
|
||||
|
||||
fn launch_range(&self) -> Vec<Expression> {
|
||||
if self.range.is_empty() {
|
||||
vec![1.into()]
|
||||
} else {
|
||||
self.range.clone()
|
||||
}
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
self.range.iter().copied().product::<Expression>().max(1) * self.embed_dim
|
||||
}
|
||||
|
||||
fn producer_barriers_seperate(&self) -> Vec<bool> {
|
||||
vec![true; self.range.len()]
|
||||
}
|
||||
|
||||
fn consumer_barriers_seperate(&self) -> Vec<Vec<bool>> {
|
||||
vec![vec![true; self.range.len()], vec![true; self.range.len()]]
|
||||
}
|
||||
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
// Load: 1 token ID (4 bytes) + 1 embedding row (embed_dim * 4 bytes)
|
||||
self.range.iter().copied().product::<Expression>().max(1) * (4 + self.embed_dim * 4)
|
||||
}
|
||||
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
// Store: 1 embedding row per launch
|
||||
self.range.iter().copied().product::<Expression>().max(1) * self.embed_dim * 4
|
||||
}
|
||||
|
||||
fn flops(&self) -> Expression {
|
||||
// No FLOPs - just memory copy
|
||||
0.into()
|
||||
}
|
||||
|
||||
fn cuda_struct(&self) -> String {
|
||||
"const int token_stride; const int out_stride; int embed_dim;".to_string()
|
||||
}
|
||||
|
||||
fn cuda_function(&self) -> String {
|
||||
"
|
||||
int embed_dim = eval_expression(payload.embed_dim, 0);
|
||||
|
||||
// Get stride offsets
|
||||
int token_offset = eval_expression(payload.token_stride, current);
|
||||
int out_offset = eval_expression(payload.out_stride, current);
|
||||
|
||||
// Get pointers
|
||||
const int* token_ids = (const int*)(source_ptrs[0]) + token_offset;
|
||||
const float* embed_table = source_ptrs[1];
|
||||
float* out_row = out_ptr + out_offset;
|
||||
|
||||
// Read token ID (stored as int)
|
||||
int token_id = token_ids[0];
|
||||
|
||||
// Lookup and copy embedding row
|
||||
const float* embed_row = embed_table + (long long)token_id * embed_dim;
|
||||
for (int i = t; i < embed_dim; i += blockDim.x) {
|
||||
out_row[i] = embed_row[i];
|
||||
}
|
||||
"
|
||||
.to_string()
|
||||
}
|
||||
|
||||
fn schedule_op(&self, _: &CudaStream, expressions: &FxHashMap<Expression, i32>) -> Vec<u8> {
|
||||
CStruct::new()
|
||||
.int(expressions[&flatten_mul_strides(&self.range, &self.token_stride)])
|
||||
.int(expressions[&flatten_mul_strides(&self.range, &self.out_stride)])
|
||||
.int(expressions[&self.embed_dim])
|
||||
.finish_struct()
|
||||
}
|
||||
|
||||
fn expressions(&self) -> Vec<Expression> {
|
||||
vec![
|
||||
flatten_mul_strides(&self.range, &self.token_stride),
|
||||
flatten_mul_strides(&self.range, &self.out_stride),
|
||||
self.embed_dim,
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user