Refactored expression system

This commit is contained in:
Joe Fioti
2024-07-20 22:51:24 -05:00
parent 2bdad487d9
commit b5c38dc6db
48 changed files with 1806 additions and 967 deletions

View File

@@ -25,6 +25,8 @@ as-any = "0.3.1"
egg = "0.9.5"
symbolic_expressions = "5.0.3"
serde = {version="1.0.202", features=["derive"]}
thread_local = "1.1.8"
generational-box = "0.5.6"
[dev-dependencies]
dfdx = { version = "0.13", features = ["f16"] }

View File

@@ -231,7 +231,7 @@ impl Compiler for GatherCompiler {
.as_data()
.unwrap()
.2
.shape()[2]
.dims()[2]
.to_usize()
.unwrap();

View File

@@ -63,7 +63,7 @@ pub struct MatMul2D;
impl Operator for MatMul2D {
fn process(&mut self, inp: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
let (a_shape, b_shape) = (inp[0].1.shape(), inp[1].1.shape());
let (a_shape, b_shape) = (inp[0].1.dims(), inp[1].1.dims());
let (a_strides, b_strides) = (inp[0].1.strides(), inp[1].1.strides());
let a_data = inp[0].0.borrowed().downcast_ref::<Vec<f32>>().unwrap();
let b_data = inp[1].0.borrowed().downcast_ref::<Vec<f32>>().unwrap();
@@ -151,7 +151,7 @@ pub struct BatchedMatMul2D;
// ABCxCD -> ABD
impl Operator for BatchedMatMul2D {
fn process(&mut self, inp: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
let (a_shape, b_shape) = (inp[0].1.shape(), inp[1].1.shape());
let (a_shape, b_shape) = (inp[0].1.dims(), inp[1].1.dims());
let (a_strides, b_strides) = (inp[0].1.strides(), inp[1].1.strides());
let a_data = inp[0].0.borrowed().downcast_ref::<Vec<f32>>().unwrap();
let b_data = inp[1].0.borrowed().downcast_ref::<Vec<f32>>().unwrap();

View File

@@ -8,7 +8,7 @@ use super::binary::Sub;
#[derive(Debug, Clone, PartialEq)]
pub struct ARange {
pub size: BigExpression,
pub size: Expression,
dyn_map: *const FxHashMap<char, usize>,
}
@@ -61,7 +61,7 @@ impl Compiler for ARangeCompiler {
};
let arange_op = graph
.add_op(ARange {
size: arange_amount.into(),
size: arange_amount,
dyn_map: &graph.dyn_map,
})
.finish();

View File

@@ -71,7 +71,7 @@ kernel void mkernel(device {type_name} *inp_a [[buffer(0)]], device {type_name}
}
impl<T> MetalKernel for MetalSub<T> {
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<BigExpression> {
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<Expression> {
vec![input_shapes[0].n_elements() * size_of::<T>()]
}
fn metal_forward(
@@ -252,7 +252,7 @@ kernel void mkernel(device {type_name} *inp_a [[buffer(0)]], device {type_name}
}
impl<T> MetalKernel for MetalEqual<T> {
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<BigExpression> {
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<Expression> {
vec![input_shapes[0].n_elements() * size_of::<T>()]
}
fn metal_forward(
@@ -419,7 +419,7 @@ impl<T: MetalFloat> Operator for MetalGather<T> {
// Setup buffers
let indexes = tensors[0].0.borrowed().downcast_ref::<Vec<f32>>().unwrap();
let index_buffer = self.device.new_buffer_with_data(
unsafe { std::mem::transmute(indexes.as_ptr()) },
indexes.as_ptr() as *const _,
(indexes.len() * std::mem::size_of::<f32>()) as u64,
MTLResourceOptions::StorageModeShared,
);
@@ -498,7 +498,7 @@ impl<T: MetalFloat> Compiler for MetalGatherCompiler<T> {
.as_data()
.unwrap()
.2;
let embed_dim = emb_shape.shape()[2].to_usize().unwrap();
let embed_dim = emb_shape.dims()[2].to_usize().unwrap();
let index_shape = graph
.edges_connecting(s.get(&indexes), s.get(&ind_copy))
.next()

View File

@@ -236,10 +236,10 @@ impl std::fmt::Debug for CommandBufferWrapper {
}
impl MetalKernel for CommandBufferWrapper {
fn intermediate_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<BigExpression> {
fn intermediate_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<Expression> {
self.wrapper.intermediate_buffer_sizes(input_shapes)
}
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<BigExpression> {
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<Expression> {
self.wrapper.output_buffer_sizes(input_shapes)
}
fn metal_forward(

View File

@@ -121,7 +121,7 @@ impl<T: MetalFloat> Compiler for ElementwiseFusionCompiler<T> {
let mut subexpressions_b = graph
.try_get_op::<FusedElementwiseOp<T>>(b)
.map(|o| o.subexpressions.clone())
.unwrap_or_else(|| vec![(expression_b.clone(), ShapeTracker::default())]);
.unwrap_or_else(|| vec![(expression_b.clone(), ShapeTracker::new(()))]);
let a_to_b_indexes = graph
.edges_connecting(a, b)
.map(|e| e.weight().as_data().unwrap().0 as usize)
@@ -141,7 +141,7 @@ impl<T: MetalFloat> Compiler for ElementwiseFusionCompiler<T> {
let mut subexpressions_a = graph
.try_get_op::<FusedElementwiseOp<T>>(a)
.map(|o| o.subexpressions.clone())
.unwrap_or_else(|| vec![(expression_a.clone(), ShapeTracker::default())]);
.unwrap_or_else(|| vec![(expression_a.clone(), ShapeTracker::new(()))]);
subexpressions_a.last_mut().unwrap().1 = connecting_shape;
// Re-reference b intermediates
for i in (0..subexpressions_b.len()).rev() {
@@ -236,6 +236,7 @@ impl<T: MetalFloat> Compiler for ElementwiseFusionCompiler<T> {
.into_iter()
.map(|s| s.simplify_cache(&mut simplification_cache))
.collect();
let g: *mut Graph = graph;
let new_op = graph
.add_op(FusedElementwiseOp::<T> {
kernel_str: "".to_string(),
@@ -247,6 +248,7 @@ impl<T: MetalFloat> Compiler for ElementwiseFusionCompiler<T> {
device: device.clone(),
output_buffer_sizes,
_phantom: Default::default(),
graph: g,
})
.finish();
// Add edges to new op
@@ -293,17 +295,20 @@ impl<T: MetalFloat> Compiler for ElementwiseFusionCompiler<T> {
.into_iter()
.map(|s| s.simplify_cache(&mut simplification_cache))
.collect();
let sh = ShapeTracker::new(());
let g: *mut Graph = graph;
let new_op = graph
.add_op(FusedElementwiseOp::<T> {
kernel_str: "".to_string(),
kernel: None,
dyn_map: &graph.dyn_map,
dyn_chars: vec![],
subexpressions: vec![(op_string, ShapeTracker::default())],
subexpressions: vec![(op_string, sh)],
queue: queue.clone(),
device: device.clone(),
output_buffer_sizes,
_phantom: Default::default(),
graph: g,
})
.finish();
// Add edges to new op
@@ -346,8 +351,9 @@ pub struct FusedElementwiseOp<T> {
pub subexpressions: Vec<(String, ShapeTracker)>,
pub queue: CommandQueue,
pub device: Device,
pub output_buffer_sizes: Vec<BigExpression>,
pub output_buffer_sizes: Vec<Expression>,
pub _phantom: PhantomData<T>,
pub graph: *mut Graph,
}
crate::debug_type!(FusedElementwiseOp);
@@ -357,7 +363,7 @@ impl<T: MetalFloat> FusedElementwiseOp<T> {
input_shapes: Vec<ShapeTracker>,
input_regexes: &mut FxHashMap<usize, Regex>,
intermediate_match: &Regex,
simplification_cache: &mut FxHashMap<BigExpression, BigExpression>,
simplification_cache: &mut FxHashMap<Expression, Expression>,
) {
let mut subexpressions = self.subexpressions.clone();
let shapes_used = subexpressions
@@ -415,7 +421,7 @@ impl<T: MetalFloat> FusedElementwiseOp<T> {
s.iter()
.rev()
.take(s.len() - 1)
.fold(BigExpression::from('z'), |acc, inp| {
.fold(Expression::from('z'), |acc, inp| {
inp.index_expression().substitute('z', acc)
})
})
@@ -451,8 +457,8 @@ impl<T: MetalFloat> FusedElementwiseOp<T> {
input_regexes.get(&i).unwrap()
};
let (ind, val) = (
ind_exp.clone().simplify_cache(simplification_cache),
val_exp.clone().simplify_cache(simplification_cache),
ind_exp.simplify_cache(simplification_cache),
val_exp.simplify_cache(simplification_cache),
);
*subexp = re
.replace_all(
@@ -475,10 +481,10 @@ impl<T: MetalFloat> FusedElementwiseOp<T> {
.iter()
.rev()
.fold(
(BigExpression::from(true), BigExpression::from('z')),
(Expression::from(true), Expression::from('z')),
|(_, ind_acc), inp| {
(
inp.valid_expression().substitute('z', ind_acc.clone()),
inp.valid_expression().substitute('z', ind_acc),
inp.index_expression().substitute('z', ind_acc),
)
},
@@ -526,7 +532,7 @@ out[idx] = ({type_name})({});
}
impl<T> MetalKernel for FusedElementwiseOp<T> {
fn output_buffer_sizes(&self, _: &[ShapeTracker]) -> Vec<BigExpression> {
fn output_buffer_sizes(&self, _: &[ShapeTracker]) -> Vec<Expression> {
self.output_buffer_sizes.clone()
}
fn metal_forward(
@@ -669,7 +675,7 @@ mod tests {
let inp = random_vec_rng(10, &mut rng);
let a = cx.named_tensor("a", (2, 5)).set(inp);
let mut padded = a
.slice((..Expression::from(1), ..))
.slice((..1, ..))
.cos()
.pad(((0, 1), (0, 0)))
.exp2()
@@ -711,7 +717,7 @@ mod tests {
const HEAD_DIM: usize = 4;
let freqs = (cx.arange(HEAD_DIM / 2) * 2.0) / (HEAD_DIM as f32);
let freqs = 1000000_f32.pow(freqs);
let pos = cx.arange(SEQ) + BigExpression::from(0);
let pos = cx.arange(SEQ) + 0;
let mut emb = pos.expand(1, 1).matmul(freqs.expand(0, SEQ)).retrieve();
cx.execute();
@@ -775,16 +781,12 @@ mod tests {
.keep();
let freqs = (cx.arange(HEAD_DIM / 2) * 2.0) / (HEAD_DIM as f32);
let freqs = 1000000_f32.pow(freqs);
let pos = cx.arange(SEQ) + BigExpression::from(0);
let pos = cx.arange(SEQ) + 0;
let emb = pos.expand(1, 1).matmul(freqs.expand(0, SEQ));
// Split input into evens and odds
let split = a.reshape((BATCH, N_HEADS, SEQ, HEAD_DIM / 2, 2));
let x0 = split
.slice((.., .., .., .., ..Expression::from(1)))
.contiguous();
let x1 = split
.slice((.., .., .., .., Expression::from(1)..))
.contiguous();
let x0 = split.slice((.., .., .., .., ..1)).contiguous();
let x1 = split.slice((.., .., .., .., 1..)).contiguous();
// Apply sin and cos embeddings
let x0_out = x0 * emb.cos().expand_to(x0.shape) - x1 * emb.sin().expand_to(x1.shape);
@@ -852,15 +854,9 @@ mod tests {
}
}
fn apply_rotary_embeddings_ggml(
input: GraphTensor,
prev_seq: BigExpression,
) -> GraphTensor {
fn apply_rotary_embeddings_ggml(input: GraphTensor, prev_seq: Expression) -> GraphTensor {
assert_eq!(input.shape.len(), 4); // batch, n_heads, seq, head_dim
let batch = input.shape()[0].small();
let n_heads = input.shape()[1].small();
let seq = input.shape()[2].small();
let head_dim = input.shape()[3].small();
let (batch, n_heads, seq, head_dim) = input.dims4();
// Get freqs
let freqs =
(input.graph().arange(head_dim / 2) * 2.0) / (head_dim.to_usize().unwrap() as f32);
@@ -892,9 +888,8 @@ mod tests {
type Output = (GraphTensor, KVCache);
fn forward(&self, (x, (k_cache, v_cache)): (GraphTensor, KVCache)) -> Self::Output {
// x: batch, seq, hidden
let batch = x.shape()[0].small();
let seq = x.shape()[1].small();
let prev_seq = k_cache.shape()[2].small();
let (batch, seq, _) = x.dims3();
let (_, _, prev_seq, _) = k_cache.dims4();
// Apply the Projections
let queries = x
.matmul(self.q_proj.permute((1, 0)))
@@ -912,8 +907,8 @@ mod tests {
.permute((0, 2, 1, 3));
// Rotary embed queries and keys
let queries = apply_rotary_embeddings_ggml(queries, prev_seq.big());
let keys = apply_rotary_embeddings_ggml(keys, prev_seq.big());
let queries = apply_rotary_embeddings_ggml(queries, prev_seq);
let keys = apply_rotary_embeddings_ggml(keys, prev_seq);
// Add KV cache
let keys = k_cache.concat_along(keys, 2);

View File

@@ -140,11 +140,11 @@ impl MetalFloat for f16 {
pub trait MetalKernel: Debug {
/// Annotate the buffer sizes of the intermediate buffers
fn intermediate_buffer_sizes(&self, _: &[ShapeTracker]) -> Vec<BigExpression> {
fn intermediate_buffer_sizes(&self, _: &[ShapeTracker]) -> Vec<Expression> {
vec![]
}
/// Annotate the buffer sizes of the output buffers
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<BigExpression>;
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<Expression>;
/// Set up the kernel on the buffer
fn metal_forward(
&self,
@@ -227,7 +227,7 @@ impl Deref for MetalKernelWrapper {
}
impl MetalKernel for () {
fn output_buffer_sizes(&self, _: &[ShapeTracker]) -> Vec<BigExpression> {
fn output_buffer_sizes(&self, _: &[ShapeTracker]) -> Vec<Expression> {
vec![]
}
fn metal_forward(
@@ -365,14 +365,10 @@ fn render_dyn_dim_inputs(shapes: &[ShapeTracker], offset: usize) -> (Vec<char>,
let symbols: Vec<char> = shapes
.iter()
.flat_map(|st| {
st.shape()
st.dims()
.into_iter()
.chain(
st.padding
.into_iter()
.flat_map(|i| [i.0.into(), i.1.into()]),
)
.chain(st.mask.into_iter().flat_map(|i| [i.0.into(), i.1.into()]))
.chain(st.padding.into_iter().flat_map(|i| [i.0, i.1]))
.chain(st.mask.into_iter().flat_map(|i| [i.0, i.1]))
})
.flat_map(|d| d.to_symbols())
.unique()
@@ -389,9 +385,9 @@ fn render_dyn_dim_inputs(shapes: &[ShapeTracker], offset: usize) -> (Vec<char>,
)
}
fn expr_to_metal_string(expr: &BigExpression) -> String {
fn expr_to_metal_string(expr: &Expression) -> String {
let mut symbols = vec![];
for term in expr.terms.clone() {
for term in expr.terms.read().clone() {
let new_symbol = match term {
Term::Num(n) => n.to_string(),
Term::Var(c) => {

View File

@@ -33,15 +33,15 @@ impl<T> Debug for Matmul<T> {
const BM: u64 = 8;
const BN: u64 = 32;
impl<T> MetalKernel for Matmul<T> {
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<BigExpression> {
let m = input_shapes[0].shape()[input_shapes[0].len() - 2].clone();
let n = input_shapes[1].shape()[input_shapes[1].len() - 1].clone();
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<Expression> {
let m = input_shapes[0].dims()[input_shapes[0].len() - 2];
let n = input_shapes[1].dims()[input_shapes[1].len() - 1];
let batch_size = input_shapes[0]
.shape()
.dims()
.into_iter()
.take(input_shapes[0].len() - 2)
.product::<BigExpression>()
.max(BigExpression::from(1));
.product::<Expression>()
.max(1);
vec![batch_size * m * n * size_of::<T>()]
}
fn metal_forward(
@@ -54,13 +54,13 @@ impl<T> MetalKernel for Matmul<T> {
let (a_shape, b_shape) = (
inputs[0]
.1
.shape()
.dims()
.into_iter()
.map(|i| i.to_usize().unwrap())
.collect::<Vec<_>>(),
inputs[1]
.1
.shape()
.dims()
.into_iter()
.map(|i| i.to_usize().unwrap())
.collect::<Vec<_>>(),
@@ -158,7 +158,7 @@ impl<T: MetalFloat> Operator for Matmul<T> {
// Setup command queue / command buffer / encoder
let command_buffer = self.queue.new_command_buffer();
let (a_shape, b_shape) = (inp[0].1.shape(), inp[1].1.shape());
let (a_shape, b_shape) = (inp[0].1.dims(), inp[1].1.dims());
let n = b_shape.last().unwrap().to_usize().unwrap();
let batch_size = a_shape
.iter()

View File

@@ -71,7 +71,7 @@ pub struct MetalARange<T: MetalFloat> {
pipeline: ComputePipelineState,
queue: CommandQueue,
device: Device,
pub size: BigExpression,
pub size: Expression,
dyn_map: *const FxHashMap<char, usize>,
_phantom: PhantomData<T>,
}
@@ -86,7 +86,7 @@ impl<T: MetalFloat> MetalARange<T> {
pub fn new(
device: Device,
queue: CommandQueue,
size: BigExpression,
size: Expression,
dyn_map: *const FxHashMap<char, usize>,
) -> Self {
let type_name = T::type_name();
@@ -109,8 +109,8 @@ kernel void metal_arange(device {type_name} *out [[buffer(0)]], device int& n_el
}
impl<T: MetalFloat> MetalKernel for MetalARange<T> {
fn output_buffer_sizes(&self, _: &[ShapeTracker]) -> Vec<BigExpression> {
vec![self.size.clone() * std::mem::size_of::<f16>()]
fn output_buffer_sizes(&self, _: &[ShapeTracker]) -> Vec<Expression> {
vec![self.size * std::mem::size_of::<f16>()]
}
fn metal_forward(
&self,
@@ -213,7 +213,7 @@ impl<T: MetalFloat> Compiler for ARangeCompiler<T> {
.add_op(MetalARange::<T>::new(
dev.clone(),
queue.clone(),
arange_amount.into(),
arange_amount,
&graph.dyn_map,
))
.finish();

View File

@@ -161,7 +161,7 @@ macro_rules! metal_unary_op {
}
impl<T> MetalKernel for $op_name<T> {
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<BigExpression> {
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<Expression> {
vec![input_shapes[0].contiguous().n_elements() * size_of::<T>()]
}
fn metal_forward(
@@ -288,7 +288,7 @@ kernel void mkernel(device {type_name} *inp_a [[buffer(0)]], device {type_name}
}
impl<T> MetalKernel for MetalAdd<T> {
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<BigExpression> {
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<Expression> {
vec![input_shapes[0].n_elements() * size_of::<T>()]
}
fn metal_forward(
@@ -405,7 +405,7 @@ kernel void mkernel(device {type_name} *inp_a [[buffer(0)]], device {type_name}
}
}
impl<T> MetalKernel for MetalMul<T> {
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<BigExpression> {
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<Expression> {
vec![input_shapes[0].n_elements() * size_of::<T>()]
}
fn metal_forward(
@@ -532,7 +532,7 @@ kernel void mkernel(device {type_name} *inp_a [[buffer(0)]], device {type_name}
}
impl<T> MetalKernel for MetalLessThan<T> {
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<BigExpression> {
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<Expression> {
vec![input_shapes[0].n_elements() * size_of::<T>()]
}
fn metal_forward(
@@ -650,7 +650,7 @@ kernel void mkernel(device {type_name} *inp_a [[buffer(0)]], device {type_name}
}
}
impl<T> MetalKernel for MetalMod<T> {
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<BigExpression> {
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<Expression> {
vec![input_shapes[0].n_elements() * size_of::<T>()]
}
fn metal_forward(
@@ -785,7 +785,7 @@ kernel void mkernel(device {type_name} *inp [[buffer(0)]], device {type_name} *o
}
impl<T> MetalKernel for MetalSumReduce<T> {
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<BigExpression> {
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<Expression> {
let mut sh = input_shapes[0];
sh.remove_dim(self.dim);
vec![sh.n_elements() * size_of::<T>()]
@@ -802,19 +802,19 @@ impl<T> MetalKernel for MetalSumReduce<T> {
let inp_size = sh.n_elements().to_usize().unwrap();
let front_size: usize = inputs[0]
.1
.shape()
.dims()
.iter()
.take(self.dim)
.map(|i| i.to_usize().unwrap())
.product();
let back_size: usize = inputs[0]
.1
.shape()
.dims()
.iter()
.skip(self.dim + 1)
.map(|i| i.to_usize().unwrap())
.product();
let dim_size = inputs[0].1.shape()[self.dim].to_usize().unwrap();
let dim_size = inputs[0].1.dims()[self.dim].to_usize().unwrap();
let encoder =
command_buffer.compute_command_encoder_with_descriptor(ComputePassDescriptor::new());
@@ -938,7 +938,7 @@ kernel void mkernel(device {type_name} *inp [[buffer(0)]], device {type_name} *o
}
}
impl<T> MetalKernel for MetalMaxReduce<T> {
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<BigExpression> {
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<Expression> {
let mut sh = input_shapes[0];
sh.remove_dim(self.dim);
vec![sh.n_elements() * size_of::<T>()]
@@ -955,19 +955,19 @@ impl<T> MetalKernel for MetalMaxReduce<T> {
let inp_size = sh.contiguous().n_elements().to_usize().unwrap();
let front_size: usize = inputs[0]
.1
.shape()
.dims()
.iter()
.take(self.dim)
.map(|i| i.to_usize().unwrap())
.product();
let back_size: usize = inputs[0]
.1
.shape()
.dims()
.iter()
.skip(self.dim + 1)
.map(|i| i.to_usize().unwrap())
.product();
let dim_size = inputs[0].1.shape()[self.dim].to_usize().unwrap();
let dim_size = inputs[0].1.dims()[self.dim].to_usize().unwrap();
let encoder =
command_buffer.compute_command_encoder_with_descriptor(ComputePassDescriptor::new());
@@ -1052,9 +1052,10 @@ impl<T: MetalFloat + 'static> Compiler for PrimitiveCompiler<T> {
> 0
{
// Copy outputs to device
let sh = ShapeTracker::new(());
let copy_node = graph
.add_op(MetalCopyToDevice::<T>::new(dev.clone()))
.input(function_node, 0, ShapeTracker::default())
.input(function_node, 0, sh)
.finish();
// Switch outgoing edges from input to copy_node

View File

@@ -1,14 +1,4 @@
use std::{
any::Any,
collections::hash_map::Entry,
fs::File,
io::Write,
marker::PhantomData,
mem::size_of,
path::{Path, PathBuf},
str::FromStr,
sync::Arc,
};
use std::{any::Any, marker::PhantomData, mem::size_of, sync::Arc};
use metal_rs::{
objc::rc::autoreleasepool, Buffer, CommandBufferRef, CommandQueue, ComputePassDescriptor,
@@ -20,20 +10,10 @@ use luminal::{
op::{InputTensor, Operator},
prelude::*,
};
use rustc_hash::FxHashMap;
use serde_json::{json, Value};
use crate::{
binary::MetalGather,
compile_lib,
elementwise_fusion::FusedElementwiseOp,
get_buffer_from_tensor,
matmul::Matmul,
other::MetalARange,
prim::{MetalConstant, MetalCopyFromDevice, MetalCopyToDevice, MetalMaxReduce, MetalSumReduce},
select_function_from_lib,
unary::{MetalMeanReduce, MetalStdNorm},
MetalBuffer, MetalFloat, MetalKernel, MetalKernelWrapper,
binary::MetalGather, get_buffer_from_tensor, MetalBuffer, MetalFloat, MetalKernel,
MetalKernelWrapper,
};
use super::{compile_function, SetInt};
@@ -300,15 +280,15 @@ kernel void matvec(
}
impl<T> MetalKernel for QuantizedMatmul<T> {
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<BigExpression> {
let m = input_shapes[0].shape()[input_shapes[0].len() - 2].clone();
let n = input_shapes[1].shape()[input_shapes[1].len() - 1].clone();
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<Expression> {
let m = input_shapes[0].dims()[input_shapes[0].len() - 2];
let n = input_shapes[1].dims()[input_shapes[1].len() - 1];
let batch_size = input_shapes[0]
.shape()
.dims()
.into_iter()
.take(input_shapes[0].len() - 2)
.product::<BigExpression>()
.max(BigExpression::from(1));
.product::<Expression>()
.max(1);
vec![batch_size * m * n * size_of::<T>()]
}
fn metal_forward(
@@ -325,13 +305,13 @@ impl<T> MetalKernel for QuantizedMatmul<T> {
let (a_shape, b_shape) = (
inputs[0]
.1
.shape()
.dims()
.into_iter()
.map(|i| i.to_usize().unwrap())
.collect::<Vec<_>>(),
inputs[1]
.1
.shape()
.dims()
.into_iter()
.map(|i| i.to_usize().unwrap())
.collect::<Vec<_>>(),
@@ -393,7 +373,7 @@ impl<T: MetalFloat> Operator for QuantizedMatmul<T> {
// Setup command queue / command buffer / encoder
let command_buffer = self.queue.new_command_buffer();
let (a_shape, b_shape) = (inp[0].1.shape(), inp[1].1.shape());
let (a_shape, b_shape) = (inp[0].1.dims(), inp[1].1.dims());
let n = b_shape[1].to_usize().unwrap();
let (batch_size, m) = if a_shape.len() == 3 {
(
@@ -474,7 +454,7 @@ impl<T: MetalFloat> Operator for QuantizedGather<T> {
// Setup buffers
let indexes = tensors[0].0.borrowed().downcast_ref::<Vec<f32>>().unwrap();
let index_buffer = self.device.new_buffer_with_data(
unsafe { std::mem::transmute(indexes.as_ptr()) },
indexes.as_ptr() as *const _,
(indexes.len() * std::mem::size_of::<f32>()) as u64,
MTLResourceOptions::StorageModeShared,
);
@@ -570,297 +550,297 @@ impl<T: MetalFloat + Default> Compiler for MetalQuantizedCompiler<T> {
}
}
#[derive(Default)]
pub struct SerializeQuantizedGraph<T> {
path: PathBuf,
_phantom: PhantomData<T>,
}
// #[derive(Default)]
// pub struct SerializeQuantizedGraph<T> {
// path: PathBuf,
// _phantom: PhantomData<T>,
// }
impl<T: MetalFloat> SerializeQuantizedGraph<T> {
pub fn new(path: impl AsRef<Path>) -> Self {
Self {
path: path.as_ref().to_path_buf(),
_phantom: Default::default(),
}
}
}
// impl<T: MetalFloat> SerializeQuantizedGraph<T> {
// pub fn new(path: impl AsRef<Path>) -> Self {
// Self {
// path: path.as_ref().to_path_buf(),
// _phantom: Default::default(),
// }
// }
// }
impl<T: MetalFloat + Default> Compiler for SerializeQuantizedGraph<T> {
type Output = ();
fn compile<To: ToIdsMut>(&self, graph: &mut Graph, mut nodes: To) {
// Serialize and save graph away
let mut ops = vec![];
for node in graph.node_indices().collect::<Vec<_>>() {
if graph.check_node_type::<Function>(node) {
continue;
}
let data = if let Some(op) = graph.try_get_op::<MetalMeanReduce<T>>(node) {
json!({
"dim": op.3,
"shape": op.7,
})
} else if let Some(op) = graph.try_get_op::<MetalARange<T>>(node) {
json!({
"size": op.size,
})
} else if let Some(op) = graph.try_get_op::<MetalStdNorm<T>>(node) {
json!({
"eps": op.epsilon,
})
} else if let Some(op) = graph.try_get_op::<MetalSumReduce<T>>(node) {
json!({
"dim": op.dim,
"shape": op.shape,
})
} else if let Some(op) = graph.try_get_op::<MetalMaxReduce<T>>(node) {
json!({
"dim": op.dim,
"shape": op.shape,
})
} else if let Some(op) = graph.try_get_op::<QuantizedGather<T>>(node) {
json!({
"embed_dim": op.embed_dim,
})
} else if let Some(op) = graph.try_get_op::<MetalGather<T>>(node) {
json!({
"embed_dim": op.embed_dim,
})
} else if let Some(op) = graph.try_get_op::<MetalConstant<T>>(node) {
json!({
"value": op.0,
})
} else if let Some(op) = graph.try_get_op::<Matmul<T>>(node) {
json!({
"matmul_kernel": op.matmul_kernel,
"matvec_kernel": op.matvec_kernel,
})
} else if let Some(op) = graph.try_get_op::<FusedElementwiseOp<T>>(node) {
json!({
"kernel_str": op.kernel_str,
"dyn_chars": op.dyn_chars,
"output_buffer_sizes": op.output_buffer_sizes,
})
} else if graph.check_node_type::<QuantizedMatmul<T>>(node)
|| graph.check_node_type::<MetalCopyFromDevice<T>>(node)
|| graph.check_node_type::<MetalCopyToDevice<T>>(node)
{
json!({})
} else {
panic!(
"Found unserializable op: {:?}",
graph.node_weight(node).unwrap()
);
};
ops.push(json!({
"type": format!("{:?}", graph.node_weight(node).unwrap()),
"id": node.index(),
"data": data
}));
}
let edges = graph
.edge_indices()
.map(|e| (e, graph.edge_endpoints(e).unwrap()))
.map(|(e, (a, b))| (a.index(), b.index(), *graph.edge_weight(e).unwrap()))
.collect::<Vec<_>>();
let value = json!({
"nodes": nodes.to_ids_mut().iter().map(|i| i.index()).collect::<Vec<_>>(),
"ops": ops,
"edges": edges,
"no_delete": graph.no_delete.iter().map(|i| i.index()).collect::<Vec<_>>(),
"to_retrieve": graph.to_retrieve.iter().map(|(k, v)| (k.index(), v)).collect::<Vec<_>>(),
});
File::create(&self.path)
.unwrap()
.write_all(value.to_string().as_bytes())
.unwrap();
}
}
// impl<T: MetalFloat + Default> Compiler for SerializeQuantizedGraph<T> {
// type Output = ();
// fn compile<To: ToIdsMut>(&self, graph: &mut Graph, mut nodes: To) {
// // Serialize and save graph away
// let mut ops = vec![];
// for node in graph.node_indices().collect::<Vec<_>>() {
// if graph.check_node_type::<Function>(node) {
// continue;
// }
// let data = if let Some(op) = graph.try_get_op::<MetalMeanReduce<T>>(node) {
// json!({
// "dim": op.3,
// "shape": op.7,
// })
// } else if let Some(op) = graph.try_get_op::<MetalARange<T>>(node) {
// json!({
// "size": op.size,
// })
// } else if let Some(op) = graph.try_get_op::<MetalStdNorm<T>>(node) {
// json!({
// "eps": op.epsilon,
// })
// } else if let Some(op) = graph.try_get_op::<MetalSumReduce<T>>(node) {
// json!({
// "dim": op.dim,
// "shape": op.shape,
// })
// } else if let Some(op) = graph.try_get_op::<MetalMaxReduce<T>>(node) {
// json!({
// "dim": op.dim,
// "shape": op.shape,
// })
// } else if let Some(op) = graph.try_get_op::<QuantizedGather<T>>(node) {
// json!({
// "embed_dim": op.embed_dim,
// })
// } else if let Some(op) = graph.try_get_op::<MetalGather<T>>(node) {
// json!({
// "embed_dim": op.embed_dim,
// })
// } else if let Some(op) = graph.try_get_op::<MetalConstant<T>>(node) {
// json!({
// "value": op.0,
// })
// } else if let Some(op) = graph.try_get_op::<Matmul<T>>(node) {
// json!({
// "matmul_kernel": op.matmul_kernel,
// "matvec_kernel": op.matvec_kernel,
// })
// } else if let Some(op) = graph.try_get_op::<FusedElementwiseOp<T>>(node) {
// json!({
// "kernel_str": op.kernel_str,
// "dyn_chars": op.dyn_chars,
// "output_buffer_sizes": op.output_buffer_sizes,
// })
// } else if graph.check_node_type::<QuantizedMatmul<T>>(node)
// || graph.check_node_type::<MetalCopyFromDevice<T>>(node)
// || graph.check_node_type::<MetalCopyToDevice<T>>(node)
// {
// json!({})
// } else {
// panic!(
// "Found unserializable op: {:?}",
// graph.node_weight(node).unwrap()
// );
// };
// ops.push(json!({
// "type": format!("{:?}", graph.node_weight(node).unwrap()),
// "id": node.index(),
// "data": data
// }));
// }
// let edges = graph
// .edge_indices()
// .map(|e| (e, graph.edge_endpoints(e).unwrap()))
// .map(|(e, (a, b))| (a.index(), b.index(), *graph.edge_weight(e).unwrap()))
// .collect::<Vec<_>>();
// let value = json!({
// "nodes": nodes.to_ids_mut().iter().map(|i| i.index()).collect::<Vec<_>>(),
// "ops": ops,
// "edges": edges,
// "no_delete": graph.no_delete.iter().map(|i| i.index()).collect::<Vec<_>>(),
// "to_retrieve": graph.to_retrieve.iter().map(|(k, v)| (k.index(), v)).collect::<Vec<_>>(),
// });
// File::create(&self.path)
// .unwrap()
// .write_all(value.to_string().as_bytes())
// .unwrap();
// }
// }
/// Deserialize a metal graph
#[derive(Debug, Clone)]
pub struct DeserializeQuantizedGraph<T>(String, PhantomData<T>);
// /// Deserialize a metal graph
// #[derive(Debug, Clone)]
// pub struct DeserializeQuantizedGraph<T>(String, PhantomData<T>);
impl<T> DeserializeQuantizedGraph<T> {
pub fn new(data: impl ToString) -> Self {
Self(data.to_string(), Default::default())
}
}
// impl<T> DeserializeQuantizedGraph<T> {
// pub fn new(data: impl ToString) -> Self {
// Self(data.to_string(), Default::default())
// }
// }
impl<T: MetalFloat> Compiler for DeserializeQuantizedGraph<T> {
type Output = ();
fn compile<To: ToIdsMut>(&self, graph: &mut Graph, mut ids: To) -> Self::Output {
let mut value = Value::from_str(&self.0).unwrap();
let dev = Device::system_default().unwrap();
let queue = dev.new_command_queue();
// Create ops
let mut op_map = FxHashMap::<usize, NodeIndex>::default();
let matmul_library = compile_lib(&dev, include_str!("kernels/gemm.metal"));
let matvec_library = compile_lib(&dev, include_str!("kernels/gemv.metal"));
let quantized_matmul = QuantizedMatmul::<T>::new(dev.clone(), queue.clone());
for op in value["ops"].as_array_mut().unwrap() {
let name = op["type"].as_str().unwrap().to_string();
let new_id = if name == "MetalMeanReduce" {
graph
.add_op(MetalMeanReduce::<T>::new(
dev.clone(),
queue.clone(),
op["data"]["dim"].as_u64().unwrap() as usize,
serde_json::from_value(op["data"]["shape"].take()).unwrap(),
&graph.dyn_map,
))
.finish()
} else if name == "MetalSumReduce" {
graph
.add_op(MetalSumReduce::<T>::new(
serde_json::from_value(op["data"]["shape"].take()).unwrap(),
op["data"]["dim"].as_u64().unwrap() as usize,
dev.clone(),
queue.clone(),
&graph.dyn_map,
))
.finish()
} else if name == "MetalMaxReduce" {
graph
.add_op(MetalMaxReduce::<T>::new(
serde_json::from_value(op["data"]["shape"].take()).unwrap(),
op["data"]["dim"].as_u64().unwrap() as usize,
dev.clone(),
queue.clone(),
&graph.dyn_map,
))
.finish()
} else if name == "MetalStdNorm" {
graph
.add_op(MetalStdNorm::<T>::new(
op["data"]["eps"].as_f64().unwrap() as f32,
dev.clone(),
queue.clone(),
))
.finish()
} else if name.contains("MetalARange") {
graph
.add_op(MetalARange::<T>::new(
dev.clone(),
queue.clone(),
serde_json::from_value(op["data"]["size"].take()).unwrap(),
&graph.dyn_map,
))
.finish()
} else if name == "QuantizedGather" {
graph
.add_op(QuantizedGather::<T>::new(
dev.clone(),
queue.clone(),
op["data"]["embed_dim"].as_u64().unwrap() as usize,
))
.finish()
} else if name == "MetalGather" {
graph
.add_op(MetalGather::<T>::new(
dev.clone(),
queue.clone(),
op["data"]["embed_dim"].as_u64().unwrap() as usize,
))
.finish()
} else if name.contains("MetalConstant") {
graph
.add_op(MetalConstant::<T>(
serde_json::from_value(op["data"]["value"].take()).unwrap(),
dev.clone(),
&graph.dyn_map,
Default::default(),
))
.finish()
} else if name == "FusedElementwiseOp" {
let mut fused_op = FusedElementwiseOp::<T> {
kernel: None,
dyn_map: &graph.dyn_map,
kernel_str: op["data"]["kernel_str"].as_str().unwrap().to_string(),
dyn_chars: serde_json::from_value(op["data"]["dyn_chars"].take()).unwrap(),
subexpressions: vec![],
queue: queue.clone(),
device: dev.clone(),
output_buffer_sizes: serde_json::from_value(
op["data"]["output_buffer_sizes"].take(),
)
.unwrap(),
_phantom: Default::default(),
};
fused_op.compile(&dev);
graph.add_op(fused_op).finish()
} else if name == "Matmul" {
let matmul_kernel =
serde_json::from_value::<String>(op["data"]["matmul_kernel"].take()).unwrap();
let matvec_kernel =
serde_json::from_value::<String>(op["data"]["matvec_kernel"].take()).unwrap();
graph
.add_op(Matmul::<T> {
matmul_pipeline: select_function_from_lib(
&matmul_library,
&matmul_kernel,
&dev,
),
matvec_pipeline: select_function_from_lib(
&matvec_library,
&matvec_kernel,
&dev,
),
matmul_kernel,
matvec_kernel,
queue: queue.clone(),
device: dev.clone(),
_phantom: Default::default(),
})
.finish()
} else if name == "QuantizedMatmul" {
graph.add_op(quantized_matmul.clone()).finish()
} else if name == "MetalCopyToDevice" {
graph
.add_op(MetalCopyToDevice::<T>::new(dev.clone()))
.finish()
} else if name == "MetalCopyFromDevice" {
graph.add_op(MetalCopyFromDevice::<T>::default()).finish()
} else {
panic!("Found unexpected serialized op: {name}");
};
op_map.insert(op["id"].as_u64().unwrap() as usize, new_id);
}
// Remap nodes that are in the op_map
for (saved, mut new) in value["nodes"]
.as_array()
.unwrap()
.iter()
.zip(ids.to_ids_mut())
{
let saved = saved.as_u64().unwrap() as usize;
if let Entry::Vacant(e) = op_map.entry(saved) {
e.insert(*new);
} else {
graph.remove_node(*new);
remap(*new, op_map[&saved], &mut new, graph);
}
}
// Create edges
let edges =
serde_json::from_value::<Vec<(usize, usize, Dependency)>>(value["edges"].take())
.unwrap();
for (a, b, dep) in edges {
graph.add_edge(op_map[&a], op_map[&b], dep);
}
// Update no_delete and to_retrieve
graph.no_delete = serde_json::from_value::<Vec<usize>>(value["no_delete"].take())
.unwrap()
.into_iter()
.map(|i| op_map[&i])
.collect();
graph.to_retrieve =
serde_json::from_value::<Vec<(usize, (u8, ShapeTracker))>>(value["to_retrieve"].take())
.unwrap()
.into_iter()
.map(|(k, v)| (op_map[&k], v))
.collect();
}
}
// impl<T: MetalFloat> Compiler for DeserializeQuantizedGraph<T> {
// type Output = ();
// fn compile<To: ToIdsMut>(&self, graph: &mut Graph, mut ids: To) -> Self::Output {
// let mut value = Value::from_str(&self.0).unwrap();
// let dev = Device::system_default().unwrap();
// let queue = dev.new_command_queue();
// // Create ops
// let mut op_map = FxHashMap::<usize, NodeIndex>::default();
// let matmul_library = compile_lib(&dev, include_str!("kernels/gemm.metal"));
// let matvec_library = compile_lib(&dev, include_str!("kernels/gemv.metal"));
// let quantized_matmul = QuantizedMatmul::<T>::new(dev.clone(), queue.clone());
// for op in value["ops"].as_array_mut().unwrap() {
// let name = op["type"].as_str().unwrap().to_string();
// let new_id = if name == "MetalMeanReduce" {
// graph
// .add_op(MetalMeanReduce::<T>::new(
// dev.clone(),
// queue.clone(),
// op["data"]["dim"].as_u64().unwrap() as usize,
// serde_json::from_value(op["data"]["shape"].take()).unwrap(),
// &graph.dyn_map,
// ))
// .finish()
// } else if name == "MetalSumReduce" {
// graph
// .add_op(MetalSumReduce::<T>::new(
// serde_json::from_value(op["data"]["shape"].take()).unwrap(),
// op["data"]["dim"].as_u64().unwrap() as usize,
// dev.clone(),
// queue.clone(),
// &graph.dyn_map,
// ))
// .finish()
// } else if name == "MetalMaxReduce" {
// graph
// .add_op(MetalMaxReduce::<T>::new(
// serde_json::from_value(op["data"]["shape"].take()).unwrap(),
// op["data"]["dim"].as_u64().unwrap() as usize,
// dev.clone(),
// queue.clone(),
// &graph.dyn_map,
// ))
// .finish()
// } else if name == "MetalStdNorm" {
// graph
// .add_op(MetalStdNorm::<T>::new(
// op["data"]["eps"].as_f64().unwrap() as f32,
// dev.clone(),
// queue.clone(),
// ))
// .finish()
// } else if name.contains("MetalARange") {
// graph
// .add_op(MetalARange::<T>::new(
// dev.clone(),
// queue.clone(),
// serde_json::from_value(op["data"]["size"].take()).unwrap(),
// &graph.dyn_map,
// ))
// .finish()
// } else if name == "QuantizedGather" {
// graph
// .add_op(QuantizedGather::<T>::new(
// dev.clone(),
// queue.clone(),
// op["data"]["embed_dim"].as_u64().unwrap() as usize,
// ))
// .finish()
// } else if name == "MetalGather" {
// graph
// .add_op(MetalGather::<T>::new(
// dev.clone(),
// queue.clone(),
// op["data"]["embed_dim"].as_u64().unwrap() as usize,
// ))
// .finish()
// } else if name.contains("MetalConstant") {
// graph
// .add_op(MetalConstant::<T>(
// serde_json::from_value(op["data"]["value"].take()).unwrap(),
// dev.clone(),
// &graph.dyn_map,
// Default::default(),
// ))
// .finish()
// } else if name == "FusedElementwiseOp" {
// let mut fused_op = FusedElementwiseOp::<T> {
// kernel: None,
// dyn_map: &graph.dyn_map,
// kernel_str: op["data"]["kernel_str"].as_str().unwrap().to_string(),
// dyn_chars: serde_json::from_value(op["data"]["dyn_chars"].take()).unwrap(),
// subexpressions: vec![],
// queue: queue.clone(),
// device: dev.clone(),
// output_buffer_sizes: serde_json::from_value(
// op["data"]["output_buffer_sizes"].take(),
// )
// .unwrap(),
// _phantom: Default::default(),
// };
// fused_op.compile(&dev);
// graph.add_op(fused_op).finish()
// } else if name == "Matmul" {
// let matmul_kernel =
// serde_json::from_value::<String>(op["data"]["matmul_kernel"].take()).unwrap();
// let matvec_kernel =
// serde_json::from_value::<String>(op["data"]["matvec_kernel"].take()).unwrap();
// graph
// .add_op(Matmul::<T> {
// matmul_pipeline: select_function_from_lib(
// &matmul_library,
// &matmul_kernel,
// &dev,
// ),
// matvec_pipeline: select_function_from_lib(
// &matvec_library,
// &matvec_kernel,
// &dev,
// ),
// matmul_kernel,
// matvec_kernel,
// queue: queue.clone(),
// device: dev.clone(),
// _phantom: Default::default(),
// })
// .finish()
// } else if name == "QuantizedMatmul" {
// graph.add_op(quantized_matmul.clone()).finish()
// } else if name == "MetalCopyToDevice" {
// graph
// .add_op(MetalCopyToDevice::<T>::new(dev.clone()))
// .finish()
// } else if name == "MetalCopyFromDevice" {
// graph.add_op(MetalCopyFromDevice::<T>::default()).finish()
// } else {
// panic!("Found unexpected serialized op: {name}");
// };
// op_map.insert(op["id"].as_u64().unwrap() as usize, new_id);
// }
// // Remap nodes that are in the op_map
// for (saved, mut new) in value["nodes"]
// .as_array()
// .unwrap()
// .iter()
// .zip(ids.to_ids_mut())
// {
// let saved = saved.as_u64().unwrap() as usize;
// if let Entry::Vacant(e) = op_map.entry(saved) {
// e.insert(*new);
// } else {
// graph.remove_node(*new);
// remap(*new, op_map[&saved], &mut new, graph);
// }
// }
// // Create edges
// let edges =
// serde_json::from_value::<Vec<(usize, usize, Dependency)>>(value["edges"].take())
// .unwrap();
// for (a, b, dep) in edges {
// graph.add_edge(op_map[&a], op_map[&b], dep);
// }
// // Update no_delete and to_retrieve
// graph.no_delete = serde_json::from_value::<Vec<usize>>(value["no_delete"].take())
// .unwrap()
// .into_iter()
// .map(|i| op_map[&i])
// .collect();
// graph.to_retrieve =
// serde_json::from_value::<Vec<(usize, (u8, ShapeTracker))>>(value["to_retrieve"].take())
// .unwrap()
// .into_iter()
// .map(|(k, v)| (op_map[&k], v))
// .collect();
// }
// }
#[cfg(test)]
mod tests {

View File

@@ -299,7 +299,7 @@ fn btreeset_intersection<T: Ord>(mut a: BTreeSet<T>, b: &BTreeSet<T>) -> BTreeSe
struct AllocateMetalBuffers {
dev: Device,
dyn_map: *const FxHashMap<char, usize>,
buffer_sizes: Vec<BigExpression>,
buffer_sizes: Vec<Expression>,
buffers: Arc<UnsafeCell<Vec<Buffer>>>,
}
impl Debug for AllocateMetalBuffers {

View File

@@ -80,8 +80,8 @@ fn test_rotate() {
const D: usize = 4;
let data = random_vec(D * B * F);
let a = cx.tensor((F, B, D)).set(data.clone()).permute((1, 0, 2));
let x1 = a.slice((.., .., ..Expression::from(D / 2)));
let x2 = a.slice((.., .., Expression::from(D / 2)..));
let x1 = a.slice((.., .., ..D / 2));
let x2 = a.slice((.., .., D / 2..));
let mut rotated_a = (-x2).concat_along(x1, 1).retrieve();
cx.execute();
let unopt = rotated_a.data();
@@ -709,7 +709,7 @@ fn test_slice() {
let data = random_vec(256);
let mut cx = Graph::new();
let a = cx.tensor(256).set(data.clone());
let mut c = a.slice(..Expression::from(20)).contiguous().retrieve();
let mut c = a.slice(..20).contiguous().retrieve();
cx.compile(MetalCompiler::<f16>::default(), &mut c);
cx.execute();
@@ -752,10 +752,10 @@ fn test_pad_contig() {
let a_data = random_vec_rng(m * k, &mut rng);
let mut a = cx.tensor(('M', 'K')).set_dyn(a_data, (m, k)).retrieve();
let mut b = a
.pad(&[(0, 0.into()), (0, Expression::from(24) - 'K')])
.pad(((0, 0), (0, Expression::from(24) - 'K')))
.contiguous()
.retrieve();
let mut c = (a.slice((.., ..Expression::from(k))) / 1.0).retrieve();
let mut c = (a.slice((.., ..k)) / 1.0).retrieve();
cx.compile(MetalCompiler::<f16>::default(), (&mut a, &mut b, &mut c));
cx.execute();
@@ -771,7 +771,7 @@ fn test_movement() {
let mut cx = Graph::new();
let a = cx.tensor(32).set(data.clone());
let b = a.pad((0, 10)).contiguous().retrieve();
let mut c = b.slice((..Expression::from(25),)).contiguous().retrieve();
let mut c = b.slice((..25,)).contiguous().retrieve();
cx.compile(MetalCompiler::<f16>::default(), &mut c);
cx.execute();

View File

@@ -83,7 +83,7 @@ kernel void mkernel(device {type_name} *inp [[buffer(0)]], device {type_name} *o
}
impl<T> MetalKernel for MetalMeanReduce<T> {
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<BigExpression> {
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<Expression> {
let mut sh = input_shapes[0];
sh.remove_dim(self.3);
vec![sh.n_elements() * size_of::<T>()]
@@ -101,19 +101,19 @@ impl<T> MetalKernel for MetalMeanReduce<T> {
let front_size: usize = inputs[0]
.1
.shape()
.dims()
.iter()
.take(self.3)
.map(|i| i.to_usize().unwrap())
.product();
let back_size: usize = inputs[0]
.1
.shape()
.dims()
.iter()
.skip(self.3 + 1)
.map(|i| i.to_usize().unwrap())
.product();
let dim_size = inputs[0].1.shape()[self.3].to_usize().unwrap();
let dim_size = inputs[0].1.dims()[self.3].to_usize().unwrap();
let encoder =
command_buffer.compute_command_encoder_with_descriptor(ComputePassDescriptor::new());
@@ -306,7 +306,7 @@ kernel void kernel_std_norm(
}
impl<T> MetalKernel for MetalStdNorm<T> {
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<BigExpression> {
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<Expression> {
vec![input_shapes[0].n_elements() * size_of::<T>()]
}
@@ -320,7 +320,7 @@ impl<T> MetalKernel for MetalStdNorm<T> {
let encoder =
command_buffer.compute_command_encoder_with_descriptor(ComputePassDescriptor::new());
encoder.set_compute_pipeline_state(&self.pipeline);
let row_size = inputs[0].1.shape().last().unwrap().to_usize().unwrap();
let row_size = inputs[0].1.dims().last().unwrap().to_usize().unwrap();
// Set inputs
encoder.set_buffer(0, Some(inputs[0].0), 0);
@@ -329,7 +329,7 @@ impl<T> MetalKernel for MetalStdNorm<T> {
encoder.set_f32(3, self.epsilon);
let batch_size = inputs[0]
.1
.shape()
.dims()
.into_iter()
.take(inputs[0].1.len() - 1)
.map(|i| i.to_usize().unwrap())
@@ -432,7 +432,7 @@ impl<T: MetalFloat> Compiler for StdNormCompiler<T> {
}
}
if sh
.shape()
.dims()
.last()
.unwrap()
.to_usize()
@@ -515,7 +515,7 @@ kernel void kernel_metal_exp(device {type_name} *inp [[buffer(0)]], device {type
}
impl<T: MetalFloat> MetalKernel for MetalExp<T> {
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<BigExpression> {
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<Expression> {
vec![input_shapes[0].n_elements() * std::mem::size_of::<T>()]
}
fn metal_forward(
@@ -653,7 +653,7 @@ kernel void kernel_metal_cos(device {type_name} *inp [[buffer(0)]], device {type
}
impl<T: MetalFloat> MetalKernel for MetalCos<T> {
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<BigExpression> {
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<Expression> {
vec![input_shapes[0].n_elements() * std::mem::size_of::<T>()]
}
fn metal_forward(

View File

@@ -48,11 +48,10 @@ impl SerializeModule for Conv1D {
}
}
// Single
impl Module<GraphTensor> for Conv1D {
type Output = GraphTensor;
fn forward(&self, input: GraphTensor) -> Self::Output {
assert_eq!(input.shape()[input.shape.len() - 2], self.ch_in);
assert_eq!(input.dims()[input.shape.len() - 2], self.ch_in);
// Input: batch_dims, ch_in, dim_in
// Reshape to 2 batch dims
let n_expands = 4 - input.shape.len();
@@ -61,9 +60,9 @@ impl Module<GraphTensor> for Conv1D {
inp = inp.expand(0, 1);
}
let batch1 = inp.shape()[0].small();
let batch2 = inp.shape()[1].small();
let dim_in = input.shape().last().unwrap().small();
let batch1 = inp.dims()[0];
let batch2 = inp.dims()[1];
let dim_in = *input.dims().last().unwrap();
let dim_out = (((dim_in + 2 * self.padding - self.dilation * (self.kernel - 1) - 1)
/ self.stride)
+ 1)
@@ -86,7 +85,7 @@ impl Module<GraphTensor> for Conv1D {
}
// Reshape back to original shape
let mut final_shape = out.shape();
let mut final_shape = out.dims();
for _ in 0..n_expands {
final_shape.remove(0);
}
@@ -96,6 +95,7 @@ impl Module<GraphTensor> for Conv1D {
pub struct Conv2D {
pub weight: GraphTensor, // ch_out, ch_in * kernel_x * kernel_y
pub bias: Option<GraphTensor>, // ch_out
kernel: (usize, usize),
stride: (usize, usize),
dilation: (usize, usize),
@@ -110,10 +110,16 @@ impl Conv2D {
kernel: (usize, usize),
stride: (usize, usize),
dilation: (usize, usize),
bias: bool,
cx: &mut Graph,
) -> Self {
Self {
weight: cx.named_tensor("Weight", (ch_out, ch_in * kernel.0 * kernel.1)),
bias: if bias {
Some(cx.named_tensor("Bias", ch_out))
} else {
None
},
kernel,
stride,
dilation,
@@ -126,15 +132,22 @@ impl Conv2D {
impl SerializeModule for Conv2D {
fn serialize(&self, s: &mut luminal::module::Serializer) {
s.tensor("weight", self.weight);
if let Some(bias) = self.bias {
s.tensor("bias", bias);
}
}
}
// Single
impl Conv2D {
pub fn forward(&self, input: GraphTensor) -> GraphTensor {
// Input: (ch_in, dimx_in, dimy_in)
let dimx_in = input.shape()[1].small();
let dimy_in = input.shape()[2].small();
pub fn forward(&self, mut input: GraphTensor) -> GraphTensor {
// Input: (batch (optional), ch_in, dimx_in, dimy_in)
let mut expanded = false;
if input.shape.len() == 3 {
// Expand batch
input = input.expand(0, 1);
expanded = true;
}
let (batch, _, dimx_in, dimy_in) = input.dims4();
let dimx_out = (((dimx_in - self.dilation.0 * (self.kernel.0 - 1) - 1) / self.stride.0)
+ 1)
.simplify();
@@ -143,21 +156,34 @@ impl Conv2D {
.simplify();
let input_pooled = input
.pool_last_dim(self.kernel.1, self.stride.1, self.dilation.1)
.permute((0, 2, 3, 1))
.permute((0, 1, 3, 4, 2))
.pool_last_dim(self.kernel.0, self.stride.0, self.dilation.0)
.permute((0, 4, 2, 3, 1))
.permute((0, 1, 5, 3, 4, 2))
.reshape((
batch,
self.ch_in * self.kernel.0 * self.kernel.1,
dimx_out * dimy_out,
));
self.weight
.matmul(input_pooled)
.reshape((self.ch_out, dimx_out, dimy_out))
let mut o = self.weight.expand(0, batch).matmul(input_pooled).reshape((
batch,
self.ch_out,
dimx_out,
dimy_out,
));
if let Some(b) = self.bias {
o += b.expand_to(o.shape);
}
if expanded {
o.reshape((self.ch_out, dimx_out, dimy_out))
} else {
o
}
}
}
pub struct Conv3D {
pub weight: GraphTensor, // ch_out, ch_in * kernel_x * kernel_y * kernel_z
pub bias: Option<GraphTensor>, // ch_out
kernel: (usize, usize, usize),
stride: (usize, usize, usize),
dilation: (usize, usize, usize),
@@ -172,10 +198,16 @@ impl Conv3D {
kernel: (usize, usize, usize),
stride: (usize, usize, usize),
dilation: (usize, usize, usize),
bias: bool,
cx: &mut Graph,
) -> Self {
Self {
weight: cx.named_tensor("Weight", (ch_out, ch_in * kernel.0 * kernel.1 * kernel.2)),
bias: if bias {
Some(cx.named_tensor("Bias", ch_out))
} else {
None
},
kernel,
stride,
dilation,
@@ -188,15 +220,18 @@ impl Conv3D {
impl SerializeModule for Conv3D {
fn serialize(&self, s: &mut luminal::module::Serializer) {
s.tensor("weight", self.weight);
if let Some(bias) = self.bias {
s.tensor("bias", bias);
}
}
}
impl Conv3D {
pub fn forward(&self, input: GraphTensor) -> GraphTensor {
// Input: ch_in, dimx_in, dimy_in, dimz_in
let dimx_in = input.shape()[1].small();
let dimy_in = input.shape()[2].small();
let dimz_in = input.shape()[3].small();
let dimx_in = input.dims()[1];
let dimy_in = input.dims()[2];
let dimz_in = input.dims()[3];
let dimx_out = (((dimx_in - self.dilation.0 * (self.kernel.0 - 1) - 1) / self.stride.0)
+ 1)
.simplify();
@@ -264,8 +299,8 @@ mod tests {
cx.execute();
assert_eq!(
out1.shape(),
vec![BigExpression::from(CH_OUT), BigExpression::from(DIM_OUT)]
out1.dims(),
vec![Expression::from(CH_OUT), Expression::from(DIM_OUT)]
);
assert_close(&out1.data(), &[0.0948, -0.9498, -1.2342]);
}
@@ -421,6 +456,7 @@ mod tests {
(KERNELX, KERNELY),
(STRIDEX, STRIDEY),
(DILATIONX, DILATIONY),
false,
&mut cx,
);
model.weight.set(vec![
@@ -485,6 +521,7 @@ mod tests {
(KERNELX, KERNELY, KERNELZ),
(STRIDEX, STRIDEY, STRIDEZ),
(DILATIONX, DILATIONY, DILATIONZ),
false,
&mut cx,
);
let weights = vec![

View File

@@ -51,7 +51,7 @@ impl Module<GraphTensor> for Embedding {
self.weight.gather(inp)
};
// Unflatten
let mut new_shape = input.shape();
let mut new_shape = input.dims();
new_shape.push(self.embedding_dim.into());
out.reshape(new_shape)
}

View File

@@ -62,17 +62,16 @@ impl Module<(GraphTensor, GraphTensor, GraphTensor)> for MultiHeadSelfAttention
GraphTensor, // batch, s1, dim
),
) -> Self::Output {
let orig_query_shape = queries.shape();
let s1 = keys.shape()[keys.shape.len() - 2].small();
let s2 = queries.shape()[queries.shape.len() - 2].small();
let orig_query_shape = queries.dims();
let s1 = keys.dims()[keys.shape.len() - 2];
let s2 = queries.dims()[queries.shape.len() - 2];
let n_batches = queries
.shape()
.dims()
.into_iter()
.take(queries.shape.len() - 2)
.product::<BigExpression>()
.max(1)
.small();
let dim = queries.shape().last().unwrap().small();
.product::<Expression>()
.max(1);
let dim = *queries.dims().last().unwrap();
let keys = keys.reshape((n_batches, s1, dim));
let values = values.reshape((n_batches, s1, dim));
let queries = queries.reshape((n_batches, s2, dim));

View File

@@ -73,16 +73,15 @@ impl Module<(GraphTensor, GraphTensor)> for TransformerDecoderBlock {
// Input: batch_dims, seq1, dim
// From_enc: batch_dims, seq2, dim
// Flatten to single batch dim
let seq1 = input.shape()[input.shape.len() - 2].small();
let seq2 = from_enc.shape()[from_enc.shape.len() - 2].small();
let dim = input.shape().last().unwrap().small();
let seq1 = input.dims()[input.shape.len() - 2];
let seq2 = from_enc.dims()[from_enc.shape.len() - 2];
let dim = *input.dims().last().unwrap();
let n_batches = input
.shape()
.dims()
.into_iter()
.take(input.shape.len() - 2)
.product::<BigExpression>()
.max(1)
.small();
.product::<Expression>()
.max(1);
let inp = input.reshape((n_batches, seq1, dim));
let fe = from_enc.reshape((n_batches, seq2, dim));
// Batched forward pass

View File

@@ -71,18 +71,18 @@ impl Module<GraphTensor> for TransformerEncoderBlock {
// Input: batch_dims, sequence, dim
// Reshape to 1 batch dim, sequence, dim
let n_batches = input
.shape()
.dims()
.into_iter()
.take(input.shape.len() - 2)
.product::<BigExpression>()
.product::<Expression>()
.max(1);
let sequence = input.shape()[input.shape.len() - 2].small();
let dim = input.shape()[input.shape.len() - 1].small();
let sequence = input.dims()[input.shape.len() - 2];
let dim = input.dims()[input.shape.len() - 1];
let x = input.reshape((n_batches, sequence, dim));
let x = x + self.attention.forward(x);
let x = x.layer_norm(2, 1e-5);
let x = x + self.ff.forward(x);
x.layer_norm(2, 1e-5).reshape(input.shape())
x.layer_norm(2, 1e-5).reshape(input.dims())
}
}

View File

@@ -24,7 +24,7 @@ impl Autograd {
// Run dfs with a starting stack and record all encountered nodes in a set
fn build_dfs_set(
stack: &mut Vec<NodeIndex>,
graph: &MainGraph,
graph: &StorageGraph,
direction: Direction,
) -> FxHashSet<NodeIndex> {
let mut set = FxHashSet::default();
@@ -61,7 +61,7 @@ impl Compiler for Autograd {
*loss,
(
graph.constant(1.0).id,
ShapeTracker::default(), // Assume scalar loss for now
ShapeTracker::new(()), // Assume scalar loss for now
),
);
let weight_set = params.iter().copied().collect::<FxHashSet<_>>();
@@ -217,7 +217,7 @@ fn add_grad(
} else if let Some(MaxReduce(dim)) = graph.try_get_op(fwd.id) {
pre_fwd_shape.remove_dim(*dim);
}
if grad.shape.shape() != pre_fwd_shape.shape() {
if grad.shape.dims() != pre_fwd_shape.dims() {
if !grad.shape.is_contiguous() {
grad = grad.contiguous();
}

View File

@@ -79,7 +79,7 @@ pub fn cross_entropy_with_logits_loss(
let inv_last_axis_numel = 1.0
/ logits
.graph()
.constant(logits.shape.shape().last().unwrap());
.constant(*logits.shape.dims().last().unwrap());
let probs = logits.log_softmax(logits.shape.last_axis());
(-(probs * target_probabilities).mean_reduce(target_probabilities.shape.all_axes()))
/ inv_last_axis_numel
@@ -102,7 +102,7 @@ pub fn kl_div_with_logits_loss(
let inv_last_axis_numel = 1.0
/ logits
.graph()
.constant(logits.shape.shape().last().unwrap());
.constant(*logits.shape.dims().last().unwrap());
let probs = logits.log_softmax(logits.shape.last_axis());
(-((probs - target_probabilities.ln()) * target_probabilities)
.mean_reduce(target_probabilities.shape.all_axes()))

View File

@@ -54,7 +54,7 @@ fn main() {
cx.keep_tensors(&model_weights);
let (logits, mut cache_dest) = model.forward((input, &cache_src));
let mut logits = logits
.slice((.., (Expression::from('s') - 1).., ..))
.slice((.., Expression::from('s') - 1.., ..))
.retrieve();
cache_dest.keep();
println!("\t\t - {}ms", now.elapsed().as_millis());

View File

@@ -49,7 +49,7 @@ impl SerializeModule for Mlp {
}
}
fn apply_rotary_embeddings_ggml(input: GraphTensor, prev_seq: BigExpression) -> GraphTensor {
fn apply_rotary_embeddings_ggml(input: GraphTensor, prev_seq: Expression) -> GraphTensor {
assert_eq!(input.shape.len(), 4); // batch, n_heads, seq, head_dim
let (batch, n_heads, seq, head_dim) = input.dims4();
// Get freqs
@@ -102,8 +102,8 @@ impl Module<(GraphTensor, KVCache)> for SelfAttention {
.permute((0, 2, 1, 3));
// Rotary embed queries and keys
let queries = apply_rotary_embeddings_ggml(queries, prev_seq.big());
let keys = apply_rotary_embeddings_ggml(keys, prev_seq.big());
let queries = apply_rotary_embeddings_ggml(queries, prev_seq);
let keys = apply_rotary_embeddings_ggml(keys, prev_seq);
// Add KV cache
let keys = k_cache.concat_along(keys, 2);

View File

@@ -49,7 +49,7 @@ impl SerializeModule for Mlp {
}
}
fn apply_rotary_embeddings_ggml(input: GraphTensor, prev_seq: BigExpression) -> GraphTensor {
fn apply_rotary_embeddings_ggml(input: GraphTensor, prev_seq: Expression) -> GraphTensor {
assert_eq!(input.shape.len(), 4); // batch, n_heads, seq, head_dim
let (batch, n_heads, seq, head_dim) = input.dims4();
// Get freqs
@@ -102,8 +102,8 @@ impl Module<(GraphTensor, KVCache)> for SelfAttention {
.permute((0, 2, 1, 3));
// Rotary embed queries and keys
let queries = apply_rotary_embeddings_ggml(queries, prev_seq.big());
let keys = apply_rotary_embeddings_ggml(keys, prev_seq.big());
let queries = apply_rotary_embeddings_ggml(queries, prev_seq);
let keys = apply_rotary_embeddings_ggml(keys, prev_seq);
// Add KV cache
let keys = k_cache.concat_along(keys, 2);

View File

@@ -47,7 +47,7 @@ impl SerializeModule for Mlp {
}
}
fn apply_rotary_embeddings_ggml(input: GraphTensor, prev_seq: BigExpression) -> GraphTensor {
fn apply_rotary_embeddings_ggml(input: GraphTensor, prev_seq: Expression) -> GraphTensor {
assert_eq!(input.shape.len(), 4); // batch, n_heads, seq, head_dim
let (batch, n_heads, seq, head_dim) = input.dims4();
// Get freqs
@@ -100,8 +100,8 @@ impl Module<(GraphTensor, KVCache)> for SelfAttention {
.permute((0, 2, 1, 3));
// Rotary embed queries and keys
let queries = apply_rotary_embeddings_ggml(queries, prev_seq.big());
let keys = apply_rotary_embeddings_ggml(keys, prev_seq.big());
let queries = apply_rotary_embeddings_ggml(queries, prev_seq);
let keys = apply_rotary_embeddings_ggml(keys, prev_seq);
// Add KV cache
let keys = k_cache.concat_along(keys, 2);

View File

@@ -19,3 +19,7 @@ memmap2 = "0.9.4"
colored = "2.1.0"
itertools = "0.12.1"
tokenizers = "0.15.2"
image = "0.25.1"
imageproc = "0.25.0"
ab_glyph = "0.2.28"
safetensors = "0.4.3"

BIN
examples/yolo_v8/bike.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 179 KiB

Binary file not shown.

View File

@@ -0,0 +1,45 @@
use std::io::Read;
use std::path::Path;
use std::{fs::File, io::Seek};
use luminal::{op::Function, prelude::*};
use memmap2::MmapOptions;
use safetensors::{Dtype, SafeTensors};
pub fn load<M: SerializeModule>(path: &str, model: &M, graph: &mut Graph) {
for (weight_name, node_index) in param_dict(model) {
if let Some(loading_node) = graph
.graph
.node_weight_mut(node_index)
.and_then(|op| op.as_any_mut().downcast_mut::<Function>())
{
let path = path.to_string();
loading_node.1 = Box::new(move |_| {
let mut bytes = vec![];
let mut file = File::open(&path).unwrap();
file.read_to_end(&mut bytes).unwrap();
let safetensors = SafeTensors::deserialize(&bytes).unwrap();
if let Ok(tensor_view) = safetensors.tensor(&weight_name.replace('/', ".")) {
// Convert to fp32
let data: Vec<f32> = match tensor_view.dtype() {
Dtype::F32 => tensor_view
.data()
.chunks_exact(4)
.map(|c| f32::from_ne_bytes([c[0], c[1], c[2], c[3]]))
.collect(),
Dtype::F16 => tensor_view
.data()
.chunks_exact(2)
.map(|c| f16::from_ne_bytes([c[0], c[1]]).to_f32())
.collect(),
_ => panic!("{:?} is not a supported dtype", tensor_view.dtype()),
};
return vec![Tensor::new(data)];
}
panic!("Tensor \"{weight_name}\" not found in files");
});
}
}
}

View File

@@ -1,5 +1,293 @@
mod loader;
mod model;
fn main() {
println!("Hello, world!");
use image::DynamicImage;
use luminal::prelude::*;
pub const NAMES: [&str; 80] = [
"person",
"bicycle",
"car",
"motorbike",
"aeroplane",
"bus",
"train",
"truck",
"boat",
"traffic light",
"fire hydrant",
"stop sign",
"parking meter",
"bench",
"bird",
"cat",
"dog",
"horse",
"sheep",
"cow",
"elephant",
"bear",
"zebra",
"giraffe",
"backpack",
"umbrella",
"handbag",
"tie",
"suitcase",
"frisbee",
"skis",
"snowboard",
"sports ball",
"kite",
"baseball bat",
"baseball glove",
"skateboard",
"surfboard",
"tennis racket",
"bottle",
"wine glass",
"cup",
"fork",
"knife",
"spoon",
"bowl",
"banana",
"apple",
"sandwich",
"orange",
"broccoli",
"carrot",
"hot dog",
"pizza",
"donut",
"cake",
"chair",
"sofa",
"pottedplant",
"bed",
"diningtable",
"toilet",
"tvmonitor",
"laptop",
"mouse",
"remote",
"keyboard",
"cell phone",
"microwave",
"oven",
"toaster",
"sink",
"refrigerator",
"book",
"clock",
"vase",
"scissors",
"teddy bear",
"hair drier",
"toothbrush",
];
/// A bounding box around an object.
#[derive(Debug, Clone)]
pub struct Bbox<D> {
pub xmin: f32,
pub ymin: f32,
pub xmax: f32,
pub ymax: f32,
pub confidence: f32,
pub data: D,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct KeyPoint {
pub x: f32,
pub y: f32,
pub mask: f32,
}
/// Intersection over union of two bounding boxes.
pub fn iou<D>(b1: &Bbox<D>, b2: &Bbox<D>) -> f32 {
let b1_area = (b1.xmax - b1.xmin + 1.) * (b1.ymax - b1.ymin + 1.);
let b2_area = (b2.xmax - b2.xmin + 1.) * (b2.ymax - b2.ymin + 1.);
let i_xmin = b1.xmin.max(b2.xmin);
let i_xmax = b1.xmax.min(b2.xmax);
let i_ymin = b1.ymin.max(b2.ymin);
let i_ymax = b1.ymax.min(b2.ymax);
let i_area = (i_xmax - i_xmin + 1.).max(0.) * (i_ymax - i_ymin + 1.).max(0.);
i_area / (b1_area + b2_area - i_area)
}
pub fn non_maximum_suppression<D>(bboxes: &mut [Vec<Bbox<D>>], threshold: f32) {
// Perform non-maximum suppression.
for bboxes_for_class in bboxes.iter_mut() {
bboxes_for_class.sort_by(|b1, b2| b2.confidence.partial_cmp(&b1.confidence).unwrap());
let mut current_index = 0;
for index in 0..bboxes_for_class.len() {
let mut drop = false;
for prev_index in 0..current_index {
let iou = iou(&bboxes_for_class[prev_index], &bboxes_for_class[index]);
if iou > threshold {
drop = true;
break;
}
}
if !drop {
bboxes_for_class.swap(current_index, index);
current_index += 1;
}
}
bboxes_for_class.truncate(current_index);
}
}
#[allow(clippy::too_many_arguments)]
pub fn report_detect(
pred_size: usize,
n_preds: usize,
pred: &[f32],
img: DynamicImage,
w: usize,
h: usize,
confidence_threshold: f32,
nms_threshold: f32,
legend_size: u32,
) -> DynamicImage {
let nclasses = pred_size - 4;
// The bounding boxes grouped by (maximum) class index.
let mut bboxes: Vec<Vec<Bbox<Vec<KeyPoint>>>> = (0..nclasses).map(|_| vec![]).collect();
// Extract the bounding boxes for which confidence is above the threshold.
for index in 0..n_preds {
let pred = pred[pred_size * index..pred_size * (index + 1)].to_vec();
let confidence = *pred[4..].iter().max_by(|x, y| x.total_cmp(y)).unwrap();
if confidence > confidence_threshold {
let mut class_index = 0;
for i in 0..nclasses {
if pred[4 + i] > pred[4 + class_index] {
class_index = i
}
}
if pred[class_index + 4] > 0. {
let bbox = Bbox {
xmin: pred[0] - pred[2] / 2.,
ymin: pred[1] - pred[3] / 2.,
xmax: pred[0] + pred[2] / 2.,
ymax: pred[1] + pred[3] / 2.,
confidence,
data: vec![],
};
bboxes[class_index].push(bbox)
}
}
}
non_maximum_suppression(&mut bboxes, nms_threshold);
// Annotate the original image and print boxes information.
let (initial_h, initial_w) = (img.height(), img.width());
let w_ratio = initial_w as f32 / w as f32;
let h_ratio = initial_h as f32 / h as f32;
let mut img = img.to_rgb8();
let font = Vec::from(include_bytes!("../roboto-mono-stripped.ttf") as &[u8]);
let font = ab_glyph::FontRef::try_from_slice(&font).unwrap();
for (class_index, bboxes_for_class) in bboxes.iter().enumerate() {
for b in bboxes_for_class.iter() {
println!("{}: {:?}", NAMES[class_index], b);
let xmin = (b.xmin * w_ratio) as i32;
let ymin = (b.ymin * h_ratio) as i32;
let dx = (b.xmax - b.xmin) * w_ratio;
let dy = (b.ymax - b.ymin) * h_ratio;
if dx >= 0. && dy >= 0. {
imageproc::drawing::draw_hollow_rect_mut(
&mut img,
imageproc::rect::Rect::at(xmin, ymin).of_size(dx as u32, dy as u32),
image::Rgb([255, 0, 0]),
);
}
if legend_size > 0 {
imageproc::drawing::draw_filled_rect_mut(
&mut img,
imageproc::rect::Rect::at(xmin, ymin).of_size(dx as u32, legend_size),
image::Rgb([170, 0, 0]),
);
let legend = format!("{} {:.0}%", NAMES[class_index], 100. * b.confidence);
imageproc::drawing::draw_text_mut(
&mut img,
image::Rgb([255, 255, 255]),
xmin,
ymin,
ab_glyph::PxScale {
x: legend_size as f32 - 1.,
y: legend_size as f32 - 1.,
},
&font,
&legend,
)
}
}
}
DynamicImage::ImageRgb8(img)
}
fn main() {
// Setup graph
let mut cx = Graph::new();
let mut input = cx.tensor((1, 3, 'h', 'w'));
let model = model::Yolo::new(0.33, 0.25, 2.0, 80, &mut cx);
let mut model_params = params(&model);
let mut output = model.forward(input).retrieve();
loader::load("yolov8n.safetensors", &model, &mut cx);
// Compile
cx.compile(
GenericCompiler::default(),
(&mut input, &mut model_params, &mut output),
);
let mut image_name = std::path::PathBuf::from("bike.jpg");
let original_image = image::io::Reader::open(&image_name)
.unwrap()
.decode()
.unwrap();
let (width, height) = {
let w = original_image.width() as usize;
let h = original_image.height() as usize;
if w < h {
let w = w * 640 / h;
// Sizes have to be divisible by 32.
(w / 32 * 32, 640)
} else {
let h = h * 640 / w;
(640, h / 32 * 32)
}
};
println!("Width: {width} Height: {height}");
let img = original_image.resize_exact(
width as u32,
height as u32,
image::imageops::FilterType::CatmullRom,
);
let data = img
.to_rgb8()
.into_raw()
.into_iter()
.map(|i| i as f32 / 255.)
.collect::<Vec<_>>();
input.set_dyn(data, (1, 3, img.height() as usize, img.width() as usize));
let (_, pred_size, n_preds) = output.dims3();
cx.execute();
let image_t = report_detect(
pred_size.exec(&cx.dyn_map).unwrap(),
n_preds.exec(&cx.dyn_map).unwrap(),
&output.data(),
original_image,
width,
height,
0.25,
0.45,
14,
);
image_name.set_extension("pp.jpg");
println!("writing {image_name:?}");
image_t.save(image_name).unwrap();
}

View File

@@ -1,88 +1,584 @@
use luminal::prelude::*;
use luminal_nn::Conv2D;
struct Bottleneck<const CH_IN: usize, const CH_OUT: usize> {
cv1: Conv2D<CH_IN, CH_OUT, 3, 3, 1, 1>,
cv2: Conv2D<CH_OUT, CH_OUT, 3, 3, 1, 1>,
struct Upsample {
scale_factor: usize,
}
impl<const CH_IN: usize, const CH_OUT: usize, Width: Dimension, Height: Dimension>
Module<GraphTensor<(Const<CH_IN>, Height, Width)>> for Bottleneck<CH_IN, CH_OUT>
{
type Output = GraphTensor<(Const<CH_OUT>, Height, Width)>;
fn forward(&self, input: GraphTensor<(Const<CH_IN>, Height, Width)>) -> Self::Output {
self.cv2
.forward(
self.cv1
.forward::<Width, Height, Width, Height>(input.permute()),
)
.permute()
impl Upsample {
fn new(scale_factor: usize) -> Self {
Upsample { scale_factor }
}
}
struct C2F {
cv1: ConvBlock,
cv2: ConvBlock,
bottleneck: Vec<Bottleneck>,
impl Module<GraphTensor> for Upsample {
type Output = GraphTensor;
fn forward(&self, xs: GraphTensor) -> GraphTensor {
let (batch, channels, h, w) = xs.dims4();
xs.expand(3, self.scale_factor)
.expand(5, self.scale_factor)
.reshape((
batch,
channels,
self.scale_factor * h,
self.scale_factor * w,
)) // Double height and width
}
}
struct Bottleneck {
cv1: Conv2D,
cv2: Conv2D,
residual: bool,
}
impl Bottleneck {
pub fn new(ch_in: usize, ch_out: usize, shortcut: bool, cx: &mut Graph) -> Self {
Self {
cv1: Conv2D::new(ch_in, ch_out, (3, 3), (3, 3), (1, 1), true, cx),
cv2: Conv2D::new(ch_out, ch_out, (3, 3), (3, 3), (1, 1), true, cx),
residual: ch_in == ch_out && shortcut,
}
}
}
impl SerializeModule for Bottleneck {
fn serialize(&self, s: &mut Serializer) {
s.module("cv1", &self.cv1);
s.module("cv2", &self.cv2);
}
}
impl Module<GraphTensor> for Bottleneck {
type Output = GraphTensor;
fn forward(&self, input: GraphTensor) -> Self::Output {
let mut out = self
.cv2
.forward(self.cv1.forward(input.permute((0, 1, 3, 2))))
.permute((0, 1, 3, 2));
if self.residual {
out += input
}
out
}
}
struct C2f {
cv1: Conv2D,
cv2: Conv2D,
bottleneck: Vec<Bottleneck>,
c: usize,
}
impl C2f {
pub fn new(c1: usize, c2: usize, n: usize, shortcut: bool, cx: &mut Graph) -> Self {
let c = (c2 as f64 / 2.) as usize;
Self {
cv1: Conv2D::new(c1, 2 * c, (1, 1), (1, 1), (1, 1), true, cx),
cv2: Conv2D::new((2 + n) * c, c2, (1, 1), (1, 1), (1, 1), true, cx),
bottleneck: (0..n)
.map(|_| Bottleneck::new(c, c, shortcut, cx))
.collect(),
c,
}
}
}
impl SerializeModule for C2f {
fn serialize(&self, s: &mut Serializer) {
s.module("cv1", &self.cv1);
s.module("cv2", &self.cv2);
for (i, l) in self.bottleneck.iter().enumerate() {
s.module(&format!("bottleneck.{i}"), l);
}
}
}
impl Module<GraphTensor> for C2f {
type Output = GraphTensor;
fn forward(&self, input: GraphTensor) -> Self::Output {
let ys = self.cv1.forward(input);
let mut ys = chunk(ys, 2, 1);
for m in self.bottleneck.iter() {
ys.push(m.forward(*ys.last().unwrap()));
}
let mut fin = ys.remove(0);
for t in ys {
fin = fin.concat_along(t, 1);
}
self.cv2.forward(fin)
}
}
fn chunk(tensor: GraphTensor, chunks: usize, dim: usize) -> Vec<GraphTensor> {
let chunk_size = tensor.dims()[dim] / chunks;
let mut t = vec![];
for i in 0..chunks {
t.push(tensor.slice_along(i * chunk_size..(i + 1) * chunk_size, dim));
}
t
}
#[allow(clippy::upper_case_acronyms)]
struct SPPF {
cv1: ConvBlock,
cv2: ConvBlock,
cv1: Conv2D,
cv2: Conv2D,
k: usize,
}
impl SerializeModule for SPPF {
fn serialize(&self, s: &mut Serializer) {
s.module("cv1", &self.cv1);
s.module("cv2", &self.cv2);
}
}
impl SPPF {
pub fn new(c1: usize, c2: usize, k: usize, cx: &mut Graph) -> Self {
let c_ = c1 / 2;
Self {
cv1: Conv2D::new(c1, c_, (1, 1), (1, 1), (1, 1), true, cx),
cv2: Conv2D::new(c_ * 4, c2, (1, 1), (1, 1), (1, 1), true, cx),
k,
}
}
}
impl Module<GraphTensor> for SPPF {
type Output = GraphTensor;
fn forward(&self, xs: GraphTensor) -> Self::Output {
let xs = self.cv1.forward(xs);
let xs2 = xs
.pad((
(0, 0),
(0, 0),
(self.k / 2, self.k / 2),
(self.k / 2, self.k / 2),
))
.pool_last_dim(self.k, 1, 1)
.max_reduce(4);
let xs3 = xs2
.pad((
(0, 0),
(0, 0),
(self.k / 2, self.k / 2),
(self.k / 2, self.k / 2),
))
.pool_last_dim(self.k, 1, 1)
.max_reduce(4);
let xs4 = xs3
.pad((
(0, 0),
(0, 0),
(self.k / 2, self.k / 2),
(self.k / 2, self.k / 2),
))
.pool_last_dim(self.k, 1, 1)
.max_reduce(4);
self.cv2.forward(
xs.concat_along(xs2, 1)
.concat_along(xs3, 1)
.concat_along(xs4, 1),
)
}
}
struct DFL {
conv: Conv2D,
num_classes: usize,
}
impl SerializeModule for DFL {
fn serialize(&self, s: &mut Serializer) {
s.module("conv", &self.conv);
}
}
impl DFL {
pub fn new(num_classes: usize, cx: &mut Graph) -> Self {
Self {
conv: Conv2D::new(num_classes, 1, (1, 1), (1, 1), (1, 1), false, cx),
num_classes,
}
}
}
impl Module<GraphTensor> for DFL {
type Output = GraphTensor;
fn forward(&self, xs: GraphTensor) -> Self::Output {
let (b_sz, _channels, anchors) = xs.dims3();
let xs = xs
.reshape((b_sz, 4, self.num_classes, anchors))
.permute((0, 1, 3, 2))
.softmax(1);
self.conv.forward(xs).reshape((b_sz, 4, anchors))
}
}
struct DarkNet {
b1_0: ConvBlock,
b1_1: ConvBlock,
b2_0: C2F,
b2_1: ConvBlock,
b2_2: C2F,
b3_0: ConvBlock,
b3_1: C2F,
b4_0: ConvBlock,
b4_1: C2F,
b1_0: Conv2D,
b1_1: Conv2D,
b2_0: C2f,
b2_1: Conv2D,
b2_2: C2f,
b3_0: Conv2D,
b3_1: C2f,
b4_0: Conv2D,
b4_1: C2f,
b5: SPPF,
}
impl SerializeModule for DarkNet {
fn serialize(&self, s: &mut Serializer) {
s.module("b1_0", &self.b1_0);
s.module("b1_1", &self.b1_1);
s.module("b2_0", &self.b2_0);
s.module("b2_1", &self.b2_1);
s.module("b3_0", &self.b3_0);
s.module("b3_1", &self.b3_1);
s.module("b4_0", &self.b4_0);
s.module("b4_1", &self.b4_1);
s.module("b5", &self.b5);
}
}
impl DarkNet {
pub fn new(w: f64, r: f64, d: f64, cx: &mut Graph) -> Self {
Self {
b1_0: Conv2D::new(3, (64. * w) as usize, (3, 3), (2, 2), (1, 1), true, cx),
b1_1: Conv2D::new(
(64. * w) as usize,
(128. * w) as usize,
(3, 3),
(2, 2),
(1, 1),
true,
cx,
),
b2_0: C2f::new(
(128. * w) as usize,
(128. * w) as usize,
(3. * d).round() as usize,
true,
cx,
),
b2_1: Conv2D::new(
(128. * w) as usize,
(256. * w) as usize,
(3, 3),
(2, 2),
(1, 1),
true,
cx,
),
b2_2: C2f::new(
(256. * w) as usize,
(256. * w) as usize,
(6. * d).round() as usize,
true,
cx,
),
b3_0: Conv2D::new(
(256. * w) as usize,
(512. * w) as usize,
(3, 3),
(2, 2),
(1, 1),
true,
cx,
),
b3_1: C2f::new(
(512. * w) as usize,
(512. * w) as usize,
(6. * d).round() as usize,
true,
cx,
),
b4_0: Conv2D::new(
(512. * w) as usize,
(512. * w * r) as usize,
(3, 3),
(2, 2),
(1, 1),
true,
cx,
),
b4_1: C2f::new(
(512. * w * r) as usize,
(512. * w * r) as usize,
(3. * d).round() as usize,
true,
cx,
),
b5: SPPF::new((512. * w * r) as usize, (512. * w * r) as usize, 5, cx),
}
}
}
impl Module<GraphTensor> for DarkNet {
type Output = (GraphTensor, GraphTensor, GraphTensor);
fn forward(&self, xs: GraphTensor) -> Self::Output {
let x1 = self.b1_1.forward(self.b1_0.forward(xs));
let x2 = self.b2_2.forward(self.b2_1.forward(self.b2_0.forward(x1)));
let x3 = self.b3_1.forward(self.b3_0.forward(x2));
let x4 = self.b4_1.forward(self.b4_0.forward(x3));
let x5 = self.b5.forward(x4);
(x2, x3, x5)
}
}
struct YoloNeck {
n1: C2F,
n2: C2F,
n3: ConvBlock,
n4: C2F,
n5: ConvBlock,
n6: C2F,
up: Upsample,
n1: C2f,
n2: C2f,
n3: Conv2D,
n4: C2f,
n5: Conv2D,
n6: C2f,
}
impl SerializeModule for YoloNeck {
fn serialize(&self, s: &mut Serializer) {
s.module("n1", &self.n1);
s.module("n2", &self.n2);
s.module("n3", &self.n3);
s.module("n4", &self.n4);
s.module("n5", &self.n5);
s.module("n6", &self.n6);
}
}
impl YoloNeck {
pub fn new(w: f64, r: f64, d: f64, cx: &mut Graph) -> Self {
let n = (3. * d).round() as usize;
Self {
up: Upsample::new(2),
n1: C2f::new(
(512. * w * (1. + r)) as usize,
(512. * w) as usize,
n,
false,
cx,
),
n2: C2f::new((768. * w) as usize, (256. * w) as usize, n, false, cx),
n3: Conv2D::new(
(256. * w) as usize,
(256. * w) as usize,
(3, 3),
(2, 2),
(1, 1),
true,
cx,
),
n4: C2f::new((768. * w) as usize, (512. * w) as usize, n, false, cx),
n5: Conv2D::new(
(512. * w) as usize,
(512. * w) as usize,
(3, 3),
(2, 2),
(1, 1),
true,
cx,
),
n6: C2f::new(
(512. * w * (1. + r)) as usize,
(512. * w * r) as usize,
n,
false,
cx,
),
}
}
}
impl Module<(GraphTensor, GraphTensor, GraphTensor)> for YoloNeck {
type Output = (GraphTensor, GraphTensor, GraphTensor);
fn forward(&self, (p3, p4, p5): (GraphTensor, GraphTensor, GraphTensor)) -> Self::Output {
let x = self.n1.forward(self.up.forward(p5).concat_along(p4, 1));
let head_1 = self.n2.forward(self.up.forward(x).concat_along(p3, 1));
let head_2 = self.n4.forward(self.n3.forward(head_1).concat_along(x, 1));
let head_3 = self.n6.forward(self.n5.forward(head_2).concat_along(p5, 1));
(head_1, head_2, head_3)
}
}
struct DetectionHead {
dfl: DFL,
cv2: [(ConvBlock, ConvBlock, Conv2D); 3],
cv3: [(ConvBlock, ConvBlock, Conv2D); 3],
cv2: [(Conv2D, Conv2D, Conv2D); 3],
cv3: [(Conv2D, Conv2D, Conv2D); 3],
ch: usize,
no: usize,
}
pub struct Yolo {
net: DarkNet,
fpn: YoloNeck,
head: DetectionHead,
}
impl<
const CHANNELS: usize,
const CLASSIFICATION: usize,
Batch: Dimension,
Height: Dimension,
Width: Dimension,
> Module<GraphTensor<(Batch, Const<CHANNELS>, Width, Height)>> for Yolo
{
type Output = GraphTensor<(Batch,)>;
fn forward(&self, input: GraphTensor<(Batch, Const<CHANNELS>, Width, Height)>) -> Self::Output {
impl SerializeModule for DetectionHead {
fn serialize(&self, s: &mut Serializer) {
s.module("dfl", &self.dfl);
for (i, m) in self.cv2.iter().enumerate() {
s.module(&format!("cv2/{i}"), m);
}
for (i, m) in self.cv3.iter().enumerate() {
s.module(&format!("cv3/{i}"), m);
}
}
}
impl DetectionHead {
pub fn new(nc: usize, filters: (usize, usize, usize), cx: &mut Graph) -> Self {
let ch = 16;
let c1 = usize::max(filters.0, nc);
let c2 = usize::max(filters.0 / 4, ch * 4);
Self {
dfl: DFL::new(ch, cx),
cv2: [
Self::new_cv2(c2, ch, filters.0, cx),
Self::new_cv2(c2, ch, filters.1, cx),
Self::new_cv2(c2, ch, filters.2, cx),
],
cv3: [
Self::new_cv3(c1, nc, filters.0, cx),
Self::new_cv3(c1, nc, filters.1, cx),
Self::new_cv3(c1, nc, filters.2, cx),
],
ch,
no: nc + ch * 4,
}
}
fn new_cv3(c1: usize, nc: usize, filter: usize, cx: &mut Graph) -> (Conv2D, Conv2D, Conv2D) {
(
Conv2D::new(filter, c1, (3, 3), (1, 1), (1, 1), true, cx),
Conv2D::new(c1, c1, (3, 3), (1, 1), (1, 1), true, cx),
Conv2D::new(c1, nc, (1, 1), (1, 1), (1, 1), true, cx),
)
}
fn new_cv2(c2: usize, ch: usize, filter: usize, cx: &mut Graph) -> (Conv2D, Conv2D, Conv2D) {
(
Conv2D::new(filter, c2, (3, 3), (1, 1), (1, 1), true, cx),
Conv2D::new(c2, c2, (3, 3), (1, 1), (1, 1), true, cx),
Conv2D::new(c2, 4 * ch, (1, 1), (1, 1), (1, 1), true, cx),
)
}
}
impl Module<(GraphTensor, GraphTensor, GraphTensor)> for DetectionHead {
type Output = (GraphTensor, GraphTensor, GraphTensor);
fn forward(&self, (xs0, xs1, xs2): (GraphTensor, GraphTensor, GraphTensor)) -> Self::Output {
let forward_cv = |xs, i: usize| {
let xs_2 = self.cv2[i].0.forward(xs);
let xs_2 = self.cv2[i].1.forward(xs_2);
let xs_2 = self.cv2[i].2.forward(xs_2);
let xs_3 = self.cv3[i].0.forward(xs);
let xs_3 = self.cv3[i].1.forward(xs_3);
let xs_3 = self.cv3[i].2.forward(xs_3);
xs_2.concat_along(xs_3, 1)
};
let xs0 = forward_cv(xs0, 0);
let xs1 = forward_cv(xs1, 1);
let xs2 = forward_cv(xs2, 2);
let (anchors, strides) = make_anchors(xs0, xs1, xs2, (8, 16, 32), 0.5);
let anchors = anchors.permute((1, 0)).expand(0, 1);
let strides = strides.permute((1, 0));
let reshape = |xs: GraphTensor| {
let d = xs.dims()[0];
let el = xs.shape.n_elements();
xs.reshape((d, self.no, el / (d * self.no)))
};
let ys0 = reshape(xs0);
let ys1 = reshape(xs1);
let ys2 = reshape(xs2);
let x_cat = ys0.concat_along(ys1, 2).concat_along(ys2, 2);
let box_ = x_cat.slice((.., ..self.ch * 4));
let cls = x_cat.slice((.., self.ch * 4..));
let dbox = dist2bbox(self.dfl.forward(box_), anchors);
let dbox = dbox * strides.expand_to(dbox.shape);
let pred = dbox.concat_along(cls.sigmoid(), 1);
(pred, anchors, strides)
}
}
pub struct Yolo {
net: DarkNet, // Backbone
fpn: YoloNeck, // Neck
head: DetectionHead, // Head
}
impl SerializeModule for Yolo {
fn serialize(&self, s: &mut Serializer) {
s.module("net", &self.net);
s.module("fpn", &self.fpn);
s.module("head", &self.head);
}
}
impl Yolo {
pub fn new(w: f64, r: f64, d: f64, num_classes: usize, cx: &mut Graph) -> Self {
let f1 = (256. * w) as usize;
let f2 = (512. * r) as usize;
let f3 = (512. * w * r) as usize;
Self {
net: DarkNet::new(w, r, d, cx),
fpn: YoloNeck::new(w, r, d, cx),
head: DetectionHead::new(num_classes, (f1, f2, f3), cx),
}
}
}
fn make_anchors(
xs0: GraphTensor,
xs1: GraphTensor,
xs2: GraphTensor,
(s0, s1, s2): (usize, usize, usize),
grid_cell_offset: f64,
) -> (GraphTensor, GraphTensor) {
let cx = xs0.graph();
let mut anchor_points = vec![];
let mut stride_tensor = vec![];
for (xs, stride) in [(xs0, s0), (xs1, s1), (xs2, s2)] {
// xs is only used to extract the h and w dimensions.
let (_, _, h, w) = xs.dims4();
let sx = cx.arange(w) + grid_cell_offset as f32;
let sy = cx.arange(h) + grid_cell_offset as f32;
let sx = sx.reshape((1, w)).expand(0, h).reshape(h * w);
let sy = sy.reshape((h, 1)).expand(1, w).reshape(h * w);
anchor_points.push(sx.expand(0, 1).concat_along(sy.expand(0, 1), 0));
stride_tensor.push(cx.constant(1.).expand(0, h * w) * stride as f32);
}
let anchor_points = anchor_points
.into_iter()
.reduce(|acc, t| acc.concat_along(t, 0))
.unwrap();
let stride_tensor = stride_tensor
.into_iter()
.reduce(|acc, t| acc.concat_along(t, 0))
.unwrap()
.expand(1, 1);
(anchor_points, stride_tensor)
}
fn dist2bbox(distance: GraphTensor, anchor_points: GraphTensor) -> GraphTensor {
let chunks = chunk(distance, 2, 1);
let lt = chunks[0];
let rb = chunks[1];
let x1y1 = anchor_points - lt;
let x2y2 = anchor_points + rb;
let c_xy = (x1y1 + x2y2) * 0.5;
let wh = x2y2 - x1y1;
c_xy.concat_along(wh, 1)
}
impl Module<GraphTensor> for Yolo {
type Output = GraphTensor;
fn forward(&self, xs: GraphTensor) -> Self::Output {
let (xs1, xs2, xs3) = self.net.forward(xs);
let (xs1, xs2, xs3) = self.fpn.forward((xs1, xs2, xs3));
let (pred, _, _) = self.head.forward((xs1, xs2, xs3));
pred
}
}

Binary file not shown.

View File

@@ -410,7 +410,7 @@ impl Graph {
&& edge
.weight()
.as_data()
.map(|d| !d.2.shape().is_empty())
.map(|d| !d.2.is_empty())
.unwrap_or_default()
{
new_graph
@@ -418,7 +418,7 @@ impl Graph {
.unwrap()
.push_str(&format!(
" | {:?}",
edge.weight().as_data().unwrap().2.shape()
edge.weight().as_data().unwrap().2.dims()
));
}
}
@@ -686,7 +686,7 @@ fn backtrack_match(
pattern_root: NodeIndex,
pattern_graph: &StableGraph<(Uuid, SelectOp), Option<u8>>,
main_root: NodeIndex,
main_graph: &mut MainGraph,
main_graph: &mut StorageGraph,
) -> Option<FxHashMap<NodeIndex, NodeIndex>> {
fn get_parents<N, E>(
graph: &petgraph::stable_graph::StableGraph<N, E>,
@@ -736,7 +736,7 @@ fn test_node(
shape,
fake,
}: &SelectOp,
graph: &mut MainGraph,
graph: &mut StorageGraph,
graph_node: NodeIndex,
) -> bool {
let input_shapes = graph
@@ -763,7 +763,7 @@ fn test_node(
if a_sh.len() != b_sh.dims.len() {
return false;
}
for (a, b) in a_sh.iter().zip(b_sh.shape().iter()) {
for (a, b) in a_sh.iter().zip(b_sh.dims().into_iter()) {
match a.to_usize() {
Some(n) => {
if b.to_usize().map(|i| i != n).unwrap_or(true) {
@@ -776,11 +776,11 @@ fn test_node(
.pop()
.expect("Selector dimension must be either a symbol or number");
if let Some(expected) = shape_map.get(&c) {
if b != expected {
if b != *expected {
return false;
}
} else {
shape_map.insert(c, b.clone());
shape_map.insert(c, b);
}
}
}

View File

@@ -11,7 +11,7 @@ use itertools::Itertools;
use petgraph::{stable_graph::StableGraph, visit::EdgeRef, Direction};
use rustc_hash::{FxHashMap, FxHashSet};
pub type MainGraph = StableGraph<Box<dyn Operator>, Dependency>;
pub type StorageGraph = StableGraph<Box<dyn Operator>, Dependency>;
/// A Luminal compute graph.
///
@@ -24,7 +24,7 @@ pub struct Graph {
/// A map of dynamic dimensions to concrete dimension sizes
pub dyn_map: FxHashMap<char, usize>,
/// Edge weights: (Input index, Output index, Input shape)
pub graph: MainGraph,
pub graph: StorageGraph,
/// Tensors marked in this set will not get deleted when the graph is ran
pub no_delete: FxHashSet<NodeIndex>,
/// Tensors marked in this set need to be retrieved later (mostly for optimizers to insert copy back calls, the graph itself doesn't treat these differently)
@@ -37,7 +37,7 @@ pub struct Graph {
}
/// A dependency between two nodes
#[derive(Debug, Clone, Copy, serde::Serialize, serde::Deserialize)]
#[derive(Debug, Clone, Copy)]
#[allow(clippy::large_enum_variant)]
pub enum Dependency {
/// A data dependency (transferring a tensor from one node to the next)
@@ -125,7 +125,7 @@ impl Graph {
Box::new(move |_| panic!("You must set a value for this tensor! ({name})")),
))),
graph_ref: self,
shape: ShapeTracker::new(&shape.to_shape()),
shape: ShapeTracker::new(shape),
}
}
@@ -305,7 +305,7 @@ impl Graph {
.map(|(_, s)| {
format!(
"{:?}",
s.shape()
s.dims()
.into_iter()
.map(|i| i.to_usize().unwrap())
.collect::<Vec<_>>()
@@ -366,7 +366,7 @@ impl Graph {
}
impl Deref for Graph {
type Target = MainGraph;
type Target = StorageGraph;
fn deref(&self) -> &Self::Target {
&self.graph
}
@@ -378,6 +378,12 @@ impl DerefMut for Graph {
}
}
impl Drop for Graph {
fn drop(&mut self) {
expression_cleanup();
}
}
/// Get source tensor array for a node
fn get_source_tensors<'a>(
no_delete: &'a FxHashSet<NodeIndex>,

View File

@@ -65,7 +65,7 @@ impl GraphTensor {
/// ```
pub fn set_dyn(self, data: impl Data + Clone, shape: impl ToShape) -> Self {
// Report dyn dim values to graph dyn map
for (d, s) in self.shape.shape().iter().zip(shape.to_shape().into_iter()) {
for (d, s) in self.shape.dims().iter().zip(shape.to_shape().into_iter()) {
if let Some(c) = d.to_symbols().pop() {
self.graph().dyn_map.insert(c, s.to_usize().unwrap());
}
@@ -103,39 +103,58 @@ impl GraphTensor {
data
}
pub fn shape(&self) -> Vec<BigExpression> {
self.shape.shape()
pub fn dims(&self) -> Vec<Expression> {
self.shape.dims()
}
pub fn dims1(&self) -> Expression {
self.shape.shape()[0].small()
assert_eq!(
self.shape.len(),
1,
"Shape has {} dimensions, tried to get 1",
self.shape.len()
);
self.dims()[0]
}
pub fn dims2(&self) -> (Expression, Expression) {
(self.shape.shape()[0].small(), self.shape.shape()[1].small())
assert_eq!(
self.shape.len(),
2,
"Shape has {} dimensions, tried to get 2",
self.shape.len()
);
let dims = self.dims();
(dims[0], dims[1])
}
pub fn dims3(&self) -> (Expression, Expression, Expression) {
(
self.shape.shape()[0].small(),
self.shape.shape()[1].small(),
self.shape.shape()[2].small(),
)
assert_eq!(
self.shape.len(),
3,
"Shape has {} dimensions, tried to get 3",
self.shape.len()
);
let dims = self.dims();
(dims[0], dims[1], dims[2])
}
pub fn dims4(&self) -> (Expression, Expression, Expression, Expression) {
(
self.shape.shape()[0].small(),
self.shape.shape()[1].small(),
self.shape.shape()[2].small(),
self.shape.shape()[3].small(),
)
assert_eq!(
self.shape.len(),
4,
"Shape has {} dimensions, tried to get 4",
self.shape.len()
);
let dims = self.dims();
(dims[0], dims[1], dims[2], dims[3])
}
pub fn dims5(&self) -> (Expression, Expression, Expression, Expression, Expression) {
(
self.shape.shape()[0].small(),
self.shape.shape()[1].small(),
self.shape.shape()[2].small(),
self.shape.shape()[3].small(),
self.shape.shape()[4].small(),
)
assert_eq!(
self.shape.len(),
5,
"Shape has {} dimensions, tried to get 5",
self.shape.len()
);
let dims = self.dims();
(dims[0], dims[1], dims[2], dims[3], dims[4])
}
/// Set the value of the tensor matching the constant shape

View File

@@ -140,13 +140,10 @@ impl Add<f32> for GraphTensor {
}
}
impl<St: ExpressionStorage> Add<GenericExpression<St>> for GraphTensor
where
GenericExpression<Vec<Term>>: From<GenericExpression<St>>,
{
impl<S: Into<Expression>> Add<S> for GraphTensor {
type Output = GraphTensor;
fn add(self, rhs: GenericExpression<St>) -> Self::Output {
fn add(self, rhs: S) -> Self::Output {
self + self.graph().constant_expr(rhs).expand_to(self.shape)
}
}
@@ -159,13 +156,10 @@ impl Sub<f32> for GraphTensor {
}
}
impl<St: ExpressionStorage> Sub<GenericExpression<St>> for GraphTensor
where
GenericExpression<Vec<Term>>: From<GenericExpression<St>>,
{
impl<S: Into<Expression>> Sub<S> for GraphTensor {
type Output = GraphTensor;
fn sub(self, rhs: GenericExpression<St>) -> Self::Output {
fn sub(self, rhs: S) -> Self::Output {
self - self.graph().constant_expr(rhs).expand_to(self.shape)
}
}
@@ -178,13 +172,10 @@ impl Mul<f32> for GraphTensor {
}
}
impl<St: ExpressionStorage> Mul<GenericExpression<St>> for GraphTensor
where
GenericExpression<Vec<Term>>: From<GenericExpression<St>>,
{
impl<S: Into<Expression>> Mul<S> for GraphTensor {
type Output = GraphTensor;
fn mul(self, rhs: GenericExpression<St>) -> Self::Output {
fn mul(self, rhs: S) -> Self::Output {
self * self.graph().constant_expr(rhs).expand_to(self.shape)
}
}
@@ -198,13 +189,10 @@ impl Div<f32> for GraphTensor {
}
}
impl<St: ExpressionStorage> Div<GenericExpression<St>> for GraphTensor
where
GenericExpression<Vec<Term>>: From<GenericExpression<St>>,
{
impl<S: Into<Expression>> Div<S> for GraphTensor {
type Output = GraphTensor;
fn div(self, rhs: GenericExpression<St>) -> Self::Output {
fn div(self, rhs: S) -> Self::Output {
self / self.graph().constant_expr(rhs).expand_to(self.shape)
}
}
@@ -217,13 +205,10 @@ impl Rem<f32> for GraphTensor {
}
}
impl<St: ExpressionStorage> Rem<GenericExpression<St>> for GraphTensor
where
GenericExpression<Vec<Term>>: From<GenericExpression<St>>,
{
impl<S: Into<Expression>> Rem<S> for GraphTensor {
type Output = GraphTensor;
fn rem(self, rhs: GenericExpression<St>) -> Self::Output {
fn rem(self, rhs: S) -> Self::Output {
self % self.graph().constant_expr(rhs).expand_to(self.shape)
}
}

View File

@@ -15,11 +15,11 @@ impl GraphTensor {
// Sum Reduce
let mut ret = mul.sum_reduce(2);
if vec {
ret = ret.reshape(ret.shape().last().unwrap());
ret = ret.reshape(ret.dims().last().unwrap());
}
ret
} else if self.shape.len() == 3 {
let d = rhs.shape().last().unwrap().small();
let d = *rhs.dims().last().unwrap();
let (a, b, _) = self.dims3();
if rhs.shape.len() == 2 {
// ABCxCD -> ABD
@@ -43,8 +43,8 @@ impl GraphTensor {
} else {
panic!(
"Can't matmul lhs {:?} and rhs {:?}",
self.shape(),
rhs.shape()
self.dims(),
rhs.dims()
)
}
} else if self.shape.len() == 4 {
@@ -73,8 +73,8 @@ impl GraphTensor {
} else {
panic!(
"Can't matmul lhs {:?} and rhs {:?}",
self.shape(),
rhs.shape()
self.dims(),
rhs.dims()
)
}
} else if self.shape.len() == 5 && rhs.shape.len() == 5 {
@@ -93,8 +93,8 @@ impl GraphTensor {
} else {
panic!(
"Can't matmul lhs {:?} and rhs {:?}",
self.shape(),
rhs.shape()
self.dims(),
rhs.dims()
)
}
}

View File

@@ -28,7 +28,7 @@ impl GraphTensor {
pub fn reshape(mut self, new_shape: impl ToShape) -> GraphTensor {
// Insert contiguous call
self = self.contiguous();
self.shape = ShapeTracker::new(&new_shape.to_shape());
self.shape = ShapeTracker::new(new_shape);
self
}
@@ -60,6 +60,12 @@ impl GraphTensor {
self
}
pub fn slice_along(self, slice: impl SliceRange, axis: usize) -> GraphTensor {
let mut s = vec![(Expression::from(0), Expression::from(i32::MAX)); axis + 1];
s[axis] = slice.bounds();
self.slice(s)
}
/// Cut out 'size' elements every 'spacing' elements in the last dimension. 'size' must be smaller than the last dimension
pub fn excise(mut self, spacing: usize, size: usize) -> GraphTensor {
let n_dims = self.shape.len();
@@ -90,38 +96,36 @@ impl GraphTensor {
/// Pool elements along the last dimension, pools are exposed as a new dimension
pub fn pool_last_dim(
mut self,
kernel: impl Into<BigExpression>,
stride: impl Into<BigExpression>,
kernel: impl Into<Expression>,
stride: impl Into<Expression>,
dilation: usize,
) -> GraphTensor {
let (kernel, stride) = (kernel.into(), stride.into());
let n_dims = self.shape.len();
let full_kernel = kernel.clone() + (kernel.clone() - 1) * (dilation - 1);
let dim_size = self.shape.shape().pop().unwrap().simplify().small();
let number_of_windows = (((dim_size.big() - full_kernel.clone()) / stride.clone()) + 1)
.simplify()
.small();
let full_kernel = kernel + (kernel - 1) * (dilation - 1);
let dim_size = self.dims().pop().unwrap().simplify();
let number_of_windows = (((dim_size - full_kernel) / stride) + 1).simplify();
// Expand new dimension
self.shape.expand(n_dims - 1, number_of_windows);
self = self.contiguous();
if n_dims > 1 {
// View as single dimension of matrix with wider width
let mat_size = (dim_size.big() + stride.clone()) * number_of_windows;
let actual_size = (dim_size.big() * number_of_windows).simplify().small();
let mat_size = (dim_size + stride) * number_of_windows;
let actual_size = (dim_size * number_of_windows).simplify();
// Reshape into single dimension to pad
self.shape.remove_dim(n_dims);
self.shape.dims[self.shape.indexes[n_dims - 1]] = actual_size;
self.shape.padding[self.shape.indexes[n_dims - 1]].1 =
(mat_size - actual_size).simplify().small();
(mat_size - actual_size).simplify();
self = self.contiguous();
// Reshape back (mats should be full now)
self.shape.add_dim(n_dims, dim_size + stride.clone());
self.shape.add_dim(n_dims, dim_size + stride);
self.shape.dims[self.shape.indexes[n_dims - 1]] = number_of_windows;
} else {
self.shape.dims[self.shape.indexes[n_dims]] = dim_size + stride;
}
// Slice down to kernel size
self.shape.mask[self.shape.indexes[n_dims]].1 = full_kernel.simplify().small();
self.shape.mask[self.shape.indexes[n_dims]].1 = full_kernel.simplify();
self.shape.mask[self.shape.indexes[n_dims - 1]].1 = number_of_windows;
self = self.contiguous();
@@ -147,14 +151,21 @@ impl GraphTensor {
self
}
pub fn pad_along(
self,
left: impl Into<Expression>,
right: impl Into<Expression>,
axis: usize,
) -> GraphTensor {
let mut p = vec![(Expression::from(0), Expression::from(0)); axis + 1];
p[axis] = (left.into(), right.into());
self.pad(p)
}
pub fn concat_along(self, rhs: GraphTensor, axis: usize) -> GraphTensor {
// Create padding
let mut a_padding = vec![(Expression::default(), Expression::default()); self.shape.len()];
a_padding[axis].1 = rhs.shape.shape()[axis].small();
let mut b_padding = vec![(Expression::default(), Expression::default()); rhs.shape.len()];
b_padding[axis].0 = self.shape.shape()[axis].small();
// Pad and add
self.pad(a_padding) + rhs.pad(b_padding)
self.pad_along(0, rhs.shape.dims()[axis], axis)
+ rhs.pad_along(self.shape.dims()[axis], 0, axis)
}
}
@@ -273,7 +284,7 @@ mod tests {
let a = cx
.tensor((3, 2))
.set(vec![1.4325, 2.492428, 3.127365, 33.2834, 4.18734, 23.854]);
let b = a.slice((.., ..Expression::from(1))).retrieve();
let b = a.slice((.., ..1)).retrieve();
cx.execute();
let d_dev = Cpu::default();
@@ -431,8 +442,8 @@ mod tests {
let mut cx = Graph::new();
let a = cx.tensor((3, 2));
a.set(vec![1.4325, 2.492428, 3.127365, 33.2834, 4.18734, 23.854]);
let x1 = a.slice((.., ..Expression::from(1))).contiguous();
let x2 = a.slice((.., Expression::from(1)..)).contiguous();
let x1 = a.slice((.., ..1)).contiguous();
let x2 = a.slice((.., 1..)).contiguous();
let c = (-x2).concat_along(x1, 1);
c.retrieve();
cx.execute();

View File

@@ -48,24 +48,14 @@ impl From<f64> for ConstantValue {
ConstantValue::Float(value as f32)
}
}
impl From<BigExpression> for ConstantValue {
fn from(value: BigExpression) -> Self {
ConstantValue::Expression(value)
}
}
impl From<Expression> for ConstantValue {
fn from(value: Expression) -> Self {
ConstantValue::Expression((&value).into())
}
}
impl From<&BigExpression> for ConstantValue {
fn from(value: &BigExpression) -> Self {
ConstantValue::Expression(value.clone())
ConstantValue::Expression(value)
}
}
impl From<&Expression> for ConstantValue {
fn from(value: &Expression) -> Self {
ConstantValue::Expression(value.into())
ConstantValue::Expression(*value)
}
}
@@ -74,26 +64,26 @@ impl Graph {
pub fn constant(&mut self, i: impl Into<ConstantValue>) -> GraphTensor {
GraphTensor::from_id(
self.add_op(Constant(i.into(), &self.dyn_map)).finish(),
ShapeTracker::default(),
ShapeTracker::new(()),
self,
)
}
/// A scalar constant evaluated from an expression at runtime
pub fn constant_expr<E: Into<BigExpression>>(&mut self, expr: E) -> GraphTensor {
pub fn constant_expr<E: Into<Expression>>(&mut self, expr: E) -> GraphTensor {
GraphTensor::from_id(
self.add_op(Constant(
ConstantValue::Expression(expr.into().simplify()),
&self.dyn_map,
))
.finish(),
ShapeTracker::default(),
ShapeTracker::new(()),
self,
)
}
/// ARange from 0 to N
pub fn arange(&mut self, to: impl Into<Expression> + Copy) -> GraphTensor {
pub fn arange(&mut self, to: impl Into<Expression>) -> GraphTensor {
let to = to.into();
if to.to_usize().map(|i| i == 1).unwrap_or_default() {
// Single number ARange is just 0
@@ -106,7 +96,8 @@ impl Graph {
/// Lower left-hand triangle of 1s. Currently required to be square
///
/// Same API as https://pytorch.org/docs/stable/generated/torch.tril
pub fn tril(&mut self, size: impl Into<Expression> + Copy, diagonal: i32) -> GraphTensor {
pub fn tril(&mut self, size: impl Into<Expression>, diagonal: i32) -> GraphTensor {
let size = size.into();
let horizontal = self.arange(size).expand(0, size);
let vertical = self.arange(size).expand(1, size);
@@ -116,7 +107,8 @@ impl Graph {
/// Upper right-hand triangle of 1s
///
/// Same API as https://pytorch.org/docs/stable/generated/torch.triu
pub fn triu(&mut self, size: impl Into<Expression> + Copy, diagonal: i32) -> GraphTensor {
pub fn triu(&mut self, size: impl Into<Expression>, diagonal: i32) -> GraphTensor {
let size = size.into();
let horizontal = self.arange(size).expand(0, size);
let vertical = self.arange(size).expand(1, size);

View File

@@ -52,23 +52,13 @@ impl GraphTensor {
let mul_tensor = self
.graph()
.add_op(op::Recip)
.input(div_tensor, 0, ShapeTracker::default())
.input(div_tensor, 0, ShapeTracker::new(()))
.finish();
node_id = self
.graph()
.add_op(op::Mul)
.input(node_id, 0, shape)
.input(
mul_tensor,
0,
ShapeTracker::fake(
&shape
.shape()
.iter()
.map(Expression::from)
.collect::<Vec<_>>(),
),
)
.input(mul_tensor, 0, ShapeTracker::fake(shape))
.finish();
}
GraphTensor::from_id(node_id, shape, self.graph_ref)

View File

@@ -128,12 +128,7 @@ impl GraphTensor {
let r = self
.graph()
.constant(1.)
.expand_to(ShapeTracker::new(&[self
.shape
.shape()
.last()
.unwrap()
.small()]))
.expand_to(ShapeTracker::new(self.shape.dims().last().unwrap()))
.cumsum_last_dim()
- 1.;
// Multiply one-hot by expanded index arange

View File

@@ -194,7 +194,7 @@ macro_rules! tuple_impls {
impl<$($name: SerializeModule,)+> SerializeModule for ($($name,)+) {
fn serialize(&self, s: &mut Serializer) {
$(s.module(&format!("layer{}", $idx), &self.$idx);)+
$(s.module(&format!("{}", $idx), &self.$idx);)+
}
}
};

View File

@@ -132,9 +132,9 @@ impl Debug for Function {
}
/// A constant value placed on the graph at runtime. Can either be an expression evaluated at runtime, or a constant float
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
#[derive(Debug, Clone, PartialEq)]
pub enum ConstantValue {
Expression(BigExpression),
Expression(Expression),
Float(f32),
}
@@ -383,7 +383,7 @@ fn get_vec<'a>(tensor: &'a InputTensor<'a>) -> &'a Vec<f32> {
fn get_index(
data: &[f32],
(ind, val): &(BigExpression, BigExpression),
(ind, val): &(Expression, Expression),
stack: &mut Vec<i64>,
index: usize,
) -> f32 {

View File

@@ -17,7 +17,7 @@ use std::ops::{Bound, Range, RangeBounds, RangeFrom, RangeFull, RangeTo, RangeTo
fn get_start_bound<D: Into<Expression> + Copy>(bound: Bound<D>) -> Expression {
match bound {
Bound::Included(x) => x.into(),
Bound::Excluded(x) => x.into() + Expression::from(1),
Bound::Excluded(x) => x.into() + 1,
Bound::Unbounded => 0.into(),
}
}
@@ -25,7 +25,7 @@ fn get_start_bound<D: Into<Expression> + Copy>(bound: Bound<D>) -> Expression {
fn get_end_bound<D: Into<Expression> + Copy>(bound: Bound<D>) -> Expression {
match bound {
Bound::Excluded(x) => x.into(),
Bound::Included(x) => x.into() + Expression::from(1),
Bound::Included(x) => x.into() + 1,
Bound::Unbounded => Expression::from(i32::MAX),
}
}
@@ -110,29 +110,29 @@ impl<R: SliceRange> SliceRange for (R,) {
}
pub trait ToSlice {
fn to_range_vec(&self) -> Vec<(Expression, Expression)>;
fn to_range_vec(self) -> Vec<(Expression, Expression)>;
}
impl<R: SliceRange> ToSlice for R {
fn to_range_vec(&self) -> Vec<(Expression, Expression)> {
fn to_range_vec(self) -> Vec<(Expression, Expression)> {
vec![self.bounds()]
}
}
impl<R1: SliceRange, R2: SliceRange> ToSlice for (R1, R2) {
fn to_range_vec(&self) -> Vec<(Expression, Expression)> {
fn to_range_vec(self) -> Vec<(Expression, Expression)> {
vec![self.0.bounds(), self.1.bounds()]
}
}
impl<R1: SliceRange, R2: SliceRange, R3: SliceRange> ToSlice for (R1, R2, R3) {
fn to_range_vec(&self) -> Vec<(Expression, Expression)> {
fn to_range_vec(self) -> Vec<(Expression, Expression)> {
vec![self.0.bounds(), self.1.bounds(), self.2.bounds()]
}
}
impl<R1: SliceRange, R2: SliceRange, R3: SliceRange, R4: SliceRange> ToSlice for (R1, R2, R3, R4) {
fn to_range_vec(&self) -> Vec<(Expression, Expression)> {
fn to_range_vec(self) -> Vec<(Expression, Expression)> {
vec![
self.0.bounds(),
self.1.bounds(),
@@ -145,7 +145,7 @@ impl<R1: SliceRange, R2: SliceRange, R3: SliceRange, R4: SliceRange> ToSlice for
impl<R1: SliceRange, R2: SliceRange, R3: SliceRange, R4: SliceRange, R5: SliceRange> ToSlice
for (R1, R2, R3, R4, R5)
{
fn to_range_vec(&self) -> Vec<(Expression, Expression)> {
fn to_range_vec(self) -> Vec<(Expression, Expression)> {
vec![
self.0.bounds(),
self.1.bounds(),
@@ -156,6 +156,32 @@ impl<R1: SliceRange, R2: SliceRange, R3: SliceRange, R4: SliceRange, R5: SliceRa
}
}
impl<A: Into<Expression>, B: Into<Expression>> ToSlice for Vec<(A, B)> {
fn to_range_vec(self) -> Vec<(Expression, Expression)> {
self.into_iter().map(|i| (i.0.into(), i.1.into())).collect()
}
}
impl<A: Into<Expression> + Copy, B: Into<Expression> + Copy> ToSlice for &Vec<(A, B)> {
fn to_range_vec(self) -> Vec<(Expression, Expression)> {
self.iter().map(|i| (i.0.into(), i.1.into())).collect()
}
}
impl<A: Into<Expression> + Copy, B: Into<Expression> + Copy> ToSlice for &[(A, B)] {
fn to_range_vec(self) -> Vec<(Expression, Expression)> {
self.iter().map(|i| (i.0.into(), i.1.into())).collect()
}
}
impl<const N: usize, A: Into<Expression> + Copy, B: Into<Expression> + Copy> ToSlice
for &[(A, B); N]
{
fn to_range_vec(self) -> Vec<(Expression, Expression)> {
self.iter().map(|i| (i.0.into(), i.1.into())).collect()
}
}
pub trait ToPad {
fn to_pad_vec(self) -> Vec<(Expression, Expression)>;
}
@@ -464,6 +490,6 @@ impl<A: Into<Expression>> ToShape for A {
impl ToShape for ShapeTracker {
fn to_shape(self) -> Vec<Expression> {
self.shape().to_shape()
self.dims().to_shape()
}
}

View File

@@ -1,20 +1,55 @@
use egg::*;
use generational_box::{AnyStorage, GenerationalBox, Owner, UnsyncStorage};
use rustc_hash::FxHashMap;
use std::{
cell::RefCell,
fmt::Debug,
hash::Hash,
ops::{
Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, Div, DivAssign, IndexMut, Mul,
MulAssign, Rem, RemAssign, Sub, SubAssign,
Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, Div, DivAssign, Mul, MulAssign,
Rem, RemAssign, Sub, SubAssign,
},
};
use symbolic_expressions::Sexp;
use tinyvec::ArrayVec;
/// A symbolic expression stored on the stack
pub type Expression = GenericExpression<ArrayVec<[Term; 20]>>; // We need to figure out how to reduce this, can't be fixed at 20. ShapeTracker would take up 6 dims * 12 pads * 12 slices * 20 terms * 8 bytes = 138kb
/// A symbolic expression stored on the heap
pub type BigExpression = GenericExpression<Vec<Term>>;
thread_local! {
static EXPRESSION_OWNER: RefCell<Option<Owner<UnsyncStorage>>> = RefCell::new(Some(UnsyncStorage::owner()));
}
/// Clean up symbolic expresion storage
pub fn expression_cleanup() {
EXPRESSION_OWNER.with(|cell| cell.borrow_mut().take());
}
/// Get the thread-local owner of expression storage
fn expression_owner() -> Owner {
EXPRESSION_OWNER.with(|cell| cell.borrow().clone().unwrap())
}
#[derive(Clone, Copy)]
pub struct Expression {
pub terms: GenerationalBox<Vec<Term>>,
}
impl Expression {
fn new(terms: Vec<Term>) -> Self {
Self {
terms: expression_owner().insert(terms),
}
}
}
impl Hash for Expression {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.terms.read().hash(state);
}
}
impl Default for Expression {
fn default() -> Self {
Expression::new(vec![])
}
}
/// A single term of a symbolic expression such as a variable, number or operation.
#[derive(Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
@@ -79,99 +114,27 @@ impl Term {
}
}
/// Trait implemented on the 2 main symbolic expression storage types, Vec<Term> and ArrayVec<Term>
#[allow(clippy::len_without_is_empty)]
pub trait ExpressionStorage:
Clone
+ IndexMut<usize, Output = Term>
+ std::iter::Extend<Term>
+ Default
+ PartialEq
+ Debug
+ Hash
+ Eq
{
fn len(&self) -> usize;
fn push(&mut self, term: Term);
fn pop(&mut self) -> Option<Term>;
fn remove(&mut self, index: usize) -> Term;
fn into_vec(self) -> Vec<Term>;
fn iter_ref(&self) -> impl Iterator<Item = &Term>;
}
// Implement the main storage types
impl ExpressionStorage for Vec<Term> {
fn len(&self) -> usize {
Vec::len(self)
}
fn push(&mut self, term: Term) {
Vec::push(self, term)
}
fn pop(&mut self) -> Option<Term> {
Vec::pop(self)
}
fn remove(&mut self, index: usize) -> Term {
Vec::remove(self, index)
}
fn into_vec(self) -> Vec<Term> {
self
}
fn iter_ref(&self) -> impl Iterator<Item = &Term> {
self.iter()
}
}
impl<const C: usize> ExpressionStorage for ArrayVec<[Term; C]>
impl<T> PartialEq<T> for Expression
where
[Term; C]: tinyvec::Array<Item = Term>,
{
fn len(&self) -> usize {
ArrayVec::len(self)
}
fn push(&mut self, term: Term) {
ArrayVec::push(self, term)
}
fn pop(&mut self) -> Option<Term> {
ArrayVec::pop(self)
}
fn remove(&mut self, index: usize) -> Term {
ArrayVec::remove(self, index)
}
fn into_vec(self) -> Vec<Term> {
self.to_vec()
}
fn iter_ref(&self) -> impl Iterator<Item = &Term> {
self.iter()
}
}
/// A symbolic expression
#[derive(Clone, Copy, Hash, Eq, serde::Serialize, serde::Deserialize)]
pub struct GenericExpression<S: ExpressionStorage> {
pub terms: S, // Terms in postfix notation
}
impl<S: ExpressionStorage, T> PartialEq<T> for GenericExpression<S>
where
for<'a> &'a T: Into<Self>,
for<'a> &'a T: Into<Expression>,
{
fn eq(&self, other: &T) -> bool {
self.terms == other.into().terms
*self.terms.read() == *other.into().terms.read()
}
}
impl<S: ExpressionStorage> Default for GenericExpression<S> {
fn default() -> Self {
let mut s = S::default();
s.push(Term::Num(0));
Self { terms: s }
impl From<&Expression> for Expression {
fn from(value: &Expression) -> Self {
*value
}
}
impl<S: ExpressionStorage + Clone> Debug for GenericExpression<S> {
impl Eq for Expression {}
impl Debug for Expression {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut symbols = vec![];
for term in self.terms.iter_ref() {
for term in self.terms.read().iter() {
let new_symbol = match term {
Term::Num(n) => n.to_string(),
Term::Var(c) => c.to_string(),
@@ -197,16 +160,16 @@ impl<S: ExpressionStorage + Clone> Debug for GenericExpression<S> {
}
}
impl<S: ExpressionStorage + Clone> std::fmt::Display for GenericExpression<S> {
impl std::fmt::Display for Expression {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", self)
write!(f, "{self:?}")
}
}
impl<S: ExpressionStorage> GenericExpression<S> {
impl Expression {
/// Simplify the expression to its minimal terms
pub fn simplify(self) -> Self {
if self.terms.len() == 1 {
if self.terms.read().len() == 1 {
return self;
}
egg_simplify(self)
@@ -215,17 +178,17 @@ impl<S: ExpressionStorage> GenericExpression<S> {
/// Simplify the expression to its minimal terms, using a cache to retrieve / store the simplification
pub fn simplify_cache(self, cache: &mut FxHashMap<Self, Self>) -> Self {
if let Some(s) = cache.get(&self) {
s.clone()
*s
} else {
let simplified = self.clone().simplify();
cache.insert(self, simplified.clone());
let simplified = self.simplify();
cache.insert(self, simplified);
simplified
}
}
pub fn as_num(&self) -> Option<i32> {
if let Term::Num(n) = self.terms[0] {
if self.terms.len() == 1 {
if let Term::Num(n) = self.terms.read()[0] {
if self.terms.read().len() == 1 {
return Some(n);
}
}
@@ -233,22 +196,23 @@ impl<S: ExpressionStorage> GenericExpression<S> {
}
/// Minimum
pub fn min<E: Into<Self>>(self, rhs: E) -> Self {
let mut rhs = rhs.into();
pub fn min(self, rhs: impl Into<Self>) -> Self {
let rhs = rhs.into();
if rhs == self || rhs == i32::MAX {
return self;
}
if let (Some(a), Some(b)) = (self.as_num(), rhs.as_num()) {
return a.min(b).into();
}
rhs.terms.extend(self.terms.iter_ref().copied());
rhs.terms.push(Term::Min);
rhs
let mut terms = rhs.terms.read().clone();
terms.extend(self.terms.read().iter().copied());
terms.push(Term::Min);
Expression::new(terms)
}
/// Maximum
pub fn max<E: Into<Self>>(self, rhs: E) -> Self {
let mut rhs = rhs.into();
pub fn max<E: Into<Expression>>(self, rhs: E) -> Self {
let rhs = rhs.into();
if rhs == self || rhs == 0 || self == i32::MAX {
return self;
}
@@ -258,14 +222,15 @@ impl<S: ExpressionStorage> GenericExpression<S> {
if let (Some(a), Some(b)) = (self.as_num(), rhs.as_num()) {
return a.max(b).into();
}
rhs.terms.extend(self.terms.iter_ref().copied());
rhs.terms.push(Term::Max);
rhs
let mut terms = rhs.terms.read().clone();
terms.extend(self.terms.read().iter().copied());
terms.push(Term::Max);
Expression::new(terms)
}
/// Greater than or equals
pub fn gte<E: Into<Self>>(self, rhs: E) -> Self {
let mut rhs = rhs.into();
pub fn gte<E: Into<Expression>>(self, rhs: E) -> Self {
let rhs = rhs.into();
if rhs == self {
return true.into();
}
@@ -275,37 +240,42 @@ impl<S: ExpressionStorage> GenericExpression<S> {
if let (Some(a), Some(b)) = (self.as_num(), rhs.as_num()) {
return (a >= b).into();
}
rhs.terms.extend(self.terms.iter_ref().copied());
rhs.terms.push(Term::Gte);
rhs
let mut terms = rhs.terms.read().clone();
terms.extend(self.terms.read().iter().copied());
terms.push(Term::Gte);
Expression::new(terms)
}
/// Less than
pub fn lt<E: Into<Self>>(self, rhs: E) -> Self {
let mut rhs = rhs.into();
pub fn lt<E: Into<Expression>>(self, rhs: E) -> Self {
let rhs = rhs.into();
if rhs == self {
return false.into();
}
if let Term::Num(n) = rhs.terms[0] {
if self.terms[self.terms.len() - 1] == Term::Mod && self.terms[0] == Term::Num(n) {
if let Term::Num(n) = rhs.terms.read()[0] {
if self.terms.read()[self.terms.read().len() - 1] == Term::Mod
&& self.terms.read()[0] == Term::Num(n)
{
return true.into();
}
}
if let (Some(a), Some(b)) = (self.as_num(), rhs.as_num()) {
return (a < b).into();
}
rhs.terms.extend(self.terms.iter_ref().copied());
rhs.terms.push(Term::Lt);
rhs
let mut terms = rhs.terms.read().clone();
terms.extend(self.terms.read().iter().copied());
terms.push(Term::Lt);
Expression::new(terms)
}
/// Substitute an expression for a variable
pub fn substitute<N: ExpressionStorage>(self, var: char, expr: GenericExpression<N>) -> Self {
let mut new_terms = S::default();
for term in self.terms.iter_ref() {
pub fn substitute(self, var: char, expr: impl Into<Expression>) -> Self {
let mut new_terms = vec![];
let t = expr.into().terms.read();
for term in self.terms.read().iter() {
match term {
Term::Var(c) if *c == var => {
for t in expr.terms.iter_ref() {
for t in t.iter() {
new_terms.push(*t);
}
}
@@ -314,11 +284,11 @@ impl<S: ExpressionStorage> GenericExpression<S> {
}
}
}
Self { terms: new_terms }
Expression::new(new_terms)
}
}
impl<S: ExpressionStorage> GenericExpression<S> {
impl Expression {
/// Evaluate the expression with no variables. Returns Some(value) if no variables are required, otherwise returns None.
pub fn to_usize(&self) -> Option<usize> {
self.exec(&FxHashMap::default())
@@ -330,7 +300,7 @@ impl<S: ExpressionStorage> GenericExpression<S> {
}
/// Evaluate the expression with one value for all variables. Uses a provided stack
pub fn exec_single_var_stack(&self, value: usize, stack: &mut Vec<i64>) -> usize {
for term in self.terms.iter_ref() {
for term in self.terms.read().iter() {
match term {
Term::Num(n) => stack.push(*n as i64),
Term::Var(_) => stack.push(value as i64),
@@ -353,7 +323,7 @@ impl<S: ExpressionStorage> GenericExpression<S> {
variables: &FxHashMap<char, usize>,
stack: &mut Vec<i64>,
) -> Option<usize> {
for term in self.terms.iter_ref() {
for term in self.terms.read().iter() {
match term {
Term::Num(n) => stack.push(*n as i64),
Term::Var(c) =>
@@ -377,7 +347,8 @@ impl<S: ExpressionStorage> GenericExpression<S> {
/// Retrieve all symbols in the expression.
pub fn to_symbols(&self) -> Vec<char> {
self.terms
.iter_ref()
.read()
.iter()
.filter_map(|t| match t {
Term::Var(c) => Some(*c),
_ => None,
@@ -387,108 +358,92 @@ impl<S: ExpressionStorage> GenericExpression<S> {
/// Check if the '-' variable exists in the expression.
pub fn is_unknown(&self) -> bool {
self.terms.iter_ref().any(|t| matches!(t, Term::Var('-')))
self.terms
.read()
.iter()
.any(|t| matches!(t, Term::Var('-')))
}
}
impl Expression {
pub fn big(&self) -> BigExpression {
BigExpression::from(*self)
}
}
impl BigExpression {
pub fn small(&self) -> Expression {
Expression::from(self)
}
}
impl<S: ExpressionStorage> From<Term> for GenericExpression<S> {
impl From<Term> for Expression {
fn from(value: Term) -> Self {
let mut terms = S::default();
terms.push(value);
GenericExpression { terms }
Expression::new(vec![value])
}
}
impl<S: ExpressionStorage> From<char> for GenericExpression<S> {
impl From<char> for Expression {
fn from(value: char) -> Self {
GenericExpression::from(Term::Var(value))
Expression::new(vec![Term::Var(value)])
}
}
impl<S: ExpressionStorage> From<&char> for GenericExpression<S> {
impl From<&char> for Expression {
fn from(value: &char) -> Self {
GenericExpression::from(Term::Var(*value))
Expression::new(vec![Term::Var(*value)])
}
}
impl<S: ExpressionStorage> From<usize> for GenericExpression<S> {
impl From<usize> for Expression {
fn from(value: usize) -> Self {
GenericExpression::from(Term::Num(value as i32))
Expression::new(vec![Term::Num(value as i32)])
}
}
impl<S: ExpressionStorage> From<&usize> for GenericExpression<S> {
impl From<&usize> for Expression {
fn from(value: &usize) -> Self {
GenericExpression::from(Term::Num(*value as i32))
Expression::new(vec![Term::Num(*value as i32)])
}
}
impl<S: ExpressionStorage> From<i32> for GenericExpression<S> {
impl From<i32> for Expression {
fn from(value: i32) -> Self {
GenericExpression::from(value as usize)
Expression::new(vec![Term::Num(value)])
}
}
impl<S: ExpressionStorage> From<&i32> for GenericExpression<S> {
impl From<&i32> for Expression {
fn from(value: &i32) -> Self {
GenericExpression::from(*value as usize)
Expression::new(vec![Term::Num(*value)])
}
}
impl<S: ExpressionStorage> From<bool> for GenericExpression<S> {
impl From<bool> for Expression {
fn from(value: bool) -> Self {
GenericExpression::from(value as usize)
Expression::new(vec![Term::Num(value as i32)])
}
}
impl<S: ExpressionStorage> From<&bool> for GenericExpression<S> {
impl From<&bool> for Expression {
fn from(value: &bool) -> Self {
GenericExpression::from(*value as usize)
Expression::new(vec![Term::Num(*value as i32)])
}
}
impl<S: ExpressionStorage, T: ExpressionStorage> From<&GenericExpression<T>>
for GenericExpression<S>
{
fn from(value: &GenericExpression<T>) -> Self {
let mut s = S::default();
s.extend(value.terms.iter_ref().copied());
Self { terms: s }
impl Sub<Expression> for usize {
type Output = Expression;
fn sub(self, rhs: Expression) -> Self::Output {
Expression::from(self) - rhs
}
}
impl From<Expression> for BigExpression {
fn from(value: Expression) -> Self {
Self {
terms: value.terms.to_vec(),
}
impl Mul<Expression> for usize {
type Output = Expression;
fn mul(self, rhs: Expression) -> Self::Output {
rhs * self
}
}
impl From<BigExpression> for Expression {
fn from(value: BigExpression) -> Self {
let mut terms = ArrayVec::new();
terms.extend(value.terms);
Self { terms }
impl Div<Expression> for usize {
type Output = Expression;
fn div(self, rhs: Expression) -> Self::Output {
Expression::from(self) / rhs
}
}
impl<S: ExpressionStorage, E: Into<Self>> Add<E> for GenericExpression<S> {
impl<E: Into<Expression>> Add<E> for Expression {
type Output = Self;
fn add(self, rhs: E) -> Self::Output {
let mut rhs = rhs.into();
let rhs = rhs.into();
if rhs == 0 {
return self;
}
@@ -501,16 +456,17 @@ impl<S: ExpressionStorage, E: Into<Self>> Add<E> for GenericExpression<S> {
if let (Some(a), Some(b)) = (self.as_num(), rhs.as_num()) {
return (a + b).into();
}
rhs.terms.extend(self.terms.iter_ref().copied());
rhs.terms.push(Term::Add);
rhs
let mut terms = rhs.terms.read().clone();
terms.extend(self.terms.read().iter().copied());
terms.push(Term::Add);
Expression::new(terms)
}
}
impl<S: ExpressionStorage, E: Into<Self>> Sub<E> for GenericExpression<S> {
impl<E: Into<Expression>> Sub<E> for Expression {
type Output = Self;
fn sub(self, rhs: E) -> Self::Output {
let mut rhs = rhs.into();
let rhs = rhs.into();
if rhs == 0 {
return self;
}
@@ -520,16 +476,17 @@ impl<S: ExpressionStorage, E: Into<Self>> Sub<E> for GenericExpression<S> {
if let (Some(a), Some(b)) = (self.as_num(), rhs.as_num()) {
return (a - b).into();
}
rhs.terms.extend(self.terms.iter_ref().copied());
rhs.terms.push(Term::Sub);
rhs
let mut terms = rhs.terms.read().clone();
terms.extend(self.terms.read().iter().copied());
terms.push(Term::Sub);
Expression::new(terms)
}
}
impl<S: ExpressionStorage, E: Into<Self>> Mul<E> for GenericExpression<S> {
impl<E: Into<Expression>> Mul<E> for Expression {
type Output = Self;
fn mul(self, rhs: E) -> Self::Output {
let mut rhs = rhs.into();
let rhs = rhs.into();
if rhs == 1 {
return self;
}
@@ -544,16 +501,17 @@ impl<S: ExpressionStorage, E: Into<Self>> Mul<E> for GenericExpression<S> {
return c.into();
}
}
rhs.terms.extend(self.terms.iter_ref().copied());
rhs.terms.push(Term::Mul);
rhs
let mut terms = rhs.terms.read().clone();
terms.extend(self.terms.read().iter().copied());
terms.push(Term::Mul);
Expression::new(terms)
}
}
impl<S: ExpressionStorage, E: Into<Self>> Div<E> for GenericExpression<S> {
impl<E: Into<Expression>> Div<E> for Expression {
type Output = Self;
fn div(self, rhs: E) -> Self::Output {
let mut rhs = rhs.into();
let rhs = rhs.into();
if rhs == 1 {
return self;
}
@@ -566,32 +524,34 @@ impl<S: ExpressionStorage, E: Into<Self>> Div<E> for GenericExpression<S> {
if let (Some(a), Some(b)) = (self.as_num(), rhs.as_num()) {
return (a / b).into();
}
rhs.terms.extend(self.terms.iter_ref().copied());
rhs.terms.push(Term::Div);
rhs
let mut terms = rhs.terms.read().clone();
terms.extend(self.terms.read().iter().copied());
terms.push(Term::Div);
Expression::new(terms)
}
}
impl<S: ExpressionStorage, E: Into<Self>> Rem<E> for GenericExpression<S> {
impl<E: Into<Expression>> Rem<E> for Expression {
type Output = Self;
fn rem(self, rhs: E) -> Self::Output {
let mut rhs = rhs.into();
let rhs = rhs.into();
if rhs == 1 || rhs == self {
return 0.into();
}
if let (Some(a), Some(b)) = (self.as_num(), rhs.as_num()) {
return (a % b).into();
}
rhs.terms.extend(self.terms.iter_ref().copied());
rhs.terms.push(Term::Mod);
rhs
let mut terms = rhs.terms.read().clone();
terms.extend(self.terms.read().iter().copied());
terms.push(Term::Mod);
Expression::new(terms)
}
}
impl<S: ExpressionStorage, E: Into<Self>> BitAnd<E> for GenericExpression<S> {
impl<E: Into<Expression>> BitAnd<E> for Expression {
type Output = Self;
fn bitand(self, rhs: E) -> Self::Output {
let mut rhs = rhs.into();
let rhs = rhs.into();
if rhs == 0 || self == 0 {
return 0.into();
}
@@ -604,30 +564,32 @@ impl<S: ExpressionStorage, E: Into<Self>> BitAnd<E> for GenericExpression<S> {
if let (Some(a), Some(b)) = (self.as_num(), rhs.as_num()) {
return (a != 0 && b != 0).into();
}
rhs.terms.extend(self.terms.iter_ref().copied());
rhs.terms.push(Term::And);
rhs
let mut terms = rhs.terms.read().clone();
terms.extend(self.terms.read().iter().copied());
terms.push(Term::And);
Expression::new(terms)
}
}
impl<S: ExpressionStorage, E: Into<Self>> BitOr<E> for GenericExpression<S> {
impl<E: Into<Expression>> BitOr<E> for Expression {
type Output = Self;
fn bitor(self, rhs: E) -> Self::Output {
let mut rhs = rhs.into();
let rhs = rhs.into();
if rhs == 1 || self == 1 {
return 1.into();
}
if let (Some(a), Some(b)) = (self.as_num(), rhs.as_num()) {
return (a != 0 || b != 0).into();
}
rhs.terms.extend(self.terms.iter_ref().copied());
rhs.terms.push(Term::Or);
rhs
let mut terms = rhs.terms.read().clone();
terms.extend(self.terms.read().iter().copied());
terms.push(Term::Or);
Expression::new(terms)
}
}
impl<S: ExpressionStorage> std::iter::Product for GenericExpression<S> {
fn product<I: Iterator<Item = GenericExpression<S>>>(mut iter: I) -> Self {
impl std::iter::Product for Expression {
fn product<I: Iterator<Item = Expression>>(mut iter: I) -> Self {
let Some(mut p) = iter.next() else {
return 0.into();
};
@@ -638,45 +600,45 @@ impl<S: ExpressionStorage> std::iter::Product for GenericExpression<S> {
}
}
impl<S: ExpressionStorage, E: Into<Self>> AddAssign<E> for GenericExpression<S> {
impl<E: Into<Expression>> AddAssign<E> for Expression {
fn add_assign(&mut self, rhs: E) {
*self = self.clone() + rhs;
*self = *self + rhs;
}
}
impl<S: ExpressionStorage, E: Into<Self>> SubAssign<E> for GenericExpression<S> {
impl<E: Into<Expression>> SubAssign<E> for Expression {
fn sub_assign(&mut self, rhs: E) {
*self = self.clone() - rhs;
*self = *self - rhs;
}
}
impl<S: ExpressionStorage, E: Into<Self>> MulAssign<E> for GenericExpression<S> {
impl<E: Into<Expression>> MulAssign<E> for Expression {
fn mul_assign(&mut self, rhs: E) {
*self = self.clone() * rhs;
*self = *self * rhs;
}
}
impl<S: ExpressionStorage, E: Into<Self>> DivAssign<E> for GenericExpression<S> {
impl<E: Into<Expression>> DivAssign<E> for Expression {
fn div_assign(&mut self, rhs: E) {
*self = self.clone() / rhs;
*self = *self / rhs;
}
}
impl<S: ExpressionStorage, E: Into<Self>> RemAssign<E> for GenericExpression<S> {
impl<E: Into<Expression>> RemAssign<E> for Expression {
fn rem_assign(&mut self, rhs: E) {
*self = self.clone() % rhs;
*self = *self % rhs;
}
}
impl<S: ExpressionStorage, E: Into<Self>> BitAndAssign<E> for GenericExpression<S> {
impl<E: Into<Expression>> BitAndAssign<E> for Expression {
fn bitand_assign(&mut self, rhs: E) {
*self = self.clone() & rhs;
*self = *self & rhs;
}
}
impl<S: ExpressionStorage, E: Into<Self>> BitOrAssign<E> for GenericExpression<S> {
impl<E: Into<Expression>> BitOrAssign<E> for Expression {
fn bitor_assign(&mut self, rhs: E) {
*self = self.clone() | rhs;
*self = *self | rhs;
}
}
@@ -689,10 +651,10 @@ define_language! {
}
}
fn luminal_to_egg<S: ExpressionStorage>(expr: &GenericExpression<S>) -> RecExpr<Math> {
fn luminal_to_egg(expr: &Expression) -> RecExpr<Math> {
let mut stack = Vec::new();
for term in expr.terms.iter_ref() {
for term in expr.terms.read().iter() {
match term {
Term::Num(_) | Term::Var(_) => {
stack.push(symbolic_expressions::Sexp::String(format!("{term:?}")))
@@ -741,7 +703,7 @@ fn luminal_to_egg<S: ExpressionStorage>(expr: &GenericExpression<S>) -> RecExpr<
expr
}
fn egg_to_luminal<S: ExpressionStorage>(expr: RecExpr<Math>) -> GenericExpression<S> {
fn egg_to_luminal(expr: RecExpr<Math>) -> Expression {
fn create_postfix(expr: &[Math]) -> Vec<Term> {
match expr.last().unwrap() {
Math::Num(i) => vec![Term::Num(*i)],
@@ -814,9 +776,9 @@ fn egg_to_luminal<S: ExpressionStorage>(expr: RecExpr<Math>) -> GenericExpressio
Math::Symbol(s) => vec![Term::Var(s.as_str().chars().next().unwrap())],
}
}
let mut terms = S::default();
let mut terms = vec![];
terms.extend(create_postfix(expr.as_ref()));
GenericExpression { terms }
Expression::new(terms)
}
type EGraph = egg::EGraph<Math, ConstantFold>;
@@ -962,9 +924,9 @@ fn make_rules() -> Vec<Rewrite> {
]
}
fn egg_simplify<S: ExpressionStorage>(expr: GenericExpression<S>) -> GenericExpression<S> {
fn egg_simplify(e: Expression) -> Expression {
// Convert to egg expression
let expr = luminal_to_egg(&expr);
let expr = luminal_to_egg(&e);
// Simplify
let runner = Runner::default()
.with_expr(&expr)
@@ -978,10 +940,11 @@ fn egg_simplify<S: ExpressionStorage>(expr: GenericExpression<S>) -> GenericExpr
#[cfg(test)]
mod tests {
use crate::prelude::*;
#[test]
fn test_expressions() {
let n = Expression::from('x') + (Expression::from(256) - (Expression::from('x') % 256));
let n = Expression::from('x') + (256 - (Expression::from('x') % 256));
assert_eq!(
n.simplify()
.exec(&[('x', 767)].into_iter().collect())
@@ -992,7 +955,7 @@ mod tests {
#[test]
fn test_minimizations() {
let expr = ((BigExpression::from('a') * 1) + 0) / 1 + (1 - 1);
let expr = ((Expression::from('a') * 1) + 0) / 1 + (1 - 1);
let reduced_expr = expr.simplify();
assert_eq!(reduced_expr, 'a');
}
@@ -1007,9 +970,8 @@ mod tests {
#[test]
fn test_group_terms() {
let s = BigExpression::from('s');
let expr = (s.clone() * ((s.clone() - 4) + 1))
+ (((s.clone() + 1) * ((s.clone() - 4) + 1)) - (s.clone() * ((s.clone() - 4) + 1)));
assert_eq!(expr.simplify().terms.len(), 7);
let s = Expression::from('s');
let expr = (s * ((s - 4) + 1)) + (((s + 1) * ((s - 4) + 1)) - (s * ((s - 4) + 1)));
assert_eq!(expr.simplify().terms.read().len(), 7);
}
}

View File

@@ -3,9 +3,7 @@ use tinyvec::ArrayVec;
use crate::prelude::*;
#[derive(
Debug, Clone, Copy, Default, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize,
)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct ShapeTracker {
pub dims: ArrayVec<[Expression; 6]>,
pub indexes: ArrayVec<[usize; 6]>,
@@ -15,7 +13,8 @@ pub struct ShapeTracker {
}
impl ShapeTracker {
pub fn new(dims: &[impl Into<Expression> + Copy]) -> Self {
#[allow(clippy::not_unsafe_ptr_arg_deref)]
pub fn new(dims: impl ToShape) -> Self {
let mut s = Self {
dims: Default::default(),
indexes: Default::default(),
@@ -23,8 +22,8 @@ impl ShapeTracker {
mask: Default::default(),
padding: Default::default(),
};
for (i, d) in dims.iter().enumerate() {
s.dims.push((*d).into());
for (i, d) in dims.to_shape().into_iter().enumerate() {
s.dims.push(d);
s.indexes.push(i);
s.fake.push(false);
s.mask.push((0.into(), i32::MAX.into())); // Unset upper bound mask are i32::MAX
@@ -34,7 +33,7 @@ impl ShapeTracker {
}
/// Create a shape tracker where all dims are fake
pub fn fake(dims: &[Expression]) -> Self {
pub fn fake(dims: impl ToShape) -> Self {
let mut s = Self::new(dims);
s.fake.iter_mut().for_each(|i| *i = true);
s
@@ -82,13 +81,13 @@ impl ShapeTracker {
}
/// Strides without permute applied
fn unordered_strides(&self) -> Vec<BigExpression> {
fn unordered_strides(&self) -> Vec<Expression> {
let mut strides = (0..self.len())
.rev()
.scan(BigExpression::from(1), |state, i| {
let ret = state.clone();
.scan(Expression::from(1), |state, i| {
let ret = *state;
if !self.fake[i] {
*state = state.clone() * self.dims[i];
*state *= self.dims[i];
}
Some(ret)
})
@@ -98,22 +97,19 @@ impl ShapeTracker {
}
/// Compute strides
pub fn strides(&self) -> Vec<BigExpression> {
pub fn strides(&self) -> Vec<Expression> {
let strides = self.unordered_strides();
self.indexes
.into_iter()
.map(|i| strides[i].clone())
.collect()
self.indexes.into_iter().map(|i| strides[i]).collect()
}
/// Create an expression to translate logical indexes into physical indexes, without expression simplification
pub fn index_expression_no_simplify(&self) -> BigExpression {
pub fn index_expression_no_simplify(&self) -> Expression {
if !self.is_reshaped() {
return 'z'.into();
}
let strides = self.unordered_strides(); // Dimension strides in original order
let mut ind_expr = BigExpression::from(0); // The final index expression
let mut current_elem_size = BigExpression::from(1); // Keep track of the size of each element of the current dim (last dim elem size: 1)
let mut ind_expr = 0.into(); // The final index expression
let mut current_elem_size = Expression::from(1); // Keep track of the size of each element of the current dim (last dim elem size: 1)
// Loop through all dims in reverse order
for i in self.indexes.into_iter().rev() {
@@ -121,15 +117,15 @@ impl ShapeTracker {
let current_size = pad_mask_dim(self.dims[i], self.padding[i], self.mask[i]);
// Don't include fake dimensions in the index expression
if !self.fake[i] {
let mut dim_ind = BigExpression::from('z');
let mut dim_ind = Expression::from('z');
// Remove other dim components
dim_ind /= current_elem_size.clone();
dim_ind /= current_elem_size;
// Get position in current dim
dim_ind %= current_size.clone();
dim_ind %= current_size;
// Add offset
dim_ind += self.mask[i].0 - self.padding[i].0;
// Multiply by stride
dim_ind *= strides[i].clone();
dim_ind *= strides[i];
// Add to index expression
ind_expr += dim_ind;
}
@@ -140,28 +136,28 @@ impl ShapeTracker {
}
/// Create an expression to translate logical indexes into physical indexes
pub fn index_expression(&self) -> BigExpression {
pub fn index_expression(&self) -> Expression {
self.index_expression_no_simplify().simplify()
}
/// If this expression evaluates to 0, the logical index is invalid. Otherwise it is valid. No simplification
pub fn valid_expression_no_simplify(&self) -> BigExpression {
pub fn valid_expression_no_simplify(&self) -> Expression {
if !self.is_reshaped() {
return true.into();
}
let mut ret = BigExpression::from(1);
let mut acc = BigExpression::from(1);
let logical = BigExpression::from('z');
let mut ret = Expression::from(1);
let mut acc = Expression::from(1);
let logical = Expression::from('z');
for i in self.indexes.into_iter().rev() {
let (bottom_slice, top_slice) = self.mask[i];
let logical_sh = pad_mask_dim(self.dims[i], self.padding[i], self.mask[i]);
if !self.fake[i] {
let dim_ind = (logical.clone() / acc.clone()) % logical_sh.clone();
let greater_than = self.padding[i].0.big() - bottom_slice;
let dim_ind = (logical / acc) % logical_sh;
let greater_than = self.padding[i].0 - bottom_slice;
if greater_than != 0 {
ret &= dim_ind.clone().gte(greater_than);
ret &= dim_ind.gte(greater_than);
}
ret &= dim_ind.lt(self.dims[i].big() + self.padding[i].0);
ret &= dim_ind.lt(self.dims[i] + self.padding[i].0);
if top_slice
.to_usize()
.map(|s| self.dims[i].to_usize().map(|dim| s < dim).unwrap_or(true))
@@ -176,22 +172,22 @@ impl ShapeTracker {
}
/// If this expression evaluates to 0, the logical index is invalid. Otherwise it is valid
pub fn valid_expression(&self) -> BigExpression {
pub fn valid_expression(&self) -> Expression {
self.valid_expression_no_simplify().simplify()
}
/// The number of elements in this tensor, including padding and mask
pub fn n_elements(&self) -> BigExpression {
self.shape().into_iter().product::<BigExpression>().max(1)
pub fn n_elements(&self) -> Expression {
self.dims().into_iter().product::<Expression>().max(1)
}
/// The number of elements in this tensor, not including pads and mask
pub fn n_physical_elements(&self) -> BigExpression {
pub fn n_physical_elements(&self) -> Expression {
self.indexes
.into_iter()
.filter(|i| !self.fake[*i])
.map(|i| self.dims[i].big())
.product::<BigExpression>()
.map(|i| self.dims[i])
.product::<Expression>()
.max(1)
}
@@ -222,10 +218,9 @@ impl ShapeTracker {
/// Create a contiguous version
pub fn contiguous(self) -> Self {
Self::new(
&self
.shape()
self.dims()
.into_iter()
.map(|i| i.simplify().small())
.map(|i| i.simplify())
.collect::<Vec<_>>(),
)
}
@@ -241,7 +236,7 @@ impl ShapeTracker {
}
/// Realize the true shape
pub fn shape(&self) -> Vec<BigExpression> {
pub fn dims(&self) -> Vec<Expression> {
self.indexes
.into_iter()
.map(|i| pad_mask_dim(self.dims[i], self.padding[i], self.mask[i]))
@@ -250,7 +245,7 @@ impl ShapeTracker {
/// Realize the true shape and convert it to usizes. All dyn dims must be replaced already
pub fn shape_usize(&self) -> Vec<usize> {
self.shape().iter().map(|e| e.to_usize().unwrap()).collect()
self.dims().iter().map(|e| e.to_usize().unwrap()).collect()
}
/// Take a slice
@@ -284,8 +279,9 @@ impl ShapeTracker {
{
panic!("Adding padding to a masked shape isn't supported")
}
self.padding[ind].0 += s.max(0);
self.padding[ind].1 += e.max(0);
let (s, e) = (s.max(0), e.max(0));
self.padding[ind].0 += s;
self.padding[ind].1 += e;
}
}
@@ -329,11 +325,11 @@ impl ShapeTracker {
}
fn pad_mask_dim(
dim: impl Into<BigExpression>,
dim: impl Into<Expression>,
padding: (Expression, Expression),
mask: (Expression, Expression),
) -> BigExpression {
(dim.into() + padding.0 + padding.1).min(mask.1) - mask.0
) -> Expression {
(padding.0 + padding.1 + dim).min(mask.1) - mask.0
}
/// Resolve shapes between the two trackers to the best of our ability
@@ -364,7 +360,7 @@ mod tests {
use crate::prelude::*;
#[test]
fn test_idx_expr() {
let mut tracker = ShapeTracker::new(&[
let mut tracker = ShapeTracker::new([
Expression::from(10),
Expression::from(5),
Expression::from(3),

View File

@@ -50,7 +50,7 @@ fn test_expand() {
fn test_slice() {
let mut cx = Graph::new();
let a = cx.tensor((2, 3)).set([[1., 2., 3.], [1., 2., 3.]]);
let b = a.slice((Expression::from(1).., ..)).retrieve();
let b = a.slice((1.., ..)).retrieve();
cx.execute();
let d_dev = Cpu::default();
@@ -323,3 +323,18 @@ fn test_max_reduce() {
assert_close(&c.data(), &d_c.as_vec());
assert_close(&d.data(), &d_d.as_vec());
}
#[test]
fn test_dot() {
let mut cx = Graph::new();
let a = cx.tensor(5).set([34.4, -96.0, 144.0, 43.0, 560.0]);
let b = cx.tensor(5).set([43.0, 560.0, 180.0, 700.0, 225.0]);
let c = a.dot(b).retrieve();
cx.execute();
let d_dev = Cpu::default();
let d_a = d_dev.tensor([34.4, -96.0, 144.0, 43.0, 560.0]);
let d_b = d_dev.tensor([43.0, 560.0, 180.0, 700.0, 225.0]);
let d_c = (d_a * d_b).sum();
assert_close(&c.data(), &d_c.as_vec());
}