mirror of
https://github.com/jafioti/luminal.git
synced 2026-06-01 21:49:47 +09:00
multiple output buffers
This commit is contained in:
@@ -7,4 +7,4 @@ edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
itertools = "0.14.0"
|
||||
#metal-rs = { version = "0.28.0", package = "metal", features = ["mps"] }
|
||||
metal-rs = { version = "0.28.0", package = "metal", features = ["mps"] }
|
||||
|
||||
@@ -432,115 +432,116 @@ fn main() {
|
||||
// println!("---");
|
||||
|
||||
// Set inputs
|
||||
let a = (0..12).map(|i| i as f32).collect::<Vec<_>>();
|
||||
// let a = (0..12).map(|i| i as f32).collect::<Vec<_>>();
|
||||
// let b = (0..8)
|
||||
// .flat_map(|i| (0..8).map(move |j| if j == i { 1.0 } else { 0.0 }))
|
||||
// .collect::<Vec<_>>();
|
||||
|
||||
println!("Out: {:?}", run_graph(vec![a], &kernels));
|
||||
}
|
||||
|
||||
fn run_graph(inputs: Vec<Vec<f32>>, kernels: &[Kernel]) -> Vec<f32> {
|
||||
todo!()
|
||||
// println!("Out: {:?}", run_graph(vec![a], &kernels));
|
||||
}
|
||||
|
||||
// fn run_graph(inputs: Vec<Vec<f32>>, kernels: &[Kernel]) -> Vec<f32> {
|
||||
// let device = Device::system_default().unwrap();
|
||||
// let queue = device.new_command_queue();
|
||||
// let command_buffer = queue.new_command_buffer();
|
||||
|
||||
// // Allocate input buffers
|
||||
// let mut buffers = inputs
|
||||
// .iter()
|
||||
// .map(|buf| {
|
||||
// device.new_buffer_with_data(
|
||||
// buf.as_ptr() as *mut _,
|
||||
// (buf.len() * std::mem::size_of::<f32>()) as u64,
|
||||
// MTLResourceOptions::StorageModeShared,
|
||||
// )
|
||||
// })
|
||||
// .collect::<Vec<_>>();
|
||||
// let n_orig_buffers = buffers.len();
|
||||
// // Allocate output buffers
|
||||
// for kernel in kernels {
|
||||
// assert_eq!(
|
||||
// kernel.outputs.len(),
|
||||
// 1,
|
||||
// "Can't handle more than one kernel output for now"
|
||||
// );
|
||||
// buffers.push(device.new_buffer(
|
||||
// (kernel.outputs[0] * std::mem::size_of::<f32>()) as u64,
|
||||
// MTLResourceOptions::StorageModeShared,
|
||||
// ));
|
||||
// }
|
||||
// // Queue up kernels
|
||||
// for (
|
||||
// n_kernel,
|
||||
// Kernel {
|
||||
// code,
|
||||
// grid,
|
||||
// threadblock,
|
||||
// inputs,
|
||||
// ..
|
||||
// },
|
||||
// ) in kernels.iter().enumerate()
|
||||
// {
|
||||
// let encoder =
|
||||
// command_buffer.compute_command_encoder_with_descriptor(ComputePassDescriptor::new());
|
||||
|
||||
// // Compile kernel
|
||||
// let options = CompileOptions::new();
|
||||
// options.set_fast_math_enabled(true);
|
||||
// let lib = device.new_library_with_source(code, &options).unwrap();
|
||||
// let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
||||
// pipeline_state_descriptor.set_compute_function(Some(
|
||||
// &lib.get_function(&format!("kernel{n_kernel}"), None)
|
||||
// .unwrap(),
|
||||
// ));
|
||||
// let pipeline = device
|
||||
// .new_compute_pipeline_state_with_function(
|
||||
// pipeline_state_descriptor.compute_function().unwrap(),
|
||||
// )
|
||||
// .unwrap();
|
||||
// encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
// // Set inputs
|
||||
// for (i, input) in inputs.iter().enumerate() {
|
||||
// encoder.set_buffer(i as u64, Some(&buffers[*input]), 0);
|
||||
// }
|
||||
// // Set output
|
||||
// encoder.set_buffer(
|
||||
// inputs.len() as u64,
|
||||
// Some(&buffers[n_kernel + n_orig_buffers]),
|
||||
// 0,
|
||||
// );
|
||||
|
||||
// // Set dispatch
|
||||
// encoder.dispatch_thread_groups(
|
||||
// MTLSize::new(grid.0 as u64, grid.1 as u64, grid.2 as u64),
|
||||
// MTLSize::new(
|
||||
// threadblock.0 as u64,
|
||||
// threadblock.1 as u64,
|
||||
// threadblock.2 as u64,
|
||||
// ),
|
||||
// );
|
||||
// encoder.end_encoding();
|
||||
// }
|
||||
|
||||
// // Run
|
||||
// command_buffer.commit();
|
||||
// command_buffer.wait_until_completed();
|
||||
|
||||
// // Copy back last buffer
|
||||
// let buffer = buffers.last().unwrap();
|
||||
// let mut data = vec![0.0; buffer.length() as usize / std::mem::size_of::<f32>()];
|
||||
// let ptr = buffer.contents() as *mut f32;
|
||||
// for (i, d) in data.iter_mut().enumerate() {
|
||||
// *d = unsafe { *ptr.add(i) };
|
||||
// }
|
||||
// data
|
||||
// todo!()
|
||||
// }
|
||||
|
||||
fn run_graph(inputs: Vec<Vec<f32>>, kernels: &[Kernel]) -> Vec<f32> {
|
||||
let device = Device::system_default().unwrap();
|
||||
let queue = device.new_command_queue();
|
||||
let command_buffer = queue.new_command_buffer();
|
||||
|
||||
// Allocate input buffers
|
||||
let mut buffers = inputs
|
||||
.iter()
|
||||
.map(|buf| {
|
||||
device.new_buffer_with_data(
|
||||
buf.as_ptr() as *mut _,
|
||||
(buf.len() * std::mem::size_of::<f32>()) as u64,
|
||||
MTLResourceOptions::StorageModeShared,
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let n_orig_buffers = buffers.len();
|
||||
// Allocate output buffers
|
||||
for kernel in kernels {
|
||||
for output in kernel.outputs {
|
||||
buffers.push(device.new_buffer(
|
||||
(output * std::mem::size_of::<f32>()) as u64,
|
||||
MTLResourceOptions::StorageModeShared,
|
||||
));
|
||||
}
|
||||
}
|
||||
// Queue up kernels
|
||||
let mut output_kernel_index = 0;
|
||||
for (
|
||||
n_kernel,
|
||||
Kernel {
|
||||
code,
|
||||
grid,
|
||||
threadblock,
|
||||
inputs,
|
||||
outputs,
|
||||
},
|
||||
) in kernels.iter().enumerate()
|
||||
{
|
||||
let encoder =
|
||||
command_buffer.compute_command_encoder_with_descriptor(ComputePassDescriptor::new());
|
||||
|
||||
// Compile kernel
|
||||
let options = CompileOptions::new();
|
||||
options.set_fast_math_enabled(true);
|
||||
let lib = device.new_library_with_source(code, &options).unwrap();
|
||||
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
||||
pipeline_state_descriptor.set_compute_function(Some(
|
||||
&lib.get_function(&format!("kernel{n_kernel}"), None)
|
||||
.unwrap(),
|
||||
));
|
||||
let pipeline = device
|
||||
.new_compute_pipeline_state_with_function(
|
||||
pipeline_state_descriptor.compute_function().unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
// Set inputs
|
||||
for (i, input) in inputs.iter().enumerate() {
|
||||
encoder.set_buffer(i as u64, Some(&buffers[*input]), 0);
|
||||
}
|
||||
// Set output
|
||||
for n in 0..outputs.len() {
|
||||
encoder.set_buffer(
|
||||
(inputs.len() + n) as u64,
|
||||
Some(&buffers[output_kernel_index + n_orig_buffers]),
|
||||
0,
|
||||
);
|
||||
output_kernel_index += 1;
|
||||
}
|
||||
|
||||
// Set dispatch
|
||||
encoder.dispatch_thread_groups(
|
||||
MTLSize::new(grid.0 as u64, grid.1 as u64, grid.2 as u64),
|
||||
MTLSize::new(
|
||||
threadblock.0 as u64,
|
||||
threadblock.1 as u64,
|
||||
threadblock.2 as u64,
|
||||
),
|
||||
);
|
||||
encoder.end_encoding();
|
||||
}
|
||||
|
||||
// Run
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
// Copy back last buffer
|
||||
let buffer = buffers.last().unwrap();
|
||||
let mut data = vec![0.0; buffer.length() as usize / std::mem::size_of::<f32>()];
|
||||
let ptr = buffer.contents() as *mut f32;
|
||||
for (i, d) in data.iter_mut().enumerate() {
|
||||
*d = unsafe { *ptr.add(i) };
|
||||
}
|
||||
data
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Default)]
|
||||
struct Stack {
|
||||
inputs: Vec<Input>,
|
||||
@@ -567,7 +568,8 @@ struct Kernel {
|
||||
outputs: Vec<usize>, // buffer sizes this kernel creates
|
||||
}
|
||||
|
||||
fn create_kernels(ir: Vec<Stack>) -> Vec<Kernel> {
|
||||
fn create_kernels(mut ir: Vec<Stack>) -> Vec<Kernel> {
|
||||
ir.last_mut().unwrap().kernel_output = true;
|
||||
// Merge the stacks as much as possible
|
||||
let mut loop_dim = 0;
|
||||
let mut merged_ir = ir
|
||||
@@ -772,7 +774,6 @@ fn create_kernels(ir: Vec<Stack>) -> Vec<Kernel> {
|
||||
}
|
||||
logical_index_start_kernel += instructions.len();
|
||||
}
|
||||
println!("{:?}", merged_ir);
|
||||
for (n_kernel, (stack, instructions)) in merged_ir.into_iter().enumerate() {
|
||||
// Compute grid and threadblock dim assignments
|
||||
let exec_dims = stack
|
||||
@@ -810,12 +811,18 @@ fn create_kernels(ir: Vec<Stack>) -> Vec<Kernel> {
|
||||
.collect::<Vec<_>>();
|
||||
input_buffer_indexes.sort();
|
||||
input_buffer_indexes.dedup();
|
||||
let output_buffer_size = stack
|
||||
.frames
|
||||
let output_buffer_sizes = instructions
|
||||
.iter()
|
||||
.filter(|i| !i.reduce)
|
||||
.map(|i| i.size)
|
||||
.product::<usize>();
|
||||
.filter(|s| s.kernel_output)
|
||||
.map(|stack| {
|
||||
stack
|
||||
.frames
|
||||
.iter()
|
||||
.filter(|i| !i.reduce)
|
||||
.map(|i| i.size)
|
||||
.product::<usize>()
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// Write kernels
|
||||
let mut kernel = "".to_string();
|
||||
@@ -909,7 +916,7 @@ kernel void kernel{n_kernel}(
|
||||
grid: (grid[0], grid[1], grid[2]),
|
||||
threadblock: (threadblock[0], threadblock[1], threadblock[2]),
|
||||
inputs: input_buffer_indexes,
|
||||
outputs: vec![output_buffer_size],
|
||||
outputs: output_buffer_sizes,
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user