changed cuda constants

This commit is contained in:
Joe Fioti
2026-01-04 02:24:49 +00:00
parent 2bcc5d54a6
commit d6321f3f6a
2 changed files with 23 additions and 11 deletions

View File

@@ -666,6 +666,10 @@ impl Runtime for CudaRuntime {
work_queue,
..
} => {
let sm_count = self
.cuda_context
.attribute(cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT)
.unwrap();
let span = span!(Level::INFO, "megakernel_setup");
let _entered = span.enter();
// Upload queue, barriers and program counter
@@ -676,10 +680,15 @@ impl Runtime for CudaRuntime {
let d_tasks = self.cuda_stream.memcpy_stod(work_queue).unwrap();
let d_head = self.cuda_stream.memcpy_stod(&[0i32]).unwrap();
let queue_lock = self.cuda_stream.memcpy_stod(&[0i32]).unwrap();
// Set up timing buffer (start_time_u64,[[event_start_u64,event_type_i32 for sm_event in sm[:1000] for sm in sms[:114]])
let timing_buffer =
self.cuda_stream.alloc_zeros::<SMEvent>(114 * 1000).unwrap();
let start_time = self.cuda_stream.alloc_zeros::<u64>(114).unwrap();
// Set up timing buffer (start_time_u64,[[event_start_u64,event_type_i32 for sm_event in sm[:1000] for sm in sms[:sm_count]])
let timing_buffer = self
.cuda_stream
.alloc_zeros::<SMEvent>(sm_count as usize * 1000)
.unwrap();
let start_time = self
.cuda_stream
.alloc_zeros::<u64>(sm_count as usize)
.unwrap();
// Set up dyn dims
for (dyn_dim, val) in dyn_map {
@@ -692,17 +701,21 @@ impl Runtime for CudaRuntime {
}
}
let shared_mem_max = self
.cuda_context
.attribute(cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR)
.unwrap();
interpreter.set_attribute(
cudarc::driver::sys::CUfunction_attribute::CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
100_000,
shared_mem_max / 2, // Half shared mem, half L2
).unwrap();
// Launch kernel
let num_sm = 114;
let cfg = LaunchConfig {
grid_dim: (num_sm, 1, 1), // One block per SM
block_dim: (1024, 1, 1), // 512 threads per block
shared_mem_bytes: 100_000, // 40 kb
grid_dim: (sm_count as u32, 1, 1), // One block per SM
block_dim: (1024, 1, 1), // 1024 threads (32 warps) per block
shared_mem_bytes: (shared_mem_max / 2) as u32,
};
let mut lb = self.cuda_stream.launch_builder(interpreter);
let n_tasks = work_queue.len() as i32;

View File

@@ -8,7 +8,6 @@ pub fn cuda_test() {
let mut cx = Graph::default();
let input = cx.tensor(5);
let output = (input + input).output();
display_graph(&cx.graph, None, "hlir.txt");
let ctx = CudaContext::new(0).unwrap();
ctx.bind_to_thread().unwrap();
@@ -16,7 +15,7 @@ pub fn cuda_test() {
cx.build_search_space::<CudaRuntime>();
let mut rt = CudaRuntime::initialize((ctx, stream, FxHashMap::default()));
rt.set_data(input, Box::new(vec![0., 1., 2., 3., 4.]));
rt = cx.search(rt, 1);
rt = cx.search(rt, 10);
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
let out = rt.get_f32(output);