forked from Rust-related/luminal
Refactored expression system
This commit is contained in:
@@ -25,6 +25,8 @@ as-any = "0.3.1"
|
||||
egg = "0.9.5"
|
||||
symbolic_expressions = "5.0.3"
|
||||
serde = {version="1.0.202", features=["derive"]}
|
||||
thread_local = "1.1.8"
|
||||
generational-box = "0.5.6"
|
||||
|
||||
[dev-dependencies]
|
||||
dfdx = { version = "0.13", features = ["f16"] }
|
||||
|
||||
@@ -231,7 +231,7 @@ impl Compiler for GatherCompiler {
|
||||
.as_data()
|
||||
.unwrap()
|
||||
.2
|
||||
.shape()[2]
|
||||
.dims()[2]
|
||||
.to_usize()
|
||||
.unwrap();
|
||||
|
||||
|
||||
@@ -63,7 +63,7 @@ pub struct MatMul2D;
|
||||
|
||||
impl Operator for MatMul2D {
|
||||
fn process(&mut self, inp: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
|
||||
let (a_shape, b_shape) = (inp[0].1.shape(), inp[1].1.shape());
|
||||
let (a_shape, b_shape) = (inp[0].1.dims(), inp[1].1.dims());
|
||||
let (a_strides, b_strides) = (inp[0].1.strides(), inp[1].1.strides());
|
||||
let a_data = inp[0].0.borrowed().downcast_ref::<Vec<f32>>().unwrap();
|
||||
let b_data = inp[1].0.borrowed().downcast_ref::<Vec<f32>>().unwrap();
|
||||
@@ -151,7 +151,7 @@ pub struct BatchedMatMul2D;
|
||||
// ABCxCD -> ABD
|
||||
impl Operator for BatchedMatMul2D {
|
||||
fn process(&mut self, inp: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
|
||||
let (a_shape, b_shape) = (inp[0].1.shape(), inp[1].1.shape());
|
||||
let (a_shape, b_shape) = (inp[0].1.dims(), inp[1].1.dims());
|
||||
let (a_strides, b_strides) = (inp[0].1.strides(), inp[1].1.strides());
|
||||
let a_data = inp[0].0.borrowed().downcast_ref::<Vec<f32>>().unwrap();
|
||||
let b_data = inp[1].0.borrowed().downcast_ref::<Vec<f32>>().unwrap();
|
||||
|
||||
@@ -8,7 +8,7 @@ use super::binary::Sub;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct ARange {
|
||||
pub size: BigExpression,
|
||||
pub size: Expression,
|
||||
dyn_map: *const FxHashMap<char, usize>,
|
||||
}
|
||||
|
||||
@@ -61,7 +61,7 @@ impl Compiler for ARangeCompiler {
|
||||
};
|
||||
let arange_op = graph
|
||||
.add_op(ARange {
|
||||
size: arange_amount.into(),
|
||||
size: arange_amount,
|
||||
dyn_map: &graph.dyn_map,
|
||||
})
|
||||
.finish();
|
||||
|
||||
@@ -71,7 +71,7 @@ kernel void mkernel(device {type_name} *inp_a [[buffer(0)]], device {type_name}
|
||||
}
|
||||
|
||||
impl<T> MetalKernel for MetalSub<T> {
|
||||
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<BigExpression> {
|
||||
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<Expression> {
|
||||
vec![input_shapes[0].n_elements() * size_of::<T>()]
|
||||
}
|
||||
fn metal_forward(
|
||||
@@ -252,7 +252,7 @@ kernel void mkernel(device {type_name} *inp_a [[buffer(0)]], device {type_name}
|
||||
}
|
||||
|
||||
impl<T> MetalKernel for MetalEqual<T> {
|
||||
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<BigExpression> {
|
||||
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<Expression> {
|
||||
vec![input_shapes[0].n_elements() * size_of::<T>()]
|
||||
}
|
||||
fn metal_forward(
|
||||
@@ -419,7 +419,7 @@ impl<T: MetalFloat> Operator for MetalGather<T> {
|
||||
// Setup buffers
|
||||
let indexes = tensors[0].0.borrowed().downcast_ref::<Vec<f32>>().unwrap();
|
||||
let index_buffer = self.device.new_buffer_with_data(
|
||||
unsafe { std::mem::transmute(indexes.as_ptr()) },
|
||||
indexes.as_ptr() as *const _,
|
||||
(indexes.len() * std::mem::size_of::<f32>()) as u64,
|
||||
MTLResourceOptions::StorageModeShared,
|
||||
);
|
||||
@@ -498,7 +498,7 @@ impl<T: MetalFloat> Compiler for MetalGatherCompiler<T> {
|
||||
.as_data()
|
||||
.unwrap()
|
||||
.2;
|
||||
let embed_dim = emb_shape.shape()[2].to_usize().unwrap();
|
||||
let embed_dim = emb_shape.dims()[2].to_usize().unwrap();
|
||||
let index_shape = graph
|
||||
.edges_connecting(s.get(&indexes), s.get(&ind_copy))
|
||||
.next()
|
||||
|
||||
@@ -236,10 +236,10 @@ impl std::fmt::Debug for CommandBufferWrapper {
|
||||
}
|
||||
|
||||
impl MetalKernel for CommandBufferWrapper {
|
||||
fn intermediate_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<BigExpression> {
|
||||
fn intermediate_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<Expression> {
|
||||
self.wrapper.intermediate_buffer_sizes(input_shapes)
|
||||
}
|
||||
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<BigExpression> {
|
||||
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<Expression> {
|
||||
self.wrapper.output_buffer_sizes(input_shapes)
|
||||
}
|
||||
fn metal_forward(
|
||||
|
||||
@@ -121,7 +121,7 @@ impl<T: MetalFloat> Compiler for ElementwiseFusionCompiler<T> {
|
||||
let mut subexpressions_b = graph
|
||||
.try_get_op::<FusedElementwiseOp<T>>(b)
|
||||
.map(|o| o.subexpressions.clone())
|
||||
.unwrap_or_else(|| vec![(expression_b.clone(), ShapeTracker::default())]);
|
||||
.unwrap_or_else(|| vec![(expression_b.clone(), ShapeTracker::new(()))]);
|
||||
let a_to_b_indexes = graph
|
||||
.edges_connecting(a, b)
|
||||
.map(|e| e.weight().as_data().unwrap().0 as usize)
|
||||
@@ -141,7 +141,7 @@ impl<T: MetalFloat> Compiler for ElementwiseFusionCompiler<T> {
|
||||
let mut subexpressions_a = graph
|
||||
.try_get_op::<FusedElementwiseOp<T>>(a)
|
||||
.map(|o| o.subexpressions.clone())
|
||||
.unwrap_or_else(|| vec![(expression_a.clone(), ShapeTracker::default())]);
|
||||
.unwrap_or_else(|| vec![(expression_a.clone(), ShapeTracker::new(()))]);
|
||||
subexpressions_a.last_mut().unwrap().1 = connecting_shape;
|
||||
// Re-reference b intermediates
|
||||
for i in (0..subexpressions_b.len()).rev() {
|
||||
@@ -236,6 +236,7 @@ impl<T: MetalFloat> Compiler for ElementwiseFusionCompiler<T> {
|
||||
.into_iter()
|
||||
.map(|s| s.simplify_cache(&mut simplification_cache))
|
||||
.collect();
|
||||
let g: *mut Graph = graph;
|
||||
let new_op = graph
|
||||
.add_op(FusedElementwiseOp::<T> {
|
||||
kernel_str: "".to_string(),
|
||||
@@ -247,6 +248,7 @@ impl<T: MetalFloat> Compiler for ElementwiseFusionCompiler<T> {
|
||||
device: device.clone(),
|
||||
output_buffer_sizes,
|
||||
_phantom: Default::default(),
|
||||
graph: g,
|
||||
})
|
||||
.finish();
|
||||
// Add edges to new op
|
||||
@@ -293,17 +295,20 @@ impl<T: MetalFloat> Compiler for ElementwiseFusionCompiler<T> {
|
||||
.into_iter()
|
||||
.map(|s| s.simplify_cache(&mut simplification_cache))
|
||||
.collect();
|
||||
let sh = ShapeTracker::new(());
|
||||
let g: *mut Graph = graph;
|
||||
let new_op = graph
|
||||
.add_op(FusedElementwiseOp::<T> {
|
||||
kernel_str: "".to_string(),
|
||||
kernel: None,
|
||||
dyn_map: &graph.dyn_map,
|
||||
dyn_chars: vec![],
|
||||
subexpressions: vec![(op_string, ShapeTracker::default())],
|
||||
subexpressions: vec![(op_string, sh)],
|
||||
queue: queue.clone(),
|
||||
device: device.clone(),
|
||||
output_buffer_sizes,
|
||||
_phantom: Default::default(),
|
||||
graph: g,
|
||||
})
|
||||
.finish();
|
||||
// Add edges to new op
|
||||
@@ -346,8 +351,9 @@ pub struct FusedElementwiseOp<T> {
|
||||
pub subexpressions: Vec<(String, ShapeTracker)>,
|
||||
pub queue: CommandQueue,
|
||||
pub device: Device,
|
||||
pub output_buffer_sizes: Vec<BigExpression>,
|
||||
pub output_buffer_sizes: Vec<Expression>,
|
||||
pub _phantom: PhantomData<T>,
|
||||
pub graph: *mut Graph,
|
||||
}
|
||||
crate::debug_type!(FusedElementwiseOp);
|
||||
|
||||
@@ -357,7 +363,7 @@ impl<T: MetalFloat> FusedElementwiseOp<T> {
|
||||
input_shapes: Vec<ShapeTracker>,
|
||||
input_regexes: &mut FxHashMap<usize, Regex>,
|
||||
intermediate_match: &Regex,
|
||||
simplification_cache: &mut FxHashMap<BigExpression, BigExpression>,
|
||||
simplification_cache: &mut FxHashMap<Expression, Expression>,
|
||||
) {
|
||||
let mut subexpressions = self.subexpressions.clone();
|
||||
let shapes_used = subexpressions
|
||||
@@ -415,7 +421,7 @@ impl<T: MetalFloat> FusedElementwiseOp<T> {
|
||||
s.iter()
|
||||
.rev()
|
||||
.take(s.len() - 1)
|
||||
.fold(BigExpression::from('z'), |acc, inp| {
|
||||
.fold(Expression::from('z'), |acc, inp| {
|
||||
inp.index_expression().substitute('z', acc)
|
||||
})
|
||||
})
|
||||
@@ -451,8 +457,8 @@ impl<T: MetalFloat> FusedElementwiseOp<T> {
|
||||
input_regexes.get(&i).unwrap()
|
||||
};
|
||||
let (ind, val) = (
|
||||
ind_exp.clone().simplify_cache(simplification_cache),
|
||||
val_exp.clone().simplify_cache(simplification_cache),
|
||||
ind_exp.simplify_cache(simplification_cache),
|
||||
val_exp.simplify_cache(simplification_cache),
|
||||
);
|
||||
*subexp = re
|
||||
.replace_all(
|
||||
@@ -475,10 +481,10 @@ impl<T: MetalFloat> FusedElementwiseOp<T> {
|
||||
.iter()
|
||||
.rev()
|
||||
.fold(
|
||||
(BigExpression::from(true), BigExpression::from('z')),
|
||||
(Expression::from(true), Expression::from('z')),
|
||||
|(_, ind_acc), inp| {
|
||||
(
|
||||
inp.valid_expression().substitute('z', ind_acc.clone()),
|
||||
inp.valid_expression().substitute('z', ind_acc),
|
||||
inp.index_expression().substitute('z', ind_acc),
|
||||
)
|
||||
},
|
||||
@@ -526,7 +532,7 @@ out[idx] = ({type_name})({});
|
||||
}
|
||||
|
||||
impl<T> MetalKernel for FusedElementwiseOp<T> {
|
||||
fn output_buffer_sizes(&self, _: &[ShapeTracker]) -> Vec<BigExpression> {
|
||||
fn output_buffer_sizes(&self, _: &[ShapeTracker]) -> Vec<Expression> {
|
||||
self.output_buffer_sizes.clone()
|
||||
}
|
||||
fn metal_forward(
|
||||
@@ -669,7 +675,7 @@ mod tests {
|
||||
let inp = random_vec_rng(10, &mut rng);
|
||||
let a = cx.named_tensor("a", (2, 5)).set(inp);
|
||||
let mut padded = a
|
||||
.slice((..Expression::from(1), ..))
|
||||
.slice((..1, ..))
|
||||
.cos()
|
||||
.pad(((0, 1), (0, 0)))
|
||||
.exp2()
|
||||
@@ -711,7 +717,7 @@ mod tests {
|
||||
const HEAD_DIM: usize = 4;
|
||||
let freqs = (cx.arange(HEAD_DIM / 2) * 2.0) / (HEAD_DIM as f32);
|
||||
let freqs = 1000000_f32.pow(freqs);
|
||||
let pos = cx.arange(SEQ) + BigExpression::from(0);
|
||||
let pos = cx.arange(SEQ) + 0;
|
||||
let mut emb = pos.expand(1, 1).matmul(freqs.expand(0, SEQ)).retrieve();
|
||||
|
||||
cx.execute();
|
||||
@@ -775,16 +781,12 @@ mod tests {
|
||||
.keep();
|
||||
let freqs = (cx.arange(HEAD_DIM / 2) * 2.0) / (HEAD_DIM as f32);
|
||||
let freqs = 1000000_f32.pow(freqs);
|
||||
let pos = cx.arange(SEQ) + BigExpression::from(0);
|
||||
let pos = cx.arange(SEQ) + 0;
|
||||
let emb = pos.expand(1, 1).matmul(freqs.expand(0, SEQ));
|
||||
// Split input into evens and odds
|
||||
let split = a.reshape((BATCH, N_HEADS, SEQ, HEAD_DIM / 2, 2));
|
||||
let x0 = split
|
||||
.slice((.., .., .., .., ..Expression::from(1)))
|
||||
.contiguous();
|
||||
let x1 = split
|
||||
.slice((.., .., .., .., Expression::from(1)..))
|
||||
.contiguous();
|
||||
let x0 = split.slice((.., .., .., .., ..1)).contiguous();
|
||||
let x1 = split.slice((.., .., .., .., 1..)).contiguous();
|
||||
|
||||
// Apply sin and cos embeddings
|
||||
let x0_out = x0 * emb.cos().expand_to(x0.shape) - x1 * emb.sin().expand_to(x1.shape);
|
||||
@@ -852,15 +854,9 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
fn apply_rotary_embeddings_ggml(
|
||||
input: GraphTensor,
|
||||
prev_seq: BigExpression,
|
||||
) -> GraphTensor {
|
||||
fn apply_rotary_embeddings_ggml(input: GraphTensor, prev_seq: Expression) -> GraphTensor {
|
||||
assert_eq!(input.shape.len(), 4); // batch, n_heads, seq, head_dim
|
||||
let batch = input.shape()[0].small();
|
||||
let n_heads = input.shape()[1].small();
|
||||
let seq = input.shape()[2].small();
|
||||
let head_dim = input.shape()[3].small();
|
||||
let (batch, n_heads, seq, head_dim) = input.dims4();
|
||||
// Get freqs
|
||||
let freqs =
|
||||
(input.graph().arange(head_dim / 2) * 2.0) / (head_dim.to_usize().unwrap() as f32);
|
||||
@@ -892,9 +888,8 @@ mod tests {
|
||||
type Output = (GraphTensor, KVCache);
|
||||
fn forward(&self, (x, (k_cache, v_cache)): (GraphTensor, KVCache)) -> Self::Output {
|
||||
// x: batch, seq, hidden
|
||||
let batch = x.shape()[0].small();
|
||||
let seq = x.shape()[1].small();
|
||||
let prev_seq = k_cache.shape()[2].small();
|
||||
let (batch, seq, _) = x.dims3();
|
||||
let (_, _, prev_seq, _) = k_cache.dims4();
|
||||
// Apply the Projections
|
||||
let queries = x
|
||||
.matmul(self.q_proj.permute((1, 0)))
|
||||
@@ -912,8 +907,8 @@ mod tests {
|
||||
.permute((0, 2, 1, 3));
|
||||
|
||||
// Rotary embed queries and keys
|
||||
let queries = apply_rotary_embeddings_ggml(queries, prev_seq.big());
|
||||
let keys = apply_rotary_embeddings_ggml(keys, prev_seq.big());
|
||||
let queries = apply_rotary_embeddings_ggml(queries, prev_seq);
|
||||
let keys = apply_rotary_embeddings_ggml(keys, prev_seq);
|
||||
|
||||
// Add KV cache
|
||||
let keys = k_cache.concat_along(keys, 2);
|
||||
|
||||
@@ -140,11 +140,11 @@ impl MetalFloat for f16 {
|
||||
|
||||
pub trait MetalKernel: Debug {
|
||||
/// Annotate the buffer sizes of the intermediate buffers
|
||||
fn intermediate_buffer_sizes(&self, _: &[ShapeTracker]) -> Vec<BigExpression> {
|
||||
fn intermediate_buffer_sizes(&self, _: &[ShapeTracker]) -> Vec<Expression> {
|
||||
vec![]
|
||||
}
|
||||
/// Annotate the buffer sizes of the output buffers
|
||||
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<BigExpression>;
|
||||
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<Expression>;
|
||||
/// Set up the kernel on the buffer
|
||||
fn metal_forward(
|
||||
&self,
|
||||
@@ -227,7 +227,7 @@ impl Deref for MetalKernelWrapper {
|
||||
}
|
||||
|
||||
impl MetalKernel for () {
|
||||
fn output_buffer_sizes(&self, _: &[ShapeTracker]) -> Vec<BigExpression> {
|
||||
fn output_buffer_sizes(&self, _: &[ShapeTracker]) -> Vec<Expression> {
|
||||
vec![]
|
||||
}
|
||||
fn metal_forward(
|
||||
@@ -365,14 +365,10 @@ fn render_dyn_dim_inputs(shapes: &[ShapeTracker], offset: usize) -> (Vec<char>,
|
||||
let symbols: Vec<char> = shapes
|
||||
.iter()
|
||||
.flat_map(|st| {
|
||||
st.shape()
|
||||
st.dims()
|
||||
.into_iter()
|
||||
.chain(
|
||||
st.padding
|
||||
.into_iter()
|
||||
.flat_map(|i| [i.0.into(), i.1.into()]),
|
||||
)
|
||||
.chain(st.mask.into_iter().flat_map(|i| [i.0.into(), i.1.into()]))
|
||||
.chain(st.padding.into_iter().flat_map(|i| [i.0, i.1]))
|
||||
.chain(st.mask.into_iter().flat_map(|i| [i.0, i.1]))
|
||||
})
|
||||
.flat_map(|d| d.to_symbols())
|
||||
.unique()
|
||||
@@ -389,9 +385,9 @@ fn render_dyn_dim_inputs(shapes: &[ShapeTracker], offset: usize) -> (Vec<char>,
|
||||
)
|
||||
}
|
||||
|
||||
fn expr_to_metal_string(expr: &BigExpression) -> String {
|
||||
fn expr_to_metal_string(expr: &Expression) -> String {
|
||||
let mut symbols = vec![];
|
||||
for term in expr.terms.clone() {
|
||||
for term in expr.terms.read().clone() {
|
||||
let new_symbol = match term {
|
||||
Term::Num(n) => n.to_string(),
|
||||
Term::Var(c) => {
|
||||
|
||||
@@ -33,15 +33,15 @@ impl<T> Debug for Matmul<T> {
|
||||
const BM: u64 = 8;
|
||||
const BN: u64 = 32;
|
||||
impl<T> MetalKernel for Matmul<T> {
|
||||
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<BigExpression> {
|
||||
let m = input_shapes[0].shape()[input_shapes[0].len() - 2].clone();
|
||||
let n = input_shapes[1].shape()[input_shapes[1].len() - 1].clone();
|
||||
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<Expression> {
|
||||
let m = input_shapes[0].dims()[input_shapes[0].len() - 2];
|
||||
let n = input_shapes[1].dims()[input_shapes[1].len() - 1];
|
||||
let batch_size = input_shapes[0]
|
||||
.shape()
|
||||
.dims()
|
||||
.into_iter()
|
||||
.take(input_shapes[0].len() - 2)
|
||||
.product::<BigExpression>()
|
||||
.max(BigExpression::from(1));
|
||||
.product::<Expression>()
|
||||
.max(1);
|
||||
vec![batch_size * m * n * size_of::<T>()]
|
||||
}
|
||||
fn metal_forward(
|
||||
@@ -54,13 +54,13 @@ impl<T> MetalKernel for Matmul<T> {
|
||||
let (a_shape, b_shape) = (
|
||||
inputs[0]
|
||||
.1
|
||||
.shape()
|
||||
.dims()
|
||||
.into_iter()
|
||||
.map(|i| i.to_usize().unwrap())
|
||||
.collect::<Vec<_>>(),
|
||||
inputs[1]
|
||||
.1
|
||||
.shape()
|
||||
.dims()
|
||||
.into_iter()
|
||||
.map(|i| i.to_usize().unwrap())
|
||||
.collect::<Vec<_>>(),
|
||||
@@ -158,7 +158,7 @@ impl<T: MetalFloat> Operator for Matmul<T> {
|
||||
// Setup command queue / command buffer / encoder
|
||||
let command_buffer = self.queue.new_command_buffer();
|
||||
|
||||
let (a_shape, b_shape) = (inp[0].1.shape(), inp[1].1.shape());
|
||||
let (a_shape, b_shape) = (inp[0].1.dims(), inp[1].1.dims());
|
||||
let n = b_shape.last().unwrap().to_usize().unwrap();
|
||||
let batch_size = a_shape
|
||||
.iter()
|
||||
|
||||
@@ -71,7 +71,7 @@ pub struct MetalARange<T: MetalFloat> {
|
||||
pipeline: ComputePipelineState,
|
||||
queue: CommandQueue,
|
||||
device: Device,
|
||||
pub size: BigExpression,
|
||||
pub size: Expression,
|
||||
dyn_map: *const FxHashMap<char, usize>,
|
||||
_phantom: PhantomData<T>,
|
||||
}
|
||||
@@ -86,7 +86,7 @@ impl<T: MetalFloat> MetalARange<T> {
|
||||
pub fn new(
|
||||
device: Device,
|
||||
queue: CommandQueue,
|
||||
size: BigExpression,
|
||||
size: Expression,
|
||||
dyn_map: *const FxHashMap<char, usize>,
|
||||
) -> Self {
|
||||
let type_name = T::type_name();
|
||||
@@ -109,8 +109,8 @@ kernel void metal_arange(device {type_name} *out [[buffer(0)]], device int& n_el
|
||||
}
|
||||
|
||||
impl<T: MetalFloat> MetalKernel for MetalARange<T> {
|
||||
fn output_buffer_sizes(&self, _: &[ShapeTracker]) -> Vec<BigExpression> {
|
||||
vec![self.size.clone() * std::mem::size_of::<f16>()]
|
||||
fn output_buffer_sizes(&self, _: &[ShapeTracker]) -> Vec<Expression> {
|
||||
vec![self.size * std::mem::size_of::<f16>()]
|
||||
}
|
||||
fn metal_forward(
|
||||
&self,
|
||||
@@ -213,7 +213,7 @@ impl<T: MetalFloat> Compiler for ARangeCompiler<T> {
|
||||
.add_op(MetalARange::<T>::new(
|
||||
dev.clone(),
|
||||
queue.clone(),
|
||||
arange_amount.into(),
|
||||
arange_amount,
|
||||
&graph.dyn_map,
|
||||
))
|
||||
.finish();
|
||||
|
||||
@@ -161,7 +161,7 @@ macro_rules! metal_unary_op {
|
||||
}
|
||||
|
||||
impl<T> MetalKernel for $op_name<T> {
|
||||
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<BigExpression> {
|
||||
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<Expression> {
|
||||
vec![input_shapes[0].contiguous().n_elements() * size_of::<T>()]
|
||||
}
|
||||
fn metal_forward(
|
||||
@@ -288,7 +288,7 @@ kernel void mkernel(device {type_name} *inp_a [[buffer(0)]], device {type_name}
|
||||
}
|
||||
|
||||
impl<T> MetalKernel for MetalAdd<T> {
|
||||
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<BigExpression> {
|
||||
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<Expression> {
|
||||
vec![input_shapes[0].n_elements() * size_of::<T>()]
|
||||
}
|
||||
fn metal_forward(
|
||||
@@ -405,7 +405,7 @@ kernel void mkernel(device {type_name} *inp_a [[buffer(0)]], device {type_name}
|
||||
}
|
||||
}
|
||||
impl<T> MetalKernel for MetalMul<T> {
|
||||
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<BigExpression> {
|
||||
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<Expression> {
|
||||
vec![input_shapes[0].n_elements() * size_of::<T>()]
|
||||
}
|
||||
fn metal_forward(
|
||||
@@ -532,7 +532,7 @@ kernel void mkernel(device {type_name} *inp_a [[buffer(0)]], device {type_name}
|
||||
}
|
||||
|
||||
impl<T> MetalKernel for MetalLessThan<T> {
|
||||
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<BigExpression> {
|
||||
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<Expression> {
|
||||
vec![input_shapes[0].n_elements() * size_of::<T>()]
|
||||
}
|
||||
fn metal_forward(
|
||||
@@ -650,7 +650,7 @@ kernel void mkernel(device {type_name} *inp_a [[buffer(0)]], device {type_name}
|
||||
}
|
||||
}
|
||||
impl<T> MetalKernel for MetalMod<T> {
|
||||
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<BigExpression> {
|
||||
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<Expression> {
|
||||
vec![input_shapes[0].n_elements() * size_of::<T>()]
|
||||
}
|
||||
fn metal_forward(
|
||||
@@ -785,7 +785,7 @@ kernel void mkernel(device {type_name} *inp [[buffer(0)]], device {type_name} *o
|
||||
}
|
||||
|
||||
impl<T> MetalKernel for MetalSumReduce<T> {
|
||||
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<BigExpression> {
|
||||
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<Expression> {
|
||||
let mut sh = input_shapes[0];
|
||||
sh.remove_dim(self.dim);
|
||||
vec![sh.n_elements() * size_of::<T>()]
|
||||
@@ -802,19 +802,19 @@ impl<T> MetalKernel for MetalSumReduce<T> {
|
||||
let inp_size = sh.n_elements().to_usize().unwrap();
|
||||
let front_size: usize = inputs[0]
|
||||
.1
|
||||
.shape()
|
||||
.dims()
|
||||
.iter()
|
||||
.take(self.dim)
|
||||
.map(|i| i.to_usize().unwrap())
|
||||
.product();
|
||||
let back_size: usize = inputs[0]
|
||||
.1
|
||||
.shape()
|
||||
.dims()
|
||||
.iter()
|
||||
.skip(self.dim + 1)
|
||||
.map(|i| i.to_usize().unwrap())
|
||||
.product();
|
||||
let dim_size = inputs[0].1.shape()[self.dim].to_usize().unwrap();
|
||||
let dim_size = inputs[0].1.dims()[self.dim].to_usize().unwrap();
|
||||
|
||||
let encoder =
|
||||
command_buffer.compute_command_encoder_with_descriptor(ComputePassDescriptor::new());
|
||||
@@ -938,7 +938,7 @@ kernel void mkernel(device {type_name} *inp [[buffer(0)]], device {type_name} *o
|
||||
}
|
||||
}
|
||||
impl<T> MetalKernel for MetalMaxReduce<T> {
|
||||
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<BigExpression> {
|
||||
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<Expression> {
|
||||
let mut sh = input_shapes[0];
|
||||
sh.remove_dim(self.dim);
|
||||
vec![sh.n_elements() * size_of::<T>()]
|
||||
@@ -955,19 +955,19 @@ impl<T> MetalKernel for MetalMaxReduce<T> {
|
||||
let inp_size = sh.contiguous().n_elements().to_usize().unwrap();
|
||||
let front_size: usize = inputs[0]
|
||||
.1
|
||||
.shape()
|
||||
.dims()
|
||||
.iter()
|
||||
.take(self.dim)
|
||||
.map(|i| i.to_usize().unwrap())
|
||||
.product();
|
||||
let back_size: usize = inputs[0]
|
||||
.1
|
||||
.shape()
|
||||
.dims()
|
||||
.iter()
|
||||
.skip(self.dim + 1)
|
||||
.map(|i| i.to_usize().unwrap())
|
||||
.product();
|
||||
let dim_size = inputs[0].1.shape()[self.dim].to_usize().unwrap();
|
||||
let dim_size = inputs[0].1.dims()[self.dim].to_usize().unwrap();
|
||||
|
||||
let encoder =
|
||||
command_buffer.compute_command_encoder_with_descriptor(ComputePassDescriptor::new());
|
||||
@@ -1052,9 +1052,10 @@ impl<T: MetalFloat + 'static> Compiler for PrimitiveCompiler<T> {
|
||||
> 0
|
||||
{
|
||||
// Copy outputs to device
|
||||
let sh = ShapeTracker::new(());
|
||||
let copy_node = graph
|
||||
.add_op(MetalCopyToDevice::<T>::new(dev.clone()))
|
||||
.input(function_node, 0, ShapeTracker::default())
|
||||
.input(function_node, 0, sh)
|
||||
.finish();
|
||||
|
||||
// Switch outgoing edges from input to copy_node
|
||||
|
||||
@@ -1,14 +1,4 @@
|
||||
use std::{
|
||||
any::Any,
|
||||
collections::hash_map::Entry,
|
||||
fs::File,
|
||||
io::Write,
|
||||
marker::PhantomData,
|
||||
mem::size_of,
|
||||
path::{Path, PathBuf},
|
||||
str::FromStr,
|
||||
sync::Arc,
|
||||
};
|
||||
use std::{any::Any, marker::PhantomData, mem::size_of, sync::Arc};
|
||||
|
||||
use metal_rs::{
|
||||
objc::rc::autoreleasepool, Buffer, CommandBufferRef, CommandQueue, ComputePassDescriptor,
|
||||
@@ -20,20 +10,10 @@ use luminal::{
|
||||
op::{InputTensor, Operator},
|
||||
prelude::*,
|
||||
};
|
||||
use rustc_hash::FxHashMap;
|
||||
use serde_json::{json, Value};
|
||||
|
||||
use crate::{
|
||||
binary::MetalGather,
|
||||
compile_lib,
|
||||
elementwise_fusion::FusedElementwiseOp,
|
||||
get_buffer_from_tensor,
|
||||
matmul::Matmul,
|
||||
other::MetalARange,
|
||||
prim::{MetalConstant, MetalCopyFromDevice, MetalCopyToDevice, MetalMaxReduce, MetalSumReduce},
|
||||
select_function_from_lib,
|
||||
unary::{MetalMeanReduce, MetalStdNorm},
|
||||
MetalBuffer, MetalFloat, MetalKernel, MetalKernelWrapper,
|
||||
binary::MetalGather, get_buffer_from_tensor, MetalBuffer, MetalFloat, MetalKernel,
|
||||
MetalKernelWrapper,
|
||||
};
|
||||
|
||||
use super::{compile_function, SetInt};
|
||||
@@ -300,15 +280,15 @@ kernel void matvec(
|
||||
}
|
||||
|
||||
impl<T> MetalKernel for QuantizedMatmul<T> {
|
||||
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<BigExpression> {
|
||||
let m = input_shapes[0].shape()[input_shapes[0].len() - 2].clone();
|
||||
let n = input_shapes[1].shape()[input_shapes[1].len() - 1].clone();
|
||||
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<Expression> {
|
||||
let m = input_shapes[0].dims()[input_shapes[0].len() - 2];
|
||||
let n = input_shapes[1].dims()[input_shapes[1].len() - 1];
|
||||
let batch_size = input_shapes[0]
|
||||
.shape()
|
||||
.dims()
|
||||
.into_iter()
|
||||
.take(input_shapes[0].len() - 2)
|
||||
.product::<BigExpression>()
|
||||
.max(BigExpression::from(1));
|
||||
.product::<Expression>()
|
||||
.max(1);
|
||||
vec![batch_size * m * n * size_of::<T>()]
|
||||
}
|
||||
fn metal_forward(
|
||||
@@ -325,13 +305,13 @@ impl<T> MetalKernel for QuantizedMatmul<T> {
|
||||
let (a_shape, b_shape) = (
|
||||
inputs[0]
|
||||
.1
|
||||
.shape()
|
||||
.dims()
|
||||
.into_iter()
|
||||
.map(|i| i.to_usize().unwrap())
|
||||
.collect::<Vec<_>>(),
|
||||
inputs[1]
|
||||
.1
|
||||
.shape()
|
||||
.dims()
|
||||
.into_iter()
|
||||
.map(|i| i.to_usize().unwrap())
|
||||
.collect::<Vec<_>>(),
|
||||
@@ -393,7 +373,7 @@ impl<T: MetalFloat> Operator for QuantizedMatmul<T> {
|
||||
// Setup command queue / command buffer / encoder
|
||||
let command_buffer = self.queue.new_command_buffer();
|
||||
|
||||
let (a_shape, b_shape) = (inp[0].1.shape(), inp[1].1.shape());
|
||||
let (a_shape, b_shape) = (inp[0].1.dims(), inp[1].1.dims());
|
||||
let n = b_shape[1].to_usize().unwrap();
|
||||
let (batch_size, m) = if a_shape.len() == 3 {
|
||||
(
|
||||
@@ -474,7 +454,7 @@ impl<T: MetalFloat> Operator for QuantizedGather<T> {
|
||||
// Setup buffers
|
||||
let indexes = tensors[0].0.borrowed().downcast_ref::<Vec<f32>>().unwrap();
|
||||
let index_buffer = self.device.new_buffer_with_data(
|
||||
unsafe { std::mem::transmute(indexes.as_ptr()) },
|
||||
indexes.as_ptr() as *const _,
|
||||
(indexes.len() * std::mem::size_of::<f32>()) as u64,
|
||||
MTLResourceOptions::StorageModeShared,
|
||||
);
|
||||
@@ -570,297 +550,297 @@ impl<T: MetalFloat + Default> Compiler for MetalQuantizedCompiler<T> {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct SerializeQuantizedGraph<T> {
|
||||
path: PathBuf,
|
||||
_phantom: PhantomData<T>,
|
||||
}
|
||||
// #[derive(Default)]
|
||||
// pub struct SerializeQuantizedGraph<T> {
|
||||
// path: PathBuf,
|
||||
// _phantom: PhantomData<T>,
|
||||
// }
|
||||
|
||||
impl<T: MetalFloat> SerializeQuantizedGraph<T> {
|
||||
pub fn new(path: impl AsRef<Path>) -> Self {
|
||||
Self {
|
||||
path: path.as_ref().to_path_buf(),
|
||||
_phantom: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
// impl<T: MetalFloat> SerializeQuantizedGraph<T> {
|
||||
// pub fn new(path: impl AsRef<Path>) -> Self {
|
||||
// Self {
|
||||
// path: path.as_ref().to_path_buf(),
|
||||
// _phantom: Default::default(),
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
impl<T: MetalFloat + Default> Compiler for SerializeQuantizedGraph<T> {
|
||||
type Output = ();
|
||||
fn compile<To: ToIdsMut>(&self, graph: &mut Graph, mut nodes: To) {
|
||||
// Serialize and save graph away
|
||||
let mut ops = vec![];
|
||||
for node in graph.node_indices().collect::<Vec<_>>() {
|
||||
if graph.check_node_type::<Function>(node) {
|
||||
continue;
|
||||
}
|
||||
let data = if let Some(op) = graph.try_get_op::<MetalMeanReduce<T>>(node) {
|
||||
json!({
|
||||
"dim": op.3,
|
||||
"shape": op.7,
|
||||
})
|
||||
} else if let Some(op) = graph.try_get_op::<MetalARange<T>>(node) {
|
||||
json!({
|
||||
"size": op.size,
|
||||
})
|
||||
} else if let Some(op) = graph.try_get_op::<MetalStdNorm<T>>(node) {
|
||||
json!({
|
||||
"eps": op.epsilon,
|
||||
})
|
||||
} else if let Some(op) = graph.try_get_op::<MetalSumReduce<T>>(node) {
|
||||
json!({
|
||||
"dim": op.dim,
|
||||
"shape": op.shape,
|
||||
})
|
||||
} else if let Some(op) = graph.try_get_op::<MetalMaxReduce<T>>(node) {
|
||||
json!({
|
||||
"dim": op.dim,
|
||||
"shape": op.shape,
|
||||
})
|
||||
} else if let Some(op) = graph.try_get_op::<QuantizedGather<T>>(node) {
|
||||
json!({
|
||||
"embed_dim": op.embed_dim,
|
||||
})
|
||||
} else if let Some(op) = graph.try_get_op::<MetalGather<T>>(node) {
|
||||
json!({
|
||||
"embed_dim": op.embed_dim,
|
||||
})
|
||||
} else if let Some(op) = graph.try_get_op::<MetalConstant<T>>(node) {
|
||||
json!({
|
||||
"value": op.0,
|
||||
})
|
||||
} else if let Some(op) = graph.try_get_op::<Matmul<T>>(node) {
|
||||
json!({
|
||||
"matmul_kernel": op.matmul_kernel,
|
||||
"matvec_kernel": op.matvec_kernel,
|
||||
})
|
||||
} else if let Some(op) = graph.try_get_op::<FusedElementwiseOp<T>>(node) {
|
||||
json!({
|
||||
"kernel_str": op.kernel_str,
|
||||
"dyn_chars": op.dyn_chars,
|
||||
"output_buffer_sizes": op.output_buffer_sizes,
|
||||
})
|
||||
} else if graph.check_node_type::<QuantizedMatmul<T>>(node)
|
||||
|| graph.check_node_type::<MetalCopyFromDevice<T>>(node)
|
||||
|| graph.check_node_type::<MetalCopyToDevice<T>>(node)
|
||||
{
|
||||
json!({})
|
||||
} else {
|
||||
panic!(
|
||||
"Found unserializable op: {:?}",
|
||||
graph.node_weight(node).unwrap()
|
||||
);
|
||||
};
|
||||
ops.push(json!({
|
||||
"type": format!("{:?}", graph.node_weight(node).unwrap()),
|
||||
"id": node.index(),
|
||||
"data": data
|
||||
}));
|
||||
}
|
||||
let edges = graph
|
||||
.edge_indices()
|
||||
.map(|e| (e, graph.edge_endpoints(e).unwrap()))
|
||||
.map(|(e, (a, b))| (a.index(), b.index(), *graph.edge_weight(e).unwrap()))
|
||||
.collect::<Vec<_>>();
|
||||
let value = json!({
|
||||
"nodes": nodes.to_ids_mut().iter().map(|i| i.index()).collect::<Vec<_>>(),
|
||||
"ops": ops,
|
||||
"edges": edges,
|
||||
"no_delete": graph.no_delete.iter().map(|i| i.index()).collect::<Vec<_>>(),
|
||||
"to_retrieve": graph.to_retrieve.iter().map(|(k, v)| (k.index(), v)).collect::<Vec<_>>(),
|
||||
});
|
||||
File::create(&self.path)
|
||||
.unwrap()
|
||||
.write_all(value.to_string().as_bytes())
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
// impl<T: MetalFloat + Default> Compiler for SerializeQuantizedGraph<T> {
|
||||
// type Output = ();
|
||||
// fn compile<To: ToIdsMut>(&self, graph: &mut Graph, mut nodes: To) {
|
||||
// // Serialize and save graph away
|
||||
// let mut ops = vec![];
|
||||
// for node in graph.node_indices().collect::<Vec<_>>() {
|
||||
// if graph.check_node_type::<Function>(node) {
|
||||
// continue;
|
||||
// }
|
||||
// let data = if let Some(op) = graph.try_get_op::<MetalMeanReduce<T>>(node) {
|
||||
// json!({
|
||||
// "dim": op.3,
|
||||
// "shape": op.7,
|
||||
// })
|
||||
// } else if let Some(op) = graph.try_get_op::<MetalARange<T>>(node) {
|
||||
// json!({
|
||||
// "size": op.size,
|
||||
// })
|
||||
// } else if let Some(op) = graph.try_get_op::<MetalStdNorm<T>>(node) {
|
||||
// json!({
|
||||
// "eps": op.epsilon,
|
||||
// })
|
||||
// } else if let Some(op) = graph.try_get_op::<MetalSumReduce<T>>(node) {
|
||||
// json!({
|
||||
// "dim": op.dim,
|
||||
// "shape": op.shape,
|
||||
// })
|
||||
// } else if let Some(op) = graph.try_get_op::<MetalMaxReduce<T>>(node) {
|
||||
// json!({
|
||||
// "dim": op.dim,
|
||||
// "shape": op.shape,
|
||||
// })
|
||||
// } else if let Some(op) = graph.try_get_op::<QuantizedGather<T>>(node) {
|
||||
// json!({
|
||||
// "embed_dim": op.embed_dim,
|
||||
// })
|
||||
// } else if let Some(op) = graph.try_get_op::<MetalGather<T>>(node) {
|
||||
// json!({
|
||||
// "embed_dim": op.embed_dim,
|
||||
// })
|
||||
// } else if let Some(op) = graph.try_get_op::<MetalConstant<T>>(node) {
|
||||
// json!({
|
||||
// "value": op.0,
|
||||
// })
|
||||
// } else if let Some(op) = graph.try_get_op::<Matmul<T>>(node) {
|
||||
// json!({
|
||||
// "matmul_kernel": op.matmul_kernel,
|
||||
// "matvec_kernel": op.matvec_kernel,
|
||||
// })
|
||||
// } else if let Some(op) = graph.try_get_op::<FusedElementwiseOp<T>>(node) {
|
||||
// json!({
|
||||
// "kernel_str": op.kernel_str,
|
||||
// "dyn_chars": op.dyn_chars,
|
||||
// "output_buffer_sizes": op.output_buffer_sizes,
|
||||
// })
|
||||
// } else if graph.check_node_type::<QuantizedMatmul<T>>(node)
|
||||
// || graph.check_node_type::<MetalCopyFromDevice<T>>(node)
|
||||
// || graph.check_node_type::<MetalCopyToDevice<T>>(node)
|
||||
// {
|
||||
// json!({})
|
||||
// } else {
|
||||
// panic!(
|
||||
// "Found unserializable op: {:?}",
|
||||
// graph.node_weight(node).unwrap()
|
||||
// );
|
||||
// };
|
||||
// ops.push(json!({
|
||||
// "type": format!("{:?}", graph.node_weight(node).unwrap()),
|
||||
// "id": node.index(),
|
||||
// "data": data
|
||||
// }));
|
||||
// }
|
||||
// let edges = graph
|
||||
// .edge_indices()
|
||||
// .map(|e| (e, graph.edge_endpoints(e).unwrap()))
|
||||
// .map(|(e, (a, b))| (a.index(), b.index(), *graph.edge_weight(e).unwrap()))
|
||||
// .collect::<Vec<_>>();
|
||||
// let value = json!({
|
||||
// "nodes": nodes.to_ids_mut().iter().map(|i| i.index()).collect::<Vec<_>>(),
|
||||
// "ops": ops,
|
||||
// "edges": edges,
|
||||
// "no_delete": graph.no_delete.iter().map(|i| i.index()).collect::<Vec<_>>(),
|
||||
// "to_retrieve": graph.to_retrieve.iter().map(|(k, v)| (k.index(), v)).collect::<Vec<_>>(),
|
||||
// });
|
||||
// File::create(&self.path)
|
||||
// .unwrap()
|
||||
// .write_all(value.to_string().as_bytes())
|
||||
// .unwrap();
|
||||
// }
|
||||
// }
|
||||
|
||||
/// Deserialize a metal graph
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DeserializeQuantizedGraph<T>(String, PhantomData<T>);
|
||||
// /// Deserialize a metal graph
|
||||
// #[derive(Debug, Clone)]
|
||||
// pub struct DeserializeQuantizedGraph<T>(String, PhantomData<T>);
|
||||
|
||||
impl<T> DeserializeQuantizedGraph<T> {
|
||||
pub fn new(data: impl ToString) -> Self {
|
||||
Self(data.to_string(), Default::default())
|
||||
}
|
||||
}
|
||||
// impl<T> DeserializeQuantizedGraph<T> {
|
||||
// pub fn new(data: impl ToString) -> Self {
|
||||
// Self(data.to_string(), Default::default())
|
||||
// }
|
||||
// }
|
||||
|
||||
impl<T: MetalFloat> Compiler for DeserializeQuantizedGraph<T> {
|
||||
type Output = ();
|
||||
fn compile<To: ToIdsMut>(&self, graph: &mut Graph, mut ids: To) -> Self::Output {
|
||||
let mut value = Value::from_str(&self.0).unwrap();
|
||||
let dev = Device::system_default().unwrap();
|
||||
let queue = dev.new_command_queue();
|
||||
// Create ops
|
||||
let mut op_map = FxHashMap::<usize, NodeIndex>::default();
|
||||
let matmul_library = compile_lib(&dev, include_str!("kernels/gemm.metal"));
|
||||
let matvec_library = compile_lib(&dev, include_str!("kernels/gemv.metal"));
|
||||
let quantized_matmul = QuantizedMatmul::<T>::new(dev.clone(), queue.clone());
|
||||
for op in value["ops"].as_array_mut().unwrap() {
|
||||
let name = op["type"].as_str().unwrap().to_string();
|
||||
let new_id = if name == "MetalMeanReduce" {
|
||||
graph
|
||||
.add_op(MetalMeanReduce::<T>::new(
|
||||
dev.clone(),
|
||||
queue.clone(),
|
||||
op["data"]["dim"].as_u64().unwrap() as usize,
|
||||
serde_json::from_value(op["data"]["shape"].take()).unwrap(),
|
||||
&graph.dyn_map,
|
||||
))
|
||||
.finish()
|
||||
} else if name == "MetalSumReduce" {
|
||||
graph
|
||||
.add_op(MetalSumReduce::<T>::new(
|
||||
serde_json::from_value(op["data"]["shape"].take()).unwrap(),
|
||||
op["data"]["dim"].as_u64().unwrap() as usize,
|
||||
dev.clone(),
|
||||
queue.clone(),
|
||||
&graph.dyn_map,
|
||||
))
|
||||
.finish()
|
||||
} else if name == "MetalMaxReduce" {
|
||||
graph
|
||||
.add_op(MetalMaxReduce::<T>::new(
|
||||
serde_json::from_value(op["data"]["shape"].take()).unwrap(),
|
||||
op["data"]["dim"].as_u64().unwrap() as usize,
|
||||
dev.clone(),
|
||||
queue.clone(),
|
||||
&graph.dyn_map,
|
||||
))
|
||||
.finish()
|
||||
} else if name == "MetalStdNorm" {
|
||||
graph
|
||||
.add_op(MetalStdNorm::<T>::new(
|
||||
op["data"]["eps"].as_f64().unwrap() as f32,
|
||||
dev.clone(),
|
||||
queue.clone(),
|
||||
))
|
||||
.finish()
|
||||
} else if name.contains("MetalARange") {
|
||||
graph
|
||||
.add_op(MetalARange::<T>::new(
|
||||
dev.clone(),
|
||||
queue.clone(),
|
||||
serde_json::from_value(op["data"]["size"].take()).unwrap(),
|
||||
&graph.dyn_map,
|
||||
))
|
||||
.finish()
|
||||
} else if name == "QuantizedGather" {
|
||||
graph
|
||||
.add_op(QuantizedGather::<T>::new(
|
||||
dev.clone(),
|
||||
queue.clone(),
|
||||
op["data"]["embed_dim"].as_u64().unwrap() as usize,
|
||||
))
|
||||
.finish()
|
||||
} else if name == "MetalGather" {
|
||||
graph
|
||||
.add_op(MetalGather::<T>::new(
|
||||
dev.clone(),
|
||||
queue.clone(),
|
||||
op["data"]["embed_dim"].as_u64().unwrap() as usize,
|
||||
))
|
||||
.finish()
|
||||
} else if name.contains("MetalConstant") {
|
||||
graph
|
||||
.add_op(MetalConstant::<T>(
|
||||
serde_json::from_value(op["data"]["value"].take()).unwrap(),
|
||||
dev.clone(),
|
||||
&graph.dyn_map,
|
||||
Default::default(),
|
||||
))
|
||||
.finish()
|
||||
} else if name == "FusedElementwiseOp" {
|
||||
let mut fused_op = FusedElementwiseOp::<T> {
|
||||
kernel: None,
|
||||
dyn_map: &graph.dyn_map,
|
||||
kernel_str: op["data"]["kernel_str"].as_str().unwrap().to_string(),
|
||||
dyn_chars: serde_json::from_value(op["data"]["dyn_chars"].take()).unwrap(),
|
||||
subexpressions: vec![],
|
||||
queue: queue.clone(),
|
||||
device: dev.clone(),
|
||||
output_buffer_sizes: serde_json::from_value(
|
||||
op["data"]["output_buffer_sizes"].take(),
|
||||
)
|
||||
.unwrap(),
|
||||
_phantom: Default::default(),
|
||||
};
|
||||
fused_op.compile(&dev);
|
||||
graph.add_op(fused_op).finish()
|
||||
} else if name == "Matmul" {
|
||||
let matmul_kernel =
|
||||
serde_json::from_value::<String>(op["data"]["matmul_kernel"].take()).unwrap();
|
||||
let matvec_kernel =
|
||||
serde_json::from_value::<String>(op["data"]["matvec_kernel"].take()).unwrap();
|
||||
graph
|
||||
.add_op(Matmul::<T> {
|
||||
matmul_pipeline: select_function_from_lib(
|
||||
&matmul_library,
|
||||
&matmul_kernel,
|
||||
&dev,
|
||||
),
|
||||
matvec_pipeline: select_function_from_lib(
|
||||
&matvec_library,
|
||||
&matvec_kernel,
|
||||
&dev,
|
||||
),
|
||||
matmul_kernel,
|
||||
matvec_kernel,
|
||||
queue: queue.clone(),
|
||||
device: dev.clone(),
|
||||
_phantom: Default::default(),
|
||||
})
|
||||
.finish()
|
||||
} else if name == "QuantizedMatmul" {
|
||||
graph.add_op(quantized_matmul.clone()).finish()
|
||||
} else if name == "MetalCopyToDevice" {
|
||||
graph
|
||||
.add_op(MetalCopyToDevice::<T>::new(dev.clone()))
|
||||
.finish()
|
||||
} else if name == "MetalCopyFromDevice" {
|
||||
graph.add_op(MetalCopyFromDevice::<T>::default()).finish()
|
||||
} else {
|
||||
panic!("Found unexpected serialized op: {name}");
|
||||
};
|
||||
op_map.insert(op["id"].as_u64().unwrap() as usize, new_id);
|
||||
}
|
||||
// Remap nodes that are in the op_map
|
||||
for (saved, mut new) in value["nodes"]
|
||||
.as_array()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.zip(ids.to_ids_mut())
|
||||
{
|
||||
let saved = saved.as_u64().unwrap() as usize;
|
||||
if let Entry::Vacant(e) = op_map.entry(saved) {
|
||||
e.insert(*new);
|
||||
} else {
|
||||
graph.remove_node(*new);
|
||||
remap(*new, op_map[&saved], &mut new, graph);
|
||||
}
|
||||
}
|
||||
// Create edges
|
||||
let edges =
|
||||
serde_json::from_value::<Vec<(usize, usize, Dependency)>>(value["edges"].take())
|
||||
.unwrap();
|
||||
for (a, b, dep) in edges {
|
||||
graph.add_edge(op_map[&a], op_map[&b], dep);
|
||||
}
|
||||
// Update no_delete and to_retrieve
|
||||
graph.no_delete = serde_json::from_value::<Vec<usize>>(value["no_delete"].take())
|
||||
.unwrap()
|
||||
.into_iter()
|
||||
.map(|i| op_map[&i])
|
||||
.collect();
|
||||
graph.to_retrieve =
|
||||
serde_json::from_value::<Vec<(usize, (u8, ShapeTracker))>>(value["to_retrieve"].take())
|
||||
.unwrap()
|
||||
.into_iter()
|
||||
.map(|(k, v)| (op_map[&k], v))
|
||||
.collect();
|
||||
}
|
||||
}
|
||||
// impl<T: MetalFloat> Compiler for DeserializeQuantizedGraph<T> {
|
||||
// type Output = ();
|
||||
// fn compile<To: ToIdsMut>(&self, graph: &mut Graph, mut ids: To) -> Self::Output {
|
||||
// let mut value = Value::from_str(&self.0).unwrap();
|
||||
// let dev = Device::system_default().unwrap();
|
||||
// let queue = dev.new_command_queue();
|
||||
// // Create ops
|
||||
// let mut op_map = FxHashMap::<usize, NodeIndex>::default();
|
||||
// let matmul_library = compile_lib(&dev, include_str!("kernels/gemm.metal"));
|
||||
// let matvec_library = compile_lib(&dev, include_str!("kernels/gemv.metal"));
|
||||
// let quantized_matmul = QuantizedMatmul::<T>::new(dev.clone(), queue.clone());
|
||||
// for op in value["ops"].as_array_mut().unwrap() {
|
||||
// let name = op["type"].as_str().unwrap().to_string();
|
||||
// let new_id = if name == "MetalMeanReduce" {
|
||||
// graph
|
||||
// .add_op(MetalMeanReduce::<T>::new(
|
||||
// dev.clone(),
|
||||
// queue.clone(),
|
||||
// op["data"]["dim"].as_u64().unwrap() as usize,
|
||||
// serde_json::from_value(op["data"]["shape"].take()).unwrap(),
|
||||
// &graph.dyn_map,
|
||||
// ))
|
||||
// .finish()
|
||||
// } else if name == "MetalSumReduce" {
|
||||
// graph
|
||||
// .add_op(MetalSumReduce::<T>::new(
|
||||
// serde_json::from_value(op["data"]["shape"].take()).unwrap(),
|
||||
// op["data"]["dim"].as_u64().unwrap() as usize,
|
||||
// dev.clone(),
|
||||
// queue.clone(),
|
||||
// &graph.dyn_map,
|
||||
// ))
|
||||
// .finish()
|
||||
// } else if name == "MetalMaxReduce" {
|
||||
// graph
|
||||
// .add_op(MetalMaxReduce::<T>::new(
|
||||
// serde_json::from_value(op["data"]["shape"].take()).unwrap(),
|
||||
// op["data"]["dim"].as_u64().unwrap() as usize,
|
||||
// dev.clone(),
|
||||
// queue.clone(),
|
||||
// &graph.dyn_map,
|
||||
// ))
|
||||
// .finish()
|
||||
// } else if name == "MetalStdNorm" {
|
||||
// graph
|
||||
// .add_op(MetalStdNorm::<T>::new(
|
||||
// op["data"]["eps"].as_f64().unwrap() as f32,
|
||||
// dev.clone(),
|
||||
// queue.clone(),
|
||||
// ))
|
||||
// .finish()
|
||||
// } else if name.contains("MetalARange") {
|
||||
// graph
|
||||
// .add_op(MetalARange::<T>::new(
|
||||
// dev.clone(),
|
||||
// queue.clone(),
|
||||
// serde_json::from_value(op["data"]["size"].take()).unwrap(),
|
||||
// &graph.dyn_map,
|
||||
// ))
|
||||
// .finish()
|
||||
// } else if name == "QuantizedGather" {
|
||||
// graph
|
||||
// .add_op(QuantizedGather::<T>::new(
|
||||
// dev.clone(),
|
||||
// queue.clone(),
|
||||
// op["data"]["embed_dim"].as_u64().unwrap() as usize,
|
||||
// ))
|
||||
// .finish()
|
||||
// } else if name == "MetalGather" {
|
||||
// graph
|
||||
// .add_op(MetalGather::<T>::new(
|
||||
// dev.clone(),
|
||||
// queue.clone(),
|
||||
// op["data"]["embed_dim"].as_u64().unwrap() as usize,
|
||||
// ))
|
||||
// .finish()
|
||||
// } else if name.contains("MetalConstant") {
|
||||
// graph
|
||||
// .add_op(MetalConstant::<T>(
|
||||
// serde_json::from_value(op["data"]["value"].take()).unwrap(),
|
||||
// dev.clone(),
|
||||
// &graph.dyn_map,
|
||||
// Default::default(),
|
||||
// ))
|
||||
// .finish()
|
||||
// } else if name == "FusedElementwiseOp" {
|
||||
// let mut fused_op = FusedElementwiseOp::<T> {
|
||||
// kernel: None,
|
||||
// dyn_map: &graph.dyn_map,
|
||||
// kernel_str: op["data"]["kernel_str"].as_str().unwrap().to_string(),
|
||||
// dyn_chars: serde_json::from_value(op["data"]["dyn_chars"].take()).unwrap(),
|
||||
// subexpressions: vec![],
|
||||
// queue: queue.clone(),
|
||||
// device: dev.clone(),
|
||||
// output_buffer_sizes: serde_json::from_value(
|
||||
// op["data"]["output_buffer_sizes"].take(),
|
||||
// )
|
||||
// .unwrap(),
|
||||
// _phantom: Default::default(),
|
||||
// };
|
||||
// fused_op.compile(&dev);
|
||||
// graph.add_op(fused_op).finish()
|
||||
// } else if name == "Matmul" {
|
||||
// let matmul_kernel =
|
||||
// serde_json::from_value::<String>(op["data"]["matmul_kernel"].take()).unwrap();
|
||||
// let matvec_kernel =
|
||||
// serde_json::from_value::<String>(op["data"]["matvec_kernel"].take()).unwrap();
|
||||
// graph
|
||||
// .add_op(Matmul::<T> {
|
||||
// matmul_pipeline: select_function_from_lib(
|
||||
// &matmul_library,
|
||||
// &matmul_kernel,
|
||||
// &dev,
|
||||
// ),
|
||||
// matvec_pipeline: select_function_from_lib(
|
||||
// &matvec_library,
|
||||
// &matvec_kernel,
|
||||
// &dev,
|
||||
// ),
|
||||
// matmul_kernel,
|
||||
// matvec_kernel,
|
||||
// queue: queue.clone(),
|
||||
// device: dev.clone(),
|
||||
// _phantom: Default::default(),
|
||||
// })
|
||||
// .finish()
|
||||
// } else if name == "QuantizedMatmul" {
|
||||
// graph.add_op(quantized_matmul.clone()).finish()
|
||||
// } else if name == "MetalCopyToDevice" {
|
||||
// graph
|
||||
// .add_op(MetalCopyToDevice::<T>::new(dev.clone()))
|
||||
// .finish()
|
||||
// } else if name == "MetalCopyFromDevice" {
|
||||
// graph.add_op(MetalCopyFromDevice::<T>::default()).finish()
|
||||
// } else {
|
||||
// panic!("Found unexpected serialized op: {name}");
|
||||
// };
|
||||
// op_map.insert(op["id"].as_u64().unwrap() as usize, new_id);
|
||||
// }
|
||||
// // Remap nodes that are in the op_map
|
||||
// for (saved, mut new) in value["nodes"]
|
||||
// .as_array()
|
||||
// .unwrap()
|
||||
// .iter()
|
||||
// .zip(ids.to_ids_mut())
|
||||
// {
|
||||
// let saved = saved.as_u64().unwrap() as usize;
|
||||
// if let Entry::Vacant(e) = op_map.entry(saved) {
|
||||
// e.insert(*new);
|
||||
// } else {
|
||||
// graph.remove_node(*new);
|
||||
// remap(*new, op_map[&saved], &mut new, graph);
|
||||
// }
|
||||
// }
|
||||
// // Create edges
|
||||
// let edges =
|
||||
// serde_json::from_value::<Vec<(usize, usize, Dependency)>>(value["edges"].take())
|
||||
// .unwrap();
|
||||
// for (a, b, dep) in edges {
|
||||
// graph.add_edge(op_map[&a], op_map[&b], dep);
|
||||
// }
|
||||
// // Update no_delete and to_retrieve
|
||||
// graph.no_delete = serde_json::from_value::<Vec<usize>>(value["no_delete"].take())
|
||||
// .unwrap()
|
||||
// .into_iter()
|
||||
// .map(|i| op_map[&i])
|
||||
// .collect();
|
||||
// graph.to_retrieve =
|
||||
// serde_json::from_value::<Vec<(usize, (u8, ShapeTracker))>>(value["to_retrieve"].take())
|
||||
// .unwrap()
|
||||
// .into_iter()
|
||||
// .map(|(k, v)| (op_map[&k], v))
|
||||
// .collect();
|
||||
// }
|
||||
// }
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
@@ -299,7 +299,7 @@ fn btreeset_intersection<T: Ord>(mut a: BTreeSet<T>, b: &BTreeSet<T>) -> BTreeSe
|
||||
struct AllocateMetalBuffers {
|
||||
dev: Device,
|
||||
dyn_map: *const FxHashMap<char, usize>,
|
||||
buffer_sizes: Vec<BigExpression>,
|
||||
buffer_sizes: Vec<Expression>,
|
||||
buffers: Arc<UnsafeCell<Vec<Buffer>>>,
|
||||
}
|
||||
impl Debug for AllocateMetalBuffers {
|
||||
|
||||
@@ -80,8 +80,8 @@ fn test_rotate() {
|
||||
const D: usize = 4;
|
||||
let data = random_vec(D * B * F);
|
||||
let a = cx.tensor((F, B, D)).set(data.clone()).permute((1, 0, 2));
|
||||
let x1 = a.slice((.., .., ..Expression::from(D / 2)));
|
||||
let x2 = a.slice((.., .., Expression::from(D / 2)..));
|
||||
let x1 = a.slice((.., .., ..D / 2));
|
||||
let x2 = a.slice((.., .., D / 2..));
|
||||
let mut rotated_a = (-x2).concat_along(x1, 1).retrieve();
|
||||
cx.execute();
|
||||
let unopt = rotated_a.data();
|
||||
@@ -709,7 +709,7 @@ fn test_slice() {
|
||||
let data = random_vec(256);
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor(256).set(data.clone());
|
||||
let mut c = a.slice(..Expression::from(20)).contiguous().retrieve();
|
||||
let mut c = a.slice(..20).contiguous().retrieve();
|
||||
|
||||
cx.compile(MetalCompiler::<f16>::default(), &mut c);
|
||||
cx.execute();
|
||||
@@ -752,10 +752,10 @@ fn test_pad_contig() {
|
||||
let a_data = random_vec_rng(m * k, &mut rng);
|
||||
let mut a = cx.tensor(('M', 'K')).set_dyn(a_data, (m, k)).retrieve();
|
||||
let mut b = a
|
||||
.pad(&[(0, 0.into()), (0, Expression::from(24) - 'K')])
|
||||
.pad(((0, 0), (0, Expression::from(24) - 'K')))
|
||||
.contiguous()
|
||||
.retrieve();
|
||||
let mut c = (a.slice((.., ..Expression::from(k))) / 1.0).retrieve();
|
||||
let mut c = (a.slice((.., ..k)) / 1.0).retrieve();
|
||||
|
||||
cx.compile(MetalCompiler::<f16>::default(), (&mut a, &mut b, &mut c));
|
||||
cx.execute();
|
||||
@@ -771,7 +771,7 @@ fn test_movement() {
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor(32).set(data.clone());
|
||||
let b = a.pad((0, 10)).contiguous().retrieve();
|
||||
let mut c = b.slice((..Expression::from(25),)).contiguous().retrieve();
|
||||
let mut c = b.slice((..25,)).contiguous().retrieve();
|
||||
|
||||
cx.compile(MetalCompiler::<f16>::default(), &mut c);
|
||||
cx.execute();
|
||||
|
||||
@@ -83,7 +83,7 @@ kernel void mkernel(device {type_name} *inp [[buffer(0)]], device {type_name} *o
|
||||
}
|
||||
|
||||
impl<T> MetalKernel for MetalMeanReduce<T> {
|
||||
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<BigExpression> {
|
||||
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<Expression> {
|
||||
let mut sh = input_shapes[0];
|
||||
sh.remove_dim(self.3);
|
||||
vec![sh.n_elements() * size_of::<T>()]
|
||||
@@ -101,19 +101,19 @@ impl<T> MetalKernel for MetalMeanReduce<T> {
|
||||
|
||||
let front_size: usize = inputs[0]
|
||||
.1
|
||||
.shape()
|
||||
.dims()
|
||||
.iter()
|
||||
.take(self.3)
|
||||
.map(|i| i.to_usize().unwrap())
|
||||
.product();
|
||||
let back_size: usize = inputs[0]
|
||||
.1
|
||||
.shape()
|
||||
.dims()
|
||||
.iter()
|
||||
.skip(self.3 + 1)
|
||||
.map(|i| i.to_usize().unwrap())
|
||||
.product();
|
||||
let dim_size = inputs[0].1.shape()[self.3].to_usize().unwrap();
|
||||
let dim_size = inputs[0].1.dims()[self.3].to_usize().unwrap();
|
||||
|
||||
let encoder =
|
||||
command_buffer.compute_command_encoder_with_descriptor(ComputePassDescriptor::new());
|
||||
@@ -306,7 +306,7 @@ kernel void kernel_std_norm(
|
||||
}
|
||||
|
||||
impl<T> MetalKernel for MetalStdNorm<T> {
|
||||
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<BigExpression> {
|
||||
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<Expression> {
|
||||
vec![input_shapes[0].n_elements() * size_of::<T>()]
|
||||
}
|
||||
|
||||
@@ -320,7 +320,7 @@ impl<T> MetalKernel for MetalStdNorm<T> {
|
||||
let encoder =
|
||||
command_buffer.compute_command_encoder_with_descriptor(ComputePassDescriptor::new());
|
||||
encoder.set_compute_pipeline_state(&self.pipeline);
|
||||
let row_size = inputs[0].1.shape().last().unwrap().to_usize().unwrap();
|
||||
let row_size = inputs[0].1.dims().last().unwrap().to_usize().unwrap();
|
||||
|
||||
// Set inputs
|
||||
encoder.set_buffer(0, Some(inputs[0].0), 0);
|
||||
@@ -329,7 +329,7 @@ impl<T> MetalKernel for MetalStdNorm<T> {
|
||||
encoder.set_f32(3, self.epsilon);
|
||||
let batch_size = inputs[0]
|
||||
.1
|
||||
.shape()
|
||||
.dims()
|
||||
.into_iter()
|
||||
.take(inputs[0].1.len() - 1)
|
||||
.map(|i| i.to_usize().unwrap())
|
||||
@@ -432,7 +432,7 @@ impl<T: MetalFloat> Compiler for StdNormCompiler<T> {
|
||||
}
|
||||
}
|
||||
if sh
|
||||
.shape()
|
||||
.dims()
|
||||
.last()
|
||||
.unwrap()
|
||||
.to_usize()
|
||||
@@ -515,7 +515,7 @@ kernel void kernel_metal_exp(device {type_name} *inp [[buffer(0)]], device {type
|
||||
}
|
||||
|
||||
impl<T: MetalFloat> MetalKernel for MetalExp<T> {
|
||||
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<BigExpression> {
|
||||
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<Expression> {
|
||||
vec![input_shapes[0].n_elements() * std::mem::size_of::<T>()]
|
||||
}
|
||||
fn metal_forward(
|
||||
@@ -653,7 +653,7 @@ kernel void kernel_metal_cos(device {type_name} *inp [[buffer(0)]], device {type
|
||||
}
|
||||
|
||||
impl<T: MetalFloat> MetalKernel for MetalCos<T> {
|
||||
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<BigExpression> {
|
||||
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<Expression> {
|
||||
vec![input_shapes[0].n_elements() * std::mem::size_of::<T>()]
|
||||
}
|
||||
fn metal_forward(
|
||||
|
||||
@@ -48,11 +48,10 @@ impl SerializeModule for Conv1D {
|
||||
}
|
||||
}
|
||||
|
||||
// Single
|
||||
impl Module<GraphTensor> for Conv1D {
|
||||
type Output = GraphTensor;
|
||||
fn forward(&self, input: GraphTensor) -> Self::Output {
|
||||
assert_eq!(input.shape()[input.shape.len() - 2], self.ch_in);
|
||||
assert_eq!(input.dims()[input.shape.len() - 2], self.ch_in);
|
||||
// Input: batch_dims, ch_in, dim_in
|
||||
// Reshape to 2 batch dims
|
||||
let n_expands = 4 - input.shape.len();
|
||||
@@ -61,9 +60,9 @@ impl Module<GraphTensor> for Conv1D {
|
||||
inp = inp.expand(0, 1);
|
||||
}
|
||||
|
||||
let batch1 = inp.shape()[0].small();
|
||||
let batch2 = inp.shape()[1].small();
|
||||
let dim_in = input.shape().last().unwrap().small();
|
||||
let batch1 = inp.dims()[0];
|
||||
let batch2 = inp.dims()[1];
|
||||
let dim_in = *input.dims().last().unwrap();
|
||||
let dim_out = (((dim_in + 2 * self.padding - self.dilation * (self.kernel - 1) - 1)
|
||||
/ self.stride)
|
||||
+ 1)
|
||||
@@ -86,7 +85,7 @@ impl Module<GraphTensor> for Conv1D {
|
||||
}
|
||||
|
||||
// Reshape back to original shape
|
||||
let mut final_shape = out.shape();
|
||||
let mut final_shape = out.dims();
|
||||
for _ in 0..n_expands {
|
||||
final_shape.remove(0);
|
||||
}
|
||||
@@ -96,6 +95,7 @@ impl Module<GraphTensor> for Conv1D {
|
||||
|
||||
pub struct Conv2D {
|
||||
pub weight: GraphTensor, // ch_out, ch_in * kernel_x * kernel_y
|
||||
pub bias: Option<GraphTensor>, // ch_out
|
||||
kernel: (usize, usize),
|
||||
stride: (usize, usize),
|
||||
dilation: (usize, usize),
|
||||
@@ -110,10 +110,16 @@ impl Conv2D {
|
||||
kernel: (usize, usize),
|
||||
stride: (usize, usize),
|
||||
dilation: (usize, usize),
|
||||
bias: bool,
|
||||
cx: &mut Graph,
|
||||
) -> Self {
|
||||
Self {
|
||||
weight: cx.named_tensor("Weight", (ch_out, ch_in * kernel.0 * kernel.1)),
|
||||
bias: if bias {
|
||||
Some(cx.named_tensor("Bias", ch_out))
|
||||
} else {
|
||||
None
|
||||
},
|
||||
kernel,
|
||||
stride,
|
||||
dilation,
|
||||
@@ -126,15 +132,22 @@ impl Conv2D {
|
||||
impl SerializeModule for Conv2D {
|
||||
fn serialize(&self, s: &mut luminal::module::Serializer) {
|
||||
s.tensor("weight", self.weight);
|
||||
if let Some(bias) = self.bias {
|
||||
s.tensor("bias", bias);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Single
|
||||
impl Conv2D {
|
||||
pub fn forward(&self, input: GraphTensor) -> GraphTensor {
|
||||
// Input: (ch_in, dimx_in, dimy_in)
|
||||
let dimx_in = input.shape()[1].small();
|
||||
let dimy_in = input.shape()[2].small();
|
||||
pub fn forward(&self, mut input: GraphTensor) -> GraphTensor {
|
||||
// Input: (batch (optional), ch_in, dimx_in, dimy_in)
|
||||
let mut expanded = false;
|
||||
if input.shape.len() == 3 {
|
||||
// Expand batch
|
||||
input = input.expand(0, 1);
|
||||
expanded = true;
|
||||
}
|
||||
let (batch, _, dimx_in, dimy_in) = input.dims4();
|
||||
let dimx_out = (((dimx_in - self.dilation.0 * (self.kernel.0 - 1) - 1) / self.stride.0)
|
||||
+ 1)
|
||||
.simplify();
|
||||
@@ -143,21 +156,34 @@ impl Conv2D {
|
||||
.simplify();
|
||||
let input_pooled = input
|
||||
.pool_last_dim(self.kernel.1, self.stride.1, self.dilation.1)
|
||||
.permute((0, 2, 3, 1))
|
||||
.permute((0, 1, 3, 4, 2))
|
||||
.pool_last_dim(self.kernel.0, self.stride.0, self.dilation.0)
|
||||
.permute((0, 4, 2, 3, 1))
|
||||
.permute((0, 1, 5, 3, 4, 2))
|
||||
.reshape((
|
||||
batch,
|
||||
self.ch_in * self.kernel.0 * self.kernel.1,
|
||||
dimx_out * dimy_out,
|
||||
));
|
||||
|
||||
self.weight
|
||||
.matmul(input_pooled)
|
||||
.reshape((self.ch_out, dimx_out, dimy_out))
|
||||
let mut o = self.weight.expand(0, batch).matmul(input_pooled).reshape((
|
||||
batch,
|
||||
self.ch_out,
|
||||
dimx_out,
|
||||
dimy_out,
|
||||
));
|
||||
if let Some(b) = self.bias {
|
||||
o += b.expand_to(o.shape);
|
||||
}
|
||||
if expanded {
|
||||
o.reshape((self.ch_out, dimx_out, dimy_out))
|
||||
} else {
|
||||
o
|
||||
}
|
||||
}
|
||||
}
|
||||
pub struct Conv3D {
|
||||
pub weight: GraphTensor, // ch_out, ch_in * kernel_x * kernel_y * kernel_z
|
||||
pub bias: Option<GraphTensor>, // ch_out
|
||||
kernel: (usize, usize, usize),
|
||||
stride: (usize, usize, usize),
|
||||
dilation: (usize, usize, usize),
|
||||
@@ -172,10 +198,16 @@ impl Conv3D {
|
||||
kernel: (usize, usize, usize),
|
||||
stride: (usize, usize, usize),
|
||||
dilation: (usize, usize, usize),
|
||||
bias: bool,
|
||||
cx: &mut Graph,
|
||||
) -> Self {
|
||||
Self {
|
||||
weight: cx.named_tensor("Weight", (ch_out, ch_in * kernel.0 * kernel.1 * kernel.2)),
|
||||
bias: if bias {
|
||||
Some(cx.named_tensor("Bias", ch_out))
|
||||
} else {
|
||||
None
|
||||
},
|
||||
kernel,
|
||||
stride,
|
||||
dilation,
|
||||
@@ -188,15 +220,18 @@ impl Conv3D {
|
||||
impl SerializeModule for Conv3D {
|
||||
fn serialize(&self, s: &mut luminal::module::Serializer) {
|
||||
s.tensor("weight", self.weight);
|
||||
if let Some(bias) = self.bias {
|
||||
s.tensor("bias", bias);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Conv3D {
|
||||
pub fn forward(&self, input: GraphTensor) -> GraphTensor {
|
||||
// Input: ch_in, dimx_in, dimy_in, dimz_in
|
||||
let dimx_in = input.shape()[1].small();
|
||||
let dimy_in = input.shape()[2].small();
|
||||
let dimz_in = input.shape()[3].small();
|
||||
let dimx_in = input.dims()[1];
|
||||
let dimy_in = input.dims()[2];
|
||||
let dimz_in = input.dims()[3];
|
||||
let dimx_out = (((dimx_in - self.dilation.0 * (self.kernel.0 - 1) - 1) / self.stride.0)
|
||||
+ 1)
|
||||
.simplify();
|
||||
@@ -264,8 +299,8 @@ mod tests {
|
||||
cx.execute();
|
||||
|
||||
assert_eq!(
|
||||
out1.shape(),
|
||||
vec![BigExpression::from(CH_OUT), BigExpression::from(DIM_OUT)]
|
||||
out1.dims(),
|
||||
vec![Expression::from(CH_OUT), Expression::from(DIM_OUT)]
|
||||
);
|
||||
assert_close(&out1.data(), &[0.0948, -0.9498, -1.2342]);
|
||||
}
|
||||
@@ -421,6 +456,7 @@ mod tests {
|
||||
(KERNELX, KERNELY),
|
||||
(STRIDEX, STRIDEY),
|
||||
(DILATIONX, DILATIONY),
|
||||
false,
|
||||
&mut cx,
|
||||
);
|
||||
model.weight.set(vec![
|
||||
@@ -485,6 +521,7 @@ mod tests {
|
||||
(KERNELX, KERNELY, KERNELZ),
|
||||
(STRIDEX, STRIDEY, STRIDEZ),
|
||||
(DILATIONX, DILATIONY, DILATIONZ),
|
||||
false,
|
||||
&mut cx,
|
||||
);
|
||||
let weights = vec![
|
||||
|
||||
@@ -51,7 +51,7 @@ impl Module<GraphTensor> for Embedding {
|
||||
self.weight.gather(inp)
|
||||
};
|
||||
// Unflatten
|
||||
let mut new_shape = input.shape();
|
||||
let mut new_shape = input.dims();
|
||||
new_shape.push(self.embedding_dim.into());
|
||||
out.reshape(new_shape)
|
||||
}
|
||||
|
||||
@@ -62,17 +62,16 @@ impl Module<(GraphTensor, GraphTensor, GraphTensor)> for MultiHeadSelfAttention
|
||||
GraphTensor, // batch, s1, dim
|
||||
),
|
||||
) -> Self::Output {
|
||||
let orig_query_shape = queries.shape();
|
||||
let s1 = keys.shape()[keys.shape.len() - 2].small();
|
||||
let s2 = queries.shape()[queries.shape.len() - 2].small();
|
||||
let orig_query_shape = queries.dims();
|
||||
let s1 = keys.dims()[keys.shape.len() - 2];
|
||||
let s2 = queries.dims()[queries.shape.len() - 2];
|
||||
let n_batches = queries
|
||||
.shape()
|
||||
.dims()
|
||||
.into_iter()
|
||||
.take(queries.shape.len() - 2)
|
||||
.product::<BigExpression>()
|
||||
.max(1)
|
||||
.small();
|
||||
let dim = queries.shape().last().unwrap().small();
|
||||
.product::<Expression>()
|
||||
.max(1);
|
||||
let dim = *queries.dims().last().unwrap();
|
||||
let keys = keys.reshape((n_batches, s1, dim));
|
||||
let values = values.reshape((n_batches, s1, dim));
|
||||
let queries = queries.reshape((n_batches, s2, dim));
|
||||
|
||||
@@ -73,16 +73,15 @@ impl Module<(GraphTensor, GraphTensor)> for TransformerDecoderBlock {
|
||||
// Input: batch_dims, seq1, dim
|
||||
// From_enc: batch_dims, seq2, dim
|
||||
// Flatten to single batch dim
|
||||
let seq1 = input.shape()[input.shape.len() - 2].small();
|
||||
let seq2 = from_enc.shape()[from_enc.shape.len() - 2].small();
|
||||
let dim = input.shape().last().unwrap().small();
|
||||
let seq1 = input.dims()[input.shape.len() - 2];
|
||||
let seq2 = from_enc.dims()[from_enc.shape.len() - 2];
|
||||
let dim = *input.dims().last().unwrap();
|
||||
let n_batches = input
|
||||
.shape()
|
||||
.dims()
|
||||
.into_iter()
|
||||
.take(input.shape.len() - 2)
|
||||
.product::<BigExpression>()
|
||||
.max(1)
|
||||
.small();
|
||||
.product::<Expression>()
|
||||
.max(1);
|
||||
let inp = input.reshape((n_batches, seq1, dim));
|
||||
let fe = from_enc.reshape((n_batches, seq2, dim));
|
||||
// Batched forward pass
|
||||
|
||||
@@ -71,18 +71,18 @@ impl Module<GraphTensor> for TransformerEncoderBlock {
|
||||
// Input: batch_dims, sequence, dim
|
||||
// Reshape to 1 batch dim, sequence, dim
|
||||
let n_batches = input
|
||||
.shape()
|
||||
.dims()
|
||||
.into_iter()
|
||||
.take(input.shape.len() - 2)
|
||||
.product::<BigExpression>()
|
||||
.product::<Expression>()
|
||||
.max(1);
|
||||
let sequence = input.shape()[input.shape.len() - 2].small();
|
||||
let dim = input.shape()[input.shape.len() - 1].small();
|
||||
let sequence = input.dims()[input.shape.len() - 2];
|
||||
let dim = input.dims()[input.shape.len() - 1];
|
||||
let x = input.reshape((n_batches, sequence, dim));
|
||||
let x = x + self.attention.forward(x);
|
||||
let x = x.layer_norm(2, 1e-5);
|
||||
let x = x + self.ff.forward(x);
|
||||
x.layer_norm(2, 1e-5).reshape(input.shape())
|
||||
x.layer_norm(2, 1e-5).reshape(input.dims())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ impl Autograd {
|
||||
// Run dfs with a starting stack and record all encountered nodes in a set
|
||||
fn build_dfs_set(
|
||||
stack: &mut Vec<NodeIndex>,
|
||||
graph: &MainGraph,
|
||||
graph: &StorageGraph,
|
||||
direction: Direction,
|
||||
) -> FxHashSet<NodeIndex> {
|
||||
let mut set = FxHashSet::default();
|
||||
@@ -61,7 +61,7 @@ impl Compiler for Autograd {
|
||||
*loss,
|
||||
(
|
||||
graph.constant(1.0).id,
|
||||
ShapeTracker::default(), // Assume scalar loss for now
|
||||
ShapeTracker::new(()), // Assume scalar loss for now
|
||||
),
|
||||
);
|
||||
let weight_set = params.iter().copied().collect::<FxHashSet<_>>();
|
||||
@@ -217,7 +217,7 @@ fn add_grad(
|
||||
} else if let Some(MaxReduce(dim)) = graph.try_get_op(fwd.id) {
|
||||
pre_fwd_shape.remove_dim(*dim);
|
||||
}
|
||||
if grad.shape.shape() != pre_fwd_shape.shape() {
|
||||
if grad.shape.dims() != pre_fwd_shape.dims() {
|
||||
if !grad.shape.is_contiguous() {
|
||||
grad = grad.contiguous();
|
||||
}
|
||||
|
||||
@@ -79,7 +79,7 @@ pub fn cross_entropy_with_logits_loss(
|
||||
let inv_last_axis_numel = 1.0
|
||||
/ logits
|
||||
.graph()
|
||||
.constant(logits.shape.shape().last().unwrap());
|
||||
.constant(*logits.shape.dims().last().unwrap());
|
||||
let probs = logits.log_softmax(logits.shape.last_axis());
|
||||
(-(probs * target_probabilities).mean_reduce(target_probabilities.shape.all_axes()))
|
||||
/ inv_last_axis_numel
|
||||
@@ -102,7 +102,7 @@ pub fn kl_div_with_logits_loss(
|
||||
let inv_last_axis_numel = 1.0
|
||||
/ logits
|
||||
.graph()
|
||||
.constant(logits.shape.shape().last().unwrap());
|
||||
.constant(*logits.shape.dims().last().unwrap());
|
||||
let probs = logits.log_softmax(logits.shape.last_axis());
|
||||
(-((probs - target_probabilities.ln()) * target_probabilities)
|
||||
.mean_reduce(target_probabilities.shape.all_axes()))
|
||||
|
||||
@@ -54,7 +54,7 @@ fn main() {
|
||||
cx.keep_tensors(&model_weights);
|
||||
let (logits, mut cache_dest) = model.forward((input, &cache_src));
|
||||
let mut logits = logits
|
||||
.slice((.., (Expression::from('s') - 1).., ..))
|
||||
.slice((.., Expression::from('s') - 1.., ..))
|
||||
.retrieve();
|
||||
cache_dest.keep();
|
||||
println!("\t\t - {}ms", now.elapsed().as_millis());
|
||||
|
||||
@@ -49,7 +49,7 @@ impl SerializeModule for Mlp {
|
||||
}
|
||||
}
|
||||
|
||||
fn apply_rotary_embeddings_ggml(input: GraphTensor, prev_seq: BigExpression) -> GraphTensor {
|
||||
fn apply_rotary_embeddings_ggml(input: GraphTensor, prev_seq: Expression) -> GraphTensor {
|
||||
assert_eq!(input.shape.len(), 4); // batch, n_heads, seq, head_dim
|
||||
let (batch, n_heads, seq, head_dim) = input.dims4();
|
||||
// Get freqs
|
||||
@@ -102,8 +102,8 @@ impl Module<(GraphTensor, KVCache)> for SelfAttention {
|
||||
.permute((0, 2, 1, 3));
|
||||
|
||||
// Rotary embed queries and keys
|
||||
let queries = apply_rotary_embeddings_ggml(queries, prev_seq.big());
|
||||
let keys = apply_rotary_embeddings_ggml(keys, prev_seq.big());
|
||||
let queries = apply_rotary_embeddings_ggml(queries, prev_seq);
|
||||
let keys = apply_rotary_embeddings_ggml(keys, prev_seq);
|
||||
|
||||
// Add KV cache
|
||||
let keys = k_cache.concat_along(keys, 2);
|
||||
|
||||
@@ -49,7 +49,7 @@ impl SerializeModule for Mlp {
|
||||
}
|
||||
}
|
||||
|
||||
fn apply_rotary_embeddings_ggml(input: GraphTensor, prev_seq: BigExpression) -> GraphTensor {
|
||||
fn apply_rotary_embeddings_ggml(input: GraphTensor, prev_seq: Expression) -> GraphTensor {
|
||||
assert_eq!(input.shape.len(), 4); // batch, n_heads, seq, head_dim
|
||||
let (batch, n_heads, seq, head_dim) = input.dims4();
|
||||
// Get freqs
|
||||
@@ -102,8 +102,8 @@ impl Module<(GraphTensor, KVCache)> for SelfAttention {
|
||||
.permute((0, 2, 1, 3));
|
||||
|
||||
// Rotary embed queries and keys
|
||||
let queries = apply_rotary_embeddings_ggml(queries, prev_seq.big());
|
||||
let keys = apply_rotary_embeddings_ggml(keys, prev_seq.big());
|
||||
let queries = apply_rotary_embeddings_ggml(queries, prev_seq);
|
||||
let keys = apply_rotary_embeddings_ggml(keys, prev_seq);
|
||||
|
||||
// Add KV cache
|
||||
let keys = k_cache.concat_along(keys, 2);
|
||||
|
||||
@@ -47,7 +47,7 @@ impl SerializeModule for Mlp {
|
||||
}
|
||||
}
|
||||
|
||||
fn apply_rotary_embeddings_ggml(input: GraphTensor, prev_seq: BigExpression) -> GraphTensor {
|
||||
fn apply_rotary_embeddings_ggml(input: GraphTensor, prev_seq: Expression) -> GraphTensor {
|
||||
assert_eq!(input.shape.len(), 4); // batch, n_heads, seq, head_dim
|
||||
let (batch, n_heads, seq, head_dim) = input.dims4();
|
||||
// Get freqs
|
||||
@@ -100,8 +100,8 @@ impl Module<(GraphTensor, KVCache)> for SelfAttention {
|
||||
.permute((0, 2, 1, 3));
|
||||
|
||||
// Rotary embed queries and keys
|
||||
let queries = apply_rotary_embeddings_ggml(queries, prev_seq.big());
|
||||
let keys = apply_rotary_embeddings_ggml(keys, prev_seq.big());
|
||||
let queries = apply_rotary_embeddings_ggml(queries, prev_seq);
|
||||
let keys = apply_rotary_embeddings_ggml(keys, prev_seq);
|
||||
|
||||
// Add KV cache
|
||||
let keys = k_cache.concat_along(keys, 2);
|
||||
|
||||
@@ -19,3 +19,7 @@ memmap2 = "0.9.4"
|
||||
colored = "2.1.0"
|
||||
itertools = "0.12.1"
|
||||
tokenizers = "0.15.2"
|
||||
image = "0.25.1"
|
||||
imageproc = "0.25.0"
|
||||
ab_glyph = "0.2.28"
|
||||
safetensors = "0.4.3"
|
||||
|
||||
BIN
examples/yolo_v8/bike.jpg
Normal file
BIN
examples/yolo_v8/bike.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 179 KiB |
BIN
examples/yolo_v8/roboto-mono-stripped.ttf
Normal file
BIN
examples/yolo_v8/roboto-mono-stripped.ttf
Normal file
Binary file not shown.
45
examples/yolo_v8/src/loader.rs
Normal file
45
examples/yolo_v8/src/loader.rs
Normal file
@@ -0,0 +1,45 @@
|
||||
use std::io::Read;
|
||||
use std::path::Path;
|
||||
use std::{fs::File, io::Seek};
|
||||
|
||||
use luminal::{op::Function, prelude::*};
|
||||
use memmap2::MmapOptions;
|
||||
use safetensors::{Dtype, SafeTensors};
|
||||
|
||||
pub fn load<M: SerializeModule>(path: &str, model: &M, graph: &mut Graph) {
|
||||
for (weight_name, node_index) in param_dict(model) {
|
||||
if let Some(loading_node) = graph
|
||||
.graph
|
||||
.node_weight_mut(node_index)
|
||||
.and_then(|op| op.as_any_mut().downcast_mut::<Function>())
|
||||
{
|
||||
let path = path.to_string();
|
||||
loading_node.1 = Box::new(move |_| {
|
||||
let mut bytes = vec![];
|
||||
let mut file = File::open(&path).unwrap();
|
||||
file.read_to_end(&mut bytes).unwrap();
|
||||
let safetensors = SafeTensors::deserialize(&bytes).unwrap();
|
||||
|
||||
if let Ok(tensor_view) = safetensors.tensor(&weight_name.replace('/', ".")) {
|
||||
// Convert to fp32
|
||||
let data: Vec<f32> = match tensor_view.dtype() {
|
||||
Dtype::F32 => tensor_view
|
||||
.data()
|
||||
.chunks_exact(4)
|
||||
.map(|c| f32::from_ne_bytes([c[0], c[1], c[2], c[3]]))
|
||||
.collect(),
|
||||
Dtype::F16 => tensor_view
|
||||
.data()
|
||||
.chunks_exact(2)
|
||||
.map(|c| f16::from_ne_bytes([c[0], c[1]]).to_f32())
|
||||
.collect(),
|
||||
_ => panic!("{:?} is not a supported dtype", tensor_view.dtype()),
|
||||
};
|
||||
return vec![Tensor::new(data)];
|
||||
}
|
||||
|
||||
panic!("Tensor \"{weight_name}\" not found in files");
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,5 +1,293 @@
|
||||
mod loader;
|
||||
mod model;
|
||||
|
||||
fn main() {
|
||||
println!("Hello, world!");
|
||||
use image::DynamicImage;
|
||||
use luminal::prelude::*;
|
||||
|
||||
pub const NAMES: [&str; 80] = [
|
||||
"person",
|
||||
"bicycle",
|
||||
"car",
|
||||
"motorbike",
|
||||
"aeroplane",
|
||||
"bus",
|
||||
"train",
|
||||
"truck",
|
||||
"boat",
|
||||
"traffic light",
|
||||
"fire hydrant",
|
||||
"stop sign",
|
||||
"parking meter",
|
||||
"bench",
|
||||
"bird",
|
||||
"cat",
|
||||
"dog",
|
||||
"horse",
|
||||
"sheep",
|
||||
"cow",
|
||||
"elephant",
|
||||
"bear",
|
||||
"zebra",
|
||||
"giraffe",
|
||||
"backpack",
|
||||
"umbrella",
|
||||
"handbag",
|
||||
"tie",
|
||||
"suitcase",
|
||||
"frisbee",
|
||||
"skis",
|
||||
"snowboard",
|
||||
"sports ball",
|
||||
"kite",
|
||||
"baseball bat",
|
||||
"baseball glove",
|
||||
"skateboard",
|
||||
"surfboard",
|
||||
"tennis racket",
|
||||
"bottle",
|
||||
"wine glass",
|
||||
"cup",
|
||||
"fork",
|
||||
"knife",
|
||||
"spoon",
|
||||
"bowl",
|
||||
"banana",
|
||||
"apple",
|
||||
"sandwich",
|
||||
"orange",
|
||||
"broccoli",
|
||||
"carrot",
|
||||
"hot dog",
|
||||
"pizza",
|
||||
"donut",
|
||||
"cake",
|
||||
"chair",
|
||||
"sofa",
|
||||
"pottedplant",
|
||||
"bed",
|
||||
"diningtable",
|
||||
"toilet",
|
||||
"tvmonitor",
|
||||
"laptop",
|
||||
"mouse",
|
||||
"remote",
|
||||
"keyboard",
|
||||
"cell phone",
|
||||
"microwave",
|
||||
"oven",
|
||||
"toaster",
|
||||
"sink",
|
||||
"refrigerator",
|
||||
"book",
|
||||
"clock",
|
||||
"vase",
|
||||
"scissors",
|
||||
"teddy bear",
|
||||
"hair drier",
|
||||
"toothbrush",
|
||||
];
|
||||
|
||||
/// A bounding box around an object.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Bbox<D> {
|
||||
pub xmin: f32,
|
||||
pub ymin: f32,
|
||||
pub xmax: f32,
|
||||
pub ymax: f32,
|
||||
pub confidence: f32,
|
||||
pub data: D,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub struct KeyPoint {
|
||||
pub x: f32,
|
||||
pub y: f32,
|
||||
pub mask: f32,
|
||||
}
|
||||
|
||||
/// Intersection over union of two bounding boxes.
|
||||
pub fn iou<D>(b1: &Bbox<D>, b2: &Bbox<D>) -> f32 {
|
||||
let b1_area = (b1.xmax - b1.xmin + 1.) * (b1.ymax - b1.ymin + 1.);
|
||||
let b2_area = (b2.xmax - b2.xmin + 1.) * (b2.ymax - b2.ymin + 1.);
|
||||
let i_xmin = b1.xmin.max(b2.xmin);
|
||||
let i_xmax = b1.xmax.min(b2.xmax);
|
||||
let i_ymin = b1.ymin.max(b2.ymin);
|
||||
let i_ymax = b1.ymax.min(b2.ymax);
|
||||
let i_area = (i_xmax - i_xmin + 1.).max(0.) * (i_ymax - i_ymin + 1.).max(0.);
|
||||
i_area / (b1_area + b2_area - i_area)
|
||||
}
|
||||
|
||||
pub fn non_maximum_suppression<D>(bboxes: &mut [Vec<Bbox<D>>], threshold: f32) {
|
||||
// Perform non-maximum suppression.
|
||||
for bboxes_for_class in bboxes.iter_mut() {
|
||||
bboxes_for_class.sort_by(|b1, b2| b2.confidence.partial_cmp(&b1.confidence).unwrap());
|
||||
let mut current_index = 0;
|
||||
for index in 0..bboxes_for_class.len() {
|
||||
let mut drop = false;
|
||||
for prev_index in 0..current_index {
|
||||
let iou = iou(&bboxes_for_class[prev_index], &bboxes_for_class[index]);
|
||||
if iou > threshold {
|
||||
drop = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if !drop {
|
||||
bboxes_for_class.swap(current_index, index);
|
||||
current_index += 1;
|
||||
}
|
||||
}
|
||||
bboxes_for_class.truncate(current_index);
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn report_detect(
|
||||
pred_size: usize,
|
||||
n_preds: usize,
|
||||
pred: &[f32],
|
||||
img: DynamicImage,
|
||||
w: usize,
|
||||
h: usize,
|
||||
confidence_threshold: f32,
|
||||
nms_threshold: f32,
|
||||
legend_size: u32,
|
||||
) -> DynamicImage {
|
||||
let nclasses = pred_size - 4;
|
||||
// The bounding boxes grouped by (maximum) class index.
|
||||
let mut bboxes: Vec<Vec<Bbox<Vec<KeyPoint>>>> = (0..nclasses).map(|_| vec![]).collect();
|
||||
// Extract the bounding boxes for which confidence is above the threshold.
|
||||
for index in 0..n_preds {
|
||||
let pred = pred[pred_size * index..pred_size * (index + 1)].to_vec();
|
||||
let confidence = *pred[4..].iter().max_by(|x, y| x.total_cmp(y)).unwrap();
|
||||
if confidence > confidence_threshold {
|
||||
let mut class_index = 0;
|
||||
for i in 0..nclasses {
|
||||
if pred[4 + i] > pred[4 + class_index] {
|
||||
class_index = i
|
||||
}
|
||||
}
|
||||
if pred[class_index + 4] > 0. {
|
||||
let bbox = Bbox {
|
||||
xmin: pred[0] - pred[2] / 2.,
|
||||
ymin: pred[1] - pred[3] / 2.,
|
||||
xmax: pred[0] + pred[2] / 2.,
|
||||
ymax: pred[1] + pred[3] / 2.,
|
||||
confidence,
|
||||
data: vec![],
|
||||
};
|
||||
bboxes[class_index].push(bbox)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
non_maximum_suppression(&mut bboxes, nms_threshold);
|
||||
|
||||
// Annotate the original image and print boxes information.
|
||||
let (initial_h, initial_w) = (img.height(), img.width());
|
||||
let w_ratio = initial_w as f32 / w as f32;
|
||||
let h_ratio = initial_h as f32 / h as f32;
|
||||
let mut img = img.to_rgb8();
|
||||
let font = Vec::from(include_bytes!("../roboto-mono-stripped.ttf") as &[u8]);
|
||||
let font = ab_glyph::FontRef::try_from_slice(&font).unwrap();
|
||||
for (class_index, bboxes_for_class) in bboxes.iter().enumerate() {
|
||||
for b in bboxes_for_class.iter() {
|
||||
println!("{}: {:?}", NAMES[class_index], b);
|
||||
let xmin = (b.xmin * w_ratio) as i32;
|
||||
let ymin = (b.ymin * h_ratio) as i32;
|
||||
let dx = (b.xmax - b.xmin) * w_ratio;
|
||||
let dy = (b.ymax - b.ymin) * h_ratio;
|
||||
if dx >= 0. && dy >= 0. {
|
||||
imageproc::drawing::draw_hollow_rect_mut(
|
||||
&mut img,
|
||||
imageproc::rect::Rect::at(xmin, ymin).of_size(dx as u32, dy as u32),
|
||||
image::Rgb([255, 0, 0]),
|
||||
);
|
||||
}
|
||||
if legend_size > 0 {
|
||||
imageproc::drawing::draw_filled_rect_mut(
|
||||
&mut img,
|
||||
imageproc::rect::Rect::at(xmin, ymin).of_size(dx as u32, legend_size),
|
||||
image::Rgb([170, 0, 0]),
|
||||
);
|
||||
let legend = format!("{} {:.0}%", NAMES[class_index], 100. * b.confidence);
|
||||
imageproc::drawing::draw_text_mut(
|
||||
&mut img,
|
||||
image::Rgb([255, 255, 255]),
|
||||
xmin,
|
||||
ymin,
|
||||
ab_glyph::PxScale {
|
||||
x: legend_size as f32 - 1.,
|
||||
y: legend_size as f32 - 1.,
|
||||
},
|
||||
&font,
|
||||
&legend,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
DynamicImage::ImageRgb8(img)
|
||||
}
|
||||
|
||||
fn main() {
|
||||
// Setup graph
|
||||
let mut cx = Graph::new();
|
||||
let mut input = cx.tensor((1, 3, 'h', 'w'));
|
||||
let model = model::Yolo::new(0.33, 0.25, 2.0, 80, &mut cx);
|
||||
let mut model_params = params(&model);
|
||||
let mut output = model.forward(input).retrieve();
|
||||
loader::load("yolov8n.safetensors", &model, &mut cx);
|
||||
|
||||
// Compile
|
||||
cx.compile(
|
||||
GenericCompiler::default(),
|
||||
(&mut input, &mut model_params, &mut output),
|
||||
);
|
||||
let mut image_name = std::path::PathBuf::from("bike.jpg");
|
||||
let original_image = image::io::Reader::open(&image_name)
|
||||
.unwrap()
|
||||
.decode()
|
||||
.unwrap();
|
||||
let (width, height) = {
|
||||
let w = original_image.width() as usize;
|
||||
let h = original_image.height() as usize;
|
||||
if w < h {
|
||||
let w = w * 640 / h;
|
||||
// Sizes have to be divisible by 32.
|
||||
(w / 32 * 32, 640)
|
||||
} else {
|
||||
let h = h * 640 / w;
|
||||
(640, h / 32 * 32)
|
||||
}
|
||||
};
|
||||
println!("Width: {width} Height: {height}");
|
||||
let img = original_image.resize_exact(
|
||||
width as u32,
|
||||
height as u32,
|
||||
image::imageops::FilterType::CatmullRom,
|
||||
);
|
||||
let data = img
|
||||
.to_rgb8()
|
||||
.into_raw()
|
||||
.into_iter()
|
||||
.map(|i| i as f32 / 255.)
|
||||
.collect::<Vec<_>>();
|
||||
input.set_dyn(data, (1, 3, img.height() as usize, img.width() as usize));
|
||||
let (_, pred_size, n_preds) = output.dims3();
|
||||
|
||||
cx.execute();
|
||||
|
||||
let image_t = report_detect(
|
||||
pred_size.exec(&cx.dyn_map).unwrap(),
|
||||
n_preds.exec(&cx.dyn_map).unwrap(),
|
||||
&output.data(),
|
||||
original_image,
|
||||
width,
|
||||
height,
|
||||
0.25,
|
||||
0.45,
|
||||
14,
|
||||
);
|
||||
image_name.set_extension("pp.jpg");
|
||||
println!("writing {image_name:?}");
|
||||
image_t.save(image_name).unwrap();
|
||||
}
|
||||
|
||||
@@ -1,88 +1,584 @@
|
||||
use luminal::prelude::*;
|
||||
use luminal_nn::Conv2D;
|
||||
|
||||
struct Bottleneck<const CH_IN: usize, const CH_OUT: usize> {
|
||||
cv1: Conv2D<CH_IN, CH_OUT, 3, 3, 1, 1>,
|
||||
cv2: Conv2D<CH_OUT, CH_OUT, 3, 3, 1, 1>,
|
||||
struct Upsample {
|
||||
scale_factor: usize,
|
||||
}
|
||||
|
||||
impl<const CH_IN: usize, const CH_OUT: usize, Width: Dimension, Height: Dimension>
|
||||
Module<GraphTensor<(Const<CH_IN>, Height, Width)>> for Bottleneck<CH_IN, CH_OUT>
|
||||
{
|
||||
type Output = GraphTensor<(Const<CH_OUT>, Height, Width)>;
|
||||
fn forward(&self, input: GraphTensor<(Const<CH_IN>, Height, Width)>) -> Self::Output {
|
||||
self.cv2
|
||||
.forward(
|
||||
self.cv1
|
||||
.forward::<Width, Height, Width, Height>(input.permute()),
|
||||
)
|
||||
.permute()
|
||||
impl Upsample {
|
||||
fn new(scale_factor: usize) -> Self {
|
||||
Upsample { scale_factor }
|
||||
}
|
||||
}
|
||||
|
||||
struct C2F {
|
||||
cv1: ConvBlock,
|
||||
cv2: ConvBlock,
|
||||
bottleneck: Vec<Bottleneck>,
|
||||
impl Module<GraphTensor> for Upsample {
|
||||
type Output = GraphTensor;
|
||||
fn forward(&self, xs: GraphTensor) -> GraphTensor {
|
||||
let (batch, channels, h, w) = xs.dims4();
|
||||
xs.expand(3, self.scale_factor)
|
||||
.expand(5, self.scale_factor)
|
||||
.reshape((
|
||||
batch,
|
||||
channels,
|
||||
self.scale_factor * h,
|
||||
self.scale_factor * w,
|
||||
)) // Double height and width
|
||||
}
|
||||
}
|
||||
|
||||
struct Bottleneck {
|
||||
cv1: Conv2D,
|
||||
cv2: Conv2D,
|
||||
residual: bool,
|
||||
}
|
||||
|
||||
impl Bottleneck {
|
||||
pub fn new(ch_in: usize, ch_out: usize, shortcut: bool, cx: &mut Graph) -> Self {
|
||||
Self {
|
||||
cv1: Conv2D::new(ch_in, ch_out, (3, 3), (3, 3), (1, 1), true, cx),
|
||||
cv2: Conv2D::new(ch_out, ch_out, (3, 3), (3, 3), (1, 1), true, cx),
|
||||
residual: ch_in == ch_out && shortcut,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SerializeModule for Bottleneck {
|
||||
fn serialize(&self, s: &mut Serializer) {
|
||||
s.module("cv1", &self.cv1);
|
||||
s.module("cv2", &self.cv2);
|
||||
}
|
||||
}
|
||||
|
||||
impl Module<GraphTensor> for Bottleneck {
|
||||
type Output = GraphTensor;
|
||||
fn forward(&self, input: GraphTensor) -> Self::Output {
|
||||
let mut out = self
|
||||
.cv2
|
||||
.forward(self.cv1.forward(input.permute((0, 1, 3, 2))))
|
||||
.permute((0, 1, 3, 2));
|
||||
if self.residual {
|
||||
out += input
|
||||
}
|
||||
out
|
||||
}
|
||||
}
|
||||
|
||||
struct C2f {
|
||||
cv1: Conv2D,
|
||||
cv2: Conv2D,
|
||||
bottleneck: Vec<Bottleneck>,
|
||||
c: usize,
|
||||
}
|
||||
|
||||
impl C2f {
|
||||
pub fn new(c1: usize, c2: usize, n: usize, shortcut: bool, cx: &mut Graph) -> Self {
|
||||
let c = (c2 as f64 / 2.) as usize;
|
||||
Self {
|
||||
cv1: Conv2D::new(c1, 2 * c, (1, 1), (1, 1), (1, 1), true, cx),
|
||||
cv2: Conv2D::new((2 + n) * c, c2, (1, 1), (1, 1), (1, 1), true, cx),
|
||||
bottleneck: (0..n)
|
||||
.map(|_| Bottleneck::new(c, c, shortcut, cx))
|
||||
.collect(),
|
||||
c,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SerializeModule for C2f {
|
||||
fn serialize(&self, s: &mut Serializer) {
|
||||
s.module("cv1", &self.cv1);
|
||||
s.module("cv2", &self.cv2);
|
||||
for (i, l) in self.bottleneck.iter().enumerate() {
|
||||
s.module(&format!("bottleneck.{i}"), l);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Module<GraphTensor> for C2f {
|
||||
type Output = GraphTensor;
|
||||
fn forward(&self, input: GraphTensor) -> Self::Output {
|
||||
let ys = self.cv1.forward(input);
|
||||
let mut ys = chunk(ys, 2, 1);
|
||||
for m in self.bottleneck.iter() {
|
||||
ys.push(m.forward(*ys.last().unwrap()));
|
||||
}
|
||||
let mut fin = ys.remove(0);
|
||||
for t in ys {
|
||||
fin = fin.concat_along(t, 1);
|
||||
}
|
||||
self.cv2.forward(fin)
|
||||
}
|
||||
}
|
||||
|
||||
fn chunk(tensor: GraphTensor, chunks: usize, dim: usize) -> Vec<GraphTensor> {
|
||||
let chunk_size = tensor.dims()[dim] / chunks;
|
||||
let mut t = vec![];
|
||||
for i in 0..chunks {
|
||||
t.push(tensor.slice_along(i * chunk_size..(i + 1) * chunk_size, dim));
|
||||
}
|
||||
t
|
||||
}
|
||||
|
||||
#[allow(clippy::upper_case_acronyms)]
|
||||
struct SPPF {
|
||||
cv1: ConvBlock,
|
||||
cv2: ConvBlock,
|
||||
cv1: Conv2D,
|
||||
cv2: Conv2D,
|
||||
k: usize,
|
||||
}
|
||||
|
||||
impl SerializeModule for SPPF {
|
||||
fn serialize(&self, s: &mut Serializer) {
|
||||
s.module("cv1", &self.cv1);
|
||||
s.module("cv2", &self.cv2);
|
||||
}
|
||||
}
|
||||
|
||||
impl SPPF {
|
||||
pub fn new(c1: usize, c2: usize, k: usize, cx: &mut Graph) -> Self {
|
||||
let c_ = c1 / 2;
|
||||
Self {
|
||||
cv1: Conv2D::new(c1, c_, (1, 1), (1, 1), (1, 1), true, cx),
|
||||
cv2: Conv2D::new(c_ * 4, c2, (1, 1), (1, 1), (1, 1), true, cx),
|
||||
k,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Module<GraphTensor> for SPPF {
|
||||
type Output = GraphTensor;
|
||||
fn forward(&self, xs: GraphTensor) -> Self::Output {
|
||||
let xs = self.cv1.forward(xs);
|
||||
let xs2 = xs
|
||||
.pad((
|
||||
(0, 0),
|
||||
(0, 0),
|
||||
(self.k / 2, self.k / 2),
|
||||
(self.k / 2, self.k / 2),
|
||||
))
|
||||
.pool_last_dim(self.k, 1, 1)
|
||||
.max_reduce(4);
|
||||
let xs3 = xs2
|
||||
.pad((
|
||||
(0, 0),
|
||||
(0, 0),
|
||||
(self.k / 2, self.k / 2),
|
||||
(self.k / 2, self.k / 2),
|
||||
))
|
||||
.pool_last_dim(self.k, 1, 1)
|
||||
.max_reduce(4);
|
||||
let xs4 = xs3
|
||||
.pad((
|
||||
(0, 0),
|
||||
(0, 0),
|
||||
(self.k / 2, self.k / 2),
|
||||
(self.k / 2, self.k / 2),
|
||||
))
|
||||
.pool_last_dim(self.k, 1, 1)
|
||||
.max_reduce(4);
|
||||
self.cv2.forward(
|
||||
xs.concat_along(xs2, 1)
|
||||
.concat_along(xs3, 1)
|
||||
.concat_along(xs4, 1),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
struct DFL {
|
||||
conv: Conv2D,
|
||||
num_classes: usize,
|
||||
}
|
||||
|
||||
impl SerializeModule for DFL {
|
||||
fn serialize(&self, s: &mut Serializer) {
|
||||
s.module("conv", &self.conv);
|
||||
}
|
||||
}
|
||||
|
||||
impl DFL {
|
||||
pub fn new(num_classes: usize, cx: &mut Graph) -> Self {
|
||||
Self {
|
||||
conv: Conv2D::new(num_classes, 1, (1, 1), (1, 1), (1, 1), false, cx),
|
||||
num_classes,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Module<GraphTensor> for DFL {
|
||||
type Output = GraphTensor;
|
||||
fn forward(&self, xs: GraphTensor) -> Self::Output {
|
||||
let (b_sz, _channels, anchors) = xs.dims3();
|
||||
let xs = xs
|
||||
.reshape((b_sz, 4, self.num_classes, anchors))
|
||||
.permute((0, 1, 3, 2))
|
||||
.softmax(1);
|
||||
self.conv.forward(xs).reshape((b_sz, 4, anchors))
|
||||
}
|
||||
}
|
||||
|
||||
struct DarkNet {
|
||||
b1_0: ConvBlock,
|
||||
b1_1: ConvBlock,
|
||||
b2_0: C2F,
|
||||
b2_1: ConvBlock,
|
||||
b2_2: C2F,
|
||||
b3_0: ConvBlock,
|
||||
b3_1: C2F,
|
||||
b4_0: ConvBlock,
|
||||
b4_1: C2F,
|
||||
b1_0: Conv2D,
|
||||
b1_1: Conv2D,
|
||||
b2_0: C2f,
|
||||
b2_1: Conv2D,
|
||||
b2_2: C2f,
|
||||
b3_0: Conv2D,
|
||||
b3_1: C2f,
|
||||
b4_0: Conv2D,
|
||||
b4_1: C2f,
|
||||
b5: SPPF,
|
||||
}
|
||||
|
||||
impl SerializeModule for DarkNet {
|
||||
fn serialize(&self, s: &mut Serializer) {
|
||||
s.module("b1_0", &self.b1_0);
|
||||
s.module("b1_1", &self.b1_1);
|
||||
s.module("b2_0", &self.b2_0);
|
||||
s.module("b2_1", &self.b2_1);
|
||||
s.module("b3_0", &self.b3_0);
|
||||
s.module("b3_1", &self.b3_1);
|
||||
s.module("b4_0", &self.b4_0);
|
||||
s.module("b4_1", &self.b4_1);
|
||||
s.module("b5", &self.b5);
|
||||
}
|
||||
}
|
||||
|
||||
impl DarkNet {
|
||||
pub fn new(w: f64, r: f64, d: f64, cx: &mut Graph) -> Self {
|
||||
Self {
|
||||
b1_0: Conv2D::new(3, (64. * w) as usize, (3, 3), (2, 2), (1, 1), true, cx),
|
||||
b1_1: Conv2D::new(
|
||||
(64. * w) as usize,
|
||||
(128. * w) as usize,
|
||||
(3, 3),
|
||||
(2, 2),
|
||||
(1, 1),
|
||||
true,
|
||||
cx,
|
||||
),
|
||||
b2_0: C2f::new(
|
||||
(128. * w) as usize,
|
||||
(128. * w) as usize,
|
||||
(3. * d).round() as usize,
|
||||
true,
|
||||
cx,
|
||||
),
|
||||
b2_1: Conv2D::new(
|
||||
(128. * w) as usize,
|
||||
(256. * w) as usize,
|
||||
(3, 3),
|
||||
(2, 2),
|
||||
(1, 1),
|
||||
true,
|
||||
cx,
|
||||
),
|
||||
b2_2: C2f::new(
|
||||
(256. * w) as usize,
|
||||
(256. * w) as usize,
|
||||
(6. * d).round() as usize,
|
||||
true,
|
||||
cx,
|
||||
),
|
||||
b3_0: Conv2D::new(
|
||||
(256. * w) as usize,
|
||||
(512. * w) as usize,
|
||||
(3, 3),
|
||||
(2, 2),
|
||||
(1, 1),
|
||||
true,
|
||||
cx,
|
||||
),
|
||||
b3_1: C2f::new(
|
||||
(512. * w) as usize,
|
||||
(512. * w) as usize,
|
||||
(6. * d).round() as usize,
|
||||
true,
|
||||
cx,
|
||||
),
|
||||
b4_0: Conv2D::new(
|
||||
(512. * w) as usize,
|
||||
(512. * w * r) as usize,
|
||||
(3, 3),
|
||||
(2, 2),
|
||||
(1, 1),
|
||||
true,
|
||||
cx,
|
||||
),
|
||||
b4_1: C2f::new(
|
||||
(512. * w * r) as usize,
|
||||
(512. * w * r) as usize,
|
||||
(3. * d).round() as usize,
|
||||
true,
|
||||
cx,
|
||||
),
|
||||
b5: SPPF::new((512. * w * r) as usize, (512. * w * r) as usize, 5, cx),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Module<GraphTensor> for DarkNet {
|
||||
type Output = (GraphTensor, GraphTensor, GraphTensor);
|
||||
fn forward(&self, xs: GraphTensor) -> Self::Output {
|
||||
let x1 = self.b1_1.forward(self.b1_0.forward(xs));
|
||||
let x2 = self.b2_2.forward(self.b2_1.forward(self.b2_0.forward(x1)));
|
||||
let x3 = self.b3_1.forward(self.b3_0.forward(x2));
|
||||
let x4 = self.b4_1.forward(self.b4_0.forward(x3));
|
||||
let x5 = self.b5.forward(x4);
|
||||
(x2, x3, x5)
|
||||
}
|
||||
}
|
||||
|
||||
struct YoloNeck {
|
||||
n1: C2F,
|
||||
n2: C2F,
|
||||
n3: ConvBlock,
|
||||
n4: C2F,
|
||||
n5: ConvBlock,
|
||||
n6: C2F,
|
||||
up: Upsample,
|
||||
n1: C2f,
|
||||
n2: C2f,
|
||||
n3: Conv2D,
|
||||
n4: C2f,
|
||||
n5: Conv2D,
|
||||
n6: C2f,
|
||||
}
|
||||
|
||||
impl SerializeModule for YoloNeck {
|
||||
fn serialize(&self, s: &mut Serializer) {
|
||||
s.module("n1", &self.n1);
|
||||
s.module("n2", &self.n2);
|
||||
s.module("n3", &self.n3);
|
||||
s.module("n4", &self.n4);
|
||||
s.module("n5", &self.n5);
|
||||
s.module("n6", &self.n6);
|
||||
}
|
||||
}
|
||||
|
||||
impl YoloNeck {
|
||||
pub fn new(w: f64, r: f64, d: f64, cx: &mut Graph) -> Self {
|
||||
let n = (3. * d).round() as usize;
|
||||
Self {
|
||||
up: Upsample::new(2),
|
||||
n1: C2f::new(
|
||||
(512. * w * (1. + r)) as usize,
|
||||
(512. * w) as usize,
|
||||
n,
|
||||
false,
|
||||
cx,
|
||||
),
|
||||
n2: C2f::new((768. * w) as usize, (256. * w) as usize, n, false, cx),
|
||||
n3: Conv2D::new(
|
||||
(256. * w) as usize,
|
||||
(256. * w) as usize,
|
||||
(3, 3),
|
||||
(2, 2),
|
||||
(1, 1),
|
||||
true,
|
||||
cx,
|
||||
),
|
||||
n4: C2f::new((768. * w) as usize, (512. * w) as usize, n, false, cx),
|
||||
n5: Conv2D::new(
|
||||
(512. * w) as usize,
|
||||
(512. * w) as usize,
|
||||
(3, 3),
|
||||
(2, 2),
|
||||
(1, 1),
|
||||
true,
|
||||
cx,
|
||||
),
|
||||
n6: C2f::new(
|
||||
(512. * w * (1. + r)) as usize,
|
||||
(512. * w * r) as usize,
|
||||
n,
|
||||
false,
|
||||
cx,
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Module<(GraphTensor, GraphTensor, GraphTensor)> for YoloNeck {
|
||||
type Output = (GraphTensor, GraphTensor, GraphTensor);
|
||||
fn forward(&self, (p3, p4, p5): (GraphTensor, GraphTensor, GraphTensor)) -> Self::Output {
|
||||
let x = self.n1.forward(self.up.forward(p5).concat_along(p4, 1));
|
||||
let head_1 = self.n2.forward(self.up.forward(x).concat_along(p3, 1));
|
||||
let head_2 = self.n4.forward(self.n3.forward(head_1).concat_along(x, 1));
|
||||
let head_3 = self.n6.forward(self.n5.forward(head_2).concat_along(p5, 1));
|
||||
(head_1, head_2, head_3)
|
||||
}
|
||||
}
|
||||
|
||||
struct DetectionHead {
|
||||
dfl: DFL,
|
||||
cv2: [(ConvBlock, ConvBlock, Conv2D); 3],
|
||||
cv3: [(ConvBlock, ConvBlock, Conv2D); 3],
|
||||
cv2: [(Conv2D, Conv2D, Conv2D); 3],
|
||||
cv3: [(Conv2D, Conv2D, Conv2D); 3],
|
||||
ch: usize,
|
||||
no: usize,
|
||||
}
|
||||
|
||||
pub struct Yolo {
|
||||
net: DarkNet,
|
||||
fpn: YoloNeck,
|
||||
head: DetectionHead,
|
||||
}
|
||||
|
||||
impl<
|
||||
const CHANNELS: usize,
|
||||
const CLASSIFICATION: usize,
|
||||
Batch: Dimension,
|
||||
Height: Dimension,
|
||||
Width: Dimension,
|
||||
> Module<GraphTensor<(Batch, Const<CHANNELS>, Width, Height)>> for Yolo
|
||||
{
|
||||
type Output = GraphTensor<(Batch,)>;
|
||||
|
||||
fn forward(&self, input: GraphTensor<(Batch, Const<CHANNELS>, Width, Height)>) -> Self::Output {
|
||||
impl SerializeModule for DetectionHead {
|
||||
fn serialize(&self, s: &mut Serializer) {
|
||||
s.module("dfl", &self.dfl);
|
||||
for (i, m) in self.cv2.iter().enumerate() {
|
||||
s.module(&format!("cv2/{i}"), m);
|
||||
}
|
||||
for (i, m) in self.cv3.iter().enumerate() {
|
||||
s.module(&format!("cv3/{i}"), m);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl DetectionHead {
|
||||
pub fn new(nc: usize, filters: (usize, usize, usize), cx: &mut Graph) -> Self {
|
||||
let ch = 16;
|
||||
let c1 = usize::max(filters.0, nc);
|
||||
let c2 = usize::max(filters.0 / 4, ch * 4);
|
||||
Self {
|
||||
dfl: DFL::new(ch, cx),
|
||||
cv2: [
|
||||
Self::new_cv2(c2, ch, filters.0, cx),
|
||||
Self::new_cv2(c2, ch, filters.1, cx),
|
||||
Self::new_cv2(c2, ch, filters.2, cx),
|
||||
],
|
||||
cv3: [
|
||||
Self::new_cv3(c1, nc, filters.0, cx),
|
||||
Self::new_cv3(c1, nc, filters.1, cx),
|
||||
Self::new_cv3(c1, nc, filters.2, cx),
|
||||
],
|
||||
ch,
|
||||
no: nc + ch * 4,
|
||||
}
|
||||
}
|
||||
|
||||
fn new_cv3(c1: usize, nc: usize, filter: usize, cx: &mut Graph) -> (Conv2D, Conv2D, Conv2D) {
|
||||
(
|
||||
Conv2D::new(filter, c1, (3, 3), (1, 1), (1, 1), true, cx),
|
||||
Conv2D::new(c1, c1, (3, 3), (1, 1), (1, 1), true, cx),
|
||||
Conv2D::new(c1, nc, (1, 1), (1, 1), (1, 1), true, cx),
|
||||
)
|
||||
}
|
||||
|
||||
fn new_cv2(c2: usize, ch: usize, filter: usize, cx: &mut Graph) -> (Conv2D, Conv2D, Conv2D) {
|
||||
(
|
||||
Conv2D::new(filter, c2, (3, 3), (1, 1), (1, 1), true, cx),
|
||||
Conv2D::new(c2, c2, (3, 3), (1, 1), (1, 1), true, cx),
|
||||
Conv2D::new(c2, 4 * ch, (1, 1), (1, 1), (1, 1), true, cx),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl Module<(GraphTensor, GraphTensor, GraphTensor)> for DetectionHead {
|
||||
type Output = (GraphTensor, GraphTensor, GraphTensor);
|
||||
fn forward(&self, (xs0, xs1, xs2): (GraphTensor, GraphTensor, GraphTensor)) -> Self::Output {
|
||||
let forward_cv = |xs, i: usize| {
|
||||
let xs_2 = self.cv2[i].0.forward(xs);
|
||||
let xs_2 = self.cv2[i].1.forward(xs_2);
|
||||
let xs_2 = self.cv2[i].2.forward(xs_2);
|
||||
|
||||
let xs_3 = self.cv3[i].0.forward(xs);
|
||||
let xs_3 = self.cv3[i].1.forward(xs_3);
|
||||
let xs_3 = self.cv3[i].2.forward(xs_3);
|
||||
xs_2.concat_along(xs_3, 1)
|
||||
};
|
||||
let xs0 = forward_cv(xs0, 0);
|
||||
let xs1 = forward_cv(xs1, 1);
|
||||
let xs2 = forward_cv(xs2, 2);
|
||||
|
||||
let (anchors, strides) = make_anchors(xs0, xs1, xs2, (8, 16, 32), 0.5);
|
||||
let anchors = anchors.permute((1, 0)).expand(0, 1);
|
||||
let strides = strides.permute((1, 0));
|
||||
|
||||
let reshape = |xs: GraphTensor| {
|
||||
let d = xs.dims()[0];
|
||||
let el = xs.shape.n_elements();
|
||||
xs.reshape((d, self.no, el / (d * self.no)))
|
||||
};
|
||||
let ys0 = reshape(xs0);
|
||||
let ys1 = reshape(xs1);
|
||||
let ys2 = reshape(xs2);
|
||||
|
||||
let x_cat = ys0.concat_along(ys1, 2).concat_along(ys2, 2);
|
||||
let box_ = x_cat.slice((.., ..self.ch * 4));
|
||||
let cls = x_cat.slice((.., self.ch * 4..));
|
||||
|
||||
let dbox = dist2bbox(self.dfl.forward(box_), anchors);
|
||||
let dbox = dbox * strides.expand_to(dbox.shape);
|
||||
let pred = dbox.concat_along(cls.sigmoid(), 1);
|
||||
(pred, anchors, strides)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Yolo {
|
||||
net: DarkNet, // Backbone
|
||||
fpn: YoloNeck, // Neck
|
||||
head: DetectionHead, // Head
|
||||
}
|
||||
|
||||
impl SerializeModule for Yolo {
|
||||
fn serialize(&self, s: &mut Serializer) {
|
||||
s.module("net", &self.net);
|
||||
s.module("fpn", &self.fpn);
|
||||
s.module("head", &self.head);
|
||||
}
|
||||
}
|
||||
|
||||
impl Yolo {
|
||||
pub fn new(w: f64, r: f64, d: f64, num_classes: usize, cx: &mut Graph) -> Self {
|
||||
let f1 = (256. * w) as usize;
|
||||
let f2 = (512. * r) as usize;
|
||||
let f3 = (512. * w * r) as usize;
|
||||
Self {
|
||||
net: DarkNet::new(w, r, d, cx),
|
||||
fpn: YoloNeck::new(w, r, d, cx),
|
||||
head: DetectionHead::new(num_classes, (f1, f2, f3), cx),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn make_anchors(
|
||||
xs0: GraphTensor,
|
||||
xs1: GraphTensor,
|
||||
xs2: GraphTensor,
|
||||
(s0, s1, s2): (usize, usize, usize),
|
||||
grid_cell_offset: f64,
|
||||
) -> (GraphTensor, GraphTensor) {
|
||||
let cx = xs0.graph();
|
||||
let mut anchor_points = vec![];
|
||||
let mut stride_tensor = vec![];
|
||||
for (xs, stride) in [(xs0, s0), (xs1, s1), (xs2, s2)] {
|
||||
// xs is only used to extract the h and w dimensions.
|
||||
let (_, _, h, w) = xs.dims4();
|
||||
let sx = cx.arange(w) + grid_cell_offset as f32;
|
||||
let sy = cx.arange(h) + grid_cell_offset as f32;
|
||||
let sx = sx.reshape((1, w)).expand(0, h).reshape(h * w);
|
||||
let sy = sy.reshape((h, 1)).expand(1, w).reshape(h * w);
|
||||
anchor_points.push(sx.expand(0, 1).concat_along(sy.expand(0, 1), 0));
|
||||
stride_tensor.push(cx.constant(1.).expand(0, h * w) * stride as f32);
|
||||
}
|
||||
let anchor_points = anchor_points
|
||||
.into_iter()
|
||||
.reduce(|acc, t| acc.concat_along(t, 0))
|
||||
.unwrap();
|
||||
let stride_tensor = stride_tensor
|
||||
.into_iter()
|
||||
.reduce(|acc, t| acc.concat_along(t, 0))
|
||||
.unwrap()
|
||||
.expand(1, 1);
|
||||
(anchor_points, stride_tensor)
|
||||
}
|
||||
|
||||
fn dist2bbox(distance: GraphTensor, anchor_points: GraphTensor) -> GraphTensor {
|
||||
let chunks = chunk(distance, 2, 1);
|
||||
let lt = chunks[0];
|
||||
let rb = chunks[1];
|
||||
let x1y1 = anchor_points - lt;
|
||||
let x2y2 = anchor_points + rb;
|
||||
let c_xy = (x1y1 + x2y2) * 0.5;
|
||||
let wh = x2y2 - x1y1;
|
||||
c_xy.concat_along(wh, 1)
|
||||
}
|
||||
|
||||
impl Module<GraphTensor> for Yolo {
|
||||
type Output = GraphTensor;
|
||||
|
||||
fn forward(&self, xs: GraphTensor) -> Self::Output {
|
||||
let (xs1, xs2, xs3) = self.net.forward(xs);
|
||||
let (xs1, xs2, xs3) = self.fpn.forward((xs1, xs2, xs3));
|
||||
let (pred, _, _) = self.head.forward((xs1, xs2, xs3));
|
||||
pred
|
||||
}
|
||||
}
|
||||
|
||||
BIN
examples/yolo_v8/yolov8n.safetensors
Normal file
BIN
examples/yolo_v8/yolov8n.safetensors
Normal file
Binary file not shown.
@@ -410,7 +410,7 @@ impl Graph {
|
||||
&& edge
|
||||
.weight()
|
||||
.as_data()
|
||||
.map(|d| !d.2.shape().is_empty())
|
||||
.map(|d| !d.2.is_empty())
|
||||
.unwrap_or_default()
|
||||
{
|
||||
new_graph
|
||||
@@ -418,7 +418,7 @@ impl Graph {
|
||||
.unwrap()
|
||||
.push_str(&format!(
|
||||
" | {:?}",
|
||||
edge.weight().as_data().unwrap().2.shape()
|
||||
edge.weight().as_data().unwrap().2.dims()
|
||||
));
|
||||
}
|
||||
}
|
||||
@@ -686,7 +686,7 @@ fn backtrack_match(
|
||||
pattern_root: NodeIndex,
|
||||
pattern_graph: &StableGraph<(Uuid, SelectOp), Option<u8>>,
|
||||
main_root: NodeIndex,
|
||||
main_graph: &mut MainGraph,
|
||||
main_graph: &mut StorageGraph,
|
||||
) -> Option<FxHashMap<NodeIndex, NodeIndex>> {
|
||||
fn get_parents<N, E>(
|
||||
graph: &petgraph::stable_graph::StableGraph<N, E>,
|
||||
@@ -736,7 +736,7 @@ fn test_node(
|
||||
shape,
|
||||
fake,
|
||||
}: &SelectOp,
|
||||
graph: &mut MainGraph,
|
||||
graph: &mut StorageGraph,
|
||||
graph_node: NodeIndex,
|
||||
) -> bool {
|
||||
let input_shapes = graph
|
||||
@@ -763,7 +763,7 @@ fn test_node(
|
||||
if a_sh.len() != b_sh.dims.len() {
|
||||
return false;
|
||||
}
|
||||
for (a, b) in a_sh.iter().zip(b_sh.shape().iter()) {
|
||||
for (a, b) in a_sh.iter().zip(b_sh.dims().into_iter()) {
|
||||
match a.to_usize() {
|
||||
Some(n) => {
|
||||
if b.to_usize().map(|i| i != n).unwrap_or(true) {
|
||||
@@ -776,11 +776,11 @@ fn test_node(
|
||||
.pop()
|
||||
.expect("Selector dimension must be either a symbol or number");
|
||||
if let Some(expected) = shape_map.get(&c) {
|
||||
if b != expected {
|
||||
if b != *expected {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
shape_map.insert(c, b.clone());
|
||||
shape_map.insert(c, b);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
18
src/graph.rs
18
src/graph.rs
@@ -11,7 +11,7 @@ use itertools::Itertools;
|
||||
use petgraph::{stable_graph::StableGraph, visit::EdgeRef, Direction};
|
||||
use rustc_hash::{FxHashMap, FxHashSet};
|
||||
|
||||
pub type MainGraph = StableGraph<Box<dyn Operator>, Dependency>;
|
||||
pub type StorageGraph = StableGraph<Box<dyn Operator>, Dependency>;
|
||||
|
||||
/// A Luminal compute graph.
|
||||
///
|
||||
@@ -24,7 +24,7 @@ pub struct Graph {
|
||||
/// A map of dynamic dimensions to concrete dimension sizes
|
||||
pub dyn_map: FxHashMap<char, usize>,
|
||||
/// Edge weights: (Input index, Output index, Input shape)
|
||||
pub graph: MainGraph,
|
||||
pub graph: StorageGraph,
|
||||
/// Tensors marked in this set will not get deleted when the graph is ran
|
||||
pub no_delete: FxHashSet<NodeIndex>,
|
||||
/// Tensors marked in this set need to be retrieved later (mostly for optimizers to insert copy back calls, the graph itself doesn't treat these differently)
|
||||
@@ -37,7 +37,7 @@ pub struct Graph {
|
||||
}
|
||||
|
||||
/// A dependency between two nodes
|
||||
#[derive(Debug, Clone, Copy, serde::Serialize, serde::Deserialize)]
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
#[allow(clippy::large_enum_variant)]
|
||||
pub enum Dependency {
|
||||
/// A data dependency (transferring a tensor from one node to the next)
|
||||
@@ -125,7 +125,7 @@ impl Graph {
|
||||
Box::new(move |_| panic!("You must set a value for this tensor! ({name})")),
|
||||
))),
|
||||
graph_ref: self,
|
||||
shape: ShapeTracker::new(&shape.to_shape()),
|
||||
shape: ShapeTracker::new(shape),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -305,7 +305,7 @@ impl Graph {
|
||||
.map(|(_, s)| {
|
||||
format!(
|
||||
"{:?}",
|
||||
s.shape()
|
||||
s.dims()
|
||||
.into_iter()
|
||||
.map(|i| i.to_usize().unwrap())
|
||||
.collect::<Vec<_>>()
|
||||
@@ -366,7 +366,7 @@ impl Graph {
|
||||
}
|
||||
|
||||
impl Deref for Graph {
|
||||
type Target = MainGraph;
|
||||
type Target = StorageGraph;
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.graph
|
||||
}
|
||||
@@ -378,6 +378,12 @@ impl DerefMut for Graph {
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for Graph {
|
||||
fn drop(&mut self) {
|
||||
expression_cleanup();
|
||||
}
|
||||
}
|
||||
|
||||
/// Get source tensor array for a node
|
||||
fn get_source_tensors<'a>(
|
||||
no_delete: &'a FxHashSet<NodeIndex>,
|
||||
|
||||
@@ -65,7 +65,7 @@ impl GraphTensor {
|
||||
/// ```
|
||||
pub fn set_dyn(self, data: impl Data + Clone, shape: impl ToShape) -> Self {
|
||||
// Report dyn dim values to graph dyn map
|
||||
for (d, s) in self.shape.shape().iter().zip(shape.to_shape().into_iter()) {
|
||||
for (d, s) in self.shape.dims().iter().zip(shape.to_shape().into_iter()) {
|
||||
if let Some(c) = d.to_symbols().pop() {
|
||||
self.graph().dyn_map.insert(c, s.to_usize().unwrap());
|
||||
}
|
||||
@@ -103,39 +103,58 @@ impl GraphTensor {
|
||||
data
|
||||
}
|
||||
|
||||
pub fn shape(&self) -> Vec<BigExpression> {
|
||||
self.shape.shape()
|
||||
pub fn dims(&self) -> Vec<Expression> {
|
||||
self.shape.dims()
|
||||
}
|
||||
|
||||
pub fn dims1(&self) -> Expression {
|
||||
self.shape.shape()[0].small()
|
||||
assert_eq!(
|
||||
self.shape.len(),
|
||||
1,
|
||||
"Shape has {} dimensions, tried to get 1",
|
||||
self.shape.len()
|
||||
);
|
||||
self.dims()[0]
|
||||
}
|
||||
pub fn dims2(&self) -> (Expression, Expression) {
|
||||
(self.shape.shape()[0].small(), self.shape.shape()[1].small())
|
||||
assert_eq!(
|
||||
self.shape.len(),
|
||||
2,
|
||||
"Shape has {} dimensions, tried to get 2",
|
||||
self.shape.len()
|
||||
);
|
||||
let dims = self.dims();
|
||||
(dims[0], dims[1])
|
||||
}
|
||||
pub fn dims3(&self) -> (Expression, Expression, Expression) {
|
||||
(
|
||||
self.shape.shape()[0].small(),
|
||||
self.shape.shape()[1].small(),
|
||||
self.shape.shape()[2].small(),
|
||||
)
|
||||
assert_eq!(
|
||||
self.shape.len(),
|
||||
3,
|
||||
"Shape has {} dimensions, tried to get 3",
|
||||
self.shape.len()
|
||||
);
|
||||
let dims = self.dims();
|
||||
(dims[0], dims[1], dims[2])
|
||||
}
|
||||
pub fn dims4(&self) -> (Expression, Expression, Expression, Expression) {
|
||||
(
|
||||
self.shape.shape()[0].small(),
|
||||
self.shape.shape()[1].small(),
|
||||
self.shape.shape()[2].small(),
|
||||
self.shape.shape()[3].small(),
|
||||
)
|
||||
assert_eq!(
|
||||
self.shape.len(),
|
||||
4,
|
||||
"Shape has {} dimensions, tried to get 4",
|
||||
self.shape.len()
|
||||
);
|
||||
let dims = self.dims();
|
||||
(dims[0], dims[1], dims[2], dims[3])
|
||||
}
|
||||
pub fn dims5(&self) -> (Expression, Expression, Expression, Expression, Expression) {
|
||||
(
|
||||
self.shape.shape()[0].small(),
|
||||
self.shape.shape()[1].small(),
|
||||
self.shape.shape()[2].small(),
|
||||
self.shape.shape()[3].small(),
|
||||
self.shape.shape()[4].small(),
|
||||
)
|
||||
assert_eq!(
|
||||
self.shape.len(),
|
||||
5,
|
||||
"Shape has {} dimensions, tried to get 5",
|
||||
self.shape.len()
|
||||
);
|
||||
let dims = self.dims();
|
||||
(dims[0], dims[1], dims[2], dims[3], dims[4])
|
||||
}
|
||||
|
||||
/// Set the value of the tensor matching the constant shape
|
||||
|
||||
@@ -140,13 +140,10 @@ impl Add<f32> for GraphTensor {
|
||||
}
|
||||
}
|
||||
|
||||
impl<St: ExpressionStorage> Add<GenericExpression<St>> for GraphTensor
|
||||
where
|
||||
GenericExpression<Vec<Term>>: From<GenericExpression<St>>,
|
||||
{
|
||||
impl<S: Into<Expression>> Add<S> for GraphTensor {
|
||||
type Output = GraphTensor;
|
||||
|
||||
fn add(self, rhs: GenericExpression<St>) -> Self::Output {
|
||||
fn add(self, rhs: S) -> Self::Output {
|
||||
self + self.graph().constant_expr(rhs).expand_to(self.shape)
|
||||
}
|
||||
}
|
||||
@@ -159,13 +156,10 @@ impl Sub<f32> for GraphTensor {
|
||||
}
|
||||
}
|
||||
|
||||
impl<St: ExpressionStorage> Sub<GenericExpression<St>> for GraphTensor
|
||||
where
|
||||
GenericExpression<Vec<Term>>: From<GenericExpression<St>>,
|
||||
{
|
||||
impl<S: Into<Expression>> Sub<S> for GraphTensor {
|
||||
type Output = GraphTensor;
|
||||
|
||||
fn sub(self, rhs: GenericExpression<St>) -> Self::Output {
|
||||
fn sub(self, rhs: S) -> Self::Output {
|
||||
self - self.graph().constant_expr(rhs).expand_to(self.shape)
|
||||
}
|
||||
}
|
||||
@@ -178,13 +172,10 @@ impl Mul<f32> for GraphTensor {
|
||||
}
|
||||
}
|
||||
|
||||
impl<St: ExpressionStorage> Mul<GenericExpression<St>> for GraphTensor
|
||||
where
|
||||
GenericExpression<Vec<Term>>: From<GenericExpression<St>>,
|
||||
{
|
||||
impl<S: Into<Expression>> Mul<S> for GraphTensor {
|
||||
type Output = GraphTensor;
|
||||
|
||||
fn mul(self, rhs: GenericExpression<St>) -> Self::Output {
|
||||
fn mul(self, rhs: S) -> Self::Output {
|
||||
self * self.graph().constant_expr(rhs).expand_to(self.shape)
|
||||
}
|
||||
}
|
||||
@@ -198,13 +189,10 @@ impl Div<f32> for GraphTensor {
|
||||
}
|
||||
}
|
||||
|
||||
impl<St: ExpressionStorage> Div<GenericExpression<St>> for GraphTensor
|
||||
where
|
||||
GenericExpression<Vec<Term>>: From<GenericExpression<St>>,
|
||||
{
|
||||
impl<S: Into<Expression>> Div<S> for GraphTensor {
|
||||
type Output = GraphTensor;
|
||||
|
||||
fn div(self, rhs: GenericExpression<St>) -> Self::Output {
|
||||
fn div(self, rhs: S) -> Self::Output {
|
||||
self / self.graph().constant_expr(rhs).expand_to(self.shape)
|
||||
}
|
||||
}
|
||||
@@ -217,13 +205,10 @@ impl Rem<f32> for GraphTensor {
|
||||
}
|
||||
}
|
||||
|
||||
impl<St: ExpressionStorage> Rem<GenericExpression<St>> for GraphTensor
|
||||
where
|
||||
GenericExpression<Vec<Term>>: From<GenericExpression<St>>,
|
||||
{
|
||||
impl<S: Into<Expression>> Rem<S> for GraphTensor {
|
||||
type Output = GraphTensor;
|
||||
|
||||
fn rem(self, rhs: GenericExpression<St>) -> Self::Output {
|
||||
fn rem(self, rhs: S) -> Self::Output {
|
||||
self % self.graph().constant_expr(rhs).expand_to(self.shape)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,11 +15,11 @@ impl GraphTensor {
|
||||
// Sum Reduce
|
||||
let mut ret = mul.sum_reduce(2);
|
||||
if vec {
|
||||
ret = ret.reshape(ret.shape().last().unwrap());
|
||||
ret = ret.reshape(ret.dims().last().unwrap());
|
||||
}
|
||||
ret
|
||||
} else if self.shape.len() == 3 {
|
||||
let d = rhs.shape().last().unwrap().small();
|
||||
let d = *rhs.dims().last().unwrap();
|
||||
let (a, b, _) = self.dims3();
|
||||
if rhs.shape.len() == 2 {
|
||||
// ABCxCD -> ABD
|
||||
@@ -43,8 +43,8 @@ impl GraphTensor {
|
||||
} else {
|
||||
panic!(
|
||||
"Can't matmul lhs {:?} and rhs {:?}",
|
||||
self.shape(),
|
||||
rhs.shape()
|
||||
self.dims(),
|
||||
rhs.dims()
|
||||
)
|
||||
}
|
||||
} else if self.shape.len() == 4 {
|
||||
@@ -73,8 +73,8 @@ impl GraphTensor {
|
||||
} else {
|
||||
panic!(
|
||||
"Can't matmul lhs {:?} and rhs {:?}",
|
||||
self.shape(),
|
||||
rhs.shape()
|
||||
self.dims(),
|
||||
rhs.dims()
|
||||
)
|
||||
}
|
||||
} else if self.shape.len() == 5 && rhs.shape.len() == 5 {
|
||||
@@ -93,8 +93,8 @@ impl GraphTensor {
|
||||
} else {
|
||||
panic!(
|
||||
"Can't matmul lhs {:?} and rhs {:?}",
|
||||
self.shape(),
|
||||
rhs.shape()
|
||||
self.dims(),
|
||||
rhs.dims()
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -28,7 +28,7 @@ impl GraphTensor {
|
||||
pub fn reshape(mut self, new_shape: impl ToShape) -> GraphTensor {
|
||||
// Insert contiguous call
|
||||
self = self.contiguous();
|
||||
self.shape = ShapeTracker::new(&new_shape.to_shape());
|
||||
self.shape = ShapeTracker::new(new_shape);
|
||||
self
|
||||
}
|
||||
|
||||
@@ -60,6 +60,12 @@ impl GraphTensor {
|
||||
self
|
||||
}
|
||||
|
||||
pub fn slice_along(self, slice: impl SliceRange, axis: usize) -> GraphTensor {
|
||||
let mut s = vec![(Expression::from(0), Expression::from(i32::MAX)); axis + 1];
|
||||
s[axis] = slice.bounds();
|
||||
self.slice(s)
|
||||
}
|
||||
|
||||
/// Cut out 'size' elements every 'spacing' elements in the last dimension. 'size' must be smaller than the last dimension
|
||||
pub fn excise(mut self, spacing: usize, size: usize) -> GraphTensor {
|
||||
let n_dims = self.shape.len();
|
||||
@@ -90,38 +96,36 @@ impl GraphTensor {
|
||||
/// Pool elements along the last dimension, pools are exposed as a new dimension
|
||||
pub fn pool_last_dim(
|
||||
mut self,
|
||||
kernel: impl Into<BigExpression>,
|
||||
stride: impl Into<BigExpression>,
|
||||
kernel: impl Into<Expression>,
|
||||
stride: impl Into<Expression>,
|
||||
dilation: usize,
|
||||
) -> GraphTensor {
|
||||
let (kernel, stride) = (kernel.into(), stride.into());
|
||||
let n_dims = self.shape.len();
|
||||
let full_kernel = kernel.clone() + (kernel.clone() - 1) * (dilation - 1);
|
||||
let dim_size = self.shape.shape().pop().unwrap().simplify().small();
|
||||
let number_of_windows = (((dim_size.big() - full_kernel.clone()) / stride.clone()) + 1)
|
||||
.simplify()
|
||||
.small();
|
||||
let full_kernel = kernel + (kernel - 1) * (dilation - 1);
|
||||
let dim_size = self.dims().pop().unwrap().simplify();
|
||||
let number_of_windows = (((dim_size - full_kernel) / stride) + 1).simplify();
|
||||
// Expand new dimension
|
||||
self.shape.expand(n_dims - 1, number_of_windows);
|
||||
self = self.contiguous();
|
||||
if n_dims > 1 {
|
||||
// View as single dimension of matrix with wider width
|
||||
let mat_size = (dim_size.big() + stride.clone()) * number_of_windows;
|
||||
let actual_size = (dim_size.big() * number_of_windows).simplify().small();
|
||||
let mat_size = (dim_size + stride) * number_of_windows;
|
||||
let actual_size = (dim_size * number_of_windows).simplify();
|
||||
// Reshape into single dimension to pad
|
||||
self.shape.remove_dim(n_dims);
|
||||
self.shape.dims[self.shape.indexes[n_dims - 1]] = actual_size;
|
||||
self.shape.padding[self.shape.indexes[n_dims - 1]].1 =
|
||||
(mat_size - actual_size).simplify().small();
|
||||
(mat_size - actual_size).simplify();
|
||||
self = self.contiguous();
|
||||
// Reshape back (mats should be full now)
|
||||
self.shape.add_dim(n_dims, dim_size + stride.clone());
|
||||
self.shape.add_dim(n_dims, dim_size + stride);
|
||||
self.shape.dims[self.shape.indexes[n_dims - 1]] = number_of_windows;
|
||||
} else {
|
||||
self.shape.dims[self.shape.indexes[n_dims]] = dim_size + stride;
|
||||
}
|
||||
// Slice down to kernel size
|
||||
self.shape.mask[self.shape.indexes[n_dims]].1 = full_kernel.simplify().small();
|
||||
self.shape.mask[self.shape.indexes[n_dims]].1 = full_kernel.simplify();
|
||||
self.shape.mask[self.shape.indexes[n_dims - 1]].1 = number_of_windows;
|
||||
self = self.contiguous();
|
||||
|
||||
@@ -147,14 +151,21 @@ impl GraphTensor {
|
||||
self
|
||||
}
|
||||
|
||||
pub fn pad_along(
|
||||
self,
|
||||
left: impl Into<Expression>,
|
||||
right: impl Into<Expression>,
|
||||
axis: usize,
|
||||
) -> GraphTensor {
|
||||
let mut p = vec![(Expression::from(0), Expression::from(0)); axis + 1];
|
||||
p[axis] = (left.into(), right.into());
|
||||
self.pad(p)
|
||||
}
|
||||
|
||||
pub fn concat_along(self, rhs: GraphTensor, axis: usize) -> GraphTensor {
|
||||
// Create padding
|
||||
let mut a_padding = vec![(Expression::default(), Expression::default()); self.shape.len()];
|
||||
a_padding[axis].1 = rhs.shape.shape()[axis].small();
|
||||
let mut b_padding = vec![(Expression::default(), Expression::default()); rhs.shape.len()];
|
||||
b_padding[axis].0 = self.shape.shape()[axis].small();
|
||||
// Pad and add
|
||||
self.pad(a_padding) + rhs.pad(b_padding)
|
||||
self.pad_along(0, rhs.shape.dims()[axis], axis)
|
||||
+ rhs.pad_along(self.shape.dims()[axis], 0, axis)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -273,7 +284,7 @@ mod tests {
|
||||
let a = cx
|
||||
.tensor((3, 2))
|
||||
.set(vec![1.4325, 2.492428, 3.127365, 33.2834, 4.18734, 23.854]);
|
||||
let b = a.slice((.., ..Expression::from(1))).retrieve();
|
||||
let b = a.slice((.., ..1)).retrieve();
|
||||
cx.execute();
|
||||
|
||||
let d_dev = Cpu::default();
|
||||
@@ -431,8 +442,8 @@ mod tests {
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor((3, 2));
|
||||
a.set(vec![1.4325, 2.492428, 3.127365, 33.2834, 4.18734, 23.854]);
|
||||
let x1 = a.slice((.., ..Expression::from(1))).contiguous();
|
||||
let x2 = a.slice((.., Expression::from(1)..)).contiguous();
|
||||
let x1 = a.slice((.., ..1)).contiguous();
|
||||
let x2 = a.slice((.., 1..)).contiguous();
|
||||
let c = (-x2).concat_along(x1, 1);
|
||||
c.retrieve();
|
||||
cx.execute();
|
||||
|
||||
@@ -48,24 +48,14 @@ impl From<f64> for ConstantValue {
|
||||
ConstantValue::Float(value as f32)
|
||||
}
|
||||
}
|
||||
impl From<BigExpression> for ConstantValue {
|
||||
fn from(value: BigExpression) -> Self {
|
||||
ConstantValue::Expression(value)
|
||||
}
|
||||
}
|
||||
impl From<Expression> for ConstantValue {
|
||||
fn from(value: Expression) -> Self {
|
||||
ConstantValue::Expression((&value).into())
|
||||
}
|
||||
}
|
||||
impl From<&BigExpression> for ConstantValue {
|
||||
fn from(value: &BigExpression) -> Self {
|
||||
ConstantValue::Expression(value.clone())
|
||||
ConstantValue::Expression(value)
|
||||
}
|
||||
}
|
||||
impl From<&Expression> for ConstantValue {
|
||||
fn from(value: &Expression) -> Self {
|
||||
ConstantValue::Expression(value.into())
|
||||
ConstantValue::Expression(*value)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -74,26 +64,26 @@ impl Graph {
|
||||
pub fn constant(&mut self, i: impl Into<ConstantValue>) -> GraphTensor {
|
||||
GraphTensor::from_id(
|
||||
self.add_op(Constant(i.into(), &self.dyn_map)).finish(),
|
||||
ShapeTracker::default(),
|
||||
ShapeTracker::new(()),
|
||||
self,
|
||||
)
|
||||
}
|
||||
|
||||
/// A scalar constant evaluated from an expression at runtime
|
||||
pub fn constant_expr<E: Into<BigExpression>>(&mut self, expr: E) -> GraphTensor {
|
||||
pub fn constant_expr<E: Into<Expression>>(&mut self, expr: E) -> GraphTensor {
|
||||
GraphTensor::from_id(
|
||||
self.add_op(Constant(
|
||||
ConstantValue::Expression(expr.into().simplify()),
|
||||
&self.dyn_map,
|
||||
))
|
||||
.finish(),
|
||||
ShapeTracker::default(),
|
||||
ShapeTracker::new(()),
|
||||
self,
|
||||
)
|
||||
}
|
||||
|
||||
/// ARange from 0 to N
|
||||
pub fn arange(&mut self, to: impl Into<Expression> + Copy) -> GraphTensor {
|
||||
pub fn arange(&mut self, to: impl Into<Expression>) -> GraphTensor {
|
||||
let to = to.into();
|
||||
if to.to_usize().map(|i| i == 1).unwrap_or_default() {
|
||||
// Single number ARange is just 0
|
||||
@@ -106,7 +96,8 @@ impl Graph {
|
||||
/// Lower left-hand triangle of 1s. Currently required to be square
|
||||
///
|
||||
/// Same API as https://pytorch.org/docs/stable/generated/torch.tril
|
||||
pub fn tril(&mut self, size: impl Into<Expression> + Copy, diagonal: i32) -> GraphTensor {
|
||||
pub fn tril(&mut self, size: impl Into<Expression>, diagonal: i32) -> GraphTensor {
|
||||
let size = size.into();
|
||||
let horizontal = self.arange(size).expand(0, size);
|
||||
let vertical = self.arange(size).expand(1, size);
|
||||
|
||||
@@ -116,7 +107,8 @@ impl Graph {
|
||||
/// Upper right-hand triangle of 1s
|
||||
///
|
||||
/// Same API as https://pytorch.org/docs/stable/generated/torch.triu
|
||||
pub fn triu(&mut self, size: impl Into<Expression> + Copy, diagonal: i32) -> GraphTensor {
|
||||
pub fn triu(&mut self, size: impl Into<Expression>, diagonal: i32) -> GraphTensor {
|
||||
let size = size.into();
|
||||
let horizontal = self.arange(size).expand(0, size);
|
||||
let vertical = self.arange(size).expand(1, size);
|
||||
|
||||
|
||||
@@ -52,23 +52,13 @@ impl GraphTensor {
|
||||
let mul_tensor = self
|
||||
.graph()
|
||||
.add_op(op::Recip)
|
||||
.input(div_tensor, 0, ShapeTracker::default())
|
||||
.input(div_tensor, 0, ShapeTracker::new(()))
|
||||
.finish();
|
||||
node_id = self
|
||||
.graph()
|
||||
.add_op(op::Mul)
|
||||
.input(node_id, 0, shape)
|
||||
.input(
|
||||
mul_tensor,
|
||||
0,
|
||||
ShapeTracker::fake(
|
||||
&shape
|
||||
.shape()
|
||||
.iter()
|
||||
.map(Expression::from)
|
||||
.collect::<Vec<_>>(),
|
||||
),
|
||||
)
|
||||
.input(mul_tensor, 0, ShapeTracker::fake(shape))
|
||||
.finish();
|
||||
}
|
||||
GraphTensor::from_id(node_id, shape, self.graph_ref)
|
||||
|
||||
@@ -128,12 +128,7 @@ impl GraphTensor {
|
||||
let r = self
|
||||
.graph()
|
||||
.constant(1.)
|
||||
.expand_to(ShapeTracker::new(&[self
|
||||
.shape
|
||||
.shape()
|
||||
.last()
|
||||
.unwrap()
|
||||
.small()]))
|
||||
.expand_to(ShapeTracker::new(self.shape.dims().last().unwrap()))
|
||||
.cumsum_last_dim()
|
||||
- 1.;
|
||||
// Multiply one-hot by expanded index arange
|
||||
|
||||
@@ -194,7 +194,7 @@ macro_rules! tuple_impls {
|
||||
|
||||
impl<$($name: SerializeModule,)+> SerializeModule for ($($name,)+) {
|
||||
fn serialize(&self, s: &mut Serializer) {
|
||||
$(s.module(&format!("layer{}", $idx), &self.$idx);)+
|
||||
$(s.module(&format!("{}", $idx), &self.$idx);)+
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -132,9 +132,9 @@ impl Debug for Function {
|
||||
}
|
||||
|
||||
/// A constant value placed on the graph at runtime. Can either be an expression evaluated at runtime, or a constant float
|
||||
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum ConstantValue {
|
||||
Expression(BigExpression),
|
||||
Expression(Expression),
|
||||
Float(f32),
|
||||
}
|
||||
|
||||
@@ -383,7 +383,7 @@ fn get_vec<'a>(tensor: &'a InputTensor<'a>) -> &'a Vec<f32> {
|
||||
|
||||
fn get_index(
|
||||
data: &[f32],
|
||||
(ind, val): &(BigExpression, BigExpression),
|
||||
(ind, val): &(Expression, Expression),
|
||||
stack: &mut Vec<i64>,
|
||||
index: usize,
|
||||
) -> f32 {
|
||||
|
||||
@@ -17,7 +17,7 @@ use std::ops::{Bound, Range, RangeBounds, RangeFrom, RangeFull, RangeTo, RangeTo
|
||||
fn get_start_bound<D: Into<Expression> + Copy>(bound: Bound<D>) -> Expression {
|
||||
match bound {
|
||||
Bound::Included(x) => x.into(),
|
||||
Bound::Excluded(x) => x.into() + Expression::from(1),
|
||||
Bound::Excluded(x) => x.into() + 1,
|
||||
Bound::Unbounded => 0.into(),
|
||||
}
|
||||
}
|
||||
@@ -25,7 +25,7 @@ fn get_start_bound<D: Into<Expression> + Copy>(bound: Bound<D>) -> Expression {
|
||||
fn get_end_bound<D: Into<Expression> + Copy>(bound: Bound<D>) -> Expression {
|
||||
match bound {
|
||||
Bound::Excluded(x) => x.into(),
|
||||
Bound::Included(x) => x.into() + Expression::from(1),
|
||||
Bound::Included(x) => x.into() + 1,
|
||||
Bound::Unbounded => Expression::from(i32::MAX),
|
||||
}
|
||||
}
|
||||
@@ -110,29 +110,29 @@ impl<R: SliceRange> SliceRange for (R,) {
|
||||
}
|
||||
|
||||
pub trait ToSlice {
|
||||
fn to_range_vec(&self) -> Vec<(Expression, Expression)>;
|
||||
fn to_range_vec(self) -> Vec<(Expression, Expression)>;
|
||||
}
|
||||
|
||||
impl<R: SliceRange> ToSlice for R {
|
||||
fn to_range_vec(&self) -> Vec<(Expression, Expression)> {
|
||||
fn to_range_vec(self) -> Vec<(Expression, Expression)> {
|
||||
vec![self.bounds()]
|
||||
}
|
||||
}
|
||||
|
||||
impl<R1: SliceRange, R2: SliceRange> ToSlice for (R1, R2) {
|
||||
fn to_range_vec(&self) -> Vec<(Expression, Expression)> {
|
||||
fn to_range_vec(self) -> Vec<(Expression, Expression)> {
|
||||
vec![self.0.bounds(), self.1.bounds()]
|
||||
}
|
||||
}
|
||||
|
||||
impl<R1: SliceRange, R2: SliceRange, R3: SliceRange> ToSlice for (R1, R2, R3) {
|
||||
fn to_range_vec(&self) -> Vec<(Expression, Expression)> {
|
||||
fn to_range_vec(self) -> Vec<(Expression, Expression)> {
|
||||
vec![self.0.bounds(), self.1.bounds(), self.2.bounds()]
|
||||
}
|
||||
}
|
||||
|
||||
impl<R1: SliceRange, R2: SliceRange, R3: SliceRange, R4: SliceRange> ToSlice for (R1, R2, R3, R4) {
|
||||
fn to_range_vec(&self) -> Vec<(Expression, Expression)> {
|
||||
fn to_range_vec(self) -> Vec<(Expression, Expression)> {
|
||||
vec![
|
||||
self.0.bounds(),
|
||||
self.1.bounds(),
|
||||
@@ -145,7 +145,7 @@ impl<R1: SliceRange, R2: SliceRange, R3: SliceRange, R4: SliceRange> ToSlice for
|
||||
impl<R1: SliceRange, R2: SliceRange, R3: SliceRange, R4: SliceRange, R5: SliceRange> ToSlice
|
||||
for (R1, R2, R3, R4, R5)
|
||||
{
|
||||
fn to_range_vec(&self) -> Vec<(Expression, Expression)> {
|
||||
fn to_range_vec(self) -> Vec<(Expression, Expression)> {
|
||||
vec![
|
||||
self.0.bounds(),
|
||||
self.1.bounds(),
|
||||
@@ -156,6 +156,32 @@ impl<R1: SliceRange, R2: SliceRange, R3: SliceRange, R4: SliceRange, R5: SliceRa
|
||||
}
|
||||
}
|
||||
|
||||
impl<A: Into<Expression>, B: Into<Expression>> ToSlice for Vec<(A, B)> {
|
||||
fn to_range_vec(self) -> Vec<(Expression, Expression)> {
|
||||
self.into_iter().map(|i| (i.0.into(), i.1.into())).collect()
|
||||
}
|
||||
}
|
||||
|
||||
impl<A: Into<Expression> + Copy, B: Into<Expression> + Copy> ToSlice for &Vec<(A, B)> {
|
||||
fn to_range_vec(self) -> Vec<(Expression, Expression)> {
|
||||
self.iter().map(|i| (i.0.into(), i.1.into())).collect()
|
||||
}
|
||||
}
|
||||
|
||||
impl<A: Into<Expression> + Copy, B: Into<Expression> + Copy> ToSlice for &[(A, B)] {
|
||||
fn to_range_vec(self) -> Vec<(Expression, Expression)> {
|
||||
self.iter().map(|i| (i.0.into(), i.1.into())).collect()
|
||||
}
|
||||
}
|
||||
|
||||
impl<const N: usize, A: Into<Expression> + Copy, B: Into<Expression> + Copy> ToSlice
|
||||
for &[(A, B); N]
|
||||
{
|
||||
fn to_range_vec(self) -> Vec<(Expression, Expression)> {
|
||||
self.iter().map(|i| (i.0.into(), i.1.into())).collect()
|
||||
}
|
||||
}
|
||||
|
||||
pub trait ToPad {
|
||||
fn to_pad_vec(self) -> Vec<(Expression, Expression)>;
|
||||
}
|
||||
@@ -464,6 +490,6 @@ impl<A: Into<Expression>> ToShape for A {
|
||||
|
||||
impl ToShape for ShapeTracker {
|
||||
fn to_shape(self) -> Vec<Expression> {
|
||||
self.shape().to_shape()
|
||||
self.dims().to_shape()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,20 +1,55 @@
|
||||
use egg::*;
|
||||
use generational_box::{AnyStorage, GenerationalBox, Owner, UnsyncStorage};
|
||||
use rustc_hash::FxHashMap;
|
||||
use std::{
|
||||
cell::RefCell,
|
||||
fmt::Debug,
|
||||
hash::Hash,
|
||||
ops::{
|
||||
Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, Div, DivAssign, IndexMut, Mul,
|
||||
MulAssign, Rem, RemAssign, Sub, SubAssign,
|
||||
Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, Div, DivAssign, Mul, MulAssign,
|
||||
Rem, RemAssign, Sub, SubAssign,
|
||||
},
|
||||
};
|
||||
use symbolic_expressions::Sexp;
|
||||
use tinyvec::ArrayVec;
|
||||
|
||||
/// A symbolic expression stored on the stack
|
||||
pub type Expression = GenericExpression<ArrayVec<[Term; 20]>>; // We need to figure out how to reduce this, can't be fixed at 20. ShapeTracker would take up 6 dims * 12 pads * 12 slices * 20 terms * 8 bytes = 138kb
|
||||
/// A symbolic expression stored on the heap
|
||||
pub type BigExpression = GenericExpression<Vec<Term>>;
|
||||
thread_local! {
|
||||
static EXPRESSION_OWNER: RefCell<Option<Owner<UnsyncStorage>>> = RefCell::new(Some(UnsyncStorage::owner()));
|
||||
}
|
||||
|
||||
/// Clean up symbolic expresion storage
|
||||
pub fn expression_cleanup() {
|
||||
EXPRESSION_OWNER.with(|cell| cell.borrow_mut().take());
|
||||
}
|
||||
|
||||
/// Get the thread-local owner of expression storage
|
||||
fn expression_owner() -> Owner {
|
||||
EXPRESSION_OWNER.with(|cell| cell.borrow().clone().unwrap())
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub struct Expression {
|
||||
pub terms: GenerationalBox<Vec<Term>>,
|
||||
}
|
||||
|
||||
impl Expression {
|
||||
fn new(terms: Vec<Term>) -> Self {
|
||||
Self {
|
||||
terms: expression_owner().insert(terms),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Hash for Expression {
|
||||
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
|
||||
self.terms.read().hash(state);
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Expression {
|
||||
fn default() -> Self {
|
||||
Expression::new(vec![])
|
||||
}
|
||||
}
|
||||
|
||||
/// A single term of a symbolic expression such as a variable, number or operation.
|
||||
#[derive(Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
|
||||
@@ -79,99 +114,27 @@ impl Term {
|
||||
}
|
||||
}
|
||||
|
||||
/// Trait implemented on the 2 main symbolic expression storage types, Vec<Term> and ArrayVec<Term>
|
||||
#[allow(clippy::len_without_is_empty)]
|
||||
pub trait ExpressionStorage:
|
||||
Clone
|
||||
+ IndexMut<usize, Output = Term>
|
||||
+ std::iter::Extend<Term>
|
||||
+ Default
|
||||
+ PartialEq
|
||||
+ Debug
|
||||
+ Hash
|
||||
+ Eq
|
||||
{
|
||||
fn len(&self) -> usize;
|
||||
fn push(&mut self, term: Term);
|
||||
fn pop(&mut self) -> Option<Term>;
|
||||
fn remove(&mut self, index: usize) -> Term;
|
||||
fn into_vec(self) -> Vec<Term>;
|
||||
fn iter_ref(&self) -> impl Iterator<Item = &Term>;
|
||||
}
|
||||
|
||||
// Implement the main storage types
|
||||
impl ExpressionStorage for Vec<Term> {
|
||||
fn len(&self) -> usize {
|
||||
Vec::len(self)
|
||||
}
|
||||
fn push(&mut self, term: Term) {
|
||||
Vec::push(self, term)
|
||||
}
|
||||
fn pop(&mut self) -> Option<Term> {
|
||||
Vec::pop(self)
|
||||
}
|
||||
fn remove(&mut self, index: usize) -> Term {
|
||||
Vec::remove(self, index)
|
||||
}
|
||||
fn into_vec(self) -> Vec<Term> {
|
||||
self
|
||||
}
|
||||
fn iter_ref(&self) -> impl Iterator<Item = &Term> {
|
||||
self.iter()
|
||||
}
|
||||
}
|
||||
|
||||
impl<const C: usize> ExpressionStorage for ArrayVec<[Term; C]>
|
||||
impl<T> PartialEq<T> for Expression
|
||||
where
|
||||
[Term; C]: tinyvec::Array<Item = Term>,
|
||||
{
|
||||
fn len(&self) -> usize {
|
||||
ArrayVec::len(self)
|
||||
}
|
||||
fn push(&mut self, term: Term) {
|
||||
ArrayVec::push(self, term)
|
||||
}
|
||||
fn pop(&mut self) -> Option<Term> {
|
||||
ArrayVec::pop(self)
|
||||
}
|
||||
fn remove(&mut self, index: usize) -> Term {
|
||||
ArrayVec::remove(self, index)
|
||||
}
|
||||
fn into_vec(self) -> Vec<Term> {
|
||||
self.to_vec()
|
||||
}
|
||||
fn iter_ref(&self) -> impl Iterator<Item = &Term> {
|
||||
self.iter()
|
||||
}
|
||||
}
|
||||
|
||||
/// A symbolic expression
|
||||
#[derive(Clone, Copy, Hash, Eq, serde::Serialize, serde::Deserialize)]
|
||||
pub struct GenericExpression<S: ExpressionStorage> {
|
||||
pub terms: S, // Terms in postfix notation
|
||||
}
|
||||
|
||||
impl<S: ExpressionStorage, T> PartialEq<T> for GenericExpression<S>
|
||||
where
|
||||
for<'a> &'a T: Into<Self>,
|
||||
for<'a> &'a T: Into<Expression>,
|
||||
{
|
||||
fn eq(&self, other: &T) -> bool {
|
||||
self.terms == other.into().terms
|
||||
*self.terms.read() == *other.into().terms.read()
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: ExpressionStorage> Default for GenericExpression<S> {
|
||||
fn default() -> Self {
|
||||
let mut s = S::default();
|
||||
s.push(Term::Num(0));
|
||||
Self { terms: s }
|
||||
impl From<&Expression> for Expression {
|
||||
fn from(value: &Expression) -> Self {
|
||||
*value
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: ExpressionStorage + Clone> Debug for GenericExpression<S> {
|
||||
impl Eq for Expression {}
|
||||
|
||||
impl Debug for Expression {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let mut symbols = vec![];
|
||||
for term in self.terms.iter_ref() {
|
||||
for term in self.terms.read().iter() {
|
||||
let new_symbol = match term {
|
||||
Term::Num(n) => n.to_string(),
|
||||
Term::Var(c) => c.to_string(),
|
||||
@@ -197,16 +160,16 @@ impl<S: ExpressionStorage + Clone> Debug for GenericExpression<S> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: ExpressionStorage + Clone> std::fmt::Display for GenericExpression<S> {
|
||||
impl std::fmt::Display for Expression {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{:?}", self)
|
||||
write!(f, "{self:?}")
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: ExpressionStorage> GenericExpression<S> {
|
||||
impl Expression {
|
||||
/// Simplify the expression to its minimal terms
|
||||
pub fn simplify(self) -> Self {
|
||||
if self.terms.len() == 1 {
|
||||
if self.terms.read().len() == 1 {
|
||||
return self;
|
||||
}
|
||||
egg_simplify(self)
|
||||
@@ -215,17 +178,17 @@ impl<S: ExpressionStorage> GenericExpression<S> {
|
||||
/// Simplify the expression to its minimal terms, using a cache to retrieve / store the simplification
|
||||
pub fn simplify_cache(self, cache: &mut FxHashMap<Self, Self>) -> Self {
|
||||
if let Some(s) = cache.get(&self) {
|
||||
s.clone()
|
||||
*s
|
||||
} else {
|
||||
let simplified = self.clone().simplify();
|
||||
cache.insert(self, simplified.clone());
|
||||
let simplified = self.simplify();
|
||||
cache.insert(self, simplified);
|
||||
simplified
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_num(&self) -> Option<i32> {
|
||||
if let Term::Num(n) = self.terms[0] {
|
||||
if self.terms.len() == 1 {
|
||||
if let Term::Num(n) = self.terms.read()[0] {
|
||||
if self.terms.read().len() == 1 {
|
||||
return Some(n);
|
||||
}
|
||||
}
|
||||
@@ -233,22 +196,23 @@ impl<S: ExpressionStorage> GenericExpression<S> {
|
||||
}
|
||||
|
||||
/// Minimum
|
||||
pub fn min<E: Into<Self>>(self, rhs: E) -> Self {
|
||||
let mut rhs = rhs.into();
|
||||
pub fn min(self, rhs: impl Into<Self>) -> Self {
|
||||
let rhs = rhs.into();
|
||||
if rhs == self || rhs == i32::MAX {
|
||||
return self;
|
||||
}
|
||||
if let (Some(a), Some(b)) = (self.as_num(), rhs.as_num()) {
|
||||
return a.min(b).into();
|
||||
}
|
||||
rhs.terms.extend(self.terms.iter_ref().copied());
|
||||
rhs.terms.push(Term::Min);
|
||||
rhs
|
||||
let mut terms = rhs.terms.read().clone();
|
||||
terms.extend(self.terms.read().iter().copied());
|
||||
terms.push(Term::Min);
|
||||
Expression::new(terms)
|
||||
}
|
||||
|
||||
/// Maximum
|
||||
pub fn max<E: Into<Self>>(self, rhs: E) -> Self {
|
||||
let mut rhs = rhs.into();
|
||||
pub fn max<E: Into<Expression>>(self, rhs: E) -> Self {
|
||||
let rhs = rhs.into();
|
||||
if rhs == self || rhs == 0 || self == i32::MAX {
|
||||
return self;
|
||||
}
|
||||
@@ -258,14 +222,15 @@ impl<S: ExpressionStorage> GenericExpression<S> {
|
||||
if let (Some(a), Some(b)) = (self.as_num(), rhs.as_num()) {
|
||||
return a.max(b).into();
|
||||
}
|
||||
rhs.terms.extend(self.terms.iter_ref().copied());
|
||||
rhs.terms.push(Term::Max);
|
||||
rhs
|
||||
let mut terms = rhs.terms.read().clone();
|
||||
terms.extend(self.terms.read().iter().copied());
|
||||
terms.push(Term::Max);
|
||||
Expression::new(terms)
|
||||
}
|
||||
|
||||
/// Greater than or equals
|
||||
pub fn gte<E: Into<Self>>(self, rhs: E) -> Self {
|
||||
let mut rhs = rhs.into();
|
||||
pub fn gte<E: Into<Expression>>(self, rhs: E) -> Self {
|
||||
let rhs = rhs.into();
|
||||
if rhs == self {
|
||||
return true.into();
|
||||
}
|
||||
@@ -275,37 +240,42 @@ impl<S: ExpressionStorage> GenericExpression<S> {
|
||||
if let (Some(a), Some(b)) = (self.as_num(), rhs.as_num()) {
|
||||
return (a >= b).into();
|
||||
}
|
||||
rhs.terms.extend(self.terms.iter_ref().copied());
|
||||
rhs.terms.push(Term::Gte);
|
||||
rhs
|
||||
let mut terms = rhs.terms.read().clone();
|
||||
terms.extend(self.terms.read().iter().copied());
|
||||
terms.push(Term::Gte);
|
||||
Expression::new(terms)
|
||||
}
|
||||
|
||||
/// Less than
|
||||
pub fn lt<E: Into<Self>>(self, rhs: E) -> Self {
|
||||
let mut rhs = rhs.into();
|
||||
pub fn lt<E: Into<Expression>>(self, rhs: E) -> Self {
|
||||
let rhs = rhs.into();
|
||||
if rhs == self {
|
||||
return false.into();
|
||||
}
|
||||
if let Term::Num(n) = rhs.terms[0] {
|
||||
if self.terms[self.terms.len() - 1] == Term::Mod && self.terms[0] == Term::Num(n) {
|
||||
if let Term::Num(n) = rhs.terms.read()[0] {
|
||||
if self.terms.read()[self.terms.read().len() - 1] == Term::Mod
|
||||
&& self.terms.read()[0] == Term::Num(n)
|
||||
{
|
||||
return true.into();
|
||||
}
|
||||
}
|
||||
if let (Some(a), Some(b)) = (self.as_num(), rhs.as_num()) {
|
||||
return (a < b).into();
|
||||
}
|
||||
rhs.terms.extend(self.terms.iter_ref().copied());
|
||||
rhs.terms.push(Term::Lt);
|
||||
rhs
|
||||
let mut terms = rhs.terms.read().clone();
|
||||
terms.extend(self.terms.read().iter().copied());
|
||||
terms.push(Term::Lt);
|
||||
Expression::new(terms)
|
||||
}
|
||||
|
||||
/// Substitute an expression for a variable
|
||||
pub fn substitute<N: ExpressionStorage>(self, var: char, expr: GenericExpression<N>) -> Self {
|
||||
let mut new_terms = S::default();
|
||||
for term in self.terms.iter_ref() {
|
||||
pub fn substitute(self, var: char, expr: impl Into<Expression>) -> Self {
|
||||
let mut new_terms = vec![];
|
||||
let t = expr.into().terms.read();
|
||||
for term in self.terms.read().iter() {
|
||||
match term {
|
||||
Term::Var(c) if *c == var => {
|
||||
for t in expr.terms.iter_ref() {
|
||||
for t in t.iter() {
|
||||
new_terms.push(*t);
|
||||
}
|
||||
}
|
||||
@@ -314,11 +284,11 @@ impl<S: ExpressionStorage> GenericExpression<S> {
|
||||
}
|
||||
}
|
||||
}
|
||||
Self { terms: new_terms }
|
||||
Expression::new(new_terms)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: ExpressionStorage> GenericExpression<S> {
|
||||
impl Expression {
|
||||
/// Evaluate the expression with no variables. Returns Some(value) if no variables are required, otherwise returns None.
|
||||
pub fn to_usize(&self) -> Option<usize> {
|
||||
self.exec(&FxHashMap::default())
|
||||
@@ -330,7 +300,7 @@ impl<S: ExpressionStorage> GenericExpression<S> {
|
||||
}
|
||||
/// Evaluate the expression with one value for all variables. Uses a provided stack
|
||||
pub fn exec_single_var_stack(&self, value: usize, stack: &mut Vec<i64>) -> usize {
|
||||
for term in self.terms.iter_ref() {
|
||||
for term in self.terms.read().iter() {
|
||||
match term {
|
||||
Term::Num(n) => stack.push(*n as i64),
|
||||
Term::Var(_) => stack.push(value as i64),
|
||||
@@ -353,7 +323,7 @@ impl<S: ExpressionStorage> GenericExpression<S> {
|
||||
variables: &FxHashMap<char, usize>,
|
||||
stack: &mut Vec<i64>,
|
||||
) -> Option<usize> {
|
||||
for term in self.terms.iter_ref() {
|
||||
for term in self.terms.read().iter() {
|
||||
match term {
|
||||
Term::Num(n) => stack.push(*n as i64),
|
||||
Term::Var(c) =>
|
||||
@@ -377,7 +347,8 @@ impl<S: ExpressionStorage> GenericExpression<S> {
|
||||
/// Retrieve all symbols in the expression.
|
||||
pub fn to_symbols(&self) -> Vec<char> {
|
||||
self.terms
|
||||
.iter_ref()
|
||||
.read()
|
||||
.iter()
|
||||
.filter_map(|t| match t {
|
||||
Term::Var(c) => Some(*c),
|
||||
_ => None,
|
||||
@@ -387,108 +358,92 @@ impl<S: ExpressionStorage> GenericExpression<S> {
|
||||
|
||||
/// Check if the '-' variable exists in the expression.
|
||||
pub fn is_unknown(&self) -> bool {
|
||||
self.terms.iter_ref().any(|t| matches!(t, Term::Var('-')))
|
||||
self.terms
|
||||
.read()
|
||||
.iter()
|
||||
.any(|t| matches!(t, Term::Var('-')))
|
||||
}
|
||||
}
|
||||
|
||||
impl Expression {
|
||||
pub fn big(&self) -> BigExpression {
|
||||
BigExpression::from(*self)
|
||||
}
|
||||
}
|
||||
|
||||
impl BigExpression {
|
||||
pub fn small(&self) -> Expression {
|
||||
Expression::from(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: ExpressionStorage> From<Term> for GenericExpression<S> {
|
||||
impl From<Term> for Expression {
|
||||
fn from(value: Term) -> Self {
|
||||
let mut terms = S::default();
|
||||
terms.push(value);
|
||||
GenericExpression { terms }
|
||||
Expression::new(vec![value])
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: ExpressionStorage> From<char> for GenericExpression<S> {
|
||||
impl From<char> for Expression {
|
||||
fn from(value: char) -> Self {
|
||||
GenericExpression::from(Term::Var(value))
|
||||
Expression::new(vec![Term::Var(value)])
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: ExpressionStorage> From<&char> for GenericExpression<S> {
|
||||
impl From<&char> for Expression {
|
||||
fn from(value: &char) -> Self {
|
||||
GenericExpression::from(Term::Var(*value))
|
||||
Expression::new(vec![Term::Var(*value)])
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: ExpressionStorage> From<usize> for GenericExpression<S> {
|
||||
impl From<usize> for Expression {
|
||||
fn from(value: usize) -> Self {
|
||||
GenericExpression::from(Term::Num(value as i32))
|
||||
Expression::new(vec![Term::Num(value as i32)])
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: ExpressionStorage> From<&usize> for GenericExpression<S> {
|
||||
impl From<&usize> for Expression {
|
||||
fn from(value: &usize) -> Self {
|
||||
GenericExpression::from(Term::Num(*value as i32))
|
||||
Expression::new(vec![Term::Num(*value as i32)])
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: ExpressionStorage> From<i32> for GenericExpression<S> {
|
||||
impl From<i32> for Expression {
|
||||
fn from(value: i32) -> Self {
|
||||
GenericExpression::from(value as usize)
|
||||
Expression::new(vec![Term::Num(value)])
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: ExpressionStorage> From<&i32> for GenericExpression<S> {
|
||||
impl From<&i32> for Expression {
|
||||
fn from(value: &i32) -> Self {
|
||||
GenericExpression::from(*value as usize)
|
||||
Expression::new(vec![Term::Num(*value)])
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: ExpressionStorage> From<bool> for GenericExpression<S> {
|
||||
impl From<bool> for Expression {
|
||||
fn from(value: bool) -> Self {
|
||||
GenericExpression::from(value as usize)
|
||||
Expression::new(vec![Term::Num(value as i32)])
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: ExpressionStorage> From<&bool> for GenericExpression<S> {
|
||||
impl From<&bool> for Expression {
|
||||
fn from(value: &bool) -> Self {
|
||||
GenericExpression::from(*value as usize)
|
||||
Expression::new(vec![Term::Num(*value as i32)])
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: ExpressionStorage, T: ExpressionStorage> From<&GenericExpression<T>>
|
||||
for GenericExpression<S>
|
||||
{
|
||||
fn from(value: &GenericExpression<T>) -> Self {
|
||||
let mut s = S::default();
|
||||
s.extend(value.terms.iter_ref().copied());
|
||||
Self { terms: s }
|
||||
impl Sub<Expression> for usize {
|
||||
type Output = Expression;
|
||||
fn sub(self, rhs: Expression) -> Self::Output {
|
||||
Expression::from(self) - rhs
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Expression> for BigExpression {
|
||||
fn from(value: Expression) -> Self {
|
||||
Self {
|
||||
terms: value.terms.to_vec(),
|
||||
}
|
||||
impl Mul<Expression> for usize {
|
||||
type Output = Expression;
|
||||
fn mul(self, rhs: Expression) -> Self::Output {
|
||||
rhs * self
|
||||
}
|
||||
}
|
||||
|
||||
impl From<BigExpression> for Expression {
|
||||
fn from(value: BigExpression) -> Self {
|
||||
let mut terms = ArrayVec::new();
|
||||
terms.extend(value.terms);
|
||||
Self { terms }
|
||||
impl Div<Expression> for usize {
|
||||
type Output = Expression;
|
||||
fn div(self, rhs: Expression) -> Self::Output {
|
||||
Expression::from(self) / rhs
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: ExpressionStorage, E: Into<Self>> Add<E> for GenericExpression<S> {
|
||||
impl<E: Into<Expression>> Add<E> for Expression {
|
||||
type Output = Self;
|
||||
fn add(self, rhs: E) -> Self::Output {
|
||||
let mut rhs = rhs.into();
|
||||
let rhs = rhs.into();
|
||||
if rhs == 0 {
|
||||
return self;
|
||||
}
|
||||
@@ -501,16 +456,17 @@ impl<S: ExpressionStorage, E: Into<Self>> Add<E> for GenericExpression<S> {
|
||||
if let (Some(a), Some(b)) = (self.as_num(), rhs.as_num()) {
|
||||
return (a + b).into();
|
||||
}
|
||||
rhs.terms.extend(self.terms.iter_ref().copied());
|
||||
rhs.terms.push(Term::Add);
|
||||
rhs
|
||||
let mut terms = rhs.terms.read().clone();
|
||||
terms.extend(self.terms.read().iter().copied());
|
||||
terms.push(Term::Add);
|
||||
Expression::new(terms)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: ExpressionStorage, E: Into<Self>> Sub<E> for GenericExpression<S> {
|
||||
impl<E: Into<Expression>> Sub<E> for Expression {
|
||||
type Output = Self;
|
||||
fn sub(self, rhs: E) -> Self::Output {
|
||||
let mut rhs = rhs.into();
|
||||
let rhs = rhs.into();
|
||||
if rhs == 0 {
|
||||
return self;
|
||||
}
|
||||
@@ -520,16 +476,17 @@ impl<S: ExpressionStorage, E: Into<Self>> Sub<E> for GenericExpression<S> {
|
||||
if let (Some(a), Some(b)) = (self.as_num(), rhs.as_num()) {
|
||||
return (a - b).into();
|
||||
}
|
||||
rhs.terms.extend(self.terms.iter_ref().copied());
|
||||
rhs.terms.push(Term::Sub);
|
||||
rhs
|
||||
let mut terms = rhs.terms.read().clone();
|
||||
terms.extend(self.terms.read().iter().copied());
|
||||
terms.push(Term::Sub);
|
||||
Expression::new(terms)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: ExpressionStorage, E: Into<Self>> Mul<E> for GenericExpression<S> {
|
||||
impl<E: Into<Expression>> Mul<E> for Expression {
|
||||
type Output = Self;
|
||||
fn mul(self, rhs: E) -> Self::Output {
|
||||
let mut rhs = rhs.into();
|
||||
let rhs = rhs.into();
|
||||
if rhs == 1 {
|
||||
return self;
|
||||
}
|
||||
@@ -544,16 +501,17 @@ impl<S: ExpressionStorage, E: Into<Self>> Mul<E> for GenericExpression<S> {
|
||||
return c.into();
|
||||
}
|
||||
}
|
||||
rhs.terms.extend(self.terms.iter_ref().copied());
|
||||
rhs.terms.push(Term::Mul);
|
||||
rhs
|
||||
let mut terms = rhs.terms.read().clone();
|
||||
terms.extend(self.terms.read().iter().copied());
|
||||
terms.push(Term::Mul);
|
||||
Expression::new(terms)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: ExpressionStorage, E: Into<Self>> Div<E> for GenericExpression<S> {
|
||||
impl<E: Into<Expression>> Div<E> for Expression {
|
||||
type Output = Self;
|
||||
fn div(self, rhs: E) -> Self::Output {
|
||||
let mut rhs = rhs.into();
|
||||
let rhs = rhs.into();
|
||||
if rhs == 1 {
|
||||
return self;
|
||||
}
|
||||
@@ -566,32 +524,34 @@ impl<S: ExpressionStorage, E: Into<Self>> Div<E> for GenericExpression<S> {
|
||||
if let (Some(a), Some(b)) = (self.as_num(), rhs.as_num()) {
|
||||
return (a / b).into();
|
||||
}
|
||||
rhs.terms.extend(self.terms.iter_ref().copied());
|
||||
rhs.terms.push(Term::Div);
|
||||
rhs
|
||||
let mut terms = rhs.terms.read().clone();
|
||||
terms.extend(self.terms.read().iter().copied());
|
||||
terms.push(Term::Div);
|
||||
Expression::new(terms)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: ExpressionStorage, E: Into<Self>> Rem<E> for GenericExpression<S> {
|
||||
impl<E: Into<Expression>> Rem<E> for Expression {
|
||||
type Output = Self;
|
||||
fn rem(self, rhs: E) -> Self::Output {
|
||||
let mut rhs = rhs.into();
|
||||
let rhs = rhs.into();
|
||||
if rhs == 1 || rhs == self {
|
||||
return 0.into();
|
||||
}
|
||||
if let (Some(a), Some(b)) = (self.as_num(), rhs.as_num()) {
|
||||
return (a % b).into();
|
||||
}
|
||||
rhs.terms.extend(self.terms.iter_ref().copied());
|
||||
rhs.terms.push(Term::Mod);
|
||||
rhs
|
||||
let mut terms = rhs.terms.read().clone();
|
||||
terms.extend(self.terms.read().iter().copied());
|
||||
terms.push(Term::Mod);
|
||||
Expression::new(terms)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: ExpressionStorage, E: Into<Self>> BitAnd<E> for GenericExpression<S> {
|
||||
impl<E: Into<Expression>> BitAnd<E> for Expression {
|
||||
type Output = Self;
|
||||
fn bitand(self, rhs: E) -> Self::Output {
|
||||
let mut rhs = rhs.into();
|
||||
let rhs = rhs.into();
|
||||
if rhs == 0 || self == 0 {
|
||||
return 0.into();
|
||||
}
|
||||
@@ -604,30 +564,32 @@ impl<S: ExpressionStorage, E: Into<Self>> BitAnd<E> for GenericExpression<S> {
|
||||
if let (Some(a), Some(b)) = (self.as_num(), rhs.as_num()) {
|
||||
return (a != 0 && b != 0).into();
|
||||
}
|
||||
rhs.terms.extend(self.terms.iter_ref().copied());
|
||||
rhs.terms.push(Term::And);
|
||||
rhs
|
||||
let mut terms = rhs.terms.read().clone();
|
||||
terms.extend(self.terms.read().iter().copied());
|
||||
terms.push(Term::And);
|
||||
Expression::new(terms)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: ExpressionStorage, E: Into<Self>> BitOr<E> for GenericExpression<S> {
|
||||
impl<E: Into<Expression>> BitOr<E> for Expression {
|
||||
type Output = Self;
|
||||
fn bitor(self, rhs: E) -> Self::Output {
|
||||
let mut rhs = rhs.into();
|
||||
let rhs = rhs.into();
|
||||
if rhs == 1 || self == 1 {
|
||||
return 1.into();
|
||||
}
|
||||
if let (Some(a), Some(b)) = (self.as_num(), rhs.as_num()) {
|
||||
return (a != 0 || b != 0).into();
|
||||
}
|
||||
rhs.terms.extend(self.terms.iter_ref().copied());
|
||||
rhs.terms.push(Term::Or);
|
||||
rhs
|
||||
let mut terms = rhs.terms.read().clone();
|
||||
terms.extend(self.terms.read().iter().copied());
|
||||
terms.push(Term::Or);
|
||||
Expression::new(terms)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: ExpressionStorage> std::iter::Product for GenericExpression<S> {
|
||||
fn product<I: Iterator<Item = GenericExpression<S>>>(mut iter: I) -> Self {
|
||||
impl std::iter::Product for Expression {
|
||||
fn product<I: Iterator<Item = Expression>>(mut iter: I) -> Self {
|
||||
let Some(mut p) = iter.next() else {
|
||||
return 0.into();
|
||||
};
|
||||
@@ -638,45 +600,45 @@ impl<S: ExpressionStorage> std::iter::Product for GenericExpression<S> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: ExpressionStorage, E: Into<Self>> AddAssign<E> for GenericExpression<S> {
|
||||
impl<E: Into<Expression>> AddAssign<E> for Expression {
|
||||
fn add_assign(&mut self, rhs: E) {
|
||||
*self = self.clone() + rhs;
|
||||
*self = *self + rhs;
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: ExpressionStorage, E: Into<Self>> SubAssign<E> for GenericExpression<S> {
|
||||
impl<E: Into<Expression>> SubAssign<E> for Expression {
|
||||
fn sub_assign(&mut self, rhs: E) {
|
||||
*self = self.clone() - rhs;
|
||||
*self = *self - rhs;
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: ExpressionStorage, E: Into<Self>> MulAssign<E> for GenericExpression<S> {
|
||||
impl<E: Into<Expression>> MulAssign<E> for Expression {
|
||||
fn mul_assign(&mut self, rhs: E) {
|
||||
*self = self.clone() * rhs;
|
||||
*self = *self * rhs;
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: ExpressionStorage, E: Into<Self>> DivAssign<E> for GenericExpression<S> {
|
||||
impl<E: Into<Expression>> DivAssign<E> for Expression {
|
||||
fn div_assign(&mut self, rhs: E) {
|
||||
*self = self.clone() / rhs;
|
||||
*self = *self / rhs;
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: ExpressionStorage, E: Into<Self>> RemAssign<E> for GenericExpression<S> {
|
||||
impl<E: Into<Expression>> RemAssign<E> for Expression {
|
||||
fn rem_assign(&mut self, rhs: E) {
|
||||
*self = self.clone() % rhs;
|
||||
*self = *self % rhs;
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: ExpressionStorage, E: Into<Self>> BitAndAssign<E> for GenericExpression<S> {
|
||||
impl<E: Into<Expression>> BitAndAssign<E> for Expression {
|
||||
fn bitand_assign(&mut self, rhs: E) {
|
||||
*self = self.clone() & rhs;
|
||||
*self = *self & rhs;
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: ExpressionStorage, E: Into<Self>> BitOrAssign<E> for GenericExpression<S> {
|
||||
impl<E: Into<Expression>> BitOrAssign<E> for Expression {
|
||||
fn bitor_assign(&mut self, rhs: E) {
|
||||
*self = self.clone() | rhs;
|
||||
*self = *self | rhs;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -689,10 +651,10 @@ define_language! {
|
||||
}
|
||||
}
|
||||
|
||||
fn luminal_to_egg<S: ExpressionStorage>(expr: &GenericExpression<S>) -> RecExpr<Math> {
|
||||
fn luminal_to_egg(expr: &Expression) -> RecExpr<Math> {
|
||||
let mut stack = Vec::new();
|
||||
|
||||
for term in expr.terms.iter_ref() {
|
||||
for term in expr.terms.read().iter() {
|
||||
match term {
|
||||
Term::Num(_) | Term::Var(_) => {
|
||||
stack.push(symbolic_expressions::Sexp::String(format!("{term:?}")))
|
||||
@@ -741,7 +703,7 @@ fn luminal_to_egg<S: ExpressionStorage>(expr: &GenericExpression<S>) -> RecExpr<
|
||||
expr
|
||||
}
|
||||
|
||||
fn egg_to_luminal<S: ExpressionStorage>(expr: RecExpr<Math>) -> GenericExpression<S> {
|
||||
fn egg_to_luminal(expr: RecExpr<Math>) -> Expression {
|
||||
fn create_postfix(expr: &[Math]) -> Vec<Term> {
|
||||
match expr.last().unwrap() {
|
||||
Math::Num(i) => vec![Term::Num(*i)],
|
||||
@@ -814,9 +776,9 @@ fn egg_to_luminal<S: ExpressionStorage>(expr: RecExpr<Math>) -> GenericExpressio
|
||||
Math::Symbol(s) => vec![Term::Var(s.as_str().chars().next().unwrap())],
|
||||
}
|
||||
}
|
||||
let mut terms = S::default();
|
||||
let mut terms = vec![];
|
||||
terms.extend(create_postfix(expr.as_ref()));
|
||||
GenericExpression { terms }
|
||||
Expression::new(terms)
|
||||
}
|
||||
|
||||
type EGraph = egg::EGraph<Math, ConstantFold>;
|
||||
@@ -962,9 +924,9 @@ fn make_rules() -> Vec<Rewrite> {
|
||||
]
|
||||
}
|
||||
|
||||
fn egg_simplify<S: ExpressionStorage>(expr: GenericExpression<S>) -> GenericExpression<S> {
|
||||
fn egg_simplify(e: Expression) -> Expression {
|
||||
// Convert to egg expression
|
||||
let expr = luminal_to_egg(&expr);
|
||||
let expr = luminal_to_egg(&e);
|
||||
// Simplify
|
||||
let runner = Runner::default()
|
||||
.with_expr(&expr)
|
||||
@@ -978,10 +940,11 @@ fn egg_simplify<S: ExpressionStorage>(expr: GenericExpression<S>) -> GenericExpr
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use crate::prelude::*;
|
||||
#[test]
|
||||
fn test_expressions() {
|
||||
let n = Expression::from('x') + (Expression::from(256) - (Expression::from('x') % 256));
|
||||
let n = Expression::from('x') + (256 - (Expression::from('x') % 256));
|
||||
assert_eq!(
|
||||
n.simplify()
|
||||
.exec(&[('x', 767)].into_iter().collect())
|
||||
@@ -992,7 +955,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_minimizations() {
|
||||
let expr = ((BigExpression::from('a') * 1) + 0) / 1 + (1 - 1);
|
||||
let expr = ((Expression::from('a') * 1) + 0) / 1 + (1 - 1);
|
||||
let reduced_expr = expr.simplify();
|
||||
assert_eq!(reduced_expr, 'a');
|
||||
}
|
||||
@@ -1007,9 +970,8 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_group_terms() {
|
||||
let s = BigExpression::from('s');
|
||||
let expr = (s.clone() * ((s.clone() - 4) + 1))
|
||||
+ (((s.clone() + 1) * ((s.clone() - 4) + 1)) - (s.clone() * ((s.clone() - 4) + 1)));
|
||||
assert_eq!(expr.simplify().terms.len(), 7);
|
||||
let s = Expression::from('s');
|
||||
let expr = (s * ((s - 4) + 1)) + (((s + 1) * ((s - 4) + 1)) - (s * ((s - 4) + 1)));
|
||||
assert_eq!(expr.simplify().terms.read().len(), 7);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,9 +3,7 @@ use tinyvec::ArrayVec;
|
||||
|
||||
use crate::prelude::*;
|
||||
|
||||
#[derive(
|
||||
Debug, Clone, Copy, Default, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize,
|
||||
)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub struct ShapeTracker {
|
||||
pub dims: ArrayVec<[Expression; 6]>,
|
||||
pub indexes: ArrayVec<[usize; 6]>,
|
||||
@@ -15,7 +13,8 @@ pub struct ShapeTracker {
|
||||
}
|
||||
|
||||
impl ShapeTracker {
|
||||
pub fn new(dims: &[impl Into<Expression> + Copy]) -> Self {
|
||||
#[allow(clippy::not_unsafe_ptr_arg_deref)]
|
||||
pub fn new(dims: impl ToShape) -> Self {
|
||||
let mut s = Self {
|
||||
dims: Default::default(),
|
||||
indexes: Default::default(),
|
||||
@@ -23,8 +22,8 @@ impl ShapeTracker {
|
||||
mask: Default::default(),
|
||||
padding: Default::default(),
|
||||
};
|
||||
for (i, d) in dims.iter().enumerate() {
|
||||
s.dims.push((*d).into());
|
||||
for (i, d) in dims.to_shape().into_iter().enumerate() {
|
||||
s.dims.push(d);
|
||||
s.indexes.push(i);
|
||||
s.fake.push(false);
|
||||
s.mask.push((0.into(), i32::MAX.into())); // Unset upper bound mask are i32::MAX
|
||||
@@ -34,7 +33,7 @@ impl ShapeTracker {
|
||||
}
|
||||
|
||||
/// Create a shape tracker where all dims are fake
|
||||
pub fn fake(dims: &[Expression]) -> Self {
|
||||
pub fn fake(dims: impl ToShape) -> Self {
|
||||
let mut s = Self::new(dims);
|
||||
s.fake.iter_mut().for_each(|i| *i = true);
|
||||
s
|
||||
@@ -82,13 +81,13 @@ impl ShapeTracker {
|
||||
}
|
||||
|
||||
/// Strides without permute applied
|
||||
fn unordered_strides(&self) -> Vec<BigExpression> {
|
||||
fn unordered_strides(&self) -> Vec<Expression> {
|
||||
let mut strides = (0..self.len())
|
||||
.rev()
|
||||
.scan(BigExpression::from(1), |state, i| {
|
||||
let ret = state.clone();
|
||||
.scan(Expression::from(1), |state, i| {
|
||||
let ret = *state;
|
||||
if !self.fake[i] {
|
||||
*state = state.clone() * self.dims[i];
|
||||
*state *= self.dims[i];
|
||||
}
|
||||
Some(ret)
|
||||
})
|
||||
@@ -98,22 +97,19 @@ impl ShapeTracker {
|
||||
}
|
||||
|
||||
/// Compute strides
|
||||
pub fn strides(&self) -> Vec<BigExpression> {
|
||||
pub fn strides(&self) -> Vec<Expression> {
|
||||
let strides = self.unordered_strides();
|
||||
self.indexes
|
||||
.into_iter()
|
||||
.map(|i| strides[i].clone())
|
||||
.collect()
|
||||
self.indexes.into_iter().map(|i| strides[i]).collect()
|
||||
}
|
||||
|
||||
/// Create an expression to translate logical indexes into physical indexes, without expression simplification
|
||||
pub fn index_expression_no_simplify(&self) -> BigExpression {
|
||||
pub fn index_expression_no_simplify(&self) -> Expression {
|
||||
if !self.is_reshaped() {
|
||||
return 'z'.into();
|
||||
}
|
||||
let strides = self.unordered_strides(); // Dimension strides in original order
|
||||
let mut ind_expr = BigExpression::from(0); // The final index expression
|
||||
let mut current_elem_size = BigExpression::from(1); // Keep track of the size of each element of the current dim (last dim elem size: 1)
|
||||
let mut ind_expr = 0.into(); // The final index expression
|
||||
let mut current_elem_size = Expression::from(1); // Keep track of the size of each element of the current dim (last dim elem size: 1)
|
||||
|
||||
// Loop through all dims in reverse order
|
||||
for i in self.indexes.into_iter().rev() {
|
||||
@@ -121,15 +117,15 @@ impl ShapeTracker {
|
||||
let current_size = pad_mask_dim(self.dims[i], self.padding[i], self.mask[i]);
|
||||
// Don't include fake dimensions in the index expression
|
||||
if !self.fake[i] {
|
||||
let mut dim_ind = BigExpression::from('z');
|
||||
let mut dim_ind = Expression::from('z');
|
||||
// Remove other dim components
|
||||
dim_ind /= current_elem_size.clone();
|
||||
dim_ind /= current_elem_size;
|
||||
// Get position in current dim
|
||||
dim_ind %= current_size.clone();
|
||||
dim_ind %= current_size;
|
||||
// Add offset
|
||||
dim_ind += self.mask[i].0 - self.padding[i].0;
|
||||
// Multiply by stride
|
||||
dim_ind *= strides[i].clone();
|
||||
dim_ind *= strides[i];
|
||||
// Add to index expression
|
||||
ind_expr += dim_ind;
|
||||
}
|
||||
@@ -140,28 +136,28 @@ impl ShapeTracker {
|
||||
}
|
||||
|
||||
/// Create an expression to translate logical indexes into physical indexes
|
||||
pub fn index_expression(&self) -> BigExpression {
|
||||
pub fn index_expression(&self) -> Expression {
|
||||
self.index_expression_no_simplify().simplify()
|
||||
}
|
||||
|
||||
/// If this expression evaluates to 0, the logical index is invalid. Otherwise it is valid. No simplification
|
||||
pub fn valid_expression_no_simplify(&self) -> BigExpression {
|
||||
pub fn valid_expression_no_simplify(&self) -> Expression {
|
||||
if !self.is_reshaped() {
|
||||
return true.into();
|
||||
}
|
||||
let mut ret = BigExpression::from(1);
|
||||
let mut acc = BigExpression::from(1);
|
||||
let logical = BigExpression::from('z');
|
||||
let mut ret = Expression::from(1);
|
||||
let mut acc = Expression::from(1);
|
||||
let logical = Expression::from('z');
|
||||
for i in self.indexes.into_iter().rev() {
|
||||
let (bottom_slice, top_slice) = self.mask[i];
|
||||
let logical_sh = pad_mask_dim(self.dims[i], self.padding[i], self.mask[i]);
|
||||
if !self.fake[i] {
|
||||
let dim_ind = (logical.clone() / acc.clone()) % logical_sh.clone();
|
||||
let greater_than = self.padding[i].0.big() - bottom_slice;
|
||||
let dim_ind = (logical / acc) % logical_sh;
|
||||
let greater_than = self.padding[i].0 - bottom_slice;
|
||||
if greater_than != 0 {
|
||||
ret &= dim_ind.clone().gte(greater_than);
|
||||
ret &= dim_ind.gte(greater_than);
|
||||
}
|
||||
ret &= dim_ind.lt(self.dims[i].big() + self.padding[i].0);
|
||||
ret &= dim_ind.lt(self.dims[i] + self.padding[i].0);
|
||||
if top_slice
|
||||
.to_usize()
|
||||
.map(|s| self.dims[i].to_usize().map(|dim| s < dim).unwrap_or(true))
|
||||
@@ -176,22 +172,22 @@ impl ShapeTracker {
|
||||
}
|
||||
|
||||
/// If this expression evaluates to 0, the logical index is invalid. Otherwise it is valid
|
||||
pub fn valid_expression(&self) -> BigExpression {
|
||||
pub fn valid_expression(&self) -> Expression {
|
||||
self.valid_expression_no_simplify().simplify()
|
||||
}
|
||||
|
||||
/// The number of elements in this tensor, including padding and mask
|
||||
pub fn n_elements(&self) -> BigExpression {
|
||||
self.shape().into_iter().product::<BigExpression>().max(1)
|
||||
pub fn n_elements(&self) -> Expression {
|
||||
self.dims().into_iter().product::<Expression>().max(1)
|
||||
}
|
||||
|
||||
/// The number of elements in this tensor, not including pads and mask
|
||||
pub fn n_physical_elements(&self) -> BigExpression {
|
||||
pub fn n_physical_elements(&self) -> Expression {
|
||||
self.indexes
|
||||
.into_iter()
|
||||
.filter(|i| !self.fake[*i])
|
||||
.map(|i| self.dims[i].big())
|
||||
.product::<BigExpression>()
|
||||
.map(|i| self.dims[i])
|
||||
.product::<Expression>()
|
||||
.max(1)
|
||||
}
|
||||
|
||||
@@ -222,10 +218,9 @@ impl ShapeTracker {
|
||||
/// Create a contiguous version
|
||||
pub fn contiguous(self) -> Self {
|
||||
Self::new(
|
||||
&self
|
||||
.shape()
|
||||
self.dims()
|
||||
.into_iter()
|
||||
.map(|i| i.simplify().small())
|
||||
.map(|i| i.simplify())
|
||||
.collect::<Vec<_>>(),
|
||||
)
|
||||
}
|
||||
@@ -241,7 +236,7 @@ impl ShapeTracker {
|
||||
}
|
||||
|
||||
/// Realize the true shape
|
||||
pub fn shape(&self) -> Vec<BigExpression> {
|
||||
pub fn dims(&self) -> Vec<Expression> {
|
||||
self.indexes
|
||||
.into_iter()
|
||||
.map(|i| pad_mask_dim(self.dims[i], self.padding[i], self.mask[i]))
|
||||
@@ -250,7 +245,7 @@ impl ShapeTracker {
|
||||
|
||||
/// Realize the true shape and convert it to usizes. All dyn dims must be replaced already
|
||||
pub fn shape_usize(&self) -> Vec<usize> {
|
||||
self.shape().iter().map(|e| e.to_usize().unwrap()).collect()
|
||||
self.dims().iter().map(|e| e.to_usize().unwrap()).collect()
|
||||
}
|
||||
|
||||
/// Take a slice
|
||||
@@ -284,8 +279,9 @@ impl ShapeTracker {
|
||||
{
|
||||
panic!("Adding padding to a masked shape isn't supported")
|
||||
}
|
||||
self.padding[ind].0 += s.max(0);
|
||||
self.padding[ind].1 += e.max(0);
|
||||
let (s, e) = (s.max(0), e.max(0));
|
||||
self.padding[ind].0 += s;
|
||||
self.padding[ind].1 += e;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -329,11 +325,11 @@ impl ShapeTracker {
|
||||
}
|
||||
|
||||
fn pad_mask_dim(
|
||||
dim: impl Into<BigExpression>,
|
||||
dim: impl Into<Expression>,
|
||||
padding: (Expression, Expression),
|
||||
mask: (Expression, Expression),
|
||||
) -> BigExpression {
|
||||
(dim.into() + padding.0 + padding.1).min(mask.1) - mask.0
|
||||
) -> Expression {
|
||||
(padding.0 + padding.1 + dim).min(mask.1) - mask.0
|
||||
}
|
||||
|
||||
/// Resolve shapes between the two trackers to the best of our ability
|
||||
@@ -364,7 +360,7 @@ mod tests {
|
||||
use crate::prelude::*;
|
||||
#[test]
|
||||
fn test_idx_expr() {
|
||||
let mut tracker = ShapeTracker::new(&[
|
||||
let mut tracker = ShapeTracker::new([
|
||||
Expression::from(10),
|
||||
Expression::from(5),
|
||||
Expression::from(3),
|
||||
|
||||
@@ -50,7 +50,7 @@ fn test_expand() {
|
||||
fn test_slice() {
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor((2, 3)).set([[1., 2., 3.], [1., 2., 3.]]);
|
||||
let b = a.slice((Expression::from(1).., ..)).retrieve();
|
||||
let b = a.slice((1.., ..)).retrieve();
|
||||
cx.execute();
|
||||
|
||||
let d_dev = Cpu::default();
|
||||
@@ -323,3 +323,18 @@ fn test_max_reduce() {
|
||||
assert_close(&c.data(), &d_c.as_vec());
|
||||
assert_close(&d.data(), &d_d.as_vec());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dot() {
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor(5).set([34.4, -96.0, 144.0, 43.0, 560.0]);
|
||||
let b = cx.tensor(5).set([43.0, 560.0, 180.0, 700.0, 225.0]);
|
||||
let c = a.dot(b).retrieve();
|
||||
cx.execute();
|
||||
|
||||
let d_dev = Cpu::default();
|
||||
let d_a = d_dev.tensor([34.4, -96.0, 144.0, 43.0, 560.0]);
|
||||
let d_b = d_dev.tensor([43.0, 560.0, 180.0, 700.0, 225.0]);
|
||||
let d_c = (d_a * d_b).sum();
|
||||
assert_close(&c.data(), &d_c.as_vec());
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user