Switched to runtime shapes

This commit is contained in:
Joe Fioti
2024-07-08 12:54:42 -04:00
parent ddc201e1c1
commit 8d7b8c8972
1132 changed files with 5871 additions and 4331 deletions

View File

@@ -145,8 +145,8 @@ mod tests {
#[test]
fn test_matmul() {
let mut cx = Graph::new();
let a = cx.tensor::<(Dyn<'M'>, Dyn<'K'>)>();
let b = cx.tensor::<(Dyn<'K'>, Dyn<'N'>)>();
let a = cx.tensor(('M', 'K'));
let b = cx.tensor(('K', 'N'));
let mut c = a.matmul(b).retrieve();
cx.compile(CPUCompiler::default(), &mut c);
@@ -158,8 +158,8 @@ mod tests {
let mut rng = StdRng::seed_from_u64(0);
let a_data = random_vec_rng(m * k, &mut rng);
let b_data = random_vec_rng(k * n, &mut rng);
a.set_dyn(a_data.clone(), &[m, k]);
b.set_dyn(b_data.clone(), &[k, n]);
a.set_dyn(a_data.clone(), (m, k));
b.set_dyn(b_data.clone(), (k, n));
cx.execute();
@@ -177,9 +177,9 @@ mod tests {
#[test]
fn test_cpu_matmul_2d_2() {
let mut cx = Graph::new();
let a = cx.tensor::<R2<2, 3>>();
let a = cx.tensor((2, 3));
a.set(vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0]);
let b = cx.tensor::<R2<3, 4>>();
let b = cx.tensor((3, 4));
b.set(vec![1., 2., 3., 1., 2., 3., 1., 2., 3., 1., 2., 3.]);
let mut c = a.matmul(b).retrieve();

View File

@@ -530,11 +530,11 @@ mod tests {
fn test_subtraction() {
let mut cx = Graph::new();
let a = cx
.tensor::<R1<10>>()
.tensor(10)
.set(vec![1., 2., 3., 4., 5., 6., 7., 8., 9., 10.]);
let b = cx.tensor::<R0>().set(vec![1.]);
let mut c = (a - b.expand()).retrieve();
let mut d = (-a + b.expand()).retrieve();
let b = cx.tensor(()).set(vec![1.]);
let mut c = (a - b.expand_to(a.shape)).retrieve();
let mut d = (-a + b.expand_to(a.shape)).retrieve();
cx.execute();

View File

@@ -306,9 +306,9 @@ fn test_common_buffer() {
use crate::MetalCompiler;
let mut cx = Graph::new();
let a = cx.tensor::<R1<5>>().set(random_vec(5)).keep();
let b = cx.tensor::<R1<5>>().set(random_vec(5)).keep();
let c = cx.tensor::<R1<5>>().set(random_vec(5)).keep();
let a = cx.tensor(5).set(random_vec(5)).keep();
let b = cx.tensor(5).set(random_vec(5)).keep();
let c = cx.tensor(5).set(random_vec(5)).keep();
let mut d = ((a + b) * c).retrieve();
cx.execute();

View File

@@ -118,7 +118,7 @@ impl<T: MetalFloat> Compiler for ElementwiseFusionCompiler<T> {
let mut subexpressions_b = graph
.try_get_op::<FusedElementwiseOp<T>>(b)
.map(|o| o.subexpressions.clone())
.unwrap_or_else(|| vec![(expression_b.clone(), ShapeTracker::new(&[]))]);
.unwrap_or_else(|| vec![(expression_b.clone(), ShapeTracker::default())]);
let a_to_b_indexes = graph
.edges_connecting(a, b)
.map(|e| e.weight().as_data().unwrap().0 as usize)
@@ -138,7 +138,7 @@ impl<T: MetalFloat> Compiler for ElementwiseFusionCompiler<T> {
let mut subexpressions_a = graph
.try_get_op::<FusedElementwiseOp<T>>(a)
.map(|o| o.subexpressions.clone())
.unwrap_or_else(|| vec![(expression_a.clone(), ShapeTracker::new(&[]))]);
.unwrap_or_else(|| vec![(expression_a.clone(), ShapeTracker::default())]);
subexpressions_a.last_mut().unwrap().1 = connecting_shape;
// Re-reference b intermediates
for i in (0..subexpressions_b.len()).rev() {
@@ -296,7 +296,7 @@ impl<T: MetalFloat> Compiler for ElementwiseFusionCompiler<T> {
kernel: None,
dyn_map: &graph.dyn_map,
dyn_chars: vec![],
subexpressions: vec![(op_string, ShapeTracker::new(&[]))],
subexpressions: vec![(op_string, ShapeTracker::default())],
queue: queue.clone(),
device: device.clone(),
output_buffer_sizes,
@@ -593,9 +593,8 @@ mod tests {
prelude::{binary::F32Pow, *},
tests::{assert_close, assert_close_precision, random_vec, random_vec_rng},
};
use luminal_nn::*;
use luminal_nn::{LayerNorm, Linear};
use rand::{rngs::StdRng, SeedableRng};
use std::{marker::PhantomData, ops::Div};
use crate::MetalCompiler;
@@ -603,7 +602,7 @@ mod tests {
fn test_fusion_simple() {
let mut cx = Graph::new();
let mut rng = StdRng::seed_from_u64(0);
let inp = cx.tensor::<R1<5>>().set(random_vec_rng(10, &mut rng));
let inp = cx.tensor(5).set(random_vec_rng(10, &mut rng));
let mut out = inp.exp2().cos().sqrt().retrieve();
cx.execute();
@@ -619,8 +618,8 @@ mod tests {
fn test_fusion_binary() {
let mut cx = Graph::new();
let mut rng = StdRng::seed_from_u64(0);
let a = cx.tensor::<R1<5>>().set(random_vec_rng(10, &mut rng));
let b = cx.tensor::<R1<5>>().set(random_vec_rng(10, &mut rng));
let a = cx.tensor(5).set(random_vec_rng(10, &mut rng));
let b = cx.tensor(5).set(random_vec_rng(10, &mut rng));
let mut out = (a.exp2() + b.cos()).retrieve();
cx.execute();
@@ -636,9 +635,9 @@ mod tests {
#[test]
fn test_fusion_subexpression_complex() {
let mut cx = Graph::new();
let a = cx.named_tensor::<R1<10>>("a").set(random_vec(10)).keep();
let b = cx.named_tensor::<R1<10>>("b").set(random_vec(10)).keep();
let d = cx.named_tensor::<R1<10>>("d").set(random_vec(10)).keep();
let a = cx.named_tensor("a", 10).set(random_vec(10)).keep();
let b = cx.named_tensor("b", 10).set(random_vec(10)).keep();
let d = cx.named_tensor("d", 10).set(random_vec(10)).keep();
let mut out = ((a.exp2() - b.sin()).sin() * 3.4).less_than(d).retrieve();
cx.execute();
@@ -656,12 +655,11 @@ mod tests {
let mut cx = Graph::new();
let mut rng = StdRng::seed_from_u64(0);
let inp = random_vec_rng(10, &mut rng);
let a = cx.named_tensor::<R2<2, 5>>("a").set(inp);
let a = cx.named_tensor("a", (2, 5)).set(inp);
let mut padded = a
.slice((..Expression::from(1), ..))
.realize::<R2<1, 5>>()
.cos()
.pad::<R2<2, 5>>(((0, 1), (0, 0)))
.pad(((0, 1), (0, 0)))
.exp2()
.retrieve();
cx.execute();
@@ -682,7 +680,7 @@ mod tests {
let mut cx = Graph::new();
let mut rng = StdRng::seed_from_u64(0);
let data = random_vec_rng(10, &mut rng);
let a = cx.tensor::<R2<2, 5>>().set(data);
let a = cx.tensor((2, 5)).set(data);
let mut out = (a.sqrt().exp() + a.sqrt().sin()).retrieve();
cx.execute();
let unopt_out = out.data();
@@ -699,14 +697,10 @@ mod tests {
let mut cx = Graph::new();
const SEQ: usize = 2;
const HEAD_DIM: usize = 4;
const HEAD_DIM_OVER_2: usize = HEAD_DIM / 2;
let freqs = (cx.arange::<Const<HEAD_DIM_OVER_2>>() * 2.0) / (HEAD_DIM as f32);
let freqs = (cx.arange(HEAD_DIM / 2) * 2.0) / (HEAD_DIM as f32);
let freqs = 1000000_f32.pow(freqs);
let pos = cx.arange::<Const<SEQ>>() + BigExpression::from(0);
let mut emb = pos
.expand::<(_, Const<1>), _>()
.matmul(freqs.expand())
.retrieve();
let pos = cx.arange(SEQ) + BigExpression::from(0);
let mut emb = pos.expand(1, 1).matmul(freqs.expand(0, SEQ)).retrieve();
cx.execute();
let unopt_out = emb.data();
@@ -725,27 +719,25 @@ mod tests {
const HEAD_DIM: usize = 4;
const HEAD_DIM_OVER_2: usize = HEAD_DIM / 2;
let a = cx
.named_tensor::<R2<SEQ, HEAD_DIM>>("a")
.named_tensor("a", (SEQ, HEAD_DIM))
.set(random_vec_rng(SEQ * HEAD_DIM, &mut rng))
.keep();
let b = cx
.tensor::<R3<SEQ, HEAD_DIM_OVER_2, 1>>()
.tensor((SEQ, HEAD_DIM / 2, 1))
.set(random_vec_rng(SEQ * HEAD_DIM_OVER_2, &mut rng))
.keep();
// Split input into evens and odds
let split = a.reshape::<R3<SEQ, HEAD_DIM_OVER_2, 2>>();
let x0: GraphTensor<R3<SEQ, HEAD_DIM_OVER_2, 1>> =
split.slice((.., .., ..Expression::from(1))).realize();
let x1: GraphTensor<R3<SEQ, HEAD_DIM_OVER_2, 1>> =
split.slice((.., .., Expression::from(1)..)).realize();
let split = a.reshape((SEQ, HEAD_DIM / 2, 2));
let x0 = split.slice((.., .., ..1));
let x1 = split.slice((.., .., 1..));
let x0_out = x0 * b - x1 * b.cos();
let x1_out = x0 + x1;
// Combine back into output
let mut out: GraphTensor<R2<SEQ, HEAD_DIM>> = x0_out
.concat_along::<R3<SEQ, HEAD_DIM_OVER_2, 2>, Axis<2>, _>(x1_out)
.reshape()
let mut out = x0_out
.concat_along(x1_out, 2)
.reshape((SEQ, HEAD_DIM))
.retrieve();
cx.execute();
@@ -765,34 +757,31 @@ mod tests {
const N_HEADS: usize = 8;
const SEQ: usize = 2;
const HEAD_DIM: usize = 4;
const HEAD_DIM_OVER_2: usize = HEAD_DIM / 2;
let a = cx
.named_tensor::<R4<BATCH, N_HEADS, SEQ, HEAD_DIM>>("a")
.named_tensor("a", (BATCH, N_HEADS, SEQ, HEAD_DIM))
.set(random_vec_rng(BATCH * N_HEADS * SEQ * HEAD_DIM, &mut rng))
.keep();
let freqs = (cx.arange::<Const<HEAD_DIM_OVER_2>>() * 2.0) / (HEAD_DIM as f32);
let freqs = (cx.arange(HEAD_DIM / 2) * 2.0) / (HEAD_DIM as f32);
let freqs = 1000000_f32.pow(freqs);
let pos = cx.arange::<Const<SEQ>>() + BigExpression::from(0);
let emb = pos.expand::<(_, Const<1>), _>().matmul(freqs.expand());
let pos = cx.arange(SEQ) + BigExpression::from(0);
let emb = pos.expand(1, 1).matmul(freqs.expand(0, SEQ));
// Split input into evens and odds
let split = a.reshape::<R5<BATCH, N_HEADS, SEQ, HEAD_DIM_OVER_2, 2>>();
let x0: GraphTensor<R5<BATCH, N_HEADS, SEQ, HEAD_DIM_OVER_2, 1>> = split
let split = a.reshape((BATCH, N_HEADS, SEQ, HEAD_DIM / 2, 2));
let x0 = split
.slice((.., .., .., .., ..Expression::from(1)))
.contiguous()
.realize();
let x1: GraphTensor<R5<BATCH, N_HEADS, SEQ, HEAD_DIM_OVER_2, 1>> = split
.contiguous();
let x1 = split
.slice((.., .., .., .., Expression::from(1)..))
.contiguous()
.realize();
.contiguous();
// Apply sin and cos embeddings
let x0_out = x0 * emb.cos().expand() - x1 * emb.sin().expand();
let x1_out = x0 * emb.sin().expand() + x1 * emb.cos().expand();
let x0_out = x0 * emb.cos().expand_to(x0.shape) - x1 * emb.sin().expand_to(x1.shape);
let x1_out = x0 * emb.sin().expand_to(x0.shape) + x1 * emb.cos().expand_to(x1.shape);
// Combine back into output
let mut out: GraphTensor<R4<BATCH, N_HEADS, SEQ, HEAD_DIM>> = x0_out
.concat_along::<R5<BATCH, N_HEADS, SEQ, HEAD_DIM_OVER_2, 2>, Axis<4>, _>(x1_out)
.reshape()
let mut out = x0_out
.concat_along(x1_out, 4)
.reshape((BATCH, N_HEADS, SEQ, HEAD_DIM))
.retrieve();
cx.execute();
let unopt_out = out.data();
@@ -814,176 +803,167 @@ mod tests {
pub const SEQ_LEN: usize = 65;
pub const N_ATTENTION_GROUPS: usize = N_HEADS / N_KV_HEADS;
pub const HEAD_DIM: usize = HIDDEN_DIM / N_HEADS;
pub const HEAD_DIM_OVER_2: usize = HEAD_DIM / 2;
pub const ATTN_PROJ_DIM: usize = HEAD_DIM * N_KV_HEADS;
pub struct Mlp<const I: usize, const H: usize> {
pub gate_proj: PermutedLinear<H, I>,
pub down_proj: PermutedLinear<I, H>,
pub up_proj: PermutedLinear<H, I>,
pub type KVCache = (GraphTensor, GraphTensor);
pub struct Mlp {
pub gate_proj: Linear, // hidden -> intermediate
pub down_proj: Linear, // intermediate -> hidden
pub up_proj: Linear, // hidden -> intermediate
}
pub type KVCache<Batch, Seq> = (
GraphTensor<(Batch, Const<N_KV_HEADS>, Seq, Const<HEAD_DIM>)>,
GraphTensor<(Batch, Const<N_KV_HEADS>, Seq, Const<HEAD_DIM>)>,
);
impl Module<GraphTensor> for Mlp {
type Output = GraphTensor;
impl<Sh: Shape, Im: Shape, const I: usize, const H: usize> Module<GraphTensor<Sh>> for Mlp<I, H>
where
GraphTensor<Sh>: Matmul<R2<H, I>, Output = GraphTensor<Im>>,
GraphTensor<Im>: Matmul<R2<I, H>, Output = GraphTensor<Sh>>,
{
type Output = GraphTensor<Sh>;
fn forward(&self, input: GraphTensor<Sh>) -> Self::Output {
fn forward(&self, input: GraphTensor) -> Self::Output {
let gate = self.gate_proj.forward(input).swish();
let up = self.up_proj.forward(input) * gate;
self.down_proj.forward(up)
}
}
impl<const I: usize, const H: usize> InitModule for Mlp<I, H> {
fn initialize(cx: &mut Graph) -> Self {
impl Mlp {
pub fn new(hidden: usize, intermediate: usize, cx: &mut Graph) -> Self {
Self {
gate_proj: InitModule::initialize(cx),
up_proj: InitModule::initialize(cx),
down_proj: InitModule::initialize(cx),
gate_proj: Linear::new_permuted(hidden, intermediate, false, cx),
down_proj: Linear::new_permuted(intermediate, hidden, false, cx),
up_proj: Linear::new_permuted(hidden, intermediate, false, cx),
}
}
}
fn apply_rotary_embeddings_ggml<const N_HEADS: usize, Batch: Dimension, Seq: Dimension>(
input: GraphTensor<(Batch, Const<N_HEADS>, Seq, Const<HEAD_DIM>)>,
impl SerializeModule for Mlp {
fn serialize(&self, s: &mut Serializer) {
s.module("ffn_gate", &self.gate_proj);
s.module("ffn_up", &self.up_proj);
s.module("ffn_down", &self.down_proj);
}
}
fn apply_rotary_embeddings_ggml(
input: GraphTensor,
prev_seq: BigExpression,
) -> GraphTensor<(Batch, Const<N_HEADS>, Seq, Const<HEAD_DIM>)> {
) -> GraphTensor {
assert_eq!(input.shape.len(), 4); // batch, n_heads, seq, head_dim
let batch = input.shape()[0].small();
let n_heads = input.shape()[1].small();
let seq = input.shape()[2].small();
let head_dim = input.shape()[3].small();
// Get freqs
let freqs =
(input.graph().arange::<Const<HEAD_DIM_OVER_2>>() * 2.0) / (HEAD_DIM as f32);
let freqs = 1000000_f32.pow(freqs);
let pos = input.graph().arange::<Seq>() + prev_seq;
let emb = pos.expand::<(_, Const<1>), _>().matmul(freqs.expand());
(input.graph().arange(head_dim / 2) * 2.0) / (head_dim.to_usize().unwrap() as f32);
let freqs = 500_000_f32.pow(freqs);
let pos = input.graph().arange(seq) + prev_seq;
let emb = pos.expand(1, 1).matmul(freqs.expand(0, seq));
// Split input into evens and odds
let split =
input.reshape::<(Batch, Const<N_HEADS>, Seq, Const<HEAD_DIM_OVER_2>, Const<2>)>();
let x0: GraphTensor<(Batch, Const<N_HEADS>, Seq, Const<HEAD_DIM_OVER_2>, Const<1>)> =
split
.slice((.., .., .., .., ..Expression::from(1)))
.contiguous()
.realize();
let x1: GraphTensor<(Batch, Const<N_HEADS>, Seq, Const<HEAD_DIM_OVER_2>, Const<1>)> =
split
.slice((.., .., .., .., Expression::from(1)..))
.contiguous()
.realize();
let split = input.reshape((batch, n_heads, seq, head_dim / 2, 2));
let x0 = split.slice((.., .., .., .., ..1));
let x1 = split.slice((.., .., .., .., 1..));
// Apply sin and cos embeddings
let x0_out = x0 * emb.cos().expand() - x1 * emb.sin().expand();
let x1_out = x0 * emb.sin().expand() + x1 * emb.cos().expand();
let x0_out = x0 * emb.cos().expand_to(x0.shape) - x1 * emb.sin().expand_to(x1.shape);
let x1_out = x0 * emb.sin().expand_to(x0.shape) + x1 * emb.cos().expand_to(x1.shape);
// Combine back into output
x0_out
.concat_along::<(Batch, Const<N_HEADS>, Seq, Const<HEAD_DIM_OVER_2>, Const<2>), Axis<4>, _>(
x1_out,
)
.reshape()
}
pub struct SelfAttention {
pub q_proj: GraphTensor<R2<HIDDEN_DIM, HIDDEN_DIM>>,
pub k_proj: GraphTensor<R2<ATTN_PROJ_DIM, HIDDEN_DIM>>,
pub v_proj: GraphTensor<R2<ATTN_PROJ_DIM, HIDDEN_DIM>>,
pub o_proj: GraphTensor<R2<HIDDEN_DIM, HIDDEN_DIM>>,
x0_out.concat_along(x1_out, 4).reshape(input.shape)
}
impl<Batch: Dimension, CurSeq: Dimension, PrevSeq: Dimension, TotSeq: Dimension>
Module<(
GraphTensor<(Batch, CurSeq, Const<HIDDEN_DIM>)>,
KVCache<Batch, PrevSeq>,
PhantomData<TotSeq>,
)> for SelfAttention
{
type Output = (
GraphTensor<(Batch, CurSeq, Const<HIDDEN_DIM>)>,
KVCache<Batch, TotSeq>,
);
fn forward(
&self,
(x, (k_cache, v_cache), _): (
GraphTensor<(Batch, CurSeq, Const<HIDDEN_DIM>)>,
KVCache<Batch, PrevSeq>,
PhantomData<TotSeq>,
),
) -> Self::Output {
pub struct SelfAttention {
pub q_proj: GraphTensor, // Hidden -> hidden
pub k_proj: GraphTensor, // Proj dim -> hidden
pub v_proj: GraphTensor, // Proj dim -> hidden
pub o_proj: GraphTensor, // Hidden -> hidden
}
impl Module<(GraphTensor, KVCache)> for SelfAttention {
type Output = (GraphTensor, KVCache);
fn forward(&self, (x, (k_cache, v_cache)): (GraphTensor, KVCache)) -> Self::Output {
// x: batch, seq, hidden
let batch = x.shape()[0].small();
let seq = x.shape()[1].small();
let prev_seq = k_cache.shape()[2].small();
// Apply the Projections
let queries = x
.matmul(self.q_proj.permute())
.reshape::<(Batch, CurSeq, Const<N_HEADS>, Const<HEAD_DIM>)>()
.permute::<_, Axes4<0, 2, 1, 3>>();
.matmul(self.q_proj.permute((1, 0)))
.reshape((batch, seq, N_HEADS, HEAD_DIM))
.permute((0, 2, 1, 3));
let keys = x
.matmul(self.k_proj.permute())
.reshape::<(Batch, CurSeq, Const<N_KV_HEADS>, Const<HEAD_DIM>)>()
.permute::<_, Axes4<0, 2, 1, 3>>();
.matmul(self.k_proj.permute((1, 0)))
.reshape((batch, seq, N_KV_HEADS, HEAD_DIM))
.permute((0, 2, 1, 3));
let values = x
.matmul(self.v_proj.permute())
.reshape::<(Batch, CurSeq, Const<N_KV_HEADS>, Const<HEAD_DIM>)>()
.permute::<_, Axes4<0, 2, 1, 3>>();
.matmul(self.v_proj.permute((1, 0)))
.reshape((batch, seq, N_KV_HEADS, HEAD_DIM))
.permute((0, 2, 1, 3));
// Rotary embed queries and keys
let queries = apply_rotary_embeddings_ggml(queries, PrevSeq::size().big());
let keys = apply_rotary_embeddings_ggml(keys, PrevSeq::size().big());
let queries = apply_rotary_embeddings_ggml(queries, prev_seq.big());
let keys = apply_rotary_embeddings_ggml(keys, prev_seq.big());
// Add KV cache
let (keys, values) = (
k_cache.concat_along::<_, Axis<2>, _>(keys),
v_cache.concat_along::<_, Axis<2>, _>(values),
);
let keys = k_cache.concat_along(keys, 2);
let values = v_cache.concat_along(values, 2);
// Repeat the KV States for Grouped-Query Attention
let repeated_keys = keys.expand::<(_, _, Const<N_ATTENTION_GROUPS>, _, _), _>();
let repeated_values = values.expand::<(_, _, Const<N_ATTENTION_GROUPS>, _, _), _>();
let repeated_keys = keys.expand(2, N_ATTENTION_GROUPS);
let repeated_values = values.expand(2, N_ATTENTION_GROUPS);
// Calculate attention weights
let mut attention_weights = queries
.reshape::<(_, Const<N_KV_HEADS>, Const<N_ATTENTION_GROUPS>, _, _)>() // Split query heads into groups
.matmul(repeated_keys.permute())
.div((HEAD_DIM as f32).sqrt());
.reshape((batch, N_KV_HEADS, N_ATTENTION_GROUPS, seq, HEAD_DIM)) // Split query heads into groups
.matmul(repeated_keys.permute((0, 1, 2, 4, 3)))
/ (HEAD_DIM as f32).sqrt();
let attention_mask = self.k_proj.graph().triu::<CurSeq>(1) * f16::MIN.to_f32();
let attention_mask = self.k_proj.graph().triu(seq, 1) * f16::MIN.to_f32();
attention_weights += attention_mask
.pad::<(CurSeq, TotSeq)>(((0, 0), (TotSeq::size() - CurSeq::size(), 0)))
.expand();
.pad(((0, 0), (prev_seq, 0)))
.expand(0, batch)
.expand(1, N_KV_HEADS)
.expand(2, N_ATTENTION_GROUPS);
// Calculate final outputs
let output = attention_weights
.softmax::<Axis<4>>()
.softmax(4)
// Apply distribution to values
.matmul(repeated_values)
// Merge heads
.permute::<_, Axes5<0, 3, 1, 2, 4>>()
.reshape::<(Batch, CurSeq, Const<HIDDEN_DIM>)>();
.permute((0, 3, 1, 2, 4))
.reshape((batch, seq, HIDDEN_DIM));
let output = output
// Apply output projection
.matmul(self.o_proj.permute());
.matmul(self.o_proj.permute((1, 0)));
(output, (keys.contiguous(), values.contiguous())) // Cache needs to be contiguous for transferring to another graph
}
}
impl InitModule for SelfAttention {
fn initialize(cx: &mut Graph) -> Self {
impl SelfAttention {
pub fn new(cx: &mut Graph) -> Self {
Self {
q_proj: cx
.named_tensor("Q Proj")
.set(random_vec(HIDDEN_DIM * HIDDEN_DIM)),
k_proj: cx
.named_tensor("K Proj")
.set(random_vec(ATTN_PROJ_DIM * HIDDEN_DIM)),
v_proj: cx
.named_tensor("V Proj")
.set(random_vec(ATTN_PROJ_DIM * HIDDEN_DIM)),
o_proj: cx
.named_tensor("O Proj")
.set(random_vec(HIDDEN_DIM * HIDDEN_DIM)),
q_proj: cx.named_tensor("Q Proj", (HIDDEN_DIM, HIDDEN_DIM)),
k_proj: cx.named_tensor("K Proj", (ATTN_PROJ_DIM, HIDDEN_DIM)),
v_proj: cx.named_tensor("V Proj", (ATTN_PROJ_DIM, HIDDEN_DIM)),
o_proj: cx.named_tensor("O Proj", (HIDDEN_DIM, HIDDEN_DIM)),
}
}
fn initialize(self) -> Self {
self.k_proj.set(random_vec(
self.k_proj.shape.n_elements().to_usize().unwrap(),
));
self.o_proj.set(random_vec(
self.o_proj.shape.n_elements().to_usize().unwrap(),
));
self.v_proj.set(random_vec(
self.v_proj.shape.n_elements().to_usize().unwrap(),
));
self.q_proj.set(random_vec(
self.q_proj.shape.n_elements().to_usize().unwrap(),
));
self
}
}
impl SerializeModule for SelfAttention {
@@ -997,35 +977,18 @@ mod tests {
pub struct TransformerBlock {
pub attention: SelfAttention,
pub attention_norm: LayerNorm<HIDDEN_DIM>,
pub feed_forward: Mlp<MLP_DIM, HIDDEN_DIM>,
pub feed_forward_norm: LayerNorm<HIDDEN_DIM>,
pub attention_norm: LayerNorm,
pub feed_forward: Mlp,
pub feed_forward_norm: LayerNorm,
}
impl<Batch: Dimension, CurSeq: Dimension, PrevSeq: Dimension, TotSeq: Dimension>
Module<(
GraphTensor<(Batch, CurSeq, Const<HIDDEN_DIM>)>,
KVCache<Batch, PrevSeq>,
PhantomData<TotSeq>,
)> for TransformerBlock
{
type Output = (
GraphTensor<(Batch, CurSeq, Const<HIDDEN_DIM>)>,
KVCache<Batch, TotSeq>,
);
fn forward(
&self,
(mut x, cache, _): (
GraphTensor<(Batch, CurSeq, Const<HIDDEN_DIM>)>,
KVCache<Batch, PrevSeq>,
PhantomData<TotSeq>,
),
) -> Self::Output {
impl Module<(GraphTensor, KVCache)> for TransformerBlock {
type Output = (GraphTensor, KVCache);
fn forward(&self, (mut x, cache): (GraphTensor, KVCache)) -> Self::Output {
// Attention
let normed = self.attention_norm.forward(x);
let (y, cache) = self
.attention
.forward((normed, cache, PhantomData::<TotSeq>));
.forward((self.attention_norm.forward(x), cache));
// Residual Addition
x += y;
@@ -1038,94 +1001,85 @@ mod tests {
}
}
impl InitModule for TransformerBlock {
fn initialize(cx: &mut Graph) -> Self {
impl TransformerBlock {
pub fn new(cx: &mut Graph) -> Self {
Self {
attention: InitModule::initialize(cx),
attention_norm: LayerNorm::new(false, false, false, 1e-5, cx),
feed_forward: InitModule::initialize(cx),
feed_forward_norm: LayerNorm::new(false, false, false, 1e-5, cx),
attention: SelfAttention::new(cx),
attention_norm: LayerNorm::new(HIDDEN_DIM, true, false, false, 1e-5, cx),
feed_forward: Mlp::new(HIDDEN_DIM, MLP_DIM, cx),
feed_forward_norm: LayerNorm::new(HIDDEN_DIM, true, false, false, 1e-5, cx),
}
}
fn initialize(mut self) -> Self {
self.attention_norm = self.attention_norm.initialize();
self.feed_forward_norm = self.feed_forward_norm.initialize();
self.attention = self.attention.initialize();
self.feed_forward.down_proj = self.feed_forward.down_proj.initialize();
self.feed_forward.up_proj = self.feed_forward.up_proj.initialize();
self.feed_forward.gate_proj = self.feed_forward.gate_proj.initialize();
self
}
}
pub struct MistralLM {
pub struct Llama {
// Transformer layers
pub layers: Vec<TransformerBlock>,
// Final Norm layer
pub norm: LayerNorm<HIDDEN_DIM>,
// Norm + LM head
pub head: LayerNorm,
}
impl<Batch: Dimension, CurSeq: Dimension, PrevSeq: Dimension, TotSeq: Dimension>
Module<(
GraphTensor<(Batch, CurSeq, Const<HIDDEN_DIM>)>,
Vec<KVCache<Batch, PrevSeq>>,
PhantomData<TotSeq>,
)> for MistralLM
{
type Output = (
GraphTensor<(Batch, CurSeq, Const<HIDDEN_DIM>)>,
Vec<KVCache<Batch, TotSeq>>,
);
fn forward(
&self,
(input, cache, _): (
GraphTensor<(Batch, CurSeq, Const<HIDDEN_DIM>)>,
Vec<KVCache<Batch, PrevSeq>>,
PhantomData<TotSeq>,
),
) -> Self::Output {
let mut x = input;
impl Module<(GraphTensor, &[KVCache])> for Llama {
type Output = (GraphTensor, Vec<KVCache>);
fn forward(&self, (mut x, cache): (GraphTensor, &[KVCache])) -> Self::Output {
// Run through layers and collect new caches
let mut new_caches = vec![];
let mut new_cache;
for (i, layer) in self.layers.iter().enumerate() {
(x, new_cache) = layer.forward((x, cache[i], PhantomData::<TotSeq>));
(x, new_cache) = layer.forward((x, cache[i]));
new_caches.push(new_cache);
}
// Run through last norm and output projection
let normed = self.norm.forward(x);
(normed, new_caches)
(self.head.forward(x), new_caches)
}
}
impl InitModule for MistralLM {
fn initialize(cx: &mut Graph) -> Self {
impl Llama {
pub fn new(cx: &mut Graph) -> Self {
Self {
norm: LayerNorm::new(false, false, false, 1e-5, cx),
layers: (0..NUM_LAYERS)
.map(|_| InitModule::initialize(cx))
.collect(),
head: LayerNorm::new(HIDDEN_DIM, true, false, false, 1e-5, cx),
layers: (0..NUM_LAYERS).map(|_| TransformerBlock::new(cx)).collect(),
}
}
fn initialize(mut self) -> Self {
self.head = self.head.initialize();
self.layers = self.layers.into_iter().map(|l| l.initialize()).collect();
self
}
}
let mut cx = Graph::new();
let model = MistralLM::initialize(&mut cx);
let model = Llama::new(&mut cx).initialize();
let caches = (0..NUM_LAYERS)
.map(|_| {
(
cx.tensor::<(Const<1>, Const<N_KV_HEADS>, Dyn<'p'>, Const<HEAD_DIM>)>()
.set_dyn(
random_vec(SEQ_LEN * N_KV_HEADS * HEAD_DIM),
&[1, N_KV_HEADS, SEQ_LEN, HEAD_DIM],
),
cx.tensor::<(Const<1>, Const<N_KV_HEADS>, Dyn<'p'>, Const<HEAD_DIM>)>()
.set_dyn(
random_vec(SEQ_LEN * N_KV_HEADS * HEAD_DIM),
&[1, N_KV_HEADS, SEQ_LEN, HEAD_DIM],
),
cx.tensor((1, N_KV_HEADS, 'p', HEAD_DIM)).set_dyn(
random_vec(SEQ_LEN * N_KV_HEADS * HEAD_DIM),
(1, N_KV_HEADS, SEQ_LEN, HEAD_DIM),
),
cx.tensor((1, N_KV_HEADS, 'p', HEAD_DIM)).set_dyn(
random_vec(SEQ_LEN * N_KV_HEADS * HEAD_DIM),
(1, N_KV_HEADS, SEQ_LEN, HEAD_DIM),
),
)
})
.collect();
.collect::<Vec<_>>();
let input = cx
.tensor::<(Const<1>, Dyn<'s'>, luminal::shape::Const<HIDDEN_DIM>)>()
.set_dyn(random_vec(2 * HIDDEN_DIM), &[1, 2, HIDDEN_DIM]);
let (mut out, _) = model.forward((input, caches, PhantomData::<Dyn<'t'>>));
.tensor((1, 's', HIDDEN_DIM))
.set_dyn(random_vec(2 * HIDDEN_DIM), (1, 2, HIDDEN_DIM));
let (mut out, _) = model.forward((input, &caches));
out.retrieve();
cx.set_dyn_dim('t', SEQ_LEN + 2);
cx.execute();
let unopt_out = out.data();

View File

@@ -429,9 +429,9 @@ mod tests {
const N: usize = 256;
let mut cx = Graph::new();
let (a_vec, b_mat) = (random_vec(M), random_vec(M * N));
let mut a = cx.named_tensor::<R2<1, M>>("Vec").set(a_vec.clone());
let mut b = cx.named_tensor::<R2<N, M>>("Mat").set(b_mat.clone());
let mut c = a.matmul(b.permute()).retrieve();
let mut a = cx.named_tensor("Vec", (1, M)).set(a_vec.clone());
let mut b = cx.named_tensor("Mat", (N, M)).set(b_mat.clone());
let mut c = a.matmul(b.permute((1, 0))).retrieve();
cx.compile(
<(GenericCompiler, MetalCompiler<f16>)>::default(),
@@ -454,8 +454,8 @@ mod tests {
const N: usize = 256;
let mut cx = Graph::new();
let (a_vec, b_mat) = (random_vec(M), random_vec(M * N));
let mut a = cx.named_tensor::<R3<1, 1, M>>("Vec").set(a_vec.clone());
let mut b = cx.named_tensor::<R2<M, N>>("Mat").set(b_mat.clone());
let mut a = cx.named_tensor("Vec", (1, 1, M)).set(a_vec.clone());
let mut b = cx.named_tensor("Mat", (M, N)).set(b_mat.clone());
let mut c = a.matmul(b).retrieve();
cx.compile(

View File

@@ -1054,7 +1054,7 @@ impl<T: MetalFloat + 'static> Compiler for PrimitiveCompiler<T> {
// Copy outputs to device
let copy_node = graph
.add_op(MetalCopyToDevice::<T>::new(dev.clone()))
.input(function_node, 0, ShapeTracker::new(&[]))
.input(function_node, 0, ShapeTracker::default())
.finish();
// Switch outgoing edges from input to copy_node

View File

@@ -900,9 +900,9 @@ mod tests {
let mat_data: Vec<i8> = (0..(1024 * 512)).map(|_| rng.gen_range(0..5)).collect();
let vec_data = random_vec_rng(1024, &mut rng);
let mut cx = Graph::new();
let weights = cx.tensor::<R2<512, 1024>>().keep();
let vec = cx.tensor::<R1<1024>>().set(vec_data.clone());
let mut out = vec.matmul(weights.permute()).retrieve();
let weights = cx.tensor((512, 1024)).keep();
let vec = cx.tensor(1024).set(vec_data.clone());
let mut out = vec.matmul(weights.permute((1, 0))).retrieve();
// "Load" weights in 8bit
let blocks = mat_data
@@ -933,11 +933,11 @@ mod tests {
let mut cx1 = Graph::new();
let weights = cx1
.tensor::<R2<512, 1024>>()
.tensor((512, 1024))
.set(mat_data.into_iter().map(|i| i as f32).collect::<Vec<_>>())
.keep();
let vec = cx1.tensor::<R1<1024>>().set(vec_data);
let out_32 = vec.matmul(weights.permute()).retrieve();
let vec = cx1.tensor(1024).set(vec_data);
let out_32 = vec.matmul(weights.permute((1, 0))).retrieve();
cx1.execute();
assert_close(&out.data(), &out_32.data());
@@ -949,9 +949,9 @@ mod tests {
let mat_data: Vec<i8> = (0..(1024 * 512)).map(|_| rng.gen_range(0..5)).collect();
let inp_mat_data = random_vec_rng(1024 * 16, &mut rng);
let mut cx = Graph::new();
let weights = cx.tensor::<R2<512, 1024>>().keep();
let inp_mat = cx.tensor::<R2<16, 1024>>().set(inp_mat_data.clone());
let mut out = inp_mat.matmul(weights.permute()).retrieve();
let weights = cx.tensor((512, 1024)).keep();
let inp_mat = cx.tensor((16, 1024)).set(inp_mat_data.clone());
let mut out = inp_mat.matmul(weights.permute((1, 0))).retrieve();
// "Load" weights in 8bit
let blocks = mat_data
@@ -1000,9 +1000,9 @@ mod tests {
let mat_data: Vec<i8> = (0..(1024 * 512)).map(|_| rng.gen_range(0..5)).collect();
let inp_mat_data = random_vec_rng(1024 * 16, &mut rng);
let mut cx = Graph::new();
let weights = cx.tensor::<R2<512, 1024>>().keep();
let inp_mat = cx.tensor::<R2<16, 1024>>().set(inp_mat_data.clone());
let mut out = inp_mat.matmul(weights.permute()).retrieve();
let weights = cx.tensor((512, 1024)).keep();
let inp_mat = cx.tensor((16, 1024)).set(inp_mat_data.clone());
let mut out = inp_mat.matmul(weights.permute((1, 0))).retrieve();
// "Load" weights in 8bit
let blocks = mat_data

View File

@@ -390,7 +390,7 @@ fn test_shared_buffers() {
use luminal::prelude::*;
use luminal::tests::{assert_close_precision, random_vec};
let mut cx = Graph::new();
let a = cx.tensor::<R1<5>>().set(random_vec(5)).keep();
let a = cx.tensor(5).set(random_vec(5)).keep();
let b = a.exp2();
let c = a.log2() * b;
let d = b.recip();

View File

@@ -1,5 +1,3 @@
use std::marker::PhantomData;
use dfdx::prelude::{Module as DfdxModule, *};
use metal_rs::objc::rc::autoreleasepool;
use rand::{rngs::StdRng, SeedableRng};
@@ -23,13 +21,13 @@ unary_test!(
);
unary_test!(|a| a.exp2(), |a| (a * 2_f32.ln()).exp(), test_exp2, f16);
unary_test!(
|a| a.softmax::<LAxis<0>>(),
|a| a.softmax(0),
|a| a.softmax::<DAxis<0>>(),
test_softmax,
f16
);
unary_test!(
|a| a.layer_norm::<LAxis<0>, _>(1e-5),
|a| a.layer_norm(0, 1e-5),
|a| a
.to_dtype::<f32>()
.normalize::<DAxis<0>>(1e-5)
@@ -60,8 +58,8 @@ binary_test!(
fn test_contiguous() {
let mut cx = Graph::new();
let data = random_vec(12);
let a = cx.tensor::<R2<3, 4>>().set(data.clone());
let mut b = a.permute::<R2<4, 3>, _>().reshape::<R2<12, 1>>().retrieve();
let a = cx.tensor((3, 4)).set(data.clone());
let mut b = a.permute((1, 0)).reshape((12, 1)).retrieve();
cx.compile(MetalCompiler::<f16>::default(), &mut b);
cx.execute();
@@ -81,15 +79,10 @@ fn test_rotate() {
const F: usize = 3;
const D: usize = 4;
let data = random_vec(D * B * F);
let a = cx
.tensor::<R3<F, B, D>>()
.set(data.clone())
.permute::<_, LAxes3<1, 0, 2>>();
let a = cx.tensor((F, B, D)).set(data.clone()).permute((1, 0, 2));
let x1 = a.slice((.., .., ..Expression::from(D / 2)));
let x2 = a.slice((.., .., Expression::from(D / 2)..));
let mut rotated_a = (-x2)
.concat_along::<R3<B, F, D>, LAxis<1>, _>(x1)
.retrieve();
let mut rotated_a = (-x2).concat_along(x1, 1).retrieve();
cx.execute();
let unopt = rotated_a.data();
rotated_a.drop();
@@ -121,10 +114,10 @@ fn test_constant() {
fn test_sum_reduce() {
let data = random_vec(40960);
let mut cx = Graph::new();
let a = cx.tensor::<R3<1, 10, 4096>>().set(data.clone());
let mut b = a.sum_reduce::<_, LAxis<2>>().retrieve();
let mut c = a.sum_reduce::<_, LAxis<1>>().retrieve();
let mut d = a.sum_reduce::<_, LAxis<0>>().retrieve();
let a = cx.tensor((1, 10, 4096)).set(data.clone());
let mut b = a.sum_reduce(2).retrieve();
let mut c = a.sum_reduce(1).retrieve();
let mut d = a.sum_reduce(0).retrieve();
cx.compile(MetalCompiler::<f16>::default(), (&mut b, &mut c, &mut d));
cx.execute();
@@ -145,8 +138,8 @@ fn test_sum_reduce() {
fn test_sum_reduce2() {
let mut cx = Graph::new();
let data = random_vec(32 * 10 * 10 * 128);
let a = cx.tensor::<R5<1, 32, 10, 10, 128>>().set(data.clone());
let mut d = a.sum_reduce::<_, LAxis<2>>().retrieve();
let a = cx.tensor((1, 32, 10, 10, 128)).set(data.clone());
let mut d = a.sum_reduce(2).retrieve();
cx.compile(MetalCompiler::<f16>::default(), &mut d);
cx.execute();
@@ -173,10 +166,10 @@ fn test_sum_reduce2() {
fn test_max_reduce() {
let data = random_vec(40960);
let mut cx = Graph::new();
let a = cx.tensor::<R3<1, 10, 4096>>().set(data.clone());
let mut b = a.max_reduce::<_, LAxis<2>>().retrieve();
let mut c = a.max_reduce::<_, LAxis<1>>().retrieve();
let mut d = a.max_reduce::<_, LAxis<0>>().retrieve();
let a = cx.tensor((1, 10, 4096)).set(data.clone());
let mut b = a.max_reduce(2).retrieve();
let mut c = a.max_reduce(1).retrieve();
let mut d = a.max_reduce(0).retrieve();
cx.compile(MetalCompiler::<f16>::default(), (&mut b, &mut c, &mut d));
cx.execute();
@@ -197,10 +190,10 @@ fn test_max_reduce() {
fn test_mean_reduce() {
let data = random_vec(40960);
let mut cx = Graph::new();
let a = cx.tensor::<R3<1, 10, 4096>>().set(data.clone());
let mut b = a.mean_reduce::<_, LAxis<2>>().retrieve();
let mut c = a.mean_reduce::<_, LAxis<1>>().retrieve();
let mut d = a.mean_reduce::<_, LAxis<0>>().retrieve();
let a = cx.tensor((1, 10, 4096)).set(data.clone());
let mut b = a.mean_reduce(2).retrieve();
let mut c = a.mean_reduce(1).retrieve();
let mut d = a.mean_reduce(0).retrieve();
cx.compile(MetalCompiler::<f16>::default(), (&mut b, &mut c, &mut d));
cx.execute();
@@ -222,8 +215,8 @@ fn test_matmul_simple() {
let mut cx = Graph::new();
let a_data = random_vec(256 * 256);
let b_data = random_vec(256 * 256);
let a = cx.tensor::<R2<256, 256>>().set(a_data.clone());
let b = cx.tensor::<R2<256, 256>>().set(b_data.clone());
let a = cx.tensor((256, 256)).set(a_data.clone());
let b = cx.tensor((256, 256)).set(b_data.clone());
let mut c = a.matmul(b).retrieve();
cx.compile(MetalCompiler::<f16>::default(), &mut c);
@@ -242,8 +235,8 @@ fn test_matmul() {
let d_dev = Cpu::default();
let mut cx = Graph::new();
let mut rng = StdRng::seed_from_u64(0);
let a = cx.tensor::<(Dyn<'M'>, Dyn<'K'>)>();
let b = cx.tensor::<(Dyn<'K'>, Dyn<'N'>)>();
let a = cx.tensor(('M', 'K'));
let b = cx.tensor(('K', 'N'));
let mut c = a.matmul(b).retrieve();
cx.compile(MetalCompiler::<f16>::default(), &mut c);
@@ -253,8 +246,8 @@ fn test_matmul() {
autoreleasepool(|| {
let a_data = random_vec_rng(m * k, &mut rng);
let b_data = random_vec_rng(k * n, &mut rng);
a.set_dyn(a_data.clone(), &[m, k]);
b.set_dyn(b_data.clone(), &[k, n]);
a.set_dyn(a_data.clone(), (m, k));
b.set_dyn(b_data.clone(), (k, n));
cx.execute();
@@ -276,11 +269,11 @@ fn test_attn_matmul() {
let a_data = random_vec_rng(32 * 11 * 128, &mut rng);
let b_data = random_vec_rng(32 * 11 * 128, &mut rng);
let a = cx
.named_tensor::<R4<1, 32, 11, 128>>("Input")
.named_tensor("Input", (1, 32, 11, 128))
.set(a_data.clone())
.keep();
let b = cx
.named_tensor::<R4<1, 32, 128, 11>>("Input")
.named_tensor("Input", (1, 32, 128, 11))
.set(b_data.clone())
.keep();
let mut c = a.matmul(b).retrieve();
@@ -310,8 +303,8 @@ fn test_batch_matmul() {
let m = 12;
let mut cx = Graph::new();
let mut rng = StdRng::seed_from_u64(0);
let a = cx.tensor::<(Dyn<'B'>, Dyn<'M'>, Dyn<'K'>)>();
let b = cx.tensor::<(Dyn<'K'>, Dyn<'N'>)>();
let a = cx.tensor(('B', 'M', 'K'));
let b = cx.tensor(('K', 'N'));
let mut c = a.matmul(b).retrieve();
cx.compile(MetalCompiler::<f16>::default(), &mut c);
@@ -321,8 +314,8 @@ fn test_batch_matmul() {
autoreleasepool(|| {
let a_data = random_vec_rng(batch * m * k, &mut rng);
let b_data = random_vec_rng(k * n, &mut rng);
a.set_dyn(a_data.clone(), &[batch, m, k]);
b.set_dyn(b_data.clone(), &[k, n]);
a.set_dyn(a_data.clone(), (batch, m, k));
b.set_dyn(b_data.clone(), (k, n));
cx.execute();
let d_dev = Cpu::default();
@@ -348,26 +341,18 @@ fn test_batch_matmul_transpose() {
let mut rng = StdRng::seed_from_u64(0);
let a_data = random_vec_rng(B * M * K, &mut rng);
let a = cx.named_tensor::<R3<B, M, K>>("A").set(a_data.clone());
let a = cx.named_tensor("A", (B, M, K)).set(a_data.clone());
let b_data = random_vec_rng(K * N, &mut rng);
let b = cx.named_tensor::<R2<N, K>>("B").set(b_data.clone());
let b = cx.named_tensor("B", (N, K)).set(b_data.clone());
let a_t_data = random_vec_rng(B * K * M, &mut rng);
let a_t = cx.named_tensor::<R3<B, K, M>>("A_T").set(a_t_data.clone());
let a_t = cx.named_tensor("A_T", (B, K, M)).set(a_t_data.clone());
let b_t_data = random_vec_rng(K * N, &mut rng);
let b_t = cx.named_tensor::<R2<K, N>>("B_T").set(b_t_data.clone());
let b_t = cx.named_tensor("B_T", (K, N)).set(b_t_data.clone());
let mut a_b = a
.matmul(b.permute::<_, luminal::prelude::Axes2<1, 0>>())
.retrieve();
let mut a_b = a.matmul(b.permute((1, 0))).retrieve();
let mut a_b_t = a.matmul(b_t).retrieve();
let mut a_t_b = a_t
.permute::<_, luminal::prelude::Axes3<0, 2, 1>>()
.matmul(b.permute::<_, luminal::prelude::Axes2<1, 0>>())
.retrieve();
let mut a_t_b_t = a_t
.permute::<_, luminal::prelude::Axes3<0, 2, 1>>()
.matmul(b_t)
.retrieve();
let mut a_t_b = a_t.permute((0, 2, 1)).matmul(b.permute((1, 0))).retrieve();
let mut a_t_b_t = a_t.permute((0, 2, 1)).matmul(b_t).retrieve();
cx.compile(
<(GenericCompiler, MetalCompiler<f16>)>::default(),
@@ -407,24 +392,18 @@ fn test_matmul_transpose() {
let mut rng = StdRng::seed_from_u64(0);
let a_data = random_vec_rng(M * K, &mut rng);
let a = cx.tensor::<R2<M, K>>().set(a_data.clone());
let a = cx.tensor((M, K)).set(a_data.clone());
let b_data = random_vec_rng(K * N, &mut rng);
let b = cx.tensor::<R2<N, K>>().set(b_data.clone());
let b = cx.tensor((N, K)).set(b_data.clone());
let a_t_data = random_vec_rng(K * M, &mut rng);
let a_t = cx.tensor::<R2<K, M>>().set(a_t_data.clone());
let a_t = cx.tensor((K, M)).set(a_t_data.clone());
let b_t_data = random_vec_rng(K * N, &mut rng);
let b_t = cx.tensor::<R2<K, N>>().set(b_t_data.clone());
let b_t = cx.tensor((K, N)).set(b_t_data.clone());
let mut a_b = a.matmul(b.permute()).retrieve();
let mut a_b = a.matmul(b.permute((1, 0))).retrieve();
let mut a_b_t = a.matmul(b_t).retrieve();
let mut a_t_b = a_t
.permute::<_, luminal::prelude::Axes2<1, 0>>()
.matmul(b.permute())
.retrieve();
let mut a_t_b_t = a_t
.permute::<_, luminal::prelude::Axes2<1, 0>>()
.matmul(b_t)
.retrieve();
let mut a_t_b = a_t.permute((1, 0)).matmul(b.permute((1, 0))).retrieve();
let mut a_t_b_t = a_t.permute((1, 0)).matmul(b_t).retrieve();
cx.compile(
<(GenericCompiler, MetalCompiler<f16>)>::default(),
@@ -468,12 +447,14 @@ fn test_relu_and_linear() {
let input_data = random_vec(32);
let w1 = random_vec(32 * 64);
let w2 = random_vec(32 * 64);
let batch = cx
.named_tensor::<R2<2, 32>>("Batch")
.set(random_vec(32 * 2));
let a = cx.named_tensor::<R1<32>>("Single").set(input_data.clone());
let batch = cx.named_tensor("Batch", (2, 32)).set(random_vec(32 * 2));
let a = cx.named_tensor("Single", 32).set(input_data.clone());
let model: (Linear<32, 64>, ReLU, Linear<64, 32>) = InitModule::initialize(&mut cx);
let model = (
Linear::new(32, 64, false, &mut cx),
ReLU,
Linear::new(64, 32, false, &mut cx),
);
model.0.weight.set(w1.clone());
model.2.weight.set(w2.clone());
let mut b = model.forward(a).retrieve();
@@ -524,9 +505,9 @@ fn test_rms_norm() {
let inp_data = random_vec_rng(15 * 32, &mut rng);
let weight_data = random_vec_rng(32, &mut rng);
let mut cx = Graph::new();
let a = cx.tensor::<R2<15, 32>>().set(inp_data.clone());
let a = cx.tensor((15, 32)).set(inp_data.clone());
let model = LayerNorm::<32>::new(true, false, false, 1e-5, &mut cx);
let model = LayerNorm::new(32, true, false, false, 1e-5, &mut cx);
model.weight.unwrap().set(weight_data.clone());
let mut b = model.forward(a).retrieve();
@@ -553,9 +534,9 @@ fn test_rms_norm() {
fn test_layer_norm() {
let mut cx = Graph::new();
let a_data = random_vec(15 * 16 * 32);
let a = cx.tensor::<R3<15, 16, 32>>().set(a_data.clone());
let mut b = a.layer_norm::<LAxis<0>, _>(1e-5).retrieve();
let mut c = a.layer_norm::<LAxis<2>, _>(1e-5).retrieve();
let a = cx.tensor((15, 16, 32)).set(a_data.clone());
let mut b = a.layer_norm(0, 1e-5).retrieve();
let mut c = a.layer_norm(2, 1e-5).retrieve();
cx.compile(
<(GenericCompiler, MetalCompiler<f16>)>::default(),
(&mut b, &mut c),
@@ -574,82 +555,110 @@ fn test_layer_norm() {
#[test]
fn test_transformer_encoder_block() {
let mut cx = Graph::new();
let model: luminal_nn::TransformerEncoderBlock<32, 64, 1> = InitModule::initialize(&mut cx);
let w_k_weight = random_vec(32 * 32);
model.attention.w_k.weight.set(w_k_weight.clone());
let w_q_weight = random_vec(32 * 32);
model.attention.w_q.weight.set(w_q_weight.clone());
let w_v_weight = random_vec(32 * 32);
model.attention.w_v.weight.set(w_v_weight.clone());
let w_o_weight = random_vec(32 * 32);
model.attention.w_o.weight.set(w_o_weight.clone());
let ff_0_weight = random_vec(32 * 64);
model.ff.0.weight.set(ff_0_weight.clone());
let ff_1_weight = random_vec(64 * 32);
model.ff.2.weight.set(ff_1_weight.clone());
let model = luminal_nn::TransformerEncoderBlock::new(3, 4, 1, &mut cx);
model
.attention
.w_k
.weight
.set(vec![1., 22., 3., 1., 2., 3., 1., 2., 3.]);
model
.attention
.w_q
.weight
.set(vec![3., 2., 3., 1.3, 2., 3., 3., 2., 3.]);
model
.attention
.w_v
.weight
.set(vec![-1., 12., 3., -1., 2., -3., 11., 2., 3.]);
model
.attention
.w_o
.weight
.set(vec![1., 22., 3., 1., 2., 3., 1., 2., 3.]);
model
.ff
.0
.weight
.set(vec![-1., 12., 3., -1., 2., -3., 11., 2., 3., 11., 2., 3.]);
model
.ff
.2
.weight
.set(vec![-1., 12., 3., -1., 2., -3., 11., 2., 3., 3., -1., 2.]);
let a_data = random_vec(2 * 32);
let a = cx
.tensor::<(Dyn<'b'>, Dyn<'a'>, luminal::prelude::Const<32>)>()
.set_dyn(a_data.clone(), &[1, 2, 3])
.keep();
cx.keep_tensors(param_dict(&model));
.tensor(('a', 3))
.set_dyn(vec![-1., 2., 3., 3., 3., -1.], (2, 3));
let mut b = model.forward(a).retrieve();
cx.execute();
let unopt_b = b.data();
b.drop();
cx.compile(<(GenericCompiler, MetalCompiler<f16>)>::default(), &mut b);
cx.execute();
assert_close_precision(&unopt_b, &b.data(), 1e-2);
let d_dev = Cpu::default();
let mut d_model: dfdx::nn::modules::TransformerEncoderBlock<32, 1, 64, f32, Cpu> =
d_dev
.build_module::<dfdx::nn::modules::builders::TransformerEncoderBlock<32, 1, 64>, f32>();
d_model.self_attn.w_k.bias.copy_from(&[0.; 32]);
d_model.self_attn.w_v.bias.copy_from(&[0.; 32]);
d_model.self_attn.w_q.bias.copy_from(&[0.; 32]);
d_model.self_attn.w_o.bias.copy_from(&[0.; 32]);
let mut d_model: dfdx::nn::modules::TransformerEncoderBlock<3, 1, 4, f32, Cpu> =
d_dev.build_module::<dfdx::nn::modules::builders::TransformerEncoderBlock<3, 1, 4>, f32>();
d_model.self_attn.w_k.bias.copy_from(&[0.0, 0.0, 0.0]);
d_model.self_attn.w_v.bias.copy_from(&[0.0, 0.0, 0.0]);
d_model.self_attn.w_q.bias.copy_from(&[0.0, 0.0, 0.0]);
d_model.self_attn.w_o.bias.copy_from(&[0., 0., 0.]);
d_model.self_attn.w_o.weight = d_dev
.tensor_from_vec(w_o_weight, (DConst::<32>, DConst::<32>))
.tensor_from_vec(
vec![1., 22., 3., 1., 2., 3., 1., 2., 3.],
(DConst::<3>, DConst::<3>),
)
.permute();
d_model.self_attn.w_k.weight = d_dev
.tensor_from_vec(w_k_weight, (DConst::<32>, DConst::<32>))
.tensor_from_vec(
vec![1., 22., 3., 1., 2., 3., 1., 2., 3.],
(DConst::<3>, DConst::<3>),
)
.permute();
d_model.self_attn.w_q.weight = d_dev
.tensor_from_vec(w_q_weight, (DConst::<32>, DConst::<32>))
.tensor_from_vec(
vec![3., 2., 3., 1.3, 2., 3., 3., 2., 3.],
(DConst::<3>, DConst::<3>),
)
.permute();
d_model.self_attn.w_v.weight = d_dev
.tensor_from_vec(w_v_weight, (DConst::<32>, DConst::<32>))
.tensor_from_vec(
vec![-1., 12., 3., -1., 2., -3., 11., 2., 3.],
(DConst::<3>, DConst::<3>),
)
.permute();
d_model.ff.0 .0.weight = d_dev
.tensor_from_vec(ff_0_weight, (DConst::<32>, DConst::<64>))
.tensor_from_vec(
vec![-1., 12., 3., -1., 2., -3., 11., 2., 3., 11., 2., 3.],
(DConst::<3>, DConst::<4>),
)
.permute();
d_model.ff.0 .0.bias = d_dev.tensor_from_vec(vec![0.; 64], (DConst::<64>,));
d_model.ff.0 .0.bias = d_dev.tensor_from_vec(vec![0., 0., 0., 0.], (DConst::<4>,));
d_model.ff.0 .2.weight = d_dev
.tensor_from_vec(ff_1_weight, (DConst::<64>, DConst::<32>))
.tensor_from_vec(
vec![-1., 12., 3., -1., 2., -3., 11., 2., 3., 3., -1., 2.],
(DConst::<4>, DConst::<3>),
)
.permute();
d_model.ff.0 .2.bias = d_dev.tensor_from_vec(vec![0.; 32], (DConst::<32>,));
d_model.norm1.gamma = d_dev.tensor_from_vec(vec![1.; 32], (DConst::<32>,));
d_model.norm2.gamma = d_dev.tensor_from_vec(vec![1.; 32], (DConst::<32>,));
d_model.ff.0 .2.bias = d_dev.tensor_from_vec(vec![0., 0., 0.], (DConst::<3>,));
d_model.norm1.gamma = d_dev.tensor_from_vec(vec![1., 1., 1.], (DConst::<3>,));
d_model.norm2.gamma = d_dev.tensor_from_vec(vec![1., 1., 1.], (DConst::<3>,));
d_model.norm1.epsilon = 1e-5;
d_model.norm2.beta = d_dev.tensor_from_vec(vec![0.; 32], (DConst::<32>,));
d_model.norm1.beta = d_dev.tensor_from_vec(vec![0.; 32], (DConst::<32>,));
d_model.norm2.beta = d_dev.tensor_from_vec(vec![0., 0., 0.], (DConst::<3>,));
d_model.norm1.beta = d_dev.tensor_from_vec(vec![0., 0., 0.], (DConst::<3>,));
d_model.norm2.epsilon = 1e-5;
let d_a = d_dev.tensor_from_vec(a_data, (DConst::<2>, DConst::<32>));
let d_a = d_dev.tensor_from_vec(vec![-1., 2., 3., 3., 3., -1.], (DConst::<2>, DConst::<3>));
let d_b = d_model.forward(d_a);
assert_close_precision(&b.data(), &d_b.as_vec(), 1e-2);
assert_close(&b.data(), &d_b.as_vec());
}
#[test]
fn test_common_buffer() {
let data = random_vec(32);
let mut cx = Graph::new();
let a = cx.tensor::<R1<32>>();
let a = cx.tensor(32);
a.set(data.clone());
let a1 = cx.tensor::<R1<32>>();
let a1 = cx.tensor(32);
a1.set(data.clone());
let exped = a * a1;
let mut b = exped.log2().retrieve();
@@ -663,15 +672,12 @@ fn test_common_buffer() {
fn test_embedding() {
let mut cx = Graph::new();
let batch = cx
.named_tensor::<R2<2, 3>>("Batch")
.named_tensor("Batch", (2, 3))
.set(vec![1.0, 0.0, 2.0, 1.0, 0.0, 1.0])
.keep();
let a = cx
.named_tensor::<R1<3>>("Single")
.set(vec![1.0, 0.0, 1.0])
.keep();
let a = cx.named_tensor("Single", 3).set(vec![1.0, 0.0, 1.0]).keep();
let model: luminal_nn::Embedding<3, 4> = InitModule::initialize(&mut cx);
let model = luminal_nn::Embedding::new(3, 4, &mut cx);
model
.weight
.set(vec![1.1, 2., 3., 1., 2., 3., 14., 2., 33., 1., 2., 3.]);
@@ -702,12 +708,8 @@ fn test_embedding() {
fn test_slice() {
let data = random_vec(256);
let mut cx = Graph::new();
let a = cx.tensor::<R1<256>>().set(data.clone());
let mut c: GraphTensor<R1<20>> = a
.slice((..Expression::from(20),))
.realize()
.contiguous()
.retrieve();
let a = cx.tensor(256).set(data.clone());
let mut c = a.slice(..Expression::from(20)).contiguous().retrieve();
cx.compile(MetalCompiler::<f16>::default(), &mut c);
cx.execute();
@@ -726,8 +728,8 @@ fn test_pad() {
// Pad a 8x2 mat to 10x4
let data = random_vec(8 * 2);
let mut cx = Graph::new();
let a = cx.tensor::<R2<8, 2>>().set(data.clone());
let mut c = a.pad::<R2<10, 4>>(((0, 2), (0, 2))).contiguous().retrieve();
let a = cx.tensor((8, 2)).set(data.clone());
let mut c = a.pad(((0, 2), (0, 2))).contiguous().retrieve();
cx.compile(MetalCompiler::<f16>::default(), &mut c);
cx.execute();
@@ -748,16 +750,12 @@ fn test_pad_contig() {
let mut cx = Graph::new();
let mut rng = StdRng::seed_from_u64(0);
let a_data = random_vec_rng(m * k, &mut rng);
let mut a = cx
.tensor::<(Dyn<'M'>, Dyn<'K'>)>()
.set_dyn(a_data, &[m, k])
.retrieve();
let mut b: GraphTensor<(Dyn<'M'>, Dyn<'K'>)> = a
let mut a = cx.tensor(('M', 'K')).set_dyn(a_data, (m, k)).retrieve();
let mut b = a
.pad(&[(0, 0.into()), (0, Expression::from(24) - 'K')])
.contiguous()
.retrieve();
let mut c: GraphTensor<(Dyn<'M'>, Dyn<'K'>)> =
(a.slice((.., ..Expression::from(k))).realize() / 1.0).retrieve();
let mut c = (a.slice((.., ..Expression::from(k))) / 1.0).retrieve();
cx.compile(MetalCompiler::<f16>::default(), (&mut a, &mut b, &mut c));
cx.execute();
@@ -771,13 +769,9 @@ fn test_pad_contig() {
fn test_movement() {
let data = random_vec(32);
let mut cx = Graph::new();
let a = cx.tensor::<R1<32>>().set(data.clone());
let b: GraphTensor<R1<42>> = a.pad((0, 10)).contiguous().retrieve();
let mut c: GraphTensor<R1<25>> = b
.slice((..Expression::from(25),))
.realize()
.contiguous()
.retrieve();
let a = cx.tensor(32).set(data.clone());
let b = a.pad((0, 10)).contiguous().retrieve();
let mut c = b.slice((..Expression::from(25),)).contiguous().retrieve();
cx.compile(MetalCompiler::<f16>::default(), &mut c);
cx.execute();
@@ -794,10 +788,9 @@ fn test_movement() {
#[test]
fn test_slice_add() {
let mut cx = Graph::new();
let a = cx.tensor().set(random_array::<256>());
let a = cx.tensor(256).set(random_array::<256>());
let mut b = (a.slice(0..64) + a.slice(64..128) + a.slice(128..192) + a.slice(192..256))
.realize::<R1<64>>()
.expand::<R2<4, 64>, _>()
.expand(0, 4)
.retrieve();
cx.compile(MetalCompiler::<f16>::default(), &mut b);
@@ -814,14 +807,12 @@ fn test_conv2d() {
const KERNELY: usize = 2;
const STRIDEX: usize = KERNELX;
const STRIDEY: usize = KERNELY;
const DILATIONX: usize = 0;
const DILATIONY: usize = 0;
const DILATIONX: usize = 1;
const DILATIONY: usize = 1;
const DIMX_IN: usize = 16;
const DIMX_OUT: usize = ((DIMX_IN - (DILATIONX + 1) * (KERNELX - 1) - 1) / STRIDEX) + 1;
const DIMY_IN: usize = 9;
const DIMY_OUT: usize = ((DIMY_IN - (DILATIONY + 1) * (KERNELY - 1) - 1) / STRIDEY) + 1;
let inp1 = cx.tensor::<R3<CH_IN, DIMX_IN, DIMY_IN>>().set(vec![
let inp1 = cx.tensor((CH_IN, DIMX_IN, DIMY_IN)).set(vec![
8., 8., 5., 7., 0., 6., 5., 3., 0., 7., 0., 6., 6., 7., 7., 5., 0., 6., 9., 4., 0., 8., 8.,
5., 7., 6., 2., 8., 9., 5., 0., 3., 1., 1., 8., 4., 1., 1., 5., 6., 9., 3., 2., 9., 4., 7.,
1., 0., 7., 7., 4., 9., 5., 0., 4., 7., 4., 7., 8., 8., 4., 8., 4., 7., 9., 3., 7., 9., 5.,
@@ -856,7 +847,14 @@ fn test_conv2d() {
1., 2., 1., 1., 4., 7., 2.,
]);
let model = luminal_nn::Conv2D::<CH_IN, CH_OUT, KERNELX, KERNELY>::initialize(&mut cx);
let model = luminal_nn::Conv2D::new(
CH_IN,
CH_OUT,
(KERNELX, KERNELY),
(STRIDEX, STRIDEY),
(DILATIONX, DILATIONY),
&mut cx,
);
model.weight.set(vec![
0.1600, 0.2000, 0.1900, -0.1100, 0.0100, -0.0300, -0.1200, -0.0800, -0.1300, -0.0300,
0.1600, -0.1700, -0.0000, 0.1900, 0.1300, 0.0300, -0.1500, 0.0900, 0.0100, 0.0200, 0.1500,
@@ -864,9 +862,7 @@ fn test_conv2d() {
0.0100, -0.2000, 0.2100, -0.0400, -0.1400, 0.1500, 0.0500, -0.1700, 0.1400,
]);
let mut out1 = model
.forward::<DIMX_IN, DIMY_IN, DIMX_OUT, DIMY_OUT>(inp1)
.retrieve();
let mut out1 = model.forward(inp1).retrieve();
cx.compile(
<(GenericCompiler, MetalCompiler<f16>)>::default(),
@@ -899,18 +895,21 @@ fn test_conv1d_pad_stride() {
const KERNEL: usize = 3;
const STRIDE: usize = 1;
const PADDING: usize = 1;
const DILATION: usize = 1;
const DIM_IN: usize = 10;
let kernel_data = random_vec_rng(KERNEL * CH_IN * CH_OUT, &mut rng);
let input_data = random_vec_rng(CH_IN * DIM_IN, &mut rng);
let model = Conv1D::<CH_IN, CH_OUT, KERNEL, STRIDE, 0, PADDING>::initialize(&mut cx);
let model = Conv1D::new(
CH_IN, CH_OUT, KERNEL, STRIDE, DILATION, PADDING, false, &mut cx,
);
model.weight.set(kernel_data.clone());
let inp1 = cx
.tensor::<(LConst<1>, LConst<CH_IN>, Dyn<'s'>)>()
.set_dyn(input_data.clone(), &[1, CH_IN, DIM_IN]);
.tensor((CH_IN, 's'))
.set_dyn(input_data.clone(), (CH_IN, DIM_IN));
let mut out1 = model.forward((inp1, PhantomData::<Dyn<'s'>>)).retrieve();
let mut out1 = model.forward(inp1).retrieve();
cx.compile(crate::MetalCompiler::<f16>::default(), &mut out1);
cx.execute();

View File

@@ -1,5 +1,3 @@
use std::marker::PhantomData;
use dfdx::prelude::{Module as DfdxModule, *};
use rand::{rngs::StdRng, SeedableRng};
@@ -17,13 +15,13 @@ unary_test!(|a| a.ln(), |a| a.ln(), test_ln, f32);
unary_test!(|a| a.log2(), |a| a.ln() / 2_f32.ln(), test_log2, f32);
unary_test!(|a| a.exp2(), |a| (a * 2_f32.ln()).exp(), test_exp2, f32);
unary_test!(
|a| a.softmax::<LAxis<0>>(),
|a| a.softmax(0),
|a| a.softmax::<DAxis<0>>(),
test_softmax,
f32
);
unary_test!(
|a| a.mean_norm::<LAxis<0>>().std_norm::<LAxis<0>, _>(1e-5),
|a| a.mean_norm(0).std_norm(0, 1e-5),
|a| a.normalize::<DAxis<0>>(1e-5),
test_norm,
f32
@@ -46,8 +44,8 @@ binary_test!(|a, b| a.max(b), |a, b| a.maximum(b), test_max, f32);
fn test_contiguous() {
let mut cx = Graph::new();
let data = random_vec(12);
let a = cx.tensor::<R2<3, 4>>().set(data.clone());
let mut b = a.permute::<R2<4, 3>, _>().reshape::<R2<12, 1>>().retrieve();
let a = cx.tensor((3, 4)).set(data.clone());
let mut b = a.permute((1, 0)).reshape((12, 1)).retrieve();
cx.compile(MetalCompiler::<f32>::default(), &mut b);
cx.execute();
@@ -64,11 +62,11 @@ fn test_contiguous() {
fn test_sum_reduce() {
let mut cx = Graph::new();
let data = random_vec(4 * 4096);
let a = cx.tensor::<R3<1, 4, 4096>>();
let a = cx.tensor((1, 4, 4096));
a.set(data.clone());
let mut b = a.sum_reduce::<_, LAxis<1>>().retrieve();
let mut c = a.sum_reduce::<_, LAxis<0>>().retrieve();
let mut d = a.sum_reduce::<_, LAxis<2>>().retrieve();
let mut b = a.sum_reduce(1).retrieve();
let mut c = a.sum_reduce(0).retrieve();
let mut d = a.sum_reduce(2).retrieve();
cx.compile(MetalCompiler::<f32>::default(), (&mut b, &mut c, &mut d));
cx.execute();
@@ -88,11 +86,11 @@ fn test_sum_reduce() {
fn test_max_reduce() {
let mut cx = Graph::new();
let data = random_vec(12);
let a = cx.tensor::<R3<2, 2, 3>>();
let a = cx.tensor((2, 2, 3));
a.set(data.clone());
let mut b = a.max_reduce::<_, LAxis<1>>().retrieve();
let mut c = a.max_reduce::<_, LAxis<0>>().retrieve();
let mut d = a.max_reduce::<_, LAxis<2>>().retrieve();
let mut b = a.max_reduce(1).retrieve();
let mut c = a.max_reduce(0).retrieve();
let mut d = a.max_reduce(2).retrieve();
cx.compile(MetalCompiler::<f32>::default(), (&mut b, &mut c, &mut d));
cx.execute();
@@ -112,8 +110,8 @@ fn test_max_reduce() {
fn test_mean_reduce() {
let data = random_vec(40960);
let mut cx = Graph::new();
let a = cx.tensor::<R3<1, 10, 4096>>().set(data.clone());
let mut b = a.mean_reduce::<_, LAxis<2>>().retrieve();
let a = cx.tensor((1, 10, 4096)).set(data.clone());
let mut b = a.mean_reduce(2).retrieve();
cx.compile(MetalCompiler::<f32>::default(), &mut b);
cx.execute();
@@ -129,8 +127,8 @@ fn test_matmul_simple() {
let mut cx = Graph::new();
let a_data = random_vec(256 * 256);
let b_data = random_vec(256 * 256);
let a = cx.tensor::<R2<256, 256>>().set(a_data.clone());
let b = cx.tensor::<R2<256, 256>>().set(b_data.clone());
let a = cx.tensor((256, 256)).set(a_data.clone());
let b = cx.tensor((256, 256)).set(b_data.clone());
let mut c = a.matmul(b).retrieve();
cx.compile(MetalCompiler::<f32>::default(), &mut c);
@@ -149,8 +147,8 @@ fn test_matmul() {
let mut cx = Graph::new();
let a_data = random_vec(512 * 512);
let b_data = random_vec(512 * 512);
let a = cx.tensor::<R2<512, 512>>().set(a_data.clone());
let b = cx.tensor::<R2<512, 512>>().set(b_data.clone());
let a = cx.tensor((512, 512)).set(a_data.clone());
let b = cx.tensor((512, 512)).set(b_data.clone());
let mut c = a.matmul(b).retrieve();
cx.compile(MetalCompiler::<f32>::default(), &mut c);
@@ -168,10 +166,10 @@ fn test_matmul() {
fn test_batch_matmul() {
let mut cx = Graph::new();
let a = cx
.tensor::<R3<2, 2, 3>>()
.tensor((2, 2, 3))
.set(vec![1., 2., 3., 1., 2., 1., 1., 2., 3., 1., 2., 1.]);
let b = cx
.tensor::<R2<3, 4>>()
.tensor((3, 4))
.set(vec![1., 2., 3., 1., 1., 2., 1., 2., -1., -2., 1., 2.]);
let mut c = a.matmul(b).retrieve();
@@ -195,24 +193,18 @@ fn test_matmul_transpose() {
let mut rng = StdRng::seed_from_u64(0);
let a_data = random_vec_rng(M * K, &mut rng);
let a = cx.tensor::<R2<M, K>>().set(a_data.clone());
let a = cx.tensor((M, K)).set(a_data.clone());
let b_data = random_vec_rng(K * N, &mut rng);
let b = cx.tensor::<R2<N, K>>().set(b_data.clone());
let b = cx.tensor((N, K)).set(b_data.clone());
let a_t_data = random_vec_rng(K * M, &mut rng);
let a_t = cx.tensor::<R2<K, M>>().set(a_t_data.clone());
let a_t = cx.tensor((K, M)).set(a_t_data.clone());
let b_t_data = random_vec_rng(K * N, &mut rng);
let b_t = cx.tensor::<R2<K, N>>().set(b_t_data.clone());
let b_t = cx.tensor((K, N)).set(b_t_data.clone());
let mut a_b = a.matmul(b.permute()).retrieve();
let mut a_b = a.matmul(b.permute((1, 0))).retrieve();
let mut a_b_t = a.matmul(b_t).retrieve();
let mut a_t_b = a_t
.permute::<_, luminal::prelude::Axes2<1, 0>>()
.matmul(b.permute())
.retrieve();
let mut a_t_b_t = a_t
.permute::<_, luminal::prelude::Axes2<1, 0>>()
.matmul(b_t)
.retrieve();
let mut a_t_b = a_t.permute((1, 0)).matmul(b.permute((1, 0))).retrieve();
let mut a_t_b_t = a_t.permute((1, 0)).matmul(b_t).retrieve();
cx.compile(
MetalCompiler::<f32>::default(),
@@ -248,12 +240,14 @@ fn test_relu_and_linear() {
let input_data = random_vec(32);
let w1 = random_vec(32 * 64);
let w2 = random_vec(32 * 64);
let batch = cx
.named_tensor::<R2<2, 32>>("Batch")
.set(random_vec(32 * 2));
let a = cx.named_tensor::<R1<32>>("Single").set(input_data.clone());
let batch = cx.named_tensor("Batch", (2, 32)).set(random_vec(32 * 2));
let a = cx.named_tensor("Single", 32).set(input_data.clone());
let model: (Linear<32, 64>, ReLU, Linear<64, 32>) = InitModule::initialize(&mut cx);
let model = (
Linear::new(32, 64, false, &mut cx),
ReLU,
Linear::new(64, 32, false, &mut cx),
);
model.0.weight.set(w1.clone());
model.2.weight.set(w2.clone());
let mut b = model.forward(a).retrieve();
@@ -296,7 +290,7 @@ fn test_relu_and_linear() {
#[test]
fn test_transformer_encoder_block() {
let mut cx = Graph::new();
let model: luminal_nn::TransformerEncoderBlock<3, 4, 1> = InitModule::initialize(&mut cx);
let model = luminal_nn::TransformerEncoderBlock::new(3, 4, 1, &mut cx);
model
.attention
.w_k
@@ -329,8 +323,8 @@ fn test_transformer_encoder_block() {
.set(vec![-1., 12., 3., -1., 2., -3., 11., 2., 3., 3., -1., 2.]);
let a = cx
.tensor::<(Dyn<'b'>, Dyn<'a'>, luminal::prelude::Const<3>)>()
.set_dyn(vec![-1., 2., 3., 3., 3., -1.], &[1, 2, 3]);
.tensor(('a', 3))
.set_dyn(vec![-1., 2., 3., 3., 3., -1.], (2, 3));
let mut b = model.forward(a).retrieve();
cx.compile(<(GenericCompiler, MetalCompiler<f32>)>::default(), &mut b);
@@ -397,11 +391,11 @@ fn test_transformer_encoder_block() {
fn test_pool_1d_dims() {
let mut cx = Graph::new();
let inp1 = cx.tensor::<R2<4, 4>>().set(vec![
let inp1 = cx.tensor((4, 4)).set(vec![
1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16.,
]);
// Stride 1
let out1 = inp1.pool_last_dim::<R3<4, 2, 3>>(3, 1, 0).retrieve();
let out1 = inp1.pool_last_dim(3, 1, 1).retrieve();
cx.execute();
@@ -418,21 +412,21 @@ fn test_pool_1d_dims() {
fn test_pool_2d() {
let mut cx = Graph::new();
let inp1 = cx.tensor::<R2<4, 4>>().set(vec![
let inp1 = cx.tensor((4, 4)).set(vec![
1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16.,
]);
// 3x3 kernel
let out1 = inp1
// Pool first dim first by moving it to end
.permute::<_, LAxes2<1, 0>>()
.pool_last_dim::<R3<4, 2, 3>>(3, 1, 0)
.permute((1, 0))
.pool_last_dim(3, 1, 1)
// Now move other dim to end
.permute::<_, LAxes3<1, 2, 0>>()
.pool_last_dim::<R4<2, 3, 2, 3>>(3, 1, 0)
.permute((1, 2, 0))
.pool_last_dim(3, 1, 1)
// Now swap middle two dims
.permute::<_, LAxes4<0, 2, 1, 3>>()
.permute((0, 2, 1, 3))
// Now merge both pooled dimensions
.reshape::<R3<4, 3, 3>>()
.reshape((4, 3, 3))
.retrieve();
cx.execute();
@@ -451,13 +445,13 @@ fn test_pool_2d() {
fn test_pool_1d_dilation() {
let mut cx = Graph::new();
let inp1 = cx.tensor::<R1<5>>().set(vec![1., 2., 3., 4., 5.]);
let inp1 = cx.tensor(5).set(vec![1., 2., 3., 4., 5.]);
// Stride 1
let out1 = inp1.pool_last_dim::<R2<3, 2>>(2, 1, 1).retrieve();
let out1 = inp1.pool_last_dim(2, 1, 2).retrieve();
// Stride 2
let out2 = inp1.pool_last_dim::<R2<2, 2>>(2, 2, 1).retrieve();
let out2 = inp1.pool_last_dim(2, 2, 2).retrieve();
// Stride 3
let out3 = inp1.pool_last_dim::<R2<1, 2>>(2, 3, 1).retrieve();
let out3 = inp1.pool_last_dim(2, 3, 2).retrieve();
cx.execute();
@@ -476,14 +470,12 @@ fn test_conv2d() {
const KERNELY: usize = 2;
const STRIDEX: usize = KERNELX;
const STRIDEY: usize = KERNELY;
const DILATIONX: usize = 0;
const DILATIONY: usize = 0;
const DILATIONX: usize = 1;
const DILATIONY: usize = 1;
const DIMX_IN: usize = 16;
const DIMX_OUT: usize = ((DIMX_IN - (DILATIONX + 1) * (KERNELX - 1) - 1) / STRIDEX) + 1;
const DIMY_IN: usize = 9;
const DIMY_OUT: usize = ((DIMY_IN - (DILATIONY + 1) * (KERNELY - 1) - 1) / STRIDEY) + 1;
let inp1 = cx.tensor::<R3<CH_IN, DIMX_IN, DIMY_IN>>().set(vec![
let inp1 = cx.tensor((CH_IN, DIMX_IN, DIMY_IN)).set(vec![
8., 8., 5., 7., 0., 6., 5., 3., 0., 7., 0., 6., 6., 7., 7., 5., 0., 6., 9., 4., 0., 8., 8.,
5., 7., 6., 2., 8., 9., 5., 0., 3., 1., 1., 8., 4., 1., 1., 5., 6., 9., 3., 2., 9., 4., 7.,
1., 0., 7., 7., 4., 9., 5., 0., 4., 7., 4., 7., 8., 8., 4., 8., 4., 7., 9., 3., 7., 9., 5.,
@@ -518,7 +510,14 @@ fn test_conv2d() {
1., 2., 1., 1., 4., 7., 2.,
]);
let model = luminal_nn::Conv2D::<CH_IN, CH_OUT, KERNELX, KERNELY>::initialize(&mut cx);
let model = luminal_nn::Conv2D::new(
CH_IN,
CH_OUT,
(KERNELX, KERNELY),
(STRIDEX, STRIDEY),
(DILATIONX, DILATIONY),
&mut cx,
);
model.weight.set(vec![
0.1600, 0.2000, 0.1900, -0.1100, 0.0100, -0.0300, -0.1200, -0.0800, -0.1300, -0.0300,
0.1600, -0.1700, -0.0000, 0.1900, 0.1300, 0.0300, -0.1500, 0.0900, 0.0100, 0.0200, 0.1500,
@@ -526,9 +525,7 @@ fn test_conv2d() {
0.0100, -0.2000, 0.2100, -0.0400, -0.1400, 0.1500, 0.0500, -0.1700, 0.1400,
]);
let mut out1 = model
.forward::<DIMX_IN, DIMY_IN, DIMX_OUT, DIMY_OUT>(inp1)
.retrieve();
let mut out1 = model.forward(inp1).retrieve();
cx.compile(
<(GenericCompiler, MetalCompiler<f32>)>::default(),
@@ -564,14 +561,14 @@ fn test_conv1d_pad_stride() {
let kernel_data = random_vec_rng(KERNEL * CH_IN * CH_OUT, &mut rng);
let input_data = random_vec_rng(CH_IN * DIM_IN, &mut rng);
let model = Conv1D::<CH_IN, CH_OUT, KERNEL, STRIDE, 0, PADDING>::initialize(&mut cx);
let model = Conv1D::new(CH_IN, CH_OUT, KERNEL, STRIDE, 1, PADDING, false, &mut cx);
model.weight.set(kernel_data.clone());
let inp1 = cx
.tensor::<(LConst<1>, LConst<CH_IN>, Dyn<'s'>)>()
.set_dyn(input_data.clone(), &[1, CH_IN, DIM_IN]);
.tensor((1, CH_IN, 's'))
.set_dyn(input_data.clone(), (1, CH_IN, DIM_IN));
let mut out1 = model.forward((inp1, PhantomData::<Dyn<'s'>>)).retrieve();
let mut out1 = model.forward(inp1).retrieve();
cx.compile(crate::MetalCompiler::<f32>::default(), &mut out1);
cx.execute();

View File

@@ -12,8 +12,8 @@ macro_rules! single_unary_test {
let mut rng = StdRng::seed_from_u64(1);
let data = random_vec_rng($size, &mut rng);
let mut cx = Graph::new();
let a = cx.tensor::<R1<$size>>().set(data.clone());
let f: fn(GraphTensor<R1<$size>>) -> GraphTensor<R1<$size>> = $luminal_func;
let a = cx.tensor($size).set(data.clone());
let f: fn(GraphTensor) -> GraphTensor = $luminal_func;
let mut b = f(a).retrieve();
cx.compile(MetalCompiler::<$type>::default(), &mut b);
cx.execute();
@@ -53,9 +53,9 @@ macro_rules! single_binary_test {
let a_data = random_vec_rng($size, &mut rng);
let b_data = random_vec_rng($size, &mut rng);
let mut cx = Graph::new();
let a = cx.tensor::<R1<$size>>().set(a_data.clone());
let b = cx.tensor::<R1<$size>>().set(b_data.clone());
let f: fn(GraphTensor<R1<$size>>, GraphTensor<R1<$size>>) -> GraphTensor<R1<$size>> =
let a = cx.tensor($size).set(a_data.clone());
let b = cx.tensor($size).set(b_data.clone());
let f: fn(GraphTensor, GraphTensor) -> GraphTensor =
$luminal_func;
let mut c = f(a, b).retrieve();
cx.compile(MetalCompiler::<$type>::default(), &mut c);

View File

@@ -772,8 +772,8 @@ mod tests {
#[test]
fn test_norms() {
let mut cx = Graph::new();
let a = cx.tensor().set([0.; 32]);
let mut b = a.layer_norm::<Axis<0>, _>(1e-5).retrieve();
let a = cx.tensor(32).set([0.; 32]);
let mut b = a.layer_norm(0, 1e-5).retrieve();
cx.compile(
<(

View File

@@ -1,85 +1,65 @@
use luminal::prelude::*;
/// Rectified Linear Unit activation function
#[derive(Default)]
pub struct ReLU;
impl InitModule for ReLU {
fn initialize(_: &mut Graph) -> Self {
Self
}
}
impl SerializeModule for ReLU {
fn serialize(&self, _: &mut Serializer) {}
}
impl<S: Shape> Module<GraphTensor<S>> for ReLU {
type Output = GraphTensor<S>;
impl Module<GraphTensor> for ReLU {
type Output = GraphTensor;
fn forward(&self, input: GraphTensor<S>) -> Self::Output {
fn forward(&self, input: GraphTensor) -> Self::Output {
input.relu()
}
}
/// Sigmoid activation function
#[derive(Default)]
pub struct Sigmoid;
impl InitModule for Sigmoid {
fn initialize(_: &mut Graph) -> Self {
Self
}
}
impl SerializeModule for Sigmoid {
fn serialize(&self, _: &mut Serializer) {}
}
impl<S: ConstShape> Module<GraphTensor<S>> for Sigmoid {
type Output = GraphTensor<S>;
impl Module<GraphTensor> for Sigmoid {
type Output = GraphTensor;
fn forward(&self, input: GraphTensor<S>) -> Self::Output {
fn forward(&self, input: GraphTensor) -> Self::Output {
input.sigmoid()
}
}
/// Swish activation function
#[derive(Default)]
pub struct Swish;
impl InitModule for Swish {
fn initialize(_: &mut Graph) -> Self {
Self
}
}
impl SerializeModule for Swish {
fn serialize(&self, _: &mut Serializer) {}
}
impl<S: ConstShape> Module<GraphTensor<S>> for Swish {
type Output = GraphTensor<S>;
impl Module<GraphTensor> for Swish {
type Output = GraphTensor;
fn forward(&self, input: GraphTensor<S>) -> Self::Output {
fn forward(&self, input: GraphTensor) -> Self::Output {
input.swish()
}
}
/// Tanh activation function
#[derive(Default)]
pub struct Tanh;
impl InitModule for Tanh {
fn initialize(_: &mut Graph) -> Self {
Self
}
}
impl SerializeModule for Tanh {
fn serialize(&self, _: &mut Serializer) {}
}
impl<S: ConstShape> Module<GraphTensor<S>> for Tanh {
type Output = GraphTensor<S>;
impl Module<GraphTensor> for Tanh {
type Output = GraphTensor;
fn forward(&self, input: GraphTensor<S>) -> Self::Output {
fn forward(&self, input: GraphTensor) -> Self::Output {
input.tanh()
}
}
@@ -98,12 +78,14 @@ mod tests {
fn test_relu_and_linear() {
// Test single and batch, unoptimized and optimized
let mut cx = Graph::new();
let batch = cx
.tensor::<R2<2, 3>>()
.set(vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0]);
let a = cx.tensor::<R1<3>>().set(vec![1.0, 2.0, 3.0]);
let batch = cx.tensor((2, 3)).set(vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0]);
let a = cx.tensor(3).set(vec![1.0, 2.0, 3.0]);
let model: (Linear<3, 4>, ReLU, Linear<4, 2>) = InitModule::initialize(&mut cx);
let model = (
Linear::new(3, 4, false, &mut cx),
ReLU,
Linear::new(4, 2, false, &mut cx),
);
model
.0
.weight

View File

@@ -1,81 +1,45 @@
use std::marker::PhantomData;
use luminal::prelude::*;
use rand::{thread_rng, Rng};
pub struct Conv1D<
const CH_IN: usize,
const CH_OUT: usize,
const KERNEL: usize,
const STRIDE: usize = KERNEL,
const DILATION: usize = 0,
const PADDING: usize = 0,
> {
pub weight: GraphTensor<R3<CH_OUT, CH_IN, KERNEL>>,
pub bias: Option<GraphTensor<R1<CH_OUT>>>,
pub struct Conv1D {
pub weight: GraphTensor, // ch_out, ch_in * kernel
pub bias: Option<GraphTensor>,
padding: usize,
dilation: usize,
stride: usize,
kernel: usize,
ch_in: usize,
}
impl<
const CH_IN: usize,
const CH_OUT: usize,
const KERNEL: usize,
const STRIDE: usize,
const DILATION: usize,
const PADDING: usize,
> InitModule for Conv1D<CH_IN, CH_OUT, KERNEL, STRIDE, DILATION, PADDING>
{
fn initialize(cx: &mut Graph) -> Self {
// Init weight as uniform(-1, 1)
let mut rng = thread_rng();
impl Conv1D {
/// Create a new 1D convolution layer
#[allow(clippy::too_many_arguments)]
pub fn new(
ch_in: usize,
ch_out: usize,
kernel: usize,
stride: usize,
dilation: usize,
padding: usize,
bias: bool,
cx: &mut Graph,
) -> Self {
Self {
weight: cx.named_tensor("Weight").set(
(0..(CH_IN * CH_OUT * KERNEL))
.map(|_| rng.gen_range(-1_f32..1_f32))
.collect::<Vec<_>>(),
),
bias: None,
weight: cx.named_tensor("Weight", (ch_out, ch_in * kernel)),
bias: if bias {
Some(cx.named_tensor("Bias", ch_out))
} else {
None
},
padding,
dilation,
stride,
kernel,
ch_in,
}
}
}
impl<
const CH_IN: usize,
const CH_OUT: usize,
const KERNEL: usize,
const STRIDE: usize,
const DILATION: usize,
const PADDING: usize,
> Conv1D<CH_IN, CH_OUT, KERNEL, STRIDE, DILATION, PADDING>
{
pub fn initialize_bias(cx: &mut Graph) -> Self {
// Init weight as uniform(-1, 1)
let mut rng = thread_rng();
Self {
weight: cx.named_tensor("Weight").set(
(0..(CH_IN * CH_OUT * KERNEL))
.map(|_| rng.gen_range(-1_f32..1_f32))
.collect::<Vec<_>>(),
),
bias: Some(
cx.named_tensor("Bias").set(
(0..CH_OUT)
.map(|_| rng.gen_range(-1_f32..1_f32))
.collect::<Vec<_>>(),
),
),
}
}
}
impl<
const CH_IN: usize,
const CH_OUT: usize,
const KERNEL: usize,
const STRIDE: usize,
const DILATION: usize,
const PADDING: usize,
> SerializeModule for Conv1D<CH_IN, CH_OUT, KERNEL, STRIDE, DILATION, PADDING>
{
impl SerializeModule for Conv1D {
fn serialize(&self, s: &mut luminal::module::Serializer) {
s.tensor("weight", self.weight);
if let Some(bias) = self.bias {
@@ -85,377 +49,193 @@ impl<
}
// Single
impl<
const CH_IN: usize,
const CH_OUT: usize,
const KERNEL: usize,
const STRIDE: usize,
const DILATION: usize,
const PADDING: usize,
DimIn: Dimension,
DimOut: Dimension,
> Module<(GraphTensor<(Const<CH_IN>, DimIn)>, PhantomData<DimOut>)>
for Conv1D<CH_IN, CH_OUT, KERNEL, STRIDE, DILATION, PADDING>
{
type Output = GraphTensor<(Const<CH_OUT>, DimOut)>;
fn forward(
&self,
(input, ph): (GraphTensor<(Const<CH_IN>, DimIn)>, PhantomData<DimOut>),
) -> Self::Output {
<Self as Module<(
GraphTensor<(Const<1>, Const<1>, Const<CH_IN>, DimIn)>,
PhantomData<DimOut>,
)>>::forward(self, (input.expand(), ph))
.reshape()
}
}
// Batch 1D
impl<
const CH_IN: usize,
const CH_OUT: usize,
const KERNEL: usize,
const STRIDE: usize,
const DILATION: usize,
const PADDING: usize,
DimIn: Dimension,
DimOut: Dimension,
Batch: Dimension,
>
Module<(
GraphTensor<(Batch, Const<CH_IN>, DimIn)>,
PhantomData<DimOut>,
)> for Conv1D<CH_IN, CH_OUT, KERNEL, STRIDE, DILATION, PADDING>
{
type Output = GraphTensor<(Batch, Const<CH_OUT>, DimOut)>;
fn forward(
&self,
(input, ph): (
GraphTensor<(Batch, Const<CH_IN>, DimIn)>,
PhantomData<DimOut>,
),
) -> Self::Output {
<Self as Module<(
GraphTensor<(Const<1>, Batch, Const<CH_IN>, DimIn)>,
PhantomData<DimOut>,
)>>::forward(self, (input.expand(), ph))
.reshape()
}
}
// Batch x Batch
impl<
const CH_IN: usize,
const CH_OUT: usize,
const KERNEL: usize,
const STRIDE: usize,
const DILATION: usize,
const PADDING: usize,
DimIn: Dimension,
DimOut: Dimension,
Batch1: Dimension,
Batch2: Dimension,
>
Module<(
GraphTensor<(Batch1, Batch2, Const<CH_IN>, DimIn)>,
PhantomData<DimOut>,
)> for Conv1D<CH_IN, CH_OUT, KERNEL, STRIDE, DILATION, PADDING>
{
type Output = GraphTensor<(Batch1, Batch2, Const<CH_OUT>, DimOut)>;
fn forward(
&self,
(input, _): (
GraphTensor<(Batch1, Batch2, Const<CH_IN>, DimIn)>,
PhantomData<DimOut>,
),
) -> Self::Output {
let mut o = input
impl Module<GraphTensor> for Conv1D {
type Output = GraphTensor;
fn forward(&self, input: GraphTensor) -> Self::Output {
assert_eq!(input.shape()[input.shape.len() - 2], self.ch_in);
// Input: batch_dims, ch_in, dim_in
// Reshape to 2 batch dims
let n_expands = 4 - input.shape.len();
let mut inp = input;
for _ in 0..n_expands {
inp = inp.expand(0, 1);
}
let batch1 = inp.shape()[0].small();
let batch2 = inp.shape()[1].small();
let dim_in = input.shape().last().unwrap().small();
let dim_out = (((dim_in + 2 * self.padding - self.dilation * (self.kernel - 1) - 1)
/ self.stride)
+ 1)
.simplify();
let mut out = inp
// Add padding
.pad::<(Batch1, Batch2, Const<CH_IN>, DimIn)>(((0, 0), (0, 0), (0, 0), (PADDING, 0)))
.pad(((0, 0), (0, 0), (0, 0), (self.padding, 0)))
.contiguous()
.pad::<(Batch1, Batch2, Const<CH_IN>, DimIn)>(((0, 0), (0, 0), (0, 0), (0, PADDING)))
.pad(((0, 0), (0, 0), (0, 0), (0, self.padding)))
// Pool
.pool_last_dim::<(Batch1, Batch2, Const<CH_IN>, DimOut, Const<KERNEL>)>(
KERNEL, STRIDE, DILATION,
)
.pool_last_dim(self.kernel, self.stride, self.dilation)
// Combine channel_in and kernel
.permute::<_, Axes5<0, 1, 3, 2, 4>>()
.dyn_reshape::<(Batch1, Batch2, DimOut, Dyn<'-'>), _>(&[
Batch1::size(),
Batch2::size(),
DimOut::size(),
(CH_IN * KERNEL).into(),
])
.permute((0, 1, 3, 2, 4))
.reshape((batch1, batch2, dim_out, self.ch_in * self.kernel))
.matmul(
self.weight
// Combine last two dimensions in kernel
.dyn_reshape::<(Const<CH_OUT>, Dyn<'-'>), _>(&[CH_OUT, CH_IN * KERNEL])
.permute()
.permute((1, 0))
// Broadcast along batch dims
.expand(),
.expand(0, batch1)
.expand(1, batch2),
)
.permute();
.permute((0, 1, 3, 2));
if let Some(b) = self.bias {
o += b.expand();
out += b.expand_to(out.shape);
}
o
// Reshape back to original shape
let mut final_shape = out.shape();
for _ in 0..n_expands {
final_shape.remove(0);
}
out.reshape(final_shape) // Output: batch_dims, ch_out, dim_out
}
}
pub struct Conv2D<
const CH_IN: usize,
const CH_OUT: usize,
const KERNELX: usize,
const KERNELY: usize,
const STRIDEX: usize = KERNELX,
const STRIDEY: usize = KERNELY,
const DILATIONX: usize = 0,
const DILATIONY: usize = 0,
> {
pub weight: GraphTensor<R4<CH_OUT, CH_IN, KERNELX, KERNELY>>,
pub struct Conv2D {
pub weight: GraphTensor, // ch_out, ch_in * kernel_x * kernel_y
kernel: (usize, usize),
stride: (usize, usize),
dilation: (usize, usize),
ch_out: usize,
ch_in: usize,
}
impl<
const CH_IN: usize,
const CH_OUT: usize,
const KERNELX: usize,
const KERNELY: usize,
const STRIDEX: usize,
const STRIDEY: usize,
const DILATIONX: usize,
const DILATIONY: usize,
> InitModule
for Conv2D<CH_IN, CH_OUT, KERNELX, KERNELY, STRIDEX, STRIDEY, DILATIONX, DILATIONY>
{
fn initialize(cx: &mut Graph) -> Self {
// Init weight as uniform(-1, 1)
let mut rng = thread_rng();
impl Conv2D {
pub fn new(
ch_in: usize,
ch_out: usize,
kernel: (usize, usize),
stride: (usize, usize),
dilation: (usize, usize),
cx: &mut Graph,
) -> Self {
Self {
weight: cx.named_tensor("Weight").set(
(0..(CH_IN * CH_OUT * KERNELX * KERNELY))
.map(|_| rng.gen_range(-1_f32..1_f32))
.collect::<Vec<_>>(),
),
weight: cx.named_tensor("Weight", (ch_out, ch_in * kernel.0 * kernel.1)),
kernel,
stride,
dilation,
ch_out,
ch_in,
}
}
}
impl<
const CH_IN: usize,
const CH_OUT: usize,
const KERNELX: usize,
const KERNELY: usize,
const STRIDEX: usize,
const STRIDEY: usize,
const DILATIONX: usize,
const DILATIONY: usize,
> SerializeModule
for Conv2D<CH_IN, CH_OUT, KERNELX, KERNELY, STRIDEX, STRIDEY, DILATIONX, DILATIONY>
{
impl SerializeModule for Conv2D {
fn serialize(&self, s: &mut luminal::module::Serializer) {
s.tensor("weight", self.weight);
}
}
// Single
impl<
const CH_IN: usize,
const CH_OUT: usize,
const KERNELX: usize,
const KERNELY: usize,
const STRIDEX: usize,
const STRIDEY: usize,
const DILATIONX: usize,
const DILATIONY: usize,
> Conv2D<CH_IN, CH_OUT, KERNELX, KERNELY, STRIDEX, STRIDEY, DILATIONX, DILATIONY>
{
pub fn forward<
const DIMX_IN: usize,
const DIMY_IN: usize,
const DIMX_OUT: usize,
const DIMY_OUT: usize,
>(
&self,
input: GraphTensor<R3<CH_IN, DIMX_IN, DIMY_IN>>,
) -> GraphTensor<R3<CH_OUT, DIMX_OUT, DIMY_OUT>> {
impl Conv2D {
pub fn forward(&self, input: GraphTensor) -> GraphTensor {
// Input: (ch_in, dimx_in, dimy_in)
let dimx_in = input.shape()[1].small();
let dimy_in = input.shape()[2].small();
let dimx_out = (((dimx_in - self.dilation.0 * (self.kernel.0 - 1) - 1) / self.stride.0)
+ 1)
.simplify();
let dimy_out = (((dimy_in - self.dilation.1 * (self.kernel.1 - 1) - 1) / self.stride.1)
+ 1)
.simplify();
let input_pooled = input
.pool_last_dim::<R4<CH_IN, DIMX_IN, DIMY_OUT, KERNELY>>(KERNELY, STRIDEY, DILATIONY)
.permute::<_, Axes4<0, 2, 3, 1>>()
.pool_last_dim::<R5<CH_IN, DIMY_OUT, KERNELY, DIMX_OUT, KERNELX>>(
KERNELX, STRIDEX, DILATIONX,
)
.permute::<_, Axes5<0, 4, 2, 3, 1>>()
.dyn_reshape::<(_, Dyn<'-'>), _>(&[CH_IN * KERNELX * KERNELY, DIMX_OUT * DIMY_OUT]);
.pool_last_dim(self.kernel.1, self.stride.1, self.dilation.1)
.permute((0, 2, 3, 1))
.pool_last_dim(self.kernel.0, self.stride.0, self.dilation.0)
.permute((0, 4, 2, 3, 1))
.reshape((
self.ch_in * self.kernel.0 * self.kernel.1,
dimx_out * dimy_out,
));
self.weight
.dyn_reshape::<(Const<CH_OUT>, Dyn<'-'>), _>(&[CH_OUT, CH_IN * KERNELX * KERNELY])
.matmul(input_pooled)
.reshape::<R3<CH_OUT, DIMX_OUT, DIMY_OUT>>()
.reshape((self.ch_out, dimx_out, dimy_out))
}
}
pub struct Conv3D<
const CH_IN: usize,
const CH_OUT: usize,
const KERNELX: usize,
const KERNELY: usize,
const KERNELZ: usize,
const STRIDEX: usize,
const STRIDEY: usize,
const STRIDEZ: usize,
const DILATIONX: usize,
const DILATIONY: usize,
const DILATIONZ: usize,
> {
pub weight: GraphTensor<R5<CH_OUT, CH_IN, KERNELX, KERNELY, KERNELZ>>,
pub struct Conv3D {
pub weight: GraphTensor, // ch_out, ch_in * kernel_x * kernel_y * kernel_z
kernel: (usize, usize, usize),
stride: (usize, usize, usize),
dilation: (usize, usize, usize),
ch_in: usize,
ch_out: usize,
}
impl<
const CH_IN: usize,
const CH_OUT: usize,
const KERNELX: usize,
const KERNELY: usize,
const KERNELZ: usize,
const STRIDEX: usize,
const STRIDEY: usize,
const STRIDEZ: usize,
const DILATIONX: usize,
const DILATIONY: usize,
const DILATIONZ: usize,
> InitModule
for Conv3D<
CH_IN,
CH_OUT,
KERNELX,
KERNELY,
KERNELZ,
STRIDEX,
STRIDEY,
STRIDEZ,
DILATIONX,
DILATIONY,
DILATIONZ,
>
{
fn initialize(cx: &mut Graph) -> Self {
// Init weight as uniform(-1, 1)
let mut rng = thread_rng();
impl Conv3D {
pub fn new(
ch_in: usize,
ch_out: usize,
kernel: (usize, usize, usize),
stride: (usize, usize, usize),
dilation: (usize, usize, usize),
cx: &mut Graph,
) -> Self {
Self {
weight: cx.named_tensor("Weight").set(
(0..(CH_IN * CH_OUT * KERNELX * KERNELY * KERNELZ))
.map(|_| rng.gen_range(-1_f32..1_f32))
.collect::<Vec<_>>(),
),
weight: cx.named_tensor("Weight", (ch_out, ch_in * kernel.0 * kernel.1 * kernel.2)),
kernel,
stride,
dilation,
ch_in,
ch_out,
}
}
}
impl<
const CH_IN: usize,
const CH_OUT: usize,
const KERNELX: usize,
const KERNELY: usize,
const KERNELZ: usize,
const STRIDEX: usize,
const STRIDEY: usize,
const STRIDEZ: usize,
const DILATIONX: usize,
const DILATIONY: usize,
const DILATIONZ: usize,
> SerializeModule
for Conv3D<
CH_IN,
CH_OUT,
KERNELX,
KERNELY,
KERNELZ,
STRIDEX,
STRIDEY,
STRIDEZ,
DILATIONX,
DILATIONY,
DILATIONZ,
>
{
impl SerializeModule for Conv3D {
fn serialize(&self, s: &mut luminal::module::Serializer) {
s.tensor("weight", self.weight);
}
}
impl<
const CH_IN: usize,
const CH_OUT: usize,
const KERNELX: usize,
const KERNELY: usize,
const KERNELZ: usize,
const STRIDEX: usize,
const STRIDEY: usize,
const STRIDEZ: usize,
const DILATIONX: usize,
const DILATIONY: usize,
const DILATIONZ: usize,
>
Conv3D<
CH_IN,
CH_OUT,
KERNELX,
KERNELY,
KERNELZ,
STRIDEX,
STRIDEY,
STRIDEZ,
DILATIONX,
DILATIONY,
DILATIONZ,
>
{
pub fn forward<
const DIMX_IN: usize,
const DIMY_IN: usize,
const DIMZ_IN: usize,
const DIMX_OUT: usize,
const DIMY_OUT: usize,
const DIMZ_OUT: usize,
>(
&self,
input: GraphTensor<R4<CH_IN, DIMX_IN, DIMY_IN, DIMZ_IN>>,
) -> GraphTensor<R4<CH_OUT, DIMX_OUT, DIMY_OUT, DIMZ_OUT>> {
impl Conv3D {
pub fn forward(&self, input: GraphTensor) -> GraphTensor {
// Input: ch_in, dimx_in, dimy_in, dimz_in
let dimx_in = input.shape()[1].small();
let dimy_in = input.shape()[2].small();
let dimz_in = input.shape()[3].small();
let dimx_out = (((dimx_in - self.dilation.0 * (self.kernel.0 - 1) - 1) / self.stride.0)
+ 1)
.simplify();
let dimy_out = (((dimy_in - self.dilation.1 * (self.kernel.1 - 1) - 1) / self.stride.1)
+ 1)
.simplify();
let dimz_out = (((dimz_in - self.dilation.2 * (self.kernel.2 - 1) - 1) / self.stride.2)
+ 1)
.simplify();
let input_pooled = input
.pool_last_dim::<R5<CH_IN, DIMX_IN, DIMY_OUT, DIMZ_OUT, KERNELY>>(
KERNELY, STRIDEY, DILATIONY,
)
.permute::<_, Axes5<0, 2, 3, 4, 1>>()
.pool_last_dim::<R6<CH_IN, DIMY_OUT, DIMZ_OUT, KERNELY, DIMX_OUT, KERNELX>>(
KERNELX, STRIDEX, DILATIONX,
)
.dyn_reshape::<(Const<CH_IN>, Dyn<'-'>), _>(&[
CH_IN,
DIMZ_OUT,
KERNELY,
DIMX_OUT * KERNELX,
DIMY_IN,
]);
.pool_last_dim(self.kernel.1, self.stride.1, self.dilation.1)
.permute((0, 2, 3, 4, 1))
.pool_last_dim(self.kernel.0, self.stride.0, self.dilation.0)
.reshape((
self.ch_in,
dimz_out,
self.kernel.1,
dimx_out * self.kernel.0,
dimy_in,
));
let last_pool = input_pooled
.pool_last_dim::<(
Const<CH_IN>,
Const<DIMZ_OUT>,
Const<KERNELY>,
Dyn<'-'>,
Const<DIMY_IN>,
Const<KERNELZ>,
)>(KERNELZ, STRIDEZ, DILATIONZ)
.permute::<_, Axes6<0, 2, 5, 3, 1, 4>>();
.pool_last_dim(self.kernel.2, self.stride.2, self.dilation.2)
.permute((0, 2, 5, 3, 1, 4));
let reshaped = last_pool.dyn_reshape::<(_, Dyn<'-'>), _>(&[
CH_IN * KERNELX * KERNELY * KERNELZ,
DIMX_OUT * DIMY_OUT * DIMZ_OUT,
]);
let reshaped = last_pool.reshape((
self.ch_in * self.kernel.0 * self.kernel.1 * self.kernel.2,
dimx_out * dimy_out * dimz_out,
));
self.weight
.dyn_reshape::<(Const<CH_OUT>, Dyn<'-'>), _>(&[
CH_OUT,
CH_IN * KERNELX * KERNELY * KERNELZ,
])
.matmul(reshaped)
.reshape::<R4<CH_OUT, DIMX_OUT, DIMY_OUT, DIMZ_OUT>>()
.reshape((self.ch_out, dimx_out, dimy_out, dimz_out))
}
}
@@ -468,7 +248,6 @@ mod tests {
tests::{assert_close, random_vec_rng},
};
use rand::{rngs::StdRng, SeedableRng};
use std::marker::PhantomData;
#[test]
fn test_conv1d_simple() {
@@ -481,18 +260,18 @@ mod tests {
const DIM_IN: usize = 6;
const DIM_OUT: usize = ((DIM_IN - (KERNEL - 1) - 1) / STRIDE) + 1;
let model = Conv1D::<CH_IN, CH_OUT, KERNEL>::initialize(&mut cx);
let model = Conv1D::new(CH_IN, CH_OUT, KERNEL, KERNEL, 1, 0, false, &mut cx);
model.weight.set([[[0.0316, -0.2057]]]);
let inp1 = cx
.tensor::<R2<CH_IN, DIM_IN>>()
.set([[3., 0., 9., 6., 0., 6.]]);
let inp1 = cx.tensor((CH_IN, DIM_IN)).set([[3., 0., 9., 6., 0., 6.]]);
let out1 = model
.forward((inp1, PhantomData::<Const<DIM_OUT>>))
.retrieve();
let out1 = model.forward(inp1).retrieve();
cx.execute();
assert_eq!(
out1.shape(),
vec![BigExpression::from(CH_OUT), BigExpression::from(DIM_OUT)]
);
assert_close(&out1.data(), &[0.0948, -0.9498, -1.2342]);
}
@@ -510,14 +289,14 @@ mod tests {
let kernel_data = random_vec_rng(KERNEL * CH_IN * CH_OUT, &mut rng);
let input_data = random_vec_rng(CH_IN * DIM_IN, &mut rng);
let model = Conv1D::<CH_IN, CH_OUT, KERNEL, STRIDE, 0, PADDING>::initialize(&mut cx);
let model = Conv1D::new(CH_IN, CH_OUT, KERNEL, STRIDE, 1, PADDING, false, &mut cx);
model.weight.set(kernel_data.clone());
let inp1 = cx
.tensor::<(Const<1>, Const<CH_IN>, Dyn<'s'>)>()
.set_dyn(input_data.clone(), &[1, CH_IN, DIM_IN]);
.tensor((1, CH_IN, 's'))
.set_dyn(input_data.clone(), (1, CH_IN, DIM_IN));
let out1 = model.forward((inp1, PhantomData::<Dyn<'s'>>)).retrieve();
let out1 = model.forward(inp1).retrieve();
cx.execute();
let input = Tensor::from_vec(input_data, (1, CH_IN, DIM_IN), &Device::Cpu).unwrap();
@@ -539,9 +318,8 @@ mod tests {
const KERNEL: usize = 2;
const STRIDE: usize = 2;
const DIM_IN: usize = 12;
const DIM_OUT: usize = ((DIM_IN - (KERNEL - 1) - 1) / STRIDE) + 1;
let model = Conv1D::<CH_IN, CH_OUT, KERNEL>::initialize(&mut cx);
let model = Conv1D::new(CH_IN, CH_OUT, KERNEL, STRIDE, 1, 0, false, &mut cx);
model.weight.set(vec![
-0.1700, -0.2000, 0.1000, -0.0200, 0.1000, 0.0200, -0.2100, -0.2300, -0.0600, 0.1500,
0.1200, 0.1000, 0.1800, 0.0600, -0.1700, -0.0400, 0.1000, -0.0200, -0.1700, 0.1000,
@@ -552,7 +330,7 @@ mod tests {
0.0700, -0.1200, 0.1400, 0.2200,
]);
let inp1 = cx.tensor::<R2<CH_IN, DIM_IN>>();
let inp1 = cx.tensor((CH_IN, DIM_IN));
inp1.set(vec![
1., 2., 6., 4., 8., 1., 6., 0., 1., 0., 6., 4., 3., 4., 9., 3., 8., 8., 5., 5., 0., 4.,
2., 7., 6., 4., 2., 2., 8., 0., 7., 3., 0., 0., 7., 2., 3., 3., 1., 9., 5., 4., 5., 5.,
@@ -562,9 +340,7 @@ mod tests {
]);
inp1.retrieve();
let out1 = model
.forward((inp1, PhantomData::<Const<DIM_OUT>>))
.retrieve();
let out1 = model.forward(inp1).retrieve();
cx.execute();
assert_close(
@@ -587,14 +363,14 @@ mod tests {
const KERNELY: usize = 2;
const STRIDEX: usize = KERNELX;
const STRIDEY: usize = KERNELY;
const DILATIONX: usize = 0;
const DILATIONY: usize = 0;
const DILATIONX: usize = 1;
const DILATIONY: usize = 1;
const DIMX_IN: usize = 16;
const DIMX_OUT: usize = ((DIMX_IN - (DILATIONX + 1) * (KERNELX - 1) - 1) / STRIDEX) + 1;
const DIMX_OUT: usize = ((DIMX_IN - DILATIONX * (KERNELX - 1) - 1) / STRIDEX) + 1;
const DIMY_IN: usize = 9;
const DIMY_OUT: usize = ((DIMY_IN - (DILATIONY + 1) * (KERNELY - 1) - 1) / STRIDEY) + 1;
const DIMY_OUT: usize = ((DIMY_IN - DILATIONY * (KERNELY - 1) - 1) / STRIDEY) + 1;
let inp1 = cx.tensor::<R3<CH_IN, DIMX_IN, DIMY_IN>>();
let inp1 = cx.tensor((CH_IN, DIMX_IN, DIMY_IN));
inp1.set(vec![
8., 8., 5., 7., 0., 6., 5., 3., 0., 7., 0., 6., 6., 7., 7., 5., 0., 6., 9., 4., 0., 8.,
8., 5., 7., 6., 2., 8., 9., 5., 0., 3., 1., 1., 8., 4., 1., 1., 5., 6., 9., 3., 2., 9.,
@@ -631,7 +407,7 @@ mod tests {
3., 1., 5., 9., 1., 6., 5., 4., 2., 1., 2., 1., 1., 4., 7., 2.,
]);
let exp_out1 = cx.tensor::<R3<CH_OUT, DIMX_OUT, DIMY_OUT>>();
let exp_out1 = cx.tensor((CH_OUT, DIMX_OUT, DIMY_OUT));
exp_out1.set(vec![
3.9600, -0.3300, -1.7800, 4.0400, 1.5300, 0.2900, 2.8700, 3.0000, 0.9600, -1.8700,
4.5900, 3.9700, 1.2800, 1.1800, 3.7800, 2.8500, 0.5500, 0.5600, 3.9800, 1.3200,
@@ -644,7 +420,14 @@ mod tests {
exp_out1.retrieve();
let model = Conv2D::<CH_IN, CH_OUT, KERNELX, KERNELY>::initialize(&mut cx);
let model = Conv2D::new(
CH_IN,
CH_OUT,
(KERNELX, KERNELY),
(STRIDEX, STRIDEY),
(DILATIONX, DILATIONY),
&mut cx,
);
model.weight.set(vec![
0.1600, 0.2000, 0.1900, -0.1100, 0.0100, -0.0300, -0.1200, -0.0800, -0.1300, -0.0300,
0.1600, -0.1700, -0.0000, 0.1900, 0.1300, 0.0300, -0.1500, 0.0900, 0.0100, 0.0200,
@@ -652,9 +435,7 @@ mod tests {
-0.2100, 0.0100, -0.2000, 0.2100, -0.0400, -0.1400, 0.1500, 0.0500, -0.1700, 0.1400,
]);
let out1 = model
.forward::<DIMX_IN, DIMY_IN, DIMX_OUT, DIMY_OUT>(inp1)
.retrieve();
let out1 = model.forward(inp1).retrieve();
cx.execute();
@@ -673,17 +454,17 @@ mod tests {
const STRIDEX: usize = 2;
const STRIDEY: usize = 2;
const STRIDEZ: usize = 2;
const DILATIONX: usize = 0;
const DILATIONY: usize = 0;
const DILATIONZ: usize = 0;
const DILATIONX: usize = 1;
const DILATIONY: usize = 1;
const DILATIONZ: usize = 1;
const DIMX_IN: usize = 2;
const DIMY_IN: usize = 3;
const DIMZ_IN: usize = 5;
const DIMX_OUT: usize = ((DIMX_IN - (DILATIONX + 1) * (KERNELX - 1) - 1) / STRIDEX) + 1;
const DIMY_OUT: usize = ((DIMY_IN - (DILATIONY + 1) * (KERNELY - 1) - 1) / STRIDEY) + 1;
const DIMZ_OUT: usize = ((DIMZ_IN - (DILATIONZ + 1) * (KERNELZ - 1) - 1) / STRIDEZ) + 1;
const DIMX_OUT: usize = ((DIMX_IN - DILATIONX * (KERNELX - 1) - 1) / STRIDEX) + 1;
const DIMY_OUT: usize = ((DIMY_IN - DILATIONY * (KERNELY - 1) - 1) / STRIDEY) + 1;
const DIMZ_OUT: usize = ((DIMZ_IN - DILATIONZ * (KERNELZ - 1) - 1) / STRIDEZ) + 1;
let inp1 = cx.tensor::<R4<CH_IN, DIMX_IN, DIMY_IN, DIMZ_IN>>();
let inp1 = cx.tensor((CH_IN, DIMX_IN, DIMY_IN, DIMZ_IN));
inp1.set(vec![
// Example input data (5 channels, 2x3x5 volume)
8., 8., 5., 7., 0., 6., 5., 3., 0., 7., 0., 6., 6., 7., 7., 5., 0., 6., 9., 4., 0., 8.,
@@ -695,7 +476,7 @@ mod tests {
4., 1., 9., 7., 7., 1., 2., 6., 3., 4., 1., 1., 6., 6., 8., 2., 7., 7.,
]);
let exp_out1 = cx.tensor::<R4<CH_OUT, DIMX_OUT, DIMY_OUT, DIMZ_OUT>>();
let exp_out1 = cx.tensor((CH_OUT, DIMX_OUT, DIMY_OUT, DIMZ_OUT));
exp_out1.set(vec![
// Example expected output data (2 channels, 1x1x2 volume)
90.6935, 98.7138, 98.8273, 102.6553,
@@ -703,19 +484,14 @@ mod tests {
exp_out1.retrieve();
let model: Conv3D<
let model = Conv3D::new(
CH_IN,
CH_OUT,
KERNELX,
KERNELY,
KERNELZ,
STRIDEX,
STRIDEY,
STRIDEZ,
DILATIONX,
DILATIONY,
DILATIONZ,
> = Conv3D::initialize(&mut cx);
(KERNELX, KERNELY, KERNELZ),
(STRIDEX, STRIDEY, STRIDEZ),
(DILATIONX, DILATIONY, DILATIONZ),
&mut cx,
);
let weights = vec![
4.273e-01, 1.388e-01, 3.546e-01, 2.403e-01, 5.572e-01, 2.788e-01, 6.718e-01, 6.935e-01,
8.410e-01, 1.297e-01, 7.073e-01, 3.455e-01, 4.166e-01, 9.513e-01, 4.682e-01, 4.546e-02,
@@ -730,9 +506,7 @@ mod tests {
];
model.weight.set(weights);
let out1 = model
.forward::<DIMX_IN, DIMY_IN, DIMZ_IN, DIMX_OUT, DIMY_OUT, DIMZ_OUT>(inp1)
.retrieve();
let out1 = model.forward(inp1).retrieve();
cx.execute();

View File

@@ -1,87 +1,56 @@
use luminal::{prelude::*, tests::random_vec};
pub struct Embedding<const N: usize, const DIM: usize> {
pub weight: GraphTensor<R2<N, DIM>>,
pub struct Embedding {
permute: bool,
pub weight: GraphTensor, // n embeddings x embedding dim
}
impl<const A: usize, const B: usize> InitModule for Embedding<A, B> {
fn initialize(cx: &mut Graph) -> Self {
impl Embedding {
pub fn new(n_embeddings: usize, embedding_dim: usize, cx: &mut Graph) -> Self {
Self {
weight: cx.named_tensor("Embedding Weight").set(random_vec(A * B)),
weight: cx.named_tensor("Embedding Weight", (n_embeddings, embedding_dim)),
permute: false,
}
}
pub fn new_permuted(n_embeddings: usize, embedding_dim: usize, cx: &mut Graph) -> Self {
Self {
weight: cx.named_tensor("Embedding Weight", (n_embeddings, embedding_dim)),
permute: true,
}
}
pub fn initialize(self) -> Self {
self.weight.set(random_vec(
self.weight.shape.n_elements().to_usize().unwrap(),
));
self
}
}
impl<const A: usize, const B: usize> SerializeModule for Embedding<A, B> {
impl SerializeModule for Embedding {
fn serialize(&self, s: &mut luminal::module::Serializer) {
s.tensor("weight", self.weight);
}
}
// Single
impl<S: Dimension, const N: usize, const DIM: usize> Module<GraphTensor<(S,)>>
for Embedding<N, DIM>
{
type Output = GraphTensor<(S, Const<DIM>)>;
impl Module<GraphTensor> for Embedding {
type Output = GraphTensor;
fn forward(&self, input: GraphTensor<(S,)>) -> Self::Output {
self.weight.gather(input)
}
}
// Batch
impl<B: Dimension, S: Dimension, const N: usize, const DIM: usize> Module<GraphTensor<(B, S)>>
for Embedding<N, DIM>
{
type Output = GraphTensor<(B, S, Const<DIM>)>;
fn forward(&self, input: GraphTensor<(B, S)>) -> Self::Output {
self.weight
.gather(input.dyn_reshape::<(Dyn<'-'>,), _>(&[B::size() * S::size()]))
.reshape()
}
}
pub struct PermutedEmbedding<const N: usize, const DIM: usize> {
pub weight: GraphTensor<R2<DIM, N>>,
}
impl<const A: usize, const B: usize> InitModule for PermutedEmbedding<A, B> {
fn initialize(cx: &mut Graph) -> Self {
Self {
weight: cx.named_tensor("Embedding Weight").set(random_vec(A * B)),
fn forward(&self, input: GraphTensor) -> Self::Output {
// Flatten batches
let batch_size = input.shape.n_elements();
let inp = input.reshape(batch_size);
let out = if self.permute {
self.weight.permute((1, 0))
} else {
self.weight
}
}
}
impl<const A: usize, const B: usize> SerializeModule for PermutedEmbedding<A, B> {
fn serialize(&self, s: &mut luminal::module::Serializer) {
s.tensor("weight", self.weight);
}
}
// Single
impl<S: Dimension, const N: usize, const DIM: usize> Module<GraphTensor<(S,)>>
for PermutedEmbedding<N, DIM>
{
type Output = GraphTensor<(S, Const<DIM>)>;
fn forward(&self, input: GraphTensor<(S,)>) -> Self::Output {
self.weight.permute().gather(input)
}
}
// Batch
impl<B: Dimension, S: Dimension, const N: usize, const DIM: usize> Module<GraphTensor<(B, S)>>
for PermutedEmbedding<N, DIM>
{
type Output = GraphTensor<(B, S, Const<DIM>)>;
fn forward(&self, input: GraphTensor<(B, S)>) -> Self::Output {
self.weight
.permute()
.gather(input.dyn_reshape::<(Dyn<'-'>,), _>(&[B::size() * S::size()]))
.reshape()
.gather(inp);
// Unflatten
let mut new_shape = input.shape();
new_shape.push(self.weight.shape()[1].clone());
out.reshape(new_shape)
}
}
@@ -101,12 +70,10 @@ mod tests {
#[test]
fn test_embedding() {
let mut cx = Graph::new();
let batch = cx
.tensor::<R2<2, 3>>()
.set(vec![1.0, 0.0, 2.0, 1.0, 0.0, 1.0]);
let a = cx.tensor::<R1<3>>().set(vec![1.0, 0.0, 1.0]).retrieve();
let batch = cx.tensor((2, 3)).set(vec![1.0, 0.0, 2.0, 1.0, 0.0, 1.0]);
let a = cx.tensor(3).set(vec![1.0, 0.0, 1.0]).retrieve();
let model: Embedding<3, 4> = InitModule::initialize(&mut cx);
let model = Embedding::new(3, 4, &mut cx).initialize();
model
.weight
.set(vec![1.1, 2., 3., 1., 2., 3., 14., 2., 33., 1., 2., 3.]);

View File

@@ -1,5 +1,3 @@
use luminal::prelude::*;
mod activation;
pub use activation::*;
mod convolution;
@@ -12,34 +10,3 @@ mod norm;
pub use norm::*;
mod transformer;
pub use transformer::*;
pub struct Repeated<T, const N: usize> {
pub modules: Vec<T>,
}
impl<T: InitModule, const N: usize> InitModule for Repeated<T, N> {
fn initialize(cx: &mut Graph) -> Self {
Self {
modules: (0..N).map(|_| InitModule::initialize(cx)).collect(),
}
}
}
impl<T: SerializeModule, const N: usize> SerializeModule for Repeated<T, N> {
fn serialize(&self, s: &mut Serializer) {
for (i, l) in self.modules.iter().enumerate() {
s.module(&format!("layer{i}"), l);
}
}
}
impl<I, T: Module<I, Output = I>, const N: usize> Module<I> for Repeated<T, N> {
type Output = I;
fn forward(&self, mut input: I) -> Self::Output {
for m in &self.modules {
input = m.forward(input);
}
input
}
}

View File

@@ -3,74 +3,71 @@ use rand::{thread_rng, Rng};
use luminal::prelude::*;
/// A simple unbiased linear layer
pub struct Linear<const A: usize, const B: usize> {
pub weight: GraphTensor<R2<A, B>>,
pub struct Linear {
pub weight: GraphTensor,
pub bias: Option<GraphTensor>,
permute: bool,
}
impl<const A: usize, const B: usize> InitModule for Linear<A, B> {
fn initialize(cx: &mut Graph) -> Self {
impl Linear {
pub fn new(inp: usize, out: usize, bias: bool, cx: &mut Graph) -> Self {
Self {
weight: cx.named_tensor("Weight", (inp, out)),
bias: if bias {
Some(cx.named_tensor("Bias", out))
} else {
None
},
permute: false,
}
}
pub fn new_permuted(inp: usize, out: usize, bias: bool, cx: &mut Graph) -> Self {
Self {
weight: cx.named_tensor("Weight", (out, inp)),
bias: if bias {
Some(cx.named_tensor("Bias", out))
} else {
None
},
permute: true,
}
}
pub fn initialize(self) -> Self {
// Init weight as uniform(-1, 1)
let mut rng = thread_rng();
Self {
weight: cx.named_tensor("Weight").set(
(0..(A * B))
.map(|_| rng.gen_range(-1_f32..1_f32))
.collect::<Vec<_>>(),
),
self.weight.set(
(0..self.weight.shape.n_elements().to_usize().unwrap())
.map(|_| rng.gen_range(-1_f32..1_f32))
.collect::<Vec<_>>(),
);
self
}
}
impl SerializeModule for Linear {
fn serialize(&self, s: &mut luminal::module::Serializer) {
s.tensor("weight", self.weight);
if let Some(bias) = self.bias {
s.tensor("bias", bias);
}
}
}
impl<const A: usize, const B: usize> SerializeModule for Linear<A, B> {
fn serialize(&self, s: &mut luminal::module::Serializer) {
s.tensor("weight", self.weight);
}
}
impl Module<GraphTensor> for Linear {
type Output = GraphTensor;
impl<const A: usize, const B: usize, S: Shape> Module<GraphTensor<S>> for Linear<A, B>
where
GraphTensor<S>: Matmul<R2<A, B>>,
{
type Output = <GraphTensor<S> as Matmul<R2<A, B>>>::Output;
fn forward(&self, input: GraphTensor<S>) -> Self::Output {
input.matmul(self.weight)
}
}
/// A simple unbiased linear layer with a permuted weight matrix
pub struct PermutedLinear<const A: usize, const B: usize> {
pub weight: GraphTensor<R2<B, A>>,
}
impl<const A: usize, const B: usize> InitModule for PermutedLinear<A, B> {
fn initialize(cx: &mut Graph) -> Self {
// Init weight as uniform(-1, 1)
let mut rng = thread_rng();
Self {
weight: cx.named_tensor("Weight").set(
(0..(A * B))
.map(|_| rng.gen_range(-1_f32..1_f32))
.collect::<Vec<_>>(),
),
fn forward(&self, input: GraphTensor) -> Self::Output {
let mut output = input.matmul(if self.permute {
self.weight.permute((1, 0))
} else {
self.weight
});
if let Some(bias) = self.bias {
output += bias.expand_to(output.shape);
}
}
}
impl<const A: usize, const B: usize> SerializeModule for PermutedLinear<A, B> {
fn serialize(&self, s: &mut luminal::module::Serializer) {
s.tensor("weight", self.weight);
}
}
impl<const A: usize, const B: usize, S: Shape> Module<GraphTensor<S>> for PermutedLinear<A, B>
where
GraphTensor<S>: Matmul<R2<A, B>>,
{
type Output = <GraphTensor<S> as Matmul<R2<A, B>>>::Output;
fn forward(&self, input: GraphTensor<S>) -> Self::Output {
input.matmul(self.weight.permute())
output
}
}
@@ -81,12 +78,10 @@ mod tests {
#[test]
fn test_linear() {
let mut cx = Graph::new();
let batch = cx
.tensor::<R2<2, 3>>()
.set(vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0]);
let a = cx.tensor::<R1<3>>().set(vec![1.0, 2.0, 3.0]);
let batch = cx.tensor((2, 3)).set([1.0, 2.0, 3.0, 1.0, 2.0, 3.0]);
let a = cx.tensor(3).set([1.0, 2.0, 3.0]);
let model: Linear<3, 4> = Linear::initialize(&mut cx);
let model = Linear::new(3, 4, false, &mut cx).initialize();
let mut b = model.forward(a).retrieve();
let mut batch_out = model.forward(batch).retrieve();

View File

@@ -3,23 +3,30 @@ use rand::thread_rng;
/// A simple layer norm with an optional weight and bias
#[derive(Default)]
pub struct LayerNorm<const DIM: usize> {
pub weight: Option<GraphTensor<R1<DIM>>>,
pub bias: Option<GraphTensor<R1<DIM>>>,
pub struct LayerNorm {
pub weight: Option<GraphTensor>,
pub bias: Option<GraphTensor>,
mean_norm: bool,
epsilon: f32,
}
impl<const DIM: usize> LayerNorm<DIM> {
pub fn new(weight: bool, bias: bool, mean_norm: bool, epsilon: f32, cx: &mut Graph) -> Self {
impl LayerNorm {
pub fn new(
dim: usize,
weight: bool,
bias: bool,
mean_norm: bool,
epsilon: f32,
cx: &mut Graph,
) -> Self {
Self {
weight: if weight {
Some(cx.named_tensor("LayerNorm Weight"))
Some(cx.named_tensor("LayerNorm Weight", dim))
} else {
None
},
bias: if bias {
Some(cx.named_tensor("LayerNorm Bias"))
Some(cx.named_tensor("LayerNorm Bias", dim))
} else {
None
},
@@ -27,53 +34,43 @@ impl<const DIM: usize> LayerNorm<DIM> {
epsilon,
}
}
pub fn init(weight: bool, bias: bool, mean_norm: bool, epsilon: f32, cx: &mut Graph) -> Self {
pub fn initialize(self) -> Self {
// Init weight as uniform(-1, 1)
let mut rng = thread_rng();
Self {
weight: if weight {
Some(
cx.named_tensor("LayerNorm Weight")
.set(random_vec_rng(DIM, &mut rng)),
)
} else {
None
},
bias: if bias {
Some(
cx.named_tensor("LayerNorm Bias")
.set(random_vec_rng(DIM, &mut rng)),
)
} else {
None
},
mean_norm,
epsilon,
if let Some(w) = self.weight {
w.set(random_vec_rng(
w.shape.n_elements().to_usize().unwrap(),
&mut rng,
));
}
if let Some(b) = self.bias {
b.set(random_vec_rng(
b.shape.n_elements().to_usize().unwrap(),
&mut rng,
));
}
self
}
}
impl<const DIM: usize, S: Shape> Module<GraphTensor<S>> for LayerNorm<DIM>
where
(Const<DIM>,): BroadcastShapeTo<S, S::AllButLast>,
{
type Output = GraphTensor<S>;
fn forward(&self, mut input: GraphTensor<S>) -> Self::Output {
impl Module<GraphTensor> for LayerNorm {
type Output = GraphTensor;
fn forward(&self, mut input: GraphTensor) -> Self::Output {
if self.mean_norm {
input = input.mean_norm::<S::LastAxis>();
input = input.mean_norm(input.shape.last_axis());
}
input = input.std_norm(self.epsilon);
input = input.std_norm(input.shape.last_axis(), self.epsilon);
if let Some(w) = self.weight {
input *= w.expand();
input *= w.expand_to(input.shape);
}
if let Some(b) = self.bias {
input += b.expand();
input += b.expand_to(input.shape);
}
input
}
}
impl<const DIM: usize> SerializeModule for LayerNorm<DIM> {
impl SerializeModule for LayerNorm {
fn serialize(&self, s: &mut Serializer) {
if let Some(w) = self.weight {
s.tensor("weight", w);

View File

@@ -4,34 +4,31 @@ use crate::Linear;
use luminal::prelude::*;
/// Multi-head self attention as layed out in [*Attention Is All You Need*](https://arxiv.org/abs/1706.03762).
pub struct MultiHeadSelfAttention<
const DIM: usize,
const K_DIM: usize,
const V_DIM: usize,
const HEADS: usize,
> {
pub w_q: Linear<DIM, K_DIM>,
pub w_k: Linear<DIM, K_DIM>,
pub w_v: Linear<DIM, V_DIM>,
pub w_o: Linear<V_DIM, DIM>,
pub struct MultiHeadSelfAttention {
pub w_q: Linear, // dim x k_dim
pub w_k: Linear, // dim x k_dim
pub w_v: Linear, // dim x v_dim
pub w_o: Linear, // v_dim x dim
k_dim: usize,
v_dim: usize,
heads: usize,
}
impl<const DIM: usize, const K_DIM: usize, const V_DIM: usize, const HEADS: usize> InitModule
for MultiHeadSelfAttention<DIM, K_DIM, V_DIM, HEADS>
{
fn initialize(cx: &mut Graph) -> Self {
impl MultiHeadSelfAttention {
pub fn new(dim: usize, k_dim: usize, v_dim: usize, heads: usize, cx: &mut Graph) -> Self {
Self {
w_q: InitModule::initialize(cx),
w_k: InitModule::initialize(cx),
w_v: InitModule::initialize(cx),
w_o: InitModule::initialize(cx),
w_q: Linear::new(dim, k_dim, false, cx),
w_k: Linear::new(dim, k_dim, false, cx),
w_v: Linear::new(dim, v_dim, false, cx),
w_o: Linear::new(v_dim, dim, false, cx),
k_dim,
v_dim,
heads,
}
}
}
impl<const DIM: usize, const K_DIM: usize, const V_DIM: usize, const HEADS: usize> SerializeModule
for MultiHeadSelfAttention<DIM, K_DIM, V_DIM, HEADS>
{
impl SerializeModule for MultiHeadSelfAttention {
fn serialize(&self, s: &mut Serializer) {
s.module("w_q", &self.w_q);
s.module("w_k", &self.w_k);
@@ -40,146 +37,71 @@ impl<const DIM: usize, const K_DIM: usize, const V_DIM: usize, const HEADS: usiz
}
}
// Single
impl<
const DIM: usize,
const K_DIM: usize,
const V_DIM: usize,
const HEADS: usize,
S: Dimension,
> Module<GraphTensor<(S, Const<DIM>)>> for MultiHeadSelfAttention<DIM, K_DIM, V_DIM, HEADS>
{
type Output = GraphTensor<(S, Const<DIM>)>;
fn forward(&self, input: GraphTensor<(S, Const<DIM>)>) -> Self::Output {
// Pass to batched forward
<Self as Module<GraphTensor<(Const<1>, S, Const<DIM>)>>>::forward(self, input.expand())
.reshape()
}
}
impl<
const DIM: usize,
const K_DIM: usize,
const V_DIM: usize,
const HEADS: usize,
S: Dimension,
S1: Dimension,
>
Module<(
GraphTensor<(S, Const<DIM>)>,
GraphTensor<(S1, Const<DIM>)>,
GraphTensor<(S, Const<DIM>)>,
)> for MultiHeadSelfAttention<DIM, K_DIM, V_DIM, HEADS>
{
type Output = GraphTensor<(S1, Const<DIM>)>;
fn forward(
&self,
(k, q, v): (
GraphTensor<(S, Const<DIM>)>,
GraphTensor<(S1, Const<DIM>)>,
GraphTensor<(S, Const<DIM>)>,
),
) -> Self::Output {
// Pass to batched forward
<Self as Module<(
GraphTensor<(Const<1>, S, Const<DIM>)>,
GraphTensor<(Const<1>, S1, Const<DIM>)>,
GraphTensor<(Const<1>, S, Const<DIM>)>,
)>>::forward(self, (k.expand(), q.expand(), v.expand()))
.reshape()
}
}
// Batched
impl<
const DIM: usize,
const K_DIM: usize,
const V_DIM: usize,
const HEADS: usize,
S: Dimension,
B: Dimension,
> Module<GraphTensor<(B, S, Const<DIM>)>> for MultiHeadSelfAttention<DIM, K_DIM, V_DIM, HEADS>
{
type Output = GraphTensor<(B, S, Const<DIM>)>;
impl Module<GraphTensor> for MultiHeadSelfAttention {
type Output = GraphTensor;
fn forward(&self, input: GraphTensor<(B, S, Const<DIM>)>) -> Self::Output {
<Self as Module<(
GraphTensor<(B, S, Const<DIM>)>,
GraphTensor<(B, S, Const<DIM>)>,
GraphTensor<(B, S, Const<DIM>)>,
)>>::forward(self, (input, input, input))
fn forward(&self, input: GraphTensor) -> Self::Output {
// Input: batch_dims, sequence, dim
<Self as Module<(GraphTensor, GraphTensor, GraphTensor)>>::forward(
self,
(input, input, input),
)
}
}
// Batched different key-query-value
impl<
const DIM: usize,
const K_DIM: usize,
const V_DIM: usize,
const HEADS: usize,
S1: Dimension,
S2: Dimension,
B: Dimension,
>
Module<(
GraphTensor<(B, S1, Const<DIM>)>,
GraphTensor<(B, S2, Const<DIM>)>,
GraphTensor<(B, S1, Const<DIM>)>,
)> for MultiHeadSelfAttention<DIM, K_DIM, V_DIM, HEADS>
{
type Output = GraphTensor<(B, S2, Const<DIM>)>;
impl Module<(GraphTensor, GraphTensor, GraphTensor)> for MultiHeadSelfAttention {
type Output = GraphTensor;
fn forward(
&self,
(keys, queries, values): (
GraphTensor<(B, S1, Const<DIM>)>,
GraphTensor<(B, S2, Const<DIM>)>,
GraphTensor<(B, S1, Const<DIM>)>,
GraphTensor, // batch, s1, dim
GraphTensor, // batch, s2, dim
GraphTensor, // batch, s1, dim
),
) -> Self::Output {
let orig_query_shape = queries.shape();
let s1 = keys.shape()[keys.shape.len() - 2].small();
let s2 = queries.shape()[queries.shape.len() - 2].small();
let n_batches = queries
.shape()
.into_iter()
.take(queries.shape.len() - 2)
.product::<BigExpression>()
.max(1)
.small();
let dim = queries.shape().last().unwrap().small();
let keys = keys.reshape((n_batches, s1, dim));
let values = values.reshape((n_batches, s1, dim));
let queries = queries.reshape((n_batches, s2, dim));
let values = self
.w_v
.forward(values)
.dyn_reshape::<(B, S1, Const<HEADS>, Dyn<'-'>), _>(&[
B::size(),
S1::size(),
HEADS.into(),
(K_DIM / HEADS).into(),
])
.permute::<_, Axes4<0, 2, 1, 3>>();
.reshape((n_batches, s1, self.heads, self.k_dim / self.heads))
.permute((0, 2, 1, 3));
let keys = self
.w_k
.forward(keys)
.dyn_reshape::<(B, S1, Const<HEADS>, Dyn<'-'>), _>(&[
B::size(),
S1::size(),
HEADS.into(),
(K_DIM / HEADS).into(),
])
.permute::<_, Axes4<0, 2, 3, 1>>();
.reshape((n_batches, s1, self.heads, self.k_dim / self.heads))
.permute((0, 2, 3, 1));
let queries = self
.w_q
.forward(queries)
.dyn_reshape::<(B, S2, Const<HEADS>, Dyn<'-'>), _>(&[
B::size(),
S2::size(),
HEADS.into(),
(K_DIM / HEADS).into(),
])
.permute::<_, Axes4<0, 2, 1, 3>>();
.reshape((n_batches, s2, self.heads, self.k_dim / self.heads))
.permute((0, 2, 1, 3));
let weights = queries
.matmul(keys)
.mul((1.0 / ((K_DIM / HEADS) as f64).sqrt()) as f32)
.softmax::<Axis<3>>();
.mul((1.0 / ((self.k_dim / self.heads) as f64).sqrt()) as f32)
.softmax(3);
let tokens: GraphTensor<(B, S2, Const<V_DIM>)> = weights
let tokens = weights
.matmul(values)
.permute::<_, Axes4<0, 2, 1, 3>>()
.reshape();
self.w_o.forward(tokens)
.permute((0, 2, 1, 3))
.reshape((n_batches, s2, self.v_dim));
self.w_o.forward(tokens).reshape(orig_query_shape) // batch_dims, s2, dim
}
}
@@ -195,7 +117,7 @@ mod tests {
#[test]
fn test_self_attention() {
let mut cx = Graph::new();
let model: MultiHeadSelfAttention<3, 3, 3, 1> = InitModule::initialize(&mut cx);
let model = MultiHeadSelfAttention::new(3, 3, 3, 1, &mut cx);
model
.w_k
.weight
@@ -213,20 +135,17 @@ mod tests {
.weight
.set(vec![1., 22., 3., 1., 2., 3., 1., 2., 3.]);
let a = cx.tensor::<(Dyn<'d'>, luminal::shape::Const<3>)>();
let e = cx.tensor::<(Dyn<'e'>, luminal::shape::Const<3>)>();
let a = cx.tensor(('d', 3));
let e = cx.tensor(('e', 3));
let b = model.forward((e, a, e));
a.set_dyn(
vec![
0.56587636, -1.4053632, 0.8394869, 0.5916256, -1.4082357, 0.8166099,
],
&[2, 3],
);
e.set_dyn(
vec![-1.0, 2.0, 3.0, 3.0, 3.0, -1.0, -1.0, 2.0, 3.0],
&[3, 3],
(2, 3),
);
e.set_dyn(vec![-1.0, 2.0, 3.0, 3.0, 3.0, -1.0, -1.0, 2.0, 3.0], (3, 3));
b.retrieve();
cx.execute();

View File

@@ -4,28 +4,21 @@ use luminal::prelude::*;
use super::attention::MultiHeadSelfAttention;
/// A transformer decoder as layed out in [*Attention Is All You Need*](https://arxiv.org/abs/1706.03762).
pub struct TransformerDecoder<
const DIM: usize,
const FF: usize,
const HEADS: usize,
const LAYERS: usize,
> {
pub layers: Vec<TransformerDecoderBlock<DIM, FF, HEADS>>,
pub struct TransformerDecoder {
pub layers: Vec<TransformerDecoderBlock>,
}
impl<const DIM: usize, const FF: usize, const HEADS: usize, const LAYERS: usize> InitModule
for TransformerDecoder<DIM, FF, HEADS, LAYERS>
{
fn initialize(cx: &mut Graph) -> Self {
impl TransformerDecoder {
pub fn new(dim: usize, ff: usize, heads: usize, layers: usize, cx: &mut Graph) -> Self {
Self {
layers: (0..LAYERS).map(|_| InitModule::initialize(cx)).collect(),
layers: (0..layers)
.map(|_| TransformerDecoderBlock::new(dim, ff, heads, cx))
.collect(),
}
}
}
impl<const DIM: usize, const FF: usize, const HEADS: usize, const LAYERS: usize> SerializeModule
for TransformerDecoder<DIM, FF, HEADS, LAYERS>
{
impl SerializeModule for TransformerDecoder {
fn serialize(&self, s: &mut Serializer) {
for (i, l) in self.layers.iter().enumerate() {
s.module(&format!("layer{i}"), l);
@@ -33,55 +26,10 @@ impl<const DIM: usize, const FF: usize, const HEADS: usize, const LAYERS: usize>
}
}
// Single
impl<
const DIM: usize,
const FF: usize,
const HEADS: usize,
const LAYERS: usize,
S1: Dimension,
S2: Dimension,
> Module<(GraphTensor<(S1, Const<DIM>)>, GraphTensor<(S2, Const<DIM>)>)>
for TransformerDecoder<DIM, FF, HEADS, LAYERS>
{
type Output = GraphTensor<(S1, Const<DIM>)>;
impl Module<(GraphTensor, GraphTensor)> for TransformerDecoder {
type Output = GraphTensor;
fn forward(
&self,
(input, from_enc): (GraphTensor<(S1, Const<DIM>)>, GraphTensor<(S2, Const<DIM>)>),
) -> Self::Output {
<Self as Module<(
GraphTensor<(Const<1>, S1, Const<DIM>)>,
GraphTensor<(Const<1>, S2, Const<DIM>)>,
)>>::forward(self, (input.expand(), from_enc.expand()))
.reshape()
}
}
// Batched
impl<
const DIM: usize,
const FF: usize,
const HEADS: usize,
const LAYERS: usize,
B: Dimension,
S1: Dimension,
S2: Dimension,
>
Module<(
GraphTensor<(B, S1, Const<DIM>)>,
GraphTensor<(B, S2, Const<DIM>)>,
)> for TransformerDecoder<DIM, FF, HEADS, LAYERS>
{
type Output = GraphTensor<(B, S1, Const<DIM>)>;
fn forward(
&self,
(mut input, from_enc): (
GraphTensor<(B, S1, Const<DIM>)>,
GraphTensor<(B, S2, Const<DIM>)>,
),
) -> Self::Output {
fn forward(&self, (mut input, from_enc): (GraphTensor, GraphTensor)) -> Self::Output {
for layer in &self.layers {
input = layer.forward((input, from_enc));
}
@@ -90,27 +38,27 @@ impl<
}
/// A single transformer decoder block
pub struct TransformerDecoderBlock<const DIM: usize, const FF: usize, const HEADS: usize> {
pub self_attention: MultiHeadSelfAttention<DIM, DIM, DIM, HEADS>,
pub cross_attention: MultiHeadSelfAttention<DIM, DIM, DIM, HEADS>,
pub ff: (Linear<DIM, FF>, ReLU, Linear<FF, DIM>),
pub struct TransformerDecoderBlock {
pub self_attention: MultiHeadSelfAttention,
pub cross_attention: MultiHeadSelfAttention,
pub ff: (Linear, ReLU, Linear),
}
impl<const DIM: usize, const FF: usize, const HEADS: usize> InitModule
for TransformerDecoderBlock<DIM, FF, HEADS>
{
fn initialize(cx: &mut Graph) -> Self {
impl TransformerDecoderBlock {
pub fn new(dim: usize, ff: usize, heads: usize, cx: &mut Graph) -> Self {
Self {
cross_attention: InitModule::initialize(cx),
self_attention: InitModule::initialize(cx),
ff: InitModule::initialize(cx),
cross_attention: MultiHeadSelfAttention::new(dim, dim, dim, heads, cx),
self_attention: MultiHeadSelfAttention::new(dim, dim, dim, heads, cx),
ff: (
Linear::new(dim, ff, false, cx),
ReLU,
Linear::new(ff, dim, false, cx),
),
}
}
}
impl<const DIM: usize, const FF: usize, const HEADS: usize> SerializeModule
for TransformerDecoderBlock<DIM, FF, HEADS>
{
impl SerializeModule for TransformerDecoderBlock {
fn serialize(&self, s: &mut Serializer) {
s.module("self_attn", &self.self_attention);
s.module("cross_attn", &self.cross_attention);
@@ -118,55 +66,32 @@ impl<const DIM: usize, const FF: usize, const HEADS: usize> SerializeModule
}
}
// Single
impl<const DIM: usize, const FF: usize, const HEADS: usize, S1: Dimension, S2: Dimension>
Module<(GraphTensor<(S1, Const<DIM>)>, GraphTensor<(S2, Const<DIM>)>)>
for TransformerDecoderBlock<DIM, FF, HEADS>
{
type Output = GraphTensor<(S1, Const<DIM>)>;
impl Module<(GraphTensor, GraphTensor)> for TransformerDecoderBlock {
type Output = GraphTensor;
fn forward(
&self,
(input, from_enc): (GraphTensor<(S1, Const<DIM>)>, GraphTensor<(S2, Const<DIM>)>),
) -> Self::Output {
// Pass to batched forward
<Self as Module<(
GraphTensor<(Const<1>, S1, Const<DIM>)>,
GraphTensor<(Const<1>, S2, Const<DIM>)>,
)>>::forward(self, (input.expand(), from_enc.expand()))
.reshape()
}
}
// Batched
impl<
const DIM: usize,
const FF: usize,
const HEADS: usize,
S1: Dimension,
S2: Dimension,
B: Dimension,
>
Module<(
GraphTensor<(B, S1, Const<DIM>)>,
GraphTensor<(B, S2, Const<DIM>)>,
)> for TransformerDecoderBlock<DIM, FF, HEADS>
{
type Output = GraphTensor<(B, S1, Const<DIM>)>;
fn forward(
&self,
(x, from_enc): (
GraphTensor<(B, S1, Const<DIM>)>,
GraphTensor<(B, S2, Const<DIM>)>,
),
) -> Self::Output {
let y = self.self_attention.forward(x);
let x = (y + x).layer_norm::<Axis<2>, _>(1e-5);
let y = self.cross_attention.forward((from_enc, x, from_enc));
let x = (y + x).layer_norm::<Axis<2>, _>(1e-5);
fn forward(&self, (input, from_enc): (GraphTensor, GraphTensor)) -> Self::Output {
// Input: batch_dims, seq1, dim
// From_enc: batch_dims, seq2, dim
// Flatten to single batch dim
let seq1 = input.shape()[input.shape.len() - 2].small();
let seq2 = from_enc.shape()[from_enc.shape.len() - 2].small();
let dim = input.shape().last().unwrap().small();
let n_batches = input
.shape()
.into_iter()
.take(input.shape.len() - 2)
.product::<BigExpression>()
.max(1)
.small();
let inp = input.reshape((n_batches, seq1, dim));
let fe = from_enc.reshape((n_batches, seq2, dim));
// Batched forward pass
let y = self.self_attention.forward(inp);
let x = (y + inp).layer_norm(2, 1e-5);
let y = self.cross_attention.forward((fe, x, fe));
let x = (y + x).layer_norm(2, 1e-5);
let y = self.ff.forward(x);
(y + x).layer_norm::<Axis<2>, _>(1e-5)
(y + x).layer_norm(2, 1e-5).reshape(input.shape)
}
}
@@ -187,7 +112,7 @@ mod tests {
#[test]
fn test_transformer_decoder_block() {
let mut cx = Graph::new();
let model: TransformerDecoderBlock<3, 4, 1> = InitModule::initialize(&mut cx);
let model = TransformerDecoderBlock::new(3, 4, 1, &mut cx);
model
.self_attention
.w_k
@@ -239,12 +164,12 @@ mod tests {
.weight
.set(vec![-1., 12., 3., -1., 2., -3., 11., 2., 3., 3., -1., 2.]);
let a = cx.tensor::<(Dyn<'d'>, Const<3>)>();
let e = cx.tensor::<(Dyn<'e'>, Const<3>)>();
let a = cx.tensor(('d', 3));
let e = cx.tensor(('e', 3));
let b = model.forward((a, e));
a.set_dyn(vec![-1., 2., 3., 3., 3., -1.], &[2, 3]);
e.set_dyn(vec![-1., 2., 3., 3., 3., -1., -1., 2., 3.], &[3, 3]);
a.set_dyn(vec![-1., 2., 3., 3., 3., -1.], (2, 3));
e.set_dyn(vec![-1., 2., 3., 3., 3., -1., -1., 2., 3.], (3, 3));
b.retrieve();
cx.execute();

View File

@@ -1,66 +1,88 @@
use crate::{Linear, ReLU, Repeated};
use crate::{Linear, ReLU};
use luminal::prelude::*;
use super::attention::MultiHeadSelfAttention;
/// A transformer encoder as layed out in [*Attention Is All You Need*](https://arxiv.org/abs/1706.03762).
pub type TransformerEncoder<
const DIM: usize,
const FF: usize,
const HEADS: usize,
const LAYERS: usize,
> = Repeated<TransformerEncoderBlock<DIM, FF, HEADS>, LAYERS>;
/// A single transformer encoder block
pub struct TransformerEncoderBlock<const DIM: usize, const FF: usize, const HEADS: usize> {
pub attention: MultiHeadSelfAttention<DIM, DIM, DIM, HEADS>,
pub ff: (Linear<DIM, FF>, ReLU, Linear<FF, DIM>),
pub struct TransformerEncoder {
pub layers: Vec<TransformerEncoderBlock>,
}
impl<const DIM: usize, const FF: usize, const HEADS: usize> InitModule
for TransformerEncoderBlock<DIM, FF, HEADS>
{
fn initialize(cx: &mut Graph) -> Self {
impl TransformerEncoder {
pub fn new(dim: usize, ff: usize, heads: usize, layers: usize, cx: &mut Graph) -> Self {
Self {
attention: InitModule::initialize(cx),
ff: InitModule::initialize(cx),
layers: (0..layers)
.map(|_| TransformerEncoderBlock::new(dim, ff, heads, cx))
.collect(),
}
}
}
impl<const DIM: usize, const FF: usize, const HEADS: usize> SerializeModule
for TransformerEncoderBlock<DIM, FF, HEADS>
{
impl SerializeModule for TransformerEncoder {
fn serialize(&self, s: &mut Serializer) {
for (i, l) in self.layers.iter().enumerate() {
s.module(&format!("layer{i}"), l);
}
}
}
impl Module<GraphTensor> for TransformerEncoder {
type Output = GraphTensor;
fn forward(&self, mut input: GraphTensor) -> Self::Output {
for layer in &self.layers {
input = layer.forward(input);
}
input
}
}
/// A single transformer encoder block
pub struct TransformerEncoderBlock {
pub attention: MultiHeadSelfAttention,
pub ff: (Linear, ReLU, Linear),
}
impl TransformerEncoderBlock {
pub fn new(dim: usize, ff: usize, heads: usize, cx: &mut Graph) -> Self {
Self {
attention: MultiHeadSelfAttention::new(dim, dim, dim, heads, cx),
ff: (
Linear::new(dim, ff, false, cx),
ReLU,
Linear::new(ff, dim, false, cx),
),
}
}
}
impl SerializeModule for TransformerEncoderBlock {
fn serialize(&self, s: &mut Serializer) {
s.module("self_attn", &self.attention);
s.module("ff", &self.ff);
}
}
// Single
impl<const DIM: usize, const FF: usize, const HEADS: usize, S: Dimension>
Module<GraphTensor<(S, Const<DIM>)>> for TransformerEncoderBlock<DIM, FF, HEADS>
{
type Output = GraphTensor<(S, Const<DIM>)>;
fn forward(&self, input: GraphTensor<(S, Const<DIM>)>) -> Self::Output {
// Pass to batched forward
<Self as Module<GraphTensor<(Const<1>, S, Const<DIM>)>>>::forward(self, input.expand())
.reshape()
}
}
// Batched
impl<const DIM: usize, const FF: usize, const HEADS: usize, S: Dimension, B: Dimension>
Module<GraphTensor<(B, S, Const<DIM>)>> for TransformerEncoderBlock<DIM, FF, HEADS>
{
type Output = GraphTensor<(B, S, Const<DIM>)>;
impl Module<GraphTensor> for TransformerEncoderBlock {
type Output = GraphTensor;
fn forward(&self, x: GraphTensor<(B, S, Const<DIM>)>) -> Self::Output {
fn forward(&self, input: GraphTensor) -> Self::Output {
// Input: batch_dims, sequence, dim
// Reshape to 1 batch dim, sequence, dim
let n_batches = input
.shape()
.into_iter()
.take(input.shape.len() - 2)
.product::<BigExpression>()
.max(1);
let sequence = input.shape()[input.shape.len() - 2].small();
let dim = input.shape()[input.shape.len() - 1].small();
let x = input.reshape((n_batches, sequence, dim));
let x = x + self.attention.forward(x);
let x = x.layer_norm::<Axis<2>, _>(1e-5);
let x = x.layer_norm(2, 1e-5);
let x = x + self.ff.forward(x);
x.layer_norm::<Axis<2>, _>(1e-5)
x.layer_norm(2, 1e-5).reshape(input.shape())
}
}
@@ -81,7 +103,7 @@ mod tests {
#[test]
fn test_transformer_encoder_block() {
let mut cx = Graph::new();
let model: TransformerEncoderBlock<3, 4, 1> = InitModule::initialize(&mut cx);
let model = TransformerEncoderBlock::new(3, 4, 1, &mut cx);
model
.attention
.w_k
@@ -114,8 +136,8 @@ mod tests {
.set(vec![-1., 12., 3., -1., 2., -3., 11., 2., 3., 3., -1., 2.]);
let a = cx
.tensor::<(Dyn<'s'>, luminal::shape::Const<3>)>()
.set_dyn(vec![-1., 2., 3., 3., 3., -1.], &[2, 3]);
.tensor(('s', 3))
.set_dyn(vec![-1., 2., 3., 3., 3., -1.], (2, 3));
let b = model.forward(a).retrieve();
cx.execute();

View File

@@ -7,44 +7,29 @@ pub use decoder::*;
mod encoder;
pub use encoder::*;
pub struct Transformer<
const DIM: usize,
const FF: usize,
const ENC_HEADS: usize,
const DEC_HEADS: usize,
const ENC_LAYERS: usize,
const DEC_LAYERS: usize,
> {
pub encoder: encoder::TransformerEncoder<DIM, FF, ENC_HEADS, ENC_LAYERS>,
pub decoder: decoder::TransformerDecoder<DIM, FF, DEC_HEADS, DEC_LAYERS>,
pub struct Transformer {
pub encoder: encoder::TransformerEncoder,
pub decoder: decoder::TransformerDecoder,
}
impl<
const DIM: usize,
const FF: usize,
const ENC_HEADS: usize,
const DEC_HEADS: usize,
const ENC_LAYERS: usize,
const DEC_LAYERS: usize,
> InitModule for Transformer<DIM, FF, ENC_HEADS, DEC_HEADS, ENC_LAYERS, DEC_LAYERS>
{
fn initialize(cx: &mut Graph) -> Self {
impl Transformer {
pub fn new(
dim: usize,
ff: usize,
enc_heads: usize,
dec_heads: usize,
enc_layers: usize,
dec_layers: usize,
cx: &mut Graph,
) -> Self {
Self {
encoder: InitModule::initialize(cx),
decoder: InitModule::initialize(cx),
encoder: TransformerEncoder::new(dim, ff, enc_heads, enc_layers, cx),
decoder: TransformerDecoder::new(dim, ff, dec_heads, dec_layers, cx),
}
}
}
impl<
const DIM: usize,
const FF: usize,
const ENC_HEADS: usize,
const DEC_HEADS: usize,
const ENC_LAYERS: usize,
const DEC_LAYERS: usize,
> SerializeModule for Transformer<DIM, FF, ENC_HEADS, DEC_HEADS, ENC_LAYERS, DEC_LAYERS>
{
impl SerializeModule for Transformer {
fn serialize(&self, s: &mut Serializer) {
s.module("encoder", &self.encoder);
s.module("decoder", &self.decoder);
@@ -52,24 +37,10 @@ impl<
}
// Single Sequence
impl<
const DIM: usize,
const FF: usize,
const ENC_HEADS: usize,
const DEC_HEADS: usize,
const ENC_LAYERS: usize,
const DEC_LAYERS: usize,
S1: Dimension,
S2: Dimension,
> Module<(GraphTensor<(S1, Const<DIM>)>, GraphTensor<(S2, Const<DIM>)>)>
for Transformer<DIM, FF, ENC_HEADS, DEC_HEADS, ENC_LAYERS, DEC_LAYERS>
{
type Output = GraphTensor<(S2, Const<DIM>)>;
impl Module<(GraphTensor, GraphTensor)> for Transformer {
type Output = GraphTensor;
fn forward(
&self,
(input, target): (GraphTensor<(S1, Const<DIM>)>, GraphTensor<(S2, Const<DIM>)>),
) -> Self::Output {
fn forward(&self, (input, target): (GraphTensor, GraphTensor)) -> Self::Output {
let encoded = self.encoder.forward(input);
self.decoder.forward((target, encoded))
}
@@ -92,7 +63,7 @@ mod tests {
#[test]
fn test_transformer_full() {
let mut cx = Graph::new();
let model: Transformer<3, 4, 1, 1, 1, 1> = InitModule::initialize(&mut cx);
let model = Transformer::new(3, 4, 1, 1, 1, 1, &mut cx);
model.decoder.layers[0]
.self_attention
.w_k
@@ -143,43 +114,43 @@ mod tests {
.2
.weight
.set(vec![-1., 12., 3., -1., 2., -3., 11., 2., 3., 3., -1., 2.]);
model.encoder.modules[0]
model.encoder.layers[0]
.attention
.w_k
.weight
.set(vec![1., 22., 3., 1., 2., 3., 1., 2., 3.]);
model.encoder.modules[0]
model.encoder.layers[0]
.attention
.w_q
.weight
.set(vec![3., 2., 3., 1.3, 2., 3., 3., 2., 3.]);
model.encoder.modules[0]
model.encoder.layers[0]
.attention
.w_v
.weight
.set(vec![-1., 12., 3., -1., 2., -3., 11., 2., 3.]);
model.encoder.modules[0]
model.encoder.layers[0]
.attention
.w_o
.weight
.set(vec![1., 22., 3., 1., 2., 3., 1., 2., 3.]);
model.encoder.modules[0]
model.encoder.layers[0]
.ff
.0
.weight
.set(vec![-1., 12., 3., -1., 2., -3., 11., 2., 3., 11., 2., 3.]);
model.encoder.modules[0]
model.encoder.layers[0]
.ff
.2
.weight
.set(vec![-1., 12., 3., -1., 2., -3., 11., 2., 3., 3., -1., 2.]);
let a = cx.tensor::<(Dyn<'d'>, luminal::shape::Const<3>)>();
let e = cx.tensor::<(Dyn<'e'>, luminal::shape::Const<3>)>();
let a = cx.tensor(('d', 3));
let e = cx.tensor(('e', 3));
let b = model.forward((a, e));
a.set_dyn(vec![-1., 2., 3., 3., 3., -1.], &[2, 3]);
e.set_dyn(vec![-1., 2., 3., 3., 3., -1., -1., 2., 3.], &[3, 3]);
a.set_dyn(vec![-1., 2., 3., 3., 3., -1.], (2, 3));
e.set_dyn(vec![-1., 2., 3., 3., 3., -1., -1., 2., 3.], (3, 3));
b.retrieve();
cx.execute();

View File

@@ -16,7 +16,7 @@ use luminal::{
pub struct Autograd(Vec<NodeIndex>, NodeIndex);
impl Autograd {
pub fn new<W: ToIds>(params: W, loss: GraphTensor<()>) -> Self {
pub fn new<W: ToIds>(params: W, loss: GraphTensor) -> Self {
Self(params.to_ids(), loss.id)
}
}
@@ -61,7 +61,7 @@ impl Compiler for Autograd {
*loss,
(
graph.constant(1.0).id,
ShapeTracker::new(&[]), // Assume scalar loss for now
ShapeTracker::default(), // Assume scalar loss for now
),
);
let weight_set = params.iter().copied().collect::<FxHashSet<_>>();
@@ -90,7 +90,7 @@ impl Compiler for Autograd {
.edges_directed(fwd_node, Direction::Incoming)
.filter_map(|e| e.weight().as_data().map(|i| (e.source(), i)))
.sorted_by_key(|(_, (a, _, _))| *a)
.map(|(node, (_, _, sh))| GraphTensor::<()>::from_id(node, sh, graph_ref))
.map(|(node, (_, _, sh))| GraphTensor::from_id(node, sh, graph_ref))
.collect::<Vec<_>>();
let mut prev_grad = {
let (id, sh) = grads[&fwd_node];
@@ -139,7 +139,7 @@ impl Compiler for Autograd {
prev_grad
.shape
.expand(op.0, inps[0].shape.dims[inps[0].shape.indexes[op.0]]);
let reduced = GraphTensor::<()>::from_id(fwd_node, prev_grad.shape, graph_ref);
let reduced = GraphTensor::from_id(fwd_node, prev_grad.shape, graph_ref);
let grad = inps[0].equals(reduced) * prev_grad;
add_grad(grad, inps[0], graph, &mut grads);
}
@@ -184,8 +184,8 @@ impl Compiler for Autograd {
}
fn add_grad(
mut grad: GraphTensor<()>,
fwd: GraphTensor<()>,
mut grad: GraphTensor,
fwd: GraphTensor,
graph: &mut Graph,
grad_map: &mut FxHashMap<NodeIndex, (NodeIndex, ShapeTracker)>,
) {
@@ -226,9 +226,8 @@ fn add_grad(
}
if let Some((existing_grad_node, existing_grad_shape)) = grad_map.get(&fwd.id).copied() {
let grad = GraphTensor::<()>::from_id(grad.id, grad.shape, graph);
let existing_grad =
GraphTensor::<()>::from_id(existing_grad_node, existing_grad_shape, graph);
let grad = GraphTensor::from_id(grad.id, grad.shape, graph);
let existing_grad = GraphTensor::from_id(existing_grad_node, existing_grad_shape, graph);
let new_grad = grad + existing_grad;
grad_map.insert(fwd.id, (new_grad.id, new_grad.shape));
} else {
@@ -244,14 +243,14 @@ mod tests {
luminal::test_imports!();
fn get_vec(grad: (NodeIndex, ShapeTracker), cx: &mut Graph) -> Vec<f32> {
GraphTensor::<()>::from_id(grad.0, grad.1, cx).data()
GraphTensor::from_id(grad.0, grad.1, cx).data()
}
#[test]
fn test_autograd_max_reduce() {
let mut cx = Graph::new();
let a = cx.named_tensor("Input").set([10., 5.]);
let b = a.max_reduce();
let a = cx.named_tensor("Input", 2).set([10., 5.]);
let b = a.max_reduce(0);
let grads = cx.compile(Autograd::new(a, b), ());
cx.keep_tensors(&grads);
@@ -268,9 +267,9 @@ mod tests {
#[test]
fn test_autograd_matmul() {
let mut cx = Graph::new();
let a = cx.named_tensor("A").set([[2., 4.], [3., 1.]]);
let input = cx.named_tensor("Input").set([10., 5.]);
let output = (input.matmul(a)).sum_reduce();
let a = cx.named_tensor("A", (2, 2)).set([[2., 4.], [3., 1.]]);
let input = cx.named_tensor("Input", 2).set([10., 5.]);
let output = (input.matmul(a)).sum_reduce(0);
let grads = cx.compile(Autograd::new(a, output), ());
cx.keep_tensors(&grads);
@@ -288,15 +287,15 @@ mod tests {
#[test]
fn test_autograd_mlp() {
let mut cx = Graph::new();
let model = <(
luminal_nn::Linear<2, 2>,
let model = (
luminal_nn::Linear::new(2, 2, false, &mut cx),
luminal_nn::ReLU,
luminal_nn::Linear<2, 1>,
)>::initialize(&mut cx);
luminal_nn::Linear::new(2, 1, false, &mut cx),
);
model.0.weight.set([[2., 4.], [3., 1.]]);
model.2.weight.set([[6.], [5.]]);
let input = cx.named_tensor("Input").set([10., 5.]);
let output = model.forward(input).sum_reduce();
let input = cx.named_tensor("Input", 2).set([10., 5.]);
let output = model.forward(input).sum_reduce(0);
let mut grads = cx.compile(Autograd::new(params(model), output), ());
cx.keep_tensors(&grads);
@@ -328,8 +327,8 @@ mod tests {
#[test]
fn test_autograd_layer_norm() {
let mut cx = Graph::new();
let a = cx.tensor().set([-1., 2., 3.]);
let mut b = a.layer_norm(1e-5).max_reduce().retrieve();
let a = cx.tensor(3).set([-1., 2., 3.]);
let mut b = a.layer_norm(0, 1e-5).max_reduce(0).retrieve();
let grads = cx.compile(Autograd::new(a, b), &mut b);
cx.keep_tensors(&grads);
@@ -347,8 +346,8 @@ mod tests {
#[test]
fn test_autograd_softmax() {
let mut cx = Graph::new();
let a = cx.tensor().set([-1., 2., 3.]);
let mut b = a.softmax().max_reduce().retrieve();
let a = cx.tensor(3).set([-1., 2., 3.]);
let mut b = a.softmax(0).max_reduce(0).retrieve();
let mut grads = cx.compile(Autograd::new(a, b), &mut b);
cx.keep_tensors(&grads);
@@ -366,7 +365,7 @@ mod tests {
#[test]
fn test_autograd_transformer() {
let mut cx = Graph::new();
let model: luminal_nn::TransformerEncoderBlock<3, 4, 1> = InitModule::initialize(&mut cx);
let model = luminal_nn::TransformerEncoderBlock::new(3, 4, 1, &mut cx);
model
.attention
.w_k
@@ -398,8 +397,8 @@ mod tests {
.weight
.set(vec![-1., 12., 3., -1., 2., -3., 11., 2., 3., 3., -1., 2.]);
let a = cx.tensor().set([[-1., 2., 3.], [3., 3., -1.]]);
let target = cx.tensor().set([[0., 1., 0.], [0., 0., 1.]]);
let a = cx.tensor((2, 3)).set([[-1., 2., 3.], [3., 3., -1.]]);
let target = cx.tensor((2, 3)).set([[0., 1., 0.], [0., 0., 1.]]);
let out = model.forward(a);
let mut loss = crate::cross_entropy_with_logits_loss(out, target).retrieve();

View File

@@ -3,22 +3,26 @@ use luminal::prelude::*;
/// [Mean Squared Error](https://en.wikipedia.org/wiki/Mean_squared_error).
///
/// This computes `(prediction - target).square().mean()`.
pub fn mse_loss<S: Shape>(prediction: GraphTensor<S>, target: GraphTensor<S>) -> GraphTensor<()> {
(prediction - target).square().mean_reduce()
pub fn mse_loss(prediction: GraphTensor, target: GraphTensor) -> GraphTensor {
(prediction - target)
.square()
.mean_reduce(prediction.shape.all_axes())
}
/// [Root Mean square error](https://en.wikipedia.org/wiki/Root-mean-square_deviation).
///
/// This computes `(prediction - target).square().mean().sqrt()`
pub fn rmse_loss<S: Shape>(prediction: GraphTensor<S>, target: GraphTensor<S>) -> GraphTensor<()> {
pub fn rmse_loss(prediction: GraphTensor, target: GraphTensor) -> GraphTensor {
mse_loss(prediction, target).sqrt()
}
/// [Mean absolute error](https://en.wikipedia.org/wiki/Mean_absolute_error).
///
/// This computes `(prediction - target).abs().mean()`
pub fn mae_loss<S: Shape>(prediction: GraphTensor<S>, target: GraphTensor<S>) -> GraphTensor<()> {
(prediction - target).abs().mean_reduce()
pub fn mae_loss(prediction: GraphTensor, target: GraphTensor) -> GraphTensor {
(prediction - target)
.abs()
.mean_reduce(prediction.shape.all_axes())
}
/// [Huber Loss](https://en.wikipedia.org/wiki/Huber_loss)
@@ -28,18 +32,19 @@ pub fn mae_loss<S: Shape>(prediction: GraphTensor<S>, target: GraphTensor<S>) ->
/// It computes:
/// 1. if `|x - y| < delta`: `0.5 * (x - y)^2`
/// 2. otherwise: `delta * (|x - y| - 0.5 * delta)`
pub fn huber_loss<S: Shape>(
prediction: GraphTensor<S>,
target: GraphTensor<S>,
pub fn huber_loss(
prediction: GraphTensor,
target: GraphTensor,
delta: impl Into<f32>,
) -> GraphTensor<()> {
) -> GraphTensor {
let delta: f32 = delta.into();
let abs_error = (prediction - target).abs();
let delta_tensor = prediction.graph().constant(delta);
let huber_error = (0.5 * (prediction - target).square())
* abs_error.less_than(delta_tensor.expand())
+ (delta * (abs_error - 0.5 * delta)) * abs_error.greater_than_equal(delta_tensor.expand());
huber_error.mean_reduce()
* abs_error.less_than(delta_tensor.expand_to(abs_error.shape))
+ (delta * (abs_error - 0.5 * delta))
* abs_error.greater_than_equal(delta_tensor.expand_to(abs_error.shape));
huber_error.mean_reduce(huber_error.shape.all_axes())
}
/// Smooth l1 loss (closely related to [Huber Loss](https://en.wikipedia.org/wiki/Huber_loss))
@@ -49,11 +54,11 @@ pub fn huber_loss<S: Shape>(
/// It computes:
/// 1. if `|x - y| < beta`: `0.5 * (x - y)^2 / beta`
/// 2. otherwise: `|x - y| - 0.5 * beta`
pub fn smooth_l1_loss<S: Shape>(
prediction: GraphTensor<S>,
target: GraphTensor<S>,
delta: impl Copy + Into<f32>,
) -> GraphTensor<()> {
pub fn smooth_l1_loss(
prediction: GraphTensor,
target: GraphTensor,
delta: impl Into<f32> + Copy,
) -> GraphTensor {
huber_loss(prediction, target, delta) / delta.into()
}
@@ -67,16 +72,17 @@ pub fn smooth_l1_loss<S: Shape>(
///
/// - `logits`: The un-normalized output from a model. [log_softmax()] is called **in** this function
/// - `target_probabilities`: Target containing probability vectors **NOT** class indices.
pub fn cross_entropy_with_logits_loss<S: Shape>(
logits: GraphTensor<S>,
target_probabilities: GraphTensor<S>,
) -> GraphTensor<()> {
pub fn cross_entropy_with_logits_loss(
logits: GraphTensor,
target_probabilities: GraphTensor,
) -> GraphTensor {
let inv_last_axis_numel = 1.0
/ logits
.graph()
.constant(logits.shape.shape().last().unwrap());
let probs = logits.log_softmax::<S::LastAxis>();
(-(probs * target_probabilities).mean_reduce()) / inv_last_axis_numel
let probs = logits.log_softmax(logits.shape.last_axis());
(-(probs * target_probabilities).mean_reduce(target_probabilities.shape.all_axes()))
/ inv_last_axis_numel
}
/// [KL Divergence loss](https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence).
@@ -89,16 +95,17 @@ pub fn cross_entropy_with_logits_loss<S: Shape>(
///
/// - `logits`: The un-normalized output from a model. [log_softmax()] is called **in** this function
/// - `target_probs`: Target containing probability vectors **NOT** class indices.
pub fn kl_div_with_logits_loss<S: Shape>(
logits: GraphTensor<S>,
target_probabilities: GraphTensor<S>,
) -> GraphTensor<()> {
pub fn kl_div_with_logits_loss(
logits: GraphTensor,
target_probabilities: GraphTensor,
) -> GraphTensor {
let inv_last_axis_numel = 1.0
/ logits
.graph()
.constant(logits.shape.shape().last().unwrap());
let probs = logits.log_softmax::<S::LastAxis>();
(-((probs - target_probabilities.ln()) * target_probabilities).mean_reduce())
let probs = logits.log_softmax(logits.shape.last_axis());
(-((probs - target_probabilities.ln()) * target_probabilities)
.mean_reduce(target_probabilities.shape.all_axes()))
/ inv_last_axis_numel
}
@@ -111,10 +118,10 @@ pub fn kl_div_with_logits_loss<S: Shape>(
/// ### Inputs
/// - `logits` - unnormalized inputs. **NOT** output of sigmoid
/// - `target_probs` - target values between 0 and 1.
pub fn binary_cross_entropy_with_logits_loss<S: Shape>(
logits: GraphTensor<S>,
target_probabilities: GraphTensor<S>,
) -> GraphTensor<()> {
pub fn binary_cross_entropy_with_logits_loss(
logits: GraphTensor,
target_probabilities: GraphTensor,
) -> GraphTensor {
let bce = (1.0 - target_probabilities) * logits + (1.0 + (-logits).exp()).ln();
bce.mean_reduce()
bce.mean_reduce(bce.shape.all_axes())
}

View File

@@ -12,12 +12,12 @@ pub fn sgd(
Vec<NodeIndex>,
Vec<NodeIndex>,
Graph,
GraphTensor<()>,
GraphTensor,
) {
let mut opt_graph = Graph::new();
let (old_weights, gradients): (Vec<NodeIndex>, Vec<NodeIndex>) = grads
.iter()
.map(|_| (opt_graph.tensor::<()>().id, opt_graph.tensor::<()>().id))
.map(|_| (opt_graph.tensor(1).id, opt_graph.tensor(1).id))
.unzip();
let (new_weights, lr) = sgd_on_graph(
@@ -41,12 +41,12 @@ pub fn sgd_on_graph(
graph: &mut Graph,
old_weights: impl ToIds,
grads: &[(NodeIndex, ShapeTracker)],
) -> (Vec<NodeIndex>, GraphTensor<()>) {
let lr = graph.named_tensor("Learning Rate").set(3e-4).keep(); // Karpathy constant
) -> (Vec<NodeIndex>, GraphTensor) {
let lr = graph.named_tensor("Learning Rate", 1).set(3e-4).keep(); // Karpathy constant
let mut new_weights = vec![];
for ((grad_id, grad_shape), old_weight_id) in grads.iter().copied().zip(old_weights.to_ids()) {
let old_weight = GraphTensor::<()>::from_id(old_weight_id, grad_shape, graph);
let gradient = GraphTensor::<()>::from_id(grad_id, grad_shape, graph);
let old_weight = GraphTensor::from_id(old_weight_id, grad_shape, graph);
let gradient = GraphTensor::from_id(grad_id, grad_shape, graph);
// SGD
let new_weight = old_weight - (gradient * lr.expand_to(grad_shape));

View File

@@ -1,3 +1,5 @@
<|begin_of_text|>Here is an implementation of merge sort:
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
```python
You are a helpful AI assistant for travel tips and recommendations<|eot_id|><|start_header_id|>user<|end_header_id|>
What can you help me with?<|eot_id|><|start_header_id|>assistant<|end_header_id|>

View File

@@ -3,5 +3,5 @@ SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
echo "Downloading Model and Tokenizer..."
curl --location https://huggingface.co/NousResearch/Meta-Llama-3-8B-Instruct/resolve/main/tokenizer.json?download=true --output $SCRIPT_DIR/tokenizer.json
curl --location https://huggingface.co/QuantFactory/Meta-Llama-3-8B-GGUF/resolve/main/Meta-Llama-3-8B.Q8_0.gguf?download=true --output $SCRIPT_DIR/llama3-8b.gguf
curl --location https://huggingface.co/QuantFactory/Meta-Llama-3-8B-Instruct-GGUF-v2/resolve/main/Meta-Llama-3-8B-Instruct-v2.Q8_0.gguf?download=true --output $SCRIPT_DIR/llama3-8b.gguf
echo "Done!"

View File

@@ -1,12 +1,12 @@
use std::{
io::{self, Write},
marker::PhantomData,
time::Instant,
};
use clap::Parser;
use colored::Colorize;
use itertools::Itertools;
use model::{HEAD_DIM, N_KV_HEADS};
use tokenizers::Tokenizer;
mod gguf;
@@ -39,15 +39,20 @@ fn main() {
// Set up graph
let mut cx = Graph::new();
let mut input = cx.named_tensor::<(Const<1>, Dyn<'s'>)>("Input");
let mut cache_src: Vec<KVCache<Const<1>, Dyn<'p'>>> = (0..model::NUM_LAYERS)
.map(|_| (cx.named_tensor("Key Cache"), cx.named_tensor("Value Cache")))
let mut input = cx.named_tensor("Input", (1, 's'));
let mut cache_src: Vec<KVCache> = (0..model::NUM_LAYERS)
.map(|_| {
(
cx.named_tensor("Key Cache", (1, N_KV_HEADS, 'p', HEAD_DIM)),
cx.named_tensor("Value Cache", (1, N_KV_HEADS, 'p', HEAD_DIM)),
)
})
.collect();
cache_src.set_dyn(vec![], &[1, model::N_KV_HEADS, 0, model::HEAD_DIM]);
let model = model::Llama::initialize(&mut cx);
cache_src.set_dyn(vec![], (1, model::N_KV_HEADS, 0, model::HEAD_DIM));
let model = model::Llama::new(&mut cx);
let mut model_weights = params(&model);
cx.keep_tensors(&model_weights);
let (logits, mut cache_dest) = model.forward((input, &cache_src, PhantomData::<Dyn<'t'>>));
let (logits, mut cache_dest) = model.forward((input, &cache_src));
let mut logits = logits
.slice((.., (Expression::from('s') - 1).., ..))
.retrieve();
@@ -69,8 +74,8 @@ fn main() {
GenericCompiler::default(),
#[cfg(feature = "metal")]
(
luminal_metal::MetalCompilerPreBuffer::<f16>::default(),
luminal_metal::quantized::MetalQuantizedCompiler::<f16>::new(q_weights),
luminal_metal::MetalCompilerPreBuffer::<f32>::default(),
luminal_metal::quantized::MetalQuantizedCompiler::<f32>::new(q_weights),
luminal_metal::BufferCompilers::default(),
),
#[cfg(feature = "cuda")]
@@ -93,7 +98,7 @@ fn main() {
print!("Loading model");
io::stdout().flush().unwrap();
let now = Instant::now();
input.set_dyn(vec![1.], &[1, 1]);
input.set_dyn(vec![1.], (1, 1));
cx.set_dyn_dim('t', 1);
cx.execute();
logits.drop();
@@ -112,7 +117,7 @@ fn main() {
.to_vec();
input.set_dyn(
input_ids.iter().map(|i| *i as f32).collect::<Vec<_>>(),
&[1, input_ids.len()],
(1, input_ids.len()),
);
cx.set_dyn_dim('t', input_ids.len());
print!("Processing Prompt");
@@ -130,10 +135,8 @@ fn main() {
// Decode token
print!("{}", cli_args.prompt.white().bold());
print!(
"{}",
tokenizer.decode(&output_ids, false).unwrap().bright_green()
);
let initial = tokenizer.decode(&output_ids, false).unwrap().bright_green();
print!("{initial}",);
io::stdout().flush().unwrap();
// Swap caches
@@ -141,11 +144,10 @@ fn main() {
// Decode loop
let start_decode = std::time::Instant::now();
let mut prev_output_len = 0;
let mut prev_output_len = initial.len();
for _ in 0..cli_args.gen_tokens {
input.set_dyn(vec![*output_ids.last().unwrap() as f32], &[1, 1]);
input.set_dyn(vec![*output_ids.last().unwrap() as f32], (1, 1));
cx.set_dyn_dim('p', input_ids.len() + output_ids.len() - 1);
cx.set_dyn_dim('t', input_ids.len() + output_ids.len());
cx.execute();
// Sample tokens

View File

@@ -1,7 +1,5 @@
use std::{marker::PhantomData, ops::Div};
use luminal::prelude::{binary::F32Pow, *};
use luminal_nn::{Embedding, LayerNorm, PermutedLinear};
use luminal_nn::{Embedding, LayerNorm, Linear};
// Llama3 8B Config
pub const VOCAB_SIZE: usize = 128256;
@@ -13,49 +11,37 @@ pub const MLP_DIM: usize = 14336;
pub const N_ATTENTION_GROUPS: usize = N_HEADS / N_KV_HEADS;
pub const HEAD_DIM: usize = HIDDEN_DIM / N_HEADS;
pub const HEAD_DIM_OVER_2: usize = HEAD_DIM / 2;
pub const ATTN_PROJ_DIM: usize = HEAD_DIM * N_KV_HEADS;
pub type KVCache<Batch, Seq> = (
GraphTensor<(Batch, Const<N_KV_HEADS>, Seq, Const<HEAD_DIM>)>,
GraphTensor<(Batch, Const<N_KV_HEADS>, Seq, Const<HEAD_DIM>)>,
);
pub type KVCache = (GraphTensor, GraphTensor);
pub struct Mlp<const I: usize, const H: usize> {
pub gate_proj: PermutedLinear<H, I>,
pub down_proj: PermutedLinear<I, H>,
pub up_proj: PermutedLinear<H, I>,
pub struct Mlp {
pub gate_proj: Linear, // hidden -> intermediate
pub down_proj: Linear, // intermediate -> hidden
pub up_proj: Linear, // hidden -> intermediate
}
impl<const I: usize, const H: usize, Batch: Dimension, Batch1: Dimension>
Module<GraphTensor<(Batch, Batch1, Const<H>)>> for Mlp<I, H>
{
type Output = GraphTensor<(Batch, Batch1, Const<H>)>;
impl Module<GraphTensor> for Mlp {
type Output = GraphTensor;
fn forward(&self, input: GraphTensor<(Batch, Batch1, Const<H>)>) -> Self::Output {
fn forward(&self, input: GraphTensor) -> Self::Output {
let gate = self.gate_proj.forward(input).swish();
let up = self.up_proj.forward(input) * gate;
self.down_proj.forward(up)
}
}
impl<const I: usize, const H: usize> InitModule for Mlp<I, H> {
fn initialize(cx: &mut Graph) -> Self {
impl Mlp {
pub fn new(hidden: usize, intermediate: usize, cx: &mut Graph) -> Self {
Self {
gate_proj: PermutedLinear {
weight: cx.named_tensor("Gate"),
},
up_proj: PermutedLinear {
weight: cx.named_tensor("Up"),
},
down_proj: PermutedLinear {
weight: cx.named_tensor("Down"),
},
gate_proj: Linear::new_permuted(hidden, intermediate, false, cx),
down_proj: Linear::new_permuted(intermediate, hidden, false, cx),
up_proj: Linear::new_permuted(hidden, intermediate, false, cx),
}
}
}
impl<const I: usize, const H: usize> SerializeModule for Mlp<I, H> {
impl SerializeModule for Mlp {
fn serialize(&self, s: &mut Serializer) {
s.module("ffn_gate", &self.gate_proj);
s.module("ffn_up", &self.up_proj);
@@ -63,122 +49,105 @@ impl<const I: usize, const H: usize> SerializeModule for Mlp<I, H> {
}
}
fn apply_rotary_embeddings_ggml<const N_HEADS: usize, Batch: Dimension, Seq: Dimension>(
input: GraphTensor<(Batch, Const<N_HEADS>, Seq, Const<HEAD_DIM>)>,
prev_seq: BigExpression,
) -> GraphTensor<(Batch, Const<N_HEADS>, Seq, Const<HEAD_DIM>)> {
fn apply_rotary_embeddings_ggml(input: GraphTensor, prev_seq: BigExpression) -> GraphTensor {
assert_eq!(input.shape.len(), 4); // batch, n_heads, seq, head_dim
let (batch, n_heads, seq, head_dim) = input.dims4();
// Get freqs
let freqs = (input.graph().arange::<Const<HEAD_DIM_OVER_2>>() * 2.0) / (HEAD_DIM as f32);
let freqs = (input.graph().arange(head_dim / 2) * 2.0) / (head_dim.to_usize().unwrap() as f32);
let freqs = 500_000_f32.pow(freqs);
let pos = input.graph().arange::<Seq>() + prev_seq;
let emb = pos.expand::<(_, Const<1>), _>().matmul(freqs.expand());
let pos = input.graph().arange(seq) + prev_seq;
let emb = pos.expand(1, 1).matmul(freqs.expand(0, 1));
// Split input into evens and odds
let split = input.reshape::<(Batch, Const<N_HEADS>, Seq, Const<HEAD_DIM_OVER_2>, Const<2>)>();
let x0: GraphTensor<(Batch, Const<N_HEADS>, Seq, Const<HEAD_DIM_OVER_2>, Const<1>)> =
split.slice((.., .., .., .., ..1)).realize();
let x1: GraphTensor<(Batch, Const<N_HEADS>, Seq, Const<HEAD_DIM_OVER_2>, Const<1>)> =
split.slice((.., .., .., .., 1..)).realize();
let split = input.reshape((batch, n_heads, seq, head_dim / 2, 2));
let x0 = split.slice((.., .., .., .., ..1));
let x1 = split.slice((.., .., .., .., 1..));
// Apply sin and cos embeddings
let x0_out = x0 * emb.cos().expand() - x1 * emb.sin().expand();
let x1_out = x0 * emb.sin().expand() + x1 * emb.cos().expand();
let x0_out = x0 * emb.cos().expand_to(x0.shape) - x1 * emb.sin().expand_to(x1.shape);
let x1_out = x0 * emb.sin().expand_to(x0.shape) + x1 * emb.cos().expand_to(x1.shape);
// Combine back into output
x0_out
.concat_along::<(Batch, Const<N_HEADS>, Seq, Const<HEAD_DIM_OVER_2>, Const<2>), Axis<4>, _>(
x1_out,
)
.reshape()
x0_out.concat_along(x1_out, 4).reshape(input.shape)
}
pub struct SelfAttention {
pub q_proj: GraphTensor<R2<HIDDEN_DIM, HIDDEN_DIM>>,
pub k_proj: GraphTensor<R2<ATTN_PROJ_DIM, HIDDEN_DIM>>,
pub v_proj: GraphTensor<R2<ATTN_PROJ_DIM, HIDDEN_DIM>>,
pub o_proj: GraphTensor<R2<HIDDEN_DIM, HIDDEN_DIM>>,
pub q_proj: GraphTensor, // Hidden -> hidden
pub k_proj: GraphTensor, // Proj dim -> hidden
pub v_proj: GraphTensor, // Proj dim -> hidden
pub o_proj: GraphTensor, // Hidden -> hidden
}
impl<Batch: Dimension, CurSeq: Dimension, PrevSeq: Dimension, TotSeq: Dimension>
Module<(
GraphTensor<(Batch, CurSeq, Const<HIDDEN_DIM>)>,
KVCache<Batch, PrevSeq>,
PhantomData<TotSeq>,
)> for SelfAttention
{
type Output = (
GraphTensor<(Batch, CurSeq, Const<HIDDEN_DIM>)>,
KVCache<Batch, TotSeq>,
);
fn forward(
&self,
(x, (k_cache, v_cache), _): (
GraphTensor<(Batch, CurSeq, Const<HIDDEN_DIM>)>,
KVCache<Batch, PrevSeq>,
PhantomData<TotSeq>,
),
) -> Self::Output {
impl Module<(GraphTensor, KVCache)> for SelfAttention {
type Output = (GraphTensor, KVCache);
fn forward(&self, (x, (k_cache, v_cache)): (GraphTensor, KVCache)) -> Self::Output {
// x: batch, seq, hidden
// cache: batch, kv_heads, prev_seq, head_dim
let (batch, seq, _) = x.dims3();
let (_, _, prev_seq, _) = k_cache.dims4();
// Apply the Projections
let queries = x
.matmul(self.q_proj.permute())
.reshape::<(Batch, CurSeq, Const<N_HEADS>, Const<HEAD_DIM>)>()
.permute::<_, Axes4<0, 2, 1, 3>>();
.matmul(self.q_proj.permute((1, 0)))
.reshape((batch, seq, N_HEADS, HEAD_DIM))
.permute((0, 2, 1, 3));
let keys = x
.matmul(self.k_proj.permute())
.reshape::<(Batch, CurSeq, Const<N_KV_HEADS>, Const<HEAD_DIM>)>()
.permute::<_, Axes4<0, 2, 1, 3>>();
.matmul(self.k_proj.permute((1, 0)))
.reshape((batch, seq, N_KV_HEADS, HEAD_DIM))
.permute((0, 2, 1, 3));
let values = x
.matmul(self.v_proj.permute())
.reshape::<(Batch, CurSeq, Const<N_KV_HEADS>, Const<HEAD_DIM>)>()
.permute::<_, Axes4<0, 2, 1, 3>>();
.matmul(self.v_proj.permute((1, 0)))
.reshape((batch, seq, N_KV_HEADS, HEAD_DIM))
.permute((0, 2, 1, 3));
// Rotary embed queries and keys
let queries = apply_rotary_embeddings_ggml(queries, PrevSeq::size().into());
let keys = apply_rotary_embeddings_ggml(keys, PrevSeq::size().into());
let queries = apply_rotary_embeddings_ggml(queries, prev_seq.big());
let keys = apply_rotary_embeddings_ggml(keys, prev_seq.big());
// Add KV cache
let keys = k_cache.concat_along::<_, Axis<2>, _>(keys);
let values = v_cache.concat_along::<_, Axis<2>, _>(values);
let keys = k_cache.concat_along(keys, 2);
let values = v_cache.concat_along(values, 2);
// Repeat the KV States for Grouped-Query Attention
let repeated_keys = keys.expand::<(_, _, Const<N_ATTENTION_GROUPS>, _, _), _>();
let repeated_values = values.expand::<(_, _, Const<N_ATTENTION_GROUPS>, _, _), _>();
let repeated_keys = keys.expand(2, N_ATTENTION_GROUPS);
let repeated_values = values.expand(2, N_ATTENTION_GROUPS);
// Calculate attention weights
let mut attention_weights = queries
.reshape::<(_, Const<N_KV_HEADS>, Const<N_ATTENTION_GROUPS>, _, _)>() // Split query heads into groups
.matmul(repeated_keys.permute())
.div((HEAD_DIM as f32).sqrt());
.reshape((batch, N_KV_HEADS, N_ATTENTION_GROUPS, seq, HEAD_DIM)) // Split query heads into groups
.matmul(repeated_keys.permute((0, 1, 2, 4, 3)))
/ (HEAD_DIM as f32).sqrt();
let attention_mask = self.k_proj.graph().triu::<CurSeq>(1) * f16::MIN.to_f32();
let attention_mask = self.k_proj.graph().triu(seq, 1) * f16::MIN.to_f32();
attention_weights += attention_mask
.pad::<(CurSeq, TotSeq)>(((0, 0), (TotSeq::size() - CurSeq::size(), 0)))
.expand();
.pad(((0, 0), (prev_seq, 0)))
.expand(0, batch)
.expand(1, N_KV_HEADS)
.expand(2, N_ATTENTION_GROUPS);
// Calculate final outputs
let output = attention_weights
.softmax::<Axis<4>>()
.softmax(4)
// Apply distribution to values
.matmul(repeated_values)
// Merge heads
.permute::<_, Axes5<0, 3, 1, 2, 4>>()
.reshape::<(Batch, CurSeq, Const<HIDDEN_DIM>)>();
.permute((0, 3, 1, 2, 4))
.reshape((batch, seq, HIDDEN_DIM));
let output = output
// Apply output projection
.matmul(self.o_proj.permute());
.matmul(self.o_proj.permute((1, 0)));
(output, (keys.contiguous(), values.contiguous())) // Cache needs to be contiguous for transferring to another graph
}
}
impl InitModule for SelfAttention {
fn initialize(cx: &mut Graph) -> Self {
impl SelfAttention {
pub fn new(cx: &mut Graph) -> Self {
Self {
q_proj: cx.named_tensor("Q Proj"),
k_proj: cx.named_tensor("K Proj"),
v_proj: cx.named_tensor("V Proj"),
o_proj: cx.named_tensor("O Proj"),
q_proj: cx.named_tensor("Q Proj", (HIDDEN_DIM, HIDDEN_DIM)),
k_proj: cx.named_tensor("K Proj", (ATTN_PROJ_DIM, HIDDEN_DIM)),
v_proj: cx.named_tensor("V Proj", (ATTN_PROJ_DIM, HIDDEN_DIM)),
o_proj: cx.named_tensor("O Proj", (HIDDEN_DIM, HIDDEN_DIM)),
}
}
}
@@ -194,54 +163,37 @@ impl SerializeModule for SelfAttention {
pub struct TransformerBlock {
pub attention: SelfAttention,
pub attention_norm: LayerNorm<HIDDEN_DIM>,
pub feed_forward: Mlp<MLP_DIM, HIDDEN_DIM>,
pub feed_forward_norm: LayerNorm<HIDDEN_DIM>,
pub attention_norm: LayerNorm,
pub feed_forward: Mlp,
pub feed_forward_norm: LayerNorm,
}
impl<Batch: Dimension, CurSeq: Dimension, PrevSeq: Dimension, TotSeq: Dimension>
Module<(
GraphTensor<(Batch, CurSeq, Const<HIDDEN_DIM>)>,
KVCache<Batch, PrevSeq>,
PhantomData<TotSeq>,
)> for TransformerBlock
{
type Output = (
GraphTensor<(Batch, CurSeq, Const<HIDDEN_DIM>)>,
KVCache<Batch, TotSeq>,
);
fn forward(
&self,
(mut x, cache, _): (
GraphTensor<(Batch, CurSeq, Const<HIDDEN_DIM>)>,
KVCache<Batch, PrevSeq>,
PhantomData<TotSeq>,
),
) -> Self::Output {
impl Module<(GraphTensor, KVCache)> for TransformerBlock {
type Output = (GraphTensor, KVCache);
fn forward(&self, (mut x, cache): (GraphTensor, KVCache)) -> Self::Output {
// Attention
let normed = self.attention_norm.forward(x);
let (y, cache) = self
.attention
.forward((normed, cache, PhantomData::<TotSeq>));
.forward((self.attention_norm.forward(x), cache));
// Residual Addition
// Residual
x += y;
// Feed Forward
let y = self.feed_forward.forward(self.feed_forward_norm.forward(x));
// Residual Addition
// Residual
(x + y, cache)
}
}
impl InitModule for TransformerBlock {
fn initialize(cx: &mut Graph) -> Self {
impl TransformerBlock {
pub fn new(cx: &mut Graph) -> Self {
Self {
attention: InitModule::initialize(cx),
attention_norm: LayerNorm::new(true, false, false, 1e-5, cx),
feed_forward: InitModule::initialize(cx),
feed_forward_norm: LayerNorm::new(true, false, false, 1e-5, cx),
attention: SelfAttention::new(cx),
attention_norm: LayerNorm::new(HIDDEN_DIM, true, false, false, 1e-5, cx),
feed_forward: Mlp::new(HIDDEN_DIM, MLP_DIM, cx),
feed_forward_norm: LayerNorm::new(HIDDEN_DIM, true, false, false, 1e-5, cx),
}
}
}
@@ -257,35 +209,16 @@ impl SerializeModule for TransformerBlock {
pub struct Llama {
// Token embeddings
pub embedding: Embedding<VOCAB_SIZE, HIDDEN_DIM>,
pub embedding: Embedding,
// Transformer layers
pub layers: Vec<TransformerBlock>,
// Norm + LM head
pub head: (
LayerNorm<HIDDEN_DIM>,
PermutedLinear<HIDDEN_DIM, VOCAB_SIZE>,
),
pub head: (LayerNorm, Linear),
}
impl<Batch: Dimension, CurSeq: Dimension, PrevSeq: Dimension, TotSeq: Dimension>
Module<(
GraphTensor<(Batch, CurSeq)>,
&[KVCache<Batch, PrevSeq>],
PhantomData<TotSeq>,
)> for Llama
{
type Output = (
GraphTensor<(Batch, CurSeq, Const<VOCAB_SIZE>)>,
Vec<KVCache<Batch, TotSeq>>,
);
fn forward(
&self,
(input, cache, _): (
GraphTensor<(Batch, CurSeq)>,
&[KVCache<Batch, PrevSeq>],
PhantomData<TotSeq>,
),
) -> Self::Output {
impl Module<(GraphTensor, &[KVCache])> for Llama {
type Output = (GraphTensor, Vec<KVCache>);
fn forward(&self, (input, cache): (GraphTensor, &[KVCache])) -> Self::Output {
// Embed tokens
let mut x = self.embedding.forward(input);
@@ -293,7 +226,7 @@ impl<Batch: Dimension, CurSeq: Dimension, PrevSeq: Dimension, TotSeq: Dimension>
let mut new_caches = vec![];
let mut new_cache;
for (i, layer) in self.layers.iter().enumerate() {
(x, new_cache) = layer.forward((x, cache[i], PhantomData::<TotSeq>));
(x, new_cache) = layer.forward((x, cache[i]));
new_caches.push(new_cache);
}
// Run through last norm and output projection
@@ -301,21 +234,15 @@ impl<Batch: Dimension, CurSeq: Dimension, PrevSeq: Dimension, TotSeq: Dimension>
}
}
impl InitModule for Llama {
fn initialize(cx: &mut Graph) -> Self {
impl Llama {
pub fn new(cx: &mut Graph) -> Self {
Self {
embedding: Embedding {
weight: cx.named_tensor("Embedding Weight"),
},
embedding: Embedding::new(VOCAB_SIZE, HIDDEN_DIM, cx),
head: (
LayerNorm::new(true, false, false, 1e-5, cx),
PermutedLinear {
weight: cx.tensor(),
},
LayerNorm::new(HIDDEN_DIM, true, false, false, 1e-5, cx),
Linear::new_permuted(HIDDEN_DIM, VOCAB_SIZE, false, cx),
),
layers: (0..NUM_LAYERS)
.map(|_| InitModule::initialize(cx))
.collect(),
layers: (0..NUM_LAYERS).map(|_| TransformerBlock::new(cx)).collect(),
}
}
}

View File

@@ -1,7 +1,5 @@
use std::{marker::PhantomData, ops::Div};
use luminal::prelude::{binary::F32Pow, *};
use luminal_nn::{Embedding, LayerNorm, PermutedLinear};
use luminal_nn::{Embedding, LayerNorm, Linear};
// Llama3 8B Config
pub const VOCAB_SIZE: usize = 128256;
@@ -13,49 +11,37 @@ pub const MLP_DIM: usize = 14336;
pub const N_ATTENTION_GROUPS: usize = N_HEADS / N_KV_HEADS;
pub const HEAD_DIM: usize = HIDDEN_DIM / N_HEADS;
pub const HEAD_DIM_OVER_2: usize = HEAD_DIM / 2;
pub const ATTN_PROJ_DIM: usize = HEAD_DIM * N_KV_HEADS;
pub type KVCache<Batch, Seq> = (
GraphTensor<(Batch, Const<N_KV_HEADS>, Seq, Const<HEAD_DIM>)>,
GraphTensor<(Batch, Const<N_KV_HEADS>, Seq, Const<HEAD_DIM>)>,
);
pub type KVCache = (GraphTensor, GraphTensor);
pub struct Mlp<const I: usize, const H: usize> {
pub gate_proj: PermutedLinear<H, I>,
pub down_proj: PermutedLinear<I, H>,
pub up_proj: PermutedLinear<H, I>,
pub struct Mlp {
pub gate_proj: Linear, // hidden -> intermediate
pub down_proj: Linear, // intermediate -> hidden
pub up_proj: Linear, // hidden -> intermediate
}
impl<const I: usize, const H: usize, Batch: Dimension, Batch1: Dimension>
Module<GraphTensor<(Batch, Batch1, Const<H>)>> for Mlp<I, H>
{
type Output = GraphTensor<(Batch, Batch1, Const<H>)>;
impl Module<GraphTensor> for Mlp {
type Output = GraphTensor;
fn forward(&self, input: GraphTensor<(Batch, Batch1, Const<H>)>) -> Self::Output {
fn forward(&self, input: GraphTensor) -> Self::Output {
let gate = self.gate_proj.forward(input).swish();
let up = self.up_proj.forward(input) * gate;
self.down_proj.forward(up)
}
}
impl<const I: usize, const H: usize> InitModule for Mlp<I, H> {
fn initialize(cx: &mut Graph) -> Self {
impl Mlp {
pub fn new(hidden: usize, intermediate: usize, cx: &mut Graph) -> Self {
Self {
gate_proj: PermutedLinear {
weight: cx.named_tensor("Gate"),
},
up_proj: PermutedLinear {
weight: cx.named_tensor("Up"),
},
down_proj: PermutedLinear {
weight: cx.named_tensor("Down"),
},
gate_proj: Linear::new_permuted(hidden, intermediate, false, cx),
down_proj: Linear::new_permuted(intermediate, hidden, false, cx),
up_proj: Linear::new_permuted(hidden, intermediate, false, cx),
}
}
}
impl<const I: usize, const H: usize> SerializeModule for Mlp<I, H> {
impl SerializeModule for Mlp {
fn serialize(&self, s: &mut Serializer) {
s.module("ffn_gate", &self.gate_proj);
s.module("ffn_up", &self.up_proj);
@@ -63,122 +49,105 @@ impl<const I: usize, const H: usize> SerializeModule for Mlp<I, H> {
}
}
fn apply_rotary_embeddings_ggml<const N_HEADS: usize, Batch: Dimension, Seq: Dimension>(
input: GraphTensor<(Batch, Const<N_HEADS>, Seq, Const<HEAD_DIM>)>,
prev_seq: BigExpression,
) -> GraphTensor<(Batch, Const<N_HEADS>, Seq, Const<HEAD_DIM>)> {
fn apply_rotary_embeddings_ggml(input: GraphTensor, prev_seq: BigExpression) -> GraphTensor {
assert_eq!(input.shape.len(), 4); // batch, n_heads, seq, head_dim
let (batch, n_heads, seq, head_dim) = input.dims4();
// Get freqs
let freqs = (input.graph().arange::<Const<HEAD_DIM_OVER_2>>() * 2.0) / (HEAD_DIM as f32);
let freqs = (input.graph().arange(head_dim / 2) * 2.0) / (head_dim.to_usize().unwrap() as f32);
let freqs = 500_000_f32.pow(freqs);
let pos = input.graph().arange::<Seq>() + prev_seq;
let emb = pos.expand::<(_, Const<1>), _>().matmul(freqs.expand());
let pos = input.graph().arange(seq) + prev_seq;
let emb = pos.expand(1, 1).matmul(freqs.expand(0, 1));
// Split input into evens and odds
let split = input.reshape::<(Batch, Const<N_HEADS>, Seq, Const<HEAD_DIM_OVER_2>, Const<2>)>();
let x0: GraphTensor<(Batch, Const<N_HEADS>, Seq, Const<HEAD_DIM_OVER_2>, Const<1>)> =
split.slice((.., .., .., .., ..1)).realize();
let x1: GraphTensor<(Batch, Const<N_HEADS>, Seq, Const<HEAD_DIM_OVER_2>, Const<1>)> =
split.slice((.., .., .., .., 1..)).realize();
let split = input.reshape((batch, n_heads, seq, head_dim / 2, 2));
let x0 = split.slice((.., .., .., .., ..1));
let x1 = split.slice((.., .., .., .., 1..));
// Apply sin and cos embeddings
let x0_out = x0 * emb.cos().expand() - x1 * emb.sin().expand();
let x1_out = x0 * emb.sin().expand() + x1 * emb.cos().expand();
let x0_out = x0 * emb.cos().expand_to(x0.shape) - x1 * emb.sin().expand_to(x1.shape);
let x1_out = x0 * emb.sin().expand_to(x0.shape) + x1 * emb.cos().expand_to(x1.shape);
// Combine back into output
x0_out
.concat_along::<(Batch, Const<N_HEADS>, Seq, Const<HEAD_DIM_OVER_2>, Const<2>), Axis<4>, _>(
x1_out,
)
.reshape()
x0_out.concat_along(x1_out, 4).reshape(input.shape)
}
pub struct SelfAttention {
pub q_proj: GraphTensor<R2<HIDDEN_DIM, HIDDEN_DIM>>,
pub k_proj: GraphTensor<R2<ATTN_PROJ_DIM, HIDDEN_DIM>>,
pub v_proj: GraphTensor<R2<ATTN_PROJ_DIM, HIDDEN_DIM>>,
pub o_proj: GraphTensor<R2<HIDDEN_DIM, HIDDEN_DIM>>,
pub q_proj: GraphTensor, // Hidden -> hidden
pub k_proj: GraphTensor, // Proj dim -> hidden
pub v_proj: GraphTensor, // Proj dim -> hidden
pub o_proj: GraphTensor, // Hidden -> hidden
}
impl<Batch: Dimension, CurSeq: Dimension, PrevSeq: Dimension, TotSeq: Dimension>
Module<(
GraphTensor<(Batch, CurSeq, Const<HIDDEN_DIM>)>,
KVCache<Batch, PrevSeq>,
PhantomData<TotSeq>,
)> for SelfAttention
{
type Output = (
GraphTensor<(Batch, CurSeq, Const<HIDDEN_DIM>)>,
KVCache<Batch, TotSeq>,
);
fn forward(
&self,
(x, (k_cache, v_cache), _): (
GraphTensor<(Batch, CurSeq, Const<HIDDEN_DIM>)>,
KVCache<Batch, PrevSeq>,
PhantomData<TotSeq>,
),
) -> Self::Output {
impl Module<(GraphTensor, KVCache)> for SelfAttention {
type Output = (GraphTensor, KVCache);
fn forward(&self, (x, (k_cache, v_cache)): (GraphTensor, KVCache)) -> Self::Output {
// x: batch, seq, hidden
// cache: batch, kv_heads, prev_seq, head_dim
let (batch, seq, _) = x.dims3();
let (_, _, prev_seq, _) = k_cache.dims4();
// Apply the Projections
let queries = x
.matmul(self.q_proj.permute())
.reshape::<(Batch, CurSeq, Const<N_HEADS>, Const<HEAD_DIM>)>()
.permute::<_, Axes4<0, 2, 1, 3>>();
.matmul(self.q_proj.permute((1, 0)))
.reshape((batch, seq, N_HEADS, HEAD_DIM))
.permute((0, 2, 1, 3));
let keys = x
.matmul(self.k_proj.permute())
.reshape::<(Batch, CurSeq, Const<N_KV_HEADS>, Const<HEAD_DIM>)>()
.permute::<_, Axes4<0, 2, 1, 3>>();
.matmul(self.k_proj.permute((1, 0)))
.reshape((batch, seq, N_KV_HEADS, HEAD_DIM))
.permute((0, 2, 1, 3));
let values = x
.matmul(self.v_proj.permute())
.reshape::<(Batch, CurSeq, Const<N_KV_HEADS>, Const<HEAD_DIM>)>()
.permute::<_, Axes4<0, 2, 1, 3>>();
.matmul(self.v_proj.permute((1, 0)))
.reshape((batch, seq, N_KV_HEADS, HEAD_DIM))
.permute((0, 2, 1, 3));
// Rotary embed queries and keys
let queries = apply_rotary_embeddings_ggml(queries, PrevSeq::size().big());
let keys = apply_rotary_embeddings_ggml(keys, PrevSeq::size().big());
let queries = apply_rotary_embeddings_ggml(queries, prev_seq.big());
let keys = apply_rotary_embeddings_ggml(keys, prev_seq.big());
// Add KV cache
let keys = k_cache.concat_along::<_, Axis<2>, _>(keys);
let values = v_cache.concat_along::<_, Axis<2>, _>(values);
let keys = k_cache.concat_along(keys, 2);
let values = v_cache.concat_along(values, 2);
// Repeat the KV States for Grouped-Query Attention
let repeated_keys = keys.expand::<(_, _, Const<N_ATTENTION_GROUPS>, _, _), _>();
let repeated_values = values.expand::<(_, _, Const<N_ATTENTION_GROUPS>, _, _), _>();
let repeated_keys = keys.expand(2, N_ATTENTION_GROUPS);
let repeated_values = values.expand(2, N_ATTENTION_GROUPS);
// Calculate attention weights
let mut attention_weights = queries
.reshape::<(_, Const<N_KV_HEADS>, Const<N_ATTENTION_GROUPS>, _, _)>() // Split query heads into groups
.matmul(repeated_keys.permute())
.div((HEAD_DIM as f32).sqrt());
.reshape((batch, N_KV_HEADS, N_ATTENTION_GROUPS, seq, HEAD_DIM)) // Split query heads into groups
.matmul(repeated_keys.permute((0, 1, 2, 4, 3)))
/ (HEAD_DIM as f32).sqrt();
let attention_mask = self.k_proj.graph().triu::<CurSeq>(1) * f16::MIN.to_f32();
let attention_mask = self.k_proj.graph().triu(seq, 1) * f16::MIN.to_f32();
attention_weights += attention_mask
.pad::<(CurSeq, TotSeq)>(((0, 0), (TotSeq::size() - CurSeq::size(), 0)))
.expand();
.pad(((0, 0), (prev_seq, 0)))
.expand(0, batch)
.expand(1, N_KV_HEADS)
.expand(2, N_ATTENTION_GROUPS);
// Calculate final outputs
let output = attention_weights
.softmax::<Axis<4>>()
.softmax(4)
// Apply distribution to values
.matmul(repeated_values)
// Merge heads
.permute::<_, Axes5<0, 3, 1, 2, 4>>()
.reshape::<(Batch, CurSeq, Const<HIDDEN_DIM>)>();
.permute((0, 3, 1, 2, 4))
.reshape((batch, seq, HIDDEN_DIM));
let output = output
// Apply output projection
.matmul(self.o_proj.permute());
.matmul(self.o_proj.permute((1, 0)));
(output, (keys.contiguous(), values.contiguous())) // Cache needs to be contiguous for transferring to another graph
}
}
impl InitModule for SelfAttention {
fn initialize(cx: &mut Graph) -> Self {
impl SelfAttention {
pub fn new(cx: &mut Graph) -> Self {
Self {
q_proj: cx.named_tensor("Q Proj"),
k_proj: cx.named_tensor("K Proj"),
v_proj: cx.named_tensor("V Proj"),
o_proj: cx.named_tensor("O Proj"),
q_proj: cx.named_tensor("Q Proj", (HIDDEN_DIM, HIDDEN_DIM)),
k_proj: cx.named_tensor("K Proj", (ATTN_PROJ_DIM, HIDDEN_DIM)),
v_proj: cx.named_tensor("V Proj", (ATTN_PROJ_DIM, HIDDEN_DIM)),
o_proj: cx.named_tensor("O Proj", (HIDDEN_DIM, HIDDEN_DIM)),
}
}
}
@@ -194,54 +163,37 @@ impl SerializeModule for SelfAttention {
pub struct TransformerBlock {
pub attention: SelfAttention,
pub attention_norm: LayerNorm<HIDDEN_DIM>,
pub feed_forward: Mlp<MLP_DIM, HIDDEN_DIM>,
pub feed_forward_norm: LayerNorm<HIDDEN_DIM>,
pub attention_norm: LayerNorm,
pub feed_forward: Mlp,
pub feed_forward_norm: LayerNorm,
}
impl<Batch: Dimension, CurSeq: Dimension, PrevSeq: Dimension, TotSeq: Dimension>
Module<(
GraphTensor<(Batch, CurSeq, Const<HIDDEN_DIM>)>,
KVCache<Batch, PrevSeq>,
PhantomData<TotSeq>,
)> for TransformerBlock
{
type Output = (
GraphTensor<(Batch, CurSeq, Const<HIDDEN_DIM>)>,
KVCache<Batch, TotSeq>,
);
fn forward(
&self,
(mut x, cache, _): (
GraphTensor<(Batch, CurSeq, Const<HIDDEN_DIM>)>,
KVCache<Batch, PrevSeq>,
PhantomData<TotSeq>,
),
) -> Self::Output {
impl Module<(GraphTensor, KVCache)> for TransformerBlock {
type Output = (GraphTensor, KVCache);
fn forward(&self, (mut x, cache): (GraphTensor, KVCache)) -> Self::Output {
// Attention
let normed = self.attention_norm.forward(x);
let (y, cache) = self
.attention
.forward((normed, cache, PhantomData::<TotSeq>));
.forward((self.attention_norm.forward(x), cache));
// Residual Addition
// Residual
x += y;
// Feed Forward
let y = self.feed_forward.forward(self.feed_forward_norm.forward(x));
// Residual Addition
// Residual
(x + y, cache)
}
}
impl InitModule for TransformerBlock {
fn initialize(cx: &mut Graph) -> Self {
impl TransformerBlock {
pub fn new(cx: &mut Graph) -> Self {
Self {
attention: InitModule::initialize(cx),
attention_norm: LayerNorm::new(true, false, false, 1e-5, cx),
feed_forward: InitModule::initialize(cx),
feed_forward_norm: LayerNorm::new(true, false, false, 1e-5, cx),
attention: SelfAttention::new(cx),
attention_norm: LayerNorm::new(HIDDEN_DIM, true, false, false, 1e-5, cx),
feed_forward: Mlp::new(HIDDEN_DIM, MLP_DIM, cx),
feed_forward_norm: LayerNorm::new(HIDDEN_DIM, true, false, false, 1e-5, cx),
}
}
}
@@ -255,36 +207,18 @@ impl SerializeModule for TransformerBlock {
}
}
pub struct MistralLM {
pub struct Llama {
// Token embeddings
pub embedding: Embedding<VOCAB_SIZE, HIDDEN_DIM>,
pub embedding: Embedding,
// Transformer layers
pub layers: Vec<TransformerBlock>,
// Final Norm layer
pub norm: LayerNorm<HIDDEN_DIM>,
// LM Head Layer
pub lm_head: GraphTensor<R2<VOCAB_SIZE, HIDDEN_DIM>>,
// Norm + LM head
pub head: (LayerNorm, Linear),
}
impl<Batch: Dimension, CurSeq: Dimension, PrevSeq: Dimension, TotSeq: Dimension>
Module<(
GraphTensor<(Batch, CurSeq)>,
&[KVCache<Batch, PrevSeq>],
PhantomData<TotSeq>,
)> for MistralLM
{
type Output = (
GraphTensor<(Batch, CurSeq, Const<VOCAB_SIZE>)>,
Vec<KVCache<Batch, TotSeq>>,
);
fn forward(
&self,
(input, cache, _): (
GraphTensor<(Batch, CurSeq)>,
&[KVCache<Batch, PrevSeq>],
PhantomData<TotSeq>,
),
) -> Self::Output {
impl Module<(GraphTensor, &[KVCache])> for Llama {
type Output = (GraphTensor, Vec<KVCache>);
fn forward(&self, (input, cache): (GraphTensor, &[KVCache])) -> Self::Output {
// Embed tokens
let mut x = self.embedding.forward(input);
@@ -292,36 +226,32 @@ impl<Batch: Dimension, CurSeq: Dimension, PrevSeq: Dimension, TotSeq: Dimension>
let mut new_caches = vec![];
let mut new_cache;
for (i, layer) in self.layers.iter().enumerate() {
(x, new_cache) = layer.forward((x, cache[i], PhantomData::<TotSeq>));
(x, new_cache) = layer.forward((x, cache[i]));
new_caches.push(new_cache);
}
// Run through last norm and output projection
let output = self.norm.forward(x).matmul(self.lm_head.permute());
(output, new_caches)
(self.head.forward(x), new_caches)
}
}
impl InitModule for MistralLM {
fn initialize(cx: &mut Graph) -> Self {
impl Llama {
pub fn new(cx: &mut Graph) -> Self {
Self {
embedding: Embedding {
weight: cx.named_tensor("Embedding Weight"),
},
norm: LayerNorm::new(true, false, false, 1e-5, cx),
lm_head: cx.named_tensor("LM Head"),
layers: (0..NUM_LAYERS)
.map(|_| InitModule::initialize(cx))
.collect(),
embedding: Embedding::new(VOCAB_SIZE, HIDDEN_DIM, cx),
head: (
LayerNorm::new(HIDDEN_DIM, true, false, false, 1e-5, cx),
Linear::new_permuted(HIDDEN_DIM, VOCAB_SIZE, false, cx),
),
layers: (0..NUM_LAYERS).map(|_| TransformerBlock::new(cx)).collect(),
}
}
}
impl SerializeModule for MistralLM {
impl SerializeModule for Llama {
fn serialize(&self, s: &mut Serializer) {
s.module("token_embd", &self.embedding);
s.module("output_norm", &self.norm);
s.tensor("output/weight", self.lm_head);
s.module("output_norm", &self.head.0);
s.module("output", &self.head.1);
for (i, layer) in self.layers.iter().enumerate() {
s.module(&format!("blk/{i}"), layer);
}

View File

@@ -1,6 +1,5 @@
use std::{
io::{self, Write},
marker::PhantomData,
path::Path,
time::Instant,
};
@@ -11,18 +10,16 @@ use tokenizers::Tokenizer;
use crate::llama::{
loader,
model::{KVCache, MistralLM, HEAD_DIM, NUM_LAYERS, N_KV_HEADS},
model::{KVCache, Llama, HEAD_DIM, NUM_LAYERS, N_KV_HEADS},
};
use super::model::VOCAB_SIZE;
/// Define the model
pub struct Model {
pub graph: Box<Graph>,
pub input: GraphTensor<(Const<1>, Dyn<'s'>)>,
pub input: GraphTensor,
kv_cache_src_set: Vec<NodeIndex>,
kv_cache_dest_set: Vec<NodeIndex>,
logits: GraphTensor<R3<1, 1, { VOCAB_SIZE }>>,
logits: GraphTensor,
pub tokenizer: Tokenizer,
pub last_generated_token: Option<u32>,
}
@@ -50,19 +47,23 @@ impl Model {
// Set up graph
let mut cx = Box::new(Graph::new());
let mut input = cx.named_tensor::<(Const<1>, Dyn<'s'>)>("Input");
let mut cache_src: Vec<KVCache<Const<1>, Dyn<'p'>>> = (0..NUM_LAYERS)
.map(|_| (cx.named_tensor("Key Cache"), cx.named_tensor("Value Cache")))
let mut input = cx.named_tensor("Input", (1, 's'));
let mut cache_src: Vec<KVCache> = (0..NUM_LAYERS)
.map(|_| {
(
cx.named_tensor("Key Cache", (1, N_KV_HEADS, 'p', HEAD_DIM)),
cx.named_tensor("Value Cache", (1, N_KV_HEADS, 'p', HEAD_DIM)),
)
})
.collect();
cache_src.set_dyn(vec![], &[1, N_KV_HEADS, 0, HEAD_DIM]);
let model = MistralLM::initialize(&mut cx);
cache_src.set_dyn(vec![], (1, N_KV_HEADS, 0, HEAD_DIM));
let model = Llama::new(&mut cx);
let mut model_weights = params(&model);
cx.keep_tensors(&model_weights);
let (logits, mut cache_dest) = model.forward((input, &cache_src, PhantomData::<Dyn<'t'>>));
let (logits, mut cache_dest) = model.forward((input, &cache_src));
let mut logits = logits
.slice((.., (Expression::from('s') - 1).., ..))
.retrieve()
.realize();
.retrieve();
cache_dest.keep();
// Set up model loading
@@ -105,7 +106,7 @@ impl Model {
print!("Loading model");
io::stdout().flush().unwrap();
let now = Instant::now();
input.set_dyn(vec![1.], &[1, 1]);
input.set_dyn(vec![1.], (1, 1));
cx.set_dyn_dim('t', 1);
cx.execute();
logits.drop();
@@ -159,7 +160,7 @@ impl Model {
self.graph.set_dyn_dim('t', seq_len);
self.input.set_dyn(
input_ids.iter().map(|i| *i as f32).collect::<Vec<_>>(),
&[1, input_ids.len()],
(1, input_ids.len()),
);
// First token output (from prompt processing)
@@ -183,7 +184,7 @@ impl Model {
// Set the data
self.graph.set_dyn_dim('p', seq_len - 1);
self.graph.set_dyn_dim('t', seq_len);
self.input.set_dyn(vec![output_id as f32], &[1, 1]);
self.input.set_dyn(vec![output_id as f32], (1, 1));
// Execute the graph
self.graph.execute();

View File

@@ -1,12 +1,12 @@
use std::{
io::{self, Write},
marker::PhantomData,
time::Instant,
};
use clap::Parser;
use colored::Colorize;
use itertools::Itertools;
use model::{Phi, HEAD_DIM, N_HEADS};
use tokenizers::Tokenizer;
mod gguf;
@@ -39,15 +39,20 @@ fn main() {
// Set up graph
let mut cx = Graph::new();
let mut input = cx.named_tensor::<(Const<1>, Dyn<'s'>)>("Input");
let mut cache_src: Vec<KVCache<Const<1>, Dyn<'p'>>> = (0..model::NUM_LAYERS)
.map(|_| (cx.named_tensor("Key Cache"), cx.named_tensor("Value Cache")))
let mut input = cx.named_tensor("Input", (1, 's'));
let mut cache_src: Vec<KVCache> = (0..model::NUM_LAYERS)
.map(|_| {
(
cx.named_tensor("Key Cache", (1, N_HEADS, 'p', HEAD_DIM)),
cx.named_tensor("Value Cache", (1, N_HEADS, 'p', HEAD_DIM)),
)
})
.collect();
cache_src.set_dyn(vec![], &[1, model::N_HEADS, 0, model::HEAD_DIM]);
let model = model::Phi::initialize(&mut cx);
cache_src.set_dyn(vec![], (1, N_HEADS, 0, HEAD_DIM));
let model = Phi::new(&mut cx);
let mut model_weights = params(&model);
cx.keep_tensors(&model_weights);
let (logits, mut cache_dest) = model.forward((input, &cache_src, PhantomData::<Dyn<'t'>>));
let (logits, mut cache_dest) = model.forward((input, &cache_src));
let mut logits = logits
.slice((.., (Expression::from('s') - 1).., ..))
.retrieve();
@@ -73,7 +78,7 @@ fn main() {
luminal_metal::BufferCompilers::default(),
),
#[cfg(feature = "cuda")]
luminal_cuda::CudaQuantizedCompiler::<f32>::new(q_weights),
luminal_cuda::CudaQuantizedCompiler::<f16>::new(q_weights),
#[cfg(all(not(feature = "metal"), not(feature = "cuda")))]
luminal_cpu::CPUCompiler::default(),
),
@@ -86,33 +91,32 @@ fn main() {
),
);
let cache_src = downstream(&cache_src, &cx);
let cache_dest = cache_dest.to_ids();
println!("\t\t - {}ms", now.elapsed().as_millis());
// Initial forward pass to load weights
print!("Loading model");
io::stdout().flush().unwrap();
let now = Instant::now();
input.set_dyn(vec![1.], &[1, 1]);
input.set_dyn(vec![1.], (1, 1));
cx.set_dyn_dim('t', 1);
cx.execute();
logits.drop();
cx.drop_tensors(&cache_dest);
transfer_data_same_graph(&cache_dest, &cache_src, &mut cx);
println!("\t\t - {}ms", now.elapsed().as_millis());
// Now that weights are loaded, delete the loading nodes so they don't run again
delete_inputs(&cache_src, &mut cx);
delete_inputs(&downstream(model_weights, &cx), &mut cx);
// Run prompt processing pass
let mut input_ids = tokenizer
.encode(&cli_args.prompt as &str, true)
let input_ids = tokenizer
.encode(&cli_args.prompt as &str, false)
.unwrap()
.get_ids()
.to_vec();
input_ids.insert(0, 1);
input.set_dyn(
input_ids.iter().map(|i| *i as f32).collect::<Vec<_>>(),
&[1, input_ids.len()],
(1, input_ids.len()),
);
cx.set_dyn_dim('t', input_ids.len());
print!("Processing Prompt");
@@ -125,14 +129,13 @@ fn main() {
1000.0 * (input_ids.len() as f64) / (elapsed_ms as f64),
input_ids.len()
);
delete_inputs(&cache_src, &mut cx);
let mut output_ids = vec![argmax(&logits.data())];
logits.drop();
// Decode token
print!("{}", cli_args.prompt.white().bold());
let out_str = tokenizer.decode(&output_ids, false).unwrap().bright_green();
let mut prev_output_len = out_str.len();
print!("{out_str}");
let out = tokenizer.decode(&output_ids, false).unwrap().bright_green();
print!("{out}");
io::stdout().flush().unwrap();
// Swap caches
@@ -140,10 +143,10 @@ fn main() {
// Decode loop
let start_decode = std::time::Instant::now();
let mut prev_output_len = out.len();
for _ in 0..cli_args.gen_tokens {
input.set_dyn(vec![*output_ids.last().unwrap() as f32], &[1, 1]);
input.set_dyn(vec![*output_ids.last().unwrap() as f32], (1, 1));
cx.set_dyn_dim('p', input_ids.len() + output_ids.len() - 1);
cx.set_dyn_dim('t', input_ids.len() + output_ids.len());
cx.execute();
// Sample tokens
@@ -166,9 +169,8 @@ fn main() {
}
println!();
let avg_token_time = (std::time::Instant::now() - start_decode).as_micros() as f32
/ (output_ids.len() - 1) as f32
/ 1000.0;
let avg_token_time =
start_decode.elapsed().as_micros() as f32 / (output_ids.len() - 1) as f32 / 1000.0;
println!(
"\nAverage token generated in {:.2}ms\t - ({:.2} tok/s)",
avg_token_time,

View File

@@ -1,9 +1,7 @@
use std::marker::PhantomData;
use luminal::prelude::{binary::F32Pow, *};
use luminal_nn::{Embedding, LayerNorm, PermutedLinear};
use luminal_nn::{Embedding, LayerNorm, Linear};
// Llama3 8B Config
// Phi3 mini Config
pub const VOCAB_SIZE: usize = 32064;
pub const HIDDEN_DIM: usize = 3072;
pub const NUM_LAYERS: usize = 32;
@@ -11,51 +9,37 @@ pub const N_HEADS: usize = 32;
pub const MLP_DIM: usize = 8192;
pub const HEAD_DIM: usize = HIDDEN_DIM / N_HEADS;
pub const HEAD_DIM_OVER_2: usize = HEAD_DIM / 2;
pub const ATTN_PROJ_DIM: usize = HEAD_DIM * N_HEADS;
pub type KVCache<Batch, Seq> = (
GraphTensor<(Batch, Const<N_HEADS>, Seq, Const<HEAD_DIM>)>,
GraphTensor<(Batch, Const<N_HEADS>, Seq, Const<HEAD_DIM>)>,
);
pub type KVCache = (GraphTensor, GraphTensor);
pub struct Mlp<const I: usize, const H: usize> {
pub gate_proj: PermutedLinear<H, I>,
pub down_proj: PermutedLinear<I, H>,
pub up_proj: PermutedLinear<H, I>,
pub struct Mlp {
pub gate_proj: Linear, // hidden -> intermediate
pub down_proj: Linear, // intermediate -> hidden
pub up_proj: Linear, // hidden -> intermediate
}
impl<Sh: Shape, Im: Shape, const I: usize, const H: usize> Module<GraphTensor<Sh>> for Mlp<I, H>
where
GraphTensor<Sh>: Matmul<R2<H, I>, Output = GraphTensor<Im>>,
GraphTensor<Im>: Matmul<R2<I, H>, Output = GraphTensor<Sh>>,
{
type Output = GraphTensor<Sh>;
impl Module<GraphTensor> for Mlp {
type Output = GraphTensor;
fn forward(&self, input: GraphTensor<Sh>) -> Self::Output {
fn forward(&self, input: GraphTensor) -> Self::Output {
let gate = self.gate_proj.forward(input).swish();
let up = self.up_proj.forward(input) * gate;
self.down_proj.forward(up)
}
}
impl<const I: usize, const H: usize> InitModule for Mlp<I, H> {
fn initialize(cx: &mut Graph) -> Self {
impl Mlp {
pub fn new(hidden: usize, intermediate: usize, cx: &mut Graph) -> Self {
Self {
gate_proj: PermutedLinear {
weight: cx.named_tensor("Gate"),
},
up_proj: PermutedLinear {
weight: cx.named_tensor("Up"),
},
down_proj: PermutedLinear {
weight: cx.named_tensor("Down"),
},
gate_proj: Linear::new_permuted(hidden, intermediate, false, cx),
down_proj: Linear::new_permuted(intermediate, hidden, false, cx),
up_proj: Linear::new_permuted(hidden, intermediate, false, cx),
}
}
}
impl<const I: usize, const H: usize> SerializeModule for Mlp<I, H> {
impl SerializeModule for Mlp {
fn serialize(&self, s: &mut Serializer) {
s.module("ffn_gate", &self.gate_proj);
s.module("ffn_up", &self.up_proj);
@@ -63,115 +47,98 @@ impl<const I: usize, const H: usize> SerializeModule for Mlp<I, H> {
}
}
fn apply_rotary_embeddings_ggml<const N_HEADS: usize, Batch: Dimension, Seq: Dimension>(
input: GraphTensor<(Batch, Const<N_HEADS>, Seq, Const<HEAD_DIM>)>,
prev_seq: BigExpression,
) -> GraphTensor<(Batch, Const<N_HEADS>, Seq, Const<HEAD_DIM>)> {
fn apply_rotary_embeddings_ggml(input: GraphTensor, prev_seq: BigExpression) -> GraphTensor {
assert_eq!(input.shape.len(), 4); // batch, n_heads, seq, head_dim
let (batch, n_heads, seq, head_dim) = input.dims4();
// Get freqs
let freqs = (input.graph().arange::<Const<HEAD_DIM_OVER_2>>() * 2.0) / (HEAD_DIM as f32);
let freqs = (input.graph().arange(head_dim / 2) * 2.0) / (head_dim.to_usize().unwrap() as f32);
let freqs = 10_000_f32.pow(freqs);
let pos = input.graph().arange::<Seq>() + prev_seq;
let emb = pos.expand::<(_, Const<1>), _>().matmul(freqs.expand());
let pos = input.graph().arange(seq) + prev_seq;
let emb = pos.expand(1, 1).matmul(freqs.expand(0, 1));
// Split input into evens and odds
let split = input.reshape::<(Batch, Const<N_HEADS>, Seq, Const<HEAD_DIM_OVER_2>, Const<2>)>();
let x0: GraphTensor<(Batch, Const<N_HEADS>, Seq, Const<HEAD_DIM_OVER_2>, Const<1>)> =
split.slice((.., .., .., .., ..1)).realize();
let x1: GraphTensor<(Batch, Const<N_HEADS>, Seq, Const<HEAD_DIM_OVER_2>, Const<1>)> =
split.slice((.., .., .., .., 1..)).realize();
let split = input.reshape((batch, n_heads, seq, head_dim / 2, 2));
let x0 = split.slice((.., .., .., .., ..1));
let x1 = split.slice((.., .., .., .., 1..));
// Apply sin and cos embeddings
let x0_out = x0 * emb.cos().expand() - x1 * emb.sin().expand();
let x1_out = x0 * emb.sin().expand() + x1 * emb.cos().expand();
let x0_out = x0 * emb.cos().expand_to(x0.shape) - x1 * emb.sin().expand_to(x1.shape);
let x1_out = x0 * emb.sin().expand_to(x0.shape) + x1 * emb.cos().expand_to(x1.shape);
// Combine back into output
x0_out
.concat_along::<(Batch, Const<N_HEADS>, Seq, Const<HEAD_DIM_OVER_2>, Const<2>), Axis<4>, _>(
x1_out,
)
.reshape()
x0_out.concat_along(x1_out, 4).reshape(input.shape)
}
pub struct SelfAttention {
pub q_proj: GraphTensor<R2<HIDDEN_DIM, HIDDEN_DIM>>,
pub k_proj: GraphTensor<R2<ATTN_PROJ_DIM, HIDDEN_DIM>>,
pub v_proj: GraphTensor<R2<ATTN_PROJ_DIM, HIDDEN_DIM>>,
pub o_proj: GraphTensor<R2<HIDDEN_DIM, HIDDEN_DIM>>,
pub q_proj: GraphTensor, // Hidden -> hidden
pub k_proj: GraphTensor, // Proj dim -> hidden
pub v_proj: GraphTensor, // Proj dim -> hidden
pub o_proj: GraphTensor, // Hidden -> hidden
}
impl<Batch: Dimension, CurSeq: Dimension, PrevSeq: Dimension, TotSeq: Dimension>
Module<(
GraphTensor<(Batch, CurSeq, Const<HIDDEN_DIM>)>,
KVCache<Batch, PrevSeq>,
PhantomData<TotSeq>,
)> for SelfAttention
{
type Output = (
GraphTensor<(Batch, CurSeq, Const<HIDDEN_DIM>)>,
KVCache<Batch, TotSeq>,
);
fn forward(
&self,
(x, (k_cache, v_cache), _): (
GraphTensor<(Batch, CurSeq, Const<HIDDEN_DIM>)>,
KVCache<Batch, PrevSeq>,
PhantomData<TotSeq>,
),
) -> Self::Output {
impl Module<(GraphTensor, KVCache)> for SelfAttention {
type Output = (GraphTensor, KVCache);
fn forward(&self, (x, (k_cache, v_cache)): (GraphTensor, KVCache)) -> Self::Output {
// x: batch, seq, hidden
// cache: batch, kv_heads, prev_seq, head_dim
let (batch, seq, _) = x.dims3();
let (_, _, prev_seq, _) = k_cache.dims4();
// Apply the Projections
let queries = x
.matmul(self.q_proj.permute())
.reshape::<(Batch, CurSeq, Const<N_HEADS>, Const<HEAD_DIM>)>()
.permute::<_, Axes4<0, 2, 1, 3>>();
.matmul(self.q_proj.permute((1, 0)))
.reshape((batch, seq, N_HEADS, HEAD_DIM))
.permute((0, 2, 1, 3));
let keys = x
.matmul(self.k_proj.permute())
.reshape::<(Batch, CurSeq, Const<N_HEADS>, Const<HEAD_DIM>)>()
.permute::<_, Axes4<0, 2, 1, 3>>();
.matmul(self.k_proj.permute((1, 0)))
.reshape((batch, seq, N_HEADS, HEAD_DIM))
.permute((0, 2, 1, 3));
let values = x
.matmul(self.v_proj.permute())
.reshape::<(Batch, CurSeq, Const<N_HEADS>, Const<HEAD_DIM>)>()
.permute::<_, Axes4<0, 2, 1, 3>>();
.matmul(self.v_proj.permute((1, 0)))
.reshape((batch, seq, N_HEADS, HEAD_DIM))
.permute((0, 2, 1, 3));
// Rotary embed queries and keys
let queries = apply_rotary_embeddings_ggml(queries, PrevSeq::size().into());
let keys = apply_rotary_embeddings_ggml(keys, PrevSeq::size().into());
let queries = apply_rotary_embeddings_ggml(queries, prev_seq.big());
let keys = apply_rotary_embeddings_ggml(keys, prev_seq.big());
// Add KV cache
let keys = k_cache.concat_along::<_, Axis<2>, _>(keys);
let values = v_cache.concat_along::<_, Axis<2>, _>(values);
let keys = k_cache.concat_along(keys, 2);
let values = v_cache.concat_along(values, 2);
// Calculate attention weights
let mut attention_weights = queries.matmul(keys.permute()) / (HEAD_DIM as f32).sqrt();
let mut attention_weights =
queries.matmul(keys.permute((0, 1, 3, 2))) / (HEAD_DIM as f32).sqrt();
let attention_mask = self.k_proj.graph().triu::<CurSeq>(1) * f16::MIN.to_f32();
let attention_mask = self.k_proj.graph().triu(seq, 1) * f16::MIN.to_f32();
attention_weights += attention_mask
.pad::<(CurSeq, TotSeq)>(((0, 0), (TotSeq::size() - CurSeq::size(), 0)))
.expand();
.pad(((0, 0), (prev_seq, 0)))
.expand(0, batch)
.expand(1, N_HEADS);
// Calculate final outputs
let output = attention_weights
.softmax::<Axis<3>>()
.softmax(3)
// Apply distribution to values
.matmul(values)
// Merge heads
.permute::<_, Axes4<0, 2, 1, 3>>()
.reshape::<(Batch, CurSeq, Const<HIDDEN_DIM>)>();
.permute((0, 2, 1, 3))
.reshape((batch, seq, HIDDEN_DIM));
let output = output
// Apply output projection
.matmul(self.o_proj.permute());
.matmul(self.o_proj.permute((1, 0)));
(output, (keys.contiguous(), values.contiguous())) // Cache needs to be contiguous for transferring to another graph
}
}
impl InitModule for SelfAttention {
fn initialize(cx: &mut Graph) -> Self {
impl SelfAttention {
pub fn new(cx: &mut Graph) -> Self {
Self {
q_proj: cx.named_tensor("Q Proj"),
k_proj: cx.named_tensor("K Proj"),
v_proj: cx.named_tensor("V Proj"),
o_proj: cx.named_tensor("O Proj"),
q_proj: cx.named_tensor("Q Proj", (HIDDEN_DIM, HIDDEN_DIM)),
k_proj: cx.named_tensor("K Proj", (ATTN_PROJ_DIM, HIDDEN_DIM)),
v_proj: cx.named_tensor("V Proj", (ATTN_PROJ_DIM, HIDDEN_DIM)),
o_proj: cx.named_tensor("O Proj", (HIDDEN_DIM, HIDDEN_DIM)),
}
}
}
@@ -187,54 +154,37 @@ impl SerializeModule for SelfAttention {
pub struct TransformerBlock {
pub attention: SelfAttention,
pub attention_norm: LayerNorm<HIDDEN_DIM>,
pub feed_forward: Mlp<MLP_DIM, HIDDEN_DIM>,
pub feed_forward_norm: LayerNorm<HIDDEN_DIM>,
pub attention_norm: LayerNorm,
pub feed_forward: Mlp,
pub feed_forward_norm: LayerNorm,
}
impl<Batch: Dimension, CurSeq: Dimension, PrevSeq: Dimension, TotSeq: Dimension>
Module<(
GraphTensor<(Batch, CurSeq, Const<HIDDEN_DIM>)>,
KVCache<Batch, PrevSeq>,
PhantomData<TotSeq>,
)> for TransformerBlock
{
type Output = (
GraphTensor<(Batch, CurSeq, Const<HIDDEN_DIM>)>,
KVCache<Batch, TotSeq>,
);
fn forward(
&self,
(mut x, cache, _): (
GraphTensor<(Batch, CurSeq, Const<HIDDEN_DIM>)>,
KVCache<Batch, PrevSeq>,
PhantomData<TotSeq>,
),
) -> Self::Output {
impl Module<(GraphTensor, KVCache)> for TransformerBlock {
type Output = (GraphTensor, KVCache);
fn forward(&self, (mut x, cache): (GraphTensor, KVCache)) -> Self::Output {
// Attention
let normed = self.attention_norm.forward(x);
let (y, cache) = self
.attention
.forward((normed, cache, PhantomData::<TotSeq>));
.forward((self.attention_norm.forward(x), cache));
// Residual Addition
// Residual
x += y;
// Feed Forward
let y = self.feed_forward.forward(self.feed_forward_norm.forward(x));
// Residual Addition
// Residual
(x + y, cache)
}
}
impl InitModule for TransformerBlock {
fn initialize(cx: &mut Graph) -> Self {
impl TransformerBlock {
pub fn new(cx: &mut Graph) -> Self {
Self {
attention: InitModule::initialize(cx),
attention_norm: LayerNorm::new(true, false, false, 1e-5, cx),
feed_forward: InitModule::initialize(cx),
feed_forward_norm: LayerNorm::new(true, false, false, 1e-5, cx),
attention: SelfAttention::new(cx),
attention_norm: LayerNorm::new(HIDDEN_DIM, true, false, false, 1e-5, cx),
feed_forward: Mlp::new(HIDDEN_DIM, MLP_DIM, cx),
feed_forward_norm: LayerNorm::new(HIDDEN_DIM, true, false, false, 1e-5, cx),
}
}
}
@@ -250,34 +200,16 @@ impl SerializeModule for TransformerBlock {
pub struct Phi {
// Token embeddings
pub embedding: Embedding<VOCAB_SIZE, HIDDEN_DIM>,
pub embedding: Embedding,
// Transformer layers
pub layers: Vec<TransformerBlock>,
// Final Norm layer
pub norm: LayerNorm<HIDDEN_DIM>,
// LM Head Layer
pub lm_head: PermutedLinear<HIDDEN_DIM, VOCAB_SIZE>,
// Norm + LM head
pub head: (LayerNorm, Linear),
}
impl<Batch: Dimension, CurSeq: Dimension, PrevSeq: Dimension, TotSeq: Dimension>
Module<(
GraphTensor<(Batch, CurSeq)>,
&[KVCache<Batch, PrevSeq>],
PhantomData<TotSeq>,
)> for Phi
{
type Output = (
GraphTensor<(Batch, CurSeq, Const<VOCAB_SIZE>)>,
Vec<KVCache<Batch, TotSeq>>,
);
fn forward(
&self,
(input, cache, _): (
GraphTensor<(Batch, CurSeq)>,
&[KVCache<Batch, PrevSeq>],
PhantomData<TotSeq>,
),
) -> Self::Output {
impl Module<(GraphTensor, &[KVCache])> for Phi {
type Output = (GraphTensor, Vec<KVCache>);
fn forward(&self, (input, cache): (GraphTensor, &[KVCache])) -> Self::Output {
// Embed tokens
let mut x = self.embedding.forward(input);
@@ -285,29 +217,23 @@ impl<Batch: Dimension, CurSeq: Dimension, PrevSeq: Dimension, TotSeq: Dimension>
let mut new_caches = vec![];
let mut new_cache;
for (i, layer) in self.layers.iter().enumerate() {
(x, new_cache) = layer.forward((x, cache[i], PhantomData::<TotSeq>));
(x, new_cache) = layer.forward((x, cache[i]));
new_caches.push(new_cache);
}
// Run through last norm and output projection
let output = self.lm_head.forward(self.norm.forward(x));
(output, new_caches)
(self.head.forward(x), new_caches)
}
}
impl InitModule for Phi {
fn initialize(cx: &mut Graph) -> Self {
impl Phi {
pub fn new(cx: &mut Graph) -> Self {
Self {
embedding: Embedding {
weight: cx.named_tensor("Embedding Weight"),
},
norm: LayerNorm::new(true, false, false, 1e-5, cx),
lm_head: PermutedLinear {
weight: cx.tensor(),
},
layers: (0..NUM_LAYERS)
.map(|_| InitModule::initialize(cx))
.collect(),
embedding: Embedding::new(VOCAB_SIZE, HIDDEN_DIM, cx),
head: (
LayerNorm::new(HIDDEN_DIM, true, false, false, 1e-5, cx),
Linear::new_permuted(HIDDEN_DIM, VOCAB_SIZE, false, cx),
),
layers: (0..NUM_LAYERS).map(|_| TransformerBlock::new(cx)).collect(),
}
}
}
@@ -315,8 +241,8 @@ impl InitModule for Phi {
impl SerializeModule for Phi {
fn serialize(&self, s: &mut Serializer) {
s.module("token_embd", &self.embedding);
s.module("output_norm", &self.norm);
s.module("output", &self.lm_head);
s.module("output_norm", &self.head.0);
s.module("output", &self.head.1);
for (i, layer) in self.layers.iter().enumerate() {
s.module(&format!("blk/{i}"), layer);
}

View File

@@ -5,9 +5,9 @@ fn main() {
// Create a new graph
let mut cx = Graph::new();
// Randomly initialize a linear layer with an input size of 4 and an output size of 5
let model = Linear::<4, 5>::initialize(&mut cx);
let model = Linear::new(4, 5, false, &mut cx).initialize();
// Make an input tensor
let a = cx.tensor::<R1<4>>().set(vec![1., 2., 3., 4.]);
let a = cx.tensor(4).set(vec![1., 2., 3., 4.]);
// Feed tensor through model
let b = model.forward(a).retrieve();

View File

@@ -10,9 +10,15 @@ use rand::{rngs::ThreadRng, thread_rng, Rng};
fn main() {
// Setup gradient graph
let mut cx = Graph::new();
let model = <(Linear<8, 16>, Swish, Linear<16, 16>, Swish, Linear<16, 5>)>::initialize(&mut cx);
let mut input = cx.tensor::<R1<8>>();
let mut target = cx.tensor::<R1<5>>();
let model = (
Linear::new(8, 16, false, &mut cx).initialize(),
Swish,
Linear::new(16, 16, false, &mut cx).initialize(),
Swish,
Linear::new(16, 5, false, &mut cx).initialize(),
);
let mut input = cx.tensor(8);
let mut target = cx.tensor(5);
let mut output = model.forward(input).retrieve();
let mut loss = mse_loss(output, target).retrieve();
@@ -38,7 +44,10 @@ fn main() {
#[cfg(feature = "metal")]
cx.compile(
luminal_metal::MetalCompiler::<f32>::default(),
(
GenericCompiler::default(),
luminal_metal::MetalCompiler::<f32>::default(),
),
(
&mut input,
&mut target,

View File

@@ -4,7 +4,7 @@ use std::{io::Write, marker::PhantomData};
use itertools::Itertools;
// WIP
use luminal::prelude::*;
use model::KVCache;
use model::{KVCache, D_MODEL, HEADS, HEAD_DIM, N_MEL_BINS};
use tokenizers::Tokenizer;
mod audio;
@@ -20,33 +20,30 @@ fn main() {
// Construct encoder graph
let mut enc_cx = Graph::new();
let encoder = model::AudioEncoder::initialize(&mut enc_cx);
let encoder = model::AudioEncoder::new(&mut enc_cx);
let mut encoder_params = params(&encoder);
enc_cx.keep_tensors(&encoder_params);
let mut audio_input = enc_cx.tensor::<(Const<1>, Const<{ model::N_MEL_BINS }>, Dyn<'s'>)>();
let mut encoded: GraphTensor<(Const<1>, Dyn<'d'>, Const<384>)> = encoder
.forward((audio_input, PhantomData::<Dyn<'d'>>))
.keep();
let mut audio_input = enc_cx.tensor((1, N_MEL_BINS, 's'));
let mut encoded = encoder.forward(audio_input).keep();
loader::load("setup/whisper-tiny.safetensors", &encoder, &mut enc_cx);
// Construct decoder graph
let mut dec_cx = Graph::new();
let decoder = model::TextDecoder::initialize(&mut dec_cx);
let decoder = model::TextDecoder::new(&mut dec_cx);
let mut decoder_params = params(&decoder);
dec_cx.keep_tensors(&decoder_params);
let mut text_input = dec_cx.tensor::<(Const<1>, Dyn<'s'>)>();
let mut encoder_output =
dec_cx.named_tensor::<(Const<1>, Dyn<'e'>, Const<{ model::D_MODEL }>)>("Enc Output");
let mut cache_src: Vec<KVCache<Const<1>, Dyn<'p'>>> = (0..model::DEC_LAYERS)
.map(|_| (dec_cx.named_tensor("Keys"), dec_cx.named_tensor("Values")))
.collect();
cache_src.set_dyn(vec![], &[1, 6, 64, 0]);
let (logits, _, mut cache_dest) = decoder.forward((
encoder_output,
text_input,
&cache_src,
PhantomData::<Dyn<'t'>>,
));
let mut text_input = dec_cx.tensor((1, 's'));
let mut encoder_output = dec_cx.named_tensor("Enc Output", (1, 'e', D_MODEL));
let mut cache_src = (0..model::DEC_LAYERS)
.map(|_| {
(
dec_cx.named_tensor("Keys", (1, HEADS, HEAD_DIM, 'p')),
dec_cx.named_tensor("Values", (1, HEADS, 'p', HEAD_DIM)),
)
})
.collect::<Vec<_>>();
cache_src.set_dyn(vec![], (1, 6, 64, 0));
let (logits, _, mut cache_dest) = decoder.forward((encoder_output, text_input, &cache_src));
let mut logits = logits
.slice((.., Expression::from('s') - 1.., ..))
.retrieve();
@@ -103,11 +100,11 @@ fn main() {
print!("Loading weights");
std::io::stdout().flush().unwrap();
let now = std::time::Instant::now();
audio_input.set_dyn(vec![0.; 160], &[1, 80, 2]);
audio_input.set_dyn(vec![0.; 160], (1, 80, 2));
enc_cx.set_dyn_dim('d', 1);
enc_cx.execute();
delete_inputs(downstream(encoder_params, &enc_cx), &mut enc_cx);
text_input.set_dyn(vec![0.], &[1, 1]);
text_input.set_dyn(vec![0.], (1, 1));
dec_cx.set_dyn_dim('e', 1);
dec_cx.set_dyn_dim('p', 0);
dec_cx.set_dyn_dim('t', 1);
@@ -132,7 +129,7 @@ fn main() {
std::io::stdout().flush().unwrap();
let start_encoding = std::time::Instant::now();
audio_input.set_dyn(mel, &[1, 80, mel_len / 80]);
audio_input.set_dyn(mel, (1, 80, mel_len / 80));
enc_cx.set_dyn_dim('d', (mel_len / 80) / 2);
enc_cx.execute();
transfer_data(encoded, &mut enc_cx, encoder_output, &mut dec_cx);
@@ -144,7 +141,7 @@ fn main() {
dec_cx.set_dyn_dim('p', 0);
dec_cx.set_dyn_dim('t', 3);
let mut output_ids = vec![];
text_input.set_dyn(vec![50257., 50358., 50362.], &[1, 3]);
text_input.set_dyn(vec![50257., 50358., 50362.], (1, 3));
dec_cx.execute();
let mut output_token = argmax(&logits.data());
logits.drop();
@@ -156,7 +153,7 @@ fn main() {
for i in 0..100 {
transfer_data_same_graph(&cache_dest, &cache_src, &mut dec_cx);
text_input.set_dyn(vec![output_token as f32], &[1, 1]);
text_input.set_dyn(vec![output_token as f32], (1, 1));
dec_cx.set_dyn_dim('p', i + 3);
dec_cx.set_dyn_dim('t', i + 4);
dec_cx.execute();

View File

@@ -1,5 +1,5 @@
use luminal::prelude::{binary::F32Pow, *};
use luminal_nn::{Conv1D, Embedding, LayerNorm, Linear, PermutedEmbedding, PermutedLinear};
use luminal_nn::{Conv1D, Embedding, LayerNorm, Linear};
use std::marker::PhantomData;
use std::ops::{Add, Mul};
@@ -41,172 +41,150 @@ pub const NO_TIMESTAMPS_TOKEN: &str = "<|notimestamps|>";
pub const EOT_TOKEN: &str = "<|endoftext|>";
pub const NO_SPEECH_TOKENS: [&str; 2] = ["<|nocaptions|>", "<|nospeech|>"];
pub type KVCache<Batch, Seq> = (
GraphTensor<(Batch, Const<HEADS>, Const<HEAD_DIM>, Seq)>,
GraphTensor<(Batch, Const<HEADS>, Seq, Const<HEAD_DIM>)>,
);
pub type KVCache = (GraphTensor, GraphTensor);
pub struct SelfAttention<const HIDDEN: usize> {
pub q_proj: GraphTensor<R2<HIDDEN, HIDDEN>>,
pub q_proj_bias: GraphTensor<R1<HIDDEN>>,
pub k_proj: GraphTensor<R2<HIDDEN, HIDDEN>>,
pub v_proj: GraphTensor<R2<HIDDEN, HIDDEN>>,
pub v_proj_bias: GraphTensor<R1<HIDDEN>>,
pub o_proj: GraphTensor<R2<HIDDEN, HIDDEN>>,
pub o_proj_bias: GraphTensor<R1<HIDDEN>>,
pub struct SelfAttention {
pub q_proj: GraphTensor, // hidden x hidden
pub q_proj_bias: GraphTensor, // hidden
pub k_proj: GraphTensor, // hidden x hidden
pub v_proj: GraphTensor, // hidden x hidden
pub v_proj_bias: GraphTensor, // hidden
pub o_proj: GraphTensor, // hidden x hidden
pub o_proj_bias: GraphTensor, // hidden
}
impl<
const HIDDEN: usize,
Batch: Dimension,
CurSeq: Dimension,
PrevSeq: Dimension,
TotSeq: Dimension,
>
Module<(
GraphTensor<(Batch, CurSeq, Const<HIDDEN>)>,
Option<KVCache<Batch, PrevSeq>>,
bool,
PhantomData<TotSeq>,
)> for SelfAttention<HIDDEN>
{
type Output = (
GraphTensor<(Batch, CurSeq, Const<HIDDEN>)>,
KVCache<Batch, TotSeq>,
);
fn forward(
&self,
(x, cache, mask, _): (
GraphTensor<(Batch, CurSeq, Const<HIDDEN>)>,
Option<KVCache<Batch, PrevSeq>>,
bool,
PhantomData<TotSeq>,
),
) -> Self::Output {
let scale = ((HIDDEN as f32 / HEADS as f32) as f64).powf(-0.25) as f32;
impl Module<(GraphTensor, Option<KVCache>, bool)> for SelfAttention {
type Output = (GraphTensor, KVCache);
fn forward(&self, (x, cache, mask): (GraphTensor, Option<KVCache>, bool)) -> Self::Output {
// x: batch, seq, hidden
let (batch, seq, hidden) = x.dims3();
let scale = ((hidden.to_usize().unwrap() as f32 / HEADS as f32) as f64).powf(-0.25) as f32;
// Apply the Projections
let queries = x
.matmul(self.q_proj.permute())
.add(self.q_proj_bias.expand())
.matmul(self.q_proj.permute((1, 0)))
.add(self.q_proj_bias.expand(0, batch).expand(1, seq))
.mul(scale)
.reshape::<(Batch, CurSeq, Const<HEADS>, Const<HEAD_DIM>)>()
.permute::<_, Axes4<0, 2, 1, 3>>();
.reshape((batch, seq, HEADS, HEAD_DIM))
.permute((0, 2, 1, 3));
let keys = x
.matmul(self.k_proj.permute())
.matmul(self.k_proj.permute((1, 0)))
.mul(scale)
.reshape::<(Batch, CurSeq, Const<HEADS>, Const<HEAD_DIM>)>()
.permute::<_, Axes4<0, 2, 3, 1>>()
.reshape((batch, seq, HEADS, HEAD_DIM))
.permute((0, 2, 3, 1))
.contiguous();
let values = x
.matmul(self.v_proj.permute())
.add(self.v_proj_bias.expand())
.reshape::<(Batch, CurSeq, Const<HEADS>, Const<HEAD_DIM>)>()
.permute::<_, Axes4<0, 2, 1, 3>>();
.matmul(self.v_proj.permute((1, 0)))
.add(self.v_proj_bias.expand(0, batch).expand(1, seq))
.reshape((batch, seq, HEADS, HEAD_DIM))
.permute((0, 2, 1, 3));
// Add KV cache
let (keys, values) = if let Some((k_cache, v_cache)) = cache {
(
k_cache.concat_along::<_, Axis<3>, _>(keys),
v_cache.concat_along::<_, Axis<2>, _>(values),
k_cache.concat_along(keys, 3),
v_cache.concat_along(values, 2),
)
} else {
(keys.realize(), values.realize())
(keys, values)
};
// Calculate attention weights
let mut attention_weights = queries.matmul(keys);
if mask {
let attention_mask = self.k_proj.graph().triu::<CurSeq>(1) * f16::MIN.to_f32();
attention_weights += attention_mask
.pad::<(CurSeq, TotSeq)>(((0, 0), (TotSeq::size() - CurSeq::size(), 0)))
.expand();
let mut attention_mask = self.k_proj.graph().triu(seq, 1) * f16::MIN.to_f32();
if let Some((c, _)) = cache {
let (_, _, prev_seq, _) = c.dims4();
attention_mask = attention_mask.pad(((0, 0), (prev_seq, 0)));
}
attention_weights += attention_mask.expand(0, batch).expand(1, HEADS);
}
// Calculate final outputs
let output = attention_weights
.softmax::<Axis<3>>()
.softmax(3)
// Apply distribution to values
.matmul(values)
// Merge heads
.permute::<_, Axes4<0, 2, 1, 3>>()
.reshape::<(Batch, CurSeq, Const<HIDDEN>)>();
.permute((0, 2, 1, 3))
.reshape((batch, seq, hidden));
let output = output
// Apply output projection
.matmul(self.o_proj.permute())
.add(self.o_proj_bias.expand());
.matmul(self.o_proj.permute((1, 0)))
.add(self.o_proj_bias.expand(0, batch).expand(1, seq));
(output, (keys.contiguous(), values.contiguous())) // Cache needs to be contiguous
}
}
impl<const HIDDEN: usize> SelfAttention<HIDDEN> {
#[allow(clippy::type_complexity)]
fn cross_attention_forward<Batch: Dimension, EncSeq: Dimension, DecSeq: Dimension>(
impl SelfAttention {
fn cross_attention_forward(
&self,
queries: GraphTensor<(Batch, DecSeq, Const<HIDDEN>)>,
keys: GraphTensor<(Batch, EncSeq, Const<HIDDEN>)>,
values: GraphTensor<(Batch, EncSeq, Const<HIDDEN>)>,
queries: GraphTensor, // batch, dec_seq, hidden
keys: GraphTensor, // batch, enc_seq, hidden
values: GraphTensor, // batch, enc_seq, hidden
) -> (
GraphTensor<(Batch, DecSeq, Const<HIDDEN>)>,
KVCache<Batch, EncSeq>,
GraphTensor, // batch, dec_seq, hidden
KVCache,
) {
let scale = ((HIDDEN as f32 / HEADS as f32) as f64).powf(-0.25) as f32;
let (batch, enc_seq, hidden) = keys.dims3();
let (_, dec_seq, hidden) = queries.dims3();
let scale = ((hidden.to_usize().unwrap() as f32 / HEADS as f32) as f64).powf(-0.25) as f32;
// Apply the projections
let queries = queries
.matmul(self.q_proj.permute())
.add(self.q_proj_bias.expand())
.matmul(self.q_proj.permute((1, 0)))
.add(self.q_proj_bias.expand(0, batch).expand(1, dec_seq))
.mul(scale)
.reshape::<(Batch, DecSeq, Const<HEADS>, Const<HEAD_DIM>)>()
.permute::<_, Axes4<0, 2, 1, 3>>();
.reshape((batch, dec_seq, HEADS, HEAD_DIM))
.permute((0, 2, 1, 3));
let keys = keys
.matmul(self.k_proj.permute())
.matmul(self.k_proj.permute((1, 0)))
.mul(scale)
.reshape::<(Batch, EncSeq, Const<HEADS>, Const<HEAD_DIM>)>()
.permute::<_, Axes4<0, 2, 3, 1>>()
.reshape((batch, enc_seq, HEADS, HEAD_DIM))
.permute((0, 2, 3, 1))
.contiguous();
let values = values
.matmul(self.v_proj.permute())
.add(self.v_proj_bias.expand())
.reshape::<(Batch, EncSeq, Const<HEADS>, Const<HEAD_DIM>)>()
.permute::<_, Axes4<0, 2, 1, 3>>();
.matmul(self.v_proj.permute((1, 0)))
.add(self.v_proj_bias.expand(0, batch).expand(1, enc_seq))
.reshape((batch, enc_seq, HEADS, HEAD_DIM))
.permute((0, 2, 1, 3));
// Calculate attention weights
let mut attention_weights = queries.matmul(keys);
// Calculate final outputs
let output = attention_weights
.softmax::<Axis<3>>()
.softmax(3)
// Apply distribution to values
.matmul(values)
// Merge heads
.permute::<_, Axes4<0, 2, 1, 3>>()
.reshape::<(Batch, DecSeq, Const<HIDDEN>)>();
.permute((0, 2, 1, 3))
.reshape((batch, dec_seq, hidden));
let output = output
// Apply output projection
.matmul(self.o_proj.permute())
.add(self.o_proj_bias.expand());
.matmul(self.o_proj.permute((1, 0)))
.add(self.o_proj_bias.expand(0, batch).expand(1, dec_seq));
(output, (keys, values))
}
}
impl<const HIDDEN: usize> InitModule for SelfAttention<HIDDEN> {
fn initialize(cx: &mut Graph) -> Self {
impl SelfAttention {
fn new(hidden: usize, cx: &mut Graph) -> Self {
Self {
q_proj: cx.named_tensor("Q Proj"),
q_proj_bias: cx.named_tensor("Q Proj Bias"),
k_proj: cx.named_tensor("K Proj"),
v_proj: cx.named_tensor("V Proj"),
v_proj_bias: cx.named_tensor("V Proj Bias"),
o_proj: cx.named_tensor("O Proj"),
o_proj_bias: cx.named_tensor("O Proj Bias"),
q_proj: cx.named_tensor("Q Proj", (hidden, hidden)),
q_proj_bias: cx.named_tensor("Q Proj Bias", hidden),
k_proj: cx.named_tensor("K Proj", (hidden, hidden)),
v_proj: cx.named_tensor("V Proj", (hidden, hidden)),
v_proj_bias: cx.named_tensor("V Proj Bias", hidden),
o_proj: cx.named_tensor("O Proj", (hidden, hidden)),
o_proj_bias: cx.named_tensor("O Proj Bias", hidden),
}
}
}
impl<const HIDDEN: usize> SerializeModule for SelfAttention<HIDDEN> {
impl SerializeModule for SelfAttention {
fn serialize(&self, s: &mut Serializer) {
s.tensor("q_proj/weight", self.q_proj);
s.tensor("q_proj/bias", self.q_proj_bias);
@@ -219,54 +197,42 @@ impl<const HIDDEN: usize> SerializeModule for SelfAttention<HIDDEN> {
}
pub struct EncoderTransformerBlock {
pub attention: SelfAttention<D_MODEL>,
pub attention_norm: LayerNorm<D_MODEL>,
pub ff1: PermutedLinear<D_MODEL, ENC_FFN_DIM>,
pub ff1_bias: GraphTensor<R1<ENC_FFN_DIM>>,
pub ff2: PermutedLinear<ENC_FFN_DIM, D_MODEL>,
pub ff2_bias: GraphTensor<R1<D_MODEL>>,
pub feed_forward_norm: LayerNorm<D_MODEL>,
pub attention: SelfAttention,
pub attention_norm: LayerNorm,
pub ff1: Linear, // hidden -> enc_ffn_dim
pub ff2: Linear, // enc_ffn_dim -> hidden
pub feed_forward_norm: LayerNorm,
}
impl<Batch: Dimension, Seq: Dimension> Module<GraphTensor<(Batch, Seq, Const<D_MODEL>)>>
for EncoderTransformerBlock
{
type Output = GraphTensor<(Batch, Seq, Const<D_MODEL>)>;
fn forward(&self, mut x: GraphTensor<(Batch, Seq, Const<D_MODEL>)>) -> Self::Output {
impl Module<GraphTensor> for EncoderTransformerBlock {
type Output = GraphTensor;
fn forward(&self, mut x: GraphTensor) -> Self::Output {
let (batch, seq, _) = x.dims3();
// Attention
let (y, _) = self.attention.forward((
self.attention_norm.forward(x),
Option::<KVCache<Batch, Seq>>::None,
false,
PhantomData::<Seq>,
));
let (y, _) = self
.attention
.forward((self.attention_norm.forward(x), None, false));
// Residual Addition
x += y;
// Feed Forward
let y = self.ff1.forward(self.feed_forward_norm.forward(x)) + self.ff1_bias.expand();
let y = self.ff2.forward(y.gelu()) + self.ff2_bias.expand();
let y = self.ff1.forward(self.feed_forward_norm.forward(x));
let y = self.ff2.forward(y.gelu());
// Residual Addition
x + y
}
}
impl InitModule for EncoderTransformerBlock {
fn initialize(cx: &mut Graph) -> Self {
impl EncoderTransformerBlock {
fn new(hidden: usize, ff: usize, cx: &mut Graph) -> Self {
Self {
attention: InitModule::initialize(cx),
attention_norm: LayerNorm::new(true, true, true, 1e-5, cx),
ff1: PermutedLinear {
weight: cx.tensor(),
},
ff1_bias: cx.tensor(),
ff2: PermutedLinear {
weight: cx.tensor(),
},
ff2_bias: cx.tensor(),
feed_forward_norm: LayerNorm::new(true, true, true, 1e-5, cx),
attention: SelfAttention::new(hidden, cx),
attention_norm: LayerNorm::new(hidden, true, true, true, 1e-5, cx),
ff1: Linear::new_permuted(hidden, ff, true, cx),
ff2: Linear::new_permuted(ff, hidden, true, cx),
feed_forward_norm: LayerNorm::new(hidden, true, true, true, 1e-5, cx),
}
}
}
@@ -277,64 +243,46 @@ impl SerializeModule for EncoderTransformerBlock {
s.module("self_attn_layer_norm", &self.attention_norm);
s.module("final_layer_norm", &self.feed_forward_norm);
s.module("fc1", &self.ff1);
s.tensor("fc1/bias", self.ff1_bias);
s.module("fc2", &self.ff2);
s.tensor("fc2/bias", self.ff2_bias);
}
}
pub struct AudioEncoder {
// Conv layers (based on https://github.com/huggingface/candle/blob/59b18d974ec3cad6963b774aa245e23f8c80414f/candle-transformers/src/models/whisper/model.rs#L246)
pub conv1: Conv1D<N_MEL_BINS, D_MODEL, 3, 1, 0, 1>,
pub conv2: Conv1D<D_MODEL, D_MODEL, 3, 2, 0, 1>,
pub conv1: Conv1D,
pub conv2: Conv1D,
// Transformer layers
pub layers: Vec<EncoderTransformerBlock>,
// Post layer norm
pub post_ln: LayerNorm<D_MODEL>,
pub post_ln: LayerNorm,
}
fn sinusoids<const CHANNELS: usize, Length: Dimension>(
cx: &mut Graph,
) -> GraphTensor<(Length, Const<CHANNELS>)> {
fn sinusoids(channels: usize, length: Expression, cx: &mut Graph) -> GraphTensor {
let max_timescale = 10000f32;
let log_timescale_increment = max_timescale.ln() / (CHANNELS / 2 - 1) as f32;
let inv_timescales = (0..CHANNELS / 2)
let log_timescale_increment = max_timescale.ln() / (channels / 2 - 1) as f32;
let inv_timescales = (0..channels / 2)
.map(|i| (i as f32 * (-log_timescale_increment)).exp())
.collect::<Vec<_>>();
let mut inv_timescales = cx
.tensor::<(Dyn<'-'>,)>()
.set_dyn(inv_timescales, &[CHANNELS / 2]);
inv_timescales.shape.dims[0] = (CHANNELS / 2).into();
let arange = cx.arange::<Length>();
.tensor(channels / 2)
.set_dyn(inv_timescales, (channels / 2));
let arange = cx.arange(length);
let mut mul_shape = arange.shape;
mul_shape.add_dim(1, CHANNELS / 2);
let scaled_time: GraphTensor<(Length, Dyn<'-'>)> =
arange.expand_to(mul_shape) * inv_timescales.expand_to(mul_shape);
scaled_time
.sin()
.concat_along::<_, Axis<1>, _>(scaled_time.cos())
mul_shape.add_dim(1, channels / 2);
let scaled_time = arange.expand_to(mul_shape) * inv_timescales.expand_to(mul_shape);
scaled_time.sin().concat_along(scaled_time.cos(), 1)
}
impl<Batch: Dimension, Seq: Dimension, SeqDivTwo: Dimension>
Module<(
GraphTensor<(Batch, Const<N_MEL_BINS>, Seq)>,
PhantomData<SeqDivTwo>,
)> for AudioEncoder
{
type Output = GraphTensor<(Batch, SeqDivTwo, Const<D_MODEL>)>;
fn forward(
&self,
(x, _): (
GraphTensor<(Batch, Const<N_MEL_BINS>, Seq)>,
PhantomData<SeqDivTwo>,
),
) -> Self::Output {
impl Module<GraphTensor> for AudioEncoder {
type Output = GraphTensor;
fn forward(&self, x: GraphTensor) -> Self::Output {
let (_, seq, _) = x.dims3();
// Conv layers
let x = self.conv1.forward((x, PhantomData::<Seq>)).gelu();
let x = self.conv2.forward((x, PhantomData::<SeqDivTwo>)).gelu();
let mut x = x.permute::<_, Axes3<0, 2, 1>>();
let x = self.conv1.forward(x).gelu();
let x = self.conv2.forward(x).gelu();
let mut x = x.permute((0, 2, 1));
// Sinusoidal positional embedding
x += sinusoids::<D_MODEL, SeqDivTwo>(x.graph()).expand();
x += sinusoids(D_MODEL, seq / 2, x.graph()).expand_to(x.shape);
// Transformer layers
let out = self.layers.forward(x);
@@ -343,15 +291,15 @@ impl<Batch: Dimension, Seq: Dimension, SeqDivTwo: Dimension>
}
}
impl InitModule for AudioEncoder {
fn initialize(cx: &mut Graph) -> Self {
impl AudioEncoder {
pub fn new(cx: &mut Graph) -> Self {
Self {
conv1: Conv1D::initialize_bias(cx),
conv2: Conv1D::initialize_bias(cx),
conv1: Conv1D::new(N_MEL_BINS, D_MODEL, 3, 1, 1, 1, true, cx),
conv2: Conv1D::new(D_MODEL, D_MODEL, 3, 2, 1, 1, true, cx),
layers: (0..ENC_LAYERS)
.map(|_| InitModule::initialize(cx))
.map(|_| EncoderTransformerBlock::new(D_MODEL, ENC_FFN_DIM, cx))
.collect(),
post_ln: LayerNorm::new(true, true, true, 1e-5, cx),
post_ln: LayerNorm::new(D_MODEL, true, true, true, 1e-5, cx),
}
}
}
@@ -368,52 +316,25 @@ impl SerializeModule for AudioEncoder {
}
pub struct DecoderTransformerBlock {
pub attention: SelfAttention<D_MODEL>,
pub attention_norm: LayerNorm<D_MODEL>,
pub cross_attention: SelfAttention<D_MODEL>,
pub cross_attention_norm: LayerNorm<D_MODEL>,
pub ff1: PermutedLinear<D_MODEL, DEC_FFN_DIM>,
pub ff1_bias: GraphTensor<R1<DEC_FFN_DIM>>,
pub ff2: PermutedLinear<DEC_FFN_DIM, D_MODEL>,
pub ff2_bias: GraphTensor<R1<D_MODEL>>,
pub feed_forward_norm: LayerNorm<D_MODEL>,
pub attention: SelfAttention,
pub attention_norm: LayerNorm,
pub cross_attention: SelfAttention,
pub cross_attention_norm: LayerNorm,
pub ff1: Linear,
pub ff2: Linear,
pub feed_forward_norm: LayerNorm,
}
impl<
Batch: Dimension,
EncSeq: Dimension,
CurSeq: Dimension,
PrevSeq: Dimension,
TotSeq: Dimension,
>
Module<(
GraphTensor<(Batch, CurSeq, Const<D_MODEL>)>,
GraphTensor<(Batch, EncSeq, Const<D_MODEL>)>,
KVCache<Batch, PrevSeq>,
PhantomData<TotSeq>,
)> for DecoderTransformerBlock
{
type Output = (
GraphTensor<(Batch, CurSeq, Const<D_MODEL>)>,
KVCache<Batch, EncSeq>,
KVCache<Batch, TotSeq>,
);
impl Module<(GraphTensor, GraphTensor, KVCache)> for DecoderTransformerBlock {
type Output = (GraphTensor, KVCache, KVCache);
fn forward(
&self,
(mut x, encoded, cache, _): (
GraphTensor<(Batch, CurSeq, Const<D_MODEL>)>,
GraphTensor<(Batch, EncSeq, Const<D_MODEL>)>,
KVCache<Batch, PrevSeq>,
PhantomData<TotSeq>,
),
(mut x, encoded, cache): (GraphTensor, GraphTensor, KVCache),
) -> Self::Output {
// Self Attention
let (y, cache) = self.attention.forward((
self.attention_norm.forward(x),
Some(cache),
true,
PhantomData::<TotSeq>,
));
let (y, cache) =
self.attention
.forward((self.attention_norm.forward(x), Some(cache), true));
// Residual Addition
x += y;
@@ -429,30 +350,24 @@ impl<
x += y;
// Feed Forward
let y = self.ff1.forward(self.feed_forward_norm.forward(x)) + self.ff1_bias.expand();
let y = self.ff2.forward(y.gelu()) + self.ff2_bias.expand();
let y = self.ff1.forward(self.feed_forward_norm.forward(x));
let y = self.ff2.forward(y.gelu());
// Residual Addition
(x + y, enc_states, cache)
}
}
impl InitModule for DecoderTransformerBlock {
fn initialize(cx: &mut Graph) -> Self {
impl DecoderTransformerBlock {
fn new(cx: &mut Graph) -> Self {
Self {
attention: InitModule::initialize(cx),
attention_norm: LayerNorm::new(true, true, true, 1e-5, cx),
cross_attention: InitModule::initialize(cx),
cross_attention_norm: LayerNorm::new(true, true, true, 1e-5, cx),
ff1: PermutedLinear {
weight: cx.tensor(),
},
ff1_bias: cx.tensor(),
ff2: PermutedLinear {
weight: cx.tensor(),
},
ff2_bias: cx.tensor(),
feed_forward_norm: LayerNorm::new(true, true, true, 1e-5, cx),
attention: SelfAttention::new(D_MODEL, cx),
attention_norm: LayerNorm::new(D_MODEL, true, true, true, 1e-5, cx),
cross_attention: SelfAttention::new(D_MODEL, cx),
cross_attention_norm: LayerNorm::new(D_MODEL, true, true, true, 1e-5, cx),
ff1: Linear::new_permuted(D_MODEL, DEC_FFN_DIM, true, cx),
ff2: Linear::new_permuted(DEC_FFN_DIM, D_MODEL, true, cx),
feed_forward_norm: LayerNorm::new(D_MODEL, true, true, true, 1e-5, cx),
}
}
}
@@ -464,93 +379,67 @@ impl SerializeModule for DecoderTransformerBlock {
s.module("encoder_attn", &self.cross_attention);
s.module("encoder_attn_layer_norm", &self.cross_attention_norm);
s.module("fc1", &self.ff1);
s.tensor("fc1/bias", self.ff1_bias);
s.module("fc2", &self.ff2);
s.tensor("fc2/bias", self.ff2_bias);
s.module("final_layer_norm", &self.feed_forward_norm);
}
}
pub struct TextDecoder {
// Embeddings
pub embedding: PermutedEmbedding<VOCAB_SIZE, D_MODEL>,
pub pos_embedding: GraphTensor<R2<MAX_TARGET_POSITION, D_MODEL>>,
pub embedding: Embedding,
pub pos_embedding: GraphTensor,
// Transformer layers
pub layers: Vec<DecoderTransformerBlock>,
// Final layer norm
pub layer_norm: LayerNorm<D_MODEL>,
pub layer_norm: LayerNorm,
}
impl<
Batch: Dimension,
EncSeq: Dimension,
PrevDecSeq: Dimension,
CurDecSeq: Dimension,
TotDecSeq: Dimension,
>
Module<(
GraphTensor<(Batch, EncSeq, Const<D_MODEL>)>,
GraphTensor<(Batch, CurDecSeq)>,
&[KVCache<Batch, PrevDecSeq>],
PhantomData<TotDecSeq>,
)> for TextDecoder
{
impl Module<(GraphTensor, GraphTensor, &[KVCache])> for TextDecoder {
type Output = (
GraphTensor<(Batch, CurDecSeq, Const<VOCAB_SIZE>)>,
Vec<KVCache<Batch, EncSeq>>, // Encoder projected states
Vec<KVCache<Batch, TotDecSeq>>, // Decoder KV cache
GraphTensor,
Vec<KVCache>, // Encoder projected states
Vec<KVCache>, // Decoder KV cache
);
fn forward(
&self,
(enc_output, input, cache, _): (
GraphTensor<(Batch, EncSeq, Const<D_MODEL>)>,
GraphTensor<(Batch, CurDecSeq)>,
&[KVCache<Batch, PrevDecSeq>],
PhantomData<TotDecSeq>,
),
(enc_output, input, cache): (GraphTensor, GraphTensor, &[KVCache]),
) -> Self::Output {
let (_, cur_dec_seq) = input.dims2();
let (_, _, prev_dec_seq, _) = cache[0].0.dims4();
// Embed text
let mut x = self.embedding.forward(input);
x += self
.pos_embedding
.slice((
PrevDecSeq::size()..CurDecSeq::size() + PrevDecSeq::size(),
..,
))
.slice((prev_dec_seq..cur_dec_seq + prev_dec_seq, ..))
.contiguous()
.realize::<(CurDecSeq, Const<D_MODEL>)>()
.expand();
.expand_to(x.shape);
// Run through layers and collect new caches
let (mut new_caches, mut enc_states) = (vec![], vec![]);
let (mut new_cache, mut enc_state);
for (i, layer) in self.layers.iter().enumerate() {
(x, enc_state, new_cache) =
layer.forward((x, enc_output, cache[i], PhantomData::<TotDecSeq>));
(x, enc_state, new_cache) = layer.forward((x, enc_output, cache[i]));
new_caches.push(new_cache);
enc_states.push(enc_state);
}
// Run through last norm and output projection
(
self.layer_norm.forward(x).matmul(
self.embedding
.weight
.realize::<R2<VOCAB_SIZE, D_MODEL>>()
.permute(),
),
self.layer_norm
.forward(x)
.matmul(self.embedding.weight.permute((1, 0))),
enc_states,
new_caches,
)
}
}
impl InitModule for TextDecoder {
fn initialize(cx: &mut Graph) -> Self {
impl TextDecoder {
pub fn new(cx: &mut Graph) -> Self {
Self {
embedding: InitModule::initialize(cx),
pos_embedding: cx.tensor(),
layer_norm: LayerNorm::new(true, true, true, 1e-5, cx),
embedding: Embedding::new_permuted(VOCAB_SIZE, D_MODEL, cx),
pos_embedding: cx.tensor((MAX_TARGET_POSITION, D_MODEL)),
layer_norm: LayerNorm::new(D_MODEL, true, true, true, 1e-5, cx),
layers: (0..DEC_LAYERS)
.map(|_| InitModule::initialize(cx))
.map(|_| DecoderTransformerBlock::new(cx))
.collect(),
}
}

View File

@@ -4,25 +4,19 @@ use luminal_nn::Conv2D;
struct Bottleneck<const CH_IN: usize, const CH_OUT: usize> {
cv1: Conv2D<CH_IN, CH_OUT, 3, 3, 1, 1>,
cv2: Conv2D<CH_OUT, CH_OUT, 3, 3, 1, 1>,
residual: bool,
}
impl<
const CH_IN: usize,
const CH_OUT: usize,
Batch: Dimension,
Width: Dimension,
Height: Dimension,
> Module<GraphTensor<(Batch, Const<CH_IN>, Height, Width)>> for Bottleneck<CH_IN, CH_OUT>
impl<const CH_IN: usize, const CH_OUT: usize, Width: Dimension, Height: Dimension>
Module<GraphTensor<(Const<CH_IN>, Height, Width)>> for Bottleneck<CH_IN, CH_OUT>
{
type Output = GraphTensor<(Batch, Const<CH_OUT>, Height, Width)>;
fn forward(&self, input: GraphTensor<(Batch, Const<CH_IN>, Height, Width)>) -> Self::Output {
let out = self.cv2.forward(self.cv1.forward(input));
if self.residual {
out + input
} else {
out
}
type Output = GraphTensor<(Const<CH_OUT>, Height, Width)>;
fn forward(&self, input: GraphTensor<(Const<CH_IN>, Height, Width)>) -> Self::Output {
self.cv2
.forward(
self.cv1
.forward::<Width, Height, Width, Height>(input.permute()),
)
.permute()
}
}

View File

@@ -0,0 +1 @@
{"rustc_fingerprint":7902964948476546481,"outputs":{"1185988223601034215":{"success":true,"status":"","code":0,"stdout":"___\nlib___.rlib\nlib___.dylib\nlib___.dylib\nlib___.a\nlib___.dylib\n/Users/jafioti/.rustup/toolchains/stable-aarch64-apple-darwin\noff\npacked\nunpacked\n___\nclippy\ndebug_assertions\nfeature=\"cargo-clippy\"\npanic=\"unwind\"\nproc_macro\ntarget_abi=\"\"\ntarget_arch=\"aarch64\"\ntarget_endian=\"little\"\ntarget_env=\"\"\ntarget_family=\"unix\"\ntarget_feature=\"aes\"\ntarget_feature=\"crc\"\ntarget_feature=\"dit\"\ntarget_feature=\"dotprod\"\ntarget_feature=\"dpb\"\ntarget_feature=\"dpb2\"\ntarget_feature=\"fcma\"\ntarget_feature=\"fhm\"\ntarget_feature=\"flagm\"\ntarget_feature=\"fp16\"\ntarget_feature=\"frintts\"\ntarget_feature=\"jsconv\"\ntarget_feature=\"lor\"\ntarget_feature=\"lse\"\ntarget_feature=\"neon\"\ntarget_feature=\"paca\"\ntarget_feature=\"pacg\"\ntarget_feature=\"pan\"\ntarget_feature=\"pmuv3\"\ntarget_feature=\"ras\"\ntarget_feature=\"rcpc\"\ntarget_feature=\"rcpc2\"\ntarget_feature=\"rdm\"\ntarget_feature=\"sb\"\ntarget_feature=\"sha2\"\ntarget_feature=\"sha3\"\ntarget_feature=\"ssbs\"\ntarget_feature=\"vh\"\ntarget_has_atomic=\"128\"\ntarget_has_atomic=\"16\"\ntarget_has_atomic=\"32\"\ntarget_has_atomic=\"64\"\ntarget_has_atomic=\"8\"\ntarget_has_atomic=\"ptr\"\ntarget_os=\"macos\"\ntarget_pointer_width=\"64\"\ntarget_vendor=\"apple\"\nunix\n","stderr":""},"4614504638168534921":{"success":true,"status":"","code":0,"stdout":"rustc 1.78.0 (9b00956e5 2024-04-29)\nbinary: rustc\ncommit-hash: 9b00956e56009bab2aa15d7bff10916599e3d6d6\ncommit-date: 2024-04-29\nhost: aarch64-apple-darwin\nrelease: 1.78.0\nLLVM version: 18.1.2\n","stderr":""}},"successes":{}}

View File

@@ -0,0 +1,3 @@
Signature: 8a477f597d28d172789f06886806bc55
# This file is a cache directory tag created by cargo.
# For information about cache directory tags see https://bford.info/cachedir/

View File

@@ -0,0 +1 @@
{"rustc":792111255936306319,"features":"[]","declared_features":"","target":2297296889237502566,"profile":1200860260873630964,"path":14319948826174377348,"deps":[[16079472387499994964,"version_check",false,11544101210379487998]],"local":[{"CheckDepInfo":{"dep_info":"debug/.fingerprint/ahash-278451fb35d9bcf6/dep-build-script-build-script-build"}}],"rustflags":[],"metadata":6548036084630991988,"config":2202906307356721367,"compile_kind":0}

View File

@@ -0,0 +1 @@
This file has an mtime of when this was started.

View File

@@ -0,0 +1 @@
{"rustc":792111255936306319,"features":"","declared_features":"","target":0,"profile":0,"path":0,"deps":[[1385435641494999048,"build_script_build",false,13432309339295883258]],"local":[{"RerunIfChanged":{"output":"debug/build/ahash-8a7ee4fbe4c9c642/output","paths":["build.rs"]}}],"rustflags":[],"metadata":0,"config":0,"compile_kind":0}

View File

@@ -0,0 +1 @@
This file has an mtime of when this was started.

View File

@@ -0,0 +1 @@
281ba81841e89abc

View File

@@ -0,0 +1 @@
{"rustc":792111255936306319,"features":"[]","declared_features":"","target":295758560010665018,"profile":3797293754785534760,"path":12569880893325353741,"deps":[[1385435641494999048,"build_script_build",false,6021049035992531610],[4254328441789853856,"once_cell",false,11839418217726875033],[11228387426131597774,"getrandom",false,3943971895081466132]],"local":[{"CheckDepInfo":{"dep_info":"debug/.fingerprint/ahash-fa8dba1057def52f/dep-lib-ahash"}}],"rustflags":[],"metadata":6548036084630991988,"config":2202906307356721367,"compile_kind":0}

View File

@@ -0,0 +1 @@
This file has an mtime of when this was started.

View File

@@ -0,0 +1 @@
{"rustc":792111255936306319,"features":"[\"default\", \"perf-literal\", \"std\"]","declared_features":"","target":12812136000324506373,"profile":3797293754785534760,"path":917586311715097687,"deps":[[554324495028472449,"memchr",false,12018462279246210472]],"local":[{"CheckDepInfo":{"dep_info":"debug/.fingerprint/aho-corasick-5c94172d53874de6/dep-lib-aho_corasick"}}],"rustflags":[],"metadata":13904389431191498124,"config":2202906307356721367,"compile_kind":0}

View File

@@ -0,0 +1 @@
This file has an mtime of when this was started.

View File

@@ -0,0 +1 @@
cd11a3772bf6054d

View File

@@ -0,0 +1 @@
{"rustc":792111255936306319,"features":"[\"auto\", \"default\", \"wincon\"]","declared_features":"","target":16157420304466204941,"profile":14808885745152445323,"path":13571762662249164363,"deps":[[1200817279721127204,"anstyle_query",false,5233170534109397214],[8720183142424604966,"utf8parse",false,8030522589549628100],[12423115053093093635,"is_terminal_polyfill",false,14350458276302548626],[15873833695005184023,"colorchoice",false,1643507399289422721],[16609132864249042075,"anstyle_parse",false,12395459992756225163],[16999472572377377103,"anstyle",false,17572025060589194804]],"local":[{"CheckDepInfo":{"dep_info":"debug/.fingerprint/anstream-1c56c6d6e6c3fc4a/dep-lib-anstream"}}],"rustflags":[],"metadata":7500874485387469444,"config":2202906307356721367,"compile_kind":0}

View File

@@ -0,0 +1 @@
This file has an mtime of when this was started.

View File

@@ -0,0 +1 @@
344eae28b15fdcf3

View File

@@ -0,0 +1 @@
{"rustc":792111255936306319,"features":"[\"default\", \"std\"]","declared_features":"","target":13663407036240438623,"profile":14808885745152445323,"path":15851815754449844370,"deps":[],"local":[{"CheckDepInfo":{"dep_info":"debug/.fingerprint/anstyle-580976b7c8c459e7/dep-lib-anstyle"}}],"rustflags":[],"metadata":14064844656010464607,"config":2202906307356721367,"compile_kind":0}

View File

@@ -0,0 +1 @@
This file has an mtime of when this was started.

View File

@@ -0,0 +1 @@
{"rustc":792111255936306319,"features":"[\"default\", \"utf8\"]","declared_features":"","target":1993415851866499831,"profile":14808885745152445323,"path":16158110412718354121,"deps":[[8720183142424604966,"utf8parse",false,8030522589549628100]],"local":[{"CheckDepInfo":{"dep_info":"debug/.fingerprint/anstyle-parse-c08eba01cda9523e/dep-lib-anstyle-parse"}}],"rustflags":[],"metadata":9799137552285937175,"config":2202906307356721367,"compile_kind":0}

View File

@@ -0,0 +1 @@
This file has an mtime of when this was started.

View File

@@ -0,0 +1 @@
{"rustc":792111255936306319,"features":"[]","declared_features":"","target":8921697713841910856,"profile":14808885745152445323,"path":17832198768418750173,"deps":[],"local":[{"CheckDepInfo":{"dep_info":"debug/.fingerprint/anstyle-query-549d4ad21e8f4c08/dep-lib-anstyle-query"}}],"rustflags":[],"metadata":10674566383365303417,"config":2202906307356721367,"compile_kind":0}

View File

@@ -0,0 +1 @@
This file has an mtime of when this was started.

View File

@@ -0,0 +1 @@
536fba925d9106a8

View File

@@ -0,0 +1 @@
{"rustc":792111255936306319,"features":"[]","declared_features":"","target":5730234523381508605,"profile":3797293754785534760,"path":7658171823302188217,"deps":[],"local":[{"CheckDepInfo":{"dep_info":"debug/.fingerprint/as-any-66dead5baaa2ce79/dep-lib-as-any"}}],"rustflags":[],"metadata":7688359999514647402,"config":2202906307356721367,"compile_kind":0}

View File

@@ -0,0 +1 @@
This file has an mtime of when this was started.

View File

@@ -0,0 +1 @@
5ff9718763cf4f55

View File

@@ -0,0 +1 @@
{"rustc":792111255936306319,"features":"[]","declared_features":"","target":14886237245231788030,"profile":1200860260873630964,"path":16811576535408259928,"deps":[],"local":[{"CheckDepInfo":{"dep_info":"debug/.fingerprint/autocfg-0446d81e8a877634/dep-lib-autocfg"}}],"rustflags":[],"metadata":13102859075309379048,"config":2202906307356721367,"compile_kind":0}

View File

@@ -0,0 +1 @@
This file has an mtime of when this was started.

View File

@@ -0,0 +1 @@
699e6c6c8950ef05

View File

@@ -0,0 +1 @@
{"rustc":792111255936306319,"features":"[\"default\", \"std\"]","declared_features":"","target":16778825523953873731,"profile":3797293754785534760,"path":1249145684557563819,"deps":[],"local":[{"CheckDepInfo":{"dep_info":"debug/.fingerprint/base64-ea033a9df982ece9/dep-lib-base64"}}],"rustflags":[],"metadata":13936919950537592407,"config":2202906307356721367,"compile_kind":0}

View File

@@ -0,0 +1 @@
This file has an mtime of when this was started.

View File

@@ -0,0 +1 @@
72015ae1027454db

View File

@@ -0,0 +1 @@
{"rustc":792111255936306319,"features":"[\"default\"]","declared_features":"","target":15712369643656012375,"profile":3797293754785534760,"path":12968319784194130589,"deps":[],"local":[{"CheckDepInfo":{"dep_info":"debug/.fingerprint/bitflags-9f46fb88979eb9ab/dep-lib-bitflags"}}],"rustflags":[],"metadata":14564035643000669268,"config":2202906307356721367,"compile_kind":0}

View File

@@ -0,0 +1 @@
This file has an mtime of when this was started.

View File

@@ -0,0 +1 @@
5d13ff0869c86fc5

View File

@@ -0,0 +1 @@
{"rustc":792111255936306319,"features":"[\"default\", \"std\"]","declared_features":"","target":18335588937564793828,"profile":3797293754785534760,"path":16883050064718116216,"deps":[],"local":[{"CheckDepInfo":{"dep_info":"debug/.fingerprint/byteorder-221dfff828506489/dep-lib-byteorder"}}],"rustflags":[],"metadata":5398730104718078656,"config":2202906307356721367,"compile_kind":0}

View File

@@ -0,0 +1 @@
This file has an mtime of when this was started.

View File

@@ -0,0 +1 @@
0a15848cbceaa920

View File

@@ -0,0 +1 @@
{"rustc":792111255936306319,"features":"[]","declared_features":"","target":15023190189141807623,"profile":1200860260873630964,"path":14622872267285471233,"deps":[],"local":[{"CheckDepInfo":{"dep_info":"debug/.fingerprint/cc-a8cbe2cdd92a7374/dep-lib-cc"}}],"rustflags":[],"metadata":5862599371499774553,"config":2202906307356721367,"compile_kind":0}

View File

@@ -0,0 +1 @@
This file has an mtime of when this was started.

View File

@@ -0,0 +1 @@
0f5b677c9f3650eb

View File

@@ -0,0 +1 @@
{"rustc":792111255936306319,"features":"[]","declared_features":"","target":10623512480563079566,"profile":3797293754785534760,"path":16945596033860744345,"deps":[],"local":[{"CheckDepInfo":{"dep_info":"debug/.fingerprint/cfg-if-45b3cfd05c61975a/dep-lib-cfg-if"}}],"rustflags":[],"metadata":8462187951337715540,"config":2202906307356721367,"compile_kind":0}

Some files were not shown because too many files have changed in this diff Show More