fmt and clippy

This commit is contained in:
Joe Fioti
2026-02-07 22:43:45 +00:00
parent b7196102f9
commit 9e3e038e74
6 changed files with 117 additions and 62 deletions

View File

@@ -8,7 +8,9 @@ pub use ops::*;
pub use to_kernel::block_to_kernel;
use cudarc::{
driver::{CudaFunction, CudaModule, CudaSlice, CudaStream, DevicePtr, DeviceRepr, ValidAsZeroBits},
driver::{
CudaFunction, CudaModule, CudaSlice, CudaStream, DevicePtr, DeviceRepr, ValidAsZeroBits,
},
nvrtc::{CompileOptions, compile_ptx_with_opts},
};
use luminal::{
@@ -405,7 +407,6 @@ impl TaskQueue {
// Pad to task_stride
bytes.resize(self.task_stride, 0);
self.data.extend_from_slice(&bytes);
self.num_tasks += 1;
}
@@ -693,6 +694,7 @@ pub fn record_block_op_timings(
}
#[tracing::instrument(skip_all)]
#[allow(clippy::type_complexity)]
fn compile_interpreter(
cuda_stream: &Arc<CudaStream>,
ops: &Vec<Arc<Box<dyn BlockOp>>>,
@@ -902,7 +904,6 @@ fn compile_interpreter(
let (module, func) = if let Some((module, kernel)) = kernel_cache.get(&kernel) {
(module.clone(), kernel.clone())
} else {
let _span = span!(Level::TRACE, "nvrtc").entered();
let ptx = compile_ptx_with_opts(
&kernel,
@@ -955,7 +956,6 @@ pub struct MegakernelOp {
pub sm_count: i32,
}
impl crate::kernel::KernelOp for MegakernelOp {
fn compile(
&self,
@@ -977,8 +977,8 @@ impl crate::kernel::KernelOp for MegakernelOp {
"megakernel".to_string(),
(self.sm_count.into(), 1.into(), 1.into()), // grid: one block per SM
(256.into(), 1.into(), 1.into()), // block: 256 threads
0.into(), // No dynamic shared memory (static scratchpad is sufficient)
self.interpreter_constants.clone(), // Return constants for runtime to manage
0.into(), // No dynamic shared memory (static scratchpad is sufficient)
self.interpreter_constants.clone(), // Return constants for runtime to manage
)
}
@@ -1005,17 +1005,31 @@ impl crate::kernel::KernelOp for MegakernelOp {
// 0: tasks - upload task queue
stream.clone_htod(self.work_queue.as_slice()).unwrap(),
// 1: head - reset in-kernel
stream.alloc_zeros::<u8>(std::mem::size_of::<i32>()).unwrap(),
stream
.alloc_zeros::<u8>(std::mem::size_of::<i32>())
.unwrap(),
// 2: ready - barrier array, reset in-kernel
stream.alloc_zeros::<u8>(n_barriers * std::mem::size_of::<i32>()).unwrap(),
stream
.alloc_zeros::<u8>(n_barriers * std::mem::size_of::<i32>())
.unwrap(),
// 3: queue_lock - reset in-kernel
stream.alloc_zeros::<u8>(std::mem::size_of::<i32>()).unwrap(),
stream
.alloc_zeros::<u8>(std::mem::size_of::<i32>())
.unwrap(),
// 4: timings - per-SM timing events
stream.alloc_zeros::<u8>(self.sm_count as usize * N_TIMING_SLOTS * std::mem::size_of::<SMEvent>()).unwrap(),
stream
.alloc_zeros::<u8>(
self.sm_count as usize * N_TIMING_SLOTS * std::mem::size_of::<SMEvent>(),
)
.unwrap(),
// 5: start_times - per-SM start times
stream.alloc_zeros::<u8>(self.sm_count as usize * std::mem::size_of::<u64>()).unwrap(),
stream
.alloc_zeros::<u8>(self.sm_count as usize * std::mem::size_of::<u64>())
.unwrap(),
// 6: buffers - array of buffer pointers
stream.alloc_zeros::<u8>(buffer_count * std::mem::size_of::<u64>()).unwrap(),
stream
.alloc_zeros::<u8>(buffer_count * std::mem::size_of::<u64>())
.unwrap(),
]
}
@@ -1029,7 +1043,8 @@ impl crate::kernel::KernelOp for MegakernelOp {
) -> Vec<u64> {
// Megakernel params: [tasks, head, ready, queue_lock, timings, start_times, buffers, dyn_dims]
// dyn_dims is handled via constants, pass 0
internal_bufs.iter()
internal_bufs
.iter()
.map(|buf| buf.device_ptr(stream).0)
.chain(std::iter::once(0u64)) // dyn_dims placeholder
.collect()
@@ -1051,7 +1066,9 @@ impl crate::kernel::KernelOp for MegakernelOp {
if let Ok(mut global) = self.module.get_global(&global_name, stream) {
let mut view = global.as_view_mut();
let mut symbol = unsafe { view.transmute_mut::<i32>(1).unwrap() };
stream.memcpy_htod(&[*val as i32], &mut symbol).expect("Failed to update dyn dim constant");
stream
.memcpy_htod(&[*val as i32], &mut symbol)
.expect("Failed to update dyn dim constant");
// IMPORTANT: Don't drop `global` - it would try to free __constant__ memory!
// Leak it intentionally to prevent the Drop from running.
std::mem::forget(global);
@@ -1061,14 +1078,17 @@ impl crate::kernel::KernelOp for MegakernelOp {
// Re-upload tasks with remaining=-1 (index 0)
// This ensures fresh task state for each execution
let task_data = self.work_queue.as_slice();
stream.memcpy_htod(task_data, &mut internal_bufs[0].as_view_mut())
stream
.memcpy_htod(task_data, &mut internal_bufs[0].as_view_mut())
.expect("Failed to re-upload tasks");
// Reset head to 0 (index 1)
{
let mut head_view = internal_bufs[1].as_view_mut();
let mut head_typed = unsafe { head_view.transmute_mut::<i32>(1).unwrap() };
stream.memcpy_htod(&[0i32], &mut head_typed).expect("Failed to reset head");
stream
.memcpy_htod(&[0i32], &mut head_typed)
.expect("Failed to reset head");
}
// Reset barriers to 0 (index 2)
@@ -1078,15 +1098,23 @@ impl crate::kernel::KernelOp for MegakernelOp {
{
let zeros: Vec<i32> = vec![0; allocated_n_barriers];
let mut ready_view = internal_bufs[2].as_view_mut();
let mut ready_typed = unsafe { ready_view.transmute_mut::<i32>(allocated_n_barriers).unwrap() };
stream.memcpy_htod(&zeros, &mut ready_typed).expect("Failed to reset barriers");
let mut ready_typed = unsafe {
ready_view
.transmute_mut::<i32>(allocated_n_barriers)
.unwrap()
};
stream
.memcpy_htod(&zeros, &mut ready_typed)
.expect("Failed to reset barriers");
}
// Reset queue_lock to 0 (index 3)
{
let mut lock_view = internal_bufs[3].as_view_mut();
let mut lock_typed = unsafe { lock_view.transmute_mut::<i32>(1).unwrap() };
stream.memcpy_htod(&[0i32], &mut lock_typed).expect("Failed to reset queue_lock");
stream
.memcpy_htod(&[0i32], &mut lock_typed)
.expect("Failed to reset queue_lock");
}
// Update buffer array (index 6)
@@ -1099,11 +1127,19 @@ impl crate::kernel::KernelOp for MegakernelOp {
}
let mut buffers_view = internal_bufs[6].as_view_mut();
let mut buffers_typed = unsafe { buffers_view.transmute_mut::<u64>(buffer_count).expect("Failed to transmute buffers") };
stream.memcpy_htod(&buffer_array, &mut buffers_typed).expect("Failed to update buffer array");
let mut buffers_typed = unsafe {
buffers_view
.transmute_mut::<u64>(buffer_count)
.expect("Failed to transmute buffers")
};
stream
.memcpy_htod(&buffer_array, &mut buffers_typed)
.expect("Failed to update buffer array");
// Ensure all uploads complete before kernel execution
stream.synchronize().expect("Failed to sync after pre_execute");
stream
.synchronize()
.expect("Failed to sync after pre_execute");
}
fn timing_buffer_indices(&self) -> Option<(usize, usize, usize)> {
@@ -1125,13 +1161,22 @@ impl MegakernelOp {
cuda_stream: &Arc<CudaStream>,
kernel_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
) -> Self {
let (interpreter, module, interpreter_constants, n_barriers, work_queue, _node_to_task_index, node_to_buffer_index) =
make_megakernel_from_llir_graph(llir_graph, subgraph, cuda_stream, kernel_cache);
let (
interpreter,
module,
interpreter_constants,
n_barriers,
work_queue,
_node_to_task_index,
node_to_buffer_index,
) = make_megakernel_from_llir_graph(llir_graph, subgraph, cuda_stream, kernel_cache);
// Get device properties
let ctx = cuda_stream.context();
let sm_count = ctx
.attribute(cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT)
.attribute(
cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT,
)
.expect("Failed to get SM count");
Self {
@@ -1147,7 +1192,11 @@ impl MegakernelOp {
/// Returns the number of buffers this megakernel uses.
pub fn buffer_count(&self) -> usize {
self.node_to_buffer_index.values().map(|&i| i + 1).max().unwrap_or(0) as usize
self.node_to_buffer_index
.values()
.map(|&i| i + 1)
.max()
.unwrap_or(0) as usize
}
}
@@ -1167,7 +1216,6 @@ impl Drop for MegakernelOp {
}
}
#[allow(clippy::type_complexity)]
pub(crate) fn make_megakernel_from_llir_graph(
llir_graph: &LLIRGraph,
@@ -1176,12 +1224,12 @@ pub(crate) fn make_megakernel_from_llir_graph(
kernel_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
) -> (
CudaFunction,
Arc<CudaModule>, // Module (needed for device globals)
Arc<CudaModule>, // Module (needed for device globals)
FxHashMap<char, CudaSlice<u8>>,
Expression,
TaskQueue,
FxHashMap<NodeIndex, usize>,
FxHashMap<NodeIndex, i32>, // node_to_buffer_index mapping
FxHashMap<NodeIndex, i32>, // node_to_buffer_index mapping
) {
let block_ops = llir_graph
.node_indices()

View File

@@ -6,7 +6,10 @@ use cudarc::driver::{CudaFunction, CudaModule, CudaStream};
use luminal::{
graph::LLIRGraph,
op::LLIROp,
prelude::{FxHashMap, FxHashSet, NodeIndex, petgraph::{Direction, visit::EdgeRef}},
prelude::{
FxHashMap, FxHashSet, NodeIndex,
petgraph::{Direction, visit::EdgeRef},
},
};
use tracing::{Level, span};

View File

@@ -265,6 +265,7 @@ pub struct MegakernelParams {
impl MegakernelParams {
/// Create megakernel params with all internal buffer pointers and dyn_dims.
/// Order: tasks, head, ready, queue_lock, timings, start_times, buffers, dyn_dims
#[allow(clippy::too_many_arguments)]
pub fn new(
tasks_ptr: u64,
head_ptr: u64,
@@ -284,7 +285,8 @@ impl MegakernelParams {
start_times_ptr,
buffers_ptr,
dyn_dims_ptr,
].into_boxed_slice();
]
.into_boxed_slice();
let ptrs: Box<[*mut c_void]> = values
.iter()
.map(|v| v as *const u64 as *mut c_void)

View File

@@ -564,10 +564,8 @@ impl CudaGraphOp {
prev_graph_node = Some(graph_node);
}
if tracing_enabled {
if let Some(prev) = prev_graph_node {
graph.add_event_record_node(&[prev], state.timing_events[num_kernels])?;
}
if tracing_enabled && let Some(prev) = prev_graph_node {
graph.add_event_record_node(&[prev], state.timing_events[num_kernels])?;
}
let exec = graph.instantiate()?;
@@ -724,7 +722,13 @@ pub fn kernel_to_host(
// Create CudaGraphOp with RefCell for interior mutability
let state = CudaGraphOpState::new(kernels);
let cuda_graph_op = CudaGraphOp::new(buffer_nodes, all_buffer_sizes, dyn_dims_order, cuda_stream.clone(), state);
let cuda_graph_op = CudaGraphOp::new(
buffer_nodes,
all_buffer_sizes,
dyn_dims_order,
cuda_stream.clone(),
state,
);
// Add CudaGraphOp to llir_graph as a HostOp
let cuda_graph_node =
@@ -789,10 +793,10 @@ pub fn kernel_to_host(
continue; // Same subgraph
}
// Check if consumer is a kernel in another CudaGraphOp
if let Some(&consumer_cuda_graph) = kernel_to_cuda_graph.get(&consumer) {
if consumer_cuda_graph != *cuda_graph_node {
edges_to_add.push((*cuda_graph_node, consumer_cuda_graph));
}
if let Some(&consumer_cuda_graph) = kernel_to_cuda_graph.get(&consumer)
&& consumer_cuda_graph != *cuda_graph_node
{
edges_to_add.push((*cuda_graph_node, consumer_cuda_graph));
}
// Also add edges to HostOps (like cuBLAS ops) that consume our outputs
if llir_graph[consumer]

View File

@@ -22,7 +22,14 @@ use luminal_tracing::PerfettoGuard;
use memmap2::MmapOptions;
use prost::Message;
use safetensors::SafeTensors;
use std::{collections::VecDeque, fmt::Debug, fs::File, mem::size_of, sync::Arc, time::Duration};
use std::{
collections::{VecDeque, hash_map::Entry},
fmt::Debug,
fs::File,
mem::size_of,
sync::Arc,
time::Duration,
};
use tracing::{Level, enabled, span, trace};
use uuid::Uuid;
@@ -677,22 +684,22 @@ impl Runtime for CudaRuntime {
for inp in exec_op.inputs.iter() {
if let Some(buf) = self.buffers.get(inp) {
buffer_map.insert(*inp, buf);
} else if let Some(hlir_node) = self.llir_to_hlir.get(inp) {
if let Some(CudaInput::Buffer(buf)) = self.hlir_buffers.get(hlir_node) {
buffer_map.insert(*inp, buf);
}
} else if let Some(hlir_node) = self.llir_to_hlir.get(inp)
&& let Some(CudaInput::Buffer(buf)) = self.hlir_buffers.get(hlir_node)
{
buffer_map.insert(*inp, buf);
}
}
// Add extra buffer nodes (for CudaGraphOp)
let extra_nodes = exec_op.internal.extra_buffer_nodes();
for extra_node in extra_nodes {
if !buffer_map.contains_key(&extra_node) {
if let Entry::Vacant(e) = buffer_map.entry(extra_node) {
if let Some(buf) = self.buffers.get(&extra_node) {
buffer_map.insert(extra_node, buf);
} else if let Some(hlir_node) = self.llir_to_hlir.get(&extra_node) {
if let Some(CudaInput::Buffer(buf)) = self.hlir_buffers.get(hlir_node) {
buffer_map.insert(extra_node, buf);
}
e.insert(buf);
} else if let Some(hlir_node) = self.llir_to_hlir.get(&extra_node)
&& let Some(CudaInput::Buffer(buf)) = self.hlir_buffers.get(hlir_node)
{
e.insert(buf);
}
}
}

View File

@@ -142,11 +142,7 @@ impl Gemma {
}
}
pub fn forward(
&self,
token_ids: GraphTensor,
kv_cache: &KVCache,
) -> GraphTensor {
pub fn forward(&self, token_ids: GraphTensor, kv_cache: &KVCache) -> GraphTensor {
let batch = token_ids.dims1();
let mut x = self.embedding.gather(
(token_ids * HIDDEN).expand_dim(1, HIDDEN)
@@ -184,12 +180,7 @@ struct GemmaLayer {
}
impl GemmaLayer {
pub fn forward(
&self,
x: GraphTensor,
k_cache: u64,
v_cache: u64,
) -> GraphTensor {
pub fn forward(&self, x: GraphTensor, k_cache: u64, v_cache: u64) -> GraphTensor {
let x_attn = self.input_layernorm.forward(x);
let q = x_attn.matmul(self.q_proj.t());
let k = x_attn.matmul(self.k_proj.t());