forked from Rust-related/luminal
Switched to runtime shapes
This commit is contained in:
@@ -145,8 +145,8 @@ mod tests {
|
||||
#[test]
|
||||
fn test_matmul() {
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor::<(Dyn<'M'>, Dyn<'K'>)>();
|
||||
let b = cx.tensor::<(Dyn<'K'>, Dyn<'N'>)>();
|
||||
let a = cx.tensor(('M', 'K'));
|
||||
let b = cx.tensor(('K', 'N'));
|
||||
let mut c = a.matmul(b).retrieve();
|
||||
|
||||
cx.compile(CPUCompiler::default(), &mut c);
|
||||
@@ -158,8 +158,8 @@ mod tests {
|
||||
let mut rng = StdRng::seed_from_u64(0);
|
||||
let a_data = random_vec_rng(m * k, &mut rng);
|
||||
let b_data = random_vec_rng(k * n, &mut rng);
|
||||
a.set_dyn(a_data.clone(), &[m, k]);
|
||||
b.set_dyn(b_data.clone(), &[k, n]);
|
||||
a.set_dyn(a_data.clone(), (m, k));
|
||||
b.set_dyn(b_data.clone(), (k, n));
|
||||
|
||||
cx.execute();
|
||||
|
||||
@@ -177,9 +177,9 @@ mod tests {
|
||||
#[test]
|
||||
fn test_cpu_matmul_2d_2() {
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor::<R2<2, 3>>();
|
||||
let a = cx.tensor((2, 3));
|
||||
a.set(vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0]);
|
||||
let b = cx.tensor::<R2<3, 4>>();
|
||||
let b = cx.tensor((3, 4));
|
||||
b.set(vec![1., 2., 3., 1., 2., 3., 1., 2., 3., 1., 2., 3.]);
|
||||
let mut c = a.matmul(b).retrieve();
|
||||
|
||||
|
||||
@@ -530,11 +530,11 @@ mod tests {
|
||||
fn test_subtraction() {
|
||||
let mut cx = Graph::new();
|
||||
let a = cx
|
||||
.tensor::<R1<10>>()
|
||||
.tensor(10)
|
||||
.set(vec![1., 2., 3., 4., 5., 6., 7., 8., 9., 10.]);
|
||||
let b = cx.tensor::<R0>().set(vec![1.]);
|
||||
let mut c = (a - b.expand()).retrieve();
|
||||
let mut d = (-a + b.expand()).retrieve();
|
||||
let b = cx.tensor(()).set(vec![1.]);
|
||||
let mut c = (a - b.expand_to(a.shape)).retrieve();
|
||||
let mut d = (-a + b.expand_to(a.shape)).retrieve();
|
||||
|
||||
cx.execute();
|
||||
|
||||
|
||||
@@ -306,9 +306,9 @@ fn test_common_buffer() {
|
||||
|
||||
use crate::MetalCompiler;
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor::<R1<5>>().set(random_vec(5)).keep();
|
||||
let b = cx.tensor::<R1<5>>().set(random_vec(5)).keep();
|
||||
let c = cx.tensor::<R1<5>>().set(random_vec(5)).keep();
|
||||
let a = cx.tensor(5).set(random_vec(5)).keep();
|
||||
let b = cx.tensor(5).set(random_vec(5)).keep();
|
||||
let c = cx.tensor(5).set(random_vec(5)).keep();
|
||||
let mut d = ((a + b) * c).retrieve();
|
||||
|
||||
cx.execute();
|
||||
|
||||
@@ -118,7 +118,7 @@ impl<T: MetalFloat> Compiler for ElementwiseFusionCompiler<T> {
|
||||
let mut subexpressions_b = graph
|
||||
.try_get_op::<FusedElementwiseOp<T>>(b)
|
||||
.map(|o| o.subexpressions.clone())
|
||||
.unwrap_or_else(|| vec![(expression_b.clone(), ShapeTracker::new(&[]))]);
|
||||
.unwrap_or_else(|| vec![(expression_b.clone(), ShapeTracker::default())]);
|
||||
let a_to_b_indexes = graph
|
||||
.edges_connecting(a, b)
|
||||
.map(|e| e.weight().as_data().unwrap().0 as usize)
|
||||
@@ -138,7 +138,7 @@ impl<T: MetalFloat> Compiler for ElementwiseFusionCompiler<T> {
|
||||
let mut subexpressions_a = graph
|
||||
.try_get_op::<FusedElementwiseOp<T>>(a)
|
||||
.map(|o| o.subexpressions.clone())
|
||||
.unwrap_or_else(|| vec![(expression_a.clone(), ShapeTracker::new(&[]))]);
|
||||
.unwrap_or_else(|| vec![(expression_a.clone(), ShapeTracker::default())]);
|
||||
subexpressions_a.last_mut().unwrap().1 = connecting_shape;
|
||||
// Re-reference b intermediates
|
||||
for i in (0..subexpressions_b.len()).rev() {
|
||||
@@ -296,7 +296,7 @@ impl<T: MetalFloat> Compiler for ElementwiseFusionCompiler<T> {
|
||||
kernel: None,
|
||||
dyn_map: &graph.dyn_map,
|
||||
dyn_chars: vec![],
|
||||
subexpressions: vec![(op_string, ShapeTracker::new(&[]))],
|
||||
subexpressions: vec![(op_string, ShapeTracker::default())],
|
||||
queue: queue.clone(),
|
||||
device: device.clone(),
|
||||
output_buffer_sizes,
|
||||
@@ -593,9 +593,8 @@ mod tests {
|
||||
prelude::{binary::F32Pow, *},
|
||||
tests::{assert_close, assert_close_precision, random_vec, random_vec_rng},
|
||||
};
|
||||
use luminal_nn::*;
|
||||
use luminal_nn::{LayerNorm, Linear};
|
||||
use rand::{rngs::StdRng, SeedableRng};
|
||||
use std::{marker::PhantomData, ops::Div};
|
||||
|
||||
use crate::MetalCompiler;
|
||||
|
||||
@@ -603,7 +602,7 @@ mod tests {
|
||||
fn test_fusion_simple() {
|
||||
let mut cx = Graph::new();
|
||||
let mut rng = StdRng::seed_from_u64(0);
|
||||
let inp = cx.tensor::<R1<5>>().set(random_vec_rng(10, &mut rng));
|
||||
let inp = cx.tensor(5).set(random_vec_rng(10, &mut rng));
|
||||
let mut out = inp.exp2().cos().sqrt().retrieve();
|
||||
|
||||
cx.execute();
|
||||
@@ -619,8 +618,8 @@ mod tests {
|
||||
fn test_fusion_binary() {
|
||||
let mut cx = Graph::new();
|
||||
let mut rng = StdRng::seed_from_u64(0);
|
||||
let a = cx.tensor::<R1<5>>().set(random_vec_rng(10, &mut rng));
|
||||
let b = cx.tensor::<R1<5>>().set(random_vec_rng(10, &mut rng));
|
||||
let a = cx.tensor(5).set(random_vec_rng(10, &mut rng));
|
||||
let b = cx.tensor(5).set(random_vec_rng(10, &mut rng));
|
||||
let mut out = (a.exp2() + b.cos()).retrieve();
|
||||
|
||||
cx.execute();
|
||||
@@ -636,9 +635,9 @@ mod tests {
|
||||
#[test]
|
||||
fn test_fusion_subexpression_complex() {
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.named_tensor::<R1<10>>("a").set(random_vec(10)).keep();
|
||||
let b = cx.named_tensor::<R1<10>>("b").set(random_vec(10)).keep();
|
||||
let d = cx.named_tensor::<R1<10>>("d").set(random_vec(10)).keep();
|
||||
let a = cx.named_tensor("a", 10).set(random_vec(10)).keep();
|
||||
let b = cx.named_tensor("b", 10).set(random_vec(10)).keep();
|
||||
let d = cx.named_tensor("d", 10).set(random_vec(10)).keep();
|
||||
let mut out = ((a.exp2() - b.sin()).sin() * 3.4).less_than(d).retrieve();
|
||||
|
||||
cx.execute();
|
||||
@@ -656,12 +655,11 @@ mod tests {
|
||||
let mut cx = Graph::new();
|
||||
let mut rng = StdRng::seed_from_u64(0);
|
||||
let inp = random_vec_rng(10, &mut rng);
|
||||
let a = cx.named_tensor::<R2<2, 5>>("a").set(inp);
|
||||
let a = cx.named_tensor("a", (2, 5)).set(inp);
|
||||
let mut padded = a
|
||||
.slice((..Expression::from(1), ..))
|
||||
.realize::<R2<1, 5>>()
|
||||
.cos()
|
||||
.pad::<R2<2, 5>>(((0, 1), (0, 0)))
|
||||
.pad(((0, 1), (0, 0)))
|
||||
.exp2()
|
||||
.retrieve();
|
||||
cx.execute();
|
||||
@@ -682,7 +680,7 @@ mod tests {
|
||||
let mut cx = Graph::new();
|
||||
let mut rng = StdRng::seed_from_u64(0);
|
||||
let data = random_vec_rng(10, &mut rng);
|
||||
let a = cx.tensor::<R2<2, 5>>().set(data);
|
||||
let a = cx.tensor((2, 5)).set(data);
|
||||
let mut out = (a.sqrt().exp() + a.sqrt().sin()).retrieve();
|
||||
cx.execute();
|
||||
let unopt_out = out.data();
|
||||
@@ -699,14 +697,10 @@ mod tests {
|
||||
let mut cx = Graph::new();
|
||||
const SEQ: usize = 2;
|
||||
const HEAD_DIM: usize = 4;
|
||||
const HEAD_DIM_OVER_2: usize = HEAD_DIM / 2;
|
||||
let freqs = (cx.arange::<Const<HEAD_DIM_OVER_2>>() * 2.0) / (HEAD_DIM as f32);
|
||||
let freqs = (cx.arange(HEAD_DIM / 2) * 2.0) / (HEAD_DIM as f32);
|
||||
let freqs = 1000000_f32.pow(freqs);
|
||||
let pos = cx.arange::<Const<SEQ>>() + BigExpression::from(0);
|
||||
let mut emb = pos
|
||||
.expand::<(_, Const<1>), _>()
|
||||
.matmul(freqs.expand())
|
||||
.retrieve();
|
||||
let pos = cx.arange(SEQ) + BigExpression::from(0);
|
||||
let mut emb = pos.expand(1, 1).matmul(freqs.expand(0, SEQ)).retrieve();
|
||||
|
||||
cx.execute();
|
||||
let unopt_out = emb.data();
|
||||
@@ -725,27 +719,25 @@ mod tests {
|
||||
const HEAD_DIM: usize = 4;
|
||||
const HEAD_DIM_OVER_2: usize = HEAD_DIM / 2;
|
||||
let a = cx
|
||||
.named_tensor::<R2<SEQ, HEAD_DIM>>("a")
|
||||
.named_tensor("a", (SEQ, HEAD_DIM))
|
||||
.set(random_vec_rng(SEQ * HEAD_DIM, &mut rng))
|
||||
.keep();
|
||||
let b = cx
|
||||
.tensor::<R3<SEQ, HEAD_DIM_OVER_2, 1>>()
|
||||
.tensor((SEQ, HEAD_DIM / 2, 1))
|
||||
.set(random_vec_rng(SEQ * HEAD_DIM_OVER_2, &mut rng))
|
||||
.keep();
|
||||
// Split input into evens and odds
|
||||
let split = a.reshape::<R3<SEQ, HEAD_DIM_OVER_2, 2>>();
|
||||
let x0: GraphTensor<R3<SEQ, HEAD_DIM_OVER_2, 1>> =
|
||||
split.slice((.., .., ..Expression::from(1))).realize();
|
||||
let x1: GraphTensor<R3<SEQ, HEAD_DIM_OVER_2, 1>> =
|
||||
split.slice((.., .., Expression::from(1)..)).realize();
|
||||
let split = a.reshape((SEQ, HEAD_DIM / 2, 2));
|
||||
let x0 = split.slice((.., .., ..1));
|
||||
let x1 = split.slice((.., .., 1..));
|
||||
|
||||
let x0_out = x0 * b - x1 * b.cos();
|
||||
let x1_out = x0 + x1;
|
||||
|
||||
// Combine back into output
|
||||
let mut out: GraphTensor<R2<SEQ, HEAD_DIM>> = x0_out
|
||||
.concat_along::<R3<SEQ, HEAD_DIM_OVER_2, 2>, Axis<2>, _>(x1_out)
|
||||
.reshape()
|
||||
let mut out = x0_out
|
||||
.concat_along(x1_out, 2)
|
||||
.reshape((SEQ, HEAD_DIM))
|
||||
.retrieve();
|
||||
cx.execute();
|
||||
|
||||
@@ -765,34 +757,31 @@ mod tests {
|
||||
const N_HEADS: usize = 8;
|
||||
const SEQ: usize = 2;
|
||||
const HEAD_DIM: usize = 4;
|
||||
const HEAD_DIM_OVER_2: usize = HEAD_DIM / 2;
|
||||
let a = cx
|
||||
.named_tensor::<R4<BATCH, N_HEADS, SEQ, HEAD_DIM>>("a")
|
||||
.named_tensor("a", (BATCH, N_HEADS, SEQ, HEAD_DIM))
|
||||
.set(random_vec_rng(BATCH * N_HEADS * SEQ * HEAD_DIM, &mut rng))
|
||||
.keep();
|
||||
let freqs = (cx.arange::<Const<HEAD_DIM_OVER_2>>() * 2.0) / (HEAD_DIM as f32);
|
||||
let freqs = (cx.arange(HEAD_DIM / 2) * 2.0) / (HEAD_DIM as f32);
|
||||
let freqs = 1000000_f32.pow(freqs);
|
||||
let pos = cx.arange::<Const<SEQ>>() + BigExpression::from(0);
|
||||
let emb = pos.expand::<(_, Const<1>), _>().matmul(freqs.expand());
|
||||
let pos = cx.arange(SEQ) + BigExpression::from(0);
|
||||
let emb = pos.expand(1, 1).matmul(freqs.expand(0, SEQ));
|
||||
// Split input into evens and odds
|
||||
let split = a.reshape::<R5<BATCH, N_HEADS, SEQ, HEAD_DIM_OVER_2, 2>>();
|
||||
let x0: GraphTensor<R5<BATCH, N_HEADS, SEQ, HEAD_DIM_OVER_2, 1>> = split
|
||||
let split = a.reshape((BATCH, N_HEADS, SEQ, HEAD_DIM / 2, 2));
|
||||
let x0 = split
|
||||
.slice((.., .., .., .., ..Expression::from(1)))
|
||||
.contiguous()
|
||||
.realize();
|
||||
let x1: GraphTensor<R5<BATCH, N_HEADS, SEQ, HEAD_DIM_OVER_2, 1>> = split
|
||||
.contiguous();
|
||||
let x1 = split
|
||||
.slice((.., .., .., .., Expression::from(1)..))
|
||||
.contiguous()
|
||||
.realize();
|
||||
.contiguous();
|
||||
|
||||
// Apply sin and cos embeddings
|
||||
let x0_out = x0 * emb.cos().expand() - x1 * emb.sin().expand();
|
||||
let x1_out = x0 * emb.sin().expand() + x1 * emb.cos().expand();
|
||||
let x0_out = x0 * emb.cos().expand_to(x0.shape) - x1 * emb.sin().expand_to(x1.shape);
|
||||
let x1_out = x0 * emb.sin().expand_to(x0.shape) + x1 * emb.cos().expand_to(x1.shape);
|
||||
|
||||
// Combine back into output
|
||||
let mut out: GraphTensor<R4<BATCH, N_HEADS, SEQ, HEAD_DIM>> = x0_out
|
||||
.concat_along::<R5<BATCH, N_HEADS, SEQ, HEAD_DIM_OVER_2, 2>, Axis<4>, _>(x1_out)
|
||||
.reshape()
|
||||
let mut out = x0_out
|
||||
.concat_along(x1_out, 4)
|
||||
.reshape((BATCH, N_HEADS, SEQ, HEAD_DIM))
|
||||
.retrieve();
|
||||
cx.execute();
|
||||
let unopt_out = out.data();
|
||||
@@ -814,176 +803,167 @@ mod tests {
|
||||
pub const SEQ_LEN: usize = 65;
|
||||
pub const N_ATTENTION_GROUPS: usize = N_HEADS / N_KV_HEADS;
|
||||
pub const HEAD_DIM: usize = HIDDEN_DIM / N_HEADS;
|
||||
pub const HEAD_DIM_OVER_2: usize = HEAD_DIM / 2;
|
||||
pub const ATTN_PROJ_DIM: usize = HEAD_DIM * N_KV_HEADS;
|
||||
pub struct Mlp<const I: usize, const H: usize> {
|
||||
pub gate_proj: PermutedLinear<H, I>,
|
||||
pub down_proj: PermutedLinear<I, H>,
|
||||
pub up_proj: PermutedLinear<H, I>,
|
||||
pub type KVCache = (GraphTensor, GraphTensor);
|
||||
|
||||
pub struct Mlp {
|
||||
pub gate_proj: Linear, // hidden -> intermediate
|
||||
pub down_proj: Linear, // intermediate -> hidden
|
||||
pub up_proj: Linear, // hidden -> intermediate
|
||||
}
|
||||
|
||||
pub type KVCache<Batch, Seq> = (
|
||||
GraphTensor<(Batch, Const<N_KV_HEADS>, Seq, Const<HEAD_DIM>)>,
|
||||
GraphTensor<(Batch, Const<N_KV_HEADS>, Seq, Const<HEAD_DIM>)>,
|
||||
);
|
||||
impl Module<GraphTensor> for Mlp {
|
||||
type Output = GraphTensor;
|
||||
|
||||
impl<Sh: Shape, Im: Shape, const I: usize, const H: usize> Module<GraphTensor<Sh>> for Mlp<I, H>
|
||||
where
|
||||
GraphTensor<Sh>: Matmul<R2<H, I>, Output = GraphTensor<Im>>,
|
||||
GraphTensor<Im>: Matmul<R2<I, H>, Output = GraphTensor<Sh>>,
|
||||
{
|
||||
type Output = GraphTensor<Sh>;
|
||||
|
||||
fn forward(&self, input: GraphTensor<Sh>) -> Self::Output {
|
||||
fn forward(&self, input: GraphTensor) -> Self::Output {
|
||||
let gate = self.gate_proj.forward(input).swish();
|
||||
let up = self.up_proj.forward(input) * gate;
|
||||
self.down_proj.forward(up)
|
||||
}
|
||||
}
|
||||
impl<const I: usize, const H: usize> InitModule for Mlp<I, H> {
|
||||
fn initialize(cx: &mut Graph) -> Self {
|
||||
|
||||
impl Mlp {
|
||||
pub fn new(hidden: usize, intermediate: usize, cx: &mut Graph) -> Self {
|
||||
Self {
|
||||
gate_proj: InitModule::initialize(cx),
|
||||
up_proj: InitModule::initialize(cx),
|
||||
down_proj: InitModule::initialize(cx),
|
||||
gate_proj: Linear::new_permuted(hidden, intermediate, false, cx),
|
||||
down_proj: Linear::new_permuted(intermediate, hidden, false, cx),
|
||||
up_proj: Linear::new_permuted(hidden, intermediate, false, cx),
|
||||
}
|
||||
}
|
||||
}
|
||||
fn apply_rotary_embeddings_ggml<const N_HEADS: usize, Batch: Dimension, Seq: Dimension>(
|
||||
input: GraphTensor<(Batch, Const<N_HEADS>, Seq, Const<HEAD_DIM>)>,
|
||||
|
||||
impl SerializeModule for Mlp {
|
||||
fn serialize(&self, s: &mut Serializer) {
|
||||
s.module("ffn_gate", &self.gate_proj);
|
||||
s.module("ffn_up", &self.up_proj);
|
||||
s.module("ffn_down", &self.down_proj);
|
||||
}
|
||||
}
|
||||
|
||||
fn apply_rotary_embeddings_ggml(
|
||||
input: GraphTensor,
|
||||
prev_seq: BigExpression,
|
||||
) -> GraphTensor<(Batch, Const<N_HEADS>, Seq, Const<HEAD_DIM>)> {
|
||||
) -> GraphTensor {
|
||||
assert_eq!(input.shape.len(), 4); // batch, n_heads, seq, head_dim
|
||||
let batch = input.shape()[0].small();
|
||||
let n_heads = input.shape()[1].small();
|
||||
let seq = input.shape()[2].small();
|
||||
let head_dim = input.shape()[3].small();
|
||||
// Get freqs
|
||||
let freqs =
|
||||
(input.graph().arange::<Const<HEAD_DIM_OVER_2>>() * 2.0) / (HEAD_DIM as f32);
|
||||
let freqs = 1000000_f32.pow(freqs);
|
||||
let pos = input.graph().arange::<Seq>() + prev_seq;
|
||||
let emb = pos.expand::<(_, Const<1>), _>().matmul(freqs.expand());
|
||||
(input.graph().arange(head_dim / 2) * 2.0) / (head_dim.to_usize().unwrap() as f32);
|
||||
let freqs = 500_000_f32.pow(freqs);
|
||||
let pos = input.graph().arange(seq) + prev_seq;
|
||||
let emb = pos.expand(1, 1).matmul(freqs.expand(0, seq));
|
||||
|
||||
// Split input into evens and odds
|
||||
let split =
|
||||
input.reshape::<(Batch, Const<N_HEADS>, Seq, Const<HEAD_DIM_OVER_2>, Const<2>)>();
|
||||
let x0: GraphTensor<(Batch, Const<N_HEADS>, Seq, Const<HEAD_DIM_OVER_2>, Const<1>)> =
|
||||
split
|
||||
.slice((.., .., .., .., ..Expression::from(1)))
|
||||
.contiguous()
|
||||
.realize();
|
||||
let x1: GraphTensor<(Batch, Const<N_HEADS>, Seq, Const<HEAD_DIM_OVER_2>, Const<1>)> =
|
||||
split
|
||||
.slice((.., .., .., .., Expression::from(1)..))
|
||||
.contiguous()
|
||||
.realize();
|
||||
let split = input.reshape((batch, n_heads, seq, head_dim / 2, 2));
|
||||
let x0 = split.slice((.., .., .., .., ..1));
|
||||
let x1 = split.slice((.., .., .., .., 1..));
|
||||
|
||||
// Apply sin and cos embeddings
|
||||
let x0_out = x0 * emb.cos().expand() - x1 * emb.sin().expand();
|
||||
let x1_out = x0 * emb.sin().expand() + x1 * emb.cos().expand();
|
||||
let x0_out = x0 * emb.cos().expand_to(x0.shape) - x1 * emb.sin().expand_to(x1.shape);
|
||||
let x1_out = x0 * emb.sin().expand_to(x0.shape) + x1 * emb.cos().expand_to(x1.shape);
|
||||
|
||||
// Combine back into output
|
||||
x0_out
|
||||
.concat_along::<(Batch, Const<N_HEADS>, Seq, Const<HEAD_DIM_OVER_2>, Const<2>), Axis<4>, _>(
|
||||
x1_out,
|
||||
)
|
||||
.reshape()
|
||||
}
|
||||
pub struct SelfAttention {
|
||||
pub q_proj: GraphTensor<R2<HIDDEN_DIM, HIDDEN_DIM>>,
|
||||
pub k_proj: GraphTensor<R2<ATTN_PROJ_DIM, HIDDEN_DIM>>,
|
||||
pub v_proj: GraphTensor<R2<ATTN_PROJ_DIM, HIDDEN_DIM>>,
|
||||
pub o_proj: GraphTensor<R2<HIDDEN_DIM, HIDDEN_DIM>>,
|
||||
x0_out.concat_along(x1_out, 4).reshape(input.shape)
|
||||
}
|
||||
|
||||
impl<Batch: Dimension, CurSeq: Dimension, PrevSeq: Dimension, TotSeq: Dimension>
|
||||
Module<(
|
||||
GraphTensor<(Batch, CurSeq, Const<HIDDEN_DIM>)>,
|
||||
KVCache<Batch, PrevSeq>,
|
||||
PhantomData<TotSeq>,
|
||||
)> for SelfAttention
|
||||
{
|
||||
type Output = (
|
||||
GraphTensor<(Batch, CurSeq, Const<HIDDEN_DIM>)>,
|
||||
KVCache<Batch, TotSeq>,
|
||||
);
|
||||
fn forward(
|
||||
&self,
|
||||
(x, (k_cache, v_cache), _): (
|
||||
GraphTensor<(Batch, CurSeq, Const<HIDDEN_DIM>)>,
|
||||
KVCache<Batch, PrevSeq>,
|
||||
PhantomData<TotSeq>,
|
||||
),
|
||||
) -> Self::Output {
|
||||
pub struct SelfAttention {
|
||||
pub q_proj: GraphTensor, // Hidden -> hidden
|
||||
pub k_proj: GraphTensor, // Proj dim -> hidden
|
||||
pub v_proj: GraphTensor, // Proj dim -> hidden
|
||||
pub o_proj: GraphTensor, // Hidden -> hidden
|
||||
}
|
||||
|
||||
impl Module<(GraphTensor, KVCache)> for SelfAttention {
|
||||
type Output = (GraphTensor, KVCache);
|
||||
fn forward(&self, (x, (k_cache, v_cache)): (GraphTensor, KVCache)) -> Self::Output {
|
||||
// x: batch, seq, hidden
|
||||
let batch = x.shape()[0].small();
|
||||
let seq = x.shape()[1].small();
|
||||
let prev_seq = k_cache.shape()[2].small();
|
||||
// Apply the Projections
|
||||
let queries = x
|
||||
.matmul(self.q_proj.permute())
|
||||
.reshape::<(Batch, CurSeq, Const<N_HEADS>, Const<HEAD_DIM>)>()
|
||||
.permute::<_, Axes4<0, 2, 1, 3>>();
|
||||
.matmul(self.q_proj.permute((1, 0)))
|
||||
.reshape((batch, seq, N_HEADS, HEAD_DIM))
|
||||
.permute((0, 2, 1, 3));
|
||||
|
||||
let keys = x
|
||||
.matmul(self.k_proj.permute())
|
||||
.reshape::<(Batch, CurSeq, Const<N_KV_HEADS>, Const<HEAD_DIM>)>()
|
||||
.permute::<_, Axes4<0, 2, 1, 3>>();
|
||||
.matmul(self.k_proj.permute((1, 0)))
|
||||
.reshape((batch, seq, N_KV_HEADS, HEAD_DIM))
|
||||
.permute((0, 2, 1, 3));
|
||||
|
||||
let values = x
|
||||
.matmul(self.v_proj.permute())
|
||||
.reshape::<(Batch, CurSeq, Const<N_KV_HEADS>, Const<HEAD_DIM>)>()
|
||||
.permute::<_, Axes4<0, 2, 1, 3>>();
|
||||
.matmul(self.v_proj.permute((1, 0)))
|
||||
.reshape((batch, seq, N_KV_HEADS, HEAD_DIM))
|
||||
.permute((0, 2, 1, 3));
|
||||
|
||||
// Rotary embed queries and keys
|
||||
let queries = apply_rotary_embeddings_ggml(queries, PrevSeq::size().big());
|
||||
let keys = apply_rotary_embeddings_ggml(keys, PrevSeq::size().big());
|
||||
let queries = apply_rotary_embeddings_ggml(queries, prev_seq.big());
|
||||
let keys = apply_rotary_embeddings_ggml(keys, prev_seq.big());
|
||||
|
||||
// Add KV cache
|
||||
let (keys, values) = (
|
||||
k_cache.concat_along::<_, Axis<2>, _>(keys),
|
||||
v_cache.concat_along::<_, Axis<2>, _>(values),
|
||||
);
|
||||
let keys = k_cache.concat_along(keys, 2);
|
||||
let values = v_cache.concat_along(values, 2);
|
||||
|
||||
// Repeat the KV States for Grouped-Query Attention
|
||||
let repeated_keys = keys.expand::<(_, _, Const<N_ATTENTION_GROUPS>, _, _), _>();
|
||||
let repeated_values = values.expand::<(_, _, Const<N_ATTENTION_GROUPS>, _, _), _>();
|
||||
let repeated_keys = keys.expand(2, N_ATTENTION_GROUPS);
|
||||
let repeated_values = values.expand(2, N_ATTENTION_GROUPS);
|
||||
|
||||
// Calculate attention weights
|
||||
let mut attention_weights = queries
|
||||
.reshape::<(_, Const<N_KV_HEADS>, Const<N_ATTENTION_GROUPS>, _, _)>() // Split query heads into groups
|
||||
.matmul(repeated_keys.permute())
|
||||
.div((HEAD_DIM as f32).sqrt());
|
||||
.reshape((batch, N_KV_HEADS, N_ATTENTION_GROUPS, seq, HEAD_DIM)) // Split query heads into groups
|
||||
.matmul(repeated_keys.permute((0, 1, 2, 4, 3)))
|
||||
/ (HEAD_DIM as f32).sqrt();
|
||||
|
||||
let attention_mask = self.k_proj.graph().triu::<CurSeq>(1) * f16::MIN.to_f32();
|
||||
let attention_mask = self.k_proj.graph().triu(seq, 1) * f16::MIN.to_f32();
|
||||
attention_weights += attention_mask
|
||||
.pad::<(CurSeq, TotSeq)>(((0, 0), (TotSeq::size() - CurSeq::size(), 0)))
|
||||
.expand();
|
||||
.pad(((0, 0), (prev_seq, 0)))
|
||||
.expand(0, batch)
|
||||
.expand(1, N_KV_HEADS)
|
||||
.expand(2, N_ATTENTION_GROUPS);
|
||||
|
||||
// Calculate final outputs
|
||||
let output = attention_weights
|
||||
.softmax::<Axis<4>>()
|
||||
.softmax(4)
|
||||
// Apply distribution to values
|
||||
.matmul(repeated_values)
|
||||
// Merge heads
|
||||
.permute::<_, Axes5<0, 3, 1, 2, 4>>()
|
||||
.reshape::<(Batch, CurSeq, Const<HIDDEN_DIM>)>();
|
||||
.permute((0, 3, 1, 2, 4))
|
||||
.reshape((batch, seq, HIDDEN_DIM));
|
||||
let output = output
|
||||
// Apply output projection
|
||||
.matmul(self.o_proj.permute());
|
||||
.matmul(self.o_proj.permute((1, 0)));
|
||||
(output, (keys.contiguous(), values.contiguous())) // Cache needs to be contiguous for transferring to another graph
|
||||
}
|
||||
}
|
||||
|
||||
impl InitModule for SelfAttention {
|
||||
fn initialize(cx: &mut Graph) -> Self {
|
||||
impl SelfAttention {
|
||||
pub fn new(cx: &mut Graph) -> Self {
|
||||
Self {
|
||||
q_proj: cx
|
||||
.named_tensor("Q Proj")
|
||||
.set(random_vec(HIDDEN_DIM * HIDDEN_DIM)),
|
||||
k_proj: cx
|
||||
.named_tensor("K Proj")
|
||||
.set(random_vec(ATTN_PROJ_DIM * HIDDEN_DIM)),
|
||||
v_proj: cx
|
||||
.named_tensor("V Proj")
|
||||
.set(random_vec(ATTN_PROJ_DIM * HIDDEN_DIM)),
|
||||
o_proj: cx
|
||||
.named_tensor("O Proj")
|
||||
.set(random_vec(HIDDEN_DIM * HIDDEN_DIM)),
|
||||
q_proj: cx.named_tensor("Q Proj", (HIDDEN_DIM, HIDDEN_DIM)),
|
||||
k_proj: cx.named_tensor("K Proj", (ATTN_PROJ_DIM, HIDDEN_DIM)),
|
||||
v_proj: cx.named_tensor("V Proj", (ATTN_PROJ_DIM, HIDDEN_DIM)),
|
||||
o_proj: cx.named_tensor("O Proj", (HIDDEN_DIM, HIDDEN_DIM)),
|
||||
}
|
||||
}
|
||||
|
||||
fn initialize(self) -> Self {
|
||||
self.k_proj.set(random_vec(
|
||||
self.k_proj.shape.n_elements().to_usize().unwrap(),
|
||||
));
|
||||
self.o_proj.set(random_vec(
|
||||
self.o_proj.shape.n_elements().to_usize().unwrap(),
|
||||
));
|
||||
self.v_proj.set(random_vec(
|
||||
self.v_proj.shape.n_elements().to_usize().unwrap(),
|
||||
));
|
||||
self.q_proj.set(random_vec(
|
||||
self.q_proj.shape.n_elements().to_usize().unwrap(),
|
||||
));
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl SerializeModule for SelfAttention {
|
||||
@@ -997,35 +977,18 @@ mod tests {
|
||||
|
||||
pub struct TransformerBlock {
|
||||
pub attention: SelfAttention,
|
||||
pub attention_norm: LayerNorm<HIDDEN_DIM>,
|
||||
pub feed_forward: Mlp<MLP_DIM, HIDDEN_DIM>,
|
||||
pub feed_forward_norm: LayerNorm<HIDDEN_DIM>,
|
||||
pub attention_norm: LayerNorm,
|
||||
pub feed_forward: Mlp,
|
||||
pub feed_forward_norm: LayerNorm,
|
||||
}
|
||||
|
||||
impl<Batch: Dimension, CurSeq: Dimension, PrevSeq: Dimension, TotSeq: Dimension>
|
||||
Module<(
|
||||
GraphTensor<(Batch, CurSeq, Const<HIDDEN_DIM>)>,
|
||||
KVCache<Batch, PrevSeq>,
|
||||
PhantomData<TotSeq>,
|
||||
)> for TransformerBlock
|
||||
{
|
||||
type Output = (
|
||||
GraphTensor<(Batch, CurSeq, Const<HIDDEN_DIM>)>,
|
||||
KVCache<Batch, TotSeq>,
|
||||
);
|
||||
fn forward(
|
||||
&self,
|
||||
(mut x, cache, _): (
|
||||
GraphTensor<(Batch, CurSeq, Const<HIDDEN_DIM>)>,
|
||||
KVCache<Batch, PrevSeq>,
|
||||
PhantomData<TotSeq>,
|
||||
),
|
||||
) -> Self::Output {
|
||||
impl Module<(GraphTensor, KVCache)> for TransformerBlock {
|
||||
type Output = (GraphTensor, KVCache);
|
||||
fn forward(&self, (mut x, cache): (GraphTensor, KVCache)) -> Self::Output {
|
||||
// Attention
|
||||
let normed = self.attention_norm.forward(x);
|
||||
let (y, cache) = self
|
||||
.attention
|
||||
.forward((normed, cache, PhantomData::<TotSeq>));
|
||||
.forward((self.attention_norm.forward(x), cache));
|
||||
|
||||
// Residual Addition
|
||||
x += y;
|
||||
@@ -1038,94 +1001,85 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
impl InitModule for TransformerBlock {
|
||||
fn initialize(cx: &mut Graph) -> Self {
|
||||
impl TransformerBlock {
|
||||
pub fn new(cx: &mut Graph) -> Self {
|
||||
Self {
|
||||
attention: InitModule::initialize(cx),
|
||||
attention_norm: LayerNorm::new(false, false, false, 1e-5, cx),
|
||||
feed_forward: InitModule::initialize(cx),
|
||||
feed_forward_norm: LayerNorm::new(false, false, false, 1e-5, cx),
|
||||
attention: SelfAttention::new(cx),
|
||||
attention_norm: LayerNorm::new(HIDDEN_DIM, true, false, false, 1e-5, cx),
|
||||
feed_forward: Mlp::new(HIDDEN_DIM, MLP_DIM, cx),
|
||||
feed_forward_norm: LayerNorm::new(HIDDEN_DIM, true, false, false, 1e-5, cx),
|
||||
}
|
||||
}
|
||||
|
||||
fn initialize(mut self) -> Self {
|
||||
self.attention_norm = self.attention_norm.initialize();
|
||||
self.feed_forward_norm = self.feed_forward_norm.initialize();
|
||||
self.attention = self.attention.initialize();
|
||||
self.feed_forward.down_proj = self.feed_forward.down_proj.initialize();
|
||||
self.feed_forward.up_proj = self.feed_forward.up_proj.initialize();
|
||||
self.feed_forward.gate_proj = self.feed_forward.gate_proj.initialize();
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MistralLM {
|
||||
pub struct Llama {
|
||||
// Transformer layers
|
||||
pub layers: Vec<TransformerBlock>,
|
||||
// Final Norm layer
|
||||
pub norm: LayerNorm<HIDDEN_DIM>,
|
||||
// Norm + LM head
|
||||
pub head: LayerNorm,
|
||||
}
|
||||
|
||||
impl<Batch: Dimension, CurSeq: Dimension, PrevSeq: Dimension, TotSeq: Dimension>
|
||||
Module<(
|
||||
GraphTensor<(Batch, CurSeq, Const<HIDDEN_DIM>)>,
|
||||
Vec<KVCache<Batch, PrevSeq>>,
|
||||
PhantomData<TotSeq>,
|
||||
)> for MistralLM
|
||||
{
|
||||
type Output = (
|
||||
GraphTensor<(Batch, CurSeq, Const<HIDDEN_DIM>)>,
|
||||
Vec<KVCache<Batch, TotSeq>>,
|
||||
);
|
||||
fn forward(
|
||||
&self,
|
||||
(input, cache, _): (
|
||||
GraphTensor<(Batch, CurSeq, Const<HIDDEN_DIM>)>,
|
||||
Vec<KVCache<Batch, PrevSeq>>,
|
||||
PhantomData<TotSeq>,
|
||||
),
|
||||
) -> Self::Output {
|
||||
let mut x = input;
|
||||
|
||||
impl Module<(GraphTensor, &[KVCache])> for Llama {
|
||||
type Output = (GraphTensor, Vec<KVCache>);
|
||||
fn forward(&self, (mut x, cache): (GraphTensor, &[KVCache])) -> Self::Output {
|
||||
// Run through layers and collect new caches
|
||||
let mut new_caches = vec![];
|
||||
let mut new_cache;
|
||||
for (i, layer) in self.layers.iter().enumerate() {
|
||||
(x, new_cache) = layer.forward((x, cache[i], PhantomData::<TotSeq>));
|
||||
(x, new_cache) = layer.forward((x, cache[i]));
|
||||
new_caches.push(new_cache);
|
||||
}
|
||||
// Run through last norm and output projection
|
||||
let normed = self.norm.forward(x);
|
||||
(normed, new_caches)
|
||||
(self.head.forward(x), new_caches)
|
||||
}
|
||||
}
|
||||
|
||||
impl InitModule for MistralLM {
|
||||
fn initialize(cx: &mut Graph) -> Self {
|
||||
impl Llama {
|
||||
pub fn new(cx: &mut Graph) -> Self {
|
||||
Self {
|
||||
norm: LayerNorm::new(false, false, false, 1e-5, cx),
|
||||
layers: (0..NUM_LAYERS)
|
||||
.map(|_| InitModule::initialize(cx))
|
||||
.collect(),
|
||||
head: LayerNorm::new(HIDDEN_DIM, true, false, false, 1e-5, cx),
|
||||
layers: (0..NUM_LAYERS).map(|_| TransformerBlock::new(cx)).collect(),
|
||||
}
|
||||
}
|
||||
|
||||
fn initialize(mut self) -> Self {
|
||||
self.head = self.head.initialize();
|
||||
self.layers = self.layers.into_iter().map(|l| l.initialize()).collect();
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
let mut cx = Graph::new();
|
||||
let model = MistralLM::initialize(&mut cx);
|
||||
let model = Llama::new(&mut cx).initialize();
|
||||
let caches = (0..NUM_LAYERS)
|
||||
.map(|_| {
|
||||
(
|
||||
cx.tensor::<(Const<1>, Const<N_KV_HEADS>, Dyn<'p'>, Const<HEAD_DIM>)>()
|
||||
.set_dyn(
|
||||
random_vec(SEQ_LEN * N_KV_HEADS * HEAD_DIM),
|
||||
&[1, N_KV_HEADS, SEQ_LEN, HEAD_DIM],
|
||||
),
|
||||
cx.tensor::<(Const<1>, Const<N_KV_HEADS>, Dyn<'p'>, Const<HEAD_DIM>)>()
|
||||
.set_dyn(
|
||||
random_vec(SEQ_LEN * N_KV_HEADS * HEAD_DIM),
|
||||
&[1, N_KV_HEADS, SEQ_LEN, HEAD_DIM],
|
||||
),
|
||||
cx.tensor((1, N_KV_HEADS, 'p', HEAD_DIM)).set_dyn(
|
||||
random_vec(SEQ_LEN * N_KV_HEADS * HEAD_DIM),
|
||||
(1, N_KV_HEADS, SEQ_LEN, HEAD_DIM),
|
||||
),
|
||||
cx.tensor((1, N_KV_HEADS, 'p', HEAD_DIM)).set_dyn(
|
||||
random_vec(SEQ_LEN * N_KV_HEADS * HEAD_DIM),
|
||||
(1, N_KV_HEADS, SEQ_LEN, HEAD_DIM),
|
||||
),
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
.collect::<Vec<_>>();
|
||||
let input = cx
|
||||
.tensor::<(Const<1>, Dyn<'s'>, luminal::shape::Const<HIDDEN_DIM>)>()
|
||||
.set_dyn(random_vec(2 * HIDDEN_DIM), &[1, 2, HIDDEN_DIM]);
|
||||
let (mut out, _) = model.forward((input, caches, PhantomData::<Dyn<'t'>>));
|
||||
.tensor((1, 's', HIDDEN_DIM))
|
||||
.set_dyn(random_vec(2 * HIDDEN_DIM), (1, 2, HIDDEN_DIM));
|
||||
let (mut out, _) = model.forward((input, &caches));
|
||||
out.retrieve();
|
||||
|
||||
cx.set_dyn_dim('t', SEQ_LEN + 2);
|
||||
cx.execute();
|
||||
|
||||
let unopt_out = out.data();
|
||||
|
||||
@@ -429,9 +429,9 @@ mod tests {
|
||||
const N: usize = 256;
|
||||
let mut cx = Graph::new();
|
||||
let (a_vec, b_mat) = (random_vec(M), random_vec(M * N));
|
||||
let mut a = cx.named_tensor::<R2<1, M>>("Vec").set(a_vec.clone());
|
||||
let mut b = cx.named_tensor::<R2<N, M>>("Mat").set(b_mat.clone());
|
||||
let mut c = a.matmul(b.permute()).retrieve();
|
||||
let mut a = cx.named_tensor("Vec", (1, M)).set(a_vec.clone());
|
||||
let mut b = cx.named_tensor("Mat", (N, M)).set(b_mat.clone());
|
||||
let mut c = a.matmul(b.permute((1, 0))).retrieve();
|
||||
|
||||
cx.compile(
|
||||
<(GenericCompiler, MetalCompiler<f16>)>::default(),
|
||||
@@ -454,8 +454,8 @@ mod tests {
|
||||
const N: usize = 256;
|
||||
let mut cx = Graph::new();
|
||||
let (a_vec, b_mat) = (random_vec(M), random_vec(M * N));
|
||||
let mut a = cx.named_tensor::<R3<1, 1, M>>("Vec").set(a_vec.clone());
|
||||
let mut b = cx.named_tensor::<R2<M, N>>("Mat").set(b_mat.clone());
|
||||
let mut a = cx.named_tensor("Vec", (1, 1, M)).set(a_vec.clone());
|
||||
let mut b = cx.named_tensor("Mat", (M, N)).set(b_mat.clone());
|
||||
let mut c = a.matmul(b).retrieve();
|
||||
|
||||
cx.compile(
|
||||
|
||||
@@ -1054,7 +1054,7 @@ impl<T: MetalFloat + 'static> Compiler for PrimitiveCompiler<T> {
|
||||
// Copy outputs to device
|
||||
let copy_node = graph
|
||||
.add_op(MetalCopyToDevice::<T>::new(dev.clone()))
|
||||
.input(function_node, 0, ShapeTracker::new(&[]))
|
||||
.input(function_node, 0, ShapeTracker::default())
|
||||
.finish();
|
||||
|
||||
// Switch outgoing edges from input to copy_node
|
||||
|
||||
@@ -900,9 +900,9 @@ mod tests {
|
||||
let mat_data: Vec<i8> = (0..(1024 * 512)).map(|_| rng.gen_range(0..5)).collect();
|
||||
let vec_data = random_vec_rng(1024, &mut rng);
|
||||
let mut cx = Graph::new();
|
||||
let weights = cx.tensor::<R2<512, 1024>>().keep();
|
||||
let vec = cx.tensor::<R1<1024>>().set(vec_data.clone());
|
||||
let mut out = vec.matmul(weights.permute()).retrieve();
|
||||
let weights = cx.tensor((512, 1024)).keep();
|
||||
let vec = cx.tensor(1024).set(vec_data.clone());
|
||||
let mut out = vec.matmul(weights.permute((1, 0))).retrieve();
|
||||
|
||||
// "Load" weights in 8bit
|
||||
let blocks = mat_data
|
||||
@@ -933,11 +933,11 @@ mod tests {
|
||||
|
||||
let mut cx1 = Graph::new();
|
||||
let weights = cx1
|
||||
.tensor::<R2<512, 1024>>()
|
||||
.tensor((512, 1024))
|
||||
.set(mat_data.into_iter().map(|i| i as f32).collect::<Vec<_>>())
|
||||
.keep();
|
||||
let vec = cx1.tensor::<R1<1024>>().set(vec_data);
|
||||
let out_32 = vec.matmul(weights.permute()).retrieve();
|
||||
let vec = cx1.tensor(1024).set(vec_data);
|
||||
let out_32 = vec.matmul(weights.permute((1, 0))).retrieve();
|
||||
cx1.execute();
|
||||
|
||||
assert_close(&out.data(), &out_32.data());
|
||||
@@ -949,9 +949,9 @@ mod tests {
|
||||
let mat_data: Vec<i8> = (0..(1024 * 512)).map(|_| rng.gen_range(0..5)).collect();
|
||||
let inp_mat_data = random_vec_rng(1024 * 16, &mut rng);
|
||||
let mut cx = Graph::new();
|
||||
let weights = cx.tensor::<R2<512, 1024>>().keep();
|
||||
let inp_mat = cx.tensor::<R2<16, 1024>>().set(inp_mat_data.clone());
|
||||
let mut out = inp_mat.matmul(weights.permute()).retrieve();
|
||||
let weights = cx.tensor((512, 1024)).keep();
|
||||
let inp_mat = cx.tensor((16, 1024)).set(inp_mat_data.clone());
|
||||
let mut out = inp_mat.matmul(weights.permute((1, 0))).retrieve();
|
||||
|
||||
// "Load" weights in 8bit
|
||||
let blocks = mat_data
|
||||
@@ -1000,9 +1000,9 @@ mod tests {
|
||||
let mat_data: Vec<i8> = (0..(1024 * 512)).map(|_| rng.gen_range(0..5)).collect();
|
||||
let inp_mat_data = random_vec_rng(1024 * 16, &mut rng);
|
||||
let mut cx = Graph::new();
|
||||
let weights = cx.tensor::<R2<512, 1024>>().keep();
|
||||
let inp_mat = cx.tensor::<R2<16, 1024>>().set(inp_mat_data.clone());
|
||||
let mut out = inp_mat.matmul(weights.permute()).retrieve();
|
||||
let weights = cx.tensor((512, 1024)).keep();
|
||||
let inp_mat = cx.tensor((16, 1024)).set(inp_mat_data.clone());
|
||||
let mut out = inp_mat.matmul(weights.permute((1, 0))).retrieve();
|
||||
|
||||
// "Load" weights in 8bit
|
||||
let blocks = mat_data
|
||||
|
||||
@@ -390,7 +390,7 @@ fn test_shared_buffers() {
|
||||
use luminal::prelude::*;
|
||||
use luminal::tests::{assert_close_precision, random_vec};
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor::<R1<5>>().set(random_vec(5)).keep();
|
||||
let a = cx.tensor(5).set(random_vec(5)).keep();
|
||||
let b = a.exp2();
|
||||
let c = a.log2() * b;
|
||||
let d = b.recip();
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use dfdx::prelude::{Module as DfdxModule, *};
|
||||
use metal_rs::objc::rc::autoreleasepool;
|
||||
use rand::{rngs::StdRng, SeedableRng};
|
||||
@@ -23,13 +21,13 @@ unary_test!(
|
||||
);
|
||||
unary_test!(|a| a.exp2(), |a| (a * 2_f32.ln()).exp(), test_exp2, f16);
|
||||
unary_test!(
|
||||
|a| a.softmax::<LAxis<0>>(),
|
||||
|a| a.softmax(0),
|
||||
|a| a.softmax::<DAxis<0>>(),
|
||||
test_softmax,
|
||||
f16
|
||||
);
|
||||
unary_test!(
|
||||
|a| a.layer_norm::<LAxis<0>, _>(1e-5),
|
||||
|a| a.layer_norm(0, 1e-5),
|
||||
|a| a
|
||||
.to_dtype::<f32>()
|
||||
.normalize::<DAxis<0>>(1e-5)
|
||||
@@ -60,8 +58,8 @@ binary_test!(
|
||||
fn test_contiguous() {
|
||||
let mut cx = Graph::new();
|
||||
let data = random_vec(12);
|
||||
let a = cx.tensor::<R2<3, 4>>().set(data.clone());
|
||||
let mut b = a.permute::<R2<4, 3>, _>().reshape::<R2<12, 1>>().retrieve();
|
||||
let a = cx.tensor((3, 4)).set(data.clone());
|
||||
let mut b = a.permute((1, 0)).reshape((12, 1)).retrieve();
|
||||
cx.compile(MetalCompiler::<f16>::default(), &mut b);
|
||||
cx.execute();
|
||||
|
||||
@@ -81,15 +79,10 @@ fn test_rotate() {
|
||||
const F: usize = 3;
|
||||
const D: usize = 4;
|
||||
let data = random_vec(D * B * F);
|
||||
let a = cx
|
||||
.tensor::<R3<F, B, D>>()
|
||||
.set(data.clone())
|
||||
.permute::<_, LAxes3<1, 0, 2>>();
|
||||
let a = cx.tensor((F, B, D)).set(data.clone()).permute((1, 0, 2));
|
||||
let x1 = a.slice((.., .., ..Expression::from(D / 2)));
|
||||
let x2 = a.slice((.., .., Expression::from(D / 2)..));
|
||||
let mut rotated_a = (-x2)
|
||||
.concat_along::<R3<B, F, D>, LAxis<1>, _>(x1)
|
||||
.retrieve();
|
||||
let mut rotated_a = (-x2).concat_along(x1, 1).retrieve();
|
||||
cx.execute();
|
||||
let unopt = rotated_a.data();
|
||||
rotated_a.drop();
|
||||
@@ -121,10 +114,10 @@ fn test_constant() {
|
||||
fn test_sum_reduce() {
|
||||
let data = random_vec(40960);
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor::<R3<1, 10, 4096>>().set(data.clone());
|
||||
let mut b = a.sum_reduce::<_, LAxis<2>>().retrieve();
|
||||
let mut c = a.sum_reduce::<_, LAxis<1>>().retrieve();
|
||||
let mut d = a.sum_reduce::<_, LAxis<0>>().retrieve();
|
||||
let a = cx.tensor((1, 10, 4096)).set(data.clone());
|
||||
let mut b = a.sum_reduce(2).retrieve();
|
||||
let mut c = a.sum_reduce(1).retrieve();
|
||||
let mut d = a.sum_reduce(0).retrieve();
|
||||
|
||||
cx.compile(MetalCompiler::<f16>::default(), (&mut b, &mut c, &mut d));
|
||||
cx.execute();
|
||||
@@ -145,8 +138,8 @@ fn test_sum_reduce() {
|
||||
fn test_sum_reduce2() {
|
||||
let mut cx = Graph::new();
|
||||
let data = random_vec(32 * 10 * 10 * 128);
|
||||
let a = cx.tensor::<R5<1, 32, 10, 10, 128>>().set(data.clone());
|
||||
let mut d = a.sum_reduce::<_, LAxis<2>>().retrieve();
|
||||
let a = cx.tensor((1, 32, 10, 10, 128)).set(data.clone());
|
||||
let mut d = a.sum_reduce(2).retrieve();
|
||||
|
||||
cx.compile(MetalCompiler::<f16>::default(), &mut d);
|
||||
cx.execute();
|
||||
@@ -173,10 +166,10 @@ fn test_sum_reduce2() {
|
||||
fn test_max_reduce() {
|
||||
let data = random_vec(40960);
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor::<R3<1, 10, 4096>>().set(data.clone());
|
||||
let mut b = a.max_reduce::<_, LAxis<2>>().retrieve();
|
||||
let mut c = a.max_reduce::<_, LAxis<1>>().retrieve();
|
||||
let mut d = a.max_reduce::<_, LAxis<0>>().retrieve();
|
||||
let a = cx.tensor((1, 10, 4096)).set(data.clone());
|
||||
let mut b = a.max_reduce(2).retrieve();
|
||||
let mut c = a.max_reduce(1).retrieve();
|
||||
let mut d = a.max_reduce(0).retrieve();
|
||||
|
||||
cx.compile(MetalCompiler::<f16>::default(), (&mut b, &mut c, &mut d));
|
||||
cx.execute();
|
||||
@@ -197,10 +190,10 @@ fn test_max_reduce() {
|
||||
fn test_mean_reduce() {
|
||||
let data = random_vec(40960);
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor::<R3<1, 10, 4096>>().set(data.clone());
|
||||
let mut b = a.mean_reduce::<_, LAxis<2>>().retrieve();
|
||||
let mut c = a.mean_reduce::<_, LAxis<1>>().retrieve();
|
||||
let mut d = a.mean_reduce::<_, LAxis<0>>().retrieve();
|
||||
let a = cx.tensor((1, 10, 4096)).set(data.clone());
|
||||
let mut b = a.mean_reduce(2).retrieve();
|
||||
let mut c = a.mean_reduce(1).retrieve();
|
||||
let mut d = a.mean_reduce(0).retrieve();
|
||||
|
||||
cx.compile(MetalCompiler::<f16>::default(), (&mut b, &mut c, &mut d));
|
||||
cx.execute();
|
||||
@@ -222,8 +215,8 @@ fn test_matmul_simple() {
|
||||
let mut cx = Graph::new();
|
||||
let a_data = random_vec(256 * 256);
|
||||
let b_data = random_vec(256 * 256);
|
||||
let a = cx.tensor::<R2<256, 256>>().set(a_data.clone());
|
||||
let b = cx.tensor::<R2<256, 256>>().set(b_data.clone());
|
||||
let a = cx.tensor((256, 256)).set(a_data.clone());
|
||||
let b = cx.tensor((256, 256)).set(b_data.clone());
|
||||
let mut c = a.matmul(b).retrieve();
|
||||
|
||||
cx.compile(MetalCompiler::<f16>::default(), &mut c);
|
||||
@@ -242,8 +235,8 @@ fn test_matmul() {
|
||||
let d_dev = Cpu::default();
|
||||
let mut cx = Graph::new();
|
||||
let mut rng = StdRng::seed_from_u64(0);
|
||||
let a = cx.tensor::<(Dyn<'M'>, Dyn<'K'>)>();
|
||||
let b = cx.tensor::<(Dyn<'K'>, Dyn<'N'>)>();
|
||||
let a = cx.tensor(('M', 'K'));
|
||||
let b = cx.tensor(('K', 'N'));
|
||||
let mut c = a.matmul(b).retrieve();
|
||||
|
||||
cx.compile(MetalCompiler::<f16>::default(), &mut c);
|
||||
@@ -253,8 +246,8 @@ fn test_matmul() {
|
||||
autoreleasepool(|| {
|
||||
let a_data = random_vec_rng(m * k, &mut rng);
|
||||
let b_data = random_vec_rng(k * n, &mut rng);
|
||||
a.set_dyn(a_data.clone(), &[m, k]);
|
||||
b.set_dyn(b_data.clone(), &[k, n]);
|
||||
a.set_dyn(a_data.clone(), (m, k));
|
||||
b.set_dyn(b_data.clone(), (k, n));
|
||||
|
||||
cx.execute();
|
||||
|
||||
@@ -276,11 +269,11 @@ fn test_attn_matmul() {
|
||||
let a_data = random_vec_rng(32 * 11 * 128, &mut rng);
|
||||
let b_data = random_vec_rng(32 * 11 * 128, &mut rng);
|
||||
let a = cx
|
||||
.named_tensor::<R4<1, 32, 11, 128>>("Input")
|
||||
.named_tensor("Input", (1, 32, 11, 128))
|
||||
.set(a_data.clone())
|
||||
.keep();
|
||||
let b = cx
|
||||
.named_tensor::<R4<1, 32, 128, 11>>("Input")
|
||||
.named_tensor("Input", (1, 32, 128, 11))
|
||||
.set(b_data.clone())
|
||||
.keep();
|
||||
let mut c = a.matmul(b).retrieve();
|
||||
@@ -310,8 +303,8 @@ fn test_batch_matmul() {
|
||||
let m = 12;
|
||||
let mut cx = Graph::new();
|
||||
let mut rng = StdRng::seed_from_u64(0);
|
||||
let a = cx.tensor::<(Dyn<'B'>, Dyn<'M'>, Dyn<'K'>)>();
|
||||
let b = cx.tensor::<(Dyn<'K'>, Dyn<'N'>)>();
|
||||
let a = cx.tensor(('B', 'M', 'K'));
|
||||
let b = cx.tensor(('K', 'N'));
|
||||
let mut c = a.matmul(b).retrieve();
|
||||
|
||||
cx.compile(MetalCompiler::<f16>::default(), &mut c);
|
||||
@@ -321,8 +314,8 @@ fn test_batch_matmul() {
|
||||
autoreleasepool(|| {
|
||||
let a_data = random_vec_rng(batch * m * k, &mut rng);
|
||||
let b_data = random_vec_rng(k * n, &mut rng);
|
||||
a.set_dyn(a_data.clone(), &[batch, m, k]);
|
||||
b.set_dyn(b_data.clone(), &[k, n]);
|
||||
a.set_dyn(a_data.clone(), (batch, m, k));
|
||||
b.set_dyn(b_data.clone(), (k, n));
|
||||
cx.execute();
|
||||
|
||||
let d_dev = Cpu::default();
|
||||
@@ -348,26 +341,18 @@ fn test_batch_matmul_transpose() {
|
||||
let mut rng = StdRng::seed_from_u64(0);
|
||||
|
||||
let a_data = random_vec_rng(B * M * K, &mut rng);
|
||||
let a = cx.named_tensor::<R3<B, M, K>>("A").set(a_data.clone());
|
||||
let a = cx.named_tensor("A", (B, M, K)).set(a_data.clone());
|
||||
let b_data = random_vec_rng(K * N, &mut rng);
|
||||
let b = cx.named_tensor::<R2<N, K>>("B").set(b_data.clone());
|
||||
let b = cx.named_tensor("B", (N, K)).set(b_data.clone());
|
||||
let a_t_data = random_vec_rng(B * K * M, &mut rng);
|
||||
let a_t = cx.named_tensor::<R3<B, K, M>>("A_T").set(a_t_data.clone());
|
||||
let a_t = cx.named_tensor("A_T", (B, K, M)).set(a_t_data.clone());
|
||||
let b_t_data = random_vec_rng(K * N, &mut rng);
|
||||
let b_t = cx.named_tensor::<R2<K, N>>("B_T").set(b_t_data.clone());
|
||||
let b_t = cx.named_tensor("B_T", (K, N)).set(b_t_data.clone());
|
||||
|
||||
let mut a_b = a
|
||||
.matmul(b.permute::<_, luminal::prelude::Axes2<1, 0>>())
|
||||
.retrieve();
|
||||
let mut a_b = a.matmul(b.permute((1, 0))).retrieve();
|
||||
let mut a_b_t = a.matmul(b_t).retrieve();
|
||||
let mut a_t_b = a_t
|
||||
.permute::<_, luminal::prelude::Axes3<0, 2, 1>>()
|
||||
.matmul(b.permute::<_, luminal::prelude::Axes2<1, 0>>())
|
||||
.retrieve();
|
||||
let mut a_t_b_t = a_t
|
||||
.permute::<_, luminal::prelude::Axes3<0, 2, 1>>()
|
||||
.matmul(b_t)
|
||||
.retrieve();
|
||||
let mut a_t_b = a_t.permute((0, 2, 1)).matmul(b.permute((1, 0))).retrieve();
|
||||
let mut a_t_b_t = a_t.permute((0, 2, 1)).matmul(b_t).retrieve();
|
||||
|
||||
cx.compile(
|
||||
<(GenericCompiler, MetalCompiler<f16>)>::default(),
|
||||
@@ -407,24 +392,18 @@ fn test_matmul_transpose() {
|
||||
let mut rng = StdRng::seed_from_u64(0);
|
||||
|
||||
let a_data = random_vec_rng(M * K, &mut rng);
|
||||
let a = cx.tensor::<R2<M, K>>().set(a_data.clone());
|
||||
let a = cx.tensor((M, K)).set(a_data.clone());
|
||||
let b_data = random_vec_rng(K * N, &mut rng);
|
||||
let b = cx.tensor::<R2<N, K>>().set(b_data.clone());
|
||||
let b = cx.tensor((N, K)).set(b_data.clone());
|
||||
let a_t_data = random_vec_rng(K * M, &mut rng);
|
||||
let a_t = cx.tensor::<R2<K, M>>().set(a_t_data.clone());
|
||||
let a_t = cx.tensor((K, M)).set(a_t_data.clone());
|
||||
let b_t_data = random_vec_rng(K * N, &mut rng);
|
||||
let b_t = cx.tensor::<R2<K, N>>().set(b_t_data.clone());
|
||||
let b_t = cx.tensor((K, N)).set(b_t_data.clone());
|
||||
|
||||
let mut a_b = a.matmul(b.permute()).retrieve();
|
||||
let mut a_b = a.matmul(b.permute((1, 0))).retrieve();
|
||||
let mut a_b_t = a.matmul(b_t).retrieve();
|
||||
let mut a_t_b = a_t
|
||||
.permute::<_, luminal::prelude::Axes2<1, 0>>()
|
||||
.matmul(b.permute())
|
||||
.retrieve();
|
||||
let mut a_t_b_t = a_t
|
||||
.permute::<_, luminal::prelude::Axes2<1, 0>>()
|
||||
.matmul(b_t)
|
||||
.retrieve();
|
||||
let mut a_t_b = a_t.permute((1, 0)).matmul(b.permute((1, 0))).retrieve();
|
||||
let mut a_t_b_t = a_t.permute((1, 0)).matmul(b_t).retrieve();
|
||||
|
||||
cx.compile(
|
||||
<(GenericCompiler, MetalCompiler<f16>)>::default(),
|
||||
@@ -468,12 +447,14 @@ fn test_relu_and_linear() {
|
||||
let input_data = random_vec(32);
|
||||
let w1 = random_vec(32 * 64);
|
||||
let w2 = random_vec(32 * 64);
|
||||
let batch = cx
|
||||
.named_tensor::<R2<2, 32>>("Batch")
|
||||
.set(random_vec(32 * 2));
|
||||
let a = cx.named_tensor::<R1<32>>("Single").set(input_data.clone());
|
||||
let batch = cx.named_tensor("Batch", (2, 32)).set(random_vec(32 * 2));
|
||||
let a = cx.named_tensor("Single", 32).set(input_data.clone());
|
||||
|
||||
let model: (Linear<32, 64>, ReLU, Linear<64, 32>) = InitModule::initialize(&mut cx);
|
||||
let model = (
|
||||
Linear::new(32, 64, false, &mut cx),
|
||||
ReLU,
|
||||
Linear::new(64, 32, false, &mut cx),
|
||||
);
|
||||
model.0.weight.set(w1.clone());
|
||||
model.2.weight.set(w2.clone());
|
||||
let mut b = model.forward(a).retrieve();
|
||||
@@ -524,9 +505,9 @@ fn test_rms_norm() {
|
||||
let inp_data = random_vec_rng(15 * 32, &mut rng);
|
||||
let weight_data = random_vec_rng(32, &mut rng);
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor::<R2<15, 32>>().set(inp_data.clone());
|
||||
let a = cx.tensor((15, 32)).set(inp_data.clone());
|
||||
|
||||
let model = LayerNorm::<32>::new(true, false, false, 1e-5, &mut cx);
|
||||
let model = LayerNorm::new(32, true, false, false, 1e-5, &mut cx);
|
||||
model.weight.unwrap().set(weight_data.clone());
|
||||
let mut b = model.forward(a).retrieve();
|
||||
|
||||
@@ -553,9 +534,9 @@ fn test_rms_norm() {
|
||||
fn test_layer_norm() {
|
||||
let mut cx = Graph::new();
|
||||
let a_data = random_vec(15 * 16 * 32);
|
||||
let a = cx.tensor::<R3<15, 16, 32>>().set(a_data.clone());
|
||||
let mut b = a.layer_norm::<LAxis<0>, _>(1e-5).retrieve();
|
||||
let mut c = a.layer_norm::<LAxis<2>, _>(1e-5).retrieve();
|
||||
let a = cx.tensor((15, 16, 32)).set(a_data.clone());
|
||||
let mut b = a.layer_norm(0, 1e-5).retrieve();
|
||||
let mut c = a.layer_norm(2, 1e-5).retrieve();
|
||||
cx.compile(
|
||||
<(GenericCompiler, MetalCompiler<f16>)>::default(),
|
||||
(&mut b, &mut c),
|
||||
@@ -574,82 +555,110 @@ fn test_layer_norm() {
|
||||
#[test]
|
||||
fn test_transformer_encoder_block() {
|
||||
let mut cx = Graph::new();
|
||||
let model: luminal_nn::TransformerEncoderBlock<32, 64, 1> = InitModule::initialize(&mut cx);
|
||||
let w_k_weight = random_vec(32 * 32);
|
||||
model.attention.w_k.weight.set(w_k_weight.clone());
|
||||
let w_q_weight = random_vec(32 * 32);
|
||||
model.attention.w_q.weight.set(w_q_weight.clone());
|
||||
let w_v_weight = random_vec(32 * 32);
|
||||
model.attention.w_v.weight.set(w_v_weight.clone());
|
||||
let w_o_weight = random_vec(32 * 32);
|
||||
model.attention.w_o.weight.set(w_o_weight.clone());
|
||||
let ff_0_weight = random_vec(32 * 64);
|
||||
model.ff.0.weight.set(ff_0_weight.clone());
|
||||
let ff_1_weight = random_vec(64 * 32);
|
||||
model.ff.2.weight.set(ff_1_weight.clone());
|
||||
let model = luminal_nn::TransformerEncoderBlock::new(3, 4, 1, &mut cx);
|
||||
model
|
||||
.attention
|
||||
.w_k
|
||||
.weight
|
||||
.set(vec![1., 22., 3., 1., 2., 3., 1., 2., 3.]);
|
||||
model
|
||||
.attention
|
||||
.w_q
|
||||
.weight
|
||||
.set(vec![3., 2., 3., 1.3, 2., 3., 3., 2., 3.]);
|
||||
model
|
||||
.attention
|
||||
.w_v
|
||||
.weight
|
||||
.set(vec![-1., 12., 3., -1., 2., -3., 11., 2., 3.]);
|
||||
model
|
||||
.attention
|
||||
.w_o
|
||||
.weight
|
||||
.set(vec![1., 22., 3., 1., 2., 3., 1., 2., 3.]);
|
||||
model
|
||||
.ff
|
||||
.0
|
||||
.weight
|
||||
.set(vec![-1., 12., 3., -1., 2., -3., 11., 2., 3., 11., 2., 3.]);
|
||||
model
|
||||
.ff
|
||||
.2
|
||||
.weight
|
||||
.set(vec![-1., 12., 3., -1., 2., -3., 11., 2., 3., 3., -1., 2.]);
|
||||
|
||||
let a_data = random_vec(2 * 32);
|
||||
let a = cx
|
||||
.tensor::<(Dyn<'b'>, Dyn<'a'>, luminal::prelude::Const<32>)>()
|
||||
.set_dyn(a_data.clone(), &[1, 2, 3])
|
||||
.keep();
|
||||
cx.keep_tensors(param_dict(&model));
|
||||
.tensor(('a', 3))
|
||||
.set_dyn(vec![-1., 2., 3., 3., 3., -1.], (2, 3));
|
||||
let mut b = model.forward(a).retrieve();
|
||||
cx.execute();
|
||||
let unopt_b = b.data();
|
||||
b.drop();
|
||||
|
||||
cx.compile(<(GenericCompiler, MetalCompiler<f16>)>::default(), &mut b);
|
||||
cx.execute();
|
||||
assert_close_precision(&unopt_b, &b.data(), 1e-2);
|
||||
|
||||
let d_dev = Cpu::default();
|
||||
let mut d_model: dfdx::nn::modules::TransformerEncoderBlock<32, 1, 64, f32, Cpu> =
|
||||
d_dev
|
||||
.build_module::<dfdx::nn::modules::builders::TransformerEncoderBlock<32, 1, 64>, f32>();
|
||||
d_model.self_attn.w_k.bias.copy_from(&[0.; 32]);
|
||||
d_model.self_attn.w_v.bias.copy_from(&[0.; 32]);
|
||||
d_model.self_attn.w_q.bias.copy_from(&[0.; 32]);
|
||||
d_model.self_attn.w_o.bias.copy_from(&[0.; 32]);
|
||||
let mut d_model: dfdx::nn::modules::TransformerEncoderBlock<3, 1, 4, f32, Cpu> =
|
||||
d_dev.build_module::<dfdx::nn::modules::builders::TransformerEncoderBlock<3, 1, 4>, f32>();
|
||||
d_model.self_attn.w_k.bias.copy_from(&[0.0, 0.0, 0.0]);
|
||||
d_model.self_attn.w_v.bias.copy_from(&[0.0, 0.0, 0.0]);
|
||||
d_model.self_attn.w_q.bias.copy_from(&[0.0, 0.0, 0.0]);
|
||||
d_model.self_attn.w_o.bias.copy_from(&[0., 0., 0.]);
|
||||
d_model.self_attn.w_o.weight = d_dev
|
||||
.tensor_from_vec(w_o_weight, (DConst::<32>, DConst::<32>))
|
||||
.tensor_from_vec(
|
||||
vec![1., 22., 3., 1., 2., 3., 1., 2., 3.],
|
||||
(DConst::<3>, DConst::<3>),
|
||||
)
|
||||
.permute();
|
||||
d_model.self_attn.w_k.weight = d_dev
|
||||
.tensor_from_vec(w_k_weight, (DConst::<32>, DConst::<32>))
|
||||
.tensor_from_vec(
|
||||
vec![1., 22., 3., 1., 2., 3., 1., 2., 3.],
|
||||
(DConst::<3>, DConst::<3>),
|
||||
)
|
||||
.permute();
|
||||
d_model.self_attn.w_q.weight = d_dev
|
||||
.tensor_from_vec(w_q_weight, (DConst::<32>, DConst::<32>))
|
||||
.tensor_from_vec(
|
||||
vec![3., 2., 3., 1.3, 2., 3., 3., 2., 3.],
|
||||
(DConst::<3>, DConst::<3>),
|
||||
)
|
||||
.permute();
|
||||
d_model.self_attn.w_v.weight = d_dev
|
||||
.tensor_from_vec(w_v_weight, (DConst::<32>, DConst::<32>))
|
||||
.tensor_from_vec(
|
||||
vec![-1., 12., 3., -1., 2., -3., 11., 2., 3.],
|
||||
(DConst::<3>, DConst::<3>),
|
||||
)
|
||||
.permute();
|
||||
d_model.ff.0 .0.weight = d_dev
|
||||
.tensor_from_vec(ff_0_weight, (DConst::<32>, DConst::<64>))
|
||||
.tensor_from_vec(
|
||||
vec![-1., 12., 3., -1., 2., -3., 11., 2., 3., 11., 2., 3.],
|
||||
(DConst::<3>, DConst::<4>),
|
||||
)
|
||||
.permute();
|
||||
d_model.ff.0 .0.bias = d_dev.tensor_from_vec(vec![0.; 64], (DConst::<64>,));
|
||||
d_model.ff.0 .0.bias = d_dev.tensor_from_vec(vec![0., 0., 0., 0.], (DConst::<4>,));
|
||||
d_model.ff.0 .2.weight = d_dev
|
||||
.tensor_from_vec(ff_1_weight, (DConst::<64>, DConst::<32>))
|
||||
.tensor_from_vec(
|
||||
vec![-1., 12., 3., -1., 2., -3., 11., 2., 3., 3., -1., 2.],
|
||||
(DConst::<4>, DConst::<3>),
|
||||
)
|
||||
.permute();
|
||||
d_model.ff.0 .2.bias = d_dev.tensor_from_vec(vec![0.; 32], (DConst::<32>,));
|
||||
d_model.norm1.gamma = d_dev.tensor_from_vec(vec![1.; 32], (DConst::<32>,));
|
||||
d_model.norm2.gamma = d_dev.tensor_from_vec(vec![1.; 32], (DConst::<32>,));
|
||||
d_model.ff.0 .2.bias = d_dev.tensor_from_vec(vec![0., 0., 0.], (DConst::<3>,));
|
||||
d_model.norm1.gamma = d_dev.tensor_from_vec(vec![1., 1., 1.], (DConst::<3>,));
|
||||
d_model.norm2.gamma = d_dev.tensor_from_vec(vec![1., 1., 1.], (DConst::<3>,));
|
||||
d_model.norm1.epsilon = 1e-5;
|
||||
d_model.norm2.beta = d_dev.tensor_from_vec(vec![0.; 32], (DConst::<32>,));
|
||||
d_model.norm1.beta = d_dev.tensor_from_vec(vec![0.; 32], (DConst::<32>,));
|
||||
d_model.norm2.beta = d_dev.tensor_from_vec(vec![0., 0., 0.], (DConst::<3>,));
|
||||
d_model.norm1.beta = d_dev.tensor_from_vec(vec![0., 0., 0.], (DConst::<3>,));
|
||||
d_model.norm2.epsilon = 1e-5;
|
||||
let d_a = d_dev.tensor_from_vec(a_data, (DConst::<2>, DConst::<32>));
|
||||
let d_a = d_dev.tensor_from_vec(vec![-1., 2., 3., 3., 3., -1.], (DConst::<2>, DConst::<3>));
|
||||
let d_b = d_model.forward(d_a);
|
||||
|
||||
assert_close_precision(&b.data(), &d_b.as_vec(), 1e-2);
|
||||
assert_close(&b.data(), &d_b.as_vec());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_common_buffer() {
|
||||
let data = random_vec(32);
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor::<R1<32>>();
|
||||
let a = cx.tensor(32);
|
||||
a.set(data.clone());
|
||||
let a1 = cx.tensor::<R1<32>>();
|
||||
let a1 = cx.tensor(32);
|
||||
a1.set(data.clone());
|
||||
let exped = a * a1;
|
||||
let mut b = exped.log2().retrieve();
|
||||
@@ -663,15 +672,12 @@ fn test_common_buffer() {
|
||||
fn test_embedding() {
|
||||
let mut cx = Graph::new();
|
||||
let batch = cx
|
||||
.named_tensor::<R2<2, 3>>("Batch")
|
||||
.named_tensor("Batch", (2, 3))
|
||||
.set(vec![1.0, 0.0, 2.0, 1.0, 0.0, 1.0])
|
||||
.keep();
|
||||
let a = cx
|
||||
.named_tensor::<R1<3>>("Single")
|
||||
.set(vec![1.0, 0.0, 1.0])
|
||||
.keep();
|
||||
let a = cx.named_tensor("Single", 3).set(vec![1.0, 0.0, 1.0]).keep();
|
||||
|
||||
let model: luminal_nn::Embedding<3, 4> = InitModule::initialize(&mut cx);
|
||||
let model = luminal_nn::Embedding::new(3, 4, &mut cx);
|
||||
model
|
||||
.weight
|
||||
.set(vec![1.1, 2., 3., 1., 2., 3., 14., 2., 33., 1., 2., 3.]);
|
||||
@@ -702,12 +708,8 @@ fn test_embedding() {
|
||||
fn test_slice() {
|
||||
let data = random_vec(256);
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor::<R1<256>>().set(data.clone());
|
||||
let mut c: GraphTensor<R1<20>> = a
|
||||
.slice((..Expression::from(20),))
|
||||
.realize()
|
||||
.contiguous()
|
||||
.retrieve();
|
||||
let a = cx.tensor(256).set(data.clone());
|
||||
let mut c = a.slice(..Expression::from(20)).contiguous().retrieve();
|
||||
|
||||
cx.compile(MetalCompiler::<f16>::default(), &mut c);
|
||||
cx.execute();
|
||||
@@ -726,8 +728,8 @@ fn test_pad() {
|
||||
// Pad a 8x2 mat to 10x4
|
||||
let data = random_vec(8 * 2);
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor::<R2<8, 2>>().set(data.clone());
|
||||
let mut c = a.pad::<R2<10, 4>>(((0, 2), (0, 2))).contiguous().retrieve();
|
||||
let a = cx.tensor((8, 2)).set(data.clone());
|
||||
let mut c = a.pad(((0, 2), (0, 2))).contiguous().retrieve();
|
||||
|
||||
cx.compile(MetalCompiler::<f16>::default(), &mut c);
|
||||
cx.execute();
|
||||
@@ -748,16 +750,12 @@ fn test_pad_contig() {
|
||||
let mut cx = Graph::new();
|
||||
let mut rng = StdRng::seed_from_u64(0);
|
||||
let a_data = random_vec_rng(m * k, &mut rng);
|
||||
let mut a = cx
|
||||
.tensor::<(Dyn<'M'>, Dyn<'K'>)>()
|
||||
.set_dyn(a_data, &[m, k])
|
||||
.retrieve();
|
||||
let mut b: GraphTensor<(Dyn<'M'>, Dyn<'K'>)> = a
|
||||
let mut a = cx.tensor(('M', 'K')).set_dyn(a_data, (m, k)).retrieve();
|
||||
let mut b = a
|
||||
.pad(&[(0, 0.into()), (0, Expression::from(24) - 'K')])
|
||||
.contiguous()
|
||||
.retrieve();
|
||||
let mut c: GraphTensor<(Dyn<'M'>, Dyn<'K'>)> =
|
||||
(a.slice((.., ..Expression::from(k))).realize() / 1.0).retrieve();
|
||||
let mut c = (a.slice((.., ..Expression::from(k))) / 1.0).retrieve();
|
||||
|
||||
cx.compile(MetalCompiler::<f16>::default(), (&mut a, &mut b, &mut c));
|
||||
cx.execute();
|
||||
@@ -771,13 +769,9 @@ fn test_pad_contig() {
|
||||
fn test_movement() {
|
||||
let data = random_vec(32);
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor::<R1<32>>().set(data.clone());
|
||||
let b: GraphTensor<R1<42>> = a.pad((0, 10)).contiguous().retrieve();
|
||||
let mut c: GraphTensor<R1<25>> = b
|
||||
.slice((..Expression::from(25),))
|
||||
.realize()
|
||||
.contiguous()
|
||||
.retrieve();
|
||||
let a = cx.tensor(32).set(data.clone());
|
||||
let b = a.pad((0, 10)).contiguous().retrieve();
|
||||
let mut c = b.slice((..Expression::from(25),)).contiguous().retrieve();
|
||||
|
||||
cx.compile(MetalCompiler::<f16>::default(), &mut c);
|
||||
cx.execute();
|
||||
@@ -794,10 +788,9 @@ fn test_movement() {
|
||||
#[test]
|
||||
fn test_slice_add() {
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor().set(random_array::<256>());
|
||||
let a = cx.tensor(256).set(random_array::<256>());
|
||||
let mut b = (a.slice(0..64) + a.slice(64..128) + a.slice(128..192) + a.slice(192..256))
|
||||
.realize::<R1<64>>()
|
||||
.expand::<R2<4, 64>, _>()
|
||||
.expand(0, 4)
|
||||
.retrieve();
|
||||
|
||||
cx.compile(MetalCompiler::<f16>::default(), &mut b);
|
||||
@@ -814,14 +807,12 @@ fn test_conv2d() {
|
||||
const KERNELY: usize = 2;
|
||||
const STRIDEX: usize = KERNELX;
|
||||
const STRIDEY: usize = KERNELY;
|
||||
const DILATIONX: usize = 0;
|
||||
const DILATIONY: usize = 0;
|
||||
const DILATIONX: usize = 1;
|
||||
const DILATIONY: usize = 1;
|
||||
const DIMX_IN: usize = 16;
|
||||
const DIMX_OUT: usize = ((DIMX_IN - (DILATIONX + 1) * (KERNELX - 1) - 1) / STRIDEX) + 1;
|
||||
const DIMY_IN: usize = 9;
|
||||
const DIMY_OUT: usize = ((DIMY_IN - (DILATIONY + 1) * (KERNELY - 1) - 1) / STRIDEY) + 1;
|
||||
|
||||
let inp1 = cx.tensor::<R3<CH_IN, DIMX_IN, DIMY_IN>>().set(vec![
|
||||
let inp1 = cx.tensor((CH_IN, DIMX_IN, DIMY_IN)).set(vec![
|
||||
8., 8., 5., 7., 0., 6., 5., 3., 0., 7., 0., 6., 6., 7., 7., 5., 0., 6., 9., 4., 0., 8., 8.,
|
||||
5., 7., 6., 2., 8., 9., 5., 0., 3., 1., 1., 8., 4., 1., 1., 5., 6., 9., 3., 2., 9., 4., 7.,
|
||||
1., 0., 7., 7., 4., 9., 5., 0., 4., 7., 4., 7., 8., 8., 4., 8., 4., 7., 9., 3., 7., 9., 5.,
|
||||
@@ -856,7 +847,14 @@ fn test_conv2d() {
|
||||
1., 2., 1., 1., 4., 7., 2.,
|
||||
]);
|
||||
|
||||
let model = luminal_nn::Conv2D::<CH_IN, CH_OUT, KERNELX, KERNELY>::initialize(&mut cx);
|
||||
let model = luminal_nn::Conv2D::new(
|
||||
CH_IN,
|
||||
CH_OUT,
|
||||
(KERNELX, KERNELY),
|
||||
(STRIDEX, STRIDEY),
|
||||
(DILATIONX, DILATIONY),
|
||||
&mut cx,
|
||||
);
|
||||
model.weight.set(vec![
|
||||
0.1600, 0.2000, 0.1900, -0.1100, 0.0100, -0.0300, -0.1200, -0.0800, -0.1300, -0.0300,
|
||||
0.1600, -0.1700, -0.0000, 0.1900, 0.1300, 0.0300, -0.1500, 0.0900, 0.0100, 0.0200, 0.1500,
|
||||
@@ -864,9 +862,7 @@ fn test_conv2d() {
|
||||
0.0100, -0.2000, 0.2100, -0.0400, -0.1400, 0.1500, 0.0500, -0.1700, 0.1400,
|
||||
]);
|
||||
|
||||
let mut out1 = model
|
||||
.forward::<DIMX_IN, DIMY_IN, DIMX_OUT, DIMY_OUT>(inp1)
|
||||
.retrieve();
|
||||
let mut out1 = model.forward(inp1).retrieve();
|
||||
|
||||
cx.compile(
|
||||
<(GenericCompiler, MetalCompiler<f16>)>::default(),
|
||||
@@ -899,18 +895,21 @@ fn test_conv1d_pad_stride() {
|
||||
const KERNEL: usize = 3;
|
||||
const STRIDE: usize = 1;
|
||||
const PADDING: usize = 1;
|
||||
const DILATION: usize = 1;
|
||||
const DIM_IN: usize = 10;
|
||||
let kernel_data = random_vec_rng(KERNEL * CH_IN * CH_OUT, &mut rng);
|
||||
let input_data = random_vec_rng(CH_IN * DIM_IN, &mut rng);
|
||||
|
||||
let model = Conv1D::<CH_IN, CH_OUT, KERNEL, STRIDE, 0, PADDING>::initialize(&mut cx);
|
||||
let model = Conv1D::new(
|
||||
CH_IN, CH_OUT, KERNEL, STRIDE, DILATION, PADDING, false, &mut cx,
|
||||
);
|
||||
model.weight.set(kernel_data.clone());
|
||||
|
||||
let inp1 = cx
|
||||
.tensor::<(LConst<1>, LConst<CH_IN>, Dyn<'s'>)>()
|
||||
.set_dyn(input_data.clone(), &[1, CH_IN, DIM_IN]);
|
||||
.tensor((CH_IN, 's'))
|
||||
.set_dyn(input_data.clone(), (CH_IN, DIM_IN));
|
||||
|
||||
let mut out1 = model.forward((inp1, PhantomData::<Dyn<'s'>>)).retrieve();
|
||||
let mut out1 = model.forward(inp1).retrieve();
|
||||
cx.compile(crate::MetalCompiler::<f16>::default(), &mut out1);
|
||||
cx.execute();
|
||||
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use dfdx::prelude::{Module as DfdxModule, *};
|
||||
use rand::{rngs::StdRng, SeedableRng};
|
||||
|
||||
@@ -17,13 +15,13 @@ unary_test!(|a| a.ln(), |a| a.ln(), test_ln, f32);
|
||||
unary_test!(|a| a.log2(), |a| a.ln() / 2_f32.ln(), test_log2, f32);
|
||||
unary_test!(|a| a.exp2(), |a| (a * 2_f32.ln()).exp(), test_exp2, f32);
|
||||
unary_test!(
|
||||
|a| a.softmax::<LAxis<0>>(),
|
||||
|a| a.softmax(0),
|
||||
|a| a.softmax::<DAxis<0>>(),
|
||||
test_softmax,
|
||||
f32
|
||||
);
|
||||
unary_test!(
|
||||
|a| a.mean_norm::<LAxis<0>>().std_norm::<LAxis<0>, _>(1e-5),
|
||||
|a| a.mean_norm(0).std_norm(0, 1e-5),
|
||||
|a| a.normalize::<DAxis<0>>(1e-5),
|
||||
test_norm,
|
||||
f32
|
||||
@@ -46,8 +44,8 @@ binary_test!(|a, b| a.max(b), |a, b| a.maximum(b), test_max, f32);
|
||||
fn test_contiguous() {
|
||||
let mut cx = Graph::new();
|
||||
let data = random_vec(12);
|
||||
let a = cx.tensor::<R2<3, 4>>().set(data.clone());
|
||||
let mut b = a.permute::<R2<4, 3>, _>().reshape::<R2<12, 1>>().retrieve();
|
||||
let a = cx.tensor((3, 4)).set(data.clone());
|
||||
let mut b = a.permute((1, 0)).reshape((12, 1)).retrieve();
|
||||
cx.compile(MetalCompiler::<f32>::default(), &mut b);
|
||||
cx.execute();
|
||||
|
||||
@@ -64,11 +62,11 @@ fn test_contiguous() {
|
||||
fn test_sum_reduce() {
|
||||
let mut cx = Graph::new();
|
||||
let data = random_vec(4 * 4096);
|
||||
let a = cx.tensor::<R3<1, 4, 4096>>();
|
||||
let a = cx.tensor((1, 4, 4096));
|
||||
a.set(data.clone());
|
||||
let mut b = a.sum_reduce::<_, LAxis<1>>().retrieve();
|
||||
let mut c = a.sum_reduce::<_, LAxis<0>>().retrieve();
|
||||
let mut d = a.sum_reduce::<_, LAxis<2>>().retrieve();
|
||||
let mut b = a.sum_reduce(1).retrieve();
|
||||
let mut c = a.sum_reduce(0).retrieve();
|
||||
let mut d = a.sum_reduce(2).retrieve();
|
||||
|
||||
cx.compile(MetalCompiler::<f32>::default(), (&mut b, &mut c, &mut d));
|
||||
cx.execute();
|
||||
@@ -88,11 +86,11 @@ fn test_sum_reduce() {
|
||||
fn test_max_reduce() {
|
||||
let mut cx = Graph::new();
|
||||
let data = random_vec(12);
|
||||
let a = cx.tensor::<R3<2, 2, 3>>();
|
||||
let a = cx.tensor((2, 2, 3));
|
||||
a.set(data.clone());
|
||||
let mut b = a.max_reduce::<_, LAxis<1>>().retrieve();
|
||||
let mut c = a.max_reduce::<_, LAxis<0>>().retrieve();
|
||||
let mut d = a.max_reduce::<_, LAxis<2>>().retrieve();
|
||||
let mut b = a.max_reduce(1).retrieve();
|
||||
let mut c = a.max_reduce(0).retrieve();
|
||||
let mut d = a.max_reduce(2).retrieve();
|
||||
|
||||
cx.compile(MetalCompiler::<f32>::default(), (&mut b, &mut c, &mut d));
|
||||
cx.execute();
|
||||
@@ -112,8 +110,8 @@ fn test_max_reduce() {
|
||||
fn test_mean_reduce() {
|
||||
let data = random_vec(40960);
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor::<R3<1, 10, 4096>>().set(data.clone());
|
||||
let mut b = a.mean_reduce::<_, LAxis<2>>().retrieve();
|
||||
let a = cx.tensor((1, 10, 4096)).set(data.clone());
|
||||
let mut b = a.mean_reduce(2).retrieve();
|
||||
|
||||
cx.compile(MetalCompiler::<f32>::default(), &mut b);
|
||||
cx.execute();
|
||||
@@ -129,8 +127,8 @@ fn test_matmul_simple() {
|
||||
let mut cx = Graph::new();
|
||||
let a_data = random_vec(256 * 256);
|
||||
let b_data = random_vec(256 * 256);
|
||||
let a = cx.tensor::<R2<256, 256>>().set(a_data.clone());
|
||||
let b = cx.tensor::<R2<256, 256>>().set(b_data.clone());
|
||||
let a = cx.tensor((256, 256)).set(a_data.clone());
|
||||
let b = cx.tensor((256, 256)).set(b_data.clone());
|
||||
let mut c = a.matmul(b).retrieve();
|
||||
|
||||
cx.compile(MetalCompiler::<f32>::default(), &mut c);
|
||||
@@ -149,8 +147,8 @@ fn test_matmul() {
|
||||
let mut cx = Graph::new();
|
||||
let a_data = random_vec(512 * 512);
|
||||
let b_data = random_vec(512 * 512);
|
||||
let a = cx.tensor::<R2<512, 512>>().set(a_data.clone());
|
||||
let b = cx.tensor::<R2<512, 512>>().set(b_data.clone());
|
||||
let a = cx.tensor((512, 512)).set(a_data.clone());
|
||||
let b = cx.tensor((512, 512)).set(b_data.clone());
|
||||
let mut c = a.matmul(b).retrieve();
|
||||
|
||||
cx.compile(MetalCompiler::<f32>::default(), &mut c);
|
||||
@@ -168,10 +166,10 @@ fn test_matmul() {
|
||||
fn test_batch_matmul() {
|
||||
let mut cx = Graph::new();
|
||||
let a = cx
|
||||
.tensor::<R3<2, 2, 3>>()
|
||||
.tensor((2, 2, 3))
|
||||
.set(vec![1., 2., 3., 1., 2., 1., 1., 2., 3., 1., 2., 1.]);
|
||||
let b = cx
|
||||
.tensor::<R2<3, 4>>()
|
||||
.tensor((3, 4))
|
||||
.set(vec![1., 2., 3., 1., 1., 2., 1., 2., -1., -2., 1., 2.]);
|
||||
let mut c = a.matmul(b).retrieve();
|
||||
|
||||
@@ -195,24 +193,18 @@ fn test_matmul_transpose() {
|
||||
let mut rng = StdRng::seed_from_u64(0);
|
||||
|
||||
let a_data = random_vec_rng(M * K, &mut rng);
|
||||
let a = cx.tensor::<R2<M, K>>().set(a_data.clone());
|
||||
let a = cx.tensor((M, K)).set(a_data.clone());
|
||||
let b_data = random_vec_rng(K * N, &mut rng);
|
||||
let b = cx.tensor::<R2<N, K>>().set(b_data.clone());
|
||||
let b = cx.tensor((N, K)).set(b_data.clone());
|
||||
let a_t_data = random_vec_rng(K * M, &mut rng);
|
||||
let a_t = cx.tensor::<R2<K, M>>().set(a_t_data.clone());
|
||||
let a_t = cx.tensor((K, M)).set(a_t_data.clone());
|
||||
let b_t_data = random_vec_rng(K * N, &mut rng);
|
||||
let b_t = cx.tensor::<R2<K, N>>().set(b_t_data.clone());
|
||||
let b_t = cx.tensor((K, N)).set(b_t_data.clone());
|
||||
|
||||
let mut a_b = a.matmul(b.permute()).retrieve();
|
||||
let mut a_b = a.matmul(b.permute((1, 0))).retrieve();
|
||||
let mut a_b_t = a.matmul(b_t).retrieve();
|
||||
let mut a_t_b = a_t
|
||||
.permute::<_, luminal::prelude::Axes2<1, 0>>()
|
||||
.matmul(b.permute())
|
||||
.retrieve();
|
||||
let mut a_t_b_t = a_t
|
||||
.permute::<_, luminal::prelude::Axes2<1, 0>>()
|
||||
.matmul(b_t)
|
||||
.retrieve();
|
||||
let mut a_t_b = a_t.permute((1, 0)).matmul(b.permute((1, 0))).retrieve();
|
||||
let mut a_t_b_t = a_t.permute((1, 0)).matmul(b_t).retrieve();
|
||||
|
||||
cx.compile(
|
||||
MetalCompiler::<f32>::default(),
|
||||
@@ -248,12 +240,14 @@ fn test_relu_and_linear() {
|
||||
let input_data = random_vec(32);
|
||||
let w1 = random_vec(32 * 64);
|
||||
let w2 = random_vec(32 * 64);
|
||||
let batch = cx
|
||||
.named_tensor::<R2<2, 32>>("Batch")
|
||||
.set(random_vec(32 * 2));
|
||||
let a = cx.named_tensor::<R1<32>>("Single").set(input_data.clone());
|
||||
let batch = cx.named_tensor("Batch", (2, 32)).set(random_vec(32 * 2));
|
||||
let a = cx.named_tensor("Single", 32).set(input_data.clone());
|
||||
|
||||
let model: (Linear<32, 64>, ReLU, Linear<64, 32>) = InitModule::initialize(&mut cx);
|
||||
let model = (
|
||||
Linear::new(32, 64, false, &mut cx),
|
||||
ReLU,
|
||||
Linear::new(64, 32, false, &mut cx),
|
||||
);
|
||||
model.0.weight.set(w1.clone());
|
||||
model.2.weight.set(w2.clone());
|
||||
let mut b = model.forward(a).retrieve();
|
||||
@@ -296,7 +290,7 @@ fn test_relu_and_linear() {
|
||||
#[test]
|
||||
fn test_transformer_encoder_block() {
|
||||
let mut cx = Graph::new();
|
||||
let model: luminal_nn::TransformerEncoderBlock<3, 4, 1> = InitModule::initialize(&mut cx);
|
||||
let model = luminal_nn::TransformerEncoderBlock::new(3, 4, 1, &mut cx);
|
||||
model
|
||||
.attention
|
||||
.w_k
|
||||
@@ -329,8 +323,8 @@ fn test_transformer_encoder_block() {
|
||||
.set(vec![-1., 12., 3., -1., 2., -3., 11., 2., 3., 3., -1., 2.]);
|
||||
|
||||
let a = cx
|
||||
.tensor::<(Dyn<'b'>, Dyn<'a'>, luminal::prelude::Const<3>)>()
|
||||
.set_dyn(vec![-1., 2., 3., 3., 3., -1.], &[1, 2, 3]);
|
||||
.tensor(('a', 3))
|
||||
.set_dyn(vec![-1., 2., 3., 3., 3., -1.], (2, 3));
|
||||
let mut b = model.forward(a).retrieve();
|
||||
|
||||
cx.compile(<(GenericCompiler, MetalCompiler<f32>)>::default(), &mut b);
|
||||
@@ -397,11 +391,11 @@ fn test_transformer_encoder_block() {
|
||||
fn test_pool_1d_dims() {
|
||||
let mut cx = Graph::new();
|
||||
|
||||
let inp1 = cx.tensor::<R2<4, 4>>().set(vec![
|
||||
let inp1 = cx.tensor((4, 4)).set(vec![
|
||||
1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16.,
|
||||
]);
|
||||
// Stride 1
|
||||
let out1 = inp1.pool_last_dim::<R3<4, 2, 3>>(3, 1, 0).retrieve();
|
||||
let out1 = inp1.pool_last_dim(3, 1, 1).retrieve();
|
||||
|
||||
cx.execute();
|
||||
|
||||
@@ -418,21 +412,21 @@ fn test_pool_1d_dims() {
|
||||
fn test_pool_2d() {
|
||||
let mut cx = Graph::new();
|
||||
|
||||
let inp1 = cx.tensor::<R2<4, 4>>().set(vec![
|
||||
let inp1 = cx.tensor((4, 4)).set(vec![
|
||||
1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16.,
|
||||
]);
|
||||
// 3x3 kernel
|
||||
let out1 = inp1
|
||||
// Pool first dim first by moving it to end
|
||||
.permute::<_, LAxes2<1, 0>>()
|
||||
.pool_last_dim::<R3<4, 2, 3>>(3, 1, 0)
|
||||
.permute((1, 0))
|
||||
.pool_last_dim(3, 1, 1)
|
||||
// Now move other dim to end
|
||||
.permute::<_, LAxes3<1, 2, 0>>()
|
||||
.pool_last_dim::<R4<2, 3, 2, 3>>(3, 1, 0)
|
||||
.permute((1, 2, 0))
|
||||
.pool_last_dim(3, 1, 1)
|
||||
// Now swap middle two dims
|
||||
.permute::<_, LAxes4<0, 2, 1, 3>>()
|
||||
.permute((0, 2, 1, 3))
|
||||
// Now merge both pooled dimensions
|
||||
.reshape::<R3<4, 3, 3>>()
|
||||
.reshape((4, 3, 3))
|
||||
.retrieve();
|
||||
|
||||
cx.execute();
|
||||
@@ -451,13 +445,13 @@ fn test_pool_2d() {
|
||||
fn test_pool_1d_dilation() {
|
||||
let mut cx = Graph::new();
|
||||
|
||||
let inp1 = cx.tensor::<R1<5>>().set(vec![1., 2., 3., 4., 5.]);
|
||||
let inp1 = cx.tensor(5).set(vec![1., 2., 3., 4., 5.]);
|
||||
// Stride 1
|
||||
let out1 = inp1.pool_last_dim::<R2<3, 2>>(2, 1, 1).retrieve();
|
||||
let out1 = inp1.pool_last_dim(2, 1, 2).retrieve();
|
||||
// Stride 2
|
||||
let out2 = inp1.pool_last_dim::<R2<2, 2>>(2, 2, 1).retrieve();
|
||||
let out2 = inp1.pool_last_dim(2, 2, 2).retrieve();
|
||||
// Stride 3
|
||||
let out3 = inp1.pool_last_dim::<R2<1, 2>>(2, 3, 1).retrieve();
|
||||
let out3 = inp1.pool_last_dim(2, 3, 2).retrieve();
|
||||
|
||||
cx.execute();
|
||||
|
||||
@@ -476,14 +470,12 @@ fn test_conv2d() {
|
||||
const KERNELY: usize = 2;
|
||||
const STRIDEX: usize = KERNELX;
|
||||
const STRIDEY: usize = KERNELY;
|
||||
const DILATIONX: usize = 0;
|
||||
const DILATIONY: usize = 0;
|
||||
const DILATIONX: usize = 1;
|
||||
const DILATIONY: usize = 1;
|
||||
const DIMX_IN: usize = 16;
|
||||
const DIMX_OUT: usize = ((DIMX_IN - (DILATIONX + 1) * (KERNELX - 1) - 1) / STRIDEX) + 1;
|
||||
const DIMY_IN: usize = 9;
|
||||
const DIMY_OUT: usize = ((DIMY_IN - (DILATIONY + 1) * (KERNELY - 1) - 1) / STRIDEY) + 1;
|
||||
|
||||
let inp1 = cx.tensor::<R3<CH_IN, DIMX_IN, DIMY_IN>>().set(vec![
|
||||
let inp1 = cx.tensor((CH_IN, DIMX_IN, DIMY_IN)).set(vec![
|
||||
8., 8., 5., 7., 0., 6., 5., 3., 0., 7., 0., 6., 6., 7., 7., 5., 0., 6., 9., 4., 0., 8., 8.,
|
||||
5., 7., 6., 2., 8., 9., 5., 0., 3., 1., 1., 8., 4., 1., 1., 5., 6., 9., 3., 2., 9., 4., 7.,
|
||||
1., 0., 7., 7., 4., 9., 5., 0., 4., 7., 4., 7., 8., 8., 4., 8., 4., 7., 9., 3., 7., 9., 5.,
|
||||
@@ -518,7 +510,14 @@ fn test_conv2d() {
|
||||
1., 2., 1., 1., 4., 7., 2.,
|
||||
]);
|
||||
|
||||
let model = luminal_nn::Conv2D::<CH_IN, CH_OUT, KERNELX, KERNELY>::initialize(&mut cx);
|
||||
let model = luminal_nn::Conv2D::new(
|
||||
CH_IN,
|
||||
CH_OUT,
|
||||
(KERNELX, KERNELY),
|
||||
(STRIDEX, STRIDEY),
|
||||
(DILATIONX, DILATIONY),
|
||||
&mut cx,
|
||||
);
|
||||
model.weight.set(vec![
|
||||
0.1600, 0.2000, 0.1900, -0.1100, 0.0100, -0.0300, -0.1200, -0.0800, -0.1300, -0.0300,
|
||||
0.1600, -0.1700, -0.0000, 0.1900, 0.1300, 0.0300, -0.1500, 0.0900, 0.0100, 0.0200, 0.1500,
|
||||
@@ -526,9 +525,7 @@ fn test_conv2d() {
|
||||
0.0100, -0.2000, 0.2100, -0.0400, -0.1400, 0.1500, 0.0500, -0.1700, 0.1400,
|
||||
]);
|
||||
|
||||
let mut out1 = model
|
||||
.forward::<DIMX_IN, DIMY_IN, DIMX_OUT, DIMY_OUT>(inp1)
|
||||
.retrieve();
|
||||
let mut out1 = model.forward(inp1).retrieve();
|
||||
|
||||
cx.compile(
|
||||
<(GenericCompiler, MetalCompiler<f32>)>::default(),
|
||||
@@ -564,14 +561,14 @@ fn test_conv1d_pad_stride() {
|
||||
let kernel_data = random_vec_rng(KERNEL * CH_IN * CH_OUT, &mut rng);
|
||||
let input_data = random_vec_rng(CH_IN * DIM_IN, &mut rng);
|
||||
|
||||
let model = Conv1D::<CH_IN, CH_OUT, KERNEL, STRIDE, 0, PADDING>::initialize(&mut cx);
|
||||
let model = Conv1D::new(CH_IN, CH_OUT, KERNEL, STRIDE, 1, PADDING, false, &mut cx);
|
||||
model.weight.set(kernel_data.clone());
|
||||
|
||||
let inp1 = cx
|
||||
.tensor::<(LConst<1>, LConst<CH_IN>, Dyn<'s'>)>()
|
||||
.set_dyn(input_data.clone(), &[1, CH_IN, DIM_IN]);
|
||||
.tensor((1, CH_IN, 's'))
|
||||
.set_dyn(input_data.clone(), (1, CH_IN, DIM_IN));
|
||||
|
||||
let mut out1 = model.forward((inp1, PhantomData::<Dyn<'s'>>)).retrieve();
|
||||
let mut out1 = model.forward(inp1).retrieve();
|
||||
cx.compile(crate::MetalCompiler::<f32>::default(), &mut out1);
|
||||
cx.execute();
|
||||
|
||||
|
||||
@@ -12,8 +12,8 @@ macro_rules! single_unary_test {
|
||||
let mut rng = StdRng::seed_from_u64(1);
|
||||
let data = random_vec_rng($size, &mut rng);
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor::<R1<$size>>().set(data.clone());
|
||||
let f: fn(GraphTensor<R1<$size>>) -> GraphTensor<R1<$size>> = $luminal_func;
|
||||
let a = cx.tensor($size).set(data.clone());
|
||||
let f: fn(GraphTensor) -> GraphTensor = $luminal_func;
|
||||
let mut b = f(a).retrieve();
|
||||
cx.compile(MetalCompiler::<$type>::default(), &mut b);
|
||||
cx.execute();
|
||||
@@ -53,9 +53,9 @@ macro_rules! single_binary_test {
|
||||
let a_data = random_vec_rng($size, &mut rng);
|
||||
let b_data = random_vec_rng($size, &mut rng);
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor::<R1<$size>>().set(a_data.clone());
|
||||
let b = cx.tensor::<R1<$size>>().set(b_data.clone());
|
||||
let f: fn(GraphTensor<R1<$size>>, GraphTensor<R1<$size>>) -> GraphTensor<R1<$size>> =
|
||||
let a = cx.tensor($size).set(a_data.clone());
|
||||
let b = cx.tensor($size).set(b_data.clone());
|
||||
let f: fn(GraphTensor, GraphTensor) -> GraphTensor =
|
||||
$luminal_func;
|
||||
let mut c = f(a, b).retrieve();
|
||||
cx.compile(MetalCompiler::<$type>::default(), &mut c);
|
||||
|
||||
@@ -772,8 +772,8 @@ mod tests {
|
||||
#[test]
|
||||
fn test_norms() {
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor().set([0.; 32]);
|
||||
let mut b = a.layer_norm::<Axis<0>, _>(1e-5).retrieve();
|
||||
let a = cx.tensor(32).set([0.; 32]);
|
||||
let mut b = a.layer_norm(0, 1e-5).retrieve();
|
||||
|
||||
cx.compile(
|
||||
<(
|
||||
|
||||
@@ -1,85 +1,65 @@
|
||||
use luminal::prelude::*;
|
||||
|
||||
/// Rectified Linear Unit activation function
|
||||
#[derive(Default)]
|
||||
pub struct ReLU;
|
||||
|
||||
impl InitModule for ReLU {
|
||||
fn initialize(_: &mut Graph) -> Self {
|
||||
Self
|
||||
}
|
||||
}
|
||||
|
||||
impl SerializeModule for ReLU {
|
||||
fn serialize(&self, _: &mut Serializer) {}
|
||||
}
|
||||
|
||||
impl<S: Shape> Module<GraphTensor<S>> for ReLU {
|
||||
type Output = GraphTensor<S>;
|
||||
impl Module<GraphTensor> for ReLU {
|
||||
type Output = GraphTensor;
|
||||
|
||||
fn forward(&self, input: GraphTensor<S>) -> Self::Output {
|
||||
fn forward(&self, input: GraphTensor) -> Self::Output {
|
||||
input.relu()
|
||||
}
|
||||
}
|
||||
|
||||
/// Sigmoid activation function
|
||||
#[derive(Default)]
|
||||
pub struct Sigmoid;
|
||||
|
||||
impl InitModule for Sigmoid {
|
||||
fn initialize(_: &mut Graph) -> Self {
|
||||
Self
|
||||
}
|
||||
}
|
||||
|
||||
impl SerializeModule for Sigmoid {
|
||||
fn serialize(&self, _: &mut Serializer) {}
|
||||
}
|
||||
|
||||
impl<S: ConstShape> Module<GraphTensor<S>> for Sigmoid {
|
||||
type Output = GraphTensor<S>;
|
||||
impl Module<GraphTensor> for Sigmoid {
|
||||
type Output = GraphTensor;
|
||||
|
||||
fn forward(&self, input: GraphTensor<S>) -> Self::Output {
|
||||
fn forward(&self, input: GraphTensor) -> Self::Output {
|
||||
input.sigmoid()
|
||||
}
|
||||
}
|
||||
|
||||
/// Swish activation function
|
||||
#[derive(Default)]
|
||||
pub struct Swish;
|
||||
|
||||
impl InitModule for Swish {
|
||||
fn initialize(_: &mut Graph) -> Self {
|
||||
Self
|
||||
}
|
||||
}
|
||||
|
||||
impl SerializeModule for Swish {
|
||||
fn serialize(&self, _: &mut Serializer) {}
|
||||
}
|
||||
|
||||
impl<S: ConstShape> Module<GraphTensor<S>> for Swish {
|
||||
type Output = GraphTensor<S>;
|
||||
impl Module<GraphTensor> for Swish {
|
||||
type Output = GraphTensor;
|
||||
|
||||
fn forward(&self, input: GraphTensor<S>) -> Self::Output {
|
||||
fn forward(&self, input: GraphTensor) -> Self::Output {
|
||||
input.swish()
|
||||
}
|
||||
}
|
||||
|
||||
/// Tanh activation function
|
||||
#[derive(Default)]
|
||||
pub struct Tanh;
|
||||
|
||||
impl InitModule for Tanh {
|
||||
fn initialize(_: &mut Graph) -> Self {
|
||||
Self
|
||||
}
|
||||
}
|
||||
|
||||
impl SerializeModule for Tanh {
|
||||
fn serialize(&self, _: &mut Serializer) {}
|
||||
}
|
||||
|
||||
impl<S: ConstShape> Module<GraphTensor<S>> for Tanh {
|
||||
type Output = GraphTensor<S>;
|
||||
impl Module<GraphTensor> for Tanh {
|
||||
type Output = GraphTensor;
|
||||
|
||||
fn forward(&self, input: GraphTensor<S>) -> Self::Output {
|
||||
fn forward(&self, input: GraphTensor) -> Self::Output {
|
||||
input.tanh()
|
||||
}
|
||||
}
|
||||
@@ -98,12 +78,14 @@ mod tests {
|
||||
fn test_relu_and_linear() {
|
||||
// Test single and batch, unoptimized and optimized
|
||||
let mut cx = Graph::new();
|
||||
let batch = cx
|
||||
.tensor::<R2<2, 3>>()
|
||||
.set(vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0]);
|
||||
let a = cx.tensor::<R1<3>>().set(vec![1.0, 2.0, 3.0]);
|
||||
let batch = cx.tensor((2, 3)).set(vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0]);
|
||||
let a = cx.tensor(3).set(vec![1.0, 2.0, 3.0]);
|
||||
|
||||
let model: (Linear<3, 4>, ReLU, Linear<4, 2>) = InitModule::initialize(&mut cx);
|
||||
let model = (
|
||||
Linear::new(3, 4, false, &mut cx),
|
||||
ReLU,
|
||||
Linear::new(4, 2, false, &mut cx),
|
||||
);
|
||||
model
|
||||
.0
|
||||
.weight
|
||||
|
||||
@@ -1,81 +1,45 @@
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use luminal::prelude::*;
|
||||
use rand::{thread_rng, Rng};
|
||||
|
||||
pub struct Conv1D<
|
||||
const CH_IN: usize,
|
||||
const CH_OUT: usize,
|
||||
const KERNEL: usize,
|
||||
const STRIDE: usize = KERNEL,
|
||||
const DILATION: usize = 0,
|
||||
const PADDING: usize = 0,
|
||||
> {
|
||||
pub weight: GraphTensor<R3<CH_OUT, CH_IN, KERNEL>>,
|
||||
pub bias: Option<GraphTensor<R1<CH_OUT>>>,
|
||||
pub struct Conv1D {
|
||||
pub weight: GraphTensor, // ch_out, ch_in * kernel
|
||||
pub bias: Option<GraphTensor>,
|
||||
padding: usize,
|
||||
dilation: usize,
|
||||
stride: usize,
|
||||
kernel: usize,
|
||||
ch_in: usize,
|
||||
}
|
||||
|
||||
impl<
|
||||
const CH_IN: usize,
|
||||
const CH_OUT: usize,
|
||||
const KERNEL: usize,
|
||||
const STRIDE: usize,
|
||||
const DILATION: usize,
|
||||
const PADDING: usize,
|
||||
> InitModule for Conv1D<CH_IN, CH_OUT, KERNEL, STRIDE, DILATION, PADDING>
|
||||
{
|
||||
fn initialize(cx: &mut Graph) -> Self {
|
||||
// Init weight as uniform(-1, 1)
|
||||
let mut rng = thread_rng();
|
||||
impl Conv1D {
|
||||
/// Create a new 1D convolution layer
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn new(
|
||||
ch_in: usize,
|
||||
ch_out: usize,
|
||||
kernel: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
padding: usize,
|
||||
bias: bool,
|
||||
cx: &mut Graph,
|
||||
) -> Self {
|
||||
Self {
|
||||
weight: cx.named_tensor("Weight").set(
|
||||
(0..(CH_IN * CH_OUT * KERNEL))
|
||||
.map(|_| rng.gen_range(-1_f32..1_f32))
|
||||
.collect::<Vec<_>>(),
|
||||
),
|
||||
bias: None,
|
||||
weight: cx.named_tensor("Weight", (ch_out, ch_in * kernel)),
|
||||
bias: if bias {
|
||||
Some(cx.named_tensor("Bias", ch_out))
|
||||
} else {
|
||||
None
|
||||
},
|
||||
padding,
|
||||
dilation,
|
||||
stride,
|
||||
kernel,
|
||||
ch_in,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<
|
||||
const CH_IN: usize,
|
||||
const CH_OUT: usize,
|
||||
const KERNEL: usize,
|
||||
const STRIDE: usize,
|
||||
const DILATION: usize,
|
||||
const PADDING: usize,
|
||||
> Conv1D<CH_IN, CH_OUT, KERNEL, STRIDE, DILATION, PADDING>
|
||||
{
|
||||
pub fn initialize_bias(cx: &mut Graph) -> Self {
|
||||
// Init weight as uniform(-1, 1)
|
||||
let mut rng = thread_rng();
|
||||
Self {
|
||||
weight: cx.named_tensor("Weight").set(
|
||||
(0..(CH_IN * CH_OUT * KERNEL))
|
||||
.map(|_| rng.gen_range(-1_f32..1_f32))
|
||||
.collect::<Vec<_>>(),
|
||||
),
|
||||
bias: Some(
|
||||
cx.named_tensor("Bias").set(
|
||||
(0..CH_OUT)
|
||||
.map(|_| rng.gen_range(-1_f32..1_f32))
|
||||
.collect::<Vec<_>>(),
|
||||
),
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<
|
||||
const CH_IN: usize,
|
||||
const CH_OUT: usize,
|
||||
const KERNEL: usize,
|
||||
const STRIDE: usize,
|
||||
const DILATION: usize,
|
||||
const PADDING: usize,
|
||||
> SerializeModule for Conv1D<CH_IN, CH_OUT, KERNEL, STRIDE, DILATION, PADDING>
|
||||
{
|
||||
impl SerializeModule for Conv1D {
|
||||
fn serialize(&self, s: &mut luminal::module::Serializer) {
|
||||
s.tensor("weight", self.weight);
|
||||
if let Some(bias) = self.bias {
|
||||
@@ -85,377 +49,193 @@ impl<
|
||||
}
|
||||
|
||||
// Single
|
||||
impl<
|
||||
const CH_IN: usize,
|
||||
const CH_OUT: usize,
|
||||
const KERNEL: usize,
|
||||
const STRIDE: usize,
|
||||
const DILATION: usize,
|
||||
const PADDING: usize,
|
||||
DimIn: Dimension,
|
||||
DimOut: Dimension,
|
||||
> Module<(GraphTensor<(Const<CH_IN>, DimIn)>, PhantomData<DimOut>)>
|
||||
for Conv1D<CH_IN, CH_OUT, KERNEL, STRIDE, DILATION, PADDING>
|
||||
{
|
||||
type Output = GraphTensor<(Const<CH_OUT>, DimOut)>;
|
||||
fn forward(
|
||||
&self,
|
||||
(input, ph): (GraphTensor<(Const<CH_IN>, DimIn)>, PhantomData<DimOut>),
|
||||
) -> Self::Output {
|
||||
<Self as Module<(
|
||||
GraphTensor<(Const<1>, Const<1>, Const<CH_IN>, DimIn)>,
|
||||
PhantomData<DimOut>,
|
||||
)>>::forward(self, (input.expand(), ph))
|
||||
.reshape()
|
||||
}
|
||||
}
|
||||
// Batch 1D
|
||||
impl<
|
||||
const CH_IN: usize,
|
||||
const CH_OUT: usize,
|
||||
const KERNEL: usize,
|
||||
const STRIDE: usize,
|
||||
const DILATION: usize,
|
||||
const PADDING: usize,
|
||||
DimIn: Dimension,
|
||||
DimOut: Dimension,
|
||||
Batch: Dimension,
|
||||
>
|
||||
Module<(
|
||||
GraphTensor<(Batch, Const<CH_IN>, DimIn)>,
|
||||
PhantomData<DimOut>,
|
||||
)> for Conv1D<CH_IN, CH_OUT, KERNEL, STRIDE, DILATION, PADDING>
|
||||
{
|
||||
type Output = GraphTensor<(Batch, Const<CH_OUT>, DimOut)>;
|
||||
fn forward(
|
||||
&self,
|
||||
(input, ph): (
|
||||
GraphTensor<(Batch, Const<CH_IN>, DimIn)>,
|
||||
PhantomData<DimOut>,
|
||||
),
|
||||
) -> Self::Output {
|
||||
<Self as Module<(
|
||||
GraphTensor<(Const<1>, Batch, Const<CH_IN>, DimIn)>,
|
||||
PhantomData<DimOut>,
|
||||
)>>::forward(self, (input.expand(), ph))
|
||||
.reshape()
|
||||
}
|
||||
}
|
||||
// Batch x Batch
|
||||
impl<
|
||||
const CH_IN: usize,
|
||||
const CH_OUT: usize,
|
||||
const KERNEL: usize,
|
||||
const STRIDE: usize,
|
||||
const DILATION: usize,
|
||||
const PADDING: usize,
|
||||
DimIn: Dimension,
|
||||
DimOut: Dimension,
|
||||
Batch1: Dimension,
|
||||
Batch2: Dimension,
|
||||
>
|
||||
Module<(
|
||||
GraphTensor<(Batch1, Batch2, Const<CH_IN>, DimIn)>,
|
||||
PhantomData<DimOut>,
|
||||
)> for Conv1D<CH_IN, CH_OUT, KERNEL, STRIDE, DILATION, PADDING>
|
||||
{
|
||||
type Output = GraphTensor<(Batch1, Batch2, Const<CH_OUT>, DimOut)>;
|
||||
fn forward(
|
||||
&self,
|
||||
(input, _): (
|
||||
GraphTensor<(Batch1, Batch2, Const<CH_IN>, DimIn)>,
|
||||
PhantomData<DimOut>,
|
||||
),
|
||||
) -> Self::Output {
|
||||
let mut o = input
|
||||
impl Module<GraphTensor> for Conv1D {
|
||||
type Output = GraphTensor;
|
||||
fn forward(&self, input: GraphTensor) -> Self::Output {
|
||||
assert_eq!(input.shape()[input.shape.len() - 2], self.ch_in);
|
||||
// Input: batch_dims, ch_in, dim_in
|
||||
// Reshape to 2 batch dims
|
||||
let n_expands = 4 - input.shape.len();
|
||||
let mut inp = input;
|
||||
for _ in 0..n_expands {
|
||||
inp = inp.expand(0, 1);
|
||||
}
|
||||
|
||||
let batch1 = inp.shape()[0].small();
|
||||
let batch2 = inp.shape()[1].small();
|
||||
let dim_in = input.shape().last().unwrap().small();
|
||||
let dim_out = (((dim_in + 2 * self.padding - self.dilation * (self.kernel - 1) - 1)
|
||||
/ self.stride)
|
||||
+ 1)
|
||||
.simplify();
|
||||
let mut out = inp
|
||||
// Add padding
|
||||
.pad::<(Batch1, Batch2, Const<CH_IN>, DimIn)>(((0, 0), (0, 0), (0, 0), (PADDING, 0)))
|
||||
.pad(((0, 0), (0, 0), (0, 0), (self.padding, 0)))
|
||||
.contiguous()
|
||||
.pad::<(Batch1, Batch2, Const<CH_IN>, DimIn)>(((0, 0), (0, 0), (0, 0), (0, PADDING)))
|
||||
.pad(((0, 0), (0, 0), (0, 0), (0, self.padding)))
|
||||
// Pool
|
||||
.pool_last_dim::<(Batch1, Batch2, Const<CH_IN>, DimOut, Const<KERNEL>)>(
|
||||
KERNEL, STRIDE, DILATION,
|
||||
)
|
||||
.pool_last_dim(self.kernel, self.stride, self.dilation)
|
||||
// Combine channel_in and kernel
|
||||
.permute::<_, Axes5<0, 1, 3, 2, 4>>()
|
||||
.dyn_reshape::<(Batch1, Batch2, DimOut, Dyn<'-'>), _>(&[
|
||||
Batch1::size(),
|
||||
Batch2::size(),
|
||||
DimOut::size(),
|
||||
(CH_IN * KERNEL).into(),
|
||||
])
|
||||
.permute((0, 1, 3, 2, 4))
|
||||
.reshape((batch1, batch2, dim_out, self.ch_in * self.kernel))
|
||||
.matmul(
|
||||
self.weight
|
||||
// Combine last two dimensions in kernel
|
||||
.dyn_reshape::<(Const<CH_OUT>, Dyn<'-'>), _>(&[CH_OUT, CH_IN * KERNEL])
|
||||
.permute()
|
||||
.permute((1, 0))
|
||||
// Broadcast along batch dims
|
||||
.expand(),
|
||||
.expand(0, batch1)
|
||||
.expand(1, batch2),
|
||||
)
|
||||
.permute();
|
||||
.permute((0, 1, 3, 2));
|
||||
if let Some(b) = self.bias {
|
||||
o += b.expand();
|
||||
out += b.expand_to(out.shape);
|
||||
}
|
||||
o
|
||||
|
||||
// Reshape back to original shape
|
||||
let mut final_shape = out.shape();
|
||||
for _ in 0..n_expands {
|
||||
final_shape.remove(0);
|
||||
}
|
||||
out.reshape(final_shape) // Output: batch_dims, ch_out, dim_out
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Conv2D<
|
||||
const CH_IN: usize,
|
||||
const CH_OUT: usize,
|
||||
const KERNELX: usize,
|
||||
const KERNELY: usize,
|
||||
const STRIDEX: usize = KERNELX,
|
||||
const STRIDEY: usize = KERNELY,
|
||||
const DILATIONX: usize = 0,
|
||||
const DILATIONY: usize = 0,
|
||||
> {
|
||||
pub weight: GraphTensor<R4<CH_OUT, CH_IN, KERNELX, KERNELY>>,
|
||||
pub struct Conv2D {
|
||||
pub weight: GraphTensor, // ch_out, ch_in * kernel_x * kernel_y
|
||||
kernel: (usize, usize),
|
||||
stride: (usize, usize),
|
||||
dilation: (usize, usize),
|
||||
ch_out: usize,
|
||||
ch_in: usize,
|
||||
}
|
||||
|
||||
impl<
|
||||
const CH_IN: usize,
|
||||
const CH_OUT: usize,
|
||||
const KERNELX: usize,
|
||||
const KERNELY: usize,
|
||||
const STRIDEX: usize,
|
||||
const STRIDEY: usize,
|
||||
const DILATIONX: usize,
|
||||
const DILATIONY: usize,
|
||||
> InitModule
|
||||
for Conv2D<CH_IN, CH_OUT, KERNELX, KERNELY, STRIDEX, STRIDEY, DILATIONX, DILATIONY>
|
||||
{
|
||||
fn initialize(cx: &mut Graph) -> Self {
|
||||
// Init weight as uniform(-1, 1)
|
||||
let mut rng = thread_rng();
|
||||
impl Conv2D {
|
||||
pub fn new(
|
||||
ch_in: usize,
|
||||
ch_out: usize,
|
||||
kernel: (usize, usize),
|
||||
stride: (usize, usize),
|
||||
dilation: (usize, usize),
|
||||
cx: &mut Graph,
|
||||
) -> Self {
|
||||
Self {
|
||||
weight: cx.named_tensor("Weight").set(
|
||||
(0..(CH_IN * CH_OUT * KERNELX * KERNELY))
|
||||
.map(|_| rng.gen_range(-1_f32..1_f32))
|
||||
.collect::<Vec<_>>(),
|
||||
),
|
||||
weight: cx.named_tensor("Weight", (ch_out, ch_in * kernel.0 * kernel.1)),
|
||||
kernel,
|
||||
stride,
|
||||
dilation,
|
||||
ch_out,
|
||||
ch_in,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<
|
||||
const CH_IN: usize,
|
||||
const CH_OUT: usize,
|
||||
const KERNELX: usize,
|
||||
const KERNELY: usize,
|
||||
const STRIDEX: usize,
|
||||
const STRIDEY: usize,
|
||||
const DILATIONX: usize,
|
||||
const DILATIONY: usize,
|
||||
> SerializeModule
|
||||
for Conv2D<CH_IN, CH_OUT, KERNELX, KERNELY, STRIDEX, STRIDEY, DILATIONX, DILATIONY>
|
||||
{
|
||||
impl SerializeModule for Conv2D {
|
||||
fn serialize(&self, s: &mut luminal::module::Serializer) {
|
||||
s.tensor("weight", self.weight);
|
||||
}
|
||||
}
|
||||
|
||||
// Single
|
||||
impl<
|
||||
const CH_IN: usize,
|
||||
const CH_OUT: usize,
|
||||
const KERNELX: usize,
|
||||
const KERNELY: usize,
|
||||
const STRIDEX: usize,
|
||||
const STRIDEY: usize,
|
||||
const DILATIONX: usize,
|
||||
const DILATIONY: usize,
|
||||
> Conv2D<CH_IN, CH_OUT, KERNELX, KERNELY, STRIDEX, STRIDEY, DILATIONX, DILATIONY>
|
||||
{
|
||||
pub fn forward<
|
||||
const DIMX_IN: usize,
|
||||
const DIMY_IN: usize,
|
||||
const DIMX_OUT: usize,
|
||||
const DIMY_OUT: usize,
|
||||
>(
|
||||
&self,
|
||||
input: GraphTensor<R3<CH_IN, DIMX_IN, DIMY_IN>>,
|
||||
) -> GraphTensor<R3<CH_OUT, DIMX_OUT, DIMY_OUT>> {
|
||||
impl Conv2D {
|
||||
pub fn forward(&self, input: GraphTensor) -> GraphTensor {
|
||||
// Input: (ch_in, dimx_in, dimy_in)
|
||||
let dimx_in = input.shape()[1].small();
|
||||
let dimy_in = input.shape()[2].small();
|
||||
let dimx_out = (((dimx_in - self.dilation.0 * (self.kernel.0 - 1) - 1) / self.stride.0)
|
||||
+ 1)
|
||||
.simplify();
|
||||
let dimy_out = (((dimy_in - self.dilation.1 * (self.kernel.1 - 1) - 1) / self.stride.1)
|
||||
+ 1)
|
||||
.simplify();
|
||||
let input_pooled = input
|
||||
.pool_last_dim::<R4<CH_IN, DIMX_IN, DIMY_OUT, KERNELY>>(KERNELY, STRIDEY, DILATIONY)
|
||||
.permute::<_, Axes4<0, 2, 3, 1>>()
|
||||
.pool_last_dim::<R5<CH_IN, DIMY_OUT, KERNELY, DIMX_OUT, KERNELX>>(
|
||||
KERNELX, STRIDEX, DILATIONX,
|
||||
)
|
||||
.permute::<_, Axes5<0, 4, 2, 3, 1>>()
|
||||
.dyn_reshape::<(_, Dyn<'-'>), _>(&[CH_IN * KERNELX * KERNELY, DIMX_OUT * DIMY_OUT]);
|
||||
.pool_last_dim(self.kernel.1, self.stride.1, self.dilation.1)
|
||||
.permute((0, 2, 3, 1))
|
||||
.pool_last_dim(self.kernel.0, self.stride.0, self.dilation.0)
|
||||
.permute((0, 4, 2, 3, 1))
|
||||
.reshape((
|
||||
self.ch_in * self.kernel.0 * self.kernel.1,
|
||||
dimx_out * dimy_out,
|
||||
));
|
||||
|
||||
self.weight
|
||||
.dyn_reshape::<(Const<CH_OUT>, Dyn<'-'>), _>(&[CH_OUT, CH_IN * KERNELX * KERNELY])
|
||||
.matmul(input_pooled)
|
||||
.reshape::<R3<CH_OUT, DIMX_OUT, DIMY_OUT>>()
|
||||
.reshape((self.ch_out, dimx_out, dimy_out))
|
||||
}
|
||||
}
|
||||
pub struct Conv3D<
|
||||
const CH_IN: usize,
|
||||
const CH_OUT: usize,
|
||||
const KERNELX: usize,
|
||||
const KERNELY: usize,
|
||||
const KERNELZ: usize,
|
||||
const STRIDEX: usize,
|
||||
const STRIDEY: usize,
|
||||
const STRIDEZ: usize,
|
||||
const DILATIONX: usize,
|
||||
const DILATIONY: usize,
|
||||
const DILATIONZ: usize,
|
||||
> {
|
||||
pub weight: GraphTensor<R5<CH_OUT, CH_IN, KERNELX, KERNELY, KERNELZ>>,
|
||||
pub struct Conv3D {
|
||||
pub weight: GraphTensor, // ch_out, ch_in * kernel_x * kernel_y * kernel_z
|
||||
kernel: (usize, usize, usize),
|
||||
stride: (usize, usize, usize),
|
||||
dilation: (usize, usize, usize),
|
||||
ch_in: usize,
|
||||
ch_out: usize,
|
||||
}
|
||||
|
||||
impl<
|
||||
const CH_IN: usize,
|
||||
const CH_OUT: usize,
|
||||
const KERNELX: usize,
|
||||
const KERNELY: usize,
|
||||
const KERNELZ: usize,
|
||||
const STRIDEX: usize,
|
||||
const STRIDEY: usize,
|
||||
const STRIDEZ: usize,
|
||||
const DILATIONX: usize,
|
||||
const DILATIONY: usize,
|
||||
const DILATIONZ: usize,
|
||||
> InitModule
|
||||
for Conv3D<
|
||||
CH_IN,
|
||||
CH_OUT,
|
||||
KERNELX,
|
||||
KERNELY,
|
||||
KERNELZ,
|
||||
STRIDEX,
|
||||
STRIDEY,
|
||||
STRIDEZ,
|
||||
DILATIONX,
|
||||
DILATIONY,
|
||||
DILATIONZ,
|
||||
>
|
||||
{
|
||||
fn initialize(cx: &mut Graph) -> Self {
|
||||
// Init weight as uniform(-1, 1)
|
||||
let mut rng = thread_rng();
|
||||
impl Conv3D {
|
||||
pub fn new(
|
||||
ch_in: usize,
|
||||
ch_out: usize,
|
||||
kernel: (usize, usize, usize),
|
||||
stride: (usize, usize, usize),
|
||||
dilation: (usize, usize, usize),
|
||||
cx: &mut Graph,
|
||||
) -> Self {
|
||||
Self {
|
||||
weight: cx.named_tensor("Weight").set(
|
||||
(0..(CH_IN * CH_OUT * KERNELX * KERNELY * KERNELZ))
|
||||
.map(|_| rng.gen_range(-1_f32..1_f32))
|
||||
.collect::<Vec<_>>(),
|
||||
),
|
||||
weight: cx.named_tensor("Weight", (ch_out, ch_in * kernel.0 * kernel.1 * kernel.2)),
|
||||
kernel,
|
||||
stride,
|
||||
dilation,
|
||||
ch_in,
|
||||
ch_out,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<
|
||||
const CH_IN: usize,
|
||||
const CH_OUT: usize,
|
||||
const KERNELX: usize,
|
||||
const KERNELY: usize,
|
||||
const KERNELZ: usize,
|
||||
const STRIDEX: usize,
|
||||
const STRIDEY: usize,
|
||||
const STRIDEZ: usize,
|
||||
const DILATIONX: usize,
|
||||
const DILATIONY: usize,
|
||||
const DILATIONZ: usize,
|
||||
> SerializeModule
|
||||
for Conv3D<
|
||||
CH_IN,
|
||||
CH_OUT,
|
||||
KERNELX,
|
||||
KERNELY,
|
||||
KERNELZ,
|
||||
STRIDEX,
|
||||
STRIDEY,
|
||||
STRIDEZ,
|
||||
DILATIONX,
|
||||
DILATIONY,
|
||||
DILATIONZ,
|
||||
>
|
||||
{
|
||||
impl SerializeModule for Conv3D {
|
||||
fn serialize(&self, s: &mut luminal::module::Serializer) {
|
||||
s.tensor("weight", self.weight);
|
||||
}
|
||||
}
|
||||
|
||||
impl<
|
||||
const CH_IN: usize,
|
||||
const CH_OUT: usize,
|
||||
const KERNELX: usize,
|
||||
const KERNELY: usize,
|
||||
const KERNELZ: usize,
|
||||
const STRIDEX: usize,
|
||||
const STRIDEY: usize,
|
||||
const STRIDEZ: usize,
|
||||
const DILATIONX: usize,
|
||||
const DILATIONY: usize,
|
||||
const DILATIONZ: usize,
|
||||
>
|
||||
Conv3D<
|
||||
CH_IN,
|
||||
CH_OUT,
|
||||
KERNELX,
|
||||
KERNELY,
|
||||
KERNELZ,
|
||||
STRIDEX,
|
||||
STRIDEY,
|
||||
STRIDEZ,
|
||||
DILATIONX,
|
||||
DILATIONY,
|
||||
DILATIONZ,
|
||||
>
|
||||
{
|
||||
pub fn forward<
|
||||
const DIMX_IN: usize,
|
||||
const DIMY_IN: usize,
|
||||
const DIMZ_IN: usize,
|
||||
const DIMX_OUT: usize,
|
||||
const DIMY_OUT: usize,
|
||||
const DIMZ_OUT: usize,
|
||||
>(
|
||||
&self,
|
||||
input: GraphTensor<R4<CH_IN, DIMX_IN, DIMY_IN, DIMZ_IN>>,
|
||||
) -> GraphTensor<R4<CH_OUT, DIMX_OUT, DIMY_OUT, DIMZ_OUT>> {
|
||||
impl Conv3D {
|
||||
pub fn forward(&self, input: GraphTensor) -> GraphTensor {
|
||||
// Input: ch_in, dimx_in, dimy_in, dimz_in
|
||||
let dimx_in = input.shape()[1].small();
|
||||
let dimy_in = input.shape()[2].small();
|
||||
let dimz_in = input.shape()[3].small();
|
||||
let dimx_out = (((dimx_in - self.dilation.0 * (self.kernel.0 - 1) - 1) / self.stride.0)
|
||||
+ 1)
|
||||
.simplify();
|
||||
let dimy_out = (((dimy_in - self.dilation.1 * (self.kernel.1 - 1) - 1) / self.stride.1)
|
||||
+ 1)
|
||||
.simplify();
|
||||
let dimz_out = (((dimz_in - self.dilation.2 * (self.kernel.2 - 1) - 1) / self.stride.2)
|
||||
+ 1)
|
||||
.simplify();
|
||||
|
||||
let input_pooled = input
|
||||
.pool_last_dim::<R5<CH_IN, DIMX_IN, DIMY_OUT, DIMZ_OUT, KERNELY>>(
|
||||
KERNELY, STRIDEY, DILATIONY,
|
||||
)
|
||||
.permute::<_, Axes5<0, 2, 3, 4, 1>>()
|
||||
.pool_last_dim::<R6<CH_IN, DIMY_OUT, DIMZ_OUT, KERNELY, DIMX_OUT, KERNELX>>(
|
||||
KERNELX, STRIDEX, DILATIONX,
|
||||
)
|
||||
.dyn_reshape::<(Const<CH_IN>, Dyn<'-'>), _>(&[
|
||||
CH_IN,
|
||||
DIMZ_OUT,
|
||||
KERNELY,
|
||||
DIMX_OUT * KERNELX,
|
||||
DIMY_IN,
|
||||
]);
|
||||
.pool_last_dim(self.kernel.1, self.stride.1, self.dilation.1)
|
||||
.permute((0, 2, 3, 4, 1))
|
||||
.pool_last_dim(self.kernel.0, self.stride.0, self.dilation.0)
|
||||
.reshape((
|
||||
self.ch_in,
|
||||
dimz_out,
|
||||
self.kernel.1,
|
||||
dimx_out * self.kernel.0,
|
||||
dimy_in,
|
||||
));
|
||||
|
||||
let last_pool = input_pooled
|
||||
.pool_last_dim::<(
|
||||
Const<CH_IN>,
|
||||
Const<DIMZ_OUT>,
|
||||
Const<KERNELY>,
|
||||
Dyn<'-'>,
|
||||
Const<DIMY_IN>,
|
||||
Const<KERNELZ>,
|
||||
)>(KERNELZ, STRIDEZ, DILATIONZ)
|
||||
.permute::<_, Axes6<0, 2, 5, 3, 1, 4>>();
|
||||
.pool_last_dim(self.kernel.2, self.stride.2, self.dilation.2)
|
||||
.permute((0, 2, 5, 3, 1, 4));
|
||||
|
||||
let reshaped = last_pool.dyn_reshape::<(_, Dyn<'-'>), _>(&[
|
||||
CH_IN * KERNELX * KERNELY * KERNELZ,
|
||||
DIMX_OUT * DIMY_OUT * DIMZ_OUT,
|
||||
]);
|
||||
let reshaped = last_pool.reshape((
|
||||
self.ch_in * self.kernel.0 * self.kernel.1 * self.kernel.2,
|
||||
dimx_out * dimy_out * dimz_out,
|
||||
));
|
||||
|
||||
self.weight
|
||||
.dyn_reshape::<(Const<CH_OUT>, Dyn<'-'>), _>(&[
|
||||
CH_OUT,
|
||||
CH_IN * KERNELX * KERNELY * KERNELZ,
|
||||
])
|
||||
.matmul(reshaped)
|
||||
.reshape::<R4<CH_OUT, DIMX_OUT, DIMY_OUT, DIMZ_OUT>>()
|
||||
.reshape((self.ch_out, dimx_out, dimy_out, dimz_out))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -468,7 +248,6 @@ mod tests {
|
||||
tests::{assert_close, random_vec_rng},
|
||||
};
|
||||
use rand::{rngs::StdRng, SeedableRng};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
#[test]
|
||||
fn test_conv1d_simple() {
|
||||
@@ -481,18 +260,18 @@ mod tests {
|
||||
const DIM_IN: usize = 6;
|
||||
const DIM_OUT: usize = ((DIM_IN - (KERNEL - 1) - 1) / STRIDE) + 1;
|
||||
|
||||
let model = Conv1D::<CH_IN, CH_OUT, KERNEL>::initialize(&mut cx);
|
||||
let model = Conv1D::new(CH_IN, CH_OUT, KERNEL, KERNEL, 1, 0, false, &mut cx);
|
||||
model.weight.set([[[0.0316, -0.2057]]]);
|
||||
|
||||
let inp1 = cx
|
||||
.tensor::<R2<CH_IN, DIM_IN>>()
|
||||
.set([[3., 0., 9., 6., 0., 6.]]);
|
||||
let inp1 = cx.tensor((CH_IN, DIM_IN)).set([[3., 0., 9., 6., 0., 6.]]);
|
||||
|
||||
let out1 = model
|
||||
.forward((inp1, PhantomData::<Const<DIM_OUT>>))
|
||||
.retrieve();
|
||||
let out1 = model.forward(inp1).retrieve();
|
||||
cx.execute();
|
||||
|
||||
assert_eq!(
|
||||
out1.shape(),
|
||||
vec![BigExpression::from(CH_OUT), BigExpression::from(DIM_OUT)]
|
||||
);
|
||||
assert_close(&out1.data(), &[0.0948, -0.9498, -1.2342]);
|
||||
}
|
||||
|
||||
@@ -510,14 +289,14 @@ mod tests {
|
||||
let kernel_data = random_vec_rng(KERNEL * CH_IN * CH_OUT, &mut rng);
|
||||
let input_data = random_vec_rng(CH_IN * DIM_IN, &mut rng);
|
||||
|
||||
let model = Conv1D::<CH_IN, CH_OUT, KERNEL, STRIDE, 0, PADDING>::initialize(&mut cx);
|
||||
let model = Conv1D::new(CH_IN, CH_OUT, KERNEL, STRIDE, 1, PADDING, false, &mut cx);
|
||||
model.weight.set(kernel_data.clone());
|
||||
|
||||
let inp1 = cx
|
||||
.tensor::<(Const<1>, Const<CH_IN>, Dyn<'s'>)>()
|
||||
.set_dyn(input_data.clone(), &[1, CH_IN, DIM_IN]);
|
||||
.tensor((1, CH_IN, 's'))
|
||||
.set_dyn(input_data.clone(), (1, CH_IN, DIM_IN));
|
||||
|
||||
let out1 = model.forward((inp1, PhantomData::<Dyn<'s'>>)).retrieve();
|
||||
let out1 = model.forward(inp1).retrieve();
|
||||
cx.execute();
|
||||
|
||||
let input = Tensor::from_vec(input_data, (1, CH_IN, DIM_IN), &Device::Cpu).unwrap();
|
||||
@@ -539,9 +318,8 @@ mod tests {
|
||||
const KERNEL: usize = 2;
|
||||
const STRIDE: usize = 2;
|
||||
const DIM_IN: usize = 12;
|
||||
const DIM_OUT: usize = ((DIM_IN - (KERNEL - 1) - 1) / STRIDE) + 1;
|
||||
|
||||
let model = Conv1D::<CH_IN, CH_OUT, KERNEL>::initialize(&mut cx);
|
||||
let model = Conv1D::new(CH_IN, CH_OUT, KERNEL, STRIDE, 1, 0, false, &mut cx);
|
||||
model.weight.set(vec![
|
||||
-0.1700, -0.2000, 0.1000, -0.0200, 0.1000, 0.0200, -0.2100, -0.2300, -0.0600, 0.1500,
|
||||
0.1200, 0.1000, 0.1800, 0.0600, -0.1700, -0.0400, 0.1000, -0.0200, -0.1700, 0.1000,
|
||||
@@ -552,7 +330,7 @@ mod tests {
|
||||
0.0700, -0.1200, 0.1400, 0.2200,
|
||||
]);
|
||||
|
||||
let inp1 = cx.tensor::<R2<CH_IN, DIM_IN>>();
|
||||
let inp1 = cx.tensor((CH_IN, DIM_IN));
|
||||
inp1.set(vec![
|
||||
1., 2., 6., 4., 8., 1., 6., 0., 1., 0., 6., 4., 3., 4., 9., 3., 8., 8., 5., 5., 0., 4.,
|
||||
2., 7., 6., 4., 2., 2., 8., 0., 7., 3., 0., 0., 7., 2., 3., 3., 1., 9., 5., 4., 5., 5.,
|
||||
@@ -562,9 +340,7 @@ mod tests {
|
||||
]);
|
||||
inp1.retrieve();
|
||||
|
||||
let out1 = model
|
||||
.forward((inp1, PhantomData::<Const<DIM_OUT>>))
|
||||
.retrieve();
|
||||
let out1 = model.forward(inp1).retrieve();
|
||||
cx.execute();
|
||||
|
||||
assert_close(
|
||||
@@ -587,14 +363,14 @@ mod tests {
|
||||
const KERNELY: usize = 2;
|
||||
const STRIDEX: usize = KERNELX;
|
||||
const STRIDEY: usize = KERNELY;
|
||||
const DILATIONX: usize = 0;
|
||||
const DILATIONY: usize = 0;
|
||||
const DILATIONX: usize = 1;
|
||||
const DILATIONY: usize = 1;
|
||||
const DIMX_IN: usize = 16;
|
||||
const DIMX_OUT: usize = ((DIMX_IN - (DILATIONX + 1) * (KERNELX - 1) - 1) / STRIDEX) + 1;
|
||||
const DIMX_OUT: usize = ((DIMX_IN - DILATIONX * (KERNELX - 1) - 1) / STRIDEX) + 1;
|
||||
const DIMY_IN: usize = 9;
|
||||
const DIMY_OUT: usize = ((DIMY_IN - (DILATIONY + 1) * (KERNELY - 1) - 1) / STRIDEY) + 1;
|
||||
const DIMY_OUT: usize = ((DIMY_IN - DILATIONY * (KERNELY - 1) - 1) / STRIDEY) + 1;
|
||||
|
||||
let inp1 = cx.tensor::<R3<CH_IN, DIMX_IN, DIMY_IN>>();
|
||||
let inp1 = cx.tensor((CH_IN, DIMX_IN, DIMY_IN));
|
||||
inp1.set(vec![
|
||||
8., 8., 5., 7., 0., 6., 5., 3., 0., 7., 0., 6., 6., 7., 7., 5., 0., 6., 9., 4., 0., 8.,
|
||||
8., 5., 7., 6., 2., 8., 9., 5., 0., 3., 1., 1., 8., 4., 1., 1., 5., 6., 9., 3., 2., 9.,
|
||||
@@ -631,7 +407,7 @@ mod tests {
|
||||
3., 1., 5., 9., 1., 6., 5., 4., 2., 1., 2., 1., 1., 4., 7., 2.,
|
||||
]);
|
||||
|
||||
let exp_out1 = cx.tensor::<R3<CH_OUT, DIMX_OUT, DIMY_OUT>>();
|
||||
let exp_out1 = cx.tensor((CH_OUT, DIMX_OUT, DIMY_OUT));
|
||||
exp_out1.set(vec![
|
||||
3.9600, -0.3300, -1.7800, 4.0400, 1.5300, 0.2900, 2.8700, 3.0000, 0.9600, -1.8700,
|
||||
4.5900, 3.9700, 1.2800, 1.1800, 3.7800, 2.8500, 0.5500, 0.5600, 3.9800, 1.3200,
|
||||
@@ -644,7 +420,14 @@ mod tests {
|
||||
|
||||
exp_out1.retrieve();
|
||||
|
||||
let model = Conv2D::<CH_IN, CH_OUT, KERNELX, KERNELY>::initialize(&mut cx);
|
||||
let model = Conv2D::new(
|
||||
CH_IN,
|
||||
CH_OUT,
|
||||
(KERNELX, KERNELY),
|
||||
(STRIDEX, STRIDEY),
|
||||
(DILATIONX, DILATIONY),
|
||||
&mut cx,
|
||||
);
|
||||
model.weight.set(vec![
|
||||
0.1600, 0.2000, 0.1900, -0.1100, 0.0100, -0.0300, -0.1200, -0.0800, -0.1300, -0.0300,
|
||||
0.1600, -0.1700, -0.0000, 0.1900, 0.1300, 0.0300, -0.1500, 0.0900, 0.0100, 0.0200,
|
||||
@@ -652,9 +435,7 @@ mod tests {
|
||||
-0.2100, 0.0100, -0.2000, 0.2100, -0.0400, -0.1400, 0.1500, 0.0500, -0.1700, 0.1400,
|
||||
]);
|
||||
|
||||
let out1 = model
|
||||
.forward::<DIMX_IN, DIMY_IN, DIMX_OUT, DIMY_OUT>(inp1)
|
||||
.retrieve();
|
||||
let out1 = model.forward(inp1).retrieve();
|
||||
|
||||
cx.execute();
|
||||
|
||||
@@ -673,17 +454,17 @@ mod tests {
|
||||
const STRIDEX: usize = 2;
|
||||
const STRIDEY: usize = 2;
|
||||
const STRIDEZ: usize = 2;
|
||||
const DILATIONX: usize = 0;
|
||||
const DILATIONY: usize = 0;
|
||||
const DILATIONZ: usize = 0;
|
||||
const DILATIONX: usize = 1;
|
||||
const DILATIONY: usize = 1;
|
||||
const DILATIONZ: usize = 1;
|
||||
const DIMX_IN: usize = 2;
|
||||
const DIMY_IN: usize = 3;
|
||||
const DIMZ_IN: usize = 5;
|
||||
const DIMX_OUT: usize = ((DIMX_IN - (DILATIONX + 1) * (KERNELX - 1) - 1) / STRIDEX) + 1;
|
||||
const DIMY_OUT: usize = ((DIMY_IN - (DILATIONY + 1) * (KERNELY - 1) - 1) / STRIDEY) + 1;
|
||||
const DIMZ_OUT: usize = ((DIMZ_IN - (DILATIONZ + 1) * (KERNELZ - 1) - 1) / STRIDEZ) + 1;
|
||||
const DIMX_OUT: usize = ((DIMX_IN - DILATIONX * (KERNELX - 1) - 1) / STRIDEX) + 1;
|
||||
const DIMY_OUT: usize = ((DIMY_IN - DILATIONY * (KERNELY - 1) - 1) / STRIDEY) + 1;
|
||||
const DIMZ_OUT: usize = ((DIMZ_IN - DILATIONZ * (KERNELZ - 1) - 1) / STRIDEZ) + 1;
|
||||
|
||||
let inp1 = cx.tensor::<R4<CH_IN, DIMX_IN, DIMY_IN, DIMZ_IN>>();
|
||||
let inp1 = cx.tensor((CH_IN, DIMX_IN, DIMY_IN, DIMZ_IN));
|
||||
inp1.set(vec![
|
||||
// Example input data (5 channels, 2x3x5 volume)
|
||||
8., 8., 5., 7., 0., 6., 5., 3., 0., 7., 0., 6., 6., 7., 7., 5., 0., 6., 9., 4., 0., 8.,
|
||||
@@ -695,7 +476,7 @@ mod tests {
|
||||
4., 1., 9., 7., 7., 1., 2., 6., 3., 4., 1., 1., 6., 6., 8., 2., 7., 7.,
|
||||
]);
|
||||
|
||||
let exp_out1 = cx.tensor::<R4<CH_OUT, DIMX_OUT, DIMY_OUT, DIMZ_OUT>>();
|
||||
let exp_out1 = cx.tensor((CH_OUT, DIMX_OUT, DIMY_OUT, DIMZ_OUT));
|
||||
exp_out1.set(vec![
|
||||
// Example expected output data (2 channels, 1x1x2 volume)
|
||||
90.6935, 98.7138, 98.8273, 102.6553,
|
||||
@@ -703,19 +484,14 @@ mod tests {
|
||||
|
||||
exp_out1.retrieve();
|
||||
|
||||
let model: Conv3D<
|
||||
let model = Conv3D::new(
|
||||
CH_IN,
|
||||
CH_OUT,
|
||||
KERNELX,
|
||||
KERNELY,
|
||||
KERNELZ,
|
||||
STRIDEX,
|
||||
STRIDEY,
|
||||
STRIDEZ,
|
||||
DILATIONX,
|
||||
DILATIONY,
|
||||
DILATIONZ,
|
||||
> = Conv3D::initialize(&mut cx);
|
||||
(KERNELX, KERNELY, KERNELZ),
|
||||
(STRIDEX, STRIDEY, STRIDEZ),
|
||||
(DILATIONX, DILATIONY, DILATIONZ),
|
||||
&mut cx,
|
||||
);
|
||||
let weights = vec![
|
||||
4.273e-01, 1.388e-01, 3.546e-01, 2.403e-01, 5.572e-01, 2.788e-01, 6.718e-01, 6.935e-01,
|
||||
8.410e-01, 1.297e-01, 7.073e-01, 3.455e-01, 4.166e-01, 9.513e-01, 4.682e-01, 4.546e-02,
|
||||
@@ -730,9 +506,7 @@ mod tests {
|
||||
];
|
||||
model.weight.set(weights);
|
||||
|
||||
let out1 = model
|
||||
.forward::<DIMX_IN, DIMY_IN, DIMZ_IN, DIMX_OUT, DIMY_OUT, DIMZ_OUT>(inp1)
|
||||
.retrieve();
|
||||
let out1 = model.forward(inp1).retrieve();
|
||||
|
||||
cx.execute();
|
||||
|
||||
|
||||
@@ -1,87 +1,56 @@
|
||||
use luminal::{prelude::*, tests::random_vec};
|
||||
|
||||
pub struct Embedding<const N: usize, const DIM: usize> {
|
||||
pub weight: GraphTensor<R2<N, DIM>>,
|
||||
pub struct Embedding {
|
||||
permute: bool,
|
||||
pub weight: GraphTensor, // n embeddings x embedding dim
|
||||
}
|
||||
|
||||
impl<const A: usize, const B: usize> InitModule for Embedding<A, B> {
|
||||
fn initialize(cx: &mut Graph) -> Self {
|
||||
impl Embedding {
|
||||
pub fn new(n_embeddings: usize, embedding_dim: usize, cx: &mut Graph) -> Self {
|
||||
Self {
|
||||
weight: cx.named_tensor("Embedding Weight").set(random_vec(A * B)),
|
||||
weight: cx.named_tensor("Embedding Weight", (n_embeddings, embedding_dim)),
|
||||
permute: false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_permuted(n_embeddings: usize, embedding_dim: usize, cx: &mut Graph) -> Self {
|
||||
Self {
|
||||
weight: cx.named_tensor("Embedding Weight", (n_embeddings, embedding_dim)),
|
||||
permute: true,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn initialize(self) -> Self {
|
||||
self.weight.set(random_vec(
|
||||
self.weight.shape.n_elements().to_usize().unwrap(),
|
||||
));
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl<const A: usize, const B: usize> SerializeModule for Embedding<A, B> {
|
||||
impl SerializeModule for Embedding {
|
||||
fn serialize(&self, s: &mut luminal::module::Serializer) {
|
||||
s.tensor("weight", self.weight);
|
||||
}
|
||||
}
|
||||
|
||||
// Single
|
||||
impl<S: Dimension, const N: usize, const DIM: usize> Module<GraphTensor<(S,)>>
|
||||
for Embedding<N, DIM>
|
||||
{
|
||||
type Output = GraphTensor<(S, Const<DIM>)>;
|
||||
impl Module<GraphTensor> for Embedding {
|
||||
type Output = GraphTensor;
|
||||
|
||||
fn forward(&self, input: GraphTensor<(S,)>) -> Self::Output {
|
||||
self.weight.gather(input)
|
||||
}
|
||||
}
|
||||
|
||||
// Batch
|
||||
impl<B: Dimension, S: Dimension, const N: usize, const DIM: usize> Module<GraphTensor<(B, S)>>
|
||||
for Embedding<N, DIM>
|
||||
{
|
||||
type Output = GraphTensor<(B, S, Const<DIM>)>;
|
||||
|
||||
fn forward(&self, input: GraphTensor<(B, S)>) -> Self::Output {
|
||||
self.weight
|
||||
.gather(input.dyn_reshape::<(Dyn<'-'>,), _>(&[B::size() * S::size()]))
|
||||
.reshape()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PermutedEmbedding<const N: usize, const DIM: usize> {
|
||||
pub weight: GraphTensor<R2<DIM, N>>,
|
||||
}
|
||||
|
||||
impl<const A: usize, const B: usize> InitModule for PermutedEmbedding<A, B> {
|
||||
fn initialize(cx: &mut Graph) -> Self {
|
||||
Self {
|
||||
weight: cx.named_tensor("Embedding Weight").set(random_vec(A * B)),
|
||||
fn forward(&self, input: GraphTensor) -> Self::Output {
|
||||
// Flatten batches
|
||||
let batch_size = input.shape.n_elements();
|
||||
let inp = input.reshape(batch_size);
|
||||
let out = if self.permute {
|
||||
self.weight.permute((1, 0))
|
||||
} else {
|
||||
self.weight
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<const A: usize, const B: usize> SerializeModule for PermutedEmbedding<A, B> {
|
||||
fn serialize(&self, s: &mut luminal::module::Serializer) {
|
||||
s.tensor("weight", self.weight);
|
||||
}
|
||||
}
|
||||
|
||||
// Single
|
||||
impl<S: Dimension, const N: usize, const DIM: usize> Module<GraphTensor<(S,)>>
|
||||
for PermutedEmbedding<N, DIM>
|
||||
{
|
||||
type Output = GraphTensor<(S, Const<DIM>)>;
|
||||
|
||||
fn forward(&self, input: GraphTensor<(S,)>) -> Self::Output {
|
||||
self.weight.permute().gather(input)
|
||||
}
|
||||
}
|
||||
|
||||
// Batch
|
||||
impl<B: Dimension, S: Dimension, const N: usize, const DIM: usize> Module<GraphTensor<(B, S)>>
|
||||
for PermutedEmbedding<N, DIM>
|
||||
{
|
||||
type Output = GraphTensor<(B, S, Const<DIM>)>;
|
||||
|
||||
fn forward(&self, input: GraphTensor<(B, S)>) -> Self::Output {
|
||||
self.weight
|
||||
.permute()
|
||||
.gather(input.dyn_reshape::<(Dyn<'-'>,), _>(&[B::size() * S::size()]))
|
||||
.reshape()
|
||||
.gather(inp);
|
||||
// Unflatten
|
||||
let mut new_shape = input.shape();
|
||||
new_shape.push(self.weight.shape()[1].clone());
|
||||
out.reshape(new_shape)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -101,12 +70,10 @@ mod tests {
|
||||
#[test]
|
||||
fn test_embedding() {
|
||||
let mut cx = Graph::new();
|
||||
let batch = cx
|
||||
.tensor::<R2<2, 3>>()
|
||||
.set(vec![1.0, 0.0, 2.0, 1.0, 0.0, 1.0]);
|
||||
let a = cx.tensor::<R1<3>>().set(vec![1.0, 0.0, 1.0]).retrieve();
|
||||
let batch = cx.tensor((2, 3)).set(vec![1.0, 0.0, 2.0, 1.0, 0.0, 1.0]);
|
||||
let a = cx.tensor(3).set(vec![1.0, 0.0, 1.0]).retrieve();
|
||||
|
||||
let model: Embedding<3, 4> = InitModule::initialize(&mut cx);
|
||||
let model = Embedding::new(3, 4, &mut cx).initialize();
|
||||
model
|
||||
.weight
|
||||
.set(vec![1.1, 2., 3., 1., 2., 3., 14., 2., 33., 1., 2., 3.]);
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
use luminal::prelude::*;
|
||||
|
||||
mod activation;
|
||||
pub use activation::*;
|
||||
mod convolution;
|
||||
@@ -12,34 +10,3 @@ mod norm;
|
||||
pub use norm::*;
|
||||
mod transformer;
|
||||
pub use transformer::*;
|
||||
|
||||
pub struct Repeated<T, const N: usize> {
|
||||
pub modules: Vec<T>,
|
||||
}
|
||||
|
||||
impl<T: InitModule, const N: usize> InitModule for Repeated<T, N> {
|
||||
fn initialize(cx: &mut Graph) -> Self {
|
||||
Self {
|
||||
modules: (0..N).map(|_| InitModule::initialize(cx)).collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: SerializeModule, const N: usize> SerializeModule for Repeated<T, N> {
|
||||
fn serialize(&self, s: &mut Serializer) {
|
||||
for (i, l) in self.modules.iter().enumerate() {
|
||||
s.module(&format!("layer{i}"), l);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<I, T: Module<I, Output = I>, const N: usize> Module<I> for Repeated<T, N> {
|
||||
type Output = I;
|
||||
|
||||
fn forward(&self, mut input: I) -> Self::Output {
|
||||
for m in &self.modules {
|
||||
input = m.forward(input);
|
||||
}
|
||||
input
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,74 +3,71 @@ use rand::{thread_rng, Rng};
|
||||
use luminal::prelude::*;
|
||||
|
||||
/// A simple unbiased linear layer
|
||||
pub struct Linear<const A: usize, const B: usize> {
|
||||
pub weight: GraphTensor<R2<A, B>>,
|
||||
pub struct Linear {
|
||||
pub weight: GraphTensor,
|
||||
pub bias: Option<GraphTensor>,
|
||||
permute: bool,
|
||||
}
|
||||
|
||||
impl<const A: usize, const B: usize> InitModule for Linear<A, B> {
|
||||
fn initialize(cx: &mut Graph) -> Self {
|
||||
impl Linear {
|
||||
pub fn new(inp: usize, out: usize, bias: bool, cx: &mut Graph) -> Self {
|
||||
Self {
|
||||
weight: cx.named_tensor("Weight", (inp, out)),
|
||||
bias: if bias {
|
||||
Some(cx.named_tensor("Bias", out))
|
||||
} else {
|
||||
None
|
||||
},
|
||||
permute: false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_permuted(inp: usize, out: usize, bias: bool, cx: &mut Graph) -> Self {
|
||||
Self {
|
||||
weight: cx.named_tensor("Weight", (out, inp)),
|
||||
bias: if bias {
|
||||
Some(cx.named_tensor("Bias", out))
|
||||
} else {
|
||||
None
|
||||
},
|
||||
permute: true,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn initialize(self) -> Self {
|
||||
// Init weight as uniform(-1, 1)
|
||||
let mut rng = thread_rng();
|
||||
Self {
|
||||
weight: cx.named_tensor("Weight").set(
|
||||
(0..(A * B))
|
||||
.map(|_| rng.gen_range(-1_f32..1_f32))
|
||||
.collect::<Vec<_>>(),
|
||||
),
|
||||
self.weight.set(
|
||||
(0..self.weight.shape.n_elements().to_usize().unwrap())
|
||||
.map(|_| rng.gen_range(-1_f32..1_f32))
|
||||
.collect::<Vec<_>>(),
|
||||
);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl SerializeModule for Linear {
|
||||
fn serialize(&self, s: &mut luminal::module::Serializer) {
|
||||
s.tensor("weight", self.weight);
|
||||
if let Some(bias) = self.bias {
|
||||
s.tensor("bias", bias);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<const A: usize, const B: usize> SerializeModule for Linear<A, B> {
|
||||
fn serialize(&self, s: &mut luminal::module::Serializer) {
|
||||
s.tensor("weight", self.weight);
|
||||
}
|
||||
}
|
||||
impl Module<GraphTensor> for Linear {
|
||||
type Output = GraphTensor;
|
||||
|
||||
impl<const A: usize, const B: usize, S: Shape> Module<GraphTensor<S>> for Linear<A, B>
|
||||
where
|
||||
GraphTensor<S>: Matmul<R2<A, B>>,
|
||||
{
|
||||
type Output = <GraphTensor<S> as Matmul<R2<A, B>>>::Output;
|
||||
|
||||
fn forward(&self, input: GraphTensor<S>) -> Self::Output {
|
||||
input.matmul(self.weight)
|
||||
}
|
||||
}
|
||||
|
||||
/// A simple unbiased linear layer with a permuted weight matrix
|
||||
pub struct PermutedLinear<const A: usize, const B: usize> {
|
||||
pub weight: GraphTensor<R2<B, A>>,
|
||||
}
|
||||
|
||||
impl<const A: usize, const B: usize> InitModule for PermutedLinear<A, B> {
|
||||
fn initialize(cx: &mut Graph) -> Self {
|
||||
// Init weight as uniform(-1, 1)
|
||||
let mut rng = thread_rng();
|
||||
Self {
|
||||
weight: cx.named_tensor("Weight").set(
|
||||
(0..(A * B))
|
||||
.map(|_| rng.gen_range(-1_f32..1_f32))
|
||||
.collect::<Vec<_>>(),
|
||||
),
|
||||
fn forward(&self, input: GraphTensor) -> Self::Output {
|
||||
let mut output = input.matmul(if self.permute {
|
||||
self.weight.permute((1, 0))
|
||||
} else {
|
||||
self.weight
|
||||
});
|
||||
if let Some(bias) = self.bias {
|
||||
output += bias.expand_to(output.shape);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<const A: usize, const B: usize> SerializeModule for PermutedLinear<A, B> {
|
||||
fn serialize(&self, s: &mut luminal::module::Serializer) {
|
||||
s.tensor("weight", self.weight);
|
||||
}
|
||||
}
|
||||
|
||||
impl<const A: usize, const B: usize, S: Shape> Module<GraphTensor<S>> for PermutedLinear<A, B>
|
||||
where
|
||||
GraphTensor<S>: Matmul<R2<A, B>>,
|
||||
{
|
||||
type Output = <GraphTensor<S> as Matmul<R2<A, B>>>::Output;
|
||||
|
||||
fn forward(&self, input: GraphTensor<S>) -> Self::Output {
|
||||
input.matmul(self.weight.permute())
|
||||
output
|
||||
}
|
||||
}
|
||||
|
||||
@@ -81,12 +78,10 @@ mod tests {
|
||||
#[test]
|
||||
fn test_linear() {
|
||||
let mut cx = Graph::new();
|
||||
let batch = cx
|
||||
.tensor::<R2<2, 3>>()
|
||||
.set(vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0]);
|
||||
let a = cx.tensor::<R1<3>>().set(vec![1.0, 2.0, 3.0]);
|
||||
let batch = cx.tensor((2, 3)).set([1.0, 2.0, 3.0, 1.0, 2.0, 3.0]);
|
||||
let a = cx.tensor(3).set([1.0, 2.0, 3.0]);
|
||||
|
||||
let model: Linear<3, 4> = Linear::initialize(&mut cx);
|
||||
let model = Linear::new(3, 4, false, &mut cx).initialize();
|
||||
let mut b = model.forward(a).retrieve();
|
||||
let mut batch_out = model.forward(batch).retrieve();
|
||||
|
||||
|
||||
@@ -3,23 +3,30 @@ use rand::thread_rng;
|
||||
|
||||
/// A simple layer norm with an optional weight and bias
|
||||
#[derive(Default)]
|
||||
pub struct LayerNorm<const DIM: usize> {
|
||||
pub weight: Option<GraphTensor<R1<DIM>>>,
|
||||
pub bias: Option<GraphTensor<R1<DIM>>>,
|
||||
pub struct LayerNorm {
|
||||
pub weight: Option<GraphTensor>,
|
||||
pub bias: Option<GraphTensor>,
|
||||
mean_norm: bool,
|
||||
epsilon: f32,
|
||||
}
|
||||
|
||||
impl<const DIM: usize> LayerNorm<DIM> {
|
||||
pub fn new(weight: bool, bias: bool, mean_norm: bool, epsilon: f32, cx: &mut Graph) -> Self {
|
||||
impl LayerNorm {
|
||||
pub fn new(
|
||||
dim: usize,
|
||||
weight: bool,
|
||||
bias: bool,
|
||||
mean_norm: bool,
|
||||
epsilon: f32,
|
||||
cx: &mut Graph,
|
||||
) -> Self {
|
||||
Self {
|
||||
weight: if weight {
|
||||
Some(cx.named_tensor("LayerNorm Weight"))
|
||||
Some(cx.named_tensor("LayerNorm Weight", dim))
|
||||
} else {
|
||||
None
|
||||
},
|
||||
bias: if bias {
|
||||
Some(cx.named_tensor("LayerNorm Bias"))
|
||||
Some(cx.named_tensor("LayerNorm Bias", dim))
|
||||
} else {
|
||||
None
|
||||
},
|
||||
@@ -27,53 +34,43 @@ impl<const DIM: usize> LayerNorm<DIM> {
|
||||
epsilon,
|
||||
}
|
||||
}
|
||||
pub fn init(weight: bool, bias: bool, mean_norm: bool, epsilon: f32, cx: &mut Graph) -> Self {
|
||||
pub fn initialize(self) -> Self {
|
||||
// Init weight as uniform(-1, 1)
|
||||
let mut rng = thread_rng();
|
||||
Self {
|
||||
weight: if weight {
|
||||
Some(
|
||||
cx.named_tensor("LayerNorm Weight")
|
||||
.set(random_vec_rng(DIM, &mut rng)),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
},
|
||||
bias: if bias {
|
||||
Some(
|
||||
cx.named_tensor("LayerNorm Bias")
|
||||
.set(random_vec_rng(DIM, &mut rng)),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
},
|
||||
mean_norm,
|
||||
epsilon,
|
||||
if let Some(w) = self.weight {
|
||||
w.set(random_vec_rng(
|
||||
w.shape.n_elements().to_usize().unwrap(),
|
||||
&mut rng,
|
||||
));
|
||||
}
|
||||
if let Some(b) = self.bias {
|
||||
b.set(random_vec_rng(
|
||||
b.shape.n_elements().to_usize().unwrap(),
|
||||
&mut rng,
|
||||
));
|
||||
}
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl<const DIM: usize, S: Shape> Module<GraphTensor<S>> for LayerNorm<DIM>
|
||||
where
|
||||
(Const<DIM>,): BroadcastShapeTo<S, S::AllButLast>,
|
||||
{
|
||||
type Output = GraphTensor<S>;
|
||||
fn forward(&self, mut input: GraphTensor<S>) -> Self::Output {
|
||||
impl Module<GraphTensor> for LayerNorm {
|
||||
type Output = GraphTensor;
|
||||
fn forward(&self, mut input: GraphTensor) -> Self::Output {
|
||||
if self.mean_norm {
|
||||
input = input.mean_norm::<S::LastAxis>();
|
||||
input = input.mean_norm(input.shape.last_axis());
|
||||
}
|
||||
input = input.std_norm(self.epsilon);
|
||||
input = input.std_norm(input.shape.last_axis(), self.epsilon);
|
||||
if let Some(w) = self.weight {
|
||||
input *= w.expand();
|
||||
input *= w.expand_to(input.shape);
|
||||
}
|
||||
if let Some(b) = self.bias {
|
||||
input += b.expand();
|
||||
input += b.expand_to(input.shape);
|
||||
}
|
||||
input
|
||||
}
|
||||
}
|
||||
|
||||
impl<const DIM: usize> SerializeModule for LayerNorm<DIM> {
|
||||
impl SerializeModule for LayerNorm {
|
||||
fn serialize(&self, s: &mut Serializer) {
|
||||
if let Some(w) = self.weight {
|
||||
s.tensor("weight", w);
|
||||
|
||||
@@ -4,34 +4,31 @@ use crate::Linear;
|
||||
use luminal::prelude::*;
|
||||
|
||||
/// Multi-head self attention as layed out in [*Attention Is All You Need*](https://arxiv.org/abs/1706.03762).
|
||||
pub struct MultiHeadSelfAttention<
|
||||
const DIM: usize,
|
||||
const K_DIM: usize,
|
||||
const V_DIM: usize,
|
||||
const HEADS: usize,
|
||||
> {
|
||||
pub w_q: Linear<DIM, K_DIM>,
|
||||
pub w_k: Linear<DIM, K_DIM>,
|
||||
pub w_v: Linear<DIM, V_DIM>,
|
||||
pub w_o: Linear<V_DIM, DIM>,
|
||||
pub struct MultiHeadSelfAttention {
|
||||
pub w_q: Linear, // dim x k_dim
|
||||
pub w_k: Linear, // dim x k_dim
|
||||
pub w_v: Linear, // dim x v_dim
|
||||
pub w_o: Linear, // v_dim x dim
|
||||
k_dim: usize,
|
||||
v_dim: usize,
|
||||
heads: usize,
|
||||
}
|
||||
|
||||
impl<const DIM: usize, const K_DIM: usize, const V_DIM: usize, const HEADS: usize> InitModule
|
||||
for MultiHeadSelfAttention<DIM, K_DIM, V_DIM, HEADS>
|
||||
{
|
||||
fn initialize(cx: &mut Graph) -> Self {
|
||||
impl MultiHeadSelfAttention {
|
||||
pub fn new(dim: usize, k_dim: usize, v_dim: usize, heads: usize, cx: &mut Graph) -> Self {
|
||||
Self {
|
||||
w_q: InitModule::initialize(cx),
|
||||
w_k: InitModule::initialize(cx),
|
||||
w_v: InitModule::initialize(cx),
|
||||
w_o: InitModule::initialize(cx),
|
||||
w_q: Linear::new(dim, k_dim, false, cx),
|
||||
w_k: Linear::new(dim, k_dim, false, cx),
|
||||
w_v: Linear::new(dim, v_dim, false, cx),
|
||||
w_o: Linear::new(v_dim, dim, false, cx),
|
||||
k_dim,
|
||||
v_dim,
|
||||
heads,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<const DIM: usize, const K_DIM: usize, const V_DIM: usize, const HEADS: usize> SerializeModule
|
||||
for MultiHeadSelfAttention<DIM, K_DIM, V_DIM, HEADS>
|
||||
{
|
||||
impl SerializeModule for MultiHeadSelfAttention {
|
||||
fn serialize(&self, s: &mut Serializer) {
|
||||
s.module("w_q", &self.w_q);
|
||||
s.module("w_k", &self.w_k);
|
||||
@@ -40,146 +37,71 @@ impl<const DIM: usize, const K_DIM: usize, const V_DIM: usize, const HEADS: usiz
|
||||
}
|
||||
}
|
||||
|
||||
// Single
|
||||
impl<
|
||||
const DIM: usize,
|
||||
const K_DIM: usize,
|
||||
const V_DIM: usize,
|
||||
const HEADS: usize,
|
||||
S: Dimension,
|
||||
> Module<GraphTensor<(S, Const<DIM>)>> for MultiHeadSelfAttention<DIM, K_DIM, V_DIM, HEADS>
|
||||
{
|
||||
type Output = GraphTensor<(S, Const<DIM>)>;
|
||||
|
||||
fn forward(&self, input: GraphTensor<(S, Const<DIM>)>) -> Self::Output {
|
||||
// Pass to batched forward
|
||||
<Self as Module<GraphTensor<(Const<1>, S, Const<DIM>)>>>::forward(self, input.expand())
|
||||
.reshape()
|
||||
}
|
||||
}
|
||||
|
||||
impl<
|
||||
const DIM: usize,
|
||||
const K_DIM: usize,
|
||||
const V_DIM: usize,
|
||||
const HEADS: usize,
|
||||
S: Dimension,
|
||||
S1: Dimension,
|
||||
>
|
||||
Module<(
|
||||
GraphTensor<(S, Const<DIM>)>,
|
||||
GraphTensor<(S1, Const<DIM>)>,
|
||||
GraphTensor<(S, Const<DIM>)>,
|
||||
)> for MultiHeadSelfAttention<DIM, K_DIM, V_DIM, HEADS>
|
||||
{
|
||||
type Output = GraphTensor<(S1, Const<DIM>)>;
|
||||
|
||||
fn forward(
|
||||
&self,
|
||||
(k, q, v): (
|
||||
GraphTensor<(S, Const<DIM>)>,
|
||||
GraphTensor<(S1, Const<DIM>)>,
|
||||
GraphTensor<(S, Const<DIM>)>,
|
||||
),
|
||||
) -> Self::Output {
|
||||
// Pass to batched forward
|
||||
<Self as Module<(
|
||||
GraphTensor<(Const<1>, S, Const<DIM>)>,
|
||||
GraphTensor<(Const<1>, S1, Const<DIM>)>,
|
||||
GraphTensor<(Const<1>, S, Const<DIM>)>,
|
||||
)>>::forward(self, (k.expand(), q.expand(), v.expand()))
|
||||
.reshape()
|
||||
}
|
||||
}
|
||||
|
||||
// Batched
|
||||
impl<
|
||||
const DIM: usize,
|
||||
const K_DIM: usize,
|
||||
const V_DIM: usize,
|
||||
const HEADS: usize,
|
||||
S: Dimension,
|
||||
B: Dimension,
|
||||
> Module<GraphTensor<(B, S, Const<DIM>)>> for MultiHeadSelfAttention<DIM, K_DIM, V_DIM, HEADS>
|
||||
{
|
||||
type Output = GraphTensor<(B, S, Const<DIM>)>;
|
||||
impl Module<GraphTensor> for MultiHeadSelfAttention {
|
||||
type Output = GraphTensor;
|
||||
|
||||
fn forward(&self, input: GraphTensor<(B, S, Const<DIM>)>) -> Self::Output {
|
||||
<Self as Module<(
|
||||
GraphTensor<(B, S, Const<DIM>)>,
|
||||
GraphTensor<(B, S, Const<DIM>)>,
|
||||
GraphTensor<(B, S, Const<DIM>)>,
|
||||
)>>::forward(self, (input, input, input))
|
||||
fn forward(&self, input: GraphTensor) -> Self::Output {
|
||||
// Input: batch_dims, sequence, dim
|
||||
<Self as Module<(GraphTensor, GraphTensor, GraphTensor)>>::forward(
|
||||
self,
|
||||
(input, input, input),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Batched different key-query-value
|
||||
impl<
|
||||
const DIM: usize,
|
||||
const K_DIM: usize,
|
||||
const V_DIM: usize,
|
||||
const HEADS: usize,
|
||||
S1: Dimension,
|
||||
S2: Dimension,
|
||||
B: Dimension,
|
||||
>
|
||||
Module<(
|
||||
GraphTensor<(B, S1, Const<DIM>)>,
|
||||
GraphTensor<(B, S2, Const<DIM>)>,
|
||||
GraphTensor<(B, S1, Const<DIM>)>,
|
||||
)> for MultiHeadSelfAttention<DIM, K_DIM, V_DIM, HEADS>
|
||||
{
|
||||
type Output = GraphTensor<(B, S2, Const<DIM>)>;
|
||||
impl Module<(GraphTensor, GraphTensor, GraphTensor)> for MultiHeadSelfAttention {
|
||||
type Output = GraphTensor;
|
||||
|
||||
fn forward(
|
||||
&self,
|
||||
(keys, queries, values): (
|
||||
GraphTensor<(B, S1, Const<DIM>)>,
|
||||
GraphTensor<(B, S2, Const<DIM>)>,
|
||||
GraphTensor<(B, S1, Const<DIM>)>,
|
||||
GraphTensor, // batch, s1, dim
|
||||
GraphTensor, // batch, s2, dim
|
||||
GraphTensor, // batch, s1, dim
|
||||
),
|
||||
) -> Self::Output {
|
||||
let orig_query_shape = queries.shape();
|
||||
let s1 = keys.shape()[keys.shape.len() - 2].small();
|
||||
let s2 = queries.shape()[queries.shape.len() - 2].small();
|
||||
let n_batches = queries
|
||||
.shape()
|
||||
.into_iter()
|
||||
.take(queries.shape.len() - 2)
|
||||
.product::<BigExpression>()
|
||||
.max(1)
|
||||
.small();
|
||||
let dim = queries.shape().last().unwrap().small();
|
||||
let keys = keys.reshape((n_batches, s1, dim));
|
||||
let values = values.reshape((n_batches, s1, dim));
|
||||
let queries = queries.reshape((n_batches, s2, dim));
|
||||
let values = self
|
||||
.w_v
|
||||
.forward(values)
|
||||
.dyn_reshape::<(B, S1, Const<HEADS>, Dyn<'-'>), _>(&[
|
||||
B::size(),
|
||||
S1::size(),
|
||||
HEADS.into(),
|
||||
(K_DIM / HEADS).into(),
|
||||
])
|
||||
.permute::<_, Axes4<0, 2, 1, 3>>();
|
||||
.reshape((n_batches, s1, self.heads, self.k_dim / self.heads))
|
||||
.permute((0, 2, 1, 3));
|
||||
let keys = self
|
||||
.w_k
|
||||
.forward(keys)
|
||||
.dyn_reshape::<(B, S1, Const<HEADS>, Dyn<'-'>), _>(&[
|
||||
B::size(),
|
||||
S1::size(),
|
||||
HEADS.into(),
|
||||
(K_DIM / HEADS).into(),
|
||||
])
|
||||
.permute::<_, Axes4<0, 2, 3, 1>>();
|
||||
.reshape((n_batches, s1, self.heads, self.k_dim / self.heads))
|
||||
.permute((0, 2, 3, 1));
|
||||
let queries = self
|
||||
.w_q
|
||||
.forward(queries)
|
||||
.dyn_reshape::<(B, S2, Const<HEADS>, Dyn<'-'>), _>(&[
|
||||
B::size(),
|
||||
S2::size(),
|
||||
HEADS.into(),
|
||||
(K_DIM / HEADS).into(),
|
||||
])
|
||||
.permute::<_, Axes4<0, 2, 1, 3>>();
|
||||
.reshape((n_batches, s2, self.heads, self.k_dim / self.heads))
|
||||
.permute((0, 2, 1, 3));
|
||||
|
||||
let weights = queries
|
||||
.matmul(keys)
|
||||
.mul((1.0 / ((K_DIM / HEADS) as f64).sqrt()) as f32)
|
||||
.softmax::<Axis<3>>();
|
||||
.mul((1.0 / ((self.k_dim / self.heads) as f64).sqrt()) as f32)
|
||||
.softmax(3);
|
||||
|
||||
let tokens: GraphTensor<(B, S2, Const<V_DIM>)> = weights
|
||||
let tokens = weights
|
||||
.matmul(values)
|
||||
.permute::<_, Axes4<0, 2, 1, 3>>()
|
||||
.reshape();
|
||||
self.w_o.forward(tokens)
|
||||
.permute((0, 2, 1, 3))
|
||||
.reshape((n_batches, s2, self.v_dim));
|
||||
self.w_o.forward(tokens).reshape(orig_query_shape) // batch_dims, s2, dim
|
||||
}
|
||||
}
|
||||
|
||||
@@ -195,7 +117,7 @@ mod tests {
|
||||
#[test]
|
||||
fn test_self_attention() {
|
||||
let mut cx = Graph::new();
|
||||
let model: MultiHeadSelfAttention<3, 3, 3, 1> = InitModule::initialize(&mut cx);
|
||||
let model = MultiHeadSelfAttention::new(3, 3, 3, 1, &mut cx);
|
||||
model
|
||||
.w_k
|
||||
.weight
|
||||
@@ -213,20 +135,17 @@ mod tests {
|
||||
.weight
|
||||
.set(vec![1., 22., 3., 1., 2., 3., 1., 2., 3.]);
|
||||
|
||||
let a = cx.tensor::<(Dyn<'d'>, luminal::shape::Const<3>)>();
|
||||
let e = cx.tensor::<(Dyn<'e'>, luminal::shape::Const<3>)>();
|
||||
let a = cx.tensor(('d', 3));
|
||||
let e = cx.tensor(('e', 3));
|
||||
let b = model.forward((e, a, e));
|
||||
|
||||
a.set_dyn(
|
||||
vec![
|
||||
0.56587636, -1.4053632, 0.8394869, 0.5916256, -1.4082357, 0.8166099,
|
||||
],
|
||||
&[2, 3],
|
||||
);
|
||||
e.set_dyn(
|
||||
vec![-1.0, 2.0, 3.0, 3.0, 3.0, -1.0, -1.0, 2.0, 3.0],
|
||||
&[3, 3],
|
||||
(2, 3),
|
||||
);
|
||||
e.set_dyn(vec![-1.0, 2.0, 3.0, 3.0, 3.0, -1.0, -1.0, 2.0, 3.0], (3, 3));
|
||||
b.retrieve();
|
||||
|
||||
cx.execute();
|
||||
|
||||
@@ -4,28 +4,21 @@ use luminal::prelude::*;
|
||||
use super::attention::MultiHeadSelfAttention;
|
||||
|
||||
/// A transformer decoder as layed out in [*Attention Is All You Need*](https://arxiv.org/abs/1706.03762).
|
||||
pub struct TransformerDecoder<
|
||||
const DIM: usize,
|
||||
const FF: usize,
|
||||
const HEADS: usize,
|
||||
const LAYERS: usize,
|
||||
> {
|
||||
pub layers: Vec<TransformerDecoderBlock<DIM, FF, HEADS>>,
|
||||
pub struct TransformerDecoder {
|
||||
pub layers: Vec<TransformerDecoderBlock>,
|
||||
}
|
||||
|
||||
impl<const DIM: usize, const FF: usize, const HEADS: usize, const LAYERS: usize> InitModule
|
||||
for TransformerDecoder<DIM, FF, HEADS, LAYERS>
|
||||
{
|
||||
fn initialize(cx: &mut Graph) -> Self {
|
||||
impl TransformerDecoder {
|
||||
pub fn new(dim: usize, ff: usize, heads: usize, layers: usize, cx: &mut Graph) -> Self {
|
||||
Self {
|
||||
layers: (0..LAYERS).map(|_| InitModule::initialize(cx)).collect(),
|
||||
layers: (0..layers)
|
||||
.map(|_| TransformerDecoderBlock::new(dim, ff, heads, cx))
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<const DIM: usize, const FF: usize, const HEADS: usize, const LAYERS: usize> SerializeModule
|
||||
for TransformerDecoder<DIM, FF, HEADS, LAYERS>
|
||||
{
|
||||
impl SerializeModule for TransformerDecoder {
|
||||
fn serialize(&self, s: &mut Serializer) {
|
||||
for (i, l) in self.layers.iter().enumerate() {
|
||||
s.module(&format!("layer{i}"), l);
|
||||
@@ -33,55 +26,10 @@ impl<const DIM: usize, const FF: usize, const HEADS: usize, const LAYERS: usize>
|
||||
}
|
||||
}
|
||||
|
||||
// Single
|
||||
impl<
|
||||
const DIM: usize,
|
||||
const FF: usize,
|
||||
const HEADS: usize,
|
||||
const LAYERS: usize,
|
||||
S1: Dimension,
|
||||
S2: Dimension,
|
||||
> Module<(GraphTensor<(S1, Const<DIM>)>, GraphTensor<(S2, Const<DIM>)>)>
|
||||
for TransformerDecoder<DIM, FF, HEADS, LAYERS>
|
||||
{
|
||||
type Output = GraphTensor<(S1, Const<DIM>)>;
|
||||
impl Module<(GraphTensor, GraphTensor)> for TransformerDecoder {
|
||||
type Output = GraphTensor;
|
||||
|
||||
fn forward(
|
||||
&self,
|
||||
(input, from_enc): (GraphTensor<(S1, Const<DIM>)>, GraphTensor<(S2, Const<DIM>)>),
|
||||
) -> Self::Output {
|
||||
<Self as Module<(
|
||||
GraphTensor<(Const<1>, S1, Const<DIM>)>,
|
||||
GraphTensor<(Const<1>, S2, Const<DIM>)>,
|
||||
)>>::forward(self, (input.expand(), from_enc.expand()))
|
||||
.reshape()
|
||||
}
|
||||
}
|
||||
|
||||
// Batched
|
||||
impl<
|
||||
const DIM: usize,
|
||||
const FF: usize,
|
||||
const HEADS: usize,
|
||||
const LAYERS: usize,
|
||||
B: Dimension,
|
||||
S1: Dimension,
|
||||
S2: Dimension,
|
||||
>
|
||||
Module<(
|
||||
GraphTensor<(B, S1, Const<DIM>)>,
|
||||
GraphTensor<(B, S2, Const<DIM>)>,
|
||||
)> for TransformerDecoder<DIM, FF, HEADS, LAYERS>
|
||||
{
|
||||
type Output = GraphTensor<(B, S1, Const<DIM>)>;
|
||||
|
||||
fn forward(
|
||||
&self,
|
||||
(mut input, from_enc): (
|
||||
GraphTensor<(B, S1, Const<DIM>)>,
|
||||
GraphTensor<(B, S2, Const<DIM>)>,
|
||||
),
|
||||
) -> Self::Output {
|
||||
fn forward(&self, (mut input, from_enc): (GraphTensor, GraphTensor)) -> Self::Output {
|
||||
for layer in &self.layers {
|
||||
input = layer.forward((input, from_enc));
|
||||
}
|
||||
@@ -90,27 +38,27 @@ impl<
|
||||
}
|
||||
|
||||
/// A single transformer decoder block
|
||||
pub struct TransformerDecoderBlock<const DIM: usize, const FF: usize, const HEADS: usize> {
|
||||
pub self_attention: MultiHeadSelfAttention<DIM, DIM, DIM, HEADS>,
|
||||
pub cross_attention: MultiHeadSelfAttention<DIM, DIM, DIM, HEADS>,
|
||||
pub ff: (Linear<DIM, FF>, ReLU, Linear<FF, DIM>),
|
||||
pub struct TransformerDecoderBlock {
|
||||
pub self_attention: MultiHeadSelfAttention,
|
||||
pub cross_attention: MultiHeadSelfAttention,
|
||||
pub ff: (Linear, ReLU, Linear),
|
||||
}
|
||||
|
||||
impl<const DIM: usize, const FF: usize, const HEADS: usize> InitModule
|
||||
for TransformerDecoderBlock<DIM, FF, HEADS>
|
||||
{
|
||||
fn initialize(cx: &mut Graph) -> Self {
|
||||
impl TransformerDecoderBlock {
|
||||
pub fn new(dim: usize, ff: usize, heads: usize, cx: &mut Graph) -> Self {
|
||||
Self {
|
||||
cross_attention: InitModule::initialize(cx),
|
||||
self_attention: InitModule::initialize(cx),
|
||||
ff: InitModule::initialize(cx),
|
||||
cross_attention: MultiHeadSelfAttention::new(dim, dim, dim, heads, cx),
|
||||
self_attention: MultiHeadSelfAttention::new(dim, dim, dim, heads, cx),
|
||||
ff: (
|
||||
Linear::new(dim, ff, false, cx),
|
||||
ReLU,
|
||||
Linear::new(ff, dim, false, cx),
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<const DIM: usize, const FF: usize, const HEADS: usize> SerializeModule
|
||||
for TransformerDecoderBlock<DIM, FF, HEADS>
|
||||
{
|
||||
impl SerializeModule for TransformerDecoderBlock {
|
||||
fn serialize(&self, s: &mut Serializer) {
|
||||
s.module("self_attn", &self.self_attention);
|
||||
s.module("cross_attn", &self.cross_attention);
|
||||
@@ -118,55 +66,32 @@ impl<const DIM: usize, const FF: usize, const HEADS: usize> SerializeModule
|
||||
}
|
||||
}
|
||||
|
||||
// Single
|
||||
impl<const DIM: usize, const FF: usize, const HEADS: usize, S1: Dimension, S2: Dimension>
|
||||
Module<(GraphTensor<(S1, Const<DIM>)>, GraphTensor<(S2, Const<DIM>)>)>
|
||||
for TransformerDecoderBlock<DIM, FF, HEADS>
|
||||
{
|
||||
type Output = GraphTensor<(S1, Const<DIM>)>;
|
||||
impl Module<(GraphTensor, GraphTensor)> for TransformerDecoderBlock {
|
||||
type Output = GraphTensor;
|
||||
|
||||
fn forward(
|
||||
&self,
|
||||
(input, from_enc): (GraphTensor<(S1, Const<DIM>)>, GraphTensor<(S2, Const<DIM>)>),
|
||||
) -> Self::Output {
|
||||
// Pass to batched forward
|
||||
<Self as Module<(
|
||||
GraphTensor<(Const<1>, S1, Const<DIM>)>,
|
||||
GraphTensor<(Const<1>, S2, Const<DIM>)>,
|
||||
)>>::forward(self, (input.expand(), from_enc.expand()))
|
||||
.reshape()
|
||||
}
|
||||
}
|
||||
|
||||
// Batched
|
||||
impl<
|
||||
const DIM: usize,
|
||||
const FF: usize,
|
||||
const HEADS: usize,
|
||||
S1: Dimension,
|
||||
S2: Dimension,
|
||||
B: Dimension,
|
||||
>
|
||||
Module<(
|
||||
GraphTensor<(B, S1, Const<DIM>)>,
|
||||
GraphTensor<(B, S2, Const<DIM>)>,
|
||||
)> for TransformerDecoderBlock<DIM, FF, HEADS>
|
||||
{
|
||||
type Output = GraphTensor<(B, S1, Const<DIM>)>;
|
||||
|
||||
fn forward(
|
||||
&self,
|
||||
(x, from_enc): (
|
||||
GraphTensor<(B, S1, Const<DIM>)>,
|
||||
GraphTensor<(B, S2, Const<DIM>)>,
|
||||
),
|
||||
) -> Self::Output {
|
||||
let y = self.self_attention.forward(x);
|
||||
let x = (y + x).layer_norm::<Axis<2>, _>(1e-5);
|
||||
let y = self.cross_attention.forward((from_enc, x, from_enc));
|
||||
let x = (y + x).layer_norm::<Axis<2>, _>(1e-5);
|
||||
fn forward(&self, (input, from_enc): (GraphTensor, GraphTensor)) -> Self::Output {
|
||||
// Input: batch_dims, seq1, dim
|
||||
// From_enc: batch_dims, seq2, dim
|
||||
// Flatten to single batch dim
|
||||
let seq1 = input.shape()[input.shape.len() - 2].small();
|
||||
let seq2 = from_enc.shape()[from_enc.shape.len() - 2].small();
|
||||
let dim = input.shape().last().unwrap().small();
|
||||
let n_batches = input
|
||||
.shape()
|
||||
.into_iter()
|
||||
.take(input.shape.len() - 2)
|
||||
.product::<BigExpression>()
|
||||
.max(1)
|
||||
.small();
|
||||
let inp = input.reshape((n_batches, seq1, dim));
|
||||
let fe = from_enc.reshape((n_batches, seq2, dim));
|
||||
// Batched forward pass
|
||||
let y = self.self_attention.forward(inp);
|
||||
let x = (y + inp).layer_norm(2, 1e-5);
|
||||
let y = self.cross_attention.forward((fe, x, fe));
|
||||
let x = (y + x).layer_norm(2, 1e-5);
|
||||
let y = self.ff.forward(x);
|
||||
(y + x).layer_norm::<Axis<2>, _>(1e-5)
|
||||
(y + x).layer_norm(2, 1e-5).reshape(input.shape)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -187,7 +112,7 @@ mod tests {
|
||||
#[test]
|
||||
fn test_transformer_decoder_block() {
|
||||
let mut cx = Graph::new();
|
||||
let model: TransformerDecoderBlock<3, 4, 1> = InitModule::initialize(&mut cx);
|
||||
let model = TransformerDecoderBlock::new(3, 4, 1, &mut cx);
|
||||
model
|
||||
.self_attention
|
||||
.w_k
|
||||
@@ -239,12 +164,12 @@ mod tests {
|
||||
.weight
|
||||
.set(vec![-1., 12., 3., -1., 2., -3., 11., 2., 3., 3., -1., 2.]);
|
||||
|
||||
let a = cx.tensor::<(Dyn<'d'>, Const<3>)>();
|
||||
let e = cx.tensor::<(Dyn<'e'>, Const<3>)>();
|
||||
let a = cx.tensor(('d', 3));
|
||||
let e = cx.tensor(('e', 3));
|
||||
let b = model.forward((a, e));
|
||||
|
||||
a.set_dyn(vec![-1., 2., 3., 3., 3., -1.], &[2, 3]);
|
||||
e.set_dyn(vec![-1., 2., 3., 3., 3., -1., -1., 2., 3.], &[3, 3]);
|
||||
a.set_dyn(vec![-1., 2., 3., 3., 3., -1.], (2, 3));
|
||||
e.set_dyn(vec![-1., 2., 3., 3., 3., -1., -1., 2., 3.], (3, 3));
|
||||
b.retrieve();
|
||||
|
||||
cx.execute();
|
||||
|
||||
@@ -1,66 +1,88 @@
|
||||
use crate::{Linear, ReLU, Repeated};
|
||||
use crate::{Linear, ReLU};
|
||||
use luminal::prelude::*;
|
||||
|
||||
use super::attention::MultiHeadSelfAttention;
|
||||
|
||||
/// A transformer encoder as layed out in [*Attention Is All You Need*](https://arxiv.org/abs/1706.03762).
|
||||
pub type TransformerEncoder<
|
||||
const DIM: usize,
|
||||
const FF: usize,
|
||||
const HEADS: usize,
|
||||
const LAYERS: usize,
|
||||
> = Repeated<TransformerEncoderBlock<DIM, FF, HEADS>, LAYERS>;
|
||||
|
||||
/// A single transformer encoder block
|
||||
pub struct TransformerEncoderBlock<const DIM: usize, const FF: usize, const HEADS: usize> {
|
||||
pub attention: MultiHeadSelfAttention<DIM, DIM, DIM, HEADS>,
|
||||
pub ff: (Linear<DIM, FF>, ReLU, Linear<FF, DIM>),
|
||||
pub struct TransformerEncoder {
|
||||
pub layers: Vec<TransformerEncoderBlock>,
|
||||
}
|
||||
|
||||
impl<const DIM: usize, const FF: usize, const HEADS: usize> InitModule
|
||||
for TransformerEncoderBlock<DIM, FF, HEADS>
|
||||
{
|
||||
fn initialize(cx: &mut Graph) -> Self {
|
||||
impl TransformerEncoder {
|
||||
pub fn new(dim: usize, ff: usize, heads: usize, layers: usize, cx: &mut Graph) -> Self {
|
||||
Self {
|
||||
attention: InitModule::initialize(cx),
|
||||
ff: InitModule::initialize(cx),
|
||||
layers: (0..layers)
|
||||
.map(|_| TransformerEncoderBlock::new(dim, ff, heads, cx))
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<const DIM: usize, const FF: usize, const HEADS: usize> SerializeModule
|
||||
for TransformerEncoderBlock<DIM, FF, HEADS>
|
||||
{
|
||||
impl SerializeModule for TransformerEncoder {
|
||||
fn serialize(&self, s: &mut Serializer) {
|
||||
for (i, l) in self.layers.iter().enumerate() {
|
||||
s.module(&format!("layer{i}"), l);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Module<GraphTensor> for TransformerEncoder {
|
||||
type Output = GraphTensor;
|
||||
|
||||
fn forward(&self, mut input: GraphTensor) -> Self::Output {
|
||||
for layer in &self.layers {
|
||||
input = layer.forward(input);
|
||||
}
|
||||
input
|
||||
}
|
||||
}
|
||||
|
||||
/// A single transformer encoder block
|
||||
pub struct TransformerEncoderBlock {
|
||||
pub attention: MultiHeadSelfAttention,
|
||||
pub ff: (Linear, ReLU, Linear),
|
||||
}
|
||||
|
||||
impl TransformerEncoderBlock {
|
||||
pub fn new(dim: usize, ff: usize, heads: usize, cx: &mut Graph) -> Self {
|
||||
Self {
|
||||
attention: MultiHeadSelfAttention::new(dim, dim, dim, heads, cx),
|
||||
ff: (
|
||||
Linear::new(dim, ff, false, cx),
|
||||
ReLU,
|
||||
Linear::new(ff, dim, false, cx),
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SerializeModule for TransformerEncoderBlock {
|
||||
fn serialize(&self, s: &mut Serializer) {
|
||||
s.module("self_attn", &self.attention);
|
||||
s.module("ff", &self.ff);
|
||||
}
|
||||
}
|
||||
|
||||
// Single
|
||||
impl<const DIM: usize, const FF: usize, const HEADS: usize, S: Dimension>
|
||||
Module<GraphTensor<(S, Const<DIM>)>> for TransformerEncoderBlock<DIM, FF, HEADS>
|
||||
{
|
||||
type Output = GraphTensor<(S, Const<DIM>)>;
|
||||
|
||||
fn forward(&self, input: GraphTensor<(S, Const<DIM>)>) -> Self::Output {
|
||||
// Pass to batched forward
|
||||
<Self as Module<GraphTensor<(Const<1>, S, Const<DIM>)>>>::forward(self, input.expand())
|
||||
.reshape()
|
||||
}
|
||||
}
|
||||
|
||||
// Batched
|
||||
impl<const DIM: usize, const FF: usize, const HEADS: usize, S: Dimension, B: Dimension>
|
||||
Module<GraphTensor<(B, S, Const<DIM>)>> for TransformerEncoderBlock<DIM, FF, HEADS>
|
||||
{
|
||||
type Output = GraphTensor<(B, S, Const<DIM>)>;
|
||||
impl Module<GraphTensor> for TransformerEncoderBlock {
|
||||
type Output = GraphTensor;
|
||||
|
||||
fn forward(&self, x: GraphTensor<(B, S, Const<DIM>)>) -> Self::Output {
|
||||
fn forward(&self, input: GraphTensor) -> Self::Output {
|
||||
// Input: batch_dims, sequence, dim
|
||||
// Reshape to 1 batch dim, sequence, dim
|
||||
let n_batches = input
|
||||
.shape()
|
||||
.into_iter()
|
||||
.take(input.shape.len() - 2)
|
||||
.product::<BigExpression>()
|
||||
.max(1);
|
||||
let sequence = input.shape()[input.shape.len() - 2].small();
|
||||
let dim = input.shape()[input.shape.len() - 1].small();
|
||||
let x = input.reshape((n_batches, sequence, dim));
|
||||
let x = x + self.attention.forward(x);
|
||||
let x = x.layer_norm::<Axis<2>, _>(1e-5);
|
||||
let x = x.layer_norm(2, 1e-5);
|
||||
let x = x + self.ff.forward(x);
|
||||
x.layer_norm::<Axis<2>, _>(1e-5)
|
||||
x.layer_norm(2, 1e-5).reshape(input.shape())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -81,7 +103,7 @@ mod tests {
|
||||
#[test]
|
||||
fn test_transformer_encoder_block() {
|
||||
let mut cx = Graph::new();
|
||||
let model: TransformerEncoderBlock<3, 4, 1> = InitModule::initialize(&mut cx);
|
||||
let model = TransformerEncoderBlock::new(3, 4, 1, &mut cx);
|
||||
model
|
||||
.attention
|
||||
.w_k
|
||||
@@ -114,8 +136,8 @@ mod tests {
|
||||
.set(vec![-1., 12., 3., -1., 2., -3., 11., 2., 3., 3., -1., 2.]);
|
||||
|
||||
let a = cx
|
||||
.tensor::<(Dyn<'s'>, luminal::shape::Const<3>)>()
|
||||
.set_dyn(vec![-1., 2., 3., 3., 3., -1.], &[2, 3]);
|
||||
.tensor(('s', 3))
|
||||
.set_dyn(vec![-1., 2., 3., 3., 3., -1.], (2, 3));
|
||||
let b = model.forward(a).retrieve();
|
||||
|
||||
cx.execute();
|
||||
|
||||
@@ -7,44 +7,29 @@ pub use decoder::*;
|
||||
mod encoder;
|
||||
pub use encoder::*;
|
||||
|
||||
pub struct Transformer<
|
||||
const DIM: usize,
|
||||
const FF: usize,
|
||||
const ENC_HEADS: usize,
|
||||
const DEC_HEADS: usize,
|
||||
const ENC_LAYERS: usize,
|
||||
const DEC_LAYERS: usize,
|
||||
> {
|
||||
pub encoder: encoder::TransformerEncoder<DIM, FF, ENC_HEADS, ENC_LAYERS>,
|
||||
pub decoder: decoder::TransformerDecoder<DIM, FF, DEC_HEADS, DEC_LAYERS>,
|
||||
pub struct Transformer {
|
||||
pub encoder: encoder::TransformerEncoder,
|
||||
pub decoder: decoder::TransformerDecoder,
|
||||
}
|
||||
|
||||
impl<
|
||||
const DIM: usize,
|
||||
const FF: usize,
|
||||
const ENC_HEADS: usize,
|
||||
const DEC_HEADS: usize,
|
||||
const ENC_LAYERS: usize,
|
||||
const DEC_LAYERS: usize,
|
||||
> InitModule for Transformer<DIM, FF, ENC_HEADS, DEC_HEADS, ENC_LAYERS, DEC_LAYERS>
|
||||
{
|
||||
fn initialize(cx: &mut Graph) -> Self {
|
||||
impl Transformer {
|
||||
pub fn new(
|
||||
dim: usize,
|
||||
ff: usize,
|
||||
enc_heads: usize,
|
||||
dec_heads: usize,
|
||||
enc_layers: usize,
|
||||
dec_layers: usize,
|
||||
cx: &mut Graph,
|
||||
) -> Self {
|
||||
Self {
|
||||
encoder: InitModule::initialize(cx),
|
||||
decoder: InitModule::initialize(cx),
|
||||
encoder: TransformerEncoder::new(dim, ff, enc_heads, enc_layers, cx),
|
||||
decoder: TransformerDecoder::new(dim, ff, dec_heads, dec_layers, cx),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<
|
||||
const DIM: usize,
|
||||
const FF: usize,
|
||||
const ENC_HEADS: usize,
|
||||
const DEC_HEADS: usize,
|
||||
const ENC_LAYERS: usize,
|
||||
const DEC_LAYERS: usize,
|
||||
> SerializeModule for Transformer<DIM, FF, ENC_HEADS, DEC_HEADS, ENC_LAYERS, DEC_LAYERS>
|
||||
{
|
||||
impl SerializeModule for Transformer {
|
||||
fn serialize(&self, s: &mut Serializer) {
|
||||
s.module("encoder", &self.encoder);
|
||||
s.module("decoder", &self.decoder);
|
||||
@@ -52,24 +37,10 @@ impl<
|
||||
}
|
||||
|
||||
// Single Sequence
|
||||
impl<
|
||||
const DIM: usize,
|
||||
const FF: usize,
|
||||
const ENC_HEADS: usize,
|
||||
const DEC_HEADS: usize,
|
||||
const ENC_LAYERS: usize,
|
||||
const DEC_LAYERS: usize,
|
||||
S1: Dimension,
|
||||
S2: Dimension,
|
||||
> Module<(GraphTensor<(S1, Const<DIM>)>, GraphTensor<(S2, Const<DIM>)>)>
|
||||
for Transformer<DIM, FF, ENC_HEADS, DEC_HEADS, ENC_LAYERS, DEC_LAYERS>
|
||||
{
|
||||
type Output = GraphTensor<(S2, Const<DIM>)>;
|
||||
impl Module<(GraphTensor, GraphTensor)> for Transformer {
|
||||
type Output = GraphTensor;
|
||||
|
||||
fn forward(
|
||||
&self,
|
||||
(input, target): (GraphTensor<(S1, Const<DIM>)>, GraphTensor<(S2, Const<DIM>)>),
|
||||
) -> Self::Output {
|
||||
fn forward(&self, (input, target): (GraphTensor, GraphTensor)) -> Self::Output {
|
||||
let encoded = self.encoder.forward(input);
|
||||
self.decoder.forward((target, encoded))
|
||||
}
|
||||
@@ -92,7 +63,7 @@ mod tests {
|
||||
#[test]
|
||||
fn test_transformer_full() {
|
||||
let mut cx = Graph::new();
|
||||
let model: Transformer<3, 4, 1, 1, 1, 1> = InitModule::initialize(&mut cx);
|
||||
let model = Transformer::new(3, 4, 1, 1, 1, 1, &mut cx);
|
||||
model.decoder.layers[0]
|
||||
.self_attention
|
||||
.w_k
|
||||
@@ -143,43 +114,43 @@ mod tests {
|
||||
.2
|
||||
.weight
|
||||
.set(vec![-1., 12., 3., -1., 2., -3., 11., 2., 3., 3., -1., 2.]);
|
||||
model.encoder.modules[0]
|
||||
model.encoder.layers[0]
|
||||
.attention
|
||||
.w_k
|
||||
.weight
|
||||
.set(vec![1., 22., 3., 1., 2., 3., 1., 2., 3.]);
|
||||
model.encoder.modules[0]
|
||||
model.encoder.layers[0]
|
||||
.attention
|
||||
.w_q
|
||||
.weight
|
||||
.set(vec![3., 2., 3., 1.3, 2., 3., 3., 2., 3.]);
|
||||
model.encoder.modules[0]
|
||||
model.encoder.layers[0]
|
||||
.attention
|
||||
.w_v
|
||||
.weight
|
||||
.set(vec![-1., 12., 3., -1., 2., -3., 11., 2., 3.]);
|
||||
model.encoder.modules[0]
|
||||
model.encoder.layers[0]
|
||||
.attention
|
||||
.w_o
|
||||
.weight
|
||||
.set(vec![1., 22., 3., 1., 2., 3., 1., 2., 3.]);
|
||||
model.encoder.modules[0]
|
||||
model.encoder.layers[0]
|
||||
.ff
|
||||
.0
|
||||
.weight
|
||||
.set(vec![-1., 12., 3., -1., 2., -3., 11., 2., 3., 11., 2., 3.]);
|
||||
model.encoder.modules[0]
|
||||
model.encoder.layers[0]
|
||||
.ff
|
||||
.2
|
||||
.weight
|
||||
.set(vec![-1., 12., 3., -1., 2., -3., 11., 2., 3., 3., -1., 2.]);
|
||||
|
||||
let a = cx.tensor::<(Dyn<'d'>, luminal::shape::Const<3>)>();
|
||||
let e = cx.tensor::<(Dyn<'e'>, luminal::shape::Const<3>)>();
|
||||
let a = cx.tensor(('d', 3));
|
||||
let e = cx.tensor(('e', 3));
|
||||
let b = model.forward((a, e));
|
||||
|
||||
a.set_dyn(vec![-1., 2., 3., 3., 3., -1.], &[2, 3]);
|
||||
e.set_dyn(vec![-1., 2., 3., 3., 3., -1., -1., 2., 3.], &[3, 3]);
|
||||
a.set_dyn(vec![-1., 2., 3., 3., 3., -1.], (2, 3));
|
||||
e.set_dyn(vec![-1., 2., 3., 3., 3., -1., -1., 2., 3.], (3, 3));
|
||||
b.retrieve();
|
||||
|
||||
cx.execute();
|
||||
|
||||
@@ -16,7 +16,7 @@ use luminal::{
|
||||
pub struct Autograd(Vec<NodeIndex>, NodeIndex);
|
||||
|
||||
impl Autograd {
|
||||
pub fn new<W: ToIds>(params: W, loss: GraphTensor<()>) -> Self {
|
||||
pub fn new<W: ToIds>(params: W, loss: GraphTensor) -> Self {
|
||||
Self(params.to_ids(), loss.id)
|
||||
}
|
||||
}
|
||||
@@ -61,7 +61,7 @@ impl Compiler for Autograd {
|
||||
*loss,
|
||||
(
|
||||
graph.constant(1.0).id,
|
||||
ShapeTracker::new(&[]), // Assume scalar loss for now
|
||||
ShapeTracker::default(), // Assume scalar loss for now
|
||||
),
|
||||
);
|
||||
let weight_set = params.iter().copied().collect::<FxHashSet<_>>();
|
||||
@@ -90,7 +90,7 @@ impl Compiler for Autograd {
|
||||
.edges_directed(fwd_node, Direction::Incoming)
|
||||
.filter_map(|e| e.weight().as_data().map(|i| (e.source(), i)))
|
||||
.sorted_by_key(|(_, (a, _, _))| *a)
|
||||
.map(|(node, (_, _, sh))| GraphTensor::<()>::from_id(node, sh, graph_ref))
|
||||
.map(|(node, (_, _, sh))| GraphTensor::from_id(node, sh, graph_ref))
|
||||
.collect::<Vec<_>>();
|
||||
let mut prev_grad = {
|
||||
let (id, sh) = grads[&fwd_node];
|
||||
@@ -139,7 +139,7 @@ impl Compiler for Autograd {
|
||||
prev_grad
|
||||
.shape
|
||||
.expand(op.0, inps[0].shape.dims[inps[0].shape.indexes[op.0]]);
|
||||
let reduced = GraphTensor::<()>::from_id(fwd_node, prev_grad.shape, graph_ref);
|
||||
let reduced = GraphTensor::from_id(fwd_node, prev_grad.shape, graph_ref);
|
||||
let grad = inps[0].equals(reduced) * prev_grad;
|
||||
add_grad(grad, inps[0], graph, &mut grads);
|
||||
}
|
||||
@@ -184,8 +184,8 @@ impl Compiler for Autograd {
|
||||
}
|
||||
|
||||
fn add_grad(
|
||||
mut grad: GraphTensor<()>,
|
||||
fwd: GraphTensor<()>,
|
||||
mut grad: GraphTensor,
|
||||
fwd: GraphTensor,
|
||||
graph: &mut Graph,
|
||||
grad_map: &mut FxHashMap<NodeIndex, (NodeIndex, ShapeTracker)>,
|
||||
) {
|
||||
@@ -226,9 +226,8 @@ fn add_grad(
|
||||
}
|
||||
|
||||
if let Some((existing_grad_node, existing_grad_shape)) = grad_map.get(&fwd.id).copied() {
|
||||
let grad = GraphTensor::<()>::from_id(grad.id, grad.shape, graph);
|
||||
let existing_grad =
|
||||
GraphTensor::<()>::from_id(existing_grad_node, existing_grad_shape, graph);
|
||||
let grad = GraphTensor::from_id(grad.id, grad.shape, graph);
|
||||
let existing_grad = GraphTensor::from_id(existing_grad_node, existing_grad_shape, graph);
|
||||
let new_grad = grad + existing_grad;
|
||||
grad_map.insert(fwd.id, (new_grad.id, new_grad.shape));
|
||||
} else {
|
||||
@@ -244,14 +243,14 @@ mod tests {
|
||||
luminal::test_imports!();
|
||||
|
||||
fn get_vec(grad: (NodeIndex, ShapeTracker), cx: &mut Graph) -> Vec<f32> {
|
||||
GraphTensor::<()>::from_id(grad.0, grad.1, cx).data()
|
||||
GraphTensor::from_id(grad.0, grad.1, cx).data()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_autograd_max_reduce() {
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.named_tensor("Input").set([10., 5.]);
|
||||
let b = a.max_reduce();
|
||||
let a = cx.named_tensor("Input", 2).set([10., 5.]);
|
||||
let b = a.max_reduce(0);
|
||||
|
||||
let grads = cx.compile(Autograd::new(a, b), ());
|
||||
cx.keep_tensors(&grads);
|
||||
@@ -268,9 +267,9 @@ mod tests {
|
||||
#[test]
|
||||
fn test_autograd_matmul() {
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.named_tensor("A").set([[2., 4.], [3., 1.]]);
|
||||
let input = cx.named_tensor("Input").set([10., 5.]);
|
||||
let output = (input.matmul(a)).sum_reduce();
|
||||
let a = cx.named_tensor("A", (2, 2)).set([[2., 4.], [3., 1.]]);
|
||||
let input = cx.named_tensor("Input", 2).set([10., 5.]);
|
||||
let output = (input.matmul(a)).sum_reduce(0);
|
||||
|
||||
let grads = cx.compile(Autograd::new(a, output), ());
|
||||
cx.keep_tensors(&grads);
|
||||
@@ -288,15 +287,15 @@ mod tests {
|
||||
#[test]
|
||||
fn test_autograd_mlp() {
|
||||
let mut cx = Graph::new();
|
||||
let model = <(
|
||||
luminal_nn::Linear<2, 2>,
|
||||
let model = (
|
||||
luminal_nn::Linear::new(2, 2, false, &mut cx),
|
||||
luminal_nn::ReLU,
|
||||
luminal_nn::Linear<2, 1>,
|
||||
)>::initialize(&mut cx);
|
||||
luminal_nn::Linear::new(2, 1, false, &mut cx),
|
||||
);
|
||||
model.0.weight.set([[2., 4.], [3., 1.]]);
|
||||
model.2.weight.set([[6.], [5.]]);
|
||||
let input = cx.named_tensor("Input").set([10., 5.]);
|
||||
let output = model.forward(input).sum_reduce();
|
||||
let input = cx.named_tensor("Input", 2).set([10., 5.]);
|
||||
let output = model.forward(input).sum_reduce(0);
|
||||
|
||||
let mut grads = cx.compile(Autograd::new(params(model), output), ());
|
||||
cx.keep_tensors(&grads);
|
||||
@@ -328,8 +327,8 @@ mod tests {
|
||||
#[test]
|
||||
fn test_autograd_layer_norm() {
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor().set([-1., 2., 3.]);
|
||||
let mut b = a.layer_norm(1e-5).max_reduce().retrieve();
|
||||
let a = cx.tensor(3).set([-1., 2., 3.]);
|
||||
let mut b = a.layer_norm(0, 1e-5).max_reduce(0).retrieve();
|
||||
|
||||
let grads = cx.compile(Autograd::new(a, b), &mut b);
|
||||
cx.keep_tensors(&grads);
|
||||
@@ -347,8 +346,8 @@ mod tests {
|
||||
#[test]
|
||||
fn test_autograd_softmax() {
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor().set([-1., 2., 3.]);
|
||||
let mut b = a.softmax().max_reduce().retrieve();
|
||||
let a = cx.tensor(3).set([-1., 2., 3.]);
|
||||
let mut b = a.softmax(0).max_reduce(0).retrieve();
|
||||
|
||||
let mut grads = cx.compile(Autograd::new(a, b), &mut b);
|
||||
cx.keep_tensors(&grads);
|
||||
@@ -366,7 +365,7 @@ mod tests {
|
||||
#[test]
|
||||
fn test_autograd_transformer() {
|
||||
let mut cx = Graph::new();
|
||||
let model: luminal_nn::TransformerEncoderBlock<3, 4, 1> = InitModule::initialize(&mut cx);
|
||||
let model = luminal_nn::TransformerEncoderBlock::new(3, 4, 1, &mut cx);
|
||||
model
|
||||
.attention
|
||||
.w_k
|
||||
@@ -398,8 +397,8 @@ mod tests {
|
||||
.weight
|
||||
.set(vec![-1., 12., 3., -1., 2., -3., 11., 2., 3., 3., -1., 2.]);
|
||||
|
||||
let a = cx.tensor().set([[-1., 2., 3.], [3., 3., -1.]]);
|
||||
let target = cx.tensor().set([[0., 1., 0.], [0., 0., 1.]]);
|
||||
let a = cx.tensor((2, 3)).set([[-1., 2., 3.], [3., 3., -1.]]);
|
||||
let target = cx.tensor((2, 3)).set([[0., 1., 0.], [0., 0., 1.]]);
|
||||
let out = model.forward(a);
|
||||
let mut loss = crate::cross_entropy_with_logits_loss(out, target).retrieve();
|
||||
|
||||
|
||||
@@ -3,22 +3,26 @@ use luminal::prelude::*;
|
||||
/// [Mean Squared Error](https://en.wikipedia.org/wiki/Mean_squared_error).
|
||||
///
|
||||
/// This computes `(prediction - target).square().mean()`.
|
||||
pub fn mse_loss<S: Shape>(prediction: GraphTensor<S>, target: GraphTensor<S>) -> GraphTensor<()> {
|
||||
(prediction - target).square().mean_reduce()
|
||||
pub fn mse_loss(prediction: GraphTensor, target: GraphTensor) -> GraphTensor {
|
||||
(prediction - target)
|
||||
.square()
|
||||
.mean_reduce(prediction.shape.all_axes())
|
||||
}
|
||||
|
||||
/// [Root Mean square error](https://en.wikipedia.org/wiki/Root-mean-square_deviation).
|
||||
///
|
||||
/// This computes `(prediction - target).square().mean().sqrt()`
|
||||
pub fn rmse_loss<S: Shape>(prediction: GraphTensor<S>, target: GraphTensor<S>) -> GraphTensor<()> {
|
||||
pub fn rmse_loss(prediction: GraphTensor, target: GraphTensor) -> GraphTensor {
|
||||
mse_loss(prediction, target).sqrt()
|
||||
}
|
||||
|
||||
/// [Mean absolute error](https://en.wikipedia.org/wiki/Mean_absolute_error).
|
||||
///
|
||||
/// This computes `(prediction - target).abs().mean()`
|
||||
pub fn mae_loss<S: Shape>(prediction: GraphTensor<S>, target: GraphTensor<S>) -> GraphTensor<()> {
|
||||
(prediction - target).abs().mean_reduce()
|
||||
pub fn mae_loss(prediction: GraphTensor, target: GraphTensor) -> GraphTensor {
|
||||
(prediction - target)
|
||||
.abs()
|
||||
.mean_reduce(prediction.shape.all_axes())
|
||||
}
|
||||
|
||||
/// [Huber Loss](https://en.wikipedia.org/wiki/Huber_loss)
|
||||
@@ -28,18 +32,19 @@ pub fn mae_loss<S: Shape>(prediction: GraphTensor<S>, target: GraphTensor<S>) ->
|
||||
/// It computes:
|
||||
/// 1. if `|x - y| < delta`: `0.5 * (x - y)^2`
|
||||
/// 2. otherwise: `delta * (|x - y| - 0.5 * delta)`
|
||||
pub fn huber_loss<S: Shape>(
|
||||
prediction: GraphTensor<S>,
|
||||
target: GraphTensor<S>,
|
||||
pub fn huber_loss(
|
||||
prediction: GraphTensor,
|
||||
target: GraphTensor,
|
||||
delta: impl Into<f32>,
|
||||
) -> GraphTensor<()> {
|
||||
) -> GraphTensor {
|
||||
let delta: f32 = delta.into();
|
||||
let abs_error = (prediction - target).abs();
|
||||
let delta_tensor = prediction.graph().constant(delta);
|
||||
let huber_error = (0.5 * (prediction - target).square())
|
||||
* abs_error.less_than(delta_tensor.expand())
|
||||
+ (delta * (abs_error - 0.5 * delta)) * abs_error.greater_than_equal(delta_tensor.expand());
|
||||
huber_error.mean_reduce()
|
||||
* abs_error.less_than(delta_tensor.expand_to(abs_error.shape))
|
||||
+ (delta * (abs_error - 0.5 * delta))
|
||||
* abs_error.greater_than_equal(delta_tensor.expand_to(abs_error.shape));
|
||||
huber_error.mean_reduce(huber_error.shape.all_axes())
|
||||
}
|
||||
|
||||
/// Smooth l1 loss (closely related to [Huber Loss](https://en.wikipedia.org/wiki/Huber_loss))
|
||||
@@ -49,11 +54,11 @@ pub fn huber_loss<S: Shape>(
|
||||
/// It computes:
|
||||
/// 1. if `|x - y| < beta`: `0.5 * (x - y)^2 / beta`
|
||||
/// 2. otherwise: `|x - y| - 0.5 * beta`
|
||||
pub fn smooth_l1_loss<S: Shape>(
|
||||
prediction: GraphTensor<S>,
|
||||
target: GraphTensor<S>,
|
||||
delta: impl Copy + Into<f32>,
|
||||
) -> GraphTensor<()> {
|
||||
pub fn smooth_l1_loss(
|
||||
prediction: GraphTensor,
|
||||
target: GraphTensor,
|
||||
delta: impl Into<f32> + Copy,
|
||||
) -> GraphTensor {
|
||||
huber_loss(prediction, target, delta) / delta.into()
|
||||
}
|
||||
|
||||
@@ -67,16 +72,17 @@ pub fn smooth_l1_loss<S: Shape>(
|
||||
///
|
||||
/// - `logits`: The un-normalized output from a model. [log_softmax()] is called **in** this function
|
||||
/// - `target_probabilities`: Target containing probability vectors **NOT** class indices.
|
||||
pub fn cross_entropy_with_logits_loss<S: Shape>(
|
||||
logits: GraphTensor<S>,
|
||||
target_probabilities: GraphTensor<S>,
|
||||
) -> GraphTensor<()> {
|
||||
pub fn cross_entropy_with_logits_loss(
|
||||
logits: GraphTensor,
|
||||
target_probabilities: GraphTensor,
|
||||
) -> GraphTensor {
|
||||
let inv_last_axis_numel = 1.0
|
||||
/ logits
|
||||
.graph()
|
||||
.constant(logits.shape.shape().last().unwrap());
|
||||
let probs = logits.log_softmax::<S::LastAxis>();
|
||||
(-(probs * target_probabilities).mean_reduce()) / inv_last_axis_numel
|
||||
let probs = logits.log_softmax(logits.shape.last_axis());
|
||||
(-(probs * target_probabilities).mean_reduce(target_probabilities.shape.all_axes()))
|
||||
/ inv_last_axis_numel
|
||||
}
|
||||
|
||||
/// [KL Divergence loss](https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence).
|
||||
@@ -89,16 +95,17 @@ pub fn cross_entropy_with_logits_loss<S: Shape>(
|
||||
///
|
||||
/// - `logits`: The un-normalized output from a model. [log_softmax()] is called **in** this function
|
||||
/// - `target_probs`: Target containing probability vectors **NOT** class indices.
|
||||
pub fn kl_div_with_logits_loss<S: Shape>(
|
||||
logits: GraphTensor<S>,
|
||||
target_probabilities: GraphTensor<S>,
|
||||
) -> GraphTensor<()> {
|
||||
pub fn kl_div_with_logits_loss(
|
||||
logits: GraphTensor,
|
||||
target_probabilities: GraphTensor,
|
||||
) -> GraphTensor {
|
||||
let inv_last_axis_numel = 1.0
|
||||
/ logits
|
||||
.graph()
|
||||
.constant(logits.shape.shape().last().unwrap());
|
||||
let probs = logits.log_softmax::<S::LastAxis>();
|
||||
(-((probs - target_probabilities.ln()) * target_probabilities).mean_reduce())
|
||||
let probs = logits.log_softmax(logits.shape.last_axis());
|
||||
(-((probs - target_probabilities.ln()) * target_probabilities)
|
||||
.mean_reduce(target_probabilities.shape.all_axes()))
|
||||
/ inv_last_axis_numel
|
||||
}
|
||||
|
||||
@@ -111,10 +118,10 @@ pub fn kl_div_with_logits_loss<S: Shape>(
|
||||
/// ### Inputs
|
||||
/// - `logits` - unnormalized inputs. **NOT** output of sigmoid
|
||||
/// - `target_probs` - target values between 0 and 1.
|
||||
pub fn binary_cross_entropy_with_logits_loss<S: Shape>(
|
||||
logits: GraphTensor<S>,
|
||||
target_probabilities: GraphTensor<S>,
|
||||
) -> GraphTensor<()> {
|
||||
pub fn binary_cross_entropy_with_logits_loss(
|
||||
logits: GraphTensor,
|
||||
target_probabilities: GraphTensor,
|
||||
) -> GraphTensor {
|
||||
let bce = (1.0 - target_probabilities) * logits + (1.0 + (-logits).exp()).ln();
|
||||
bce.mean_reduce()
|
||||
bce.mean_reduce(bce.shape.all_axes())
|
||||
}
|
||||
|
||||
@@ -12,12 +12,12 @@ pub fn sgd(
|
||||
Vec<NodeIndex>,
|
||||
Vec<NodeIndex>,
|
||||
Graph,
|
||||
GraphTensor<()>,
|
||||
GraphTensor,
|
||||
) {
|
||||
let mut opt_graph = Graph::new();
|
||||
let (old_weights, gradients): (Vec<NodeIndex>, Vec<NodeIndex>) = grads
|
||||
.iter()
|
||||
.map(|_| (opt_graph.tensor::<()>().id, opt_graph.tensor::<()>().id))
|
||||
.map(|_| (opt_graph.tensor(1).id, opt_graph.tensor(1).id))
|
||||
.unzip();
|
||||
|
||||
let (new_weights, lr) = sgd_on_graph(
|
||||
@@ -41,12 +41,12 @@ pub fn sgd_on_graph(
|
||||
graph: &mut Graph,
|
||||
old_weights: impl ToIds,
|
||||
grads: &[(NodeIndex, ShapeTracker)],
|
||||
) -> (Vec<NodeIndex>, GraphTensor<()>) {
|
||||
let lr = graph.named_tensor("Learning Rate").set(3e-4).keep(); // Karpathy constant
|
||||
) -> (Vec<NodeIndex>, GraphTensor) {
|
||||
let lr = graph.named_tensor("Learning Rate", 1).set(3e-4).keep(); // Karpathy constant
|
||||
let mut new_weights = vec![];
|
||||
for ((grad_id, grad_shape), old_weight_id) in grads.iter().copied().zip(old_weights.to_ids()) {
|
||||
let old_weight = GraphTensor::<()>::from_id(old_weight_id, grad_shape, graph);
|
||||
let gradient = GraphTensor::<()>::from_id(grad_id, grad_shape, graph);
|
||||
let old_weight = GraphTensor::from_id(old_weight_id, grad_shape, graph);
|
||||
let gradient = GraphTensor::from_id(grad_id, grad_shape, graph);
|
||||
|
||||
// SGD
|
||||
let new_weight = old_weight - (gradient * lr.expand_to(grad_shape));
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
<|begin_of_text|>Here is an implementation of merge sort:
|
||||
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
||||
|
||||
```python
|
||||
You are a helpful AI assistant for travel tips and recommendations<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||
|
||||
What can you help me with?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
@@ -3,5 +3,5 @@ SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
|
||||
|
||||
echo "Downloading Model and Tokenizer..."
|
||||
curl --location https://huggingface.co/NousResearch/Meta-Llama-3-8B-Instruct/resolve/main/tokenizer.json?download=true --output $SCRIPT_DIR/tokenizer.json
|
||||
curl --location https://huggingface.co/QuantFactory/Meta-Llama-3-8B-GGUF/resolve/main/Meta-Llama-3-8B.Q8_0.gguf?download=true --output $SCRIPT_DIR/llama3-8b.gguf
|
||||
curl --location https://huggingface.co/QuantFactory/Meta-Llama-3-8B-Instruct-GGUF-v2/resolve/main/Meta-Llama-3-8B-Instruct-v2.Q8_0.gguf?download=true --output $SCRIPT_DIR/llama3-8b.gguf
|
||||
echo "Done!"
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
use std::{
|
||||
io::{self, Write},
|
||||
marker::PhantomData,
|
||||
time::Instant,
|
||||
};
|
||||
|
||||
use clap::Parser;
|
||||
use colored::Colorize;
|
||||
use itertools::Itertools;
|
||||
use model::{HEAD_DIM, N_KV_HEADS};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
mod gguf;
|
||||
@@ -39,15 +39,20 @@ fn main() {
|
||||
|
||||
// Set up graph
|
||||
let mut cx = Graph::new();
|
||||
let mut input = cx.named_tensor::<(Const<1>, Dyn<'s'>)>("Input");
|
||||
let mut cache_src: Vec<KVCache<Const<1>, Dyn<'p'>>> = (0..model::NUM_LAYERS)
|
||||
.map(|_| (cx.named_tensor("Key Cache"), cx.named_tensor("Value Cache")))
|
||||
let mut input = cx.named_tensor("Input", (1, 's'));
|
||||
let mut cache_src: Vec<KVCache> = (0..model::NUM_LAYERS)
|
||||
.map(|_| {
|
||||
(
|
||||
cx.named_tensor("Key Cache", (1, N_KV_HEADS, 'p', HEAD_DIM)),
|
||||
cx.named_tensor("Value Cache", (1, N_KV_HEADS, 'p', HEAD_DIM)),
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
cache_src.set_dyn(vec![], &[1, model::N_KV_HEADS, 0, model::HEAD_DIM]);
|
||||
let model = model::Llama::initialize(&mut cx);
|
||||
cache_src.set_dyn(vec![], (1, model::N_KV_HEADS, 0, model::HEAD_DIM));
|
||||
let model = model::Llama::new(&mut cx);
|
||||
let mut model_weights = params(&model);
|
||||
cx.keep_tensors(&model_weights);
|
||||
let (logits, mut cache_dest) = model.forward((input, &cache_src, PhantomData::<Dyn<'t'>>));
|
||||
let (logits, mut cache_dest) = model.forward((input, &cache_src));
|
||||
let mut logits = logits
|
||||
.slice((.., (Expression::from('s') - 1).., ..))
|
||||
.retrieve();
|
||||
@@ -69,8 +74,8 @@ fn main() {
|
||||
GenericCompiler::default(),
|
||||
#[cfg(feature = "metal")]
|
||||
(
|
||||
luminal_metal::MetalCompilerPreBuffer::<f16>::default(),
|
||||
luminal_metal::quantized::MetalQuantizedCompiler::<f16>::new(q_weights),
|
||||
luminal_metal::MetalCompilerPreBuffer::<f32>::default(),
|
||||
luminal_metal::quantized::MetalQuantizedCompiler::<f32>::new(q_weights),
|
||||
luminal_metal::BufferCompilers::default(),
|
||||
),
|
||||
#[cfg(feature = "cuda")]
|
||||
@@ -93,7 +98,7 @@ fn main() {
|
||||
print!("Loading model");
|
||||
io::stdout().flush().unwrap();
|
||||
let now = Instant::now();
|
||||
input.set_dyn(vec![1.], &[1, 1]);
|
||||
input.set_dyn(vec![1.], (1, 1));
|
||||
cx.set_dyn_dim('t', 1);
|
||||
cx.execute();
|
||||
logits.drop();
|
||||
@@ -112,7 +117,7 @@ fn main() {
|
||||
.to_vec();
|
||||
input.set_dyn(
|
||||
input_ids.iter().map(|i| *i as f32).collect::<Vec<_>>(),
|
||||
&[1, input_ids.len()],
|
||||
(1, input_ids.len()),
|
||||
);
|
||||
cx.set_dyn_dim('t', input_ids.len());
|
||||
print!("Processing Prompt");
|
||||
@@ -130,10 +135,8 @@ fn main() {
|
||||
|
||||
// Decode token
|
||||
print!("{}", cli_args.prompt.white().bold());
|
||||
print!(
|
||||
"{}",
|
||||
tokenizer.decode(&output_ids, false).unwrap().bright_green()
|
||||
);
|
||||
let initial = tokenizer.decode(&output_ids, false).unwrap().bright_green();
|
||||
print!("{initial}",);
|
||||
io::stdout().flush().unwrap();
|
||||
|
||||
// Swap caches
|
||||
@@ -141,11 +144,10 @@ fn main() {
|
||||
|
||||
// Decode loop
|
||||
let start_decode = std::time::Instant::now();
|
||||
let mut prev_output_len = 0;
|
||||
let mut prev_output_len = initial.len();
|
||||
for _ in 0..cli_args.gen_tokens {
|
||||
input.set_dyn(vec![*output_ids.last().unwrap() as f32], &[1, 1]);
|
||||
input.set_dyn(vec![*output_ids.last().unwrap() as f32], (1, 1));
|
||||
cx.set_dyn_dim('p', input_ids.len() + output_ids.len() - 1);
|
||||
cx.set_dyn_dim('t', input_ids.len() + output_ids.len());
|
||||
cx.execute();
|
||||
|
||||
// Sample tokens
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
use std::{marker::PhantomData, ops::Div};
|
||||
|
||||
use luminal::prelude::{binary::F32Pow, *};
|
||||
use luminal_nn::{Embedding, LayerNorm, PermutedLinear};
|
||||
use luminal_nn::{Embedding, LayerNorm, Linear};
|
||||
|
||||
// Llama3 8B Config
|
||||
pub const VOCAB_SIZE: usize = 128256;
|
||||
@@ -13,49 +11,37 @@ pub const MLP_DIM: usize = 14336;
|
||||
|
||||
pub const N_ATTENTION_GROUPS: usize = N_HEADS / N_KV_HEADS;
|
||||
pub const HEAD_DIM: usize = HIDDEN_DIM / N_HEADS;
|
||||
pub const HEAD_DIM_OVER_2: usize = HEAD_DIM / 2;
|
||||
pub const ATTN_PROJ_DIM: usize = HEAD_DIM * N_KV_HEADS;
|
||||
|
||||
pub type KVCache<Batch, Seq> = (
|
||||
GraphTensor<(Batch, Const<N_KV_HEADS>, Seq, Const<HEAD_DIM>)>,
|
||||
GraphTensor<(Batch, Const<N_KV_HEADS>, Seq, Const<HEAD_DIM>)>,
|
||||
);
|
||||
pub type KVCache = (GraphTensor, GraphTensor);
|
||||
|
||||
pub struct Mlp<const I: usize, const H: usize> {
|
||||
pub gate_proj: PermutedLinear<H, I>,
|
||||
pub down_proj: PermutedLinear<I, H>,
|
||||
pub up_proj: PermutedLinear<H, I>,
|
||||
pub struct Mlp {
|
||||
pub gate_proj: Linear, // hidden -> intermediate
|
||||
pub down_proj: Linear, // intermediate -> hidden
|
||||
pub up_proj: Linear, // hidden -> intermediate
|
||||
}
|
||||
|
||||
impl<const I: usize, const H: usize, Batch: Dimension, Batch1: Dimension>
|
||||
Module<GraphTensor<(Batch, Batch1, Const<H>)>> for Mlp<I, H>
|
||||
{
|
||||
type Output = GraphTensor<(Batch, Batch1, Const<H>)>;
|
||||
impl Module<GraphTensor> for Mlp {
|
||||
type Output = GraphTensor;
|
||||
|
||||
fn forward(&self, input: GraphTensor<(Batch, Batch1, Const<H>)>) -> Self::Output {
|
||||
fn forward(&self, input: GraphTensor) -> Self::Output {
|
||||
let gate = self.gate_proj.forward(input).swish();
|
||||
let up = self.up_proj.forward(input) * gate;
|
||||
self.down_proj.forward(up)
|
||||
}
|
||||
}
|
||||
|
||||
impl<const I: usize, const H: usize> InitModule for Mlp<I, H> {
|
||||
fn initialize(cx: &mut Graph) -> Self {
|
||||
impl Mlp {
|
||||
pub fn new(hidden: usize, intermediate: usize, cx: &mut Graph) -> Self {
|
||||
Self {
|
||||
gate_proj: PermutedLinear {
|
||||
weight: cx.named_tensor("Gate"),
|
||||
},
|
||||
up_proj: PermutedLinear {
|
||||
weight: cx.named_tensor("Up"),
|
||||
},
|
||||
down_proj: PermutedLinear {
|
||||
weight: cx.named_tensor("Down"),
|
||||
},
|
||||
gate_proj: Linear::new_permuted(hidden, intermediate, false, cx),
|
||||
down_proj: Linear::new_permuted(intermediate, hidden, false, cx),
|
||||
up_proj: Linear::new_permuted(hidden, intermediate, false, cx),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<const I: usize, const H: usize> SerializeModule for Mlp<I, H> {
|
||||
impl SerializeModule for Mlp {
|
||||
fn serialize(&self, s: &mut Serializer) {
|
||||
s.module("ffn_gate", &self.gate_proj);
|
||||
s.module("ffn_up", &self.up_proj);
|
||||
@@ -63,122 +49,105 @@ impl<const I: usize, const H: usize> SerializeModule for Mlp<I, H> {
|
||||
}
|
||||
}
|
||||
|
||||
fn apply_rotary_embeddings_ggml<const N_HEADS: usize, Batch: Dimension, Seq: Dimension>(
|
||||
input: GraphTensor<(Batch, Const<N_HEADS>, Seq, Const<HEAD_DIM>)>,
|
||||
prev_seq: BigExpression,
|
||||
) -> GraphTensor<(Batch, Const<N_HEADS>, Seq, Const<HEAD_DIM>)> {
|
||||
fn apply_rotary_embeddings_ggml(input: GraphTensor, prev_seq: BigExpression) -> GraphTensor {
|
||||
assert_eq!(input.shape.len(), 4); // batch, n_heads, seq, head_dim
|
||||
let (batch, n_heads, seq, head_dim) = input.dims4();
|
||||
// Get freqs
|
||||
let freqs = (input.graph().arange::<Const<HEAD_DIM_OVER_2>>() * 2.0) / (HEAD_DIM as f32);
|
||||
let freqs = (input.graph().arange(head_dim / 2) * 2.0) / (head_dim.to_usize().unwrap() as f32);
|
||||
let freqs = 500_000_f32.pow(freqs);
|
||||
let pos = input.graph().arange::<Seq>() + prev_seq;
|
||||
let emb = pos.expand::<(_, Const<1>), _>().matmul(freqs.expand());
|
||||
let pos = input.graph().arange(seq) + prev_seq;
|
||||
let emb = pos.expand(1, 1).matmul(freqs.expand(0, 1));
|
||||
|
||||
// Split input into evens and odds
|
||||
let split = input.reshape::<(Batch, Const<N_HEADS>, Seq, Const<HEAD_DIM_OVER_2>, Const<2>)>();
|
||||
let x0: GraphTensor<(Batch, Const<N_HEADS>, Seq, Const<HEAD_DIM_OVER_2>, Const<1>)> =
|
||||
split.slice((.., .., .., .., ..1)).realize();
|
||||
let x1: GraphTensor<(Batch, Const<N_HEADS>, Seq, Const<HEAD_DIM_OVER_2>, Const<1>)> =
|
||||
split.slice((.., .., .., .., 1..)).realize();
|
||||
let split = input.reshape((batch, n_heads, seq, head_dim / 2, 2));
|
||||
let x0 = split.slice((.., .., .., .., ..1));
|
||||
let x1 = split.slice((.., .., .., .., 1..));
|
||||
|
||||
// Apply sin and cos embeddings
|
||||
let x0_out = x0 * emb.cos().expand() - x1 * emb.sin().expand();
|
||||
let x1_out = x0 * emb.sin().expand() + x1 * emb.cos().expand();
|
||||
let x0_out = x0 * emb.cos().expand_to(x0.shape) - x1 * emb.sin().expand_to(x1.shape);
|
||||
let x1_out = x0 * emb.sin().expand_to(x0.shape) + x1 * emb.cos().expand_to(x1.shape);
|
||||
|
||||
// Combine back into output
|
||||
x0_out
|
||||
.concat_along::<(Batch, Const<N_HEADS>, Seq, Const<HEAD_DIM_OVER_2>, Const<2>), Axis<4>, _>(
|
||||
x1_out,
|
||||
)
|
||||
.reshape()
|
||||
x0_out.concat_along(x1_out, 4).reshape(input.shape)
|
||||
}
|
||||
|
||||
pub struct SelfAttention {
|
||||
pub q_proj: GraphTensor<R2<HIDDEN_DIM, HIDDEN_DIM>>,
|
||||
pub k_proj: GraphTensor<R2<ATTN_PROJ_DIM, HIDDEN_DIM>>,
|
||||
pub v_proj: GraphTensor<R2<ATTN_PROJ_DIM, HIDDEN_DIM>>,
|
||||
pub o_proj: GraphTensor<R2<HIDDEN_DIM, HIDDEN_DIM>>,
|
||||
pub q_proj: GraphTensor, // Hidden -> hidden
|
||||
pub k_proj: GraphTensor, // Proj dim -> hidden
|
||||
pub v_proj: GraphTensor, // Proj dim -> hidden
|
||||
pub o_proj: GraphTensor, // Hidden -> hidden
|
||||
}
|
||||
|
||||
impl<Batch: Dimension, CurSeq: Dimension, PrevSeq: Dimension, TotSeq: Dimension>
|
||||
Module<(
|
||||
GraphTensor<(Batch, CurSeq, Const<HIDDEN_DIM>)>,
|
||||
KVCache<Batch, PrevSeq>,
|
||||
PhantomData<TotSeq>,
|
||||
)> for SelfAttention
|
||||
{
|
||||
type Output = (
|
||||
GraphTensor<(Batch, CurSeq, Const<HIDDEN_DIM>)>,
|
||||
KVCache<Batch, TotSeq>,
|
||||
);
|
||||
fn forward(
|
||||
&self,
|
||||
(x, (k_cache, v_cache), _): (
|
||||
GraphTensor<(Batch, CurSeq, Const<HIDDEN_DIM>)>,
|
||||
KVCache<Batch, PrevSeq>,
|
||||
PhantomData<TotSeq>,
|
||||
),
|
||||
) -> Self::Output {
|
||||
impl Module<(GraphTensor, KVCache)> for SelfAttention {
|
||||
type Output = (GraphTensor, KVCache);
|
||||
fn forward(&self, (x, (k_cache, v_cache)): (GraphTensor, KVCache)) -> Self::Output {
|
||||
// x: batch, seq, hidden
|
||||
// cache: batch, kv_heads, prev_seq, head_dim
|
||||
let (batch, seq, _) = x.dims3();
|
||||
let (_, _, prev_seq, _) = k_cache.dims4();
|
||||
// Apply the Projections
|
||||
let queries = x
|
||||
.matmul(self.q_proj.permute())
|
||||
.reshape::<(Batch, CurSeq, Const<N_HEADS>, Const<HEAD_DIM>)>()
|
||||
.permute::<_, Axes4<0, 2, 1, 3>>();
|
||||
.matmul(self.q_proj.permute((1, 0)))
|
||||
.reshape((batch, seq, N_HEADS, HEAD_DIM))
|
||||
.permute((0, 2, 1, 3));
|
||||
|
||||
let keys = x
|
||||
.matmul(self.k_proj.permute())
|
||||
.reshape::<(Batch, CurSeq, Const<N_KV_HEADS>, Const<HEAD_DIM>)>()
|
||||
.permute::<_, Axes4<0, 2, 1, 3>>();
|
||||
.matmul(self.k_proj.permute((1, 0)))
|
||||
.reshape((batch, seq, N_KV_HEADS, HEAD_DIM))
|
||||
.permute((0, 2, 1, 3));
|
||||
|
||||
let values = x
|
||||
.matmul(self.v_proj.permute())
|
||||
.reshape::<(Batch, CurSeq, Const<N_KV_HEADS>, Const<HEAD_DIM>)>()
|
||||
.permute::<_, Axes4<0, 2, 1, 3>>();
|
||||
.matmul(self.v_proj.permute((1, 0)))
|
||||
.reshape((batch, seq, N_KV_HEADS, HEAD_DIM))
|
||||
.permute((0, 2, 1, 3));
|
||||
|
||||
// Rotary embed queries and keys
|
||||
let queries = apply_rotary_embeddings_ggml(queries, PrevSeq::size().into());
|
||||
let keys = apply_rotary_embeddings_ggml(keys, PrevSeq::size().into());
|
||||
let queries = apply_rotary_embeddings_ggml(queries, prev_seq.big());
|
||||
let keys = apply_rotary_embeddings_ggml(keys, prev_seq.big());
|
||||
|
||||
// Add KV cache
|
||||
let keys = k_cache.concat_along::<_, Axis<2>, _>(keys);
|
||||
let values = v_cache.concat_along::<_, Axis<2>, _>(values);
|
||||
let keys = k_cache.concat_along(keys, 2);
|
||||
let values = v_cache.concat_along(values, 2);
|
||||
|
||||
// Repeat the KV States for Grouped-Query Attention
|
||||
let repeated_keys = keys.expand::<(_, _, Const<N_ATTENTION_GROUPS>, _, _), _>();
|
||||
let repeated_values = values.expand::<(_, _, Const<N_ATTENTION_GROUPS>, _, _), _>();
|
||||
let repeated_keys = keys.expand(2, N_ATTENTION_GROUPS);
|
||||
let repeated_values = values.expand(2, N_ATTENTION_GROUPS);
|
||||
|
||||
// Calculate attention weights
|
||||
let mut attention_weights = queries
|
||||
.reshape::<(_, Const<N_KV_HEADS>, Const<N_ATTENTION_GROUPS>, _, _)>() // Split query heads into groups
|
||||
.matmul(repeated_keys.permute())
|
||||
.div((HEAD_DIM as f32).sqrt());
|
||||
.reshape((batch, N_KV_HEADS, N_ATTENTION_GROUPS, seq, HEAD_DIM)) // Split query heads into groups
|
||||
.matmul(repeated_keys.permute((0, 1, 2, 4, 3)))
|
||||
/ (HEAD_DIM as f32).sqrt();
|
||||
|
||||
let attention_mask = self.k_proj.graph().triu::<CurSeq>(1) * f16::MIN.to_f32();
|
||||
let attention_mask = self.k_proj.graph().triu(seq, 1) * f16::MIN.to_f32();
|
||||
attention_weights += attention_mask
|
||||
.pad::<(CurSeq, TotSeq)>(((0, 0), (TotSeq::size() - CurSeq::size(), 0)))
|
||||
.expand();
|
||||
.pad(((0, 0), (prev_seq, 0)))
|
||||
.expand(0, batch)
|
||||
.expand(1, N_KV_HEADS)
|
||||
.expand(2, N_ATTENTION_GROUPS);
|
||||
|
||||
// Calculate final outputs
|
||||
let output = attention_weights
|
||||
.softmax::<Axis<4>>()
|
||||
.softmax(4)
|
||||
// Apply distribution to values
|
||||
.matmul(repeated_values)
|
||||
// Merge heads
|
||||
.permute::<_, Axes5<0, 3, 1, 2, 4>>()
|
||||
.reshape::<(Batch, CurSeq, Const<HIDDEN_DIM>)>();
|
||||
.permute((0, 3, 1, 2, 4))
|
||||
.reshape((batch, seq, HIDDEN_DIM));
|
||||
let output = output
|
||||
// Apply output projection
|
||||
.matmul(self.o_proj.permute());
|
||||
.matmul(self.o_proj.permute((1, 0)));
|
||||
(output, (keys.contiguous(), values.contiguous())) // Cache needs to be contiguous for transferring to another graph
|
||||
}
|
||||
}
|
||||
|
||||
impl InitModule for SelfAttention {
|
||||
fn initialize(cx: &mut Graph) -> Self {
|
||||
impl SelfAttention {
|
||||
pub fn new(cx: &mut Graph) -> Self {
|
||||
Self {
|
||||
q_proj: cx.named_tensor("Q Proj"),
|
||||
k_proj: cx.named_tensor("K Proj"),
|
||||
v_proj: cx.named_tensor("V Proj"),
|
||||
o_proj: cx.named_tensor("O Proj"),
|
||||
q_proj: cx.named_tensor("Q Proj", (HIDDEN_DIM, HIDDEN_DIM)),
|
||||
k_proj: cx.named_tensor("K Proj", (ATTN_PROJ_DIM, HIDDEN_DIM)),
|
||||
v_proj: cx.named_tensor("V Proj", (ATTN_PROJ_DIM, HIDDEN_DIM)),
|
||||
o_proj: cx.named_tensor("O Proj", (HIDDEN_DIM, HIDDEN_DIM)),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -194,54 +163,37 @@ impl SerializeModule for SelfAttention {
|
||||
|
||||
pub struct TransformerBlock {
|
||||
pub attention: SelfAttention,
|
||||
pub attention_norm: LayerNorm<HIDDEN_DIM>,
|
||||
pub feed_forward: Mlp<MLP_DIM, HIDDEN_DIM>,
|
||||
pub feed_forward_norm: LayerNorm<HIDDEN_DIM>,
|
||||
pub attention_norm: LayerNorm,
|
||||
pub feed_forward: Mlp,
|
||||
pub feed_forward_norm: LayerNorm,
|
||||
}
|
||||
|
||||
impl<Batch: Dimension, CurSeq: Dimension, PrevSeq: Dimension, TotSeq: Dimension>
|
||||
Module<(
|
||||
GraphTensor<(Batch, CurSeq, Const<HIDDEN_DIM>)>,
|
||||
KVCache<Batch, PrevSeq>,
|
||||
PhantomData<TotSeq>,
|
||||
)> for TransformerBlock
|
||||
{
|
||||
type Output = (
|
||||
GraphTensor<(Batch, CurSeq, Const<HIDDEN_DIM>)>,
|
||||
KVCache<Batch, TotSeq>,
|
||||
);
|
||||
fn forward(
|
||||
&self,
|
||||
(mut x, cache, _): (
|
||||
GraphTensor<(Batch, CurSeq, Const<HIDDEN_DIM>)>,
|
||||
KVCache<Batch, PrevSeq>,
|
||||
PhantomData<TotSeq>,
|
||||
),
|
||||
) -> Self::Output {
|
||||
impl Module<(GraphTensor, KVCache)> for TransformerBlock {
|
||||
type Output = (GraphTensor, KVCache);
|
||||
fn forward(&self, (mut x, cache): (GraphTensor, KVCache)) -> Self::Output {
|
||||
// Attention
|
||||
let normed = self.attention_norm.forward(x);
|
||||
let (y, cache) = self
|
||||
.attention
|
||||
.forward((normed, cache, PhantomData::<TotSeq>));
|
||||
.forward((self.attention_norm.forward(x), cache));
|
||||
|
||||
// Residual Addition
|
||||
// Residual
|
||||
x += y;
|
||||
|
||||
// Feed Forward
|
||||
let y = self.feed_forward.forward(self.feed_forward_norm.forward(x));
|
||||
|
||||
// Residual Addition
|
||||
// Residual
|
||||
(x + y, cache)
|
||||
}
|
||||
}
|
||||
|
||||
impl InitModule for TransformerBlock {
|
||||
fn initialize(cx: &mut Graph) -> Self {
|
||||
impl TransformerBlock {
|
||||
pub fn new(cx: &mut Graph) -> Self {
|
||||
Self {
|
||||
attention: InitModule::initialize(cx),
|
||||
attention_norm: LayerNorm::new(true, false, false, 1e-5, cx),
|
||||
feed_forward: InitModule::initialize(cx),
|
||||
feed_forward_norm: LayerNorm::new(true, false, false, 1e-5, cx),
|
||||
attention: SelfAttention::new(cx),
|
||||
attention_norm: LayerNorm::new(HIDDEN_DIM, true, false, false, 1e-5, cx),
|
||||
feed_forward: Mlp::new(HIDDEN_DIM, MLP_DIM, cx),
|
||||
feed_forward_norm: LayerNorm::new(HIDDEN_DIM, true, false, false, 1e-5, cx),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -257,35 +209,16 @@ impl SerializeModule for TransformerBlock {
|
||||
|
||||
pub struct Llama {
|
||||
// Token embeddings
|
||||
pub embedding: Embedding<VOCAB_SIZE, HIDDEN_DIM>,
|
||||
pub embedding: Embedding,
|
||||
// Transformer layers
|
||||
pub layers: Vec<TransformerBlock>,
|
||||
// Norm + LM head
|
||||
pub head: (
|
||||
LayerNorm<HIDDEN_DIM>,
|
||||
PermutedLinear<HIDDEN_DIM, VOCAB_SIZE>,
|
||||
),
|
||||
pub head: (LayerNorm, Linear),
|
||||
}
|
||||
|
||||
impl<Batch: Dimension, CurSeq: Dimension, PrevSeq: Dimension, TotSeq: Dimension>
|
||||
Module<(
|
||||
GraphTensor<(Batch, CurSeq)>,
|
||||
&[KVCache<Batch, PrevSeq>],
|
||||
PhantomData<TotSeq>,
|
||||
)> for Llama
|
||||
{
|
||||
type Output = (
|
||||
GraphTensor<(Batch, CurSeq, Const<VOCAB_SIZE>)>,
|
||||
Vec<KVCache<Batch, TotSeq>>,
|
||||
);
|
||||
fn forward(
|
||||
&self,
|
||||
(input, cache, _): (
|
||||
GraphTensor<(Batch, CurSeq)>,
|
||||
&[KVCache<Batch, PrevSeq>],
|
||||
PhantomData<TotSeq>,
|
||||
),
|
||||
) -> Self::Output {
|
||||
impl Module<(GraphTensor, &[KVCache])> for Llama {
|
||||
type Output = (GraphTensor, Vec<KVCache>);
|
||||
fn forward(&self, (input, cache): (GraphTensor, &[KVCache])) -> Self::Output {
|
||||
// Embed tokens
|
||||
let mut x = self.embedding.forward(input);
|
||||
|
||||
@@ -293,7 +226,7 @@ impl<Batch: Dimension, CurSeq: Dimension, PrevSeq: Dimension, TotSeq: Dimension>
|
||||
let mut new_caches = vec![];
|
||||
let mut new_cache;
|
||||
for (i, layer) in self.layers.iter().enumerate() {
|
||||
(x, new_cache) = layer.forward((x, cache[i], PhantomData::<TotSeq>));
|
||||
(x, new_cache) = layer.forward((x, cache[i]));
|
||||
new_caches.push(new_cache);
|
||||
}
|
||||
// Run through last norm and output projection
|
||||
@@ -301,21 +234,15 @@ impl<Batch: Dimension, CurSeq: Dimension, PrevSeq: Dimension, TotSeq: Dimension>
|
||||
}
|
||||
}
|
||||
|
||||
impl InitModule for Llama {
|
||||
fn initialize(cx: &mut Graph) -> Self {
|
||||
impl Llama {
|
||||
pub fn new(cx: &mut Graph) -> Self {
|
||||
Self {
|
||||
embedding: Embedding {
|
||||
weight: cx.named_tensor("Embedding Weight"),
|
||||
},
|
||||
embedding: Embedding::new(VOCAB_SIZE, HIDDEN_DIM, cx),
|
||||
head: (
|
||||
LayerNorm::new(true, false, false, 1e-5, cx),
|
||||
PermutedLinear {
|
||||
weight: cx.tensor(),
|
||||
},
|
||||
LayerNorm::new(HIDDEN_DIM, true, false, false, 1e-5, cx),
|
||||
Linear::new_permuted(HIDDEN_DIM, VOCAB_SIZE, false, cx),
|
||||
),
|
||||
layers: (0..NUM_LAYERS)
|
||||
.map(|_| InitModule::initialize(cx))
|
||||
.collect(),
|
||||
layers: (0..NUM_LAYERS).map(|_| TransformerBlock::new(cx)).collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
use std::{marker::PhantomData, ops::Div};
|
||||
|
||||
use luminal::prelude::{binary::F32Pow, *};
|
||||
use luminal_nn::{Embedding, LayerNorm, PermutedLinear};
|
||||
use luminal_nn::{Embedding, LayerNorm, Linear};
|
||||
|
||||
// Llama3 8B Config
|
||||
pub const VOCAB_SIZE: usize = 128256;
|
||||
@@ -13,49 +11,37 @@ pub const MLP_DIM: usize = 14336;
|
||||
|
||||
pub const N_ATTENTION_GROUPS: usize = N_HEADS / N_KV_HEADS;
|
||||
pub const HEAD_DIM: usize = HIDDEN_DIM / N_HEADS;
|
||||
pub const HEAD_DIM_OVER_2: usize = HEAD_DIM / 2;
|
||||
pub const ATTN_PROJ_DIM: usize = HEAD_DIM * N_KV_HEADS;
|
||||
|
||||
pub type KVCache<Batch, Seq> = (
|
||||
GraphTensor<(Batch, Const<N_KV_HEADS>, Seq, Const<HEAD_DIM>)>,
|
||||
GraphTensor<(Batch, Const<N_KV_HEADS>, Seq, Const<HEAD_DIM>)>,
|
||||
);
|
||||
pub type KVCache = (GraphTensor, GraphTensor);
|
||||
|
||||
pub struct Mlp<const I: usize, const H: usize> {
|
||||
pub gate_proj: PermutedLinear<H, I>,
|
||||
pub down_proj: PermutedLinear<I, H>,
|
||||
pub up_proj: PermutedLinear<H, I>,
|
||||
pub struct Mlp {
|
||||
pub gate_proj: Linear, // hidden -> intermediate
|
||||
pub down_proj: Linear, // intermediate -> hidden
|
||||
pub up_proj: Linear, // hidden -> intermediate
|
||||
}
|
||||
|
||||
impl<const I: usize, const H: usize, Batch: Dimension, Batch1: Dimension>
|
||||
Module<GraphTensor<(Batch, Batch1, Const<H>)>> for Mlp<I, H>
|
||||
{
|
||||
type Output = GraphTensor<(Batch, Batch1, Const<H>)>;
|
||||
impl Module<GraphTensor> for Mlp {
|
||||
type Output = GraphTensor;
|
||||
|
||||
fn forward(&self, input: GraphTensor<(Batch, Batch1, Const<H>)>) -> Self::Output {
|
||||
fn forward(&self, input: GraphTensor) -> Self::Output {
|
||||
let gate = self.gate_proj.forward(input).swish();
|
||||
let up = self.up_proj.forward(input) * gate;
|
||||
self.down_proj.forward(up)
|
||||
}
|
||||
}
|
||||
|
||||
impl<const I: usize, const H: usize> InitModule for Mlp<I, H> {
|
||||
fn initialize(cx: &mut Graph) -> Self {
|
||||
impl Mlp {
|
||||
pub fn new(hidden: usize, intermediate: usize, cx: &mut Graph) -> Self {
|
||||
Self {
|
||||
gate_proj: PermutedLinear {
|
||||
weight: cx.named_tensor("Gate"),
|
||||
},
|
||||
up_proj: PermutedLinear {
|
||||
weight: cx.named_tensor("Up"),
|
||||
},
|
||||
down_proj: PermutedLinear {
|
||||
weight: cx.named_tensor("Down"),
|
||||
},
|
||||
gate_proj: Linear::new_permuted(hidden, intermediate, false, cx),
|
||||
down_proj: Linear::new_permuted(intermediate, hidden, false, cx),
|
||||
up_proj: Linear::new_permuted(hidden, intermediate, false, cx),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<const I: usize, const H: usize> SerializeModule for Mlp<I, H> {
|
||||
impl SerializeModule for Mlp {
|
||||
fn serialize(&self, s: &mut Serializer) {
|
||||
s.module("ffn_gate", &self.gate_proj);
|
||||
s.module("ffn_up", &self.up_proj);
|
||||
@@ -63,122 +49,105 @@ impl<const I: usize, const H: usize> SerializeModule for Mlp<I, H> {
|
||||
}
|
||||
}
|
||||
|
||||
fn apply_rotary_embeddings_ggml<const N_HEADS: usize, Batch: Dimension, Seq: Dimension>(
|
||||
input: GraphTensor<(Batch, Const<N_HEADS>, Seq, Const<HEAD_DIM>)>,
|
||||
prev_seq: BigExpression,
|
||||
) -> GraphTensor<(Batch, Const<N_HEADS>, Seq, Const<HEAD_DIM>)> {
|
||||
fn apply_rotary_embeddings_ggml(input: GraphTensor, prev_seq: BigExpression) -> GraphTensor {
|
||||
assert_eq!(input.shape.len(), 4); // batch, n_heads, seq, head_dim
|
||||
let (batch, n_heads, seq, head_dim) = input.dims4();
|
||||
// Get freqs
|
||||
let freqs = (input.graph().arange::<Const<HEAD_DIM_OVER_2>>() * 2.0) / (HEAD_DIM as f32);
|
||||
let freqs = (input.graph().arange(head_dim / 2) * 2.0) / (head_dim.to_usize().unwrap() as f32);
|
||||
let freqs = 500_000_f32.pow(freqs);
|
||||
let pos = input.graph().arange::<Seq>() + prev_seq;
|
||||
let emb = pos.expand::<(_, Const<1>), _>().matmul(freqs.expand());
|
||||
let pos = input.graph().arange(seq) + prev_seq;
|
||||
let emb = pos.expand(1, 1).matmul(freqs.expand(0, 1));
|
||||
|
||||
// Split input into evens and odds
|
||||
let split = input.reshape::<(Batch, Const<N_HEADS>, Seq, Const<HEAD_DIM_OVER_2>, Const<2>)>();
|
||||
let x0: GraphTensor<(Batch, Const<N_HEADS>, Seq, Const<HEAD_DIM_OVER_2>, Const<1>)> =
|
||||
split.slice((.., .., .., .., ..1)).realize();
|
||||
let x1: GraphTensor<(Batch, Const<N_HEADS>, Seq, Const<HEAD_DIM_OVER_2>, Const<1>)> =
|
||||
split.slice((.., .., .., .., 1..)).realize();
|
||||
let split = input.reshape((batch, n_heads, seq, head_dim / 2, 2));
|
||||
let x0 = split.slice((.., .., .., .., ..1));
|
||||
let x1 = split.slice((.., .., .., .., 1..));
|
||||
|
||||
// Apply sin and cos embeddings
|
||||
let x0_out = x0 * emb.cos().expand() - x1 * emb.sin().expand();
|
||||
let x1_out = x0 * emb.sin().expand() + x1 * emb.cos().expand();
|
||||
let x0_out = x0 * emb.cos().expand_to(x0.shape) - x1 * emb.sin().expand_to(x1.shape);
|
||||
let x1_out = x0 * emb.sin().expand_to(x0.shape) + x1 * emb.cos().expand_to(x1.shape);
|
||||
|
||||
// Combine back into output
|
||||
x0_out
|
||||
.concat_along::<(Batch, Const<N_HEADS>, Seq, Const<HEAD_DIM_OVER_2>, Const<2>), Axis<4>, _>(
|
||||
x1_out,
|
||||
)
|
||||
.reshape()
|
||||
x0_out.concat_along(x1_out, 4).reshape(input.shape)
|
||||
}
|
||||
|
||||
pub struct SelfAttention {
|
||||
pub q_proj: GraphTensor<R2<HIDDEN_DIM, HIDDEN_DIM>>,
|
||||
pub k_proj: GraphTensor<R2<ATTN_PROJ_DIM, HIDDEN_DIM>>,
|
||||
pub v_proj: GraphTensor<R2<ATTN_PROJ_DIM, HIDDEN_DIM>>,
|
||||
pub o_proj: GraphTensor<R2<HIDDEN_DIM, HIDDEN_DIM>>,
|
||||
pub q_proj: GraphTensor, // Hidden -> hidden
|
||||
pub k_proj: GraphTensor, // Proj dim -> hidden
|
||||
pub v_proj: GraphTensor, // Proj dim -> hidden
|
||||
pub o_proj: GraphTensor, // Hidden -> hidden
|
||||
}
|
||||
|
||||
impl<Batch: Dimension, CurSeq: Dimension, PrevSeq: Dimension, TotSeq: Dimension>
|
||||
Module<(
|
||||
GraphTensor<(Batch, CurSeq, Const<HIDDEN_DIM>)>,
|
||||
KVCache<Batch, PrevSeq>,
|
||||
PhantomData<TotSeq>,
|
||||
)> for SelfAttention
|
||||
{
|
||||
type Output = (
|
||||
GraphTensor<(Batch, CurSeq, Const<HIDDEN_DIM>)>,
|
||||
KVCache<Batch, TotSeq>,
|
||||
);
|
||||
fn forward(
|
||||
&self,
|
||||
(x, (k_cache, v_cache), _): (
|
||||
GraphTensor<(Batch, CurSeq, Const<HIDDEN_DIM>)>,
|
||||
KVCache<Batch, PrevSeq>,
|
||||
PhantomData<TotSeq>,
|
||||
),
|
||||
) -> Self::Output {
|
||||
impl Module<(GraphTensor, KVCache)> for SelfAttention {
|
||||
type Output = (GraphTensor, KVCache);
|
||||
fn forward(&self, (x, (k_cache, v_cache)): (GraphTensor, KVCache)) -> Self::Output {
|
||||
// x: batch, seq, hidden
|
||||
// cache: batch, kv_heads, prev_seq, head_dim
|
||||
let (batch, seq, _) = x.dims3();
|
||||
let (_, _, prev_seq, _) = k_cache.dims4();
|
||||
// Apply the Projections
|
||||
let queries = x
|
||||
.matmul(self.q_proj.permute())
|
||||
.reshape::<(Batch, CurSeq, Const<N_HEADS>, Const<HEAD_DIM>)>()
|
||||
.permute::<_, Axes4<0, 2, 1, 3>>();
|
||||
.matmul(self.q_proj.permute((1, 0)))
|
||||
.reshape((batch, seq, N_HEADS, HEAD_DIM))
|
||||
.permute((0, 2, 1, 3));
|
||||
|
||||
let keys = x
|
||||
.matmul(self.k_proj.permute())
|
||||
.reshape::<(Batch, CurSeq, Const<N_KV_HEADS>, Const<HEAD_DIM>)>()
|
||||
.permute::<_, Axes4<0, 2, 1, 3>>();
|
||||
.matmul(self.k_proj.permute((1, 0)))
|
||||
.reshape((batch, seq, N_KV_HEADS, HEAD_DIM))
|
||||
.permute((0, 2, 1, 3));
|
||||
|
||||
let values = x
|
||||
.matmul(self.v_proj.permute())
|
||||
.reshape::<(Batch, CurSeq, Const<N_KV_HEADS>, Const<HEAD_DIM>)>()
|
||||
.permute::<_, Axes4<0, 2, 1, 3>>();
|
||||
.matmul(self.v_proj.permute((1, 0)))
|
||||
.reshape((batch, seq, N_KV_HEADS, HEAD_DIM))
|
||||
.permute((0, 2, 1, 3));
|
||||
|
||||
// Rotary embed queries and keys
|
||||
let queries = apply_rotary_embeddings_ggml(queries, PrevSeq::size().big());
|
||||
let keys = apply_rotary_embeddings_ggml(keys, PrevSeq::size().big());
|
||||
let queries = apply_rotary_embeddings_ggml(queries, prev_seq.big());
|
||||
let keys = apply_rotary_embeddings_ggml(keys, prev_seq.big());
|
||||
|
||||
// Add KV cache
|
||||
let keys = k_cache.concat_along::<_, Axis<2>, _>(keys);
|
||||
let values = v_cache.concat_along::<_, Axis<2>, _>(values);
|
||||
let keys = k_cache.concat_along(keys, 2);
|
||||
let values = v_cache.concat_along(values, 2);
|
||||
|
||||
// Repeat the KV States for Grouped-Query Attention
|
||||
let repeated_keys = keys.expand::<(_, _, Const<N_ATTENTION_GROUPS>, _, _), _>();
|
||||
let repeated_values = values.expand::<(_, _, Const<N_ATTENTION_GROUPS>, _, _), _>();
|
||||
let repeated_keys = keys.expand(2, N_ATTENTION_GROUPS);
|
||||
let repeated_values = values.expand(2, N_ATTENTION_GROUPS);
|
||||
|
||||
// Calculate attention weights
|
||||
let mut attention_weights = queries
|
||||
.reshape::<(_, Const<N_KV_HEADS>, Const<N_ATTENTION_GROUPS>, _, _)>() // Split query heads into groups
|
||||
.matmul(repeated_keys.permute())
|
||||
.div((HEAD_DIM as f32).sqrt());
|
||||
.reshape((batch, N_KV_HEADS, N_ATTENTION_GROUPS, seq, HEAD_DIM)) // Split query heads into groups
|
||||
.matmul(repeated_keys.permute((0, 1, 2, 4, 3)))
|
||||
/ (HEAD_DIM as f32).sqrt();
|
||||
|
||||
let attention_mask = self.k_proj.graph().triu::<CurSeq>(1) * f16::MIN.to_f32();
|
||||
let attention_mask = self.k_proj.graph().triu(seq, 1) * f16::MIN.to_f32();
|
||||
attention_weights += attention_mask
|
||||
.pad::<(CurSeq, TotSeq)>(((0, 0), (TotSeq::size() - CurSeq::size(), 0)))
|
||||
.expand();
|
||||
.pad(((0, 0), (prev_seq, 0)))
|
||||
.expand(0, batch)
|
||||
.expand(1, N_KV_HEADS)
|
||||
.expand(2, N_ATTENTION_GROUPS);
|
||||
|
||||
// Calculate final outputs
|
||||
let output = attention_weights
|
||||
.softmax::<Axis<4>>()
|
||||
.softmax(4)
|
||||
// Apply distribution to values
|
||||
.matmul(repeated_values)
|
||||
// Merge heads
|
||||
.permute::<_, Axes5<0, 3, 1, 2, 4>>()
|
||||
.reshape::<(Batch, CurSeq, Const<HIDDEN_DIM>)>();
|
||||
.permute((0, 3, 1, 2, 4))
|
||||
.reshape((batch, seq, HIDDEN_DIM));
|
||||
let output = output
|
||||
// Apply output projection
|
||||
.matmul(self.o_proj.permute());
|
||||
.matmul(self.o_proj.permute((1, 0)));
|
||||
(output, (keys.contiguous(), values.contiguous())) // Cache needs to be contiguous for transferring to another graph
|
||||
}
|
||||
}
|
||||
|
||||
impl InitModule for SelfAttention {
|
||||
fn initialize(cx: &mut Graph) -> Self {
|
||||
impl SelfAttention {
|
||||
pub fn new(cx: &mut Graph) -> Self {
|
||||
Self {
|
||||
q_proj: cx.named_tensor("Q Proj"),
|
||||
k_proj: cx.named_tensor("K Proj"),
|
||||
v_proj: cx.named_tensor("V Proj"),
|
||||
o_proj: cx.named_tensor("O Proj"),
|
||||
q_proj: cx.named_tensor("Q Proj", (HIDDEN_DIM, HIDDEN_DIM)),
|
||||
k_proj: cx.named_tensor("K Proj", (ATTN_PROJ_DIM, HIDDEN_DIM)),
|
||||
v_proj: cx.named_tensor("V Proj", (ATTN_PROJ_DIM, HIDDEN_DIM)),
|
||||
o_proj: cx.named_tensor("O Proj", (HIDDEN_DIM, HIDDEN_DIM)),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -194,54 +163,37 @@ impl SerializeModule for SelfAttention {
|
||||
|
||||
pub struct TransformerBlock {
|
||||
pub attention: SelfAttention,
|
||||
pub attention_norm: LayerNorm<HIDDEN_DIM>,
|
||||
pub feed_forward: Mlp<MLP_DIM, HIDDEN_DIM>,
|
||||
pub feed_forward_norm: LayerNorm<HIDDEN_DIM>,
|
||||
pub attention_norm: LayerNorm,
|
||||
pub feed_forward: Mlp,
|
||||
pub feed_forward_norm: LayerNorm,
|
||||
}
|
||||
|
||||
impl<Batch: Dimension, CurSeq: Dimension, PrevSeq: Dimension, TotSeq: Dimension>
|
||||
Module<(
|
||||
GraphTensor<(Batch, CurSeq, Const<HIDDEN_DIM>)>,
|
||||
KVCache<Batch, PrevSeq>,
|
||||
PhantomData<TotSeq>,
|
||||
)> for TransformerBlock
|
||||
{
|
||||
type Output = (
|
||||
GraphTensor<(Batch, CurSeq, Const<HIDDEN_DIM>)>,
|
||||
KVCache<Batch, TotSeq>,
|
||||
);
|
||||
fn forward(
|
||||
&self,
|
||||
(mut x, cache, _): (
|
||||
GraphTensor<(Batch, CurSeq, Const<HIDDEN_DIM>)>,
|
||||
KVCache<Batch, PrevSeq>,
|
||||
PhantomData<TotSeq>,
|
||||
),
|
||||
) -> Self::Output {
|
||||
impl Module<(GraphTensor, KVCache)> for TransformerBlock {
|
||||
type Output = (GraphTensor, KVCache);
|
||||
fn forward(&self, (mut x, cache): (GraphTensor, KVCache)) -> Self::Output {
|
||||
// Attention
|
||||
let normed = self.attention_norm.forward(x);
|
||||
let (y, cache) = self
|
||||
.attention
|
||||
.forward((normed, cache, PhantomData::<TotSeq>));
|
||||
.forward((self.attention_norm.forward(x), cache));
|
||||
|
||||
// Residual Addition
|
||||
// Residual
|
||||
x += y;
|
||||
|
||||
// Feed Forward
|
||||
let y = self.feed_forward.forward(self.feed_forward_norm.forward(x));
|
||||
|
||||
// Residual Addition
|
||||
// Residual
|
||||
(x + y, cache)
|
||||
}
|
||||
}
|
||||
|
||||
impl InitModule for TransformerBlock {
|
||||
fn initialize(cx: &mut Graph) -> Self {
|
||||
impl TransformerBlock {
|
||||
pub fn new(cx: &mut Graph) -> Self {
|
||||
Self {
|
||||
attention: InitModule::initialize(cx),
|
||||
attention_norm: LayerNorm::new(true, false, false, 1e-5, cx),
|
||||
feed_forward: InitModule::initialize(cx),
|
||||
feed_forward_norm: LayerNorm::new(true, false, false, 1e-5, cx),
|
||||
attention: SelfAttention::new(cx),
|
||||
attention_norm: LayerNorm::new(HIDDEN_DIM, true, false, false, 1e-5, cx),
|
||||
feed_forward: Mlp::new(HIDDEN_DIM, MLP_DIM, cx),
|
||||
feed_forward_norm: LayerNorm::new(HIDDEN_DIM, true, false, false, 1e-5, cx),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -255,36 +207,18 @@ impl SerializeModule for TransformerBlock {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MistralLM {
|
||||
pub struct Llama {
|
||||
// Token embeddings
|
||||
pub embedding: Embedding<VOCAB_SIZE, HIDDEN_DIM>,
|
||||
pub embedding: Embedding,
|
||||
// Transformer layers
|
||||
pub layers: Vec<TransformerBlock>,
|
||||
// Final Norm layer
|
||||
pub norm: LayerNorm<HIDDEN_DIM>,
|
||||
// LM Head Layer
|
||||
pub lm_head: GraphTensor<R2<VOCAB_SIZE, HIDDEN_DIM>>,
|
||||
// Norm + LM head
|
||||
pub head: (LayerNorm, Linear),
|
||||
}
|
||||
|
||||
impl<Batch: Dimension, CurSeq: Dimension, PrevSeq: Dimension, TotSeq: Dimension>
|
||||
Module<(
|
||||
GraphTensor<(Batch, CurSeq)>,
|
||||
&[KVCache<Batch, PrevSeq>],
|
||||
PhantomData<TotSeq>,
|
||||
)> for MistralLM
|
||||
{
|
||||
type Output = (
|
||||
GraphTensor<(Batch, CurSeq, Const<VOCAB_SIZE>)>,
|
||||
Vec<KVCache<Batch, TotSeq>>,
|
||||
);
|
||||
fn forward(
|
||||
&self,
|
||||
(input, cache, _): (
|
||||
GraphTensor<(Batch, CurSeq)>,
|
||||
&[KVCache<Batch, PrevSeq>],
|
||||
PhantomData<TotSeq>,
|
||||
),
|
||||
) -> Self::Output {
|
||||
impl Module<(GraphTensor, &[KVCache])> for Llama {
|
||||
type Output = (GraphTensor, Vec<KVCache>);
|
||||
fn forward(&self, (input, cache): (GraphTensor, &[KVCache])) -> Self::Output {
|
||||
// Embed tokens
|
||||
let mut x = self.embedding.forward(input);
|
||||
|
||||
@@ -292,36 +226,32 @@ impl<Batch: Dimension, CurSeq: Dimension, PrevSeq: Dimension, TotSeq: Dimension>
|
||||
let mut new_caches = vec![];
|
||||
let mut new_cache;
|
||||
for (i, layer) in self.layers.iter().enumerate() {
|
||||
(x, new_cache) = layer.forward((x, cache[i], PhantomData::<TotSeq>));
|
||||
(x, new_cache) = layer.forward((x, cache[i]));
|
||||
new_caches.push(new_cache);
|
||||
}
|
||||
// Run through last norm and output projection
|
||||
let output = self.norm.forward(x).matmul(self.lm_head.permute());
|
||||
|
||||
(output, new_caches)
|
||||
(self.head.forward(x), new_caches)
|
||||
}
|
||||
}
|
||||
|
||||
impl InitModule for MistralLM {
|
||||
fn initialize(cx: &mut Graph) -> Self {
|
||||
impl Llama {
|
||||
pub fn new(cx: &mut Graph) -> Self {
|
||||
Self {
|
||||
embedding: Embedding {
|
||||
weight: cx.named_tensor("Embedding Weight"),
|
||||
},
|
||||
norm: LayerNorm::new(true, false, false, 1e-5, cx),
|
||||
lm_head: cx.named_tensor("LM Head"),
|
||||
layers: (0..NUM_LAYERS)
|
||||
.map(|_| InitModule::initialize(cx))
|
||||
.collect(),
|
||||
embedding: Embedding::new(VOCAB_SIZE, HIDDEN_DIM, cx),
|
||||
head: (
|
||||
LayerNorm::new(HIDDEN_DIM, true, false, false, 1e-5, cx),
|
||||
Linear::new_permuted(HIDDEN_DIM, VOCAB_SIZE, false, cx),
|
||||
),
|
||||
layers: (0..NUM_LAYERS).map(|_| TransformerBlock::new(cx)).collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SerializeModule for MistralLM {
|
||||
impl SerializeModule for Llama {
|
||||
fn serialize(&self, s: &mut Serializer) {
|
||||
s.module("token_embd", &self.embedding);
|
||||
s.module("output_norm", &self.norm);
|
||||
s.tensor("output/weight", self.lm_head);
|
||||
s.module("output_norm", &self.head.0);
|
||||
s.module("output", &self.head.1);
|
||||
for (i, layer) in self.layers.iter().enumerate() {
|
||||
s.module(&format!("blk/{i}"), layer);
|
||||
}
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
use std::{
|
||||
io::{self, Write},
|
||||
marker::PhantomData,
|
||||
path::Path,
|
||||
time::Instant,
|
||||
};
|
||||
@@ -11,18 +10,16 @@ use tokenizers::Tokenizer;
|
||||
|
||||
use crate::llama::{
|
||||
loader,
|
||||
model::{KVCache, MistralLM, HEAD_DIM, NUM_LAYERS, N_KV_HEADS},
|
||||
model::{KVCache, Llama, HEAD_DIM, NUM_LAYERS, N_KV_HEADS},
|
||||
};
|
||||
|
||||
use super::model::VOCAB_SIZE;
|
||||
|
||||
/// Define the model
|
||||
pub struct Model {
|
||||
pub graph: Box<Graph>,
|
||||
pub input: GraphTensor<(Const<1>, Dyn<'s'>)>,
|
||||
pub input: GraphTensor,
|
||||
kv_cache_src_set: Vec<NodeIndex>,
|
||||
kv_cache_dest_set: Vec<NodeIndex>,
|
||||
logits: GraphTensor<R3<1, 1, { VOCAB_SIZE }>>,
|
||||
logits: GraphTensor,
|
||||
pub tokenizer: Tokenizer,
|
||||
pub last_generated_token: Option<u32>,
|
||||
}
|
||||
@@ -50,19 +47,23 @@ impl Model {
|
||||
// Set up graph
|
||||
let mut cx = Box::new(Graph::new());
|
||||
|
||||
let mut input = cx.named_tensor::<(Const<1>, Dyn<'s'>)>("Input");
|
||||
let mut cache_src: Vec<KVCache<Const<1>, Dyn<'p'>>> = (0..NUM_LAYERS)
|
||||
.map(|_| (cx.named_tensor("Key Cache"), cx.named_tensor("Value Cache")))
|
||||
let mut input = cx.named_tensor("Input", (1, 's'));
|
||||
let mut cache_src: Vec<KVCache> = (0..NUM_LAYERS)
|
||||
.map(|_| {
|
||||
(
|
||||
cx.named_tensor("Key Cache", (1, N_KV_HEADS, 'p', HEAD_DIM)),
|
||||
cx.named_tensor("Value Cache", (1, N_KV_HEADS, 'p', HEAD_DIM)),
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
cache_src.set_dyn(vec![], &[1, N_KV_HEADS, 0, HEAD_DIM]);
|
||||
let model = MistralLM::initialize(&mut cx);
|
||||
cache_src.set_dyn(vec![], (1, N_KV_HEADS, 0, HEAD_DIM));
|
||||
let model = Llama::new(&mut cx);
|
||||
let mut model_weights = params(&model);
|
||||
cx.keep_tensors(&model_weights);
|
||||
let (logits, mut cache_dest) = model.forward((input, &cache_src, PhantomData::<Dyn<'t'>>));
|
||||
let (logits, mut cache_dest) = model.forward((input, &cache_src));
|
||||
let mut logits = logits
|
||||
.slice((.., (Expression::from('s') - 1).., ..))
|
||||
.retrieve()
|
||||
.realize();
|
||||
.retrieve();
|
||||
cache_dest.keep();
|
||||
|
||||
// Set up model loading
|
||||
@@ -105,7 +106,7 @@ impl Model {
|
||||
print!("Loading model");
|
||||
io::stdout().flush().unwrap();
|
||||
let now = Instant::now();
|
||||
input.set_dyn(vec![1.], &[1, 1]);
|
||||
input.set_dyn(vec![1.], (1, 1));
|
||||
cx.set_dyn_dim('t', 1);
|
||||
cx.execute();
|
||||
logits.drop();
|
||||
@@ -159,7 +160,7 @@ impl Model {
|
||||
self.graph.set_dyn_dim('t', seq_len);
|
||||
self.input.set_dyn(
|
||||
input_ids.iter().map(|i| *i as f32).collect::<Vec<_>>(),
|
||||
&[1, input_ids.len()],
|
||||
(1, input_ids.len()),
|
||||
);
|
||||
|
||||
// First token output (from prompt processing)
|
||||
@@ -183,7 +184,7 @@ impl Model {
|
||||
// Set the data
|
||||
self.graph.set_dyn_dim('p', seq_len - 1);
|
||||
self.graph.set_dyn_dim('t', seq_len);
|
||||
self.input.set_dyn(vec![output_id as f32], &[1, 1]);
|
||||
self.input.set_dyn(vec![output_id as f32], (1, 1));
|
||||
|
||||
// Execute the graph
|
||||
self.graph.execute();
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
use std::{
|
||||
io::{self, Write},
|
||||
marker::PhantomData,
|
||||
time::Instant,
|
||||
};
|
||||
|
||||
use clap::Parser;
|
||||
use colored::Colorize;
|
||||
use itertools::Itertools;
|
||||
use model::{Phi, HEAD_DIM, N_HEADS};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
mod gguf;
|
||||
@@ -39,15 +39,20 @@ fn main() {
|
||||
|
||||
// Set up graph
|
||||
let mut cx = Graph::new();
|
||||
let mut input = cx.named_tensor::<(Const<1>, Dyn<'s'>)>("Input");
|
||||
let mut cache_src: Vec<KVCache<Const<1>, Dyn<'p'>>> = (0..model::NUM_LAYERS)
|
||||
.map(|_| (cx.named_tensor("Key Cache"), cx.named_tensor("Value Cache")))
|
||||
let mut input = cx.named_tensor("Input", (1, 's'));
|
||||
let mut cache_src: Vec<KVCache> = (0..model::NUM_LAYERS)
|
||||
.map(|_| {
|
||||
(
|
||||
cx.named_tensor("Key Cache", (1, N_HEADS, 'p', HEAD_DIM)),
|
||||
cx.named_tensor("Value Cache", (1, N_HEADS, 'p', HEAD_DIM)),
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
cache_src.set_dyn(vec![], &[1, model::N_HEADS, 0, model::HEAD_DIM]);
|
||||
let model = model::Phi::initialize(&mut cx);
|
||||
cache_src.set_dyn(vec![], (1, N_HEADS, 0, HEAD_DIM));
|
||||
let model = Phi::new(&mut cx);
|
||||
let mut model_weights = params(&model);
|
||||
cx.keep_tensors(&model_weights);
|
||||
let (logits, mut cache_dest) = model.forward((input, &cache_src, PhantomData::<Dyn<'t'>>));
|
||||
let (logits, mut cache_dest) = model.forward((input, &cache_src));
|
||||
let mut logits = logits
|
||||
.slice((.., (Expression::from('s') - 1).., ..))
|
||||
.retrieve();
|
||||
@@ -73,7 +78,7 @@ fn main() {
|
||||
luminal_metal::BufferCompilers::default(),
|
||||
),
|
||||
#[cfg(feature = "cuda")]
|
||||
luminal_cuda::CudaQuantizedCompiler::<f32>::new(q_weights),
|
||||
luminal_cuda::CudaQuantizedCompiler::<f16>::new(q_weights),
|
||||
#[cfg(all(not(feature = "metal"), not(feature = "cuda")))]
|
||||
luminal_cpu::CPUCompiler::default(),
|
||||
),
|
||||
@@ -86,33 +91,32 @@ fn main() {
|
||||
),
|
||||
);
|
||||
let cache_src = downstream(&cache_src, &cx);
|
||||
let cache_dest = cache_dest.to_ids();
|
||||
println!("\t\t - {}ms", now.elapsed().as_millis());
|
||||
|
||||
// Initial forward pass to load weights
|
||||
print!("Loading model");
|
||||
io::stdout().flush().unwrap();
|
||||
let now = Instant::now();
|
||||
input.set_dyn(vec![1.], &[1, 1]);
|
||||
input.set_dyn(vec![1.], (1, 1));
|
||||
cx.set_dyn_dim('t', 1);
|
||||
cx.execute();
|
||||
logits.drop();
|
||||
cx.drop_tensors(&cache_dest);
|
||||
transfer_data_same_graph(&cache_dest, &cache_src, &mut cx);
|
||||
println!("\t\t - {}ms", now.elapsed().as_millis());
|
||||
|
||||
// Now that weights are loaded, delete the loading nodes so they don't run again
|
||||
delete_inputs(&cache_src, &mut cx);
|
||||
delete_inputs(&downstream(model_weights, &cx), &mut cx);
|
||||
|
||||
// Run prompt processing pass
|
||||
let mut input_ids = tokenizer
|
||||
.encode(&cli_args.prompt as &str, true)
|
||||
let input_ids = tokenizer
|
||||
.encode(&cli_args.prompt as &str, false)
|
||||
.unwrap()
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
input_ids.insert(0, 1);
|
||||
input.set_dyn(
|
||||
input_ids.iter().map(|i| *i as f32).collect::<Vec<_>>(),
|
||||
&[1, input_ids.len()],
|
||||
(1, input_ids.len()),
|
||||
);
|
||||
cx.set_dyn_dim('t', input_ids.len());
|
||||
print!("Processing Prompt");
|
||||
@@ -125,14 +129,13 @@ fn main() {
|
||||
1000.0 * (input_ids.len() as f64) / (elapsed_ms as f64),
|
||||
input_ids.len()
|
||||
);
|
||||
delete_inputs(&cache_src, &mut cx);
|
||||
let mut output_ids = vec![argmax(&logits.data())];
|
||||
logits.drop();
|
||||
|
||||
// Decode token
|
||||
print!("{}", cli_args.prompt.white().bold());
|
||||
let out_str = tokenizer.decode(&output_ids, false).unwrap().bright_green();
|
||||
let mut prev_output_len = out_str.len();
|
||||
print!("{out_str}");
|
||||
let out = tokenizer.decode(&output_ids, false).unwrap().bright_green();
|
||||
print!("{out}");
|
||||
io::stdout().flush().unwrap();
|
||||
|
||||
// Swap caches
|
||||
@@ -140,10 +143,10 @@ fn main() {
|
||||
|
||||
// Decode loop
|
||||
let start_decode = std::time::Instant::now();
|
||||
let mut prev_output_len = out.len();
|
||||
for _ in 0..cli_args.gen_tokens {
|
||||
input.set_dyn(vec![*output_ids.last().unwrap() as f32], &[1, 1]);
|
||||
input.set_dyn(vec![*output_ids.last().unwrap() as f32], (1, 1));
|
||||
cx.set_dyn_dim('p', input_ids.len() + output_ids.len() - 1);
|
||||
cx.set_dyn_dim('t', input_ids.len() + output_ids.len());
|
||||
cx.execute();
|
||||
|
||||
// Sample tokens
|
||||
@@ -166,9 +169,8 @@ fn main() {
|
||||
}
|
||||
|
||||
println!();
|
||||
let avg_token_time = (std::time::Instant::now() - start_decode).as_micros() as f32
|
||||
/ (output_ids.len() - 1) as f32
|
||||
/ 1000.0;
|
||||
let avg_token_time =
|
||||
start_decode.elapsed().as_micros() as f32 / (output_ids.len() - 1) as f32 / 1000.0;
|
||||
println!(
|
||||
"\nAverage token generated in {:.2}ms\t - ({:.2} tok/s)",
|
||||
avg_token_time,
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use luminal::prelude::{binary::F32Pow, *};
|
||||
use luminal_nn::{Embedding, LayerNorm, PermutedLinear};
|
||||
use luminal_nn::{Embedding, LayerNorm, Linear};
|
||||
|
||||
// Llama3 8B Config
|
||||
// Phi3 mini Config
|
||||
pub const VOCAB_SIZE: usize = 32064;
|
||||
pub const HIDDEN_DIM: usize = 3072;
|
||||
pub const NUM_LAYERS: usize = 32;
|
||||
@@ -11,51 +9,37 @@ pub const N_HEADS: usize = 32;
|
||||
pub const MLP_DIM: usize = 8192;
|
||||
|
||||
pub const HEAD_DIM: usize = HIDDEN_DIM / N_HEADS;
|
||||
pub const HEAD_DIM_OVER_2: usize = HEAD_DIM / 2;
|
||||
pub const ATTN_PROJ_DIM: usize = HEAD_DIM * N_HEADS;
|
||||
|
||||
pub type KVCache<Batch, Seq> = (
|
||||
GraphTensor<(Batch, Const<N_HEADS>, Seq, Const<HEAD_DIM>)>,
|
||||
GraphTensor<(Batch, Const<N_HEADS>, Seq, Const<HEAD_DIM>)>,
|
||||
);
|
||||
pub type KVCache = (GraphTensor, GraphTensor);
|
||||
|
||||
pub struct Mlp<const I: usize, const H: usize> {
|
||||
pub gate_proj: PermutedLinear<H, I>,
|
||||
pub down_proj: PermutedLinear<I, H>,
|
||||
pub up_proj: PermutedLinear<H, I>,
|
||||
pub struct Mlp {
|
||||
pub gate_proj: Linear, // hidden -> intermediate
|
||||
pub down_proj: Linear, // intermediate -> hidden
|
||||
pub up_proj: Linear, // hidden -> intermediate
|
||||
}
|
||||
|
||||
impl<Sh: Shape, Im: Shape, const I: usize, const H: usize> Module<GraphTensor<Sh>> for Mlp<I, H>
|
||||
where
|
||||
GraphTensor<Sh>: Matmul<R2<H, I>, Output = GraphTensor<Im>>,
|
||||
GraphTensor<Im>: Matmul<R2<I, H>, Output = GraphTensor<Sh>>,
|
||||
{
|
||||
type Output = GraphTensor<Sh>;
|
||||
impl Module<GraphTensor> for Mlp {
|
||||
type Output = GraphTensor;
|
||||
|
||||
fn forward(&self, input: GraphTensor<Sh>) -> Self::Output {
|
||||
fn forward(&self, input: GraphTensor) -> Self::Output {
|
||||
let gate = self.gate_proj.forward(input).swish();
|
||||
let up = self.up_proj.forward(input) * gate;
|
||||
self.down_proj.forward(up)
|
||||
}
|
||||
}
|
||||
|
||||
impl<const I: usize, const H: usize> InitModule for Mlp<I, H> {
|
||||
fn initialize(cx: &mut Graph) -> Self {
|
||||
impl Mlp {
|
||||
pub fn new(hidden: usize, intermediate: usize, cx: &mut Graph) -> Self {
|
||||
Self {
|
||||
gate_proj: PermutedLinear {
|
||||
weight: cx.named_tensor("Gate"),
|
||||
},
|
||||
up_proj: PermutedLinear {
|
||||
weight: cx.named_tensor("Up"),
|
||||
},
|
||||
down_proj: PermutedLinear {
|
||||
weight: cx.named_tensor("Down"),
|
||||
},
|
||||
gate_proj: Linear::new_permuted(hidden, intermediate, false, cx),
|
||||
down_proj: Linear::new_permuted(intermediate, hidden, false, cx),
|
||||
up_proj: Linear::new_permuted(hidden, intermediate, false, cx),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<const I: usize, const H: usize> SerializeModule for Mlp<I, H> {
|
||||
impl SerializeModule for Mlp {
|
||||
fn serialize(&self, s: &mut Serializer) {
|
||||
s.module("ffn_gate", &self.gate_proj);
|
||||
s.module("ffn_up", &self.up_proj);
|
||||
@@ -63,115 +47,98 @@ impl<const I: usize, const H: usize> SerializeModule for Mlp<I, H> {
|
||||
}
|
||||
}
|
||||
|
||||
fn apply_rotary_embeddings_ggml<const N_HEADS: usize, Batch: Dimension, Seq: Dimension>(
|
||||
input: GraphTensor<(Batch, Const<N_HEADS>, Seq, Const<HEAD_DIM>)>,
|
||||
prev_seq: BigExpression,
|
||||
) -> GraphTensor<(Batch, Const<N_HEADS>, Seq, Const<HEAD_DIM>)> {
|
||||
fn apply_rotary_embeddings_ggml(input: GraphTensor, prev_seq: BigExpression) -> GraphTensor {
|
||||
assert_eq!(input.shape.len(), 4); // batch, n_heads, seq, head_dim
|
||||
let (batch, n_heads, seq, head_dim) = input.dims4();
|
||||
// Get freqs
|
||||
let freqs = (input.graph().arange::<Const<HEAD_DIM_OVER_2>>() * 2.0) / (HEAD_DIM as f32);
|
||||
let freqs = (input.graph().arange(head_dim / 2) * 2.0) / (head_dim.to_usize().unwrap() as f32);
|
||||
let freqs = 10_000_f32.pow(freqs);
|
||||
let pos = input.graph().arange::<Seq>() + prev_seq;
|
||||
let emb = pos.expand::<(_, Const<1>), _>().matmul(freqs.expand());
|
||||
let pos = input.graph().arange(seq) + prev_seq;
|
||||
let emb = pos.expand(1, 1).matmul(freqs.expand(0, 1));
|
||||
|
||||
// Split input into evens and odds
|
||||
let split = input.reshape::<(Batch, Const<N_HEADS>, Seq, Const<HEAD_DIM_OVER_2>, Const<2>)>();
|
||||
let x0: GraphTensor<(Batch, Const<N_HEADS>, Seq, Const<HEAD_DIM_OVER_2>, Const<1>)> =
|
||||
split.slice((.., .., .., .., ..1)).realize();
|
||||
let x1: GraphTensor<(Batch, Const<N_HEADS>, Seq, Const<HEAD_DIM_OVER_2>, Const<1>)> =
|
||||
split.slice((.., .., .., .., 1..)).realize();
|
||||
let split = input.reshape((batch, n_heads, seq, head_dim / 2, 2));
|
||||
let x0 = split.slice((.., .., .., .., ..1));
|
||||
let x1 = split.slice((.., .., .., .., 1..));
|
||||
|
||||
// Apply sin and cos embeddings
|
||||
let x0_out = x0 * emb.cos().expand() - x1 * emb.sin().expand();
|
||||
let x1_out = x0 * emb.sin().expand() + x1 * emb.cos().expand();
|
||||
let x0_out = x0 * emb.cos().expand_to(x0.shape) - x1 * emb.sin().expand_to(x1.shape);
|
||||
let x1_out = x0 * emb.sin().expand_to(x0.shape) + x1 * emb.cos().expand_to(x1.shape);
|
||||
|
||||
// Combine back into output
|
||||
x0_out
|
||||
.concat_along::<(Batch, Const<N_HEADS>, Seq, Const<HEAD_DIM_OVER_2>, Const<2>), Axis<4>, _>(
|
||||
x1_out,
|
||||
)
|
||||
.reshape()
|
||||
x0_out.concat_along(x1_out, 4).reshape(input.shape)
|
||||
}
|
||||
|
||||
pub struct SelfAttention {
|
||||
pub q_proj: GraphTensor<R2<HIDDEN_DIM, HIDDEN_DIM>>,
|
||||
pub k_proj: GraphTensor<R2<ATTN_PROJ_DIM, HIDDEN_DIM>>,
|
||||
pub v_proj: GraphTensor<R2<ATTN_PROJ_DIM, HIDDEN_DIM>>,
|
||||
pub o_proj: GraphTensor<R2<HIDDEN_DIM, HIDDEN_DIM>>,
|
||||
pub q_proj: GraphTensor, // Hidden -> hidden
|
||||
pub k_proj: GraphTensor, // Proj dim -> hidden
|
||||
pub v_proj: GraphTensor, // Proj dim -> hidden
|
||||
pub o_proj: GraphTensor, // Hidden -> hidden
|
||||
}
|
||||
|
||||
impl<Batch: Dimension, CurSeq: Dimension, PrevSeq: Dimension, TotSeq: Dimension>
|
||||
Module<(
|
||||
GraphTensor<(Batch, CurSeq, Const<HIDDEN_DIM>)>,
|
||||
KVCache<Batch, PrevSeq>,
|
||||
PhantomData<TotSeq>,
|
||||
)> for SelfAttention
|
||||
{
|
||||
type Output = (
|
||||
GraphTensor<(Batch, CurSeq, Const<HIDDEN_DIM>)>,
|
||||
KVCache<Batch, TotSeq>,
|
||||
);
|
||||
fn forward(
|
||||
&self,
|
||||
(x, (k_cache, v_cache), _): (
|
||||
GraphTensor<(Batch, CurSeq, Const<HIDDEN_DIM>)>,
|
||||
KVCache<Batch, PrevSeq>,
|
||||
PhantomData<TotSeq>,
|
||||
),
|
||||
) -> Self::Output {
|
||||
impl Module<(GraphTensor, KVCache)> for SelfAttention {
|
||||
type Output = (GraphTensor, KVCache);
|
||||
fn forward(&self, (x, (k_cache, v_cache)): (GraphTensor, KVCache)) -> Self::Output {
|
||||
// x: batch, seq, hidden
|
||||
// cache: batch, kv_heads, prev_seq, head_dim
|
||||
let (batch, seq, _) = x.dims3();
|
||||
let (_, _, prev_seq, _) = k_cache.dims4();
|
||||
// Apply the Projections
|
||||
let queries = x
|
||||
.matmul(self.q_proj.permute())
|
||||
.reshape::<(Batch, CurSeq, Const<N_HEADS>, Const<HEAD_DIM>)>()
|
||||
.permute::<_, Axes4<0, 2, 1, 3>>();
|
||||
.matmul(self.q_proj.permute((1, 0)))
|
||||
.reshape((batch, seq, N_HEADS, HEAD_DIM))
|
||||
.permute((0, 2, 1, 3));
|
||||
|
||||
let keys = x
|
||||
.matmul(self.k_proj.permute())
|
||||
.reshape::<(Batch, CurSeq, Const<N_HEADS>, Const<HEAD_DIM>)>()
|
||||
.permute::<_, Axes4<0, 2, 1, 3>>();
|
||||
.matmul(self.k_proj.permute((1, 0)))
|
||||
.reshape((batch, seq, N_HEADS, HEAD_DIM))
|
||||
.permute((0, 2, 1, 3));
|
||||
|
||||
let values = x
|
||||
.matmul(self.v_proj.permute())
|
||||
.reshape::<(Batch, CurSeq, Const<N_HEADS>, Const<HEAD_DIM>)>()
|
||||
.permute::<_, Axes4<0, 2, 1, 3>>();
|
||||
.matmul(self.v_proj.permute((1, 0)))
|
||||
.reshape((batch, seq, N_HEADS, HEAD_DIM))
|
||||
.permute((0, 2, 1, 3));
|
||||
|
||||
// Rotary embed queries and keys
|
||||
let queries = apply_rotary_embeddings_ggml(queries, PrevSeq::size().into());
|
||||
let keys = apply_rotary_embeddings_ggml(keys, PrevSeq::size().into());
|
||||
let queries = apply_rotary_embeddings_ggml(queries, prev_seq.big());
|
||||
let keys = apply_rotary_embeddings_ggml(keys, prev_seq.big());
|
||||
|
||||
// Add KV cache
|
||||
let keys = k_cache.concat_along::<_, Axis<2>, _>(keys);
|
||||
let values = v_cache.concat_along::<_, Axis<2>, _>(values);
|
||||
let keys = k_cache.concat_along(keys, 2);
|
||||
let values = v_cache.concat_along(values, 2);
|
||||
|
||||
// Calculate attention weights
|
||||
let mut attention_weights = queries.matmul(keys.permute()) / (HEAD_DIM as f32).sqrt();
|
||||
let mut attention_weights =
|
||||
queries.matmul(keys.permute((0, 1, 3, 2))) / (HEAD_DIM as f32).sqrt();
|
||||
|
||||
let attention_mask = self.k_proj.graph().triu::<CurSeq>(1) * f16::MIN.to_f32();
|
||||
let attention_mask = self.k_proj.graph().triu(seq, 1) * f16::MIN.to_f32();
|
||||
attention_weights += attention_mask
|
||||
.pad::<(CurSeq, TotSeq)>(((0, 0), (TotSeq::size() - CurSeq::size(), 0)))
|
||||
.expand();
|
||||
.pad(((0, 0), (prev_seq, 0)))
|
||||
.expand(0, batch)
|
||||
.expand(1, N_HEADS);
|
||||
|
||||
// Calculate final outputs
|
||||
let output = attention_weights
|
||||
.softmax::<Axis<3>>()
|
||||
.softmax(3)
|
||||
// Apply distribution to values
|
||||
.matmul(values)
|
||||
// Merge heads
|
||||
.permute::<_, Axes4<0, 2, 1, 3>>()
|
||||
.reshape::<(Batch, CurSeq, Const<HIDDEN_DIM>)>();
|
||||
.permute((0, 2, 1, 3))
|
||||
.reshape((batch, seq, HIDDEN_DIM));
|
||||
let output = output
|
||||
// Apply output projection
|
||||
.matmul(self.o_proj.permute());
|
||||
.matmul(self.o_proj.permute((1, 0)));
|
||||
(output, (keys.contiguous(), values.contiguous())) // Cache needs to be contiguous for transferring to another graph
|
||||
}
|
||||
}
|
||||
|
||||
impl InitModule for SelfAttention {
|
||||
fn initialize(cx: &mut Graph) -> Self {
|
||||
impl SelfAttention {
|
||||
pub fn new(cx: &mut Graph) -> Self {
|
||||
Self {
|
||||
q_proj: cx.named_tensor("Q Proj"),
|
||||
k_proj: cx.named_tensor("K Proj"),
|
||||
v_proj: cx.named_tensor("V Proj"),
|
||||
o_proj: cx.named_tensor("O Proj"),
|
||||
q_proj: cx.named_tensor("Q Proj", (HIDDEN_DIM, HIDDEN_DIM)),
|
||||
k_proj: cx.named_tensor("K Proj", (ATTN_PROJ_DIM, HIDDEN_DIM)),
|
||||
v_proj: cx.named_tensor("V Proj", (ATTN_PROJ_DIM, HIDDEN_DIM)),
|
||||
o_proj: cx.named_tensor("O Proj", (HIDDEN_DIM, HIDDEN_DIM)),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -187,54 +154,37 @@ impl SerializeModule for SelfAttention {
|
||||
|
||||
pub struct TransformerBlock {
|
||||
pub attention: SelfAttention,
|
||||
pub attention_norm: LayerNorm<HIDDEN_DIM>,
|
||||
pub feed_forward: Mlp<MLP_DIM, HIDDEN_DIM>,
|
||||
pub feed_forward_norm: LayerNorm<HIDDEN_DIM>,
|
||||
pub attention_norm: LayerNorm,
|
||||
pub feed_forward: Mlp,
|
||||
pub feed_forward_norm: LayerNorm,
|
||||
}
|
||||
|
||||
impl<Batch: Dimension, CurSeq: Dimension, PrevSeq: Dimension, TotSeq: Dimension>
|
||||
Module<(
|
||||
GraphTensor<(Batch, CurSeq, Const<HIDDEN_DIM>)>,
|
||||
KVCache<Batch, PrevSeq>,
|
||||
PhantomData<TotSeq>,
|
||||
)> for TransformerBlock
|
||||
{
|
||||
type Output = (
|
||||
GraphTensor<(Batch, CurSeq, Const<HIDDEN_DIM>)>,
|
||||
KVCache<Batch, TotSeq>,
|
||||
);
|
||||
fn forward(
|
||||
&self,
|
||||
(mut x, cache, _): (
|
||||
GraphTensor<(Batch, CurSeq, Const<HIDDEN_DIM>)>,
|
||||
KVCache<Batch, PrevSeq>,
|
||||
PhantomData<TotSeq>,
|
||||
),
|
||||
) -> Self::Output {
|
||||
impl Module<(GraphTensor, KVCache)> for TransformerBlock {
|
||||
type Output = (GraphTensor, KVCache);
|
||||
fn forward(&self, (mut x, cache): (GraphTensor, KVCache)) -> Self::Output {
|
||||
// Attention
|
||||
let normed = self.attention_norm.forward(x);
|
||||
let (y, cache) = self
|
||||
.attention
|
||||
.forward((normed, cache, PhantomData::<TotSeq>));
|
||||
.forward((self.attention_norm.forward(x), cache));
|
||||
|
||||
// Residual Addition
|
||||
// Residual
|
||||
x += y;
|
||||
|
||||
// Feed Forward
|
||||
let y = self.feed_forward.forward(self.feed_forward_norm.forward(x));
|
||||
|
||||
// Residual Addition
|
||||
// Residual
|
||||
(x + y, cache)
|
||||
}
|
||||
}
|
||||
|
||||
impl InitModule for TransformerBlock {
|
||||
fn initialize(cx: &mut Graph) -> Self {
|
||||
impl TransformerBlock {
|
||||
pub fn new(cx: &mut Graph) -> Self {
|
||||
Self {
|
||||
attention: InitModule::initialize(cx),
|
||||
attention_norm: LayerNorm::new(true, false, false, 1e-5, cx),
|
||||
feed_forward: InitModule::initialize(cx),
|
||||
feed_forward_norm: LayerNorm::new(true, false, false, 1e-5, cx),
|
||||
attention: SelfAttention::new(cx),
|
||||
attention_norm: LayerNorm::new(HIDDEN_DIM, true, false, false, 1e-5, cx),
|
||||
feed_forward: Mlp::new(HIDDEN_DIM, MLP_DIM, cx),
|
||||
feed_forward_norm: LayerNorm::new(HIDDEN_DIM, true, false, false, 1e-5, cx),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -250,34 +200,16 @@ impl SerializeModule for TransformerBlock {
|
||||
|
||||
pub struct Phi {
|
||||
// Token embeddings
|
||||
pub embedding: Embedding<VOCAB_SIZE, HIDDEN_DIM>,
|
||||
pub embedding: Embedding,
|
||||
// Transformer layers
|
||||
pub layers: Vec<TransformerBlock>,
|
||||
// Final Norm layer
|
||||
pub norm: LayerNorm<HIDDEN_DIM>,
|
||||
// LM Head Layer
|
||||
pub lm_head: PermutedLinear<HIDDEN_DIM, VOCAB_SIZE>,
|
||||
// Norm + LM head
|
||||
pub head: (LayerNorm, Linear),
|
||||
}
|
||||
|
||||
impl<Batch: Dimension, CurSeq: Dimension, PrevSeq: Dimension, TotSeq: Dimension>
|
||||
Module<(
|
||||
GraphTensor<(Batch, CurSeq)>,
|
||||
&[KVCache<Batch, PrevSeq>],
|
||||
PhantomData<TotSeq>,
|
||||
)> for Phi
|
||||
{
|
||||
type Output = (
|
||||
GraphTensor<(Batch, CurSeq, Const<VOCAB_SIZE>)>,
|
||||
Vec<KVCache<Batch, TotSeq>>,
|
||||
);
|
||||
fn forward(
|
||||
&self,
|
||||
(input, cache, _): (
|
||||
GraphTensor<(Batch, CurSeq)>,
|
||||
&[KVCache<Batch, PrevSeq>],
|
||||
PhantomData<TotSeq>,
|
||||
),
|
||||
) -> Self::Output {
|
||||
impl Module<(GraphTensor, &[KVCache])> for Phi {
|
||||
type Output = (GraphTensor, Vec<KVCache>);
|
||||
fn forward(&self, (input, cache): (GraphTensor, &[KVCache])) -> Self::Output {
|
||||
// Embed tokens
|
||||
let mut x = self.embedding.forward(input);
|
||||
|
||||
@@ -285,29 +217,23 @@ impl<Batch: Dimension, CurSeq: Dimension, PrevSeq: Dimension, TotSeq: Dimension>
|
||||
let mut new_caches = vec![];
|
||||
let mut new_cache;
|
||||
for (i, layer) in self.layers.iter().enumerate() {
|
||||
(x, new_cache) = layer.forward((x, cache[i], PhantomData::<TotSeq>));
|
||||
(x, new_cache) = layer.forward((x, cache[i]));
|
||||
new_caches.push(new_cache);
|
||||
}
|
||||
// Run through last norm and output projection
|
||||
let output = self.lm_head.forward(self.norm.forward(x));
|
||||
|
||||
(output, new_caches)
|
||||
(self.head.forward(x), new_caches)
|
||||
}
|
||||
}
|
||||
|
||||
impl InitModule for Phi {
|
||||
fn initialize(cx: &mut Graph) -> Self {
|
||||
impl Phi {
|
||||
pub fn new(cx: &mut Graph) -> Self {
|
||||
Self {
|
||||
embedding: Embedding {
|
||||
weight: cx.named_tensor("Embedding Weight"),
|
||||
},
|
||||
norm: LayerNorm::new(true, false, false, 1e-5, cx),
|
||||
lm_head: PermutedLinear {
|
||||
weight: cx.tensor(),
|
||||
},
|
||||
layers: (0..NUM_LAYERS)
|
||||
.map(|_| InitModule::initialize(cx))
|
||||
.collect(),
|
||||
embedding: Embedding::new(VOCAB_SIZE, HIDDEN_DIM, cx),
|
||||
head: (
|
||||
LayerNorm::new(HIDDEN_DIM, true, false, false, 1e-5, cx),
|
||||
Linear::new_permuted(HIDDEN_DIM, VOCAB_SIZE, false, cx),
|
||||
),
|
||||
layers: (0..NUM_LAYERS).map(|_| TransformerBlock::new(cx)).collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -315,8 +241,8 @@ impl InitModule for Phi {
|
||||
impl SerializeModule for Phi {
|
||||
fn serialize(&self, s: &mut Serializer) {
|
||||
s.module("token_embd", &self.embedding);
|
||||
s.module("output_norm", &self.norm);
|
||||
s.module("output", &self.lm_head);
|
||||
s.module("output_norm", &self.head.0);
|
||||
s.module("output", &self.head.1);
|
||||
for (i, layer) in self.layers.iter().enumerate() {
|
||||
s.module(&format!("blk/{i}"), layer);
|
||||
}
|
||||
|
||||
@@ -5,9 +5,9 @@ fn main() {
|
||||
// Create a new graph
|
||||
let mut cx = Graph::new();
|
||||
// Randomly initialize a linear layer with an input size of 4 and an output size of 5
|
||||
let model = Linear::<4, 5>::initialize(&mut cx);
|
||||
let model = Linear::new(4, 5, false, &mut cx).initialize();
|
||||
// Make an input tensor
|
||||
let a = cx.tensor::<R1<4>>().set(vec![1., 2., 3., 4.]);
|
||||
let a = cx.tensor(4).set(vec![1., 2., 3., 4.]);
|
||||
// Feed tensor through model
|
||||
let b = model.forward(a).retrieve();
|
||||
|
||||
|
||||
@@ -10,9 +10,15 @@ use rand::{rngs::ThreadRng, thread_rng, Rng};
|
||||
fn main() {
|
||||
// Setup gradient graph
|
||||
let mut cx = Graph::new();
|
||||
let model = <(Linear<8, 16>, Swish, Linear<16, 16>, Swish, Linear<16, 5>)>::initialize(&mut cx);
|
||||
let mut input = cx.tensor::<R1<8>>();
|
||||
let mut target = cx.tensor::<R1<5>>();
|
||||
let model = (
|
||||
Linear::new(8, 16, false, &mut cx).initialize(),
|
||||
Swish,
|
||||
Linear::new(16, 16, false, &mut cx).initialize(),
|
||||
Swish,
|
||||
Linear::new(16, 5, false, &mut cx).initialize(),
|
||||
);
|
||||
let mut input = cx.tensor(8);
|
||||
let mut target = cx.tensor(5);
|
||||
let mut output = model.forward(input).retrieve();
|
||||
let mut loss = mse_loss(output, target).retrieve();
|
||||
|
||||
@@ -38,7 +44,10 @@ fn main() {
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
cx.compile(
|
||||
luminal_metal::MetalCompiler::<f32>::default(),
|
||||
(
|
||||
GenericCompiler::default(),
|
||||
luminal_metal::MetalCompiler::<f32>::default(),
|
||||
),
|
||||
(
|
||||
&mut input,
|
||||
&mut target,
|
||||
|
||||
@@ -4,7 +4,7 @@ use std::{io::Write, marker::PhantomData};
|
||||
use itertools::Itertools;
|
||||
// WIP
|
||||
use luminal::prelude::*;
|
||||
use model::KVCache;
|
||||
use model::{KVCache, D_MODEL, HEADS, HEAD_DIM, N_MEL_BINS};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
mod audio;
|
||||
@@ -20,33 +20,30 @@ fn main() {
|
||||
|
||||
// Construct encoder graph
|
||||
let mut enc_cx = Graph::new();
|
||||
let encoder = model::AudioEncoder::initialize(&mut enc_cx);
|
||||
let encoder = model::AudioEncoder::new(&mut enc_cx);
|
||||
let mut encoder_params = params(&encoder);
|
||||
enc_cx.keep_tensors(&encoder_params);
|
||||
let mut audio_input = enc_cx.tensor::<(Const<1>, Const<{ model::N_MEL_BINS }>, Dyn<'s'>)>();
|
||||
let mut encoded: GraphTensor<(Const<1>, Dyn<'d'>, Const<384>)> = encoder
|
||||
.forward((audio_input, PhantomData::<Dyn<'d'>>))
|
||||
.keep();
|
||||
let mut audio_input = enc_cx.tensor((1, N_MEL_BINS, 's'));
|
||||
let mut encoded = encoder.forward(audio_input).keep();
|
||||
loader::load("setup/whisper-tiny.safetensors", &encoder, &mut enc_cx);
|
||||
|
||||
// Construct decoder graph
|
||||
let mut dec_cx = Graph::new();
|
||||
let decoder = model::TextDecoder::initialize(&mut dec_cx);
|
||||
let decoder = model::TextDecoder::new(&mut dec_cx);
|
||||
let mut decoder_params = params(&decoder);
|
||||
dec_cx.keep_tensors(&decoder_params);
|
||||
let mut text_input = dec_cx.tensor::<(Const<1>, Dyn<'s'>)>();
|
||||
let mut encoder_output =
|
||||
dec_cx.named_tensor::<(Const<1>, Dyn<'e'>, Const<{ model::D_MODEL }>)>("Enc Output");
|
||||
let mut cache_src: Vec<KVCache<Const<1>, Dyn<'p'>>> = (0..model::DEC_LAYERS)
|
||||
.map(|_| (dec_cx.named_tensor("Keys"), dec_cx.named_tensor("Values")))
|
||||
.collect();
|
||||
cache_src.set_dyn(vec![], &[1, 6, 64, 0]);
|
||||
let (logits, _, mut cache_dest) = decoder.forward((
|
||||
encoder_output,
|
||||
text_input,
|
||||
&cache_src,
|
||||
PhantomData::<Dyn<'t'>>,
|
||||
));
|
||||
let mut text_input = dec_cx.tensor((1, 's'));
|
||||
let mut encoder_output = dec_cx.named_tensor("Enc Output", (1, 'e', D_MODEL));
|
||||
let mut cache_src = (0..model::DEC_LAYERS)
|
||||
.map(|_| {
|
||||
(
|
||||
dec_cx.named_tensor("Keys", (1, HEADS, HEAD_DIM, 'p')),
|
||||
dec_cx.named_tensor("Values", (1, HEADS, 'p', HEAD_DIM)),
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
cache_src.set_dyn(vec![], (1, 6, 64, 0));
|
||||
let (logits, _, mut cache_dest) = decoder.forward((encoder_output, text_input, &cache_src));
|
||||
let mut logits = logits
|
||||
.slice((.., Expression::from('s') - 1.., ..))
|
||||
.retrieve();
|
||||
@@ -103,11 +100,11 @@ fn main() {
|
||||
print!("Loading weights");
|
||||
std::io::stdout().flush().unwrap();
|
||||
let now = std::time::Instant::now();
|
||||
audio_input.set_dyn(vec![0.; 160], &[1, 80, 2]);
|
||||
audio_input.set_dyn(vec![0.; 160], (1, 80, 2));
|
||||
enc_cx.set_dyn_dim('d', 1);
|
||||
enc_cx.execute();
|
||||
delete_inputs(downstream(encoder_params, &enc_cx), &mut enc_cx);
|
||||
text_input.set_dyn(vec![0.], &[1, 1]);
|
||||
text_input.set_dyn(vec![0.], (1, 1));
|
||||
dec_cx.set_dyn_dim('e', 1);
|
||||
dec_cx.set_dyn_dim('p', 0);
|
||||
dec_cx.set_dyn_dim('t', 1);
|
||||
@@ -132,7 +129,7 @@ fn main() {
|
||||
std::io::stdout().flush().unwrap();
|
||||
let start_encoding = std::time::Instant::now();
|
||||
|
||||
audio_input.set_dyn(mel, &[1, 80, mel_len / 80]);
|
||||
audio_input.set_dyn(mel, (1, 80, mel_len / 80));
|
||||
enc_cx.set_dyn_dim('d', (mel_len / 80) / 2);
|
||||
enc_cx.execute();
|
||||
transfer_data(encoded, &mut enc_cx, encoder_output, &mut dec_cx);
|
||||
@@ -144,7 +141,7 @@ fn main() {
|
||||
dec_cx.set_dyn_dim('p', 0);
|
||||
dec_cx.set_dyn_dim('t', 3);
|
||||
let mut output_ids = vec![];
|
||||
text_input.set_dyn(vec![50257., 50358., 50362.], &[1, 3]);
|
||||
text_input.set_dyn(vec![50257., 50358., 50362.], (1, 3));
|
||||
dec_cx.execute();
|
||||
let mut output_token = argmax(&logits.data());
|
||||
logits.drop();
|
||||
@@ -156,7 +153,7 @@ fn main() {
|
||||
|
||||
for i in 0..100 {
|
||||
transfer_data_same_graph(&cache_dest, &cache_src, &mut dec_cx);
|
||||
text_input.set_dyn(vec![output_token as f32], &[1, 1]);
|
||||
text_input.set_dyn(vec![output_token as f32], (1, 1));
|
||||
dec_cx.set_dyn_dim('p', i + 3);
|
||||
dec_cx.set_dyn_dim('t', i + 4);
|
||||
dec_cx.execute();
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use luminal::prelude::{binary::F32Pow, *};
|
||||
use luminal_nn::{Conv1D, Embedding, LayerNorm, Linear, PermutedEmbedding, PermutedLinear};
|
||||
use luminal_nn::{Conv1D, Embedding, LayerNorm, Linear};
|
||||
use std::marker::PhantomData;
|
||||
use std::ops::{Add, Mul};
|
||||
|
||||
@@ -41,172 +41,150 @@ pub const NO_TIMESTAMPS_TOKEN: &str = "<|notimestamps|>";
|
||||
pub const EOT_TOKEN: &str = "<|endoftext|>";
|
||||
pub const NO_SPEECH_TOKENS: [&str; 2] = ["<|nocaptions|>", "<|nospeech|>"];
|
||||
|
||||
pub type KVCache<Batch, Seq> = (
|
||||
GraphTensor<(Batch, Const<HEADS>, Const<HEAD_DIM>, Seq)>,
|
||||
GraphTensor<(Batch, Const<HEADS>, Seq, Const<HEAD_DIM>)>,
|
||||
);
|
||||
pub type KVCache = (GraphTensor, GraphTensor);
|
||||
|
||||
pub struct SelfAttention<const HIDDEN: usize> {
|
||||
pub q_proj: GraphTensor<R2<HIDDEN, HIDDEN>>,
|
||||
pub q_proj_bias: GraphTensor<R1<HIDDEN>>,
|
||||
pub k_proj: GraphTensor<R2<HIDDEN, HIDDEN>>,
|
||||
pub v_proj: GraphTensor<R2<HIDDEN, HIDDEN>>,
|
||||
pub v_proj_bias: GraphTensor<R1<HIDDEN>>,
|
||||
pub o_proj: GraphTensor<R2<HIDDEN, HIDDEN>>,
|
||||
pub o_proj_bias: GraphTensor<R1<HIDDEN>>,
|
||||
pub struct SelfAttention {
|
||||
pub q_proj: GraphTensor, // hidden x hidden
|
||||
pub q_proj_bias: GraphTensor, // hidden
|
||||
pub k_proj: GraphTensor, // hidden x hidden
|
||||
pub v_proj: GraphTensor, // hidden x hidden
|
||||
pub v_proj_bias: GraphTensor, // hidden
|
||||
pub o_proj: GraphTensor, // hidden x hidden
|
||||
pub o_proj_bias: GraphTensor, // hidden
|
||||
}
|
||||
|
||||
impl<
|
||||
const HIDDEN: usize,
|
||||
Batch: Dimension,
|
||||
CurSeq: Dimension,
|
||||
PrevSeq: Dimension,
|
||||
TotSeq: Dimension,
|
||||
>
|
||||
Module<(
|
||||
GraphTensor<(Batch, CurSeq, Const<HIDDEN>)>,
|
||||
Option<KVCache<Batch, PrevSeq>>,
|
||||
bool,
|
||||
PhantomData<TotSeq>,
|
||||
)> for SelfAttention<HIDDEN>
|
||||
{
|
||||
type Output = (
|
||||
GraphTensor<(Batch, CurSeq, Const<HIDDEN>)>,
|
||||
KVCache<Batch, TotSeq>,
|
||||
);
|
||||
fn forward(
|
||||
&self,
|
||||
(x, cache, mask, _): (
|
||||
GraphTensor<(Batch, CurSeq, Const<HIDDEN>)>,
|
||||
Option<KVCache<Batch, PrevSeq>>,
|
||||
bool,
|
||||
PhantomData<TotSeq>,
|
||||
),
|
||||
) -> Self::Output {
|
||||
let scale = ((HIDDEN as f32 / HEADS as f32) as f64).powf(-0.25) as f32;
|
||||
impl Module<(GraphTensor, Option<KVCache>, bool)> for SelfAttention {
|
||||
type Output = (GraphTensor, KVCache);
|
||||
fn forward(&self, (x, cache, mask): (GraphTensor, Option<KVCache>, bool)) -> Self::Output {
|
||||
// x: batch, seq, hidden
|
||||
let (batch, seq, hidden) = x.dims3();
|
||||
let scale = ((hidden.to_usize().unwrap() as f32 / HEADS as f32) as f64).powf(-0.25) as f32;
|
||||
// Apply the Projections
|
||||
let queries = x
|
||||
.matmul(self.q_proj.permute())
|
||||
.add(self.q_proj_bias.expand())
|
||||
.matmul(self.q_proj.permute((1, 0)))
|
||||
.add(self.q_proj_bias.expand(0, batch).expand(1, seq))
|
||||
.mul(scale)
|
||||
.reshape::<(Batch, CurSeq, Const<HEADS>, Const<HEAD_DIM>)>()
|
||||
.permute::<_, Axes4<0, 2, 1, 3>>();
|
||||
.reshape((batch, seq, HEADS, HEAD_DIM))
|
||||
.permute((0, 2, 1, 3));
|
||||
|
||||
let keys = x
|
||||
.matmul(self.k_proj.permute())
|
||||
.matmul(self.k_proj.permute((1, 0)))
|
||||
.mul(scale)
|
||||
.reshape::<(Batch, CurSeq, Const<HEADS>, Const<HEAD_DIM>)>()
|
||||
.permute::<_, Axes4<0, 2, 3, 1>>()
|
||||
.reshape((batch, seq, HEADS, HEAD_DIM))
|
||||
.permute((0, 2, 3, 1))
|
||||
.contiguous();
|
||||
|
||||
let values = x
|
||||
.matmul(self.v_proj.permute())
|
||||
.add(self.v_proj_bias.expand())
|
||||
.reshape::<(Batch, CurSeq, Const<HEADS>, Const<HEAD_DIM>)>()
|
||||
.permute::<_, Axes4<0, 2, 1, 3>>();
|
||||
.matmul(self.v_proj.permute((1, 0)))
|
||||
.add(self.v_proj_bias.expand(0, batch).expand(1, seq))
|
||||
.reshape((batch, seq, HEADS, HEAD_DIM))
|
||||
.permute((0, 2, 1, 3));
|
||||
|
||||
// Add KV cache
|
||||
let (keys, values) = if let Some((k_cache, v_cache)) = cache {
|
||||
(
|
||||
k_cache.concat_along::<_, Axis<3>, _>(keys),
|
||||
v_cache.concat_along::<_, Axis<2>, _>(values),
|
||||
k_cache.concat_along(keys, 3),
|
||||
v_cache.concat_along(values, 2),
|
||||
)
|
||||
} else {
|
||||
(keys.realize(), values.realize())
|
||||
(keys, values)
|
||||
};
|
||||
|
||||
// Calculate attention weights
|
||||
let mut attention_weights = queries.matmul(keys);
|
||||
|
||||
if mask {
|
||||
let attention_mask = self.k_proj.graph().triu::<CurSeq>(1) * f16::MIN.to_f32();
|
||||
attention_weights += attention_mask
|
||||
.pad::<(CurSeq, TotSeq)>(((0, 0), (TotSeq::size() - CurSeq::size(), 0)))
|
||||
.expand();
|
||||
let mut attention_mask = self.k_proj.graph().triu(seq, 1) * f16::MIN.to_f32();
|
||||
if let Some((c, _)) = cache {
|
||||
let (_, _, prev_seq, _) = c.dims4();
|
||||
attention_mask = attention_mask.pad(((0, 0), (prev_seq, 0)));
|
||||
}
|
||||
attention_weights += attention_mask.expand(0, batch).expand(1, HEADS);
|
||||
}
|
||||
|
||||
// Calculate final outputs
|
||||
let output = attention_weights
|
||||
.softmax::<Axis<3>>()
|
||||
.softmax(3)
|
||||
// Apply distribution to values
|
||||
.matmul(values)
|
||||
// Merge heads
|
||||
.permute::<_, Axes4<0, 2, 1, 3>>()
|
||||
.reshape::<(Batch, CurSeq, Const<HIDDEN>)>();
|
||||
.permute((0, 2, 1, 3))
|
||||
.reshape((batch, seq, hidden));
|
||||
let output = output
|
||||
// Apply output projection
|
||||
.matmul(self.o_proj.permute())
|
||||
.add(self.o_proj_bias.expand());
|
||||
.matmul(self.o_proj.permute((1, 0)))
|
||||
.add(self.o_proj_bias.expand(0, batch).expand(1, seq));
|
||||
(output, (keys.contiguous(), values.contiguous())) // Cache needs to be contiguous
|
||||
}
|
||||
}
|
||||
|
||||
impl<const HIDDEN: usize> SelfAttention<HIDDEN> {
|
||||
#[allow(clippy::type_complexity)]
|
||||
fn cross_attention_forward<Batch: Dimension, EncSeq: Dimension, DecSeq: Dimension>(
|
||||
impl SelfAttention {
|
||||
fn cross_attention_forward(
|
||||
&self,
|
||||
queries: GraphTensor<(Batch, DecSeq, Const<HIDDEN>)>,
|
||||
keys: GraphTensor<(Batch, EncSeq, Const<HIDDEN>)>,
|
||||
values: GraphTensor<(Batch, EncSeq, Const<HIDDEN>)>,
|
||||
queries: GraphTensor, // batch, dec_seq, hidden
|
||||
keys: GraphTensor, // batch, enc_seq, hidden
|
||||
values: GraphTensor, // batch, enc_seq, hidden
|
||||
) -> (
|
||||
GraphTensor<(Batch, DecSeq, Const<HIDDEN>)>,
|
||||
KVCache<Batch, EncSeq>,
|
||||
GraphTensor, // batch, dec_seq, hidden
|
||||
KVCache,
|
||||
) {
|
||||
let scale = ((HIDDEN as f32 / HEADS as f32) as f64).powf(-0.25) as f32;
|
||||
let (batch, enc_seq, hidden) = keys.dims3();
|
||||
let (_, dec_seq, hidden) = queries.dims3();
|
||||
let scale = ((hidden.to_usize().unwrap() as f32 / HEADS as f32) as f64).powf(-0.25) as f32;
|
||||
// Apply the projections
|
||||
let queries = queries
|
||||
.matmul(self.q_proj.permute())
|
||||
.add(self.q_proj_bias.expand())
|
||||
.matmul(self.q_proj.permute((1, 0)))
|
||||
.add(self.q_proj_bias.expand(0, batch).expand(1, dec_seq))
|
||||
.mul(scale)
|
||||
.reshape::<(Batch, DecSeq, Const<HEADS>, Const<HEAD_DIM>)>()
|
||||
.permute::<_, Axes4<0, 2, 1, 3>>();
|
||||
.reshape((batch, dec_seq, HEADS, HEAD_DIM))
|
||||
.permute((0, 2, 1, 3));
|
||||
let keys = keys
|
||||
.matmul(self.k_proj.permute())
|
||||
.matmul(self.k_proj.permute((1, 0)))
|
||||
.mul(scale)
|
||||
.reshape::<(Batch, EncSeq, Const<HEADS>, Const<HEAD_DIM>)>()
|
||||
.permute::<_, Axes4<0, 2, 3, 1>>()
|
||||
.reshape((batch, enc_seq, HEADS, HEAD_DIM))
|
||||
.permute((0, 2, 3, 1))
|
||||
.contiguous();
|
||||
let values = values
|
||||
.matmul(self.v_proj.permute())
|
||||
.add(self.v_proj_bias.expand())
|
||||
.reshape::<(Batch, EncSeq, Const<HEADS>, Const<HEAD_DIM>)>()
|
||||
.permute::<_, Axes4<0, 2, 1, 3>>();
|
||||
.matmul(self.v_proj.permute((1, 0)))
|
||||
.add(self.v_proj_bias.expand(0, batch).expand(1, enc_seq))
|
||||
.reshape((batch, enc_seq, HEADS, HEAD_DIM))
|
||||
.permute((0, 2, 1, 3));
|
||||
|
||||
// Calculate attention weights
|
||||
let mut attention_weights = queries.matmul(keys);
|
||||
|
||||
// Calculate final outputs
|
||||
let output = attention_weights
|
||||
.softmax::<Axis<3>>()
|
||||
.softmax(3)
|
||||
// Apply distribution to values
|
||||
.matmul(values)
|
||||
// Merge heads
|
||||
.permute::<_, Axes4<0, 2, 1, 3>>()
|
||||
.reshape::<(Batch, DecSeq, Const<HIDDEN>)>();
|
||||
.permute((0, 2, 1, 3))
|
||||
.reshape((batch, dec_seq, hidden));
|
||||
let output = output
|
||||
// Apply output projection
|
||||
.matmul(self.o_proj.permute())
|
||||
.add(self.o_proj_bias.expand());
|
||||
.matmul(self.o_proj.permute((1, 0)))
|
||||
.add(self.o_proj_bias.expand(0, batch).expand(1, dec_seq));
|
||||
|
||||
(output, (keys, values))
|
||||
}
|
||||
}
|
||||
|
||||
impl<const HIDDEN: usize> InitModule for SelfAttention<HIDDEN> {
|
||||
fn initialize(cx: &mut Graph) -> Self {
|
||||
impl SelfAttention {
|
||||
fn new(hidden: usize, cx: &mut Graph) -> Self {
|
||||
Self {
|
||||
q_proj: cx.named_tensor("Q Proj"),
|
||||
q_proj_bias: cx.named_tensor("Q Proj Bias"),
|
||||
k_proj: cx.named_tensor("K Proj"),
|
||||
v_proj: cx.named_tensor("V Proj"),
|
||||
v_proj_bias: cx.named_tensor("V Proj Bias"),
|
||||
o_proj: cx.named_tensor("O Proj"),
|
||||
o_proj_bias: cx.named_tensor("O Proj Bias"),
|
||||
q_proj: cx.named_tensor("Q Proj", (hidden, hidden)),
|
||||
q_proj_bias: cx.named_tensor("Q Proj Bias", hidden),
|
||||
k_proj: cx.named_tensor("K Proj", (hidden, hidden)),
|
||||
v_proj: cx.named_tensor("V Proj", (hidden, hidden)),
|
||||
v_proj_bias: cx.named_tensor("V Proj Bias", hidden),
|
||||
o_proj: cx.named_tensor("O Proj", (hidden, hidden)),
|
||||
o_proj_bias: cx.named_tensor("O Proj Bias", hidden),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<const HIDDEN: usize> SerializeModule for SelfAttention<HIDDEN> {
|
||||
impl SerializeModule for SelfAttention {
|
||||
fn serialize(&self, s: &mut Serializer) {
|
||||
s.tensor("q_proj/weight", self.q_proj);
|
||||
s.tensor("q_proj/bias", self.q_proj_bias);
|
||||
@@ -219,54 +197,42 @@ impl<const HIDDEN: usize> SerializeModule for SelfAttention<HIDDEN> {
|
||||
}
|
||||
|
||||
pub struct EncoderTransformerBlock {
|
||||
pub attention: SelfAttention<D_MODEL>,
|
||||
pub attention_norm: LayerNorm<D_MODEL>,
|
||||
pub ff1: PermutedLinear<D_MODEL, ENC_FFN_DIM>,
|
||||
pub ff1_bias: GraphTensor<R1<ENC_FFN_DIM>>,
|
||||
pub ff2: PermutedLinear<ENC_FFN_DIM, D_MODEL>,
|
||||
pub ff2_bias: GraphTensor<R1<D_MODEL>>,
|
||||
pub feed_forward_norm: LayerNorm<D_MODEL>,
|
||||
pub attention: SelfAttention,
|
||||
pub attention_norm: LayerNorm,
|
||||
pub ff1: Linear, // hidden -> enc_ffn_dim
|
||||
pub ff2: Linear, // enc_ffn_dim -> hidden
|
||||
pub feed_forward_norm: LayerNorm,
|
||||
}
|
||||
|
||||
impl<Batch: Dimension, Seq: Dimension> Module<GraphTensor<(Batch, Seq, Const<D_MODEL>)>>
|
||||
for EncoderTransformerBlock
|
||||
{
|
||||
type Output = GraphTensor<(Batch, Seq, Const<D_MODEL>)>;
|
||||
fn forward(&self, mut x: GraphTensor<(Batch, Seq, Const<D_MODEL>)>) -> Self::Output {
|
||||
impl Module<GraphTensor> for EncoderTransformerBlock {
|
||||
type Output = GraphTensor;
|
||||
fn forward(&self, mut x: GraphTensor) -> Self::Output {
|
||||
let (batch, seq, _) = x.dims3();
|
||||
// Attention
|
||||
let (y, _) = self.attention.forward((
|
||||
self.attention_norm.forward(x),
|
||||
Option::<KVCache<Batch, Seq>>::None,
|
||||
false,
|
||||
PhantomData::<Seq>,
|
||||
));
|
||||
let (y, _) = self
|
||||
.attention
|
||||
.forward((self.attention_norm.forward(x), None, false));
|
||||
|
||||
// Residual Addition
|
||||
x += y;
|
||||
|
||||
// Feed Forward
|
||||
let y = self.ff1.forward(self.feed_forward_norm.forward(x)) + self.ff1_bias.expand();
|
||||
let y = self.ff2.forward(y.gelu()) + self.ff2_bias.expand();
|
||||
let y = self.ff1.forward(self.feed_forward_norm.forward(x));
|
||||
let y = self.ff2.forward(y.gelu());
|
||||
|
||||
// Residual Addition
|
||||
x + y
|
||||
}
|
||||
}
|
||||
|
||||
impl InitModule for EncoderTransformerBlock {
|
||||
fn initialize(cx: &mut Graph) -> Self {
|
||||
impl EncoderTransformerBlock {
|
||||
fn new(hidden: usize, ff: usize, cx: &mut Graph) -> Self {
|
||||
Self {
|
||||
attention: InitModule::initialize(cx),
|
||||
attention_norm: LayerNorm::new(true, true, true, 1e-5, cx),
|
||||
ff1: PermutedLinear {
|
||||
weight: cx.tensor(),
|
||||
},
|
||||
ff1_bias: cx.tensor(),
|
||||
ff2: PermutedLinear {
|
||||
weight: cx.tensor(),
|
||||
},
|
||||
ff2_bias: cx.tensor(),
|
||||
feed_forward_norm: LayerNorm::new(true, true, true, 1e-5, cx),
|
||||
attention: SelfAttention::new(hidden, cx),
|
||||
attention_norm: LayerNorm::new(hidden, true, true, true, 1e-5, cx),
|
||||
ff1: Linear::new_permuted(hidden, ff, true, cx),
|
||||
ff2: Linear::new_permuted(ff, hidden, true, cx),
|
||||
feed_forward_norm: LayerNorm::new(hidden, true, true, true, 1e-5, cx),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -277,64 +243,46 @@ impl SerializeModule for EncoderTransformerBlock {
|
||||
s.module("self_attn_layer_norm", &self.attention_norm);
|
||||
s.module("final_layer_norm", &self.feed_forward_norm);
|
||||
s.module("fc1", &self.ff1);
|
||||
s.tensor("fc1/bias", self.ff1_bias);
|
||||
s.module("fc2", &self.ff2);
|
||||
s.tensor("fc2/bias", self.ff2_bias);
|
||||
}
|
||||
}
|
||||
|
||||
pub struct AudioEncoder {
|
||||
// Conv layers (based on https://github.com/huggingface/candle/blob/59b18d974ec3cad6963b774aa245e23f8c80414f/candle-transformers/src/models/whisper/model.rs#L246)
|
||||
pub conv1: Conv1D<N_MEL_BINS, D_MODEL, 3, 1, 0, 1>,
|
||||
pub conv2: Conv1D<D_MODEL, D_MODEL, 3, 2, 0, 1>,
|
||||
pub conv1: Conv1D,
|
||||
pub conv2: Conv1D,
|
||||
// Transformer layers
|
||||
pub layers: Vec<EncoderTransformerBlock>,
|
||||
// Post layer norm
|
||||
pub post_ln: LayerNorm<D_MODEL>,
|
||||
pub post_ln: LayerNorm,
|
||||
}
|
||||
|
||||
fn sinusoids<const CHANNELS: usize, Length: Dimension>(
|
||||
cx: &mut Graph,
|
||||
) -> GraphTensor<(Length, Const<CHANNELS>)> {
|
||||
fn sinusoids(channels: usize, length: Expression, cx: &mut Graph) -> GraphTensor {
|
||||
let max_timescale = 10000f32;
|
||||
let log_timescale_increment = max_timescale.ln() / (CHANNELS / 2 - 1) as f32;
|
||||
let inv_timescales = (0..CHANNELS / 2)
|
||||
let log_timescale_increment = max_timescale.ln() / (channels / 2 - 1) as f32;
|
||||
let inv_timescales = (0..channels / 2)
|
||||
.map(|i| (i as f32 * (-log_timescale_increment)).exp())
|
||||
.collect::<Vec<_>>();
|
||||
let mut inv_timescales = cx
|
||||
.tensor::<(Dyn<'-'>,)>()
|
||||
.set_dyn(inv_timescales, &[CHANNELS / 2]);
|
||||
inv_timescales.shape.dims[0] = (CHANNELS / 2).into();
|
||||
let arange = cx.arange::<Length>();
|
||||
.tensor(channels / 2)
|
||||
.set_dyn(inv_timescales, (channels / 2));
|
||||
let arange = cx.arange(length);
|
||||
let mut mul_shape = arange.shape;
|
||||
mul_shape.add_dim(1, CHANNELS / 2);
|
||||
let scaled_time: GraphTensor<(Length, Dyn<'-'>)> =
|
||||
arange.expand_to(mul_shape) * inv_timescales.expand_to(mul_shape);
|
||||
scaled_time
|
||||
.sin()
|
||||
.concat_along::<_, Axis<1>, _>(scaled_time.cos())
|
||||
mul_shape.add_dim(1, channels / 2);
|
||||
let scaled_time = arange.expand_to(mul_shape) * inv_timescales.expand_to(mul_shape);
|
||||
scaled_time.sin().concat_along(scaled_time.cos(), 1)
|
||||
}
|
||||
|
||||
impl<Batch: Dimension, Seq: Dimension, SeqDivTwo: Dimension>
|
||||
Module<(
|
||||
GraphTensor<(Batch, Const<N_MEL_BINS>, Seq)>,
|
||||
PhantomData<SeqDivTwo>,
|
||||
)> for AudioEncoder
|
||||
{
|
||||
type Output = GraphTensor<(Batch, SeqDivTwo, Const<D_MODEL>)>;
|
||||
fn forward(
|
||||
&self,
|
||||
(x, _): (
|
||||
GraphTensor<(Batch, Const<N_MEL_BINS>, Seq)>,
|
||||
PhantomData<SeqDivTwo>,
|
||||
),
|
||||
) -> Self::Output {
|
||||
impl Module<GraphTensor> for AudioEncoder {
|
||||
type Output = GraphTensor;
|
||||
fn forward(&self, x: GraphTensor) -> Self::Output {
|
||||
let (_, seq, _) = x.dims3();
|
||||
// Conv layers
|
||||
let x = self.conv1.forward((x, PhantomData::<Seq>)).gelu();
|
||||
let x = self.conv2.forward((x, PhantomData::<SeqDivTwo>)).gelu();
|
||||
let mut x = x.permute::<_, Axes3<0, 2, 1>>();
|
||||
let x = self.conv1.forward(x).gelu();
|
||||
let x = self.conv2.forward(x).gelu();
|
||||
let mut x = x.permute((0, 2, 1));
|
||||
// Sinusoidal positional embedding
|
||||
x += sinusoids::<D_MODEL, SeqDivTwo>(x.graph()).expand();
|
||||
x += sinusoids(D_MODEL, seq / 2, x.graph()).expand_to(x.shape);
|
||||
|
||||
// Transformer layers
|
||||
let out = self.layers.forward(x);
|
||||
@@ -343,15 +291,15 @@ impl<Batch: Dimension, Seq: Dimension, SeqDivTwo: Dimension>
|
||||
}
|
||||
}
|
||||
|
||||
impl InitModule for AudioEncoder {
|
||||
fn initialize(cx: &mut Graph) -> Self {
|
||||
impl AudioEncoder {
|
||||
pub fn new(cx: &mut Graph) -> Self {
|
||||
Self {
|
||||
conv1: Conv1D::initialize_bias(cx),
|
||||
conv2: Conv1D::initialize_bias(cx),
|
||||
conv1: Conv1D::new(N_MEL_BINS, D_MODEL, 3, 1, 1, 1, true, cx),
|
||||
conv2: Conv1D::new(D_MODEL, D_MODEL, 3, 2, 1, 1, true, cx),
|
||||
layers: (0..ENC_LAYERS)
|
||||
.map(|_| InitModule::initialize(cx))
|
||||
.map(|_| EncoderTransformerBlock::new(D_MODEL, ENC_FFN_DIM, cx))
|
||||
.collect(),
|
||||
post_ln: LayerNorm::new(true, true, true, 1e-5, cx),
|
||||
post_ln: LayerNorm::new(D_MODEL, true, true, true, 1e-5, cx),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -368,52 +316,25 @@ impl SerializeModule for AudioEncoder {
|
||||
}
|
||||
|
||||
pub struct DecoderTransformerBlock {
|
||||
pub attention: SelfAttention<D_MODEL>,
|
||||
pub attention_norm: LayerNorm<D_MODEL>,
|
||||
pub cross_attention: SelfAttention<D_MODEL>,
|
||||
pub cross_attention_norm: LayerNorm<D_MODEL>,
|
||||
pub ff1: PermutedLinear<D_MODEL, DEC_FFN_DIM>,
|
||||
pub ff1_bias: GraphTensor<R1<DEC_FFN_DIM>>,
|
||||
pub ff2: PermutedLinear<DEC_FFN_DIM, D_MODEL>,
|
||||
pub ff2_bias: GraphTensor<R1<D_MODEL>>,
|
||||
pub feed_forward_norm: LayerNorm<D_MODEL>,
|
||||
pub attention: SelfAttention,
|
||||
pub attention_norm: LayerNorm,
|
||||
pub cross_attention: SelfAttention,
|
||||
pub cross_attention_norm: LayerNorm,
|
||||
pub ff1: Linear,
|
||||
pub ff2: Linear,
|
||||
pub feed_forward_norm: LayerNorm,
|
||||
}
|
||||
|
||||
impl<
|
||||
Batch: Dimension,
|
||||
EncSeq: Dimension,
|
||||
CurSeq: Dimension,
|
||||
PrevSeq: Dimension,
|
||||
TotSeq: Dimension,
|
||||
>
|
||||
Module<(
|
||||
GraphTensor<(Batch, CurSeq, Const<D_MODEL>)>,
|
||||
GraphTensor<(Batch, EncSeq, Const<D_MODEL>)>,
|
||||
KVCache<Batch, PrevSeq>,
|
||||
PhantomData<TotSeq>,
|
||||
)> for DecoderTransformerBlock
|
||||
{
|
||||
type Output = (
|
||||
GraphTensor<(Batch, CurSeq, Const<D_MODEL>)>,
|
||||
KVCache<Batch, EncSeq>,
|
||||
KVCache<Batch, TotSeq>,
|
||||
);
|
||||
impl Module<(GraphTensor, GraphTensor, KVCache)> for DecoderTransformerBlock {
|
||||
type Output = (GraphTensor, KVCache, KVCache);
|
||||
fn forward(
|
||||
&self,
|
||||
(mut x, encoded, cache, _): (
|
||||
GraphTensor<(Batch, CurSeq, Const<D_MODEL>)>,
|
||||
GraphTensor<(Batch, EncSeq, Const<D_MODEL>)>,
|
||||
KVCache<Batch, PrevSeq>,
|
||||
PhantomData<TotSeq>,
|
||||
),
|
||||
(mut x, encoded, cache): (GraphTensor, GraphTensor, KVCache),
|
||||
) -> Self::Output {
|
||||
// Self Attention
|
||||
let (y, cache) = self.attention.forward((
|
||||
self.attention_norm.forward(x),
|
||||
Some(cache),
|
||||
true,
|
||||
PhantomData::<TotSeq>,
|
||||
));
|
||||
let (y, cache) =
|
||||
self.attention
|
||||
.forward((self.attention_norm.forward(x), Some(cache), true));
|
||||
|
||||
// Residual Addition
|
||||
x += y;
|
||||
@@ -429,30 +350,24 @@ impl<
|
||||
x += y;
|
||||
|
||||
// Feed Forward
|
||||
let y = self.ff1.forward(self.feed_forward_norm.forward(x)) + self.ff1_bias.expand();
|
||||
let y = self.ff2.forward(y.gelu()) + self.ff2_bias.expand();
|
||||
let y = self.ff1.forward(self.feed_forward_norm.forward(x));
|
||||
let y = self.ff2.forward(y.gelu());
|
||||
|
||||
// Residual Addition
|
||||
(x + y, enc_states, cache)
|
||||
}
|
||||
}
|
||||
|
||||
impl InitModule for DecoderTransformerBlock {
|
||||
fn initialize(cx: &mut Graph) -> Self {
|
||||
impl DecoderTransformerBlock {
|
||||
fn new(cx: &mut Graph) -> Self {
|
||||
Self {
|
||||
attention: InitModule::initialize(cx),
|
||||
attention_norm: LayerNorm::new(true, true, true, 1e-5, cx),
|
||||
cross_attention: InitModule::initialize(cx),
|
||||
cross_attention_norm: LayerNorm::new(true, true, true, 1e-5, cx),
|
||||
ff1: PermutedLinear {
|
||||
weight: cx.tensor(),
|
||||
},
|
||||
ff1_bias: cx.tensor(),
|
||||
ff2: PermutedLinear {
|
||||
weight: cx.tensor(),
|
||||
},
|
||||
ff2_bias: cx.tensor(),
|
||||
feed_forward_norm: LayerNorm::new(true, true, true, 1e-5, cx),
|
||||
attention: SelfAttention::new(D_MODEL, cx),
|
||||
attention_norm: LayerNorm::new(D_MODEL, true, true, true, 1e-5, cx),
|
||||
cross_attention: SelfAttention::new(D_MODEL, cx),
|
||||
cross_attention_norm: LayerNorm::new(D_MODEL, true, true, true, 1e-5, cx),
|
||||
ff1: Linear::new_permuted(D_MODEL, DEC_FFN_DIM, true, cx),
|
||||
ff2: Linear::new_permuted(DEC_FFN_DIM, D_MODEL, true, cx),
|
||||
feed_forward_norm: LayerNorm::new(D_MODEL, true, true, true, 1e-5, cx),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -464,93 +379,67 @@ impl SerializeModule for DecoderTransformerBlock {
|
||||
s.module("encoder_attn", &self.cross_attention);
|
||||
s.module("encoder_attn_layer_norm", &self.cross_attention_norm);
|
||||
s.module("fc1", &self.ff1);
|
||||
s.tensor("fc1/bias", self.ff1_bias);
|
||||
s.module("fc2", &self.ff2);
|
||||
s.tensor("fc2/bias", self.ff2_bias);
|
||||
s.module("final_layer_norm", &self.feed_forward_norm);
|
||||
}
|
||||
}
|
||||
|
||||
pub struct TextDecoder {
|
||||
// Embeddings
|
||||
pub embedding: PermutedEmbedding<VOCAB_SIZE, D_MODEL>,
|
||||
pub pos_embedding: GraphTensor<R2<MAX_TARGET_POSITION, D_MODEL>>,
|
||||
pub embedding: Embedding,
|
||||
pub pos_embedding: GraphTensor,
|
||||
// Transformer layers
|
||||
pub layers: Vec<DecoderTransformerBlock>,
|
||||
// Final layer norm
|
||||
pub layer_norm: LayerNorm<D_MODEL>,
|
||||
pub layer_norm: LayerNorm,
|
||||
}
|
||||
|
||||
impl<
|
||||
Batch: Dimension,
|
||||
EncSeq: Dimension,
|
||||
PrevDecSeq: Dimension,
|
||||
CurDecSeq: Dimension,
|
||||
TotDecSeq: Dimension,
|
||||
>
|
||||
Module<(
|
||||
GraphTensor<(Batch, EncSeq, Const<D_MODEL>)>,
|
||||
GraphTensor<(Batch, CurDecSeq)>,
|
||||
&[KVCache<Batch, PrevDecSeq>],
|
||||
PhantomData<TotDecSeq>,
|
||||
)> for TextDecoder
|
||||
{
|
||||
impl Module<(GraphTensor, GraphTensor, &[KVCache])> for TextDecoder {
|
||||
type Output = (
|
||||
GraphTensor<(Batch, CurDecSeq, Const<VOCAB_SIZE>)>,
|
||||
Vec<KVCache<Batch, EncSeq>>, // Encoder projected states
|
||||
Vec<KVCache<Batch, TotDecSeq>>, // Decoder KV cache
|
||||
GraphTensor,
|
||||
Vec<KVCache>, // Encoder projected states
|
||||
Vec<KVCache>, // Decoder KV cache
|
||||
);
|
||||
fn forward(
|
||||
&self,
|
||||
(enc_output, input, cache, _): (
|
||||
GraphTensor<(Batch, EncSeq, Const<D_MODEL>)>,
|
||||
GraphTensor<(Batch, CurDecSeq)>,
|
||||
&[KVCache<Batch, PrevDecSeq>],
|
||||
PhantomData<TotDecSeq>,
|
||||
),
|
||||
(enc_output, input, cache): (GraphTensor, GraphTensor, &[KVCache]),
|
||||
) -> Self::Output {
|
||||
let (_, cur_dec_seq) = input.dims2();
|
||||
let (_, _, prev_dec_seq, _) = cache[0].0.dims4();
|
||||
// Embed text
|
||||
let mut x = self.embedding.forward(input);
|
||||
x += self
|
||||
.pos_embedding
|
||||
.slice((
|
||||
PrevDecSeq::size()..CurDecSeq::size() + PrevDecSeq::size(),
|
||||
..,
|
||||
))
|
||||
.slice((prev_dec_seq..cur_dec_seq + prev_dec_seq, ..))
|
||||
.contiguous()
|
||||
.realize::<(CurDecSeq, Const<D_MODEL>)>()
|
||||
.expand();
|
||||
.expand_to(x.shape);
|
||||
// Run through layers and collect new caches
|
||||
let (mut new_caches, mut enc_states) = (vec![], vec![]);
|
||||
let (mut new_cache, mut enc_state);
|
||||
for (i, layer) in self.layers.iter().enumerate() {
|
||||
(x, enc_state, new_cache) =
|
||||
layer.forward((x, enc_output, cache[i], PhantomData::<TotDecSeq>));
|
||||
(x, enc_state, new_cache) = layer.forward((x, enc_output, cache[i]));
|
||||
new_caches.push(new_cache);
|
||||
enc_states.push(enc_state);
|
||||
}
|
||||
// Run through last norm and output projection
|
||||
(
|
||||
self.layer_norm.forward(x).matmul(
|
||||
self.embedding
|
||||
.weight
|
||||
.realize::<R2<VOCAB_SIZE, D_MODEL>>()
|
||||
.permute(),
|
||||
),
|
||||
self.layer_norm
|
||||
.forward(x)
|
||||
.matmul(self.embedding.weight.permute((1, 0))),
|
||||
enc_states,
|
||||
new_caches,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl InitModule for TextDecoder {
|
||||
fn initialize(cx: &mut Graph) -> Self {
|
||||
impl TextDecoder {
|
||||
pub fn new(cx: &mut Graph) -> Self {
|
||||
Self {
|
||||
embedding: InitModule::initialize(cx),
|
||||
pos_embedding: cx.tensor(),
|
||||
layer_norm: LayerNorm::new(true, true, true, 1e-5, cx),
|
||||
embedding: Embedding::new_permuted(VOCAB_SIZE, D_MODEL, cx),
|
||||
pos_embedding: cx.tensor((MAX_TARGET_POSITION, D_MODEL)),
|
||||
layer_norm: LayerNorm::new(D_MODEL, true, true, true, 1e-5, cx),
|
||||
layers: (0..DEC_LAYERS)
|
||||
.map(|_| InitModule::initialize(cx))
|
||||
.map(|_| DecoderTransformerBlock::new(cx))
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,25 +4,19 @@ use luminal_nn::Conv2D;
|
||||
struct Bottleneck<const CH_IN: usize, const CH_OUT: usize> {
|
||||
cv1: Conv2D<CH_IN, CH_OUT, 3, 3, 1, 1>,
|
||||
cv2: Conv2D<CH_OUT, CH_OUT, 3, 3, 1, 1>,
|
||||
residual: bool,
|
||||
}
|
||||
|
||||
impl<
|
||||
const CH_IN: usize,
|
||||
const CH_OUT: usize,
|
||||
Batch: Dimension,
|
||||
Width: Dimension,
|
||||
Height: Dimension,
|
||||
> Module<GraphTensor<(Batch, Const<CH_IN>, Height, Width)>> for Bottleneck<CH_IN, CH_OUT>
|
||||
impl<const CH_IN: usize, const CH_OUT: usize, Width: Dimension, Height: Dimension>
|
||||
Module<GraphTensor<(Const<CH_IN>, Height, Width)>> for Bottleneck<CH_IN, CH_OUT>
|
||||
{
|
||||
type Output = GraphTensor<(Batch, Const<CH_OUT>, Height, Width)>;
|
||||
fn forward(&self, input: GraphTensor<(Batch, Const<CH_IN>, Height, Width)>) -> Self::Output {
|
||||
let out = self.cv2.forward(self.cv1.forward(input));
|
||||
if self.residual {
|
||||
out + input
|
||||
} else {
|
||||
out
|
||||
}
|
||||
type Output = GraphTensor<(Const<CH_OUT>, Height, Width)>;
|
||||
fn forward(&self, input: GraphTensor<(Const<CH_IN>, Height, Width)>) -> Self::Output {
|
||||
self.cv2
|
||||
.forward(
|
||||
self.cv1
|
||||
.forward::<Width, Height, Width, Height>(input.permute()),
|
||||
)
|
||||
.permute()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
1
examples/yolo_v8/target/.rustc_info.json
Normal file
1
examples/yolo_v8/target/.rustc_info.json
Normal file
@@ -0,0 +1 @@
|
||||
{"rustc_fingerprint":7902964948476546481,"outputs":{"1185988223601034215":{"success":true,"status":"","code":0,"stdout":"___\nlib___.rlib\nlib___.dylib\nlib___.dylib\nlib___.a\nlib___.dylib\n/Users/jafioti/.rustup/toolchains/stable-aarch64-apple-darwin\noff\npacked\nunpacked\n___\nclippy\ndebug_assertions\nfeature=\"cargo-clippy\"\npanic=\"unwind\"\nproc_macro\ntarget_abi=\"\"\ntarget_arch=\"aarch64\"\ntarget_endian=\"little\"\ntarget_env=\"\"\ntarget_family=\"unix\"\ntarget_feature=\"aes\"\ntarget_feature=\"crc\"\ntarget_feature=\"dit\"\ntarget_feature=\"dotprod\"\ntarget_feature=\"dpb\"\ntarget_feature=\"dpb2\"\ntarget_feature=\"fcma\"\ntarget_feature=\"fhm\"\ntarget_feature=\"flagm\"\ntarget_feature=\"fp16\"\ntarget_feature=\"frintts\"\ntarget_feature=\"jsconv\"\ntarget_feature=\"lor\"\ntarget_feature=\"lse\"\ntarget_feature=\"neon\"\ntarget_feature=\"paca\"\ntarget_feature=\"pacg\"\ntarget_feature=\"pan\"\ntarget_feature=\"pmuv3\"\ntarget_feature=\"ras\"\ntarget_feature=\"rcpc\"\ntarget_feature=\"rcpc2\"\ntarget_feature=\"rdm\"\ntarget_feature=\"sb\"\ntarget_feature=\"sha2\"\ntarget_feature=\"sha3\"\ntarget_feature=\"ssbs\"\ntarget_feature=\"vh\"\ntarget_has_atomic=\"128\"\ntarget_has_atomic=\"16\"\ntarget_has_atomic=\"32\"\ntarget_has_atomic=\"64\"\ntarget_has_atomic=\"8\"\ntarget_has_atomic=\"ptr\"\ntarget_os=\"macos\"\ntarget_pointer_width=\"64\"\ntarget_vendor=\"apple\"\nunix\n","stderr":""},"4614504638168534921":{"success":true,"status":"","code":0,"stdout":"rustc 1.78.0 (9b00956e5 2024-04-29)\nbinary: rustc\ncommit-hash: 9b00956e56009bab2aa15d7bff10916599e3d6d6\ncommit-date: 2024-04-29\nhost: aarch64-apple-darwin\nrelease: 1.78.0\nLLVM version: 18.1.2\n","stderr":""}},"successes":{}}
|
||||
3
examples/yolo_v8/target/CACHEDIR.TAG
Normal file
3
examples/yolo_v8/target/CACHEDIR.TAG
Normal file
@@ -0,0 +1,3 @@
|
||||
Signature: 8a477f597d28d172789f06886806bc55
|
||||
# This file is a cache directory tag created by cargo.
|
||||
# For information about cache directory tags see https://bford.info/cachedir/
|
||||
0
examples/yolo_v8/target/debug/.cargo-lock
Normal file
0
examples/yolo_v8/target/debug/.cargo-lock
Normal file
@@ -0,0 +1 @@
|
||||
fa9fe4392d2669ba
|
||||
@@ -0,0 +1 @@
|
||||
{"rustc":792111255936306319,"features":"[]","declared_features":"","target":2297296889237502566,"profile":1200860260873630964,"path":14319948826174377348,"deps":[[16079472387499994964,"version_check",false,11544101210379487998]],"local":[{"CheckDepInfo":{"dep_info":"debug/.fingerprint/ahash-278451fb35d9bcf6/dep-build-script-build-script-build"}}],"rustflags":[],"metadata":6548036084630991988,"config":2202906307356721367,"compile_kind":0}
|
||||
Binary file not shown.
@@ -0,0 +1 @@
|
||||
This file has an mtime of when this was started.
|
||||
@@ -0,0 +1 @@
|
||||
9a028b8532108f53
|
||||
@@ -0,0 +1 @@
|
||||
{"rustc":792111255936306319,"features":"","declared_features":"","target":0,"profile":0,"path":0,"deps":[[1385435641494999048,"build_script_build",false,13432309339295883258]],"local":[{"RerunIfChanged":{"output":"debug/build/ahash-8a7ee4fbe4c9c642/output","paths":["build.rs"]}}],"rustflags":[],"metadata":0,"config":0,"compile_kind":0}
|
||||
Binary file not shown.
@@ -0,0 +1 @@
|
||||
This file has an mtime of when this was started.
|
||||
@@ -0,0 +1 @@
|
||||
281ba81841e89abc
|
||||
@@ -0,0 +1 @@
|
||||
{"rustc":792111255936306319,"features":"[]","declared_features":"","target":295758560010665018,"profile":3797293754785534760,"path":12569880893325353741,"deps":[[1385435641494999048,"build_script_build",false,6021049035992531610],[4254328441789853856,"once_cell",false,11839418217726875033],[11228387426131597774,"getrandom",false,3943971895081466132]],"local":[{"CheckDepInfo":{"dep_info":"debug/.fingerprint/ahash-fa8dba1057def52f/dep-lib-ahash"}}],"rustflags":[],"metadata":6548036084630991988,"config":2202906307356721367,"compile_kind":0}
|
||||
Binary file not shown.
@@ -0,0 +1 @@
|
||||
This file has an mtime of when this was started.
|
||||
@@ -0,0 +1 @@
|
||||
fcb0e000d91e225c
|
||||
@@ -0,0 +1 @@
|
||||
{"rustc":792111255936306319,"features":"[\"default\", \"perf-literal\", \"std\"]","declared_features":"","target":12812136000324506373,"profile":3797293754785534760,"path":917586311715097687,"deps":[[554324495028472449,"memchr",false,12018462279246210472]],"local":[{"CheckDepInfo":{"dep_info":"debug/.fingerprint/aho-corasick-5c94172d53874de6/dep-lib-aho_corasick"}}],"rustflags":[],"metadata":13904389431191498124,"config":2202906307356721367,"compile_kind":0}
|
||||
Binary file not shown.
@@ -0,0 +1 @@
|
||||
This file has an mtime of when this was started.
|
||||
@@ -0,0 +1 @@
|
||||
cd11a3772bf6054d
|
||||
@@ -0,0 +1 @@
|
||||
{"rustc":792111255936306319,"features":"[\"auto\", \"default\", \"wincon\"]","declared_features":"","target":16157420304466204941,"profile":14808885745152445323,"path":13571762662249164363,"deps":[[1200817279721127204,"anstyle_query",false,5233170534109397214],[8720183142424604966,"utf8parse",false,8030522589549628100],[12423115053093093635,"is_terminal_polyfill",false,14350458276302548626],[15873833695005184023,"colorchoice",false,1643507399289422721],[16609132864249042075,"anstyle_parse",false,12395459992756225163],[16999472572377377103,"anstyle",false,17572025060589194804]],"local":[{"CheckDepInfo":{"dep_info":"debug/.fingerprint/anstream-1c56c6d6e6c3fc4a/dep-lib-anstream"}}],"rustflags":[],"metadata":7500874485387469444,"config":2202906307356721367,"compile_kind":0}
|
||||
Binary file not shown.
@@ -0,0 +1 @@
|
||||
This file has an mtime of when this was started.
|
||||
@@ -0,0 +1 @@
|
||||
344eae28b15fdcf3
|
||||
@@ -0,0 +1 @@
|
||||
{"rustc":792111255936306319,"features":"[\"default\", \"std\"]","declared_features":"","target":13663407036240438623,"profile":14808885745152445323,"path":15851815754449844370,"deps":[],"local":[{"CheckDepInfo":{"dep_info":"debug/.fingerprint/anstyle-580976b7c8c459e7/dep-lib-anstyle"}}],"rustflags":[],"metadata":14064844656010464607,"config":2202906307356721367,"compile_kind":0}
|
||||
Binary file not shown.
@@ -0,0 +1 @@
|
||||
This file has an mtime of when this was started.
|
||||
@@ -0,0 +1 @@
|
||||
8b04cd7f308505ac
|
||||
@@ -0,0 +1 @@
|
||||
{"rustc":792111255936306319,"features":"[\"default\", \"utf8\"]","declared_features":"","target":1993415851866499831,"profile":14808885745152445323,"path":16158110412718354121,"deps":[[8720183142424604966,"utf8parse",false,8030522589549628100]],"local":[{"CheckDepInfo":{"dep_info":"debug/.fingerprint/anstyle-parse-c08eba01cda9523e/dep-lib-anstyle-parse"}}],"rustflags":[],"metadata":9799137552285937175,"config":2202906307356721367,"compile_kind":0}
|
||||
Binary file not shown.
@@ -0,0 +1 @@
|
||||
This file has an mtime of when this was started.
|
||||
@@ -0,0 +1 @@
|
||||
dec0a1cedff49f48
|
||||
@@ -0,0 +1 @@
|
||||
{"rustc":792111255936306319,"features":"[]","declared_features":"","target":8921697713841910856,"profile":14808885745152445323,"path":17832198768418750173,"deps":[],"local":[{"CheckDepInfo":{"dep_info":"debug/.fingerprint/anstyle-query-549d4ad21e8f4c08/dep-lib-anstyle-query"}}],"rustflags":[],"metadata":10674566383365303417,"config":2202906307356721367,"compile_kind":0}
|
||||
Binary file not shown.
@@ -0,0 +1 @@
|
||||
This file has an mtime of when this was started.
|
||||
@@ -0,0 +1 @@
|
||||
536fba925d9106a8
|
||||
@@ -0,0 +1 @@
|
||||
{"rustc":792111255936306319,"features":"[]","declared_features":"","target":5730234523381508605,"profile":3797293754785534760,"path":7658171823302188217,"deps":[],"local":[{"CheckDepInfo":{"dep_info":"debug/.fingerprint/as-any-66dead5baaa2ce79/dep-lib-as-any"}}],"rustflags":[],"metadata":7688359999514647402,"config":2202906307356721367,"compile_kind":0}
|
||||
Binary file not shown.
@@ -0,0 +1 @@
|
||||
This file has an mtime of when this was started.
|
||||
@@ -0,0 +1 @@
|
||||
5ff9718763cf4f55
|
||||
@@ -0,0 +1 @@
|
||||
{"rustc":792111255936306319,"features":"[]","declared_features":"","target":14886237245231788030,"profile":1200860260873630964,"path":16811576535408259928,"deps":[],"local":[{"CheckDepInfo":{"dep_info":"debug/.fingerprint/autocfg-0446d81e8a877634/dep-lib-autocfg"}}],"rustflags":[],"metadata":13102859075309379048,"config":2202906307356721367,"compile_kind":0}
|
||||
Binary file not shown.
@@ -0,0 +1 @@
|
||||
This file has an mtime of when this was started.
|
||||
@@ -0,0 +1 @@
|
||||
699e6c6c8950ef05
|
||||
@@ -0,0 +1 @@
|
||||
{"rustc":792111255936306319,"features":"[\"default\", \"std\"]","declared_features":"","target":16778825523953873731,"profile":3797293754785534760,"path":1249145684557563819,"deps":[],"local":[{"CheckDepInfo":{"dep_info":"debug/.fingerprint/base64-ea033a9df982ece9/dep-lib-base64"}}],"rustflags":[],"metadata":13936919950537592407,"config":2202906307356721367,"compile_kind":0}
|
||||
Binary file not shown.
@@ -0,0 +1 @@
|
||||
This file has an mtime of when this was started.
|
||||
@@ -0,0 +1 @@
|
||||
72015ae1027454db
|
||||
@@ -0,0 +1 @@
|
||||
{"rustc":792111255936306319,"features":"[\"default\"]","declared_features":"","target":15712369643656012375,"profile":3797293754785534760,"path":12968319784194130589,"deps":[],"local":[{"CheckDepInfo":{"dep_info":"debug/.fingerprint/bitflags-9f46fb88979eb9ab/dep-lib-bitflags"}}],"rustflags":[],"metadata":14564035643000669268,"config":2202906307356721367,"compile_kind":0}
|
||||
Binary file not shown.
@@ -0,0 +1 @@
|
||||
This file has an mtime of when this was started.
|
||||
@@ -0,0 +1 @@
|
||||
5d13ff0869c86fc5
|
||||
@@ -0,0 +1 @@
|
||||
{"rustc":792111255936306319,"features":"[\"default\", \"std\"]","declared_features":"","target":18335588937564793828,"profile":3797293754785534760,"path":16883050064718116216,"deps":[],"local":[{"CheckDepInfo":{"dep_info":"debug/.fingerprint/byteorder-221dfff828506489/dep-lib-byteorder"}}],"rustflags":[],"metadata":5398730104718078656,"config":2202906307356721367,"compile_kind":0}
|
||||
Binary file not shown.
@@ -0,0 +1 @@
|
||||
This file has an mtime of when this was started.
|
||||
@@ -0,0 +1 @@
|
||||
0a15848cbceaa920
|
||||
@@ -0,0 +1 @@
|
||||
{"rustc":792111255936306319,"features":"[]","declared_features":"","target":15023190189141807623,"profile":1200860260873630964,"path":14622872267285471233,"deps":[],"local":[{"CheckDepInfo":{"dep_info":"debug/.fingerprint/cc-a8cbe2cdd92a7374/dep-lib-cc"}}],"rustflags":[],"metadata":5862599371499774553,"config":2202906307356721367,"compile_kind":0}
|
||||
Binary file not shown.
@@ -0,0 +1 @@
|
||||
This file has an mtime of when this was started.
|
||||
@@ -0,0 +1 @@
|
||||
0f5b677c9f3650eb
|
||||
@@ -0,0 +1 @@
|
||||
{"rustc":792111255936306319,"features":"[]","declared_features":"","target":10623512480563079566,"profile":3797293754785534760,"path":16945596033860744345,"deps":[],"local":[{"CheckDepInfo":{"dep_info":"debug/.fingerprint/cfg-if-45b3cfd05c61975a/dep-lib-cfg-if"}}],"rustflags":[],"metadata":8462187951337715540,"config":2202906307356721367,"compile_kind":0}
|
||||
Binary file not shown.
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user