mirror of
https://github.com/jafioti/luminal.git
synced 2026-06-01 21:49:47 +09:00
symbolic inputs / outputs
This commit is contained in:
@@ -26,6 +26,7 @@ pub fn codegen(
|
||||
outputs: Vec<NodeIndex>,
|
||||
mut arch: GPUArch,
|
||||
n_graph: usize,
|
||||
dyn_vars: &FxHashMap<char, usize>,
|
||||
) -> Option<StableGraph<Kernel, (usize, usize), Directed>> {
|
||||
let gmems = graph
|
||||
.node_weights()
|
||||
@@ -199,6 +200,7 @@ pub fn codegen(
|
||||
if !input_comment.is_empty() {
|
||||
input_comment = format!("\t// Inputs\n{input_comment}\n");
|
||||
}
|
||||
let n_inputs_outputs = inputs.len() + outputs.len();
|
||||
let input_string = inputs
|
||||
.into_iter()
|
||||
.map(|(_, a)| a)
|
||||
@@ -210,6 +212,12 @@ pub fn codegen(
|
||||
var_to_char(node_to_var[&a].0)
|
||||
)
|
||||
})
|
||||
.chain(dyn_vars.iter().sorted().enumerate().map(|(i, (c, _))| {
|
||||
format!(
|
||||
"constant uint& const_{c} [[buffer({})]]",
|
||||
i + n_inputs_outputs
|
||||
)
|
||||
}))
|
||||
.join(",\n\t");
|
||||
let (smem_setup, smem_input) = if smem_buffers.is_empty() {
|
||||
("".to_string(), "".to_string())
|
||||
@@ -434,14 +442,11 @@ fn make_kernel(
|
||||
// Make accumulator
|
||||
*prev_max_var += 1;
|
||||
arch.add_metal_buffer_type(*prev_max_var, "thread ");
|
||||
let mut map = FxHashMap::default();
|
||||
map.insert('s', 1);
|
||||
map.insert('p', 0);
|
||||
kernel_lines.push(format!(
|
||||
"{spacing}{}float {}[{}] = {{0.0}};",
|
||||
arch.metal_buffer_type(*prev_max_var),
|
||||
var_to_char(*prev_max_var),
|
||||
size.exec(&map).unwrap()
|
||||
size.to_kernel()
|
||||
));
|
||||
|
||||
// Copy from source to accumulator
|
||||
@@ -451,12 +456,13 @@ fn make_kernel(
|
||||
.unwrap();
|
||||
// Use a single loop with correct striding from the input
|
||||
kernel_lines.push(format!(
|
||||
"{spacing}for (int load = 0; load < {loads}; ++load) {{"
|
||||
"{spacing}for (int load = 0; load < {}; ++load) {{",
|
||||
loads.to_kernel()
|
||||
));
|
||||
let indexing_expression = indexing_expression
|
||||
.simplify()
|
||||
.to_string()
|
||||
.replace("z", "load");
|
||||
.to_kernel()
|
||||
.replace("const_z", "load");
|
||||
kernel_lines.push(format!(
|
||||
"{inner_spacing}{}[{indexing_expression}] = *({} + {indexing_expression});",
|
||||
var_to_char(*prev_max_var),
|
||||
@@ -521,7 +527,7 @@ fn make_kernel(
|
||||
} else {
|
||||
*prev_max_var += 1;
|
||||
let loop_var = var_to_char(*prev_max_var);
|
||||
kernel_lines.push(format!("{spacing}for (int loop_{loop_var} = 0; loop_{loop_var} < {range}; ++loop_{loop_var}) {{"));
|
||||
kernel_lines.push(format!("{spacing}for (int loop_{loop_var} = 0; loop_{loop_var} < {}; ++loop_{loop_var}) {{", range.to_kernel()));
|
||||
};
|
||||
let loop_var = var_to_char(*prev_max_var);
|
||||
let loop_var_int = *prev_max_var;
|
||||
@@ -550,7 +556,9 @@ fn make_kernel(
|
||||
arch.metal_buffer_type(*prev_max_var),
|
||||
var_to_char(*prev_max_var),
|
||||
var_to_char(real_input),
|
||||
stride.to_string().replace('z', &format!("loop_{loop_var}"))
|
||||
stride
|
||||
.to_kernel()
|
||||
.replace("const_z", &format!("loop_{loop_var}"))
|
||||
));
|
||||
node_to_var.insert(*input, (*prev_max_var, is_ptr));
|
||||
}
|
||||
@@ -584,7 +592,9 @@ fn make_kernel(
|
||||
arch.metal_buffer_type(*prev_max_var),
|
||||
var_to_char(*prev_max_var),
|
||||
var_to_char(real_output),
|
||||
stride.to_string().replace('z', &format!("loop_{loop_var}"))
|
||||
stride
|
||||
.to_kernel()
|
||||
.replace("const_z", &format!("loop_{loop_var}"))
|
||||
));
|
||||
new_output_vars.push(*prev_max_var);
|
||||
node_to_var.insert(*output, (*prev_max_var, is_ptr));
|
||||
@@ -687,12 +697,13 @@ fn make_kernel(
|
||||
current_elem_size *= range;
|
||||
}
|
||||
kernel_lines.push(format!(
|
||||
"{spacing}for (int save = 0; save < {size}; ++save) {{"
|
||||
"{spacing}for (int save = 0; save < {}; ++save) {{",
|
||||
size.to_kernel()
|
||||
));
|
||||
let indexing_expression = indexing_expression
|
||||
.simplify()
|
||||
.to_string()
|
||||
.replace("z", "save");
|
||||
.to_kernel()
|
||||
.replace("const_z", "save");
|
||||
kernel_lines.push(format!(
|
||||
"{inner_spacing}{}[{indexing_expression}] = *({} + {indexing_expression});",
|
||||
var_to_char(outer_out),
|
||||
|
||||
@@ -31,6 +31,7 @@ pub fn search(
|
||||
egraph: &EGraph,
|
||||
inputs: &[(NodeIndex, Vec<f32>)],
|
||||
arch: GPUArch,
|
||||
dyn_vars: &FxHashMap<char, usize>,
|
||||
) -> Option<StableGraph<Kernel, (usize, usize)>> {
|
||||
fn recurse<'a>(
|
||||
egraph: &'a EGraph,
|
||||
@@ -115,7 +116,8 @@ pub fn search(
|
||||
// Build termdag
|
||||
let graph = extraction_to_graph(egraph, &trajectory);
|
||||
let root = graph.externals(Direction::Outgoing).next().unwrap();
|
||||
let Some(kernels) = crate::codegen::codegen(graph.clone(), vec![root], arch.clone(), 0)
|
||||
let Some(kernels) =
|
||||
crate::codegen::codegen(graph.clone(), vec![root], arch.clone(), 0, dyn_vars)
|
||||
else {
|
||||
continue;
|
||||
};
|
||||
|
||||
@@ -13,9 +13,6 @@ use metal_rs::{
|
||||
Buffer, CompileOptions, ComputePassDescriptor, ComputePipelineDescriptor, Device,
|
||||
MTLResourceOptions, MTLSize,
|
||||
};
|
||||
use regex::Regex;
|
||||
use rustc_hash::FxHashMap;
|
||||
use serde::Serialize;
|
||||
use std::{collections::HashMap, fmt::Debug};
|
||||
|
||||
#[derive(Clone, PartialEq, Eq)]
|
||||
@@ -88,8 +85,11 @@ pub enum GraphTerm {
|
||||
Custom(Kernel),
|
||||
}
|
||||
|
||||
impl Operator for Kernel {
|
||||
#[derive(Debug)]
|
||||
pub struct CompatKernel(Kernel, *mut Graph);
|
||||
impl Operator for CompatKernel {
|
||||
fn process(&mut self, inputs: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
|
||||
let dyn_vars = &unsafe { self.1.as_ref().unwrap() }.dyn_map;
|
||||
let device = Device::system_default().unwrap();
|
||||
let queue = device.new_command_queue();
|
||||
let command_buffer = queue.new_command_buffer();
|
||||
@@ -98,7 +98,7 @@ impl Operator for Kernel {
|
||||
let options = CompileOptions::new();
|
||||
// options.set_fast_math_enabled(true);
|
||||
let lib = device
|
||||
.new_library_with_source(&self.code, &options)
|
||||
.new_library_with_source(&self.0.code, &options)
|
||||
.unwrap();
|
||||
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
||||
pipeline_state_descriptor
|
||||
@@ -120,33 +120,32 @@ impl Operator for Kernel {
|
||||
}
|
||||
// set output
|
||||
let mut buffers = vec![];
|
||||
for (i, size) in self.outputs.iter().enumerate() {
|
||||
for (i, size) in self.0.outputs.iter().enumerate() {
|
||||
buffers.push(device.new_buffer(
|
||||
(size.exec(&FxHashMap::default()).unwrap() * std::mem::size_of::<f32>()) as u64,
|
||||
(size.exec(dyn_vars).unwrap() * std::mem::size_of::<f32>()) as u64,
|
||||
MTLResourceOptions::StorageModeShared,
|
||||
));
|
||||
encoder.set_buffer((i + inputs.len()) as u64, Some(buffers.last().unwrap()), 0);
|
||||
}
|
||||
// set smem
|
||||
if !self.smem.is_empty() {
|
||||
if !self.0.smem.is_empty() {
|
||||
encoder.set_threadgroup_memory_length(
|
||||
0,
|
||||
(self.smem.exec(&FxHashMap::default()).unwrap() * std::mem::size_of::<f32>())
|
||||
as u64,
|
||||
(self.0.smem.exec(dyn_vars).unwrap() * std::mem::size_of::<f32>()) as u64,
|
||||
);
|
||||
}
|
||||
|
||||
// Set dispatch
|
||||
encoder.dispatch_thread_groups(
|
||||
MTLSize::new(
|
||||
self.grid.0.exec(&FxHashMap::default()).unwrap() as u64,
|
||||
self.grid.1.exec(&FxHashMap::default()).unwrap() as u64,
|
||||
self.grid.2.exec(&FxHashMap::default()).unwrap() as u64,
|
||||
self.0.grid.0.exec(dyn_vars).unwrap() as u64,
|
||||
self.0.grid.1.exec(dyn_vars).unwrap() as u64,
|
||||
self.0.grid.2.exec(dyn_vars).unwrap() as u64,
|
||||
),
|
||||
MTLSize::new(
|
||||
self.threadblock.0.exec(&FxHashMap::default()).unwrap() as u64,
|
||||
self.threadblock.1.exec(&FxHashMap::default()).unwrap() as u64,
|
||||
self.threadblock.2.exec(&FxHashMap::default()).unwrap() as u64,
|
||||
self.0.threadblock.0.exec(dyn_vars).unwrap() as u64,
|
||||
self.0.threadblock.1.exec(dyn_vars).unwrap() as u64,
|
||||
self.0.threadblock.2.exec(dyn_vars).unwrap() as u64,
|
||||
),
|
||||
);
|
||||
encoder.end_encoding();
|
||||
@@ -165,7 +164,8 @@ pub fn custom_kernel<const O: usize>(
|
||||
output_shapes: [impl ToShape; O],
|
||||
cx: &mut Graph,
|
||||
) -> [GraphTensor; O] {
|
||||
let mut kernel_op = cx.add_op(kernel);
|
||||
let graph_ref: *mut Graph = cx;
|
||||
let mut kernel_op = cx.add_op(CompatKernel(kernel, graph_ref));
|
||||
for input in inputs {
|
||||
kernel_op = kernel_op.input(input.id, 0, input.shape);
|
||||
}
|
||||
|
||||
@@ -99,10 +99,11 @@ pub fn run_graph(
|
||||
// Copy outputs back
|
||||
return (outputs, time_taken_micros);
|
||||
} else {
|
||||
// println!("Grid {:?} TB: {:?}", kernel.grid, kernel.threadblock);
|
||||
// println!("{}", kernel.code);
|
||||
println!("Grid {:?} TB: {:?}", kernel.grid, kernel.threadblock);
|
||||
println!("{}", kernel.code);
|
||||
|
||||
// compile kernel
|
||||
let n_inputs = kernels.edges_directed(node, Direction::Incoming).count();
|
||||
let command_buffer = queue.new_command_buffer();
|
||||
let encoder = command_buffer
|
||||
.compute_command_encoder_with_descriptor(ComputePassDescriptor::new());
|
||||
@@ -155,8 +156,9 @@ pub fn run_graph(
|
||||
// println!("Inp {i}: {}", buffers[&(input, input_index)].length());
|
||||
}
|
||||
// set output
|
||||
let n_inputs = kernels.edges_directed(node, Direction::Incoming).count();
|
||||
for (i, size) in kernel.outputs.iter().enumerate() {
|
||||
println!("{size}");
|
||||
println!("{:?}", dyn_vars);
|
||||
buffers.insert(
|
||||
(node, i),
|
||||
Some(device.new_buffer(
|
||||
@@ -170,6 +172,14 @@ pub fn run_graph(
|
||||
0,
|
||||
);
|
||||
}
|
||||
// set dynamic dimensions
|
||||
for (i, (_, v)) in dyn_vars.iter().sorted_by_key(|(k, _)| **k).enumerate() {
|
||||
encoder.set_bytes(
|
||||
(i + n_inputs + kernel.outputs.len()) as u64,
|
||||
std::mem::size_of::<usize>() as u64,
|
||||
v as *const usize as *const _,
|
||||
);
|
||||
}
|
||||
// set smem
|
||||
if !kernel.smem.is_empty() {
|
||||
encoder.set_threadgroup_memory_length(
|
||||
|
||||
@@ -5,7 +5,7 @@ use luminal::prelude::{
|
||||
use rustc_hash::FxHashMap;
|
||||
|
||||
use crate::{
|
||||
GraphTerm, Kernel,
|
||||
CompatKernel, GraphTerm, Kernel,
|
||||
codegen::{GRID_DIMS, THREADBLOCK_DIMS},
|
||||
utils::loop_in,
|
||||
};
|
||||
@@ -250,9 +250,9 @@ pub fn translate_graph(
|
||||
inits.push((new, vec![value]));
|
||||
}
|
||||
_ => {
|
||||
if let Some(kernel) = node_weight.as_any().downcast_ref::<Kernel>() {
|
||||
if let Some(kernel) = node_weight.as_any().downcast_ref::<CompatKernel>() {
|
||||
// Add a custom kernel
|
||||
let custom = new_graph.add_node(GraphTerm::Custom(kernel.clone()));
|
||||
let custom = new_graph.add_node(GraphTerm::Custom(kernel.0.clone()));
|
||||
for (source, ind, _) in sources {
|
||||
let new_source = node_mapping[&(source, ind)];
|
||||
new_graph.add_edge(new_source, custom, ());
|
||||
|
||||
@@ -50,12 +50,12 @@ fn main() {
|
||||
|
||||
// Set up graph
|
||||
let mut cx = Graph::new();
|
||||
let mut input = cx.named_tensor("Input", 2);
|
||||
let mut input = cx.named_tensor("Input", 's');
|
||||
let mut cache_src: Vec<KVCache> = (0..model::NUM_LAYERS)
|
||||
.map(|_| {
|
||||
(
|
||||
cx.named_tensor("Key Cache", (N_KV_HEADS, 0, HEAD_DIM)),
|
||||
cx.named_tensor("Value Cache", (N_KV_HEADS, 0, HEAD_DIM)),
|
||||
cx.named_tensor("Key Cache", (N_KV_HEADS, 'p', HEAD_DIM)),
|
||||
cx.named_tensor("Value Cache", (N_KV_HEADS, 'p', HEAD_DIM)),
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
@@ -97,6 +97,7 @@ fn main() {
|
||||
let input_data = vec![0.0, 1.0];
|
||||
input.set(input_data.clone());
|
||||
cx.set_dyn_dim('s', 2);
|
||||
cx.set_dyn_dim('p', 0);
|
||||
cx.execute_debug();
|
||||
// cx.display();
|
||||
println!(
|
||||
@@ -124,6 +125,7 @@ fn main() {
|
||||
outputs,
|
||||
luminal_2::GPUArch::Metal(HashMap::default()),
|
||||
0,
|
||||
&cx.dyn_map,
|
||||
)
|
||||
.unwrap();
|
||||
// luminal_2::utils::display_graph(&kernels, &[]);
|
||||
@@ -138,8 +140,6 @@ fn main() {
|
||||
// );
|
||||
// luminal_2::utils::display_graph(&kernels, &[]);
|
||||
// luminal_2::utils::print_kernels(&kernels);
|
||||
let mut dyn_map = FxHashMap::default();
|
||||
dyn_map.insert('s', 2);
|
||||
println!("input: {:?}", old_to_new_mapping[&input.id]);
|
||||
let mut inps = vec![(
|
||||
old_to_new_mapping[&input.id],
|
||||
@@ -162,7 +162,7 @@ fn main() {
|
||||
|
||||
// let (buf_sizes, buf_map) = produce_buffer_map(&kernels);
|
||||
// println!("bufs: {:?}", buf_sizes);
|
||||
let (mut outputs, runtime) = run_graph(inps, &kernels, &dyn_map);
|
||||
let (mut outputs, runtime) = run_graph(inps, &kernels, &cx.dyn_map);
|
||||
let logits = logits.data();
|
||||
println!("Old Logits: {:?}", &logits[..10]);
|
||||
println!("New Logits: {:?}", &outputs[0][..10]);
|
||||
|
||||
@@ -236,6 +236,34 @@ impl Expression {
|
||||
}
|
||||
symbols.pop().unwrap_or_default()
|
||||
}
|
||||
|
||||
pub fn to_kernel(&self) -> String {
|
||||
let mut symbols = vec![];
|
||||
for term in self.terms.read().iter() {
|
||||
let new_symbol = match term {
|
||||
Term::Num(n) => n.to_string(),
|
||||
Term::Var(c) => format!("const_{c}"),
|
||||
Term::Acc(_) => "1".to_string(), // super jank, exists so that we can max(Acc, x)
|
||||
Term::Max => format!(
|
||||
"max((int){}, (int){})",
|
||||
symbols.pop().unwrap(),
|
||||
symbols.pop().unwrap()
|
||||
),
|
||||
Term::Min => format!(
|
||||
"min((int){}, (int){})",
|
||||
symbols.pop().unwrap(),
|
||||
symbols.pop().unwrap()
|
||||
),
|
||||
_ => format!(
|
||||
"({}{term:?}{})",
|
||||
symbols.pop().unwrap(),
|
||||
symbols.pop().unwrap()
|
||||
),
|
||||
};
|
||||
symbols.push(new_symbol);
|
||||
}
|
||||
symbols.pop().unwrap_or_default()
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Expression {
|
||||
|
||||
Reference in New Issue
Block a user