mirror of
https://github.com/jafioti/luminal.git
synced 2026-06-01 21:49:47 +09:00
2
crates/luminal_cuda/.cargo/config.toml
Normal file
2
crates/luminal_cuda/.cargo/config.toml
Normal file
@@ -0,0 +1,2 @@
|
||||
[env]
|
||||
RUST_TEST_THREADS = "1"
|
||||
@@ -24,4 +24,6 @@ memmap2 = "0.9.9"
|
||||
uuid = {version="1.19.0", features=["v4"]}
|
||||
|
||||
[dev-dependencies]
|
||||
candle-core = "0.9.1"
|
||||
proptest = "1.9.0"
|
||||
rand = "0.9.2"
|
||||
|
||||
@@ -37,6 +37,7 @@ pub trait BlockOp: Debug + as_any::AsAny {
|
||||
fn launch_range(&self) -> Vec<Expression> {
|
||||
unimplemented!()
|
||||
}
|
||||
/// Returns the output buffer size in elements.
|
||||
fn output_size(&self) -> Expression {
|
||||
unimplemented!()
|
||||
}
|
||||
@@ -46,6 +47,21 @@ pub trait BlockOp: Debug + as_any::AsAny {
|
||||
fn cuda_op(&self) -> (String, String) {
|
||||
("".to_string(), "".to_string())
|
||||
} // C dtype, C function
|
||||
|
||||
/// Returns the number of bytes this op will load from global memory.
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
0.into()
|
||||
}
|
||||
|
||||
/// Returns the number of bytes this op will store to global memory.
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
0.into()
|
||||
}
|
||||
|
||||
/// Returns the number of floating point operations this op performs.
|
||||
fn flops(&self) -> Expression {
|
||||
0.into()
|
||||
}
|
||||
#[allow(clippy::mutable_key_type)]
|
||||
fn schedule_op(
|
||||
&self,
|
||||
|
||||
@@ -100,6 +100,21 @@ impl BlockOp for RowAdd {
|
||||
vec![vec![true; self.range.len()], vec![true; self.range.len()]]
|
||||
}
|
||||
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
// Load 2 input rows (a + b) per launch
|
||||
self.range.iter().copied().product::<Expression>().max(1) * self.row_width * 2 * 4
|
||||
}
|
||||
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
// Store 1 output row per launch
|
||||
self.range.iter().copied().product::<Expression>().max(1) * self.row_width * 4
|
||||
}
|
||||
|
||||
fn flops(&self) -> Expression {
|
||||
// 1 add per element
|
||||
self.range.iter().copied().product::<Expression>().max(1) * self.row_width
|
||||
}
|
||||
|
||||
fn cuda_op(&self) -> (String, String) {
|
||||
let struct_body =
|
||||
"const int a_strides; const int b_strides; const int out_strides; int row_width;"
|
||||
@@ -236,6 +251,22 @@ impl BlockOp for RowSwishMul {
|
||||
vec![vec![true; self.range.len()], vec![true; self.range.len()]]
|
||||
}
|
||||
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
// Load 2 input rows (a + b) per launch
|
||||
self.range.iter().copied().product::<Expression>() * self.row_width * 2 * 4
|
||||
}
|
||||
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
// Store 1 output row per launch
|
||||
self.range.iter().copied().product::<Expression>() * self.row_width * 4
|
||||
}
|
||||
|
||||
fn flops(&self) -> Expression {
|
||||
// swish(x) * b[idx] = x / (1 + exp(-x)) * b
|
||||
// ~5 ops per element: neg, exp, add, div, mul
|
||||
self.range.iter().copied().product::<Expression>() * self.row_width * 5
|
||||
}
|
||||
|
||||
fn cuda_op(&self) -> (String, String) {
|
||||
let struct_body = "
|
||||
const int a;
|
||||
@@ -423,6 +454,22 @@ impl BlockOp for RowRMSNorm {
|
||||
vec![vec![true; self.range.len()], vec![true; self.range.len()]]
|
||||
}
|
||||
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
// Load input row + weight row per launch
|
||||
self.range.iter().copied().product::<Expression>() * self.row_width * 2 * 4
|
||||
}
|
||||
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
// Store 1 output row per launch
|
||||
self.range.iter().copied().product::<Expression>() * self.row_width * 4
|
||||
}
|
||||
|
||||
fn flops(&self) -> Expression {
|
||||
// Per row: d squares, d-1 adds for sum, div by d, add eps, sqrt, recip, then 2d muls (inp * inv_rms * weight)
|
||||
// Approximate: 5*d ops per row
|
||||
self.range.iter().copied().product::<Expression>() * self.row_width * 5
|
||||
}
|
||||
|
||||
fn cuda_op(&self) -> (String, String) {
|
||||
let struct_body = "
|
||||
const int inp;
|
||||
@@ -811,6 +858,22 @@ impl BlockOp for RowRope {
|
||||
vec![vec![true; self.range.len()], vec![true; self.range.len()]]
|
||||
}
|
||||
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
// Load input row (row_width floats) + token_ids (1 int per row)
|
||||
self.range.iter().copied().product::<Expression>() * (self.row_width * 4 + 4)
|
||||
}
|
||||
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
// Store 1 output row per launch
|
||||
self.range.iter().copied().product::<Expression>() * self.row_width * 4
|
||||
}
|
||||
|
||||
fn flops(&self) -> Expression {
|
||||
// Per pair of elements: pow, sincos, 4 muls, 2 adds ≈ 10 ops
|
||||
// row_width/2 pairs per row
|
||||
self.range.iter().copied().product::<Expression>() * self.row_width * 5
|
||||
}
|
||||
|
||||
fn cuda_op(&self) -> (String, String) {
|
||||
let struct_body = "
|
||||
const int inp;
|
||||
@@ -1035,6 +1098,47 @@ impl BlockOp for TileMatmul {
|
||||
vec![a, b]
|
||||
}
|
||||
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
// Matmul C = A @ B where A is (M, K) and B is (K, N)
|
||||
// Loads: A (M * K) + B (K * N) floats
|
||||
// untiled_range[0] = M, untiled_range[1] = N, iters = K
|
||||
// Batch dimensions from range[0..len-2]
|
||||
let batch: Expression = if self.range.len() > 2 {
|
||||
self.range[..self.range.len() - 2].iter().copied().product()
|
||||
} else {
|
||||
1.into()
|
||||
};
|
||||
let m = self.untiled_range[0];
|
||||
let n = self.untiled_range[1];
|
||||
let k = self.iters;
|
||||
batch * (m * k + k * n) * 4
|
||||
}
|
||||
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
// Store C (M * N) floats
|
||||
let batch: Expression = if self.range.len() > 2 {
|
||||
self.range[..self.range.len() - 2].iter().copied().product()
|
||||
} else {
|
||||
1.into()
|
||||
};
|
||||
let m = self.untiled_range[0];
|
||||
let n = self.untiled_range[1];
|
||||
batch * m * n * 4
|
||||
}
|
||||
|
||||
fn flops(&self) -> Expression {
|
||||
// Matmul FLOPs: 2 * M * N * K (one mul + one add per output element per K iteration)
|
||||
let batch: Expression = if self.range.len() > 2 {
|
||||
self.range[..self.range.len() - 2].iter().copied().product()
|
||||
} else {
|
||||
1.into()
|
||||
};
|
||||
let m = self.untiled_range[0];
|
||||
let n = self.untiled_range[1];
|
||||
let k = self.iters;
|
||||
batch * m * n * k * 2
|
||||
}
|
||||
|
||||
fn cuda_op(&self) -> (String, String) {
|
||||
let struct_body = "
|
||||
const int untiled_range[2];
|
||||
|
||||
@@ -22,7 +22,28 @@ pub trait KernelOp: luminal::op::EgglogOp {
|
||||
FxHashMap<char, CudaSlice<u8>>,
|
||||
);
|
||||
|
||||
/// Returns the output buffer size in elements.
|
||||
fn output_size(&self) -> Expression;
|
||||
|
||||
/// Returns the number of bytes this kernel will load from global memory.
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
0.into()
|
||||
}
|
||||
|
||||
/// Returns the number of bytes this kernel will store to global memory.
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
0.into()
|
||||
}
|
||||
|
||||
/// Returns the number of floating point operations this kernel performs.
|
||||
fn flops(&self) -> Expression {
|
||||
0.into()
|
||||
}
|
||||
|
||||
/// Returns the name of this kernel for profiling display.
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
"Unknown"
|
||||
}
|
||||
}
|
||||
|
||||
luminal::impl_into_ops!(KernelOp);
|
||||
|
||||
@@ -192,6 +192,22 @@ extern \"C\" {{
|
||||
fn output_size(&self) -> Expression {
|
||||
self.out_shape.iter().copied().product()
|
||||
}
|
||||
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
self.out_shape.iter().copied().product::<Expression>() * self.iters * 4
|
||||
}
|
||||
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
self.output_size() * 4
|
||||
}
|
||||
|
||||
fn flops(&self) -> Expression {
|
||||
self.out_shape.iter().copied().product::<Expression>() * self.iters
|
||||
}
|
||||
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
"MaxReduce"
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
@@ -366,6 +382,23 @@ extern \"C\" {{
|
||||
fn output_size(&self) -> Expression {
|
||||
self.out_shape.iter().copied().product()
|
||||
}
|
||||
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
self.out_shape.iter().copied().product::<Expression>() * self.iters * 4
|
||||
}
|
||||
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
self.output_size() * 4
|
||||
}
|
||||
|
||||
fn flops(&self) -> Expression {
|
||||
let n_outputs: Expression = self.out_shape.iter().copied().product();
|
||||
n_outputs * self.iters + n_outputs
|
||||
}
|
||||
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
"MeanReduce"
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
@@ -536,6 +569,22 @@ extern \"C\" {{
|
||||
fn output_size(&self) -> Expression {
|
||||
self.out_shape.iter().copied().product()
|
||||
}
|
||||
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
self.out_shape.iter().copied().product::<Expression>() * self.iters * 4
|
||||
}
|
||||
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
self.output_size() * 4
|
||||
}
|
||||
|
||||
fn flops(&self) -> Expression {
|
||||
self.out_shape.iter().copied().product::<Expression>() * self.iters
|
||||
}
|
||||
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
"SumReduce"
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
@@ -654,6 +703,22 @@ extern \"C\" {{
|
||||
fn output_size(&self) -> Expression {
|
||||
self.out_shape.iter().copied().product()
|
||||
}
|
||||
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
self.output_size() * 4 * 2
|
||||
}
|
||||
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
self.output_size() * 4
|
||||
}
|
||||
|
||||
fn flops(&self) -> Expression {
|
||||
self.out_shape.iter().copied().product()
|
||||
}
|
||||
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
"Add"
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
@@ -772,6 +837,22 @@ extern \"C\" {{
|
||||
fn output_size(&self) -> Expression {
|
||||
self.out_shape.iter().copied().product()
|
||||
}
|
||||
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
self.output_size() * 4 * 2
|
||||
}
|
||||
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
self.output_size() * 4
|
||||
}
|
||||
|
||||
fn flops(&self) -> Expression {
|
||||
self.out_shape.iter().copied().product()
|
||||
}
|
||||
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
"Mul"
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
@@ -894,6 +975,22 @@ extern \"C\" {{
|
||||
fn output_size(&self) -> Expression {
|
||||
self.out_shape.iter().copied().product()
|
||||
}
|
||||
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
self.output_size() * 4 * 2
|
||||
}
|
||||
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
self.output_size() * 4
|
||||
}
|
||||
|
||||
fn flops(&self) -> Expression {
|
||||
0.into()
|
||||
}
|
||||
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
"Gather"
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
@@ -993,4 +1090,20 @@ extern \"C\" {{
|
||||
fn output_size(&self) -> Expression {
|
||||
self.range
|
||||
}
|
||||
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
0.into()
|
||||
}
|
||||
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
self.output_size() * 4
|
||||
}
|
||||
|
||||
fn flops(&self) -> Expression {
|
||||
0.into()
|
||||
}
|
||||
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
"Iota"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
use crate::{block::*, kernel::KernelOp};
|
||||
use cudarc::driver::{CudaFunction, CudaSlice, CudaStream, DevicePtr, LaunchConfig, PushKernelArg};
|
||||
use cudarc::driver::{
|
||||
sys::CUevent_flags, CudaFunction, CudaSlice, CudaStream, DevicePtr, LaunchConfig, PushKernelArg,
|
||||
};
|
||||
use fixedbitset::FixedBitSet;
|
||||
use itertools::Itertools;
|
||||
use luminal::hlir::*;
|
||||
@@ -41,9 +43,56 @@ enum ExecutableKernel {
|
||||
inputs: Vec<NodeIndex>,
|
||||
output: NodeIndex,
|
||||
constants: FxHashMap<char, CudaSlice<u8>>,
|
||||
// Profiling metrics
|
||||
kernel_name: &'static str,
|
||||
bytes_loaded: Expression,
|
||||
bytes_stored: Expression,
|
||||
flops: Expression,
|
||||
},
|
||||
}
|
||||
|
||||
/// Statistics for a single kernel execution
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct KernelStats {
|
||||
pub name: &'static str,
|
||||
pub execution_time_us: f64,
|
||||
pub bytes_loaded: usize,
|
||||
pub bytes_stored: usize,
|
||||
pub flops: usize,
|
||||
pub bandwidth_gbps: f64,
|
||||
pub tflops: f64,
|
||||
}
|
||||
|
||||
/// Statistics for a single block op execution (aggregated across all SMs)
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct BlockOpStats {
|
||||
pub name: &'static str,
|
||||
pub execution_time_us: f64,
|
||||
pub bytes_loaded: usize,
|
||||
pub bytes_stored: usize,
|
||||
pub flops: usize,
|
||||
pub bandwidth_gbps: f64,
|
||||
pub tflops: f64,
|
||||
}
|
||||
|
||||
/// Aggregated execution statistics
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct ExecutionStats {
|
||||
pub kernel_stats: Vec<KernelStats>,
|
||||
pub block_op_stats: Vec<BlockOpStats>,
|
||||
pub total_time_us: f64,
|
||||
/// Total bytes loaded across all ops
|
||||
pub total_bytes_loaded: usize,
|
||||
/// Total bytes stored across all ops
|
||||
pub total_bytes_stored: usize,
|
||||
/// Total floating point operations across all ops
|
||||
pub total_flops: usize,
|
||||
/// Aggregate bandwidth in GB/s (total bytes / total time)
|
||||
pub aggregate_bandwidth_gbps: f64,
|
||||
/// Aggregate compute in TFLOPS (total flops / total time)
|
||||
pub aggregate_tflops: f64,
|
||||
}
|
||||
|
||||
impl Drop for ExecutableKernel {
|
||||
fn drop(&mut self) {
|
||||
match self {
|
||||
@@ -90,6 +139,8 @@ pub struct CudaRuntime {
|
||||
pub(crate) timings: Vec<(Vec<SMEvent>, u64, Uuid)>,
|
||||
last_dyn_map: FxHashMap<char, usize>,
|
||||
intermediate_buffer_dims: FxHashSet<char>,
|
||||
/// Statistics from the last execution
|
||||
pub last_execution_stats: ExecutionStats,
|
||||
}
|
||||
|
||||
impl CudaRuntime {
|
||||
@@ -282,6 +333,7 @@ impl Runtime for CudaRuntime {
|
||||
timings: vec![],
|
||||
last_dyn_map: FxHashMap::default(),
|
||||
intermediate_buffer_dims: FxHashSet::default(),
|
||||
last_execution_stats: ExecutionStats::default(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -332,6 +384,10 @@ impl Runtime for CudaRuntime {
|
||||
output: kernel,
|
||||
shared_mem,
|
||||
constants,
|
||||
kernel_name: kernel_op.kernel_name(),
|
||||
bytes_loaded: kernel_op.bytes_loaded(),
|
||||
bytes_stored: kernel_op.bytes_stored(),
|
||||
flops: kernel_op.flops(),
|
||||
}),
|
||||
);
|
||||
}
|
||||
@@ -368,10 +424,24 @@ impl Runtime for CudaRuntime {
|
||||
let start = std::time::Instant::now();
|
||||
self.execute(dyn_map);
|
||||
self.timings.clear();
|
||||
(
|
||||
start.elapsed(),
|
||||
pretty_duration::pretty_duration(&start.elapsed(), None),
|
||||
)
|
||||
|
||||
let duration = start.elapsed();
|
||||
let stats = &self.last_execution_stats;
|
||||
|
||||
// Compute MBU and MFU from execution stats
|
||||
let peak_bandwidth_gbps = crate::cuda_bandwidth_gbps(self.cuda_stream.context());
|
||||
let peak_tflops = crate::cuda_compute_f32_tflops(self.cuda_stream.context());
|
||||
|
||||
let mbu =
|
||||
peak_bandwidth_gbps.map(|peak_bw| stats.aggregate_bandwidth_gbps / peak_bw as f64);
|
||||
let mfu = peak_tflops.map(|peak_tf| stats.aggregate_tflops / peak_tf as f64);
|
||||
|
||||
let duration_str = pretty_duration::pretty_duration(&duration, None);
|
||||
let mbu_str = mbu.map_or("-".to_string(), |v| format!("{:.1}%", v * 100.0));
|
||||
let mfu_str = mfu.map_or("-".to_string(), |v| format!("{:.1}%", v * 100.0));
|
||||
let display = format!("{duration_str} | MBU: {mbu_str} | MFU: {mfu_str}");
|
||||
|
||||
(duration, display)
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
@@ -401,6 +471,8 @@ impl Runtime for CudaRuntime {
|
||||
self.register_buffer(llir_node, ptr);
|
||||
}
|
||||
let mut timings = vec![];
|
||||
let mut kernel_stats = Vec::new();
|
||||
let total_start = std::time::Instant::now();
|
||||
for exec_node in toposort(&self.exec_graph, None).unwrap() {
|
||||
match &mut self.exec_graph[exec_node] {
|
||||
ExecutableKernel::Kernel {
|
||||
@@ -412,6 +484,10 @@ impl Runtime for CudaRuntime {
|
||||
output,
|
||||
shared_mem,
|
||||
constants,
|
||||
kernel_name,
|
||||
bytes_loaded,
|
||||
bytes_stored,
|
||||
flops,
|
||||
} => {
|
||||
for (dyn_dim, val) in dyn_map {
|
||||
if let Some(global) = constants.get_mut(dyn_dim) {
|
||||
@@ -453,8 +529,48 @@ impl Runtime for CudaRuntime {
|
||||
}
|
||||
let span = span!(Level::INFO, "kernel");
|
||||
let _entered = span.enter();
|
||||
|
||||
// Use CUDA events for accurate GPU-side timing
|
||||
let start_event = self
|
||||
.cuda_stream
|
||||
.context()
|
||||
.new_event(Some(CUevent_flags::CU_EVENT_DEFAULT))
|
||||
.unwrap();
|
||||
let end_event = self
|
||||
.cuda_stream
|
||||
.context()
|
||||
.new_event(Some(CUevent_flags::CU_EVENT_DEFAULT))
|
||||
.unwrap();
|
||||
|
||||
start_event.record(&self.cuda_stream).unwrap();
|
||||
unsafe { lb.launch(cfg) }.unwrap();
|
||||
self.cuda_stream.synchronize().unwrap();
|
||||
end_event.record(&self.cuda_stream).unwrap();
|
||||
|
||||
// elapsed_ms synchronizes internally
|
||||
let kernel_time_ms = start_event.elapsed_ms(&end_event).unwrap();
|
||||
let kernel_time_us = kernel_time_ms as f64 * 1000.0;
|
||||
|
||||
// Calculate metrics
|
||||
let loaded = bytes_loaded.exec(dyn_map).unwrap_or(0);
|
||||
let stored = bytes_stored.exec(dyn_map).unwrap_or(0);
|
||||
let flop_count = flops.exec(dyn_map).unwrap_or(0);
|
||||
|
||||
// Calculate bandwidth (GB/s) and compute (TFLOPS)
|
||||
// Total memory traffic = bytes loaded + bytes stored
|
||||
let total_bytes = loaded + stored;
|
||||
let bandwidth_gbps = (total_bytes as f64) / (kernel_time_us * 1e-6) / 1e9;
|
||||
let tflops = (flop_count as f64) / (kernel_time_us * 1e-6) / 1e12;
|
||||
|
||||
kernel_stats.push(KernelStats {
|
||||
name: *kernel_name,
|
||||
execution_time_us: kernel_time_us,
|
||||
bytes_loaded: loaded,
|
||||
bytes_stored: stored,
|
||||
flops: flop_count,
|
||||
bandwidth_gbps,
|
||||
tflops,
|
||||
});
|
||||
|
||||
drop(_entered);
|
||||
drop(span);
|
||||
}
|
||||
@@ -551,7 +667,396 @@ impl Runtime for CudaRuntime {
|
||||
}
|
||||
}
|
||||
}
|
||||
self.timings.extend(timings);
|
||||
self.timings.extend(timings.clone());
|
||||
|
||||
{
|
||||
let span = span!(Level::TRACE, "timings");
|
||||
let _entered = span.enter();
|
||||
// Compute block op stats from SMEvent timings
|
||||
let sm_count = self
|
||||
.cuda_stream
|
||||
.context()
|
||||
.attribute(cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT)
|
||||
.unwrap() as usize;
|
||||
let block_op_stats =
|
||||
Self::compute_block_op_stats(&self.llir_graph, &timings, dyn_map, sm_count);
|
||||
|
||||
// Compute aggregate stats from kernel and block op stats
|
||||
let total_time_us = total_start.elapsed().as_secs_f64() * 1_000_000.0;
|
||||
|
||||
let total_bytes_loaded: usize =
|
||||
kernel_stats.iter().map(|s| s.bytes_loaded).sum::<usize>()
|
||||
+ block_op_stats.iter().map(|s| s.bytes_loaded).sum::<usize>();
|
||||
let total_bytes_stored: usize =
|
||||
kernel_stats.iter().map(|s| s.bytes_stored).sum::<usize>()
|
||||
+ block_op_stats.iter().map(|s| s.bytes_stored).sum::<usize>();
|
||||
let total_flops: usize = kernel_stats.iter().map(|s| s.flops).sum::<usize>()
|
||||
+ block_op_stats.iter().map(|s| s.flops).sum::<usize>();
|
||||
|
||||
let total_bytes = total_bytes_loaded + total_bytes_stored;
|
||||
let aggregate_bandwidth_gbps = if total_time_us > 0.0 {
|
||||
(total_bytes as f64) / (total_time_us * 1e-6) / 1e9
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
let aggregate_tflops = if total_time_us > 0.0 {
|
||||
(total_flops as f64) / (total_time_us * 1e-6) / 1e12
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
// Store execution stats
|
||||
self.last_execution_stats = ExecutionStats {
|
||||
kernel_stats,
|
||||
block_op_stats,
|
||||
total_time_us,
|
||||
total_bytes_loaded,
|
||||
total_bytes_stored,
|
||||
total_flops,
|
||||
aggregate_bandwidth_gbps,
|
||||
aggregate_tflops,
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CudaRuntime {
|
||||
fn compute_block_op_stats(
|
||||
llir_graph: &LLIRGraph,
|
||||
timings: &[(Vec<SMEvent>, u64, Uuid)],
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
sm_count: usize,
|
||||
) -> Vec<BlockOpStats> {
|
||||
use std::collections::HashMap;
|
||||
|
||||
// Get unique op names (same order as in interpreter)
|
||||
let op_names: Vec<&'static str> = llir_graph
|
||||
.node_indices()
|
||||
.filter_map(|n| llir_graph[n].to_dialect::<dyn BlockOp>())
|
||||
.map(|bo| (bo.op_name(), bo.clone()))
|
||||
.collect::<HashMap<_, _>>()
|
||||
.into_iter()
|
||||
.sorted_by_key(|(n, _)| *n)
|
||||
.map(|(n, _)| n)
|
||||
.collect();
|
||||
|
||||
if op_names.is_empty() {
|
||||
return vec![];
|
||||
}
|
||||
|
||||
// Build map from op_name to index for event decoding
|
||||
let op_name_to_idx: HashMap<&'static str, usize> =
|
||||
op_names.iter().enumerate().map(|(i, n)| (*n, i)).collect();
|
||||
|
||||
// Sum up bytes_loaded, bytes_stored, flops across ALL instances of each op type
|
||||
let mut op_bytes_loaded: Vec<usize> = vec![0; op_names.len()];
|
||||
let mut op_bytes_stored: Vec<usize> = vec![0; op_names.len()];
|
||||
let mut op_flops: Vec<usize> = vec![0; op_names.len()];
|
||||
|
||||
for node in llir_graph.node_indices() {
|
||||
if let Some(op) = llir_graph[node].to_dialect::<dyn BlockOp>() {
|
||||
if let Some(&idx) = op_name_to_idx.get(op.op_name()) {
|
||||
let flops_expr = op.flops();
|
||||
let flops_val = flops_expr.exec(dyn_map).unwrap();
|
||||
if op.op_name() == "TileMatmul" {
|
||||
tracing::debug!(
|
||||
"TileMatmul flops: expr={:?}, val={}",
|
||||
flops_expr,
|
||||
flops_val
|
||||
);
|
||||
}
|
||||
op_bytes_loaded[idx] += op.bytes_loaded().exec(dyn_map).unwrap();
|
||||
op_bytes_stored[idx] += op.bytes_stored().exec(dyn_map).unwrap();
|
||||
op_flops[idx] += flops_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Aggregate timing per op type across all SMs
|
||||
// event codes: 0=Issue, 1=Wait, 2+=BlockOps (index in ops list)
|
||||
let mut op_times_ns: Vec<u64> = vec![0; op_names.len()];
|
||||
|
||||
for (sm_timings, _start_time, _) in timings {
|
||||
for sm_chunk in sm_timings.chunks(1000) {
|
||||
for i in 0..sm_chunk.len().saturating_sub(1) {
|
||||
let event = sm_chunk[i].event;
|
||||
let next_start = sm_chunk[i + 1].start;
|
||||
if next_start == 0 {
|
||||
break; // No more events recorded for this SM
|
||||
}
|
||||
// event >= 2 means it's a block op (0=Issue, 1=Wait)
|
||||
if event >= 2 {
|
||||
let op_idx = (event - 2) as usize;
|
||||
if op_idx < op_names.len() {
|
||||
let duration = next_start.saturating_sub(sm_chunk[i].start);
|
||||
op_times_ns[op_idx] += duration;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Compute bandwidth and TFLOPS from aggregated metrics
|
||||
// Divide total SM-time by sm_count to get wall-clock time
|
||||
op_names
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|(i, _)| op_times_ns[*i] > 0)
|
||||
.map(|(i, &name)| {
|
||||
// Total SM-time divided by number of SMs gives wall-clock time
|
||||
let time_us = (op_times_ns[i] as f64 / sm_count as f64) / 1000.0;
|
||||
let bytes_loaded = op_bytes_loaded[i];
|
||||
let bytes_stored = op_bytes_stored[i];
|
||||
let flop_count = op_flops[i];
|
||||
|
||||
let total_bytes = bytes_loaded + bytes_stored;
|
||||
let bandwidth_gbps = if time_us > 0.0 {
|
||||
(total_bytes as f64) / (time_us * 1e-6) / 1e9
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
let tflops = if time_us > 0.0 {
|
||||
(flop_count as f64) / (time_us * 1e-6) / 1e12
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
BlockOpStats {
|
||||
name,
|
||||
execution_time_us: time_us,
|
||||
bytes_loaded,
|
||||
bytes_stored,
|
||||
flops: flop_count,
|
||||
bandwidth_gbps,
|
||||
tflops,
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Print execution statistics for the last execution.
|
||||
/// Shows bandwidth and compute utilization for each kernel and block op.
|
||||
pub fn print_execution_stats(&self) {
|
||||
let stats = &self.last_execution_stats;
|
||||
if stats.kernel_stats.is_empty() && stats.block_op_stats.is_empty() {
|
||||
println!("No execution stats available.");
|
||||
return;
|
||||
}
|
||||
|
||||
// Get device peak performance
|
||||
let peak_bandwidth_gbps = crate::cuda_bandwidth_gbps(self.cuda_stream.context());
|
||||
let peak_tflops = crate::cuda_compute_f32_tflops(self.cuda_stream.context());
|
||||
|
||||
// Print kernel stats if any
|
||||
if !stats.kernel_stats.is_empty() {
|
||||
println!("\n=== Kernel Execution Statistics ===\n");
|
||||
println!(
|
||||
"{:<20} {:>12} {:>12} {:>12} {:>12} {:>12} {:>12} {:>8} {:>8}",
|
||||
"Kernel",
|
||||
"Time (us)",
|
||||
"Loaded",
|
||||
"Stored",
|
||||
"Agg FLOPS",
|
||||
"BW (GB/s)",
|
||||
"TFLOPS",
|
||||
"MBU",
|
||||
"MFU"
|
||||
);
|
||||
println!("{}", "-".repeat(116));
|
||||
|
||||
for stat in &stats.kernel_stats {
|
||||
self.print_stat_row(
|
||||
stat.name,
|
||||
stat.execution_time_us,
|
||||
stat.bytes_loaded,
|
||||
stat.bytes_stored,
|
||||
stat.flops,
|
||||
stat.bandwidth_gbps,
|
||||
stat.tflops,
|
||||
peak_bandwidth_gbps,
|
||||
peak_tflops,
|
||||
);
|
||||
}
|
||||
|
||||
println!("{}", "-".repeat(116));
|
||||
}
|
||||
|
||||
// Print block op stats if any
|
||||
if !stats.block_op_stats.is_empty() {
|
||||
println!("\n=== Block Op Execution Statistics ===\n");
|
||||
println!(
|
||||
"{:<20} {:>12} {:>12} {:>12} {:>12} {:>12} {:>12} {:>8} {:>8}",
|
||||
"BlockOp",
|
||||
"Time (us)",
|
||||
"Loaded",
|
||||
"Stored",
|
||||
"Agg FLOPS",
|
||||
"BW (GB/s)",
|
||||
"TFLOPS",
|
||||
"MBU",
|
||||
"MFU"
|
||||
);
|
||||
println!("{}", "-".repeat(116));
|
||||
|
||||
for stat in &stats.block_op_stats {
|
||||
self.print_stat_row(
|
||||
stat.name,
|
||||
stat.execution_time_us,
|
||||
stat.bytes_loaded,
|
||||
stat.bytes_stored,
|
||||
stat.flops,
|
||||
stat.bandwidth_gbps,
|
||||
stat.tflops,
|
||||
peak_bandwidth_gbps,
|
||||
peak_tflops,
|
||||
);
|
||||
}
|
||||
|
||||
println!("{}", "-".repeat(116));
|
||||
}
|
||||
|
||||
// Print aggregate stats
|
||||
println!("\n=== Aggregate Statistics ===\n");
|
||||
println!(
|
||||
"{:<20} {:>12} {:>12} {:>12} {:>12} {:>12} {:>12} {:>8} {:>8}",
|
||||
"", "Time (us)", "Loaded", "Stored", "Agg FLOPS", "BW (GB/s)", "TFLOPS", "MBU", "MFU"
|
||||
);
|
||||
println!("{}", "-".repeat(116));
|
||||
|
||||
let (mbu_str, mfu_str) =
|
||||
if let (Some(peak_bw), Some(peak_tf)) = (peak_bandwidth_gbps, peak_tflops) {
|
||||
let mbu = (stats.aggregate_bandwidth_gbps / peak_bw as f64) * 100.0;
|
||||
let mfu = (stats.aggregate_tflops / peak_tf as f64) * 100.0;
|
||||
(format!("{:.1}%", mbu), format!("{:.1}%", mfu))
|
||||
} else {
|
||||
("-".to_string(), "-".to_string())
|
||||
};
|
||||
|
||||
println!(
|
||||
"{:<20} {:>12.2} {:>12} {:>12} {:>12} {:>12} {:>12} {:>8} {:>8}",
|
||||
"Total",
|
||||
stats.total_time_us,
|
||||
format_size(stats.total_bytes_loaded),
|
||||
format_size(stats.total_bytes_stored),
|
||||
format_flops(stats.total_flops),
|
||||
format!("{:.2}", stats.aggregate_bandwidth_gbps),
|
||||
format!("{:.4}", stats.aggregate_tflops),
|
||||
mbu_str,
|
||||
mfu_str
|
||||
);
|
||||
|
||||
// Print device info
|
||||
if let (Some(peak_bw), Some(peak_tf)) = (peak_bandwidth_gbps, peak_tflops) {
|
||||
println!(
|
||||
"\nDevice peak: {} GB/s bandwidth, {} TFLOPS (F32)",
|
||||
peak_bw, peak_tf
|
||||
);
|
||||
}
|
||||
println!();
|
||||
}
|
||||
|
||||
fn print_stat_row(
|
||||
&self,
|
||||
name: &str,
|
||||
execution_time_us: f64,
|
||||
bytes_loaded: usize,
|
||||
bytes_stored: usize,
|
||||
flops: usize,
|
||||
bandwidth_gbps: f64,
|
||||
tflops: f64,
|
||||
peak_bandwidth_gbps: Option<usize>,
|
||||
peak_tflops: Option<usize>,
|
||||
) {
|
||||
let total_bytes = bytes_loaded + bytes_stored;
|
||||
|
||||
// Show "-" if bytes are 0 (not specified)
|
||||
let loaded_str = if bytes_loaded > 0 {
|
||||
format_size(bytes_loaded)
|
||||
} else {
|
||||
"-".to_string()
|
||||
};
|
||||
let stored_str = if bytes_stored > 0 {
|
||||
format_size(bytes_stored)
|
||||
} else {
|
||||
"-".to_string()
|
||||
};
|
||||
let flops_str = if flops > 0 {
|
||||
format_flops(flops)
|
||||
} else {
|
||||
"-".to_string()
|
||||
};
|
||||
let bw_str = if total_bytes > 0 {
|
||||
format!("{:.2}", bandwidth_gbps)
|
||||
} else {
|
||||
"-".to_string()
|
||||
};
|
||||
let tflops_str = if flops > 0 {
|
||||
format!("{:.4}", tflops)
|
||||
} else {
|
||||
"-".to_string()
|
||||
};
|
||||
|
||||
// Calculate MBU (Memory Bandwidth Utilization) and MFU (Model FLOPS Utilization)
|
||||
let mbu_str = if let Some(peak_bw) = peak_bandwidth_gbps {
|
||||
if total_bytes > 0 {
|
||||
let mbu = (bandwidth_gbps / peak_bw as f64) * 100.0;
|
||||
format!("{:.1}%", mbu)
|
||||
} else {
|
||||
"-".to_string()
|
||||
}
|
||||
} else {
|
||||
"-".to_string()
|
||||
};
|
||||
|
||||
let mfu_str = if let Some(peak_tf) = peak_tflops {
|
||||
if flops > 0 {
|
||||
let mfu = (tflops / peak_tf as f64) * 100.0;
|
||||
format!("{:.1}%", mfu)
|
||||
} else {
|
||||
"-".to_string()
|
||||
}
|
||||
} else {
|
||||
"-".to_string()
|
||||
};
|
||||
|
||||
println!(
|
||||
"{:<20} {:>12.2} {:>12} {:>12} {:>12} {:>12} {:>12} {:>8} {:>8}",
|
||||
name,
|
||||
execution_time_us,
|
||||
loaded_str,
|
||||
stored_str,
|
||||
flops_str,
|
||||
bw_str,
|
||||
tflops_str,
|
||||
mbu_str,
|
||||
mfu_str
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn format_size(bytes: usize) -> String {
|
||||
if bytes >= 1_000_000_000 {
|
||||
format!("{:.2} GB", bytes as f64 / 1e9)
|
||||
} else if bytes >= 1_000_000 {
|
||||
format!("{:.2} MB", bytes as f64 / 1e6)
|
||||
} else if bytes >= 1_000 {
|
||||
format!("{:.2} KB", bytes as f64 / 1e3)
|
||||
} else {
|
||||
format!("{} B", bytes)
|
||||
}
|
||||
}
|
||||
|
||||
fn format_flops(flops: usize) -> String {
|
||||
if flops >= 1_000_000_000_000 {
|
||||
format!("{:.2} T", flops as f64 / 1e12)
|
||||
} else if flops >= 1_000_000_000 {
|
||||
format!("{:.2} G", flops as f64 / 1e9)
|
||||
} else if flops >= 1_000_000 {
|
||||
format!("{:.2} M", flops as f64 / 1e6)
|
||||
} else if flops >= 1_000 {
|
||||
format!("{:.2} K", flops as f64 / 1e3)
|
||||
} else {
|
||||
format!("{}", flops)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,178 +1,243 @@
|
||||
use candle_core::{Device, Tensor};
|
||||
use cudarc::driver::CudaContext;
|
||||
use luminal::prelude::*;
|
||||
use proptest::prelude::*;
|
||||
use rand::{rngs::StdRng, Rng, SeedableRng};
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::cuda_bandwidth_gbps;
|
||||
use crate::runtime::CudaRuntime;
|
||||
|
||||
fn random_vec(n: usize) -> Vec<f32> {
|
||||
let mut rng = StdRng::seed_from_u64(0);
|
||||
(0..n).map(|_| rng.random_range(-0.5..0.5)).collect()
|
||||
}
|
||||
|
||||
fn assert_close(a_vec: &[f32], b_vec: &[f32]) {
|
||||
assert_close_precision(a_vec, b_vec, 1e-3);
|
||||
}
|
||||
|
||||
fn assert_close_precision(a_vec: &[f32], b_vec: &[f32], threshold: f32) {
|
||||
assert_eq!(a_vec.len(), b_vec.len(), "Number of elements doesn't match");
|
||||
for (i, (a, b)) in a_vec.iter().zip(b_vec.iter()).enumerate() {
|
||||
if (a - b).abs() > threshold {
|
||||
panic!(
|
||||
"{a} is not close to {b}, index {i}, avg distance: {}",
|
||||
a_vec
|
||||
.iter()
|
||||
.zip(b_vec.iter())
|
||||
.map(|(a, b)| (a - b).abs())
|
||||
.sum::<f32>()
|
||||
/ a_vec.len() as f32
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn get_cuda_stream() -> Option<Arc<cudarc::driver::CudaStream>> {
|
||||
let ctx = CudaContext::new(0).ok()?;
|
||||
ctx.bind_to_thread().ok()?;
|
||||
Some(ctx.default_stream())
|
||||
}
|
||||
|
||||
/// Test a unary operation on CUDA against candle reference
|
||||
pub fn test_unary(
|
||||
shape: impl ToShape,
|
||||
func: impl Fn(GraphTensor) -> GraphTensor,
|
||||
ref_func: impl Fn(Tensor) -> Tensor,
|
||||
) {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
|
||||
let shape: Vec<usize> = shape
|
||||
.to_shape()
|
||||
.into_iter()
|
||||
.map(|e| e.to_usize().unwrap())
|
||||
.collect();
|
||||
let n_elements: usize = shape.iter().product();
|
||||
|
||||
let mut cx = Graph::default();
|
||||
let a = cx.tensor(shape.clone());
|
||||
let b = func(a).output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
|
||||
let input_data = random_vec(n_elements);
|
||||
rt.set_data(a, input_data.clone());
|
||||
rt = cx.search(rt, 5);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
let result = rt.get_f32(b);
|
||||
|
||||
// Reference using candle
|
||||
let device = Device::Cpu;
|
||||
let ref_a = Tensor::from_vec(input_data, shape, &device).unwrap();
|
||||
let ref_b = ref_func(ref_a).flatten_all().unwrap();
|
||||
|
||||
assert_close(&result, &ref_b.to_vec1::<f32>().unwrap());
|
||||
}
|
||||
|
||||
/// Test a binary operation on CUDA against candle reference
|
||||
pub fn test_binary(
|
||||
a_shape: impl ToShape,
|
||||
b_shape: impl ToShape,
|
||||
func: impl Fn(GraphTensor, GraphTensor) -> GraphTensor,
|
||||
ref_func: impl Fn(Tensor, Tensor) -> Tensor,
|
||||
) {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
|
||||
let a_shape: Vec<usize> = a_shape
|
||||
.to_shape()
|
||||
.into_iter()
|
||||
.map(|e| e.to_usize().unwrap())
|
||||
.collect();
|
||||
let b_shape: Vec<usize> = b_shape
|
||||
.to_shape()
|
||||
.into_iter()
|
||||
.map(|e| e.to_usize().unwrap())
|
||||
.collect();
|
||||
let a_elements: usize = a_shape.iter().product();
|
||||
let b_elements: usize = b_shape.iter().product();
|
||||
|
||||
let mut cx = Graph::default();
|
||||
let a = cx.tensor(a_shape.clone());
|
||||
let b = cx.tensor(b_shape.clone());
|
||||
let c = func(a, b).output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
|
||||
let a_data = random_vec(a_elements);
|
||||
let b_data = random_vec(b_elements);
|
||||
rt.set_data(a, a_data.clone());
|
||||
rt.set_data(b, b_data.clone());
|
||||
rt = cx.search(rt, 5);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
let result = rt.get_f32(c);
|
||||
|
||||
// Reference using candle
|
||||
let device = Device::Cpu;
|
||||
let ref_a = Tensor::from_vec(a_data, a_shape, &device).unwrap();
|
||||
let ref_b = Tensor::from_vec(b_data, b_shape, &device).unwrap();
|
||||
let ref_c = ref_func(ref_a, ref_b).flatten_all().unwrap();
|
||||
|
||||
assert_close(&result, &ref_c.to_vec1::<f32>().unwrap());
|
||||
}
|
||||
|
||||
proptest! {
|
||||
#![proptest_config(ProptestConfig::with_cases(5))]
|
||||
|
||||
#[test]
|
||||
fn cuda_test(len in 1usize..32, values in proptest::collection::vec(-5.0f32..5.0, 1..64)) {
|
||||
prop_assume!(values.len() >= len);
|
||||
let ctx = match CudaContext::new(0) {
|
||||
Ok(ctx) => ctx,
|
||||
Err(_) => return Ok(()),
|
||||
};
|
||||
let mut cx = Graph::default();
|
||||
let input = cx.tensor(len);
|
||||
let output = (input + input).output();
|
||||
fn test_add(x in 1usize..100, y in 1usize..5) {
|
||||
test_binary(x, x, |a, b| a + b, |a, b| (&a + &b).unwrap());
|
||||
test_binary((y, x), (y, x), |a, b| a + b, |a, b| (&a + &b).unwrap());
|
||||
}
|
||||
|
||||
ctx.bind_to_thread().unwrap();
|
||||
let stream = ctx.default_stream();
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
let input_values = values.into_iter().take(len).collect::<Vec<f32>>();
|
||||
rt.set_data(input, input_values.clone());
|
||||
rt = cx.search(rt, 10);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let out = rt.get_f32(output);
|
||||
let expected = input_values.into_iter().map(|v| v * 2.0).collect::<Vec<f32>>();
|
||||
assert_eq!(out, expected);
|
||||
#[test]
|
||||
fn test_mul(x in 1usize..100, y in 1usize..5) {
|
||||
test_binary(x, x, |a, b| a * b, |a, b| (&a * &b).unwrap());
|
||||
test_binary((y, x), (y, x), |a, b| a * b, |a, b| (&a * &b).unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_max(rows in 1usize..8, cols in 1usize..8) {
|
||||
test_unary((rows, cols), |a| a.max(1), |a| a.max(1).unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mean(rows in 1usize..8, cols in 1usize..8) {
|
||||
test_unary((rows, cols), |a| a.mean(1), |a| a.mean(1).unwrap());
|
||||
}
|
||||
}
|
||||
|
||||
// #[test] // see kernel/ops.rs for why this is disabled
|
||||
// pub fn cuda_sum_reduce_test() {
|
||||
// let mut cx = Graph::default();
|
||||
// let input = cx.tensor((1000, 1000));
|
||||
// let sum_dim0 = input.sum(0).output(); // row sum
|
||||
// let sum_dim1 = input.sum(1).output(); // col sum
|
||||
|
||||
// let data: Vec<f32> = (0..1_000_000).map(|i| (i % 100) as f32 * 0.01).collect();
|
||||
|
||||
// let expected_dim0: Vec<f32> = (0..1000)
|
||||
// .map(|col| (0..1000).map(|row| data[row * 1000 + col]).sum())
|
||||
// .collect();
|
||||
// let expected_dim1: Vec<f32> = (0..1000)
|
||||
// .map(|row| (0..1000).map(|col| data[row * 1000 + col]).sum())
|
||||
// .collect();
|
||||
|
||||
// let ctx = CudaContext::new(0).unwrap();
|
||||
// ctx.bind_to_thread().unwrap();
|
||||
// let stream = ctx.default_stream();
|
||||
// cx.build_search_space::<CudaRuntime>();
|
||||
// let mut rt = CudaRuntime::initialize(stream);
|
||||
// rt.set_data(input, data);
|
||||
// rt = cx.search(rt, 10);
|
||||
// rt.execute(&cx.dyn_map);
|
||||
|
||||
// let out_dim0 = rt.get_f32(sum_dim0);
|
||||
// let out_dim1 = rt.get_f32(sum_dim1);
|
||||
|
||||
// for i in 0..1000 {
|
||||
// let rel_err_0 = (out_dim0[i] - expected_dim0[i]).abs() / expected_dim0[i].abs().max(1.0);
|
||||
// let rel_err_1 = (out_dim1[i] - expected_dim1[i]).abs() / expected_dim1[i].abs().max(1.0);
|
||||
// assert!(
|
||||
// rel_err_0 < 0.001,
|
||||
// "dim0 mismatch at {i}: got {}, expected {}",
|
||||
// out_dim0[i],
|
||||
// expected_dim0[i]
|
||||
// );
|
||||
// assert!(
|
||||
// rel_err_1 < 0.001,
|
||||
// "dim1 mismatch at {i}: got {}, expected {}",
|
||||
// out_dim1[i],
|
||||
// expected_dim1[i]
|
||||
// );
|
||||
// }
|
||||
// }
|
||||
|
||||
/// Test that measures bandwidth utilization for a large element-wise add kernel.
|
||||
/// This demonstrates that KernelAdd can achieve reasonable bandwidth with large tensors.
|
||||
#[test]
|
||||
pub fn cuda_max_reduce_test() {
|
||||
pub fn kernel_add_bandwidth_test() {
|
||||
// 64M elements = 256MB per tensor, 768MB total memory traffic (2 reads + 1 write)
|
||||
let size = 64 * 1024 * 1024;
|
||||
|
||||
let mut cx = Graph::default();
|
||||
let input = cx.tensor((1000, 1000));
|
||||
let max_dim0 = input.max(0).output(); // row max
|
||||
let max_dim1 = input.max(1).output(); // col max
|
||||
let a = cx.tensor(size);
|
||||
let b = cx.tensor(size);
|
||||
let output = (a + b).output();
|
||||
|
||||
let data: Vec<f32> = (0..1_000_000).map(|i| (i % 100) as f32 * 0.01).collect();
|
||||
|
||||
let expected_dim0: Vec<f32> = (0..1000)
|
||||
.map(|col| {
|
||||
(0..1000)
|
||||
.map(|row| data[row * 1000 + col])
|
||||
.fold(f32::NEG_INFINITY, f32::max)
|
||||
})
|
||||
.collect();
|
||||
let expected_dim1: Vec<f32> = (0..1000)
|
||||
.map(|row| {
|
||||
(0..1000)
|
||||
.map(|col| data[row * 1000 + col])
|
||||
.fold(f32::NEG_INFINITY, f32::max)
|
||||
})
|
||||
// Generate test data
|
||||
let data_a: Vec<f32> = (0..size).map(|i| (i % 1000) as f32 * 0.001).collect();
|
||||
let data_b: Vec<f32> = (0..size)
|
||||
.map(|i| ((i + 500) % 1000) as f32 * 0.001)
|
||||
.collect();
|
||||
|
||||
let ctx = CudaContext::new(0).unwrap();
|
||||
ctx.bind_to_thread().unwrap();
|
||||
let stream = ctx.default_stream();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.set_data(input, data);
|
||||
rt = cx.search(rt, 10);
|
||||
let mut rt = CudaRuntime::initialize(stream.clone());
|
||||
rt.set_data(a, data_a.clone());
|
||||
rt.set_data(b, data_b.clone());
|
||||
rt = cx.search(rt, 5);
|
||||
|
||||
// Warm up
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
let out_dim0 = rt.get_f32(max_dim0);
|
||||
let out_dim1 = rt.get_f32(max_dim1);
|
||||
|
||||
for i in 0..1000 {
|
||||
let rel_err_0 = (out_dim0[i] - expected_dim0[i]).abs() / expected_dim0[i].abs().max(1.0);
|
||||
let rel_err_1 = (out_dim1[i] - expected_dim1[i]).abs() / expected_dim1[i].abs().max(1.0);
|
||||
assert!(
|
||||
rel_err_0 < 0.001,
|
||||
"dim0 mismatch at {i}: got {}, expected {}",
|
||||
out_dim0[i],
|
||||
expected_dim0[i]
|
||||
);
|
||||
assert!(
|
||||
rel_err_1 < 0.001,
|
||||
"dim1 mismatch at {i}: got {}, expected {}",
|
||||
out_dim1[i],
|
||||
expected_dim1[i]
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn cuda_mean_reduce_test() {
|
||||
let mut cx = Graph::default();
|
||||
let input = cx.tensor((1000, 1000));
|
||||
let mean_dim0 = input.mean(0).output(); // mean along rows
|
||||
let mean_dim1 = input.mean(1).output(); // mean along cols
|
||||
|
||||
let data: Vec<f32> = (0..1_000_000).map(|i| (i % 100) as f32 * 0.01).collect();
|
||||
|
||||
let expected_dim0: Vec<f32> = (0..1000)
|
||||
.map(|col| (0..1000).map(|row| data[row * 1000 + col]).sum::<f32>() / 1000.0)
|
||||
.collect();
|
||||
let expected_dim1: Vec<f32> = (0..1000)
|
||||
.map(|row| (0..1000).map(|col| data[row * 1000 + col]).sum::<f32>() / 1000.0)
|
||||
.collect();
|
||||
|
||||
let ctx = CudaContext::new(0).unwrap();
|
||||
ctx.bind_to_thread().unwrap();
|
||||
let stream = ctx.default_stream();
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.set_data(input, data);
|
||||
rt = cx.search(rt, 10);
|
||||
// Run and measure
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
let out_dim0 = rt.get_f32(mean_dim0);
|
||||
let out_dim1 = rt.get_f32(mean_dim1);
|
||||
// Print stats
|
||||
println!("\n=== Large KernelAdd Bandwidth Test ===");
|
||||
println!(
|
||||
"Tensor size: {} elements ({} MB per tensor)",
|
||||
size,
|
||||
size * 4 / 1024 / 1024
|
||||
);
|
||||
println!(
|
||||
"Total memory traffic: {} MB (2 reads + 1 write)",
|
||||
size * 4 * 3 / 1024 / 1024
|
||||
);
|
||||
rt.print_execution_stats();
|
||||
|
||||
for i in 0..1000 {
|
||||
let rel_err_0 = (out_dim0[i] - expected_dim0[i]).abs() / expected_dim0[i].abs().max(1.0);
|
||||
let rel_err_1 = (out_dim1[i] - expected_dim1[i]).abs() / expected_dim1[i].abs().max(1.0);
|
||||
// Verify correctness (spot check)
|
||||
let result = rt.get_f32(output);
|
||||
for i in [0, size / 2, size - 1] {
|
||||
let expected = data_a[i] + data_b[i];
|
||||
let got = result[i];
|
||||
assert!(
|
||||
rel_err_0 < 0.001,
|
||||
"dim0 mismatch at {i}: got {}, expected {}",
|
||||
out_dim0[i],
|
||||
expected_dim0[i]
|
||||
);
|
||||
assert!(
|
||||
rel_err_1 < 0.001,
|
||||
"dim1 mismatch at {i}: got {}, expected {}",
|
||||
out_dim1[i],
|
||||
expected_dim1[i]
|
||||
(got - expected).abs() < 1e-5,
|
||||
"Mismatch at {}: expected {}, got {}",
|
||||
i,
|
||||
expected,
|
||||
got
|
||||
);
|
||||
}
|
||||
|
||||
// Check bandwidth is reasonable (at least 50% of peak for large kernels)
|
||||
let stats = &rt.last_execution_stats;
|
||||
if let Some(peak_bw) = cuda_bandwidth_gbps(&ctx) {
|
||||
for stat in &stats.kernel_stats {
|
||||
let total_bytes = stat.bytes_loaded + stat.bytes_stored;
|
||||
if stat.name == "Add" && total_bytes > 0 {
|
||||
let utilization = stat.bandwidth_gbps / peak_bw as f64 * 100.0;
|
||||
println!(
|
||||
"\nAdd kernel achieved {:.1} GB/s ({:.1}% of {:.0} GB/s peak)",
|
||||
stat.bandwidth_gbps, utilization, peak_bw
|
||||
);
|
||||
println!(
|
||||
" Loaded: {} bytes, Stored: {} bytes",
|
||||
stat.bytes_loaded, stat.bytes_stored
|
||||
);
|
||||
// Large adds should achieve decent bandwidth
|
||||
assert!(
|
||||
utilization > 50.0,
|
||||
"Bandwidth utilization too low: {:.1}%",
|
||||
utilization
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,118 +0,0 @@
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use crate::model::{HEAD_DIM, HIDDEN, INTERMEDIATE, KV_GROUPS, LAYERS, VOCAB_SIZE};
|
||||
|
||||
/// Handles all benchmarking metrics and reporting for the LLM inference
|
||||
pub struct Benchmarker {
|
||||
start_generation: Instant,
|
||||
ttft: Duration,
|
||||
decode_durations: Vec<Duration>,
|
||||
seq_lengths: Vec<(usize, usize)>,
|
||||
current_iter_start: Option<Instant>,
|
||||
}
|
||||
|
||||
impl Benchmarker {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
start_generation: Instant::now(),
|
||||
ttft: Duration::default(),
|
||||
decode_durations: vec![],
|
||||
seq_lengths: vec![],
|
||||
current_iter_start: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Mark the start of an iteration (prefill or decode)
|
||||
pub fn start_iteration(&mut self, seq_len: usize, prev_seq: usize) {
|
||||
self.current_iter_start = Some(Instant::now());
|
||||
self.seq_lengths.push((seq_len, prev_seq));
|
||||
}
|
||||
|
||||
/// Mark the end of an iteration and record timing
|
||||
pub fn end_iteration(&mut self, iteration: usize) {
|
||||
if let Some(start) = self.current_iter_start.take() {
|
||||
let duration = start.elapsed();
|
||||
if iteration == 0 {
|
||||
self.ttft = duration;
|
||||
} else {
|
||||
self.decode_durations.push(duration);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Print the benchmark results to stdout
|
||||
pub fn report(&self, peak_tflops: f64, peak_gbps: f64) {
|
||||
let total_elapsed = self.start_generation.elapsed();
|
||||
let decode_total = self
|
||||
.decode_durations
|
||||
.iter()
|
||||
.fold(Duration::ZERO, |acc, value| acc + *value);
|
||||
|
||||
let (total_flops, total_bytes) = self
|
||||
.seq_lengths
|
||||
.iter()
|
||||
.map(|(seq_len, prev_seq)| llama_estimate_flops_and_bytes(*seq_len, *prev_seq))
|
||||
.fold((0u64, 0u64), |(acc_flops, acc_bytes), (flops, bytes)| {
|
||||
(acc_flops + flops, acc_bytes + bytes)
|
||||
});
|
||||
|
||||
let achieved_tflops = total_flops as f64 / total_elapsed.as_secs_f64() / 1e12;
|
||||
let achieved_gbps = total_bytes as f64 / total_elapsed.as_secs_f64() / 1e9;
|
||||
println!("Benchmark results:");
|
||||
println!(" TTFT: {:.2} ms", self.ttft.as_secs_f64() * 1e3);
|
||||
if !self.decode_durations.is_empty() {
|
||||
println!(
|
||||
" TPOT: {:.2} ms",
|
||||
(decode_total / self.decode_durations.len() as u32).as_secs_f64() * 1e3
|
||||
);
|
||||
}
|
||||
println!(
|
||||
" Achieved: {:.2} TFLOP/s, {:.2} GB/s",
|
||||
achieved_tflops, achieved_gbps
|
||||
);
|
||||
println!(
|
||||
" MFU (est): {:.2}%",
|
||||
(achieved_tflops / peak_tflops) * 100.0
|
||||
);
|
||||
println!(" MBU (est): {:.2}%", (achieved_gbps / peak_gbps) * 100.0);
|
||||
}
|
||||
}
|
||||
|
||||
fn llama_estimate_flops_and_bytes(seq_len: usize, prev_seq: usize) -> (u64, u64) {
|
||||
let total_seq = seq_len + prev_seq;
|
||||
let hidden = HIDDEN as u64;
|
||||
let intermediate = INTERMEDIATE as u64;
|
||||
let seq = seq_len as u64;
|
||||
let total_seq = total_seq as u64;
|
||||
let head_dim = HEAD_DIM as u64;
|
||||
let n_heads = hidden / head_dim;
|
||||
let kv_hidden = (HIDDEN / KV_GROUPS) as u64;
|
||||
let vocab = VOCAB_SIZE as u64;
|
||||
let bytes_per = std::mem::size_of::<f32>() as u64;
|
||||
|
||||
let q_proj_flops = 2 * seq * hidden * hidden;
|
||||
let k_proj_flops = 2 * seq * hidden * kv_hidden;
|
||||
let v_proj_flops = 2 * seq * hidden * kv_hidden;
|
||||
let o_proj_flops = 2 * seq * hidden * hidden;
|
||||
let mlp_flops = 6 * seq * hidden * intermediate;
|
||||
let attn_flops = 4 * seq * total_seq * head_dim * n_heads;
|
||||
let lm_head_flops = 2 * seq * hidden * vocab;
|
||||
|
||||
let per_layer_flops =
|
||||
q_proj_flops + k_proj_flops + v_proj_flops + o_proj_flops + mlp_flops + attn_flops;
|
||||
let total_flops = per_layer_flops * LAYERS as u64 + lm_head_flops;
|
||||
|
||||
let q_bytes = bytes_per * (seq * hidden + hidden * hidden + seq * hidden);
|
||||
let k_bytes = bytes_per * (seq * hidden + hidden * kv_hidden + seq * kv_hidden);
|
||||
let v_bytes = bytes_per * (seq * hidden + hidden * kv_hidden + seq * kv_hidden);
|
||||
let o_bytes = bytes_per * (seq * hidden + hidden * hidden + seq * hidden);
|
||||
let mlp_bytes = bytes_per * (seq * hidden + hidden * intermediate + seq * intermediate) * 2
|
||||
+ bytes_per * (seq * intermediate + intermediate * hidden + seq * hidden);
|
||||
let attn_bytes = bytes_per * (seq * hidden + total_seq * kv_hidden * 2 + seq * hidden);
|
||||
let lm_head_bytes = bytes_per * (seq * hidden + hidden * vocab + seq * vocab);
|
||||
|
||||
let per_layer_bytes = q_bytes + k_bytes + v_bytes + o_bytes + mlp_bytes + attn_bytes;
|
||||
let total_bytes = per_layer_bytes * LAYERS as u64 + lm_head_bytes;
|
||||
|
||||
(total_flops, total_bytes)
|
||||
}
|
||||
@@ -1,13 +1,9 @@
|
||||
mod benchmark;
|
||||
mod model;
|
||||
|
||||
use benchmark::Benchmarker;
|
||||
use luminal::prelude::*;
|
||||
use luminal_cuda::{
|
||||
cuda_bandwidth_gbps, cuda_compute_f32_tflops, cudarc::driver::CudaContext, runtime::CudaRuntime,
|
||||
};
|
||||
use luminal_cuda::{cudarc::driver::CudaContext, runtime::CudaRuntime};
|
||||
use model::*;
|
||||
use std::io::Write;
|
||||
use std::{io::Write, time::Duration};
|
||||
use tokenizers::Tokenizer;
|
||||
use tracing::{span, Level};
|
||||
|
||||
@@ -69,8 +65,9 @@ fn main() {
|
||||
|
||||
// Decode loop
|
||||
let mut prev_seq = 0;
|
||||
let mut benchmarker = Benchmarker::new();
|
||||
let mut fwd_durations = vec![];
|
||||
for i in 0..gen_tokens {
|
||||
let start = std::time::Instant::now();
|
||||
let span = if i == 0 {
|
||||
span!(Level::INFO, "prefill")
|
||||
} else {
|
||||
@@ -94,7 +91,6 @@ fn main() {
|
||||
);
|
||||
|
||||
// Execute forward pass
|
||||
benchmarker.start_iteration(seq_len, prev_seq);
|
||||
runtime.execute(&cx.dyn_map);
|
||||
let logits_data = runtime.get_f32(logits);
|
||||
|
||||
@@ -105,16 +101,17 @@ fn main() {
|
||||
prev_seq += seq_len;
|
||||
print!("{}", tokenizer.decode(&sentence, true).unwrap());
|
||||
std::io::stdout().flush().unwrap();
|
||||
benchmarker.end_iteration(i);
|
||||
fwd_durations.push(start.elapsed());
|
||||
}
|
||||
println!();
|
||||
|
||||
// Report benchmarks
|
||||
if let (Some(flops), Some(bandwidth)) =
|
||||
(cuda_compute_f32_tflops(&ctx), cuda_bandwidth_gbps(&ctx))
|
||||
{
|
||||
benchmarker.report(flops as f64, bandwidth as f64);
|
||||
}
|
||||
println!(" TTFT: {:.2} ms", fwd_durations[0].as_secs_f64() * 1e3);
|
||||
println!(
|
||||
" TPOT: {:.2} ms",
|
||||
(fwd_durations.iter().sum::<Duration>() / fwd_durations.len() as u32).as_secs_f64() * 1e3
|
||||
);
|
||||
runtime.print_execution_stats();
|
||||
// Dump cuda trace to timeline
|
||||
trace_session.stop();
|
||||
if let Some(path) = trace_session.perfetto_path {
|
||||
|
||||
Reference in New Issue
Block a user