forked from Rust-related/luminal
Merge pull request #42 from TheSeamau5/main
Fix tokenizer issue by switching to HF Tokenizers
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -11,4 +11,5 @@ Cargo.lock
|
||||
/**/mistral-7b-hf
|
||||
/**/setup_weights/target
|
||||
*.model
|
||||
*.gguf
|
||||
*.gguf
|
||||
*.json
|
||||
@@ -8,13 +8,15 @@ metal = ["dep:luminal_metal", "dep:metal-rs"]
|
||||
cuda = ["dep:luminal_cuda"]
|
||||
|
||||
[dependencies]
|
||||
luminal = {path="../.."}
|
||||
luminal_metal = {path="../../crates/luminal_metal", optional=true}
|
||||
luminal_cuda = {path="../../crates/luminal_cuda", optional=true}
|
||||
rust_tokenizers = "8.1.0"
|
||||
luminal = { path = "../.." }
|
||||
luminal_metal = { path = "../../crates/luminal_metal", optional = true }
|
||||
luminal_cuda = { path = "../../crates/luminal_cuda", optional = true }
|
||||
clap = { version = "4.4.18", features = ["derive"] }
|
||||
byteorder = "1.5.0"
|
||||
memmap2 = "0.9.4"
|
||||
metal-rs = { version = "0.27.0", package = "metal", features = ["mps"], optional=true }
|
||||
metal-rs = { version = "0.27.0", package = "metal", features = [
|
||||
"mps",
|
||||
], optional = true }
|
||||
colored = "2.1.0"
|
||||
itertools = "0.12.1"
|
||||
tokenizers = "0.15.2"
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
|
||||
|
||||
echo "Downloading Tokenizer"
|
||||
curl --location https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2/resolve/main/tokenizer.model?download=true --output $SCRIPT_DIR/mistral_tokenizer.model
|
||||
curl --location https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2/resolve/main/tokenizer.json?download=true --output $SCRIPT_DIR/mistral_tokenizer.json
|
||||
echo "Downloading Model"
|
||||
curl --location https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.2-GGUF/resolve/main/mistral-7b-instruct-v0.2.Q8_0.gguf?download=true --output $SCRIPT_DIR/mistral-7b-instruct-v0.2.Q8_0.gguf
|
||||
echo "Done Downloading Model"
|
||||
|
||||
@@ -6,7 +6,7 @@ use std::{
|
||||
|
||||
use clap::Parser;
|
||||
use colored::Colorize;
|
||||
use rust_tokenizers::tokenizer::{SentencePieceBpeTokenizer, Tokenizer, TruncationStrategy};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
mod gguf;
|
||||
mod loader;
|
||||
@@ -30,8 +30,7 @@ pub struct CLIArgs {
|
||||
|
||||
fn main() {
|
||||
let cli_args = CLIArgs::parse();
|
||||
let tokenizer =
|
||||
SentencePieceBpeTokenizer::from_file("setup/mistral_tokenizer.model", false).unwrap();
|
||||
let tokenizer = Tokenizer::from_file("setup/mistral_tokenizer.json").unwrap();
|
||||
|
||||
print!("Defining graph");
|
||||
io::stdout().flush().unwrap();
|
||||
@@ -53,12 +52,11 @@ fn main() {
|
||||
cache_dest.keep();
|
||||
|
||||
// Set up model loading
|
||||
let model_path = "setup/mistral-7b-instruct-v0.2.Q8_0.gguf";
|
||||
#[cfg(feature = "metal")]
|
||||
let quantized_weight_nodes =
|
||||
loader::MetalQ8Loader::new("setup/mistral-7b-instruct-v0.2.Q8_0.gguf")
|
||||
.load(&model, &mut cx);
|
||||
let quantized_weight_nodes = loader::MetalQ8Loader::new(model_path).load(&model, &mut cx);
|
||||
#[cfg(not(feature = "metal"))]
|
||||
loader::Q8Loader::new("setup/mistral-7b-instruct-v0.2.Q8_0.gguf").load(&model, &mut cx);
|
||||
loader::Q8Loader::new(model_path).load(&model, &mut cx);
|
||||
println!("\t\t - {}ms", now.elapsed().as_millis());
|
||||
|
||||
print!("Compiling graph");
|
||||
@@ -109,20 +107,19 @@ fn main() {
|
||||
cx.execute();
|
||||
let elapsed_ms = now.elapsed().as_millis();
|
||||
println!(
|
||||
"\t - {elapsed_ms}ms ({:.2} tok/s)",
|
||||
1000.0 * (input_ids.len() as f64) / (elapsed_ms as f64)
|
||||
"\t - {elapsed_ms}ms ({:.2} tok/s, {} prompt tokens)",
|
||||
1000.0 * (input_ids.len() as f64) / (elapsed_ms as f64),
|
||||
input_ids.len()
|
||||
);
|
||||
delete_inputs(&cache_src_set, &mut cx);
|
||||
let output_id = sample_index(&logits.data());
|
||||
logits.drop();
|
||||
input_ids.push(output_id);
|
||||
|
||||
let mut output_ids = vec![output_id];
|
||||
|
||||
// Decode token
|
||||
print!(
|
||||
"{}{}",
|
||||
cli_args.prompt.white().bold(),
|
||||
decode(&tokenizer, &[output_id]).bright_green()
|
||||
);
|
||||
print!("{}", cli_args.prompt.white().bold());
|
||||
io::stdout().flush().unwrap();
|
||||
|
||||
// Swap caches
|
||||
@@ -130,11 +127,11 @@ fn main() {
|
||||
|
||||
// Decode loop
|
||||
let mut token_decode_times = vec![];
|
||||
let mut prev_output = String::new();
|
||||
for _ in 0..cli_args.gen_tokens {
|
||||
input.set_dyn(vec![*input_ids.last().unwrap() as f32], &[1, 1]);
|
||||
cx.set_dyn_dim('p', input_ids.len() - 1);
|
||||
cx.set_dyn_dim('t', input_ids.len());
|
||||
|
||||
let now = Instant::now();
|
||||
cx.execute();
|
||||
token_decode_times.push(now.elapsed().as_micros());
|
||||
@@ -143,12 +140,24 @@ fn main() {
|
||||
let output_id = sample_index(&logits.data());
|
||||
logits.drop();
|
||||
input_ids.push(output_id);
|
||||
print!("{}", decode(&tokenizer, &[output_id]).bright_green());
|
||||
output_ids.push(output_id);
|
||||
|
||||
// Get the current decoded output
|
||||
let current_output = decode(&tokenizer, &output_ids);
|
||||
|
||||
// Print the new substring added to the decoded output
|
||||
let new_substring = ¤t_output[prev_output.len()..];
|
||||
print!("{}", new_substring.bright_green());
|
||||
io::stdout().flush().unwrap();
|
||||
|
||||
// Update the previous output
|
||||
prev_output = current_output;
|
||||
|
||||
// Swap caches
|
||||
transfer_data_same_graph(&cache_dest_set, &cache_src_set, &mut cx);
|
||||
}
|
||||
|
||||
println!();
|
||||
let avg_token_time = token_decode_times
|
||||
.iter()
|
||||
.map(|t| *t as f32 / 1000.)
|
||||
@@ -161,25 +170,20 @@ fn main() {
|
||||
);
|
||||
}
|
||||
|
||||
fn encode(tokenizer: &SentencePieceBpeTokenizer, text: &str) -> Vec<i64> {
|
||||
let mut vector = tokenizer
|
||||
.encode(text, None, text.len(), &TruncationStrategy::LongestFirst, 0)
|
||||
.token_ids;
|
||||
vector.insert(0, 1); // Start token
|
||||
vector
|
||||
fn encode(tokenizer: &Tokenizer, text: &str) -> Vec<u32> {
|
||||
let vector = tokenizer.encode(text, false).unwrap();
|
||||
vector.get_ids().to_owned()
|
||||
}
|
||||
|
||||
fn decode(tokenizer: &SentencePieceBpeTokenizer, token_ids: &[i64]) -> String {
|
||||
tokenizer
|
||||
.decode(token_ids, true, false)
|
||||
.replace("<0x0A>", "\n")
|
||||
fn decode(tokenizer: &Tokenizer, token_ids: &[u32]) -> String {
|
||||
tokenizer.decode(&token_ids, false).unwrap()
|
||||
}
|
||||
|
||||
// Currently just an argmax, do actual sampling here
|
||||
fn sample_index(dist: &[f32]) -> i64 {
|
||||
fn sample_index(dist: &[f32]) -> u32 {
|
||||
dist.iter()
|
||||
.enumerate()
|
||||
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
|
||||
.unwrap()
|
||||
.0 as i64
|
||||
.0 as u32
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user