Files
luminal/examples/whisper/src/main.rs
2024-05-27 19:34:38 -05:00

190 lines
6.8 KiB
Rust

#![allow(unused)]
use std::{io::Write, marker::PhantomData};
use itertools::Itertools;
// WIP
use luminal::prelude::*;
use model::KVCache;
use tokenizers::Tokenizer;
mod audio;
mod loader;
mod model;
fn main() {
let tokenizer = Tokenizer::from_file("setup/tokenizer.json").unwrap();
print!("Defining graph");
std::io::stdout().flush().unwrap();
let now = std::time::Instant::now();
// Construct encoder graph
let mut enc_cx = Graph::new();
let encoder = model::AudioEncoder::initialize(&mut enc_cx);
let mut encoder_params = params(&encoder);
enc_cx.keep_tensors(&encoder_params);
let mut audio_input = enc_cx.tensor::<(Const<1>, Const<{ model::N_MEL_BINS }>, Dyn<'s'>)>();
let mut encoded: GraphTensor<(Const<1>, Dyn<'d'>, Const<384>)> = encoder
.forward((audio_input, PhantomData::<Dyn<'d'>>))
.keep();
loader::load("setup/whisper-tiny.safetensors", &encoder, &mut enc_cx);
// Construct decoder graph
let mut dec_cx = Graph::new();
let decoder = model::TextDecoder::initialize(&mut dec_cx);
let mut decoder_params = params(&decoder);
dec_cx.keep_tensors(&decoder_params);
let mut text_input = dec_cx.tensor::<(Const<1>, Dyn<'s'>)>();
let mut encoder_output =
dec_cx.named_tensor::<(Const<1>, Dyn<'e'>, Const<{ model::D_MODEL }>)>("Enc Output");
let mut cache_src: Vec<KVCache<Const<1>, Dyn<'p'>>> = (0..model::DEC_LAYERS)
.map(|_| (dec_cx.named_tensor("Keys"), dec_cx.named_tensor("Values")))
.collect();
cache_src.set_dyn(vec![], &[1, 6, 64, 0]);
let (logits, _, mut cache_dest) = decoder.forward((
encoder_output,
text_input,
&cache_src,
PhantomData::<Dyn<'t'>>,
));
let mut logits = logits
.slice((.., Expression::from('s') - 1.., ..))
.retrieve();
cache_dest.keep();
loader::load("setup/whisper-tiny.safetensors", &decoder, &mut dec_cx);
// Compile graphs
println!("\t\t - {}ms", now.elapsed().as_millis());
print!("Compiling graph");
std::io::stdout().flush().unwrap();
let now = std::time::Instant::now();
enc_cx.compile(
(
GenericCompiler::default(),
#[cfg(feature = "metal")]
luminal_metal::MetalCompiler::<f16>::default(),
#[cfg(feature = "cuda")]
luminal_cuda::CudaCompiler::<f32>::default(),
#[cfg(all(not(feature = "metal"), not(feature = "cuda")))]
luminal_cpu::CPUCompiler::default(),
),
(&mut audio_input, &mut encoded, &mut encoder_params),
);
dec_cx.compile(
(
GenericCompiler::default(),
#[cfg(feature = "metal")]
luminal_metal::MetalCompiler::<f16>::default(),
#[cfg(feature = "cuda")]
luminal_cuda::CudaCompiler::<f32>::default(),
#[cfg(all(not(feature = "metal"), not(feature = "cuda")))]
luminal_cpu::CPUCompiler::default(),
),
(
&mut text_input,
&mut encoder_output,
&mut cache_src,
&mut cache_dest,
&mut logits,
&mut decoder_params,
),
);
let cache_src = downstream(cache_src, &dec_cx);
let encoder_output = downstream(encoder_output, &dec_cx);
dec_cx.keep_tensors(&encoder_output);
delete_inputs(&encoder_output, &mut dec_cx);
println!("\t\t - {}ms", now.elapsed().as_millis());
// Load weights
print!("Loading weights");
std::io::stdout().flush().unwrap();
let now = std::time::Instant::now();
audio_input.set_dyn(vec![0.; 160], &[1, 80, 2]);
enc_cx.set_dyn_dim('d', 1);
enc_cx.execute();
delete_inputs(downstream(encoder_params, &enc_cx), &mut enc_cx);
text_input.set_dyn(vec![0.], &[1, 1]);
dec_cx.set_dyn_dim('e', 1);
dec_cx.set_dyn_dim('p', 0);
dec_cx.set_dyn_dim('t', 1);
transfer_data(encoded, &mut enc_cx, &encoder_output, &mut dec_cx);
dec_cx.execute();
logits.drop();
transfer_data_same_graph(&cache_dest, &cache_src, &mut dec_cx);
delete_inputs(&cache_src, &mut dec_cx);
delete_inputs(&downstream(decoder_params, &dec_cx), &mut dec_cx);
println!("\t\t - {}ms", now.elapsed().as_millis());
// Process audio into mel spectrogram
let mel_bytes = include_bytes!("../setup/melfilters.bytes").as_slice();
let mut mel_filters = vec![0f32; mel_bytes.len() / 4];
<byteorder::LittleEndian as byteorder::ByteOrder>::read_f32_into(mel_bytes, &mut mel_filters);
let (pcm_data, _) = audio::pcm_decode("setup/jfk.wav").unwrap();
let mel = audio::pcm_to_mel(80, &pcm_data, &mel_filters);
let mel_len = mel.len();
// Encode audio
print!("Encoding audio");
std::io::stdout().flush().unwrap();
let start_encoding = std::time::Instant::now();
audio_input.set_dyn(mel, &[1, 80, mel_len / 80]);
enc_cx.set_dyn_dim('d', (mel_len / 80) / 2);
enc_cx.execute();
transfer_data(encoded, &mut enc_cx, encoder_output, &mut dec_cx);
println!("\t\t - {}ms", start_encoding.elapsed().as_millis());
// Decode text
let start_decode = std::time::Instant::now();
dec_cx.set_dyn_dim('e', mel_len / 80 / 2);
dec_cx.set_dyn_dim('p', 0);
dec_cx.set_dyn_dim('t', 3);
let mut output_ids = vec![];
text_input.set_dyn(vec![50257., 50358., 50362.], &[1, 3]);
dec_cx.execute();
let mut output_token = argmax(&logits.data());
logits.drop();
output_ids.push(output_token);
let output_str = tokenizer.decode(&output_ids, false).unwrap();
print!("{}", output_str.trim());
std::io::stdout().flush().unwrap();
let mut prev_output_len = output_str.len();
for i in 0..100 {
transfer_data_same_graph(&cache_dest, &cache_src, &mut dec_cx);
text_input.set_dyn(vec![output_token as f32], &[1, 1]);
dec_cx.set_dyn_dim('p', i + 3);
dec_cx.set_dyn_dim('t', i + 4);
dec_cx.execute();
output_token = argmax(&logits.data());
if output_token == 50256 {
println!();
break;
}
logits.drop();
output_ids.push(output_token);
let output_str = tokenizer.decode(&output_ids, false).unwrap();
print!("{}", &output_str[prev_output_len..]);
std::io::stdout().flush().unwrap();
prev_output_len = output_str.len();
}
let avg_token_time =
start_decode.elapsed().as_micros() as f32 / (output_ids.len() - 3) as f32 / 1000.0;
println!(
"\nAverage token generated in {:.2}ms\t - ({:.2} tok/s)",
avg_token_time,
1000.0 / avg_token_time
);
println!(
"Total transcription time - {}ms",
start_encoding.elapsed().as_millis()
);
}
fn argmax(dist: &[f32]) -> u32 {
dist.iter()
.position_max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.unwrap() as u32
}