mirror of
https://github.com/jafioti/luminal.git
synced 2026-06-01 21:49:47 +09:00
optimized rowadd op
This commit is contained in:
@@ -19,14 +19,14 @@ pub struct RowAdd {
|
||||
a_stride: Vec<Expression>,
|
||||
b_stride: Vec<Expression>,
|
||||
out_stride: Vec<Expression>,
|
||||
row_width: Expression,
|
||||
row_width: usize,
|
||||
}
|
||||
|
||||
impl EgglogOp for RowAdd {
|
||||
fn term(&self) -> (String, Vec<OpParam>) {
|
||||
(
|
||||
"RowAdd".to_string(),
|
||||
vec![EList, Input, EList, Input, EList, EList, Expr],
|
||||
vec![EList, Input, EList, Input, EList, EList, Int],
|
||||
)
|
||||
}
|
||||
|
||||
@@ -35,7 +35,7 @@ impl EgglogOp for RowAdd {
|
||||
(
|
||||
; get add
|
||||
(= ?sa (Add ?shape ?a ?a_stride ?b ?b_stride ?out_stride))
|
||||
(= ?row_width (nth_from_end ?shape 0))
|
||||
(= (MNum ?row_width) (nth_from_end ?shape 0))
|
||||
; assert the row is contiguous
|
||||
(= (MNum 1) (nth_from_end ?a_stride 0))
|
||||
(= (MNum 1) (nth_from_end ?b_stride 0))
|
||||
@@ -72,7 +72,7 @@ impl EgglogOp for RowAdd {
|
||||
a_stride: extract_expr_list(egraph, children[2], list_cache, expr_cache).unwrap(),
|
||||
b_stride: extract_expr_list(egraph, children[4], list_cache, expr_cache).unwrap(),
|
||||
out_stride: extract_expr_list(egraph, children[5], list_cache, expr_cache).unwrap(),
|
||||
row_width: extract_expr(egraph, children[6], expr_cache).unwrap(),
|
||||
row_width: egraph.enodes[&children[6]].0.parse().unwrap(),
|
||||
})),
|
||||
vec![children[1], children[3]],
|
||||
)
|
||||
@@ -85,11 +85,8 @@ impl BlockOp for RowAdd {
|
||||
}
|
||||
|
||||
fn launch_range(&self) -> Vec<Expression> {
|
||||
if self.range.is_empty() {
|
||||
vec![1.into()]
|
||||
} else {
|
||||
self.range.clone()
|
||||
}
|
||||
// Single iteration - process all elements at once to minimize interpreter overhead
|
||||
vec![1.into()]
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
@@ -97,16 +94,17 @@ impl BlockOp for RowAdd {
|
||||
}
|
||||
|
||||
fn consumer_barriers_seperate(&self) -> Vec<Vec<bool>> {
|
||||
vec![vec![true; self.range.len()], vec![true; self.range.len()]]
|
||||
// All inputs must be fully ready before we start (single iteration)
|
||||
vec![vec![false; self.range.len()], vec![false; self.range.len()]]
|
||||
}
|
||||
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
// Load 2 input rows (a + b) per launch
|
||||
// Load 2 input arrays (a + b)
|
||||
self.range.iter().copied().product::<Expression>().max(1) * self.row_width * 2 * 4
|
||||
}
|
||||
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
// Store 1 output row per launch
|
||||
// Store 1 output array
|
||||
self.range.iter().copied().product::<Expression>().max(1) * self.row_width * 4
|
||||
}
|
||||
|
||||
@@ -116,18 +114,57 @@ impl BlockOp for RowAdd {
|
||||
}
|
||||
|
||||
fn cuda_op(&self) -> (String, String) {
|
||||
let struct_body =
|
||||
"const int a_strides; const int b_strides; const int out_strides; int row_width;"
|
||||
.to_string();
|
||||
let function_body = "
|
||||
const float* a = source_ptrs[0] + eval_expression(payload.a_strides, current);
|
||||
const float* b = source_ptrs[1] + eval_expression(payload.b_strides, current);
|
||||
float* out = out_ptr + eval_expression(payload.out_strides, current);
|
||||
let struct_body = "const int a; const int b; const int out; int row_width;".to_string();
|
||||
let function_body = r#"
|
||||
// Process all elements in a single iteration (no per-row overhead)
|
||||
const float* __restrict__ a_base = source_ptrs[0] + (current == 0 ? 0 : eval_expression(payload.a, current));
|
||||
const float* __restrict__ b_base = source_ptrs[1] + (current == 0 ? 0 : eval_expression(payload.b, current));
|
||||
float* __restrict__ out_base = out_ptr + (current == 0 ? 0 : eval_expression(payload.out, current));
|
||||
|
||||
for (int idx = t; idx < eval_expression(payload.row_width, 0); idx += blockDim.x) {
|
||||
out[idx] = a[idx] + b[idx];
|
||||
|
||||
// Use float4 for vectorized access when aligned
|
||||
const bool aligned16 =
|
||||
((((unsigned long long)a_base |
|
||||
(unsigned long long)b_base |
|
||||
(unsigned long long)out_base) & 0xFULL) == 0ULL);
|
||||
|
||||
if (aligned16 && (payload.row_width % 4 == 0)) {
|
||||
const float4* __restrict__ a4 = reinterpret_cast<const float4*>(a_base);
|
||||
const float4* __restrict__ b4 = reinterpret_cast<const float4*>(b_base);
|
||||
float4* __restrict__ out4 = reinterpret_cast<float4*>(out_base);
|
||||
|
||||
const int n_vec = payload.row_width >> 2;
|
||||
const int stride = blockDim.x;
|
||||
|
||||
// Process 4 float4 vectors per thread per iteration for ILP
|
||||
int i = t;
|
||||
for (; i + 3 * stride < n_vec; i += 4 * stride) {
|
||||
float4 va0 = a4[i];
|
||||
float4 va1 = a4[i + stride];
|
||||
float4 va2 = a4[i + 2 * stride];
|
||||
float4 va3 = a4[i + 3 * stride];
|
||||
float4 vb0 = b4[i];
|
||||
float4 vb1 = b4[i + stride];
|
||||
float4 vb2 = b4[i + 2 * stride];
|
||||
float4 vb3 = b4[i + 3 * stride];
|
||||
out4[i] = make_float4(va0.x + vb0.x, va0.y + vb0.y, va0.z + vb0.z, va0.w + vb0.w);
|
||||
out4[i + stride] = make_float4(va1.x + vb1.x, va1.y + vb1.y, va1.z + vb1.z, va1.w + vb1.w);
|
||||
out4[i + 2 * stride] = make_float4(va2.x + vb2.x, va2.y + vb2.y, va2.z + vb2.z, va2.w + vb2.w);
|
||||
out4[i + 3 * stride] = make_float4(va3.x + vb3.x, va3.y + vb3.y, va3.z + vb3.z, va3.w + vb3.w);
|
||||
}
|
||||
// Handle remainder
|
||||
for (; i < n_vec; i += stride) {
|
||||
float4 va = a4[i];
|
||||
float4 vb = b4[i];
|
||||
out4[i] = make_float4(va.x + vb.x, va.y + vb.y, va.z + vb.z, va.w + vb.w);
|
||||
}
|
||||
} else {
|
||||
// Scalar fallback with coalesced strided access
|
||||
for (int i = t; i < payload.row_width; i += blockDim.x) {
|
||||
out_base[i] = a_base[i] + b_base[i];
|
||||
}
|
||||
}
|
||||
"
|
||||
"#
|
||||
.to_string();
|
||||
(struct_body, function_body)
|
||||
}
|
||||
@@ -137,17 +174,12 @@ impl BlockOp for RowAdd {
|
||||
.int(expressions[&flatten_mul_strides(&self.range, &self.a_stride)])
|
||||
.int(expressions[&flatten_mul_strides(&self.range, &self.b_stride)])
|
||||
.int(expressions[&flatten_mul_strides(&self.range, &self.out_stride)])
|
||||
.int(expressions[&self.row_width])
|
||||
.int(self.row_width as i32)
|
||||
.finish_struct()
|
||||
}
|
||||
|
||||
fn expressions(&self) -> Vec<Expression> {
|
||||
vec![
|
||||
flatten_mul_strides(&self.range, &self.a_stride),
|
||||
flatten_mul_strides(&self.range, &self.b_stride),
|
||||
flatten_mul_strides(&self.range, &self.out_stride),
|
||||
self.row_width,
|
||||
]
|
||||
vec![self.range.iter().copied().product::<Expression>().max(1) * self.row_width]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -630,7 +630,7 @@ impl Runtime for CudaRuntime {
|
||||
// Launch kernel
|
||||
let cfg = LaunchConfig {
|
||||
grid_dim: (sm_count as u32, 1, 1), // One block per SM
|
||||
block_dim: (1024, 1, 1), // 1024 threads (32 warps) per block
|
||||
block_dim: (256, 1, 1), // 1024 threads (32 warps) per block
|
||||
shared_mem_bytes: (shared_mem_max / 2) as u32,
|
||||
};
|
||||
let mut lb = self.cuda_stream.launch_builder(interpreter);
|
||||
|
||||
@@ -12,7 +12,7 @@ use tracing::{span, Level};
|
||||
fn main() {
|
||||
let max_seq_len = 4096;
|
||||
let gen_tokens = 5;
|
||||
let search_graphs = 10; // the number of graphs we want to search during compilation
|
||||
let search_graphs = 5; // the number of graphs we want to search during compilation
|
||||
let prompt = "Hello, how are you";
|
||||
|
||||
// Set up tracing to perfetto
|
||||
|
||||
Reference in New Issue
Block a user