symbolic inputs / outputs

This commit is contained in:
Joe Fioti
2025-07-21 22:24:02 -07:00
parent 97d5e05820
commit 752d3e2401
7 changed files with 94 additions and 43 deletions

View File

@@ -26,6 +26,7 @@ pub fn codegen(
outputs: Vec<NodeIndex>,
mut arch: GPUArch,
n_graph: usize,
dyn_vars: &FxHashMap<char, usize>,
) -> Option<StableGraph<Kernel, (usize, usize), Directed>> {
let gmems = graph
.node_weights()
@@ -199,6 +200,7 @@ pub fn codegen(
if !input_comment.is_empty() {
input_comment = format!("\t// Inputs\n{input_comment}\n");
}
let n_inputs_outputs = inputs.len() + outputs.len();
let input_string = inputs
.into_iter()
.map(|(_, a)| a)
@@ -210,6 +212,12 @@ pub fn codegen(
var_to_char(node_to_var[&a].0)
)
})
.chain(dyn_vars.iter().sorted().enumerate().map(|(i, (c, _))| {
format!(
"constant uint& const_{c} [[buffer({})]]",
i + n_inputs_outputs
)
}))
.join(",\n\t");
let (smem_setup, smem_input) = if smem_buffers.is_empty() {
("".to_string(), "".to_string())
@@ -434,14 +442,11 @@ fn make_kernel(
// Make accumulator
*prev_max_var += 1;
arch.add_metal_buffer_type(*prev_max_var, "thread ");
let mut map = FxHashMap::default();
map.insert('s', 1);
map.insert('p', 0);
kernel_lines.push(format!(
"{spacing}{}float {}[{}] = {{0.0}};",
arch.metal_buffer_type(*prev_max_var),
var_to_char(*prev_max_var),
size.exec(&map).unwrap()
size.to_kernel()
));
// Copy from source to accumulator
@@ -451,12 +456,13 @@ fn make_kernel(
.unwrap();
// Use a single loop with correct striding from the input
kernel_lines.push(format!(
"{spacing}for (int load = 0; load < {loads}; ++load) {{"
"{spacing}for (int load = 0; load < {}; ++load) {{",
loads.to_kernel()
));
let indexing_expression = indexing_expression
.simplify()
.to_string()
.replace("z", "load");
.to_kernel()
.replace("const_z", "load");
kernel_lines.push(format!(
"{inner_spacing}{}[{indexing_expression}] = *({} + {indexing_expression});",
var_to_char(*prev_max_var),
@@ -521,7 +527,7 @@ fn make_kernel(
} else {
*prev_max_var += 1;
let loop_var = var_to_char(*prev_max_var);
kernel_lines.push(format!("{spacing}for (int loop_{loop_var} = 0; loop_{loop_var} < {range}; ++loop_{loop_var}) {{"));
kernel_lines.push(format!("{spacing}for (int loop_{loop_var} = 0; loop_{loop_var} < {}; ++loop_{loop_var}) {{", range.to_kernel()));
};
let loop_var = var_to_char(*prev_max_var);
let loop_var_int = *prev_max_var;
@@ -550,7 +556,9 @@ fn make_kernel(
arch.metal_buffer_type(*prev_max_var),
var_to_char(*prev_max_var),
var_to_char(real_input),
stride.to_string().replace('z', &format!("loop_{loop_var}"))
stride
.to_kernel()
.replace("const_z", &format!("loop_{loop_var}"))
));
node_to_var.insert(*input, (*prev_max_var, is_ptr));
}
@@ -584,7 +592,9 @@ fn make_kernel(
arch.metal_buffer_type(*prev_max_var),
var_to_char(*prev_max_var),
var_to_char(real_output),
stride.to_string().replace('z', &format!("loop_{loop_var}"))
stride
.to_kernel()
.replace("const_z", &format!("loop_{loop_var}"))
));
new_output_vars.push(*prev_max_var);
node_to_var.insert(*output, (*prev_max_var, is_ptr));
@@ -687,12 +697,13 @@ fn make_kernel(
current_elem_size *= range;
}
kernel_lines.push(format!(
"{spacing}for (int save = 0; save < {size}; ++save) {{"
"{spacing}for (int save = 0; save < {}; ++save) {{",
size.to_kernel()
));
let indexing_expression = indexing_expression
.simplify()
.to_string()
.replace("z", "save");
.to_kernel()
.replace("const_z", "save");
kernel_lines.push(format!(
"{inner_spacing}{}[{indexing_expression}] = *({} + {indexing_expression});",
var_to_char(outer_out),

View File

@@ -31,6 +31,7 @@ pub fn search(
egraph: &EGraph,
inputs: &[(NodeIndex, Vec<f32>)],
arch: GPUArch,
dyn_vars: &FxHashMap<char, usize>,
) -> Option<StableGraph<Kernel, (usize, usize)>> {
fn recurse<'a>(
egraph: &'a EGraph,
@@ -115,7 +116,8 @@ pub fn search(
// Build termdag
let graph = extraction_to_graph(egraph, &trajectory);
let root = graph.externals(Direction::Outgoing).next().unwrap();
let Some(kernels) = crate::codegen::codegen(graph.clone(), vec![root], arch.clone(), 0)
let Some(kernels) =
crate::codegen::codegen(graph.clone(), vec![root], arch.clone(), 0, dyn_vars)
else {
continue;
};

View File

@@ -13,9 +13,6 @@ use metal_rs::{
Buffer, CompileOptions, ComputePassDescriptor, ComputePipelineDescriptor, Device,
MTLResourceOptions, MTLSize,
};
use regex::Regex;
use rustc_hash::FxHashMap;
use serde::Serialize;
use std::{collections::HashMap, fmt::Debug};
#[derive(Clone, PartialEq, Eq)]
@@ -88,8 +85,11 @@ pub enum GraphTerm {
Custom(Kernel),
}
impl Operator for Kernel {
#[derive(Debug)]
pub struct CompatKernel(Kernel, *mut Graph);
impl Operator for CompatKernel {
fn process(&mut self, inputs: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
let dyn_vars = &unsafe { self.1.as_ref().unwrap() }.dyn_map;
let device = Device::system_default().unwrap();
let queue = device.new_command_queue();
let command_buffer = queue.new_command_buffer();
@@ -98,7 +98,7 @@ impl Operator for Kernel {
let options = CompileOptions::new();
// options.set_fast_math_enabled(true);
let lib = device
.new_library_with_source(&self.code, &options)
.new_library_with_source(&self.0.code, &options)
.unwrap();
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
pipeline_state_descriptor
@@ -120,33 +120,32 @@ impl Operator for Kernel {
}
// set output
let mut buffers = vec![];
for (i, size) in self.outputs.iter().enumerate() {
for (i, size) in self.0.outputs.iter().enumerate() {
buffers.push(device.new_buffer(
(size.exec(&FxHashMap::default()).unwrap() * std::mem::size_of::<f32>()) as u64,
(size.exec(dyn_vars).unwrap() * std::mem::size_of::<f32>()) as u64,
MTLResourceOptions::StorageModeShared,
));
encoder.set_buffer((i + inputs.len()) as u64, Some(buffers.last().unwrap()), 0);
}
// set smem
if !self.smem.is_empty() {
if !self.0.smem.is_empty() {
encoder.set_threadgroup_memory_length(
0,
(self.smem.exec(&FxHashMap::default()).unwrap() * std::mem::size_of::<f32>())
as u64,
(self.0.smem.exec(dyn_vars).unwrap() * std::mem::size_of::<f32>()) as u64,
);
}
// Set dispatch
encoder.dispatch_thread_groups(
MTLSize::new(
self.grid.0.exec(&FxHashMap::default()).unwrap() as u64,
self.grid.1.exec(&FxHashMap::default()).unwrap() as u64,
self.grid.2.exec(&FxHashMap::default()).unwrap() as u64,
self.0.grid.0.exec(dyn_vars).unwrap() as u64,
self.0.grid.1.exec(dyn_vars).unwrap() as u64,
self.0.grid.2.exec(dyn_vars).unwrap() as u64,
),
MTLSize::new(
self.threadblock.0.exec(&FxHashMap::default()).unwrap() as u64,
self.threadblock.1.exec(&FxHashMap::default()).unwrap() as u64,
self.threadblock.2.exec(&FxHashMap::default()).unwrap() as u64,
self.0.threadblock.0.exec(dyn_vars).unwrap() as u64,
self.0.threadblock.1.exec(dyn_vars).unwrap() as u64,
self.0.threadblock.2.exec(dyn_vars).unwrap() as u64,
),
);
encoder.end_encoding();
@@ -165,7 +164,8 @@ pub fn custom_kernel<const O: usize>(
output_shapes: [impl ToShape; O],
cx: &mut Graph,
) -> [GraphTensor; O] {
let mut kernel_op = cx.add_op(kernel);
let graph_ref: *mut Graph = cx;
let mut kernel_op = cx.add_op(CompatKernel(kernel, graph_ref));
for input in inputs {
kernel_op = kernel_op.input(input.id, 0, input.shape);
}

View File

@@ -99,10 +99,11 @@ pub fn run_graph(
// Copy outputs back
return (outputs, time_taken_micros);
} else {
// println!("Grid {:?} TB: {:?}", kernel.grid, kernel.threadblock);
// println!("{}", kernel.code);
println!("Grid {:?} TB: {:?}", kernel.grid, kernel.threadblock);
println!("{}", kernel.code);
// compile kernel
let n_inputs = kernels.edges_directed(node, Direction::Incoming).count();
let command_buffer = queue.new_command_buffer();
let encoder = command_buffer
.compute_command_encoder_with_descriptor(ComputePassDescriptor::new());
@@ -155,8 +156,9 @@ pub fn run_graph(
// println!("Inp {i}: {}", buffers[&(input, input_index)].length());
}
// set output
let n_inputs = kernels.edges_directed(node, Direction::Incoming).count();
for (i, size) in kernel.outputs.iter().enumerate() {
println!("{size}");
println!("{:?}", dyn_vars);
buffers.insert(
(node, i),
Some(device.new_buffer(
@@ -170,6 +172,14 @@ pub fn run_graph(
0,
);
}
// set dynamic dimensions
for (i, (_, v)) in dyn_vars.iter().sorted_by_key(|(k, _)| **k).enumerate() {
encoder.set_bytes(
(i + n_inputs + kernel.outputs.len()) as u64,
std::mem::size_of::<usize>() as u64,
v as *const usize as *const _,
);
}
// set smem
if !kernel.smem.is_empty() {
encoder.set_threadgroup_memory_length(

View File

@@ -5,7 +5,7 @@ use luminal::prelude::{
use rustc_hash::FxHashMap;
use crate::{
GraphTerm, Kernel,
CompatKernel, GraphTerm, Kernel,
codegen::{GRID_DIMS, THREADBLOCK_DIMS},
utils::loop_in,
};
@@ -250,9 +250,9 @@ pub fn translate_graph(
inits.push((new, vec![value]));
}
_ => {
if let Some(kernel) = node_weight.as_any().downcast_ref::<Kernel>() {
if let Some(kernel) = node_weight.as_any().downcast_ref::<CompatKernel>() {
// Add a custom kernel
let custom = new_graph.add_node(GraphTerm::Custom(kernel.clone()));
let custom = new_graph.add_node(GraphTerm::Custom(kernel.0.clone()));
for (source, ind, _) in sources {
let new_source = node_mapping[&(source, ind)];
new_graph.add_edge(new_source, custom, ());

View File

@@ -50,12 +50,12 @@ fn main() {
// Set up graph
let mut cx = Graph::new();
let mut input = cx.named_tensor("Input", 2);
let mut input = cx.named_tensor("Input", 's');
let mut cache_src: Vec<KVCache> = (0..model::NUM_LAYERS)
.map(|_| {
(
cx.named_tensor("Key Cache", (N_KV_HEADS, 0, HEAD_DIM)),
cx.named_tensor("Value Cache", (N_KV_HEADS, 0, HEAD_DIM)),
cx.named_tensor("Key Cache", (N_KV_HEADS, 'p', HEAD_DIM)),
cx.named_tensor("Value Cache", (N_KV_HEADS, 'p', HEAD_DIM)),
)
})
.collect();
@@ -97,6 +97,7 @@ fn main() {
let input_data = vec![0.0, 1.0];
input.set(input_data.clone());
cx.set_dyn_dim('s', 2);
cx.set_dyn_dim('p', 0);
cx.execute_debug();
// cx.display();
println!(
@@ -124,6 +125,7 @@ fn main() {
outputs,
luminal_2::GPUArch::Metal(HashMap::default()),
0,
&cx.dyn_map,
)
.unwrap();
// luminal_2::utils::display_graph(&kernels, &[]);
@@ -138,8 +140,6 @@ fn main() {
// );
// luminal_2::utils::display_graph(&kernels, &[]);
// luminal_2::utils::print_kernels(&kernels);
let mut dyn_map = FxHashMap::default();
dyn_map.insert('s', 2);
println!("input: {:?}", old_to_new_mapping[&input.id]);
let mut inps = vec![(
old_to_new_mapping[&input.id],
@@ -162,7 +162,7 @@ fn main() {
// let (buf_sizes, buf_map) = produce_buffer_map(&kernels);
// println!("bufs: {:?}", buf_sizes);
let (mut outputs, runtime) = run_graph(inps, &kernels, &dyn_map);
let (mut outputs, runtime) = run_graph(inps, &kernels, &cx.dyn_map);
let logits = logits.data();
println!("Old Logits: {:?}", &logits[..10]);
println!("New Logits: {:?}", &outputs[0][..10]);

View File

@@ -236,6 +236,34 @@ impl Expression {
}
symbols.pop().unwrap_or_default()
}
pub fn to_kernel(&self) -> String {
let mut symbols = vec![];
for term in self.terms.read().iter() {
let new_symbol = match term {
Term::Num(n) => n.to_string(),
Term::Var(c) => format!("const_{c}"),
Term::Acc(_) => "1".to_string(), // super jank, exists so that we can max(Acc, x)
Term::Max => format!(
"max((int){}, (int){})",
symbols.pop().unwrap(),
symbols.pop().unwrap()
),
Term::Min => format!(
"min((int){}, (int){})",
symbols.pop().unwrap(),
symbols.pop().unwrap()
),
_ => format!(
"({}{term:?}{})",
symbols.pop().unwrap(),
symbols.pop().unwrap()
),
};
symbols.push(new_symbol);
}
symbols.pop().unwrap_or_default()
}
}
impl std::fmt::Display for Expression {