Merge pull request #42 from TheSeamau5/main

Fix tokenizer issue by switching to HF Tokenizers
This commit is contained in:
Joe Fioti
2024-03-13 22:52:17 -05:00
committed by GitHub
4 changed files with 42 additions and 35 deletions

3
.gitignore vendored
View File

@@ -11,4 +11,5 @@ Cargo.lock
/**/mistral-7b-hf
/**/setup_weights/target
*.model
*.gguf
*.gguf
*.json

View File

@@ -8,13 +8,15 @@ metal = ["dep:luminal_metal", "dep:metal-rs"]
cuda = ["dep:luminal_cuda"]
[dependencies]
luminal = {path="../.."}
luminal_metal = {path="../../crates/luminal_metal", optional=true}
luminal_cuda = {path="../../crates/luminal_cuda", optional=true}
rust_tokenizers = "8.1.0"
luminal = { path = "../.." }
luminal_metal = { path = "../../crates/luminal_metal", optional = true }
luminal_cuda = { path = "../../crates/luminal_cuda", optional = true }
clap = { version = "4.4.18", features = ["derive"] }
byteorder = "1.5.0"
memmap2 = "0.9.4"
metal-rs = { version = "0.27.0", package = "metal", features = ["mps"], optional=true }
metal-rs = { version = "0.27.0", package = "metal", features = [
"mps",
], optional = true }
colored = "2.1.0"
itertools = "0.12.1"
tokenizers = "0.15.2"

View File

@@ -2,7 +2,7 @@
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
echo "Downloading Tokenizer"
curl --location https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2/resolve/main/tokenizer.model?download=true --output $SCRIPT_DIR/mistral_tokenizer.model
curl --location https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2/resolve/main/tokenizer.json?download=true --output $SCRIPT_DIR/mistral_tokenizer.json
echo "Downloading Model"
curl --location https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.2-GGUF/resolve/main/mistral-7b-instruct-v0.2.Q8_0.gguf?download=true --output $SCRIPT_DIR/mistral-7b-instruct-v0.2.Q8_0.gguf
echo "Done Downloading Model"

View File

@@ -6,7 +6,7 @@ use std::{
use clap::Parser;
use colored::Colorize;
use rust_tokenizers::tokenizer::{SentencePieceBpeTokenizer, Tokenizer, TruncationStrategy};
use tokenizers::Tokenizer;
mod gguf;
mod loader;
@@ -30,8 +30,7 @@ pub struct CLIArgs {
fn main() {
let cli_args = CLIArgs::parse();
let tokenizer =
SentencePieceBpeTokenizer::from_file("setup/mistral_tokenizer.model", false).unwrap();
let tokenizer = Tokenizer::from_file("setup/mistral_tokenizer.json").unwrap();
print!("Defining graph");
io::stdout().flush().unwrap();
@@ -53,12 +52,11 @@ fn main() {
cache_dest.keep();
// Set up model loading
let model_path = "setup/mistral-7b-instruct-v0.2.Q8_0.gguf";
#[cfg(feature = "metal")]
let quantized_weight_nodes =
loader::MetalQ8Loader::new("setup/mistral-7b-instruct-v0.2.Q8_0.gguf")
.load(&model, &mut cx);
let quantized_weight_nodes = loader::MetalQ8Loader::new(model_path).load(&model, &mut cx);
#[cfg(not(feature = "metal"))]
loader::Q8Loader::new("setup/mistral-7b-instruct-v0.2.Q8_0.gguf").load(&model, &mut cx);
loader::Q8Loader::new(model_path).load(&model, &mut cx);
println!("\t\t - {}ms", now.elapsed().as_millis());
print!("Compiling graph");
@@ -109,20 +107,19 @@ fn main() {
cx.execute();
let elapsed_ms = now.elapsed().as_millis();
println!(
"\t - {elapsed_ms}ms ({:.2} tok/s)",
1000.0 * (input_ids.len() as f64) / (elapsed_ms as f64)
"\t - {elapsed_ms}ms ({:.2} tok/s, {} prompt tokens)",
1000.0 * (input_ids.len() as f64) / (elapsed_ms as f64),
input_ids.len()
);
delete_inputs(&cache_src_set, &mut cx);
let output_id = sample_index(&logits.data());
logits.drop();
input_ids.push(output_id);
let mut output_ids = vec![output_id];
// Decode token
print!(
"{}{}",
cli_args.prompt.white().bold(),
decode(&tokenizer, &[output_id]).bright_green()
);
print!("{}", cli_args.prompt.white().bold());
io::stdout().flush().unwrap();
// Swap caches
@@ -130,11 +127,11 @@ fn main() {
// Decode loop
let mut token_decode_times = vec![];
let mut prev_output = String::new();
for _ in 0..cli_args.gen_tokens {
input.set_dyn(vec![*input_ids.last().unwrap() as f32], &[1, 1]);
cx.set_dyn_dim('p', input_ids.len() - 1);
cx.set_dyn_dim('t', input_ids.len());
let now = Instant::now();
cx.execute();
token_decode_times.push(now.elapsed().as_micros());
@@ -143,12 +140,24 @@ fn main() {
let output_id = sample_index(&logits.data());
logits.drop();
input_ids.push(output_id);
print!("{}", decode(&tokenizer, &[output_id]).bright_green());
output_ids.push(output_id);
// Get the current decoded output
let current_output = decode(&tokenizer, &output_ids);
// Print the new substring added to the decoded output
let new_substring = &current_output[prev_output.len()..];
print!("{}", new_substring.bright_green());
io::stdout().flush().unwrap();
// Update the previous output
prev_output = current_output;
// Swap caches
transfer_data_same_graph(&cache_dest_set, &cache_src_set, &mut cx);
}
println!();
let avg_token_time = token_decode_times
.iter()
.map(|t| *t as f32 / 1000.)
@@ -161,25 +170,20 @@ fn main() {
);
}
fn encode(tokenizer: &SentencePieceBpeTokenizer, text: &str) -> Vec<i64> {
let mut vector = tokenizer
.encode(text, None, text.len(), &TruncationStrategy::LongestFirst, 0)
.token_ids;
vector.insert(0, 1); // Start token
vector
fn encode(tokenizer: &Tokenizer, text: &str) -> Vec<u32> {
let vector = tokenizer.encode(text, false).unwrap();
vector.get_ids().to_owned()
}
fn decode(tokenizer: &SentencePieceBpeTokenizer, token_ids: &[i64]) -> String {
tokenizer
.decode(token_ids, true, false)
.replace("<0x0A>", "\n")
fn decode(tokenizer: &Tokenizer, token_ids: &[u32]) -> String {
tokenizer.decode(&token_ids, false).unwrap()
}
// Currently just an argmax, do actual sampling here
fn sample_index(dist: &[f32]) -> i64 {
fn sample_index(dist: &[f32]) -> u32 {
dist.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.unwrap()
.0 as i64
.0 as u32
}