multiple output buffers

This commit is contained in:
Joe Fioti
2025-02-10 11:40:41 -06:00
parent 5de0e273e6
commit 80602e4278
2 changed files with 116 additions and 109 deletions

View File

@@ -7,4 +7,4 @@ edition = "2021"
[dependencies]
itertools = "0.14.0"
#metal-rs = { version = "0.28.0", package = "metal", features = ["mps"] }
metal-rs = { version = "0.28.0", package = "metal", features = ["mps"] }

View File

@@ -432,115 +432,116 @@ fn main() {
// println!("---");
// Set inputs
let a = (0..12).map(|i| i as f32).collect::<Vec<_>>();
// let a = (0..12).map(|i| i as f32).collect::<Vec<_>>();
// let b = (0..8)
// .flat_map(|i| (0..8).map(move |j| if j == i { 1.0 } else { 0.0 }))
// .collect::<Vec<_>>();
println!("Out: {:?}", run_graph(vec![a], &kernels));
}
fn run_graph(inputs: Vec<Vec<f32>>, kernels: &[Kernel]) -> Vec<f32> {
todo!()
// println!("Out: {:?}", run_graph(vec![a], &kernels));
}
// fn run_graph(inputs: Vec<Vec<f32>>, kernels: &[Kernel]) -> Vec<f32> {
// let device = Device::system_default().unwrap();
// let queue = device.new_command_queue();
// let command_buffer = queue.new_command_buffer();
// // Allocate input buffers
// let mut buffers = inputs
// .iter()
// .map(|buf| {
// device.new_buffer_with_data(
// buf.as_ptr() as *mut _,
// (buf.len() * std::mem::size_of::<f32>()) as u64,
// MTLResourceOptions::StorageModeShared,
// )
// })
// .collect::<Vec<_>>();
// let n_orig_buffers = buffers.len();
// // Allocate output buffers
// for kernel in kernels {
// assert_eq!(
// kernel.outputs.len(),
// 1,
// "Can't handle more than one kernel output for now"
// );
// buffers.push(device.new_buffer(
// (kernel.outputs[0] * std::mem::size_of::<f32>()) as u64,
// MTLResourceOptions::StorageModeShared,
// ));
// }
// // Queue up kernels
// for (
// n_kernel,
// Kernel {
// code,
// grid,
// threadblock,
// inputs,
// ..
// },
// ) in kernels.iter().enumerate()
// {
// let encoder =
// command_buffer.compute_command_encoder_with_descriptor(ComputePassDescriptor::new());
// // Compile kernel
// let options = CompileOptions::new();
// options.set_fast_math_enabled(true);
// let lib = device.new_library_with_source(code, &options).unwrap();
// let pipeline_state_descriptor = ComputePipelineDescriptor::new();
// pipeline_state_descriptor.set_compute_function(Some(
// &lib.get_function(&format!("kernel{n_kernel}"), None)
// .unwrap(),
// ));
// let pipeline = device
// .new_compute_pipeline_state_with_function(
// pipeline_state_descriptor.compute_function().unwrap(),
// )
// .unwrap();
// encoder.set_compute_pipeline_state(&pipeline);
// // Set inputs
// for (i, input) in inputs.iter().enumerate() {
// encoder.set_buffer(i as u64, Some(&buffers[*input]), 0);
// }
// // Set output
// encoder.set_buffer(
// inputs.len() as u64,
// Some(&buffers[n_kernel + n_orig_buffers]),
// 0,
// );
// // Set dispatch
// encoder.dispatch_thread_groups(
// MTLSize::new(grid.0 as u64, grid.1 as u64, grid.2 as u64),
// MTLSize::new(
// threadblock.0 as u64,
// threadblock.1 as u64,
// threadblock.2 as u64,
// ),
// );
// encoder.end_encoding();
// }
// // Run
// command_buffer.commit();
// command_buffer.wait_until_completed();
// // Copy back last buffer
// let buffer = buffers.last().unwrap();
// let mut data = vec![0.0; buffer.length() as usize / std::mem::size_of::<f32>()];
// let ptr = buffer.contents() as *mut f32;
// for (i, d) in data.iter_mut().enumerate() {
// *d = unsafe { *ptr.add(i) };
// }
// data
// todo!()
// }
fn run_graph(inputs: Vec<Vec<f32>>, kernels: &[Kernel]) -> Vec<f32> {
let device = Device::system_default().unwrap();
let queue = device.new_command_queue();
let command_buffer = queue.new_command_buffer();
// Allocate input buffers
let mut buffers = inputs
.iter()
.map(|buf| {
device.new_buffer_with_data(
buf.as_ptr() as *mut _,
(buf.len() * std::mem::size_of::<f32>()) as u64,
MTLResourceOptions::StorageModeShared,
)
})
.collect::<Vec<_>>();
let n_orig_buffers = buffers.len();
// Allocate output buffers
for kernel in kernels {
for output in kernel.outputs {
buffers.push(device.new_buffer(
(output * std::mem::size_of::<f32>()) as u64,
MTLResourceOptions::StorageModeShared,
));
}
}
// Queue up kernels
let mut output_kernel_index = 0;
for (
n_kernel,
Kernel {
code,
grid,
threadblock,
inputs,
outputs,
},
) in kernels.iter().enumerate()
{
let encoder =
command_buffer.compute_command_encoder_with_descriptor(ComputePassDescriptor::new());
// Compile kernel
let options = CompileOptions::new();
options.set_fast_math_enabled(true);
let lib = device.new_library_with_source(code, &options).unwrap();
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
pipeline_state_descriptor.set_compute_function(Some(
&lib.get_function(&format!("kernel{n_kernel}"), None)
.unwrap(),
));
let pipeline = device
.new_compute_pipeline_state_with_function(
pipeline_state_descriptor.compute_function().unwrap(),
)
.unwrap();
encoder.set_compute_pipeline_state(&pipeline);
// Set inputs
for (i, input) in inputs.iter().enumerate() {
encoder.set_buffer(i as u64, Some(&buffers[*input]), 0);
}
// Set output
for n in 0..outputs.len() {
encoder.set_buffer(
(inputs.len() + n) as u64,
Some(&buffers[output_kernel_index + n_orig_buffers]),
0,
);
output_kernel_index += 1;
}
// Set dispatch
encoder.dispatch_thread_groups(
MTLSize::new(grid.0 as u64, grid.1 as u64, grid.2 as u64),
MTLSize::new(
threadblock.0 as u64,
threadblock.1 as u64,
threadblock.2 as u64,
),
);
encoder.end_encoding();
}
// Run
command_buffer.commit();
command_buffer.wait_until_completed();
// Copy back last buffer
let buffer = buffers.last().unwrap();
let mut data = vec![0.0; buffer.length() as usize / std::mem::size_of::<f32>()];
let ptr = buffer.contents() as *mut f32;
for (i, d) in data.iter_mut().enumerate() {
*d = unsafe { *ptr.add(i) };
}
data
}
#[derive(Clone, Debug, Default)]
struct Stack {
inputs: Vec<Input>,
@@ -567,7 +568,8 @@ struct Kernel {
outputs: Vec<usize>, // buffer sizes this kernel creates
}
fn create_kernels(ir: Vec<Stack>) -> Vec<Kernel> {
fn create_kernels(mut ir: Vec<Stack>) -> Vec<Kernel> {
ir.last_mut().unwrap().kernel_output = true;
// Merge the stacks as much as possible
let mut loop_dim = 0;
let mut merged_ir = ir
@@ -772,7 +774,6 @@ fn create_kernels(ir: Vec<Stack>) -> Vec<Kernel> {
}
logical_index_start_kernel += instructions.len();
}
println!("{:?}", merged_ir);
for (n_kernel, (stack, instructions)) in merged_ir.into_iter().enumerate() {
// Compute grid and threadblock dim assignments
let exec_dims = stack
@@ -810,12 +811,18 @@ fn create_kernels(ir: Vec<Stack>) -> Vec<Kernel> {
.collect::<Vec<_>>();
input_buffer_indexes.sort();
input_buffer_indexes.dedup();
let output_buffer_size = stack
.frames
let output_buffer_sizes = instructions
.iter()
.filter(|i| !i.reduce)
.map(|i| i.size)
.product::<usize>();
.filter(|s| s.kernel_output)
.map(|stack| {
stack
.frames
.iter()
.filter(|i| !i.reduce)
.map(|i| i.size)
.product::<usize>()
})
.collect::<Vec<_>>();
// Write kernels
let mut kernel = "".to_string();
@@ -909,7 +916,7 @@ kernel void kernel{n_kernel}(
grid: (grid[0], grid[1], grid[2]),
threadblock: (threadblock[0], threadblock[1], threadblock[2]),
inputs: input_buffer_indexes,
outputs: vec![output_buffer_size],
outputs: output_buffer_sizes,
});
}