mirror of
https://github.com/jafioti/luminal.git
synced 2026-06-01 21:49:47 +09:00
fmt and clippy
This commit is contained in:
@@ -8,7 +8,9 @@ pub use ops::*;
|
||||
pub use to_kernel::block_to_kernel;
|
||||
|
||||
use cudarc::{
|
||||
driver::{CudaFunction, CudaModule, CudaSlice, CudaStream, DevicePtr, DeviceRepr, ValidAsZeroBits},
|
||||
driver::{
|
||||
CudaFunction, CudaModule, CudaSlice, CudaStream, DevicePtr, DeviceRepr, ValidAsZeroBits,
|
||||
},
|
||||
nvrtc::{CompileOptions, compile_ptx_with_opts},
|
||||
};
|
||||
use luminal::{
|
||||
@@ -405,7 +407,6 @@ impl TaskQueue {
|
||||
// Pad to task_stride
|
||||
bytes.resize(self.task_stride, 0);
|
||||
|
||||
|
||||
self.data.extend_from_slice(&bytes);
|
||||
self.num_tasks += 1;
|
||||
}
|
||||
@@ -693,6 +694,7 @@ pub fn record_block_op_timings(
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
#[allow(clippy::type_complexity)]
|
||||
fn compile_interpreter(
|
||||
cuda_stream: &Arc<CudaStream>,
|
||||
ops: &Vec<Arc<Box<dyn BlockOp>>>,
|
||||
@@ -902,7 +904,6 @@ fn compile_interpreter(
|
||||
let (module, func) = if let Some((module, kernel)) = kernel_cache.get(&kernel) {
|
||||
(module.clone(), kernel.clone())
|
||||
} else {
|
||||
|
||||
let _span = span!(Level::TRACE, "nvrtc").entered();
|
||||
let ptx = compile_ptx_with_opts(
|
||||
&kernel,
|
||||
@@ -955,7 +956,6 @@ pub struct MegakernelOp {
|
||||
pub sm_count: i32,
|
||||
}
|
||||
|
||||
|
||||
impl crate::kernel::KernelOp for MegakernelOp {
|
||||
fn compile(
|
||||
&self,
|
||||
@@ -977,8 +977,8 @@ impl crate::kernel::KernelOp for MegakernelOp {
|
||||
"megakernel".to_string(),
|
||||
(self.sm_count.into(), 1.into(), 1.into()), // grid: one block per SM
|
||||
(256.into(), 1.into(), 1.into()), // block: 256 threads
|
||||
0.into(), // No dynamic shared memory (static scratchpad is sufficient)
|
||||
self.interpreter_constants.clone(), // Return constants for runtime to manage
|
||||
0.into(), // No dynamic shared memory (static scratchpad is sufficient)
|
||||
self.interpreter_constants.clone(), // Return constants for runtime to manage
|
||||
)
|
||||
}
|
||||
|
||||
@@ -1005,17 +1005,31 @@ impl crate::kernel::KernelOp for MegakernelOp {
|
||||
// 0: tasks - upload task queue
|
||||
stream.clone_htod(self.work_queue.as_slice()).unwrap(),
|
||||
// 1: head - reset in-kernel
|
||||
stream.alloc_zeros::<u8>(std::mem::size_of::<i32>()).unwrap(),
|
||||
stream
|
||||
.alloc_zeros::<u8>(std::mem::size_of::<i32>())
|
||||
.unwrap(),
|
||||
// 2: ready - barrier array, reset in-kernel
|
||||
stream.alloc_zeros::<u8>(n_barriers * std::mem::size_of::<i32>()).unwrap(),
|
||||
stream
|
||||
.alloc_zeros::<u8>(n_barriers * std::mem::size_of::<i32>())
|
||||
.unwrap(),
|
||||
// 3: queue_lock - reset in-kernel
|
||||
stream.alloc_zeros::<u8>(std::mem::size_of::<i32>()).unwrap(),
|
||||
stream
|
||||
.alloc_zeros::<u8>(std::mem::size_of::<i32>())
|
||||
.unwrap(),
|
||||
// 4: timings - per-SM timing events
|
||||
stream.alloc_zeros::<u8>(self.sm_count as usize * N_TIMING_SLOTS * std::mem::size_of::<SMEvent>()).unwrap(),
|
||||
stream
|
||||
.alloc_zeros::<u8>(
|
||||
self.sm_count as usize * N_TIMING_SLOTS * std::mem::size_of::<SMEvent>(),
|
||||
)
|
||||
.unwrap(),
|
||||
// 5: start_times - per-SM start times
|
||||
stream.alloc_zeros::<u8>(self.sm_count as usize * std::mem::size_of::<u64>()).unwrap(),
|
||||
stream
|
||||
.alloc_zeros::<u8>(self.sm_count as usize * std::mem::size_of::<u64>())
|
||||
.unwrap(),
|
||||
// 6: buffers - array of buffer pointers
|
||||
stream.alloc_zeros::<u8>(buffer_count * std::mem::size_of::<u64>()).unwrap(),
|
||||
stream
|
||||
.alloc_zeros::<u8>(buffer_count * std::mem::size_of::<u64>())
|
||||
.unwrap(),
|
||||
]
|
||||
}
|
||||
|
||||
@@ -1029,7 +1043,8 @@ impl crate::kernel::KernelOp for MegakernelOp {
|
||||
) -> Vec<u64> {
|
||||
// Megakernel params: [tasks, head, ready, queue_lock, timings, start_times, buffers, dyn_dims]
|
||||
// dyn_dims is handled via constants, pass 0
|
||||
internal_bufs.iter()
|
||||
internal_bufs
|
||||
.iter()
|
||||
.map(|buf| buf.device_ptr(stream).0)
|
||||
.chain(std::iter::once(0u64)) // dyn_dims placeholder
|
||||
.collect()
|
||||
@@ -1051,7 +1066,9 @@ impl crate::kernel::KernelOp for MegakernelOp {
|
||||
if let Ok(mut global) = self.module.get_global(&global_name, stream) {
|
||||
let mut view = global.as_view_mut();
|
||||
let mut symbol = unsafe { view.transmute_mut::<i32>(1).unwrap() };
|
||||
stream.memcpy_htod(&[*val as i32], &mut symbol).expect("Failed to update dyn dim constant");
|
||||
stream
|
||||
.memcpy_htod(&[*val as i32], &mut symbol)
|
||||
.expect("Failed to update dyn dim constant");
|
||||
// IMPORTANT: Don't drop `global` - it would try to free __constant__ memory!
|
||||
// Leak it intentionally to prevent the Drop from running.
|
||||
std::mem::forget(global);
|
||||
@@ -1061,14 +1078,17 @@ impl crate::kernel::KernelOp for MegakernelOp {
|
||||
// Re-upload tasks with remaining=-1 (index 0)
|
||||
// This ensures fresh task state for each execution
|
||||
let task_data = self.work_queue.as_slice();
|
||||
stream.memcpy_htod(task_data, &mut internal_bufs[0].as_view_mut())
|
||||
stream
|
||||
.memcpy_htod(task_data, &mut internal_bufs[0].as_view_mut())
|
||||
.expect("Failed to re-upload tasks");
|
||||
|
||||
// Reset head to 0 (index 1)
|
||||
{
|
||||
let mut head_view = internal_bufs[1].as_view_mut();
|
||||
let mut head_typed = unsafe { head_view.transmute_mut::<i32>(1).unwrap() };
|
||||
stream.memcpy_htod(&[0i32], &mut head_typed).expect("Failed to reset head");
|
||||
stream
|
||||
.memcpy_htod(&[0i32], &mut head_typed)
|
||||
.expect("Failed to reset head");
|
||||
}
|
||||
|
||||
// Reset barriers to 0 (index 2)
|
||||
@@ -1078,15 +1098,23 @@ impl crate::kernel::KernelOp for MegakernelOp {
|
||||
{
|
||||
let zeros: Vec<i32> = vec![0; allocated_n_barriers];
|
||||
let mut ready_view = internal_bufs[2].as_view_mut();
|
||||
let mut ready_typed = unsafe { ready_view.transmute_mut::<i32>(allocated_n_barriers).unwrap() };
|
||||
stream.memcpy_htod(&zeros, &mut ready_typed).expect("Failed to reset barriers");
|
||||
let mut ready_typed = unsafe {
|
||||
ready_view
|
||||
.transmute_mut::<i32>(allocated_n_barriers)
|
||||
.unwrap()
|
||||
};
|
||||
stream
|
||||
.memcpy_htod(&zeros, &mut ready_typed)
|
||||
.expect("Failed to reset barriers");
|
||||
}
|
||||
|
||||
// Reset queue_lock to 0 (index 3)
|
||||
{
|
||||
let mut lock_view = internal_bufs[3].as_view_mut();
|
||||
let mut lock_typed = unsafe { lock_view.transmute_mut::<i32>(1).unwrap() };
|
||||
stream.memcpy_htod(&[0i32], &mut lock_typed).expect("Failed to reset queue_lock");
|
||||
stream
|
||||
.memcpy_htod(&[0i32], &mut lock_typed)
|
||||
.expect("Failed to reset queue_lock");
|
||||
}
|
||||
|
||||
// Update buffer array (index 6)
|
||||
@@ -1099,11 +1127,19 @@ impl crate::kernel::KernelOp for MegakernelOp {
|
||||
}
|
||||
|
||||
let mut buffers_view = internal_bufs[6].as_view_mut();
|
||||
let mut buffers_typed = unsafe { buffers_view.transmute_mut::<u64>(buffer_count).expect("Failed to transmute buffers") };
|
||||
stream.memcpy_htod(&buffer_array, &mut buffers_typed).expect("Failed to update buffer array");
|
||||
let mut buffers_typed = unsafe {
|
||||
buffers_view
|
||||
.transmute_mut::<u64>(buffer_count)
|
||||
.expect("Failed to transmute buffers")
|
||||
};
|
||||
stream
|
||||
.memcpy_htod(&buffer_array, &mut buffers_typed)
|
||||
.expect("Failed to update buffer array");
|
||||
|
||||
// Ensure all uploads complete before kernel execution
|
||||
stream.synchronize().expect("Failed to sync after pre_execute");
|
||||
stream
|
||||
.synchronize()
|
||||
.expect("Failed to sync after pre_execute");
|
||||
}
|
||||
|
||||
fn timing_buffer_indices(&self) -> Option<(usize, usize, usize)> {
|
||||
@@ -1125,13 +1161,22 @@ impl MegakernelOp {
|
||||
cuda_stream: &Arc<CudaStream>,
|
||||
kernel_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> Self {
|
||||
let (interpreter, module, interpreter_constants, n_barriers, work_queue, _node_to_task_index, node_to_buffer_index) =
|
||||
make_megakernel_from_llir_graph(llir_graph, subgraph, cuda_stream, kernel_cache);
|
||||
let (
|
||||
interpreter,
|
||||
module,
|
||||
interpreter_constants,
|
||||
n_barriers,
|
||||
work_queue,
|
||||
_node_to_task_index,
|
||||
node_to_buffer_index,
|
||||
) = make_megakernel_from_llir_graph(llir_graph, subgraph, cuda_stream, kernel_cache);
|
||||
|
||||
// Get device properties
|
||||
let ctx = cuda_stream.context();
|
||||
let sm_count = ctx
|
||||
.attribute(cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT)
|
||||
.attribute(
|
||||
cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT,
|
||||
)
|
||||
.expect("Failed to get SM count");
|
||||
|
||||
Self {
|
||||
@@ -1147,7 +1192,11 @@ impl MegakernelOp {
|
||||
|
||||
/// Returns the number of buffers this megakernel uses.
|
||||
pub fn buffer_count(&self) -> usize {
|
||||
self.node_to_buffer_index.values().map(|&i| i + 1).max().unwrap_or(0) as usize
|
||||
self.node_to_buffer_index
|
||||
.values()
|
||||
.map(|&i| i + 1)
|
||||
.max()
|
||||
.unwrap_or(0) as usize
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1167,7 +1216,6 @@ impl Drop for MegakernelOp {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[allow(clippy::type_complexity)]
|
||||
pub(crate) fn make_megakernel_from_llir_graph(
|
||||
llir_graph: &LLIRGraph,
|
||||
@@ -1176,12 +1224,12 @@ pub(crate) fn make_megakernel_from_llir_graph(
|
||||
kernel_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> (
|
||||
CudaFunction,
|
||||
Arc<CudaModule>, // Module (needed for device globals)
|
||||
Arc<CudaModule>, // Module (needed for device globals)
|
||||
FxHashMap<char, CudaSlice<u8>>,
|
||||
Expression,
|
||||
TaskQueue,
|
||||
FxHashMap<NodeIndex, usize>,
|
||||
FxHashMap<NodeIndex, i32>, // node_to_buffer_index mapping
|
||||
FxHashMap<NodeIndex, i32>, // node_to_buffer_index mapping
|
||||
) {
|
||||
let block_ops = llir_graph
|
||||
.node_indices()
|
||||
|
||||
@@ -6,7 +6,10 @@ use cudarc::driver::{CudaFunction, CudaModule, CudaStream};
|
||||
use luminal::{
|
||||
graph::LLIRGraph,
|
||||
op::LLIROp,
|
||||
prelude::{FxHashMap, FxHashSet, NodeIndex, petgraph::{Direction, visit::EdgeRef}},
|
||||
prelude::{
|
||||
FxHashMap, FxHashSet, NodeIndex,
|
||||
petgraph::{Direction, visit::EdgeRef},
|
||||
},
|
||||
};
|
||||
use tracing::{Level, span};
|
||||
|
||||
|
||||
@@ -265,6 +265,7 @@ pub struct MegakernelParams {
|
||||
impl MegakernelParams {
|
||||
/// Create megakernel params with all internal buffer pointers and dyn_dims.
|
||||
/// Order: tasks, head, ready, queue_lock, timings, start_times, buffers, dyn_dims
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn new(
|
||||
tasks_ptr: u64,
|
||||
head_ptr: u64,
|
||||
@@ -284,7 +285,8 @@ impl MegakernelParams {
|
||||
start_times_ptr,
|
||||
buffers_ptr,
|
||||
dyn_dims_ptr,
|
||||
].into_boxed_slice();
|
||||
]
|
||||
.into_boxed_slice();
|
||||
let ptrs: Box<[*mut c_void]> = values
|
||||
.iter()
|
||||
.map(|v| v as *const u64 as *mut c_void)
|
||||
|
||||
@@ -564,10 +564,8 @@ impl CudaGraphOp {
|
||||
prev_graph_node = Some(graph_node);
|
||||
}
|
||||
|
||||
if tracing_enabled {
|
||||
if let Some(prev) = prev_graph_node {
|
||||
graph.add_event_record_node(&[prev], state.timing_events[num_kernels])?;
|
||||
}
|
||||
if tracing_enabled && let Some(prev) = prev_graph_node {
|
||||
graph.add_event_record_node(&[prev], state.timing_events[num_kernels])?;
|
||||
}
|
||||
|
||||
let exec = graph.instantiate()?;
|
||||
@@ -724,7 +722,13 @@ pub fn kernel_to_host(
|
||||
// Create CudaGraphOp with RefCell for interior mutability
|
||||
let state = CudaGraphOpState::new(kernels);
|
||||
|
||||
let cuda_graph_op = CudaGraphOp::new(buffer_nodes, all_buffer_sizes, dyn_dims_order, cuda_stream.clone(), state);
|
||||
let cuda_graph_op = CudaGraphOp::new(
|
||||
buffer_nodes,
|
||||
all_buffer_sizes,
|
||||
dyn_dims_order,
|
||||
cuda_stream.clone(),
|
||||
state,
|
||||
);
|
||||
|
||||
// Add CudaGraphOp to llir_graph as a HostOp
|
||||
let cuda_graph_node =
|
||||
@@ -789,10 +793,10 @@ pub fn kernel_to_host(
|
||||
continue; // Same subgraph
|
||||
}
|
||||
// Check if consumer is a kernel in another CudaGraphOp
|
||||
if let Some(&consumer_cuda_graph) = kernel_to_cuda_graph.get(&consumer) {
|
||||
if consumer_cuda_graph != *cuda_graph_node {
|
||||
edges_to_add.push((*cuda_graph_node, consumer_cuda_graph));
|
||||
}
|
||||
if let Some(&consumer_cuda_graph) = kernel_to_cuda_graph.get(&consumer)
|
||||
&& consumer_cuda_graph != *cuda_graph_node
|
||||
{
|
||||
edges_to_add.push((*cuda_graph_node, consumer_cuda_graph));
|
||||
}
|
||||
// Also add edges to HostOps (like cuBLAS ops) that consume our outputs
|
||||
if llir_graph[consumer]
|
||||
|
||||
@@ -22,7 +22,14 @@ use luminal_tracing::PerfettoGuard;
|
||||
use memmap2::MmapOptions;
|
||||
use prost::Message;
|
||||
use safetensors::SafeTensors;
|
||||
use std::{collections::VecDeque, fmt::Debug, fs::File, mem::size_of, sync::Arc, time::Duration};
|
||||
use std::{
|
||||
collections::{VecDeque, hash_map::Entry},
|
||||
fmt::Debug,
|
||||
fs::File,
|
||||
mem::size_of,
|
||||
sync::Arc,
|
||||
time::Duration,
|
||||
};
|
||||
use tracing::{Level, enabled, span, trace};
|
||||
use uuid::Uuid;
|
||||
|
||||
@@ -677,22 +684,22 @@ impl Runtime for CudaRuntime {
|
||||
for inp in exec_op.inputs.iter() {
|
||||
if let Some(buf) = self.buffers.get(inp) {
|
||||
buffer_map.insert(*inp, buf);
|
||||
} else if let Some(hlir_node) = self.llir_to_hlir.get(inp) {
|
||||
if let Some(CudaInput::Buffer(buf)) = self.hlir_buffers.get(hlir_node) {
|
||||
buffer_map.insert(*inp, buf);
|
||||
}
|
||||
} else if let Some(hlir_node) = self.llir_to_hlir.get(inp)
|
||||
&& let Some(CudaInput::Buffer(buf)) = self.hlir_buffers.get(hlir_node)
|
||||
{
|
||||
buffer_map.insert(*inp, buf);
|
||||
}
|
||||
}
|
||||
// Add extra buffer nodes (for CudaGraphOp)
|
||||
let extra_nodes = exec_op.internal.extra_buffer_nodes();
|
||||
for extra_node in extra_nodes {
|
||||
if !buffer_map.contains_key(&extra_node) {
|
||||
if let Entry::Vacant(e) = buffer_map.entry(extra_node) {
|
||||
if let Some(buf) = self.buffers.get(&extra_node) {
|
||||
buffer_map.insert(extra_node, buf);
|
||||
} else if let Some(hlir_node) = self.llir_to_hlir.get(&extra_node) {
|
||||
if let Some(CudaInput::Buffer(buf)) = self.hlir_buffers.get(hlir_node) {
|
||||
buffer_map.insert(extra_node, buf);
|
||||
}
|
||||
e.insert(buf);
|
||||
} else if let Some(hlir_node) = self.llir_to_hlir.get(&extra_node)
|
||||
&& let Some(CudaInput::Buffer(buf)) = self.hlir_buffers.get(hlir_node)
|
||||
{
|
||||
e.insert(buf);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -142,11 +142,7 @@ impl Gemma {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward(
|
||||
&self,
|
||||
token_ids: GraphTensor,
|
||||
kv_cache: &KVCache,
|
||||
) -> GraphTensor {
|
||||
pub fn forward(&self, token_ids: GraphTensor, kv_cache: &KVCache) -> GraphTensor {
|
||||
let batch = token_ids.dims1();
|
||||
let mut x = self.embedding.gather(
|
||||
(token_ids * HIDDEN).expand_dim(1, HIDDEN)
|
||||
@@ -184,12 +180,7 @@ struct GemmaLayer {
|
||||
}
|
||||
|
||||
impl GemmaLayer {
|
||||
pub fn forward(
|
||||
&self,
|
||||
x: GraphTensor,
|
||||
k_cache: u64,
|
||||
v_cache: u64,
|
||||
) -> GraphTensor {
|
||||
pub fn forward(&self, x: GraphTensor, k_cache: u64, v_cache: u64) -> GraphTensor {
|
||||
let x_attn = self.input_layernorm.forward(x);
|
||||
let q = x_attn.matmul(self.q_proj.t());
|
||||
let k = x_attn.matmul(self.k_proj.t());
|
||||
|
||||
Reference in New Issue
Block a user