rowembed op

This commit is contained in:
Joe Fioti
2026-01-20 19:12:47 +00:00
parent 8f6e5aaae2
commit 3057849501

View File

@@ -17,6 +17,7 @@ pub type Ops = (
RowRMSNorm,
RowRope,
TileMatmulFullSplit,
RowEmbed, // Working but with slight numerical differences - needs investigation
);
#[derive(Debug, Default)]
@@ -1967,3 +1968,203 @@ impl CStruct {
self
}
}
#[derive(Debug, Default)]
pub struct RowEmbed {
range: Vec<Expression>, // batch dimensions (e.g., [s] for sequence length)
token_stride: Vec<Expression>, // stride for token_ids input
out_stride: Vec<Expression>, // stride for output
embed_dim: Expression, // embedding dimension (e.g., HIDDEN)
}
impl EgglogOp for RowEmbed {
fn term(&self) -> (String, Vec<OpParam>) {
(
"RowEmbed".to_string(),
vec![EList, Input, EList, Input, EList, Expr],
)
}
fn rewrites(&self) -> Vec<String> {
vec![
// Match Gather with Add(Mul(Cast(token_ids), const), Iota) indices
"(rule
(
(= ?gather (Gather ?indices ?idx_shape ?idx_stride ?embed_table ?embed_shape ?embed_stride))
(= ?indices (Add ?add_shape ?mul_result ?mul_stride ?iota_result ?iota_stride ?add_out_stride))
(= ?mul_result (Mul ?mul_shape ?token_ids_cast ?token_cast_stride ?mul_const ?mul_const_stride ?mul_out_stride))
(= ?token_ids_cast (Cast ?token_ids ?cast_dtype))
(= ?embed_dim (nth_from_end ?embed_shape 0))
(= ?batch_shape (RemoveNthFromEnd ?idx_shape 0))
(= ?out_stride_batch (RemoveNthFromEnd ?add_out_stride 0))
)
(
(let ?re (RowEmbed ?batch_shape ?token_ids ?token_cast_stride ?embed_table ?out_stride_batch ?embed_dim))
(union ?gather ?re)
(set (dtype ?re) (F32))
)
:name \"row embed with cast mul\"
)".to_string(),
// Match Gather with Add(Iota, Mul(Cast(token_ids), const)) indices (reversed order)
"(rule
(
(= ?gather (Gather ?indices ?idx_shape ?idx_stride ?embed_table ?embed_shape ?embed_stride))
(= ?indices (Add ?add_shape ?iota_result ?iota_stride ?mul_result ?mul_stride ?add_out_stride))
(= ?mul_result (Mul ?mul_shape ?token_ids_cast ?token_cast_stride ?mul_const ?mul_const_stride ?mul_out_stride))
(= ?token_ids_cast (Cast ?token_ids ?cast_dtype))
(= ?embed_dim (nth_from_end ?embed_shape 0))
(= ?batch_shape (RemoveNthFromEnd ?idx_shape 0))
(= ?out_stride_batch (RemoveNthFromEnd ?add_out_stride 0))
)
(
(let ?re (RowEmbed ?batch_shape ?token_ids ?token_cast_stride ?embed_table ?out_stride_batch ?embed_dim))
(union ?gather ?re)
(set (dtype ?re) (F32))
)
:name \"row embed with cast mul reversed\"
)".to_string(),
// Match Gather with Add(Mul(token_ids, const), Iota) indices (no Cast)
"(rule
(
(= ?gather (Gather ?indices ?idx_shape ?idx_stride ?embed_table ?embed_shape ?embed_stride))
(= ?indices (Add ?add_shape ?mul_result ?mul_stride ?iota_result ?iota_stride ?add_out_stride))
(= ?mul_result (Mul ?mul_shape ?token_ids ?token_stride ?mul_const ?mul_const_stride ?mul_out_stride))
(= ?embed_dim (nth_from_end ?embed_shape 0))
(= ?batch_shape (RemoveNthFromEnd ?idx_shape 0))
(= ?out_stride_batch (RemoveNthFromEnd ?add_out_stride 0))
)
(
(let ?re (RowEmbed ?batch_shape ?token_ids ?token_stride ?embed_table ?out_stride_batch ?embed_dim))
(union ?gather ?re)
(set (dtype ?re) (F32))
)
:name \"row embed with mul\"
)".to_string(),
// Match Gather with Add(Iota, Mul(token_ids, const)) indices (reversed order, no Cast)
"(rule
(
(= ?gather (Gather ?indices ?idx_shape ?idx_stride ?embed_table ?embed_shape ?embed_stride))
(= ?indices (Add ?add_shape ?iota_result ?iota_stride ?mul_result ?mul_stride ?add_out_stride))
(= ?mul_result (Mul ?mul_shape ?token_ids ?token_stride ?mul_const ?mul_const_stride ?mul_out_stride))
(= ?embed_dim (nth_from_end ?embed_shape 0))
(= ?batch_shape (RemoveNthFromEnd ?idx_shape 0))
(= ?out_stride_batch (RemoveNthFromEnd ?add_out_stride 0))
)
(
(let ?re (RowEmbed ?batch_shape ?token_ids ?token_stride ?embed_table ?out_stride_batch ?embed_dim))
(union ?gather ?re)
(set (dtype ?re) (F32))
)
:name \"row embed with mul reversed\"
)".to_string(),
]
}
fn cleanup(&self) -> bool {
false
}
fn extract<'a>(
&self,
egraph: &'a SerializedEGraph,
children: &[&'a ENodeId],
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
) -> (LLIROp, Vec<&'a ENodeId>) {
(
LLIROp::new::<dyn BlockOp>(Box::new(Self {
range: extract_expr_list(egraph, children[0], list_cache, expr_cache).unwrap(),
token_stride: extract_expr_list(egraph, children[2], list_cache, expr_cache).unwrap(),
out_stride: extract_expr_list(egraph, children[4], list_cache, expr_cache).unwrap(),
embed_dim: extract_expr(egraph, children[5], expr_cache).unwrap(),
})),
vec![children[1], children[3]], // token_ids, embedding_table
)
}
}
impl BlockOp for RowEmbed {
fn op_name(&self) -> &'static str {
"RowEmbed"
}
fn launch_range(&self) -> Vec<Expression> {
if self.range.is_empty() {
vec![1.into()]
} else {
self.range.clone()
}
}
fn output_size(&self) -> Expression {
self.range.iter().copied().product::<Expression>().max(1) * self.embed_dim
}
fn producer_barriers_seperate(&self) -> Vec<bool> {
vec![true; self.range.len()]
}
fn consumer_barriers_seperate(&self) -> Vec<Vec<bool>> {
vec![vec![true; self.range.len()], vec![true; self.range.len()]]
}
fn bytes_loaded(&self) -> Expression {
// Load: 1 token ID (4 bytes) + 1 embedding row (embed_dim * 4 bytes)
self.range.iter().copied().product::<Expression>().max(1) * (4 + self.embed_dim * 4)
}
fn bytes_stored(&self) -> Expression {
// Store: 1 embedding row per launch
self.range.iter().copied().product::<Expression>().max(1) * self.embed_dim * 4
}
fn flops(&self) -> Expression {
// No FLOPs - just memory copy
0.into()
}
fn cuda_struct(&self) -> String {
"const int token_stride; const int out_stride; int embed_dim;".to_string()
}
fn cuda_function(&self) -> String {
"
int embed_dim = eval_expression(payload.embed_dim, 0);
// Get stride offsets
int token_offset = eval_expression(payload.token_stride, current);
int out_offset = eval_expression(payload.out_stride, current);
// Get pointers
const int* token_ids = (const int*)(source_ptrs[0]) + token_offset;
const float* embed_table = source_ptrs[1];
float* out_row = out_ptr + out_offset;
// Read token ID (stored as int)
int token_id = token_ids[0];
// Lookup and copy embedding row
const float* embed_row = embed_table + (long long)token_id * embed_dim;
for (int i = t; i < embed_dim; i += blockDim.x) {
out_row[i] = embed_row[i];
}
"
.to_string()
}
fn schedule_op(&self, _: &CudaStream, expressions: &FxHashMap<Expression, i32>) -> Vec<u8> {
CStruct::new()
.int(expressions[&flatten_mul_strides(&self.range, &self.token_stride)])
.int(expressions[&flatten_mul_strides(&self.range, &self.out_stride)])
.int(expressions[&self.embed_dim])
.finish_struct()
}
fn expressions(&self) -> Vec<Expression> {
vec![
flatten_mul_strides(&self.range, &self.token_stride),
flatten_mul_strides(&self.range, &self.out_stride),
self.embed_dim,
]
}
}