mirror of
https://github.com/jafioti/luminal.git
synced 2026-06-01 21:49:47 +09:00
Compare commits
2 Commits
flashinfer
...
codex/rust
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0b1e09cf23 | ||
|
|
7402503bd4 |
@@ -1,3 +1,5 @@
|
||||
#[path = "../../../examples_common/benchmark_stdio.rs"]
|
||||
mod benchmark_stdio;
|
||||
mod hf;
|
||||
mod model;
|
||||
|
||||
@@ -6,18 +8,14 @@ use luminal::prelude::*;
|
||||
use luminal_cuda_lite::{cudarc::driver::CudaContext, runtime::CudaRuntime};
|
||||
use model::*;
|
||||
use rustc_hash::FxHashSet;
|
||||
use std::{io::Write, time::Duration};
|
||||
use std::{
|
||||
io::Write,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
const REPO_ID: &str = "google/gemma-4-26B-A4B";
|
||||
|
||||
fn env_usize(name: &str, default: usize) -> usize {
|
||||
std::env::var(name)
|
||||
.ok()
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(default)
|
||||
}
|
||||
|
||||
fn env_bool(name: &str) -> bool {
|
||||
std::env::var(name)
|
||||
.ok()
|
||||
@@ -25,9 +23,10 @@ fn env_bool(name: &str) -> bool {
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let max_seq_len = env_usize("MAX_SEQ_LEN", 4096);
|
||||
let gen_tokens = env_usize("GEN_TOKENS", 30);
|
||||
let search_graphs = env_usize("SEARCH_GRAPHS", 50);
|
||||
let stdio = benchmark_stdio::enabled();
|
||||
let max_seq_len = benchmark_stdio::env_usize("MAX_SEQ_LEN", 4096);
|
||||
let gen_tokens = benchmark_stdio::env_usize("GEN_TOKENS", 30);
|
||||
let search_graphs = benchmark_stdio::env_usize("SEARCH_GRAPHS", 50);
|
||||
let prompt = std::env::var("PROMPT").unwrap_or_else(|_| "The capital of France is".to_string());
|
||||
let print_token_ids = env_bool("PRINT_TOKEN_IDS");
|
||||
|
||||
@@ -38,11 +37,6 @@ fn main() {
|
||||
println!("Using model directory: {}", model_dir.display());
|
||||
|
||||
let tokenizer = Tokenizer::from_file(model_dir.join("tokenizer.json")).unwrap();
|
||||
let prompt_tokens = tokenizer
|
||||
.encode(prompt.as_str(), true)
|
||||
.unwrap()
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
|
||||
let mut cx = Graph::default();
|
||||
let input = cx.named_tensor("input", 's').as_dtype(DType::Int);
|
||||
@@ -63,11 +57,14 @@ fn main() {
|
||||
let weights_path = model_dir.join("model_combined.safetensors");
|
||||
runtime.load_safetensors(&cx, weights_path.to_str().unwrap());
|
||||
|
||||
for layer in 0..LAYERS {
|
||||
let cache_bytes = cache_bytes_for_layer(layer, max_seq_len);
|
||||
runtime.set_zeros(kv_cache.k_caches[layer], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[layer], cache_bytes);
|
||||
}
|
||||
let reset_cache = |runtime: &mut CudaRuntime| {
|
||||
for layer in 0..LAYERS {
|
||||
let cache_bytes = cache_bytes_for_layer(layer, max_seq_len);
|
||||
runtime.set_zeros(kv_cache.k_caches[layer], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[layer], cache_bytes);
|
||||
}
|
||||
};
|
||||
reset_cache(&mut runtime);
|
||||
|
||||
println!("Compiling...");
|
||||
cx.set_dim('s', 1);
|
||||
@@ -75,15 +72,66 @@ fn main() {
|
||||
runtime.set_data(input, vec![1]);
|
||||
runtime.set_data(pos_ids, vec![1]);
|
||||
runtime = cx.search(runtime, search_graphs);
|
||||
reset_cache(&mut runtime);
|
||||
|
||||
for layer in 0..LAYERS {
|
||||
let cache_bytes = cache_bytes_for_layer(layer, max_seq_len);
|
||||
runtime.set_zeros(kv_cache.k_caches[layer], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[layer], cache_bytes);
|
||||
if stdio {
|
||||
benchmark_stdio::serve(|prompt| {
|
||||
reset_cache(&mut runtime);
|
||||
run_prompt(
|
||||
prompt,
|
||||
gen_tokens,
|
||||
print_token_ids,
|
||||
&tokenizer,
|
||||
&mut cx,
|
||||
&mut runtime,
|
||||
input,
|
||||
pos_ids,
|
||||
logits,
|
||||
&cache_outputs,
|
||||
&kv_cache,
|
||||
true,
|
||||
);
|
||||
});
|
||||
} else {
|
||||
run_prompt(
|
||||
&prompt,
|
||||
gen_tokens,
|
||||
print_token_ids,
|
||||
&tokenizer,
|
||||
&mut cx,
|
||||
&mut runtime,
|
||||
input,
|
||||
pos_ids,
|
||||
logits,
|
||||
&cache_outputs,
|
||||
&kv_cache,
|
||||
false,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
print!("{prompt}");
|
||||
std::io::stdout().flush().unwrap();
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn run_prompt(
|
||||
prompt: &str,
|
||||
gen_tokens: usize,
|
||||
print_token_ids: bool,
|
||||
tokenizer: &Tokenizer,
|
||||
cx: &mut Graph,
|
||||
runtime: &mut CudaRuntime,
|
||||
input: GraphTensor,
|
||||
pos_ids: GraphTensor,
|
||||
logits: GraphTensor,
|
||||
cache_outputs: &[(GraphTensor, GraphTensor)],
|
||||
kv_cache: &KVCache,
|
||||
stdio: bool,
|
||||
) {
|
||||
let prompt_tokens = tokenizer.encode(prompt, true).unwrap().get_ids().to_vec();
|
||||
let query_start = Instant::now();
|
||||
|
||||
if !stdio {
|
||||
print!("{prompt}");
|
||||
std::io::stdout().flush().unwrap();
|
||||
}
|
||||
|
||||
let mut prev_seq = 0usize;
|
||||
let mut fwd_durations = vec![];
|
||||
@@ -93,7 +141,7 @@ fn main() {
|
||||
|
||||
const EOS_TOKEN: u32 = 1;
|
||||
|
||||
let prefill_start = std::time::Instant::now();
|
||||
let prefill_start = Instant::now();
|
||||
for &token in &prompt_tokens {
|
||||
cx.set_dim('s', 1);
|
||||
cx.set_dim('p', prev_seq);
|
||||
@@ -121,12 +169,26 @@ fn main() {
|
||||
.unwrap()
|
||||
.0 as u32;
|
||||
generated_token_ids.push(next_token);
|
||||
print!("{}", tokenizer.decode(&[next_token], true).unwrap());
|
||||
std::io::stdout().flush().unwrap();
|
||||
let mut generated = 0usize;
|
||||
if stdio {
|
||||
if next_token != EOS_TOKEN {
|
||||
let decoded = tokenizer.decode(&[next_token], true).unwrap();
|
||||
benchmark_stdio::emit_token(&decoded);
|
||||
generated += 1;
|
||||
}
|
||||
} else {
|
||||
let decoded = tokenizer.decode(&[next_token], true).unwrap();
|
||||
print!("{decoded}");
|
||||
std::io::stdout().flush().unwrap();
|
||||
}
|
||||
seen_tokens.insert(next_token);
|
||||
|
||||
for _ in 1..gen_tokens {
|
||||
let start = std::time::Instant::now();
|
||||
if stdio && next_token == EOS_TOKEN {
|
||||
break;
|
||||
}
|
||||
|
||||
let start = Instant::now();
|
||||
cx.set_dim('s', 1);
|
||||
cx.set_dim('p', prev_seq);
|
||||
runtime.set_data(input, vec![next_token as i32]);
|
||||
@@ -165,10 +227,21 @@ fn main() {
|
||||
break;
|
||||
}
|
||||
|
||||
print!("{}", tokenizer.decode(&[next_token], true).unwrap());
|
||||
std::io::stdout().flush().unwrap();
|
||||
let decoded = tokenizer.decode(&[next_token], true).unwrap();
|
||||
if stdio {
|
||||
benchmark_stdio::emit_token(&decoded);
|
||||
} else {
|
||||
print!("{decoded}");
|
||||
std::io::stdout().flush().unwrap();
|
||||
}
|
||||
generated += 1;
|
||||
fwd_durations.push(start.elapsed());
|
||||
}
|
||||
if stdio {
|
||||
benchmark_stdio::emit_eoq(generated, query_start);
|
||||
return;
|
||||
}
|
||||
|
||||
println!();
|
||||
if print_token_ids {
|
||||
println!("Generated token ids: {generated_token_ids:?}");
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
#[path = "../../../examples_common/benchmark_stdio.rs"]
|
||||
mod benchmark_stdio;
|
||||
mod hf;
|
||||
mod model;
|
||||
|
||||
@@ -7,22 +9,36 @@ use luminal_cuda_lite::{cudarc::driver::CudaContext, runtime::CudaRuntime};
|
||||
use luminal_tracing::*;
|
||||
use model::*;
|
||||
use rustc_hash::FxHashSet;
|
||||
use std::{io::Write, time::Duration};
|
||||
use std::{
|
||||
io::Write,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use tokenizers::Tokenizer;
|
||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
||||
|
||||
const REPO_ID: &str = "NousResearch/Meta-Llama-3-8B-Instruct";
|
||||
|
||||
fn main() {
|
||||
let stdio = benchmark_stdio::enabled();
|
||||
let max_seq_len = 4096;
|
||||
let gen_tokens = 500;
|
||||
let search_graphs = 500;
|
||||
let gen_tokens = if stdio {
|
||||
benchmark_stdio::env_usize("GEN_TOKENS", 500)
|
||||
} else {
|
||||
500
|
||||
};
|
||||
let search_graphs = if stdio {
|
||||
benchmark_stdio::env_usize("SEARCH_GRAPHS", 500)
|
||||
} else {
|
||||
500
|
||||
};
|
||||
let prompt = "Explain what a neural network is in a paragraph.";
|
||||
|
||||
tracing_subscriber::registry()
|
||||
.with(tracing_subscriber::fmt::layer())
|
||||
.with(luminal_filter())
|
||||
.init();
|
||||
if !stdio {
|
||||
tracing_subscriber::registry()
|
||||
.with(tracing_subscriber::fmt::layer())
|
||||
.with(luminal_filter())
|
||||
.init();
|
||||
}
|
||||
|
||||
let ctx = CudaContext::new(0).unwrap();
|
||||
let stream = ctx.default_stream();
|
||||
@@ -31,14 +47,6 @@ fn main() {
|
||||
println!("Using model directory: {}", model_dir.display());
|
||||
|
||||
let tokenizer = Tokenizer::from_file(model_dir.join("tokenizer.json")).unwrap();
|
||||
let chat_prompt = format!(
|
||||
"<|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
);
|
||||
let prompt_tokens = tokenizer
|
||||
.encode(chat_prompt.as_str(), true)
|
||||
.unwrap()
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
|
||||
// Build graph
|
||||
let mut cx = Graph::default();
|
||||
@@ -66,10 +74,13 @@ fn main() {
|
||||
runtime.load_safetensors(&cx, weights_path.to_str().unwrap());
|
||||
|
||||
let cache_bytes = N_KV_HEADS * max_seq_len * HEAD_DIM * std::mem::size_of::<f32>();
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
|
||||
}
|
||||
let reset_cache = |runtime: &mut CudaRuntime| {
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
|
||||
}
|
||||
};
|
||||
reset_cache(&mut runtime);
|
||||
|
||||
println!("Compiling...");
|
||||
cx.set_dim('s', 1);
|
||||
@@ -77,12 +88,65 @@ fn main() {
|
||||
runtime.set_data(input, vec![1]);
|
||||
runtime.set_data(token_ids, vec![1]);
|
||||
runtime = cx.search(runtime, search_graphs);
|
||||
reset_cache(&mut runtime);
|
||||
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
|
||||
if stdio {
|
||||
benchmark_stdio::serve(|prompt| {
|
||||
reset_cache(&mut runtime);
|
||||
run_prompt(
|
||||
prompt,
|
||||
gen_tokens,
|
||||
&tokenizer,
|
||||
&mut cx,
|
||||
&mut runtime,
|
||||
input,
|
||||
token_ids,
|
||||
logits,
|
||||
&cache_outputs,
|
||||
&kv_cache,
|
||||
true,
|
||||
);
|
||||
});
|
||||
} else {
|
||||
run_prompt(
|
||||
prompt,
|
||||
gen_tokens,
|
||||
&tokenizer,
|
||||
&mut cx,
|
||||
&mut runtime,
|
||||
input,
|
||||
token_ids,
|
||||
logits,
|
||||
&cache_outputs,
|
||||
&kv_cache,
|
||||
false,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn run_prompt(
|
||||
prompt: &str,
|
||||
gen_tokens: usize,
|
||||
tokenizer: &Tokenizer,
|
||||
cx: &mut Graph,
|
||||
runtime: &mut CudaRuntime,
|
||||
input: GraphTensor,
|
||||
token_ids: GraphTensor,
|
||||
logits: GraphTensor,
|
||||
cache_outputs: &[(GraphTensor, GraphTensor)],
|
||||
kv_cache: &KVCache,
|
||||
stdio: bool,
|
||||
) {
|
||||
let chat_prompt = format!(
|
||||
"<|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
);
|
||||
let prompt_tokens = tokenizer
|
||||
.encode(chat_prompt.as_str(), true)
|
||||
.unwrap()
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
let query_start = Instant::now();
|
||||
let mut prev_seq = 1usize;
|
||||
let mut sentence = vec![prompt_tokens[0]];
|
||||
let total_steps = prompt_tokens.len() - 1 + gen_tokens;
|
||||
@@ -94,13 +158,16 @@ fn main() {
|
||||
const EOS_TOKEN: u32 = 128009;
|
||||
const STOP_TOKEN: u32 = 128001;
|
||||
|
||||
println!(
|
||||
"Prompt: {} tokens, generating up to {} tokens",
|
||||
prompt_len, gen_tokens
|
||||
);
|
||||
if !stdio {
|
||||
println!(
|
||||
"Prompt: {} tokens, generating up to {} tokens",
|
||||
prompt_len, gen_tokens
|
||||
);
|
||||
}
|
||||
|
||||
let mut generated = 0usize;
|
||||
for i in 0..total_steps {
|
||||
let start = std::time::Instant::now();
|
||||
let start = Instant::now();
|
||||
let is_prefill = i < prompt_len - 1;
|
||||
let seq_len = sentence.len();
|
||||
|
||||
@@ -159,12 +226,21 @@ fn main() {
|
||||
}
|
||||
|
||||
let decoded = tokenizer.decode(&[next_token], true).unwrap();
|
||||
print!("{}", decoded);
|
||||
std::io::stdout().flush().unwrap();
|
||||
if stdio {
|
||||
benchmark_stdio::emit_token(&decoded);
|
||||
} else {
|
||||
print!("{}", decoded);
|
||||
std::io::stdout().flush().unwrap();
|
||||
}
|
||||
generated += 1;
|
||||
}
|
||||
if stdio {
|
||||
benchmark_stdio::emit_eoq(generated, query_start);
|
||||
return;
|
||||
}
|
||||
println!();
|
||||
|
||||
// Benchmarks
|
||||
println!();
|
||||
let decode_durations: Vec<_> = fwd_durations.iter().skip(prompt_len).collect();
|
||||
if decode_durations.len() > 2 {
|
||||
println!(
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
#[path = "../../../examples_common/benchmark_stdio.rs"]
|
||||
mod benchmark_stdio;
|
||||
mod hf;
|
||||
mod model;
|
||||
|
||||
@@ -7,22 +9,36 @@ use luminal_cuda_lite::{cudarc::driver::CudaContext, runtime::CudaRuntime};
|
||||
use luminal_tracing::*;
|
||||
use model::*;
|
||||
use rustc_hash::FxHashSet;
|
||||
use std::{io::Write, time::Duration};
|
||||
use std::{
|
||||
io::Write,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use tokenizers::Tokenizer;
|
||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
||||
|
||||
const REPO_ID: &str = "Qwen/Qwen3-4B";
|
||||
|
||||
fn main() {
|
||||
let stdio = benchmark_stdio::enabled();
|
||||
let max_seq_len = 4096;
|
||||
let gen_tokens = 500;
|
||||
let search_graphs = 500;
|
||||
let gen_tokens = if stdio {
|
||||
benchmark_stdio::env_usize("GEN_TOKENS", 500)
|
||||
} else {
|
||||
500
|
||||
};
|
||||
let search_graphs = if stdio {
|
||||
benchmark_stdio::env_usize("SEARCH_GRAPHS", 500)
|
||||
} else {
|
||||
500
|
||||
};
|
||||
let prompt = "Explain what a neural network is in a paragraph.";
|
||||
|
||||
tracing_subscriber::registry()
|
||||
.with(tracing_subscriber::fmt::layer())
|
||||
.with(luminal_filter())
|
||||
.init();
|
||||
if !stdio {
|
||||
tracing_subscriber::registry()
|
||||
.with(tracing_subscriber::fmt::layer())
|
||||
.with(luminal_filter())
|
||||
.init();
|
||||
}
|
||||
|
||||
let ctx = CudaContext::new(0).unwrap();
|
||||
let stream = ctx.default_stream();
|
||||
@@ -31,7 +47,6 @@ fn main() {
|
||||
println!("Using model directory: {}", model_dir.display());
|
||||
|
||||
let tokenizer = Tokenizer::from_file(model_dir.join("tokenizer.json")).unwrap();
|
||||
let prompt_tokens = tokenizer.encode(prompt, true).unwrap().get_ids().to_vec();
|
||||
|
||||
// Build graph
|
||||
let mut cx = Graph::default();
|
||||
@@ -54,10 +69,13 @@ fn main() {
|
||||
runtime.load_safetensors(&cx, weights_path.to_str().unwrap());
|
||||
|
||||
let cache_bytes = N_KV_HEADS * max_seq_len * HEAD_DIM * std::mem::size_of::<f32>();
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
|
||||
}
|
||||
let reset_cache = |runtime: &mut CudaRuntime| {
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
|
||||
}
|
||||
};
|
||||
reset_cache(&mut runtime);
|
||||
|
||||
println!("Compiling...");
|
||||
cx.set_dim('s', 1);
|
||||
@@ -65,12 +83,58 @@ fn main() {
|
||||
runtime.set_data(input, vec![1]);
|
||||
runtime.set_data(token_ids, vec![1]);
|
||||
runtime = cx.search(runtime, search_graphs);
|
||||
reset_cache(&mut runtime);
|
||||
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
|
||||
if stdio {
|
||||
benchmark_stdio::serve(|prompt| {
|
||||
reset_cache(&mut runtime);
|
||||
run_prompt(
|
||||
prompt,
|
||||
gen_tokens,
|
||||
&tokenizer,
|
||||
&mut cx,
|
||||
&mut runtime,
|
||||
input,
|
||||
token_ids,
|
||||
logits,
|
||||
&cache_outputs,
|
||||
&kv_cache,
|
||||
true,
|
||||
);
|
||||
});
|
||||
} else {
|
||||
run_prompt(
|
||||
prompt,
|
||||
gen_tokens,
|
||||
&tokenizer,
|
||||
&mut cx,
|
||||
&mut runtime,
|
||||
input,
|
||||
token_ids,
|
||||
logits,
|
||||
&cache_outputs,
|
||||
&kv_cache,
|
||||
false,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn run_prompt(
|
||||
prompt: &str,
|
||||
gen_tokens: usize,
|
||||
tokenizer: &Tokenizer,
|
||||
cx: &mut Graph,
|
||||
runtime: &mut CudaRuntime,
|
||||
input: GraphTensor,
|
||||
token_ids: GraphTensor,
|
||||
logits: GraphTensor,
|
||||
cache_outputs: &[(GraphTensor, GraphTensor)],
|
||||
kv_cache: &KVCache,
|
||||
stdio: bool,
|
||||
) {
|
||||
let prompt_tokens = tokenizer.encode(prompt, true).unwrap().get_ids().to_vec();
|
||||
let query_start = Instant::now();
|
||||
let mut prev_seq = 1usize;
|
||||
let mut sentence = vec![prompt_tokens[0]];
|
||||
let total_steps = prompt_tokens.len() - 1 + gen_tokens;
|
||||
@@ -82,13 +146,16 @@ fn main() {
|
||||
const EOS_TOKEN: u32 = 151645; // <|endoftext|>
|
||||
const STOP_TOKEN: u32 = 151643; // <|end|>
|
||||
|
||||
println!(
|
||||
"Prompt: {} tokens, generating up to {} tokens",
|
||||
prompt_len, gen_tokens
|
||||
);
|
||||
if !stdio {
|
||||
println!(
|
||||
"Prompt: {} tokens, generating up to {} tokens",
|
||||
prompt_len, gen_tokens
|
||||
);
|
||||
}
|
||||
|
||||
let mut generated = 0usize;
|
||||
for i in 0..total_steps {
|
||||
let start = std::time::Instant::now();
|
||||
let start = Instant::now();
|
||||
let is_prefill = i < prompt_len - 1;
|
||||
let seq_len = sentence.len();
|
||||
|
||||
@@ -147,12 +214,21 @@ fn main() {
|
||||
}
|
||||
|
||||
let decoded = tokenizer.decode(&[next_token], true).unwrap();
|
||||
print!("{}", decoded);
|
||||
std::io::stdout().flush().unwrap();
|
||||
if stdio {
|
||||
benchmark_stdio::emit_token(&decoded);
|
||||
} else {
|
||||
print!("{}", decoded);
|
||||
std::io::stdout().flush().unwrap();
|
||||
}
|
||||
generated += 1;
|
||||
}
|
||||
if stdio {
|
||||
benchmark_stdio::emit_eoq(generated, query_start);
|
||||
return;
|
||||
}
|
||||
println!();
|
||||
|
||||
// Benchmarks
|
||||
println!();
|
||||
let decode_durations: Vec<_> = fwd_durations.iter().skip(prompt_len).collect();
|
||||
if decode_durations.len() > 2 {
|
||||
println!(
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
#[path = "../../../examples_common/benchmark_stdio.rs"]
|
||||
mod benchmark_stdio;
|
||||
mod hf;
|
||||
mod model;
|
||||
|
||||
@@ -6,15 +8,27 @@ use luminal::prelude::*;
|
||||
use luminal_cuda_lite::{cudarc::driver::CudaContext, runtime::CudaRuntime};
|
||||
use model::*;
|
||||
use rustc_hash::FxHashSet;
|
||||
use std::{io::Write, time::Duration};
|
||||
use std::{
|
||||
io::Write,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
const REPO_ID: &str = "Qwen/Qwen3-30B-A3B";
|
||||
|
||||
fn main() {
|
||||
let stdio = benchmark_stdio::enabled();
|
||||
let max_seq_len = 4096;
|
||||
let gen_tokens = 30;
|
||||
let search_graphs = 50;
|
||||
let gen_tokens = if stdio {
|
||||
benchmark_stdio::env_usize("GEN_TOKENS", 30)
|
||||
} else {
|
||||
30
|
||||
};
|
||||
let search_graphs = if stdio {
|
||||
benchmark_stdio::env_usize("SEARCH_GRAPHS", 50)
|
||||
} else {
|
||||
50
|
||||
};
|
||||
let prompt = "The capital of France is";
|
||||
|
||||
let ctx = CudaContext::new(0).unwrap();
|
||||
@@ -24,7 +38,6 @@ fn main() {
|
||||
println!("Using model directory: {}", model_dir.display());
|
||||
|
||||
let tokenizer = Tokenizer::from_file(model_dir.join("tokenizer.json")).unwrap();
|
||||
let prompt_tokens = tokenizer.encode(prompt, true).unwrap().get_ids().to_vec();
|
||||
|
||||
// Build graph
|
||||
let mut cx = Graph::default();
|
||||
@@ -47,10 +60,13 @@ fn main() {
|
||||
runtime.load_safetensors(&cx, weights_path.to_str().unwrap());
|
||||
|
||||
let cache_bytes = N_KV_HEADS * max_seq_len * HEAD_DIM * std::mem::size_of::<f32>();
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
|
||||
}
|
||||
let reset_cache = |runtime: &mut CudaRuntime| {
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
|
||||
}
|
||||
};
|
||||
reset_cache(&mut runtime);
|
||||
|
||||
println!("Compiling...");
|
||||
cx.set_dim('s', 1);
|
||||
@@ -58,14 +74,63 @@ fn main() {
|
||||
runtime.set_data(input, vec![1]);
|
||||
runtime.set_data(pos_ids, vec![1]);
|
||||
runtime = cx.search(runtime, search_graphs);
|
||||
reset_cache(&mut runtime);
|
||||
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
|
||||
if stdio {
|
||||
benchmark_stdio::serve(|prompt| {
|
||||
reset_cache(&mut runtime);
|
||||
run_prompt(
|
||||
prompt,
|
||||
gen_tokens,
|
||||
&tokenizer,
|
||||
&mut cx,
|
||||
&mut runtime,
|
||||
input,
|
||||
pos_ids,
|
||||
logits,
|
||||
&cache_outputs,
|
||||
&kv_cache,
|
||||
true,
|
||||
);
|
||||
});
|
||||
} else {
|
||||
run_prompt(
|
||||
prompt,
|
||||
gen_tokens,
|
||||
&tokenizer,
|
||||
&mut cx,
|
||||
&mut runtime,
|
||||
input,
|
||||
pos_ids,
|
||||
logits,
|
||||
&cache_outputs,
|
||||
&kv_cache,
|
||||
false,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
print!("{prompt}");
|
||||
std::io::stdout().flush().unwrap();
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn run_prompt(
|
||||
prompt: &str,
|
||||
gen_tokens: usize,
|
||||
tokenizer: &Tokenizer,
|
||||
cx: &mut Graph,
|
||||
runtime: &mut CudaRuntime,
|
||||
input: GraphTensor,
|
||||
pos_ids: GraphTensor,
|
||||
logits: GraphTensor,
|
||||
cache_outputs: &[(GraphTensor, GraphTensor)],
|
||||
kv_cache: &KVCache,
|
||||
stdio: bool,
|
||||
) {
|
||||
let prompt_tokens = tokenizer.encode(prompt, true).unwrap().get_ids().to_vec();
|
||||
let query_start = Instant::now();
|
||||
|
||||
if !stdio {
|
||||
print!("{prompt}");
|
||||
std::io::stdout().flush().unwrap();
|
||||
}
|
||||
|
||||
let mut prev_seq = 0usize;
|
||||
let mut fwd_durations = vec![];
|
||||
@@ -76,7 +141,7 @@ fn main() {
|
||||
const STOP_TOKEN: u32 = 151643;
|
||||
|
||||
// Prefill: process prompt tokens one at a time
|
||||
let prefill_start = std::time::Instant::now();
|
||||
let prefill_start = Instant::now();
|
||||
for &token in &prompt_tokens {
|
||||
cx.set_dim('s', 1);
|
||||
cx.set_dim('p', prev_seq);
|
||||
@@ -105,13 +170,27 @@ fn main() {
|
||||
.max_by(|(_, a), (_, b)| a.total_cmp(b))
|
||||
.unwrap()
|
||||
.0 as u32;
|
||||
print!("{}", tokenizer.decode(&[next_token], true).unwrap());
|
||||
std::io::stdout().flush().unwrap();
|
||||
let mut generated = 0usize;
|
||||
if stdio {
|
||||
if next_token != EOS_TOKEN && next_token != STOP_TOKEN {
|
||||
let decoded = tokenizer.decode(&[next_token], true).unwrap();
|
||||
benchmark_stdio::emit_token(&decoded);
|
||||
generated += 1;
|
||||
}
|
||||
} else {
|
||||
let decoded = tokenizer.decode(&[next_token], true).unwrap();
|
||||
print!("{decoded}");
|
||||
std::io::stdout().flush().unwrap();
|
||||
}
|
||||
seen_tokens.insert(next_token);
|
||||
|
||||
// Decode loop
|
||||
for _ in 1..gen_tokens {
|
||||
let start = std::time::Instant::now();
|
||||
if stdio && (next_token == EOS_TOKEN || next_token == STOP_TOKEN) {
|
||||
break;
|
||||
}
|
||||
|
||||
let start = Instant::now();
|
||||
cx.set_dim('s', 1);
|
||||
cx.set_dim('p', prev_seq);
|
||||
runtime.set_data(input, vec![next_token as i32]);
|
||||
@@ -150,13 +229,23 @@ fn main() {
|
||||
break;
|
||||
}
|
||||
|
||||
print!("{}", tokenizer.decode(&[next_token], true).unwrap());
|
||||
std::io::stdout().flush().unwrap();
|
||||
let decoded = tokenizer.decode(&[next_token], true).unwrap();
|
||||
if stdio {
|
||||
benchmark_stdio::emit_token(&decoded);
|
||||
} else {
|
||||
print!("{decoded}");
|
||||
std::io::stdout().flush().unwrap();
|
||||
}
|
||||
generated += 1;
|
||||
fwd_durations.push(start.elapsed());
|
||||
}
|
||||
println!();
|
||||
if stdio {
|
||||
benchmark_stdio::emit_eoq(generated, query_start);
|
||||
return;
|
||||
}
|
||||
|
||||
// Report benchmarks
|
||||
println!();
|
||||
println!(
|
||||
" TTFT: {:.2} ms ({} prompt tokens)",
|
||||
prefill_duration.as_secs_f64() * 1e3,
|
||||
|
||||
58
examples_common/benchmark_stdio.rs
Normal file
58
examples_common/benchmark_stdio.rs
Normal file
@@ -0,0 +1,58 @@
|
||||
use std::{
|
||||
io::{BufRead, Write},
|
||||
time::Instant,
|
||||
};
|
||||
|
||||
pub fn enabled() -> bool {
|
||||
std::env::args().any(|arg| arg == "--stdio")
|
||||
}
|
||||
|
||||
pub fn env_usize(name: &str, default: usize) -> usize {
|
||||
std::env::var(name)
|
||||
.ok()
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(default)
|
||||
}
|
||||
|
||||
fn emit_ready() {
|
||||
println!("READY");
|
||||
std::io::stdout().flush().unwrap();
|
||||
}
|
||||
|
||||
pub fn serve(mut f: impl FnMut(&str)) {
|
||||
emit_ready();
|
||||
|
||||
let stdin = std::io::stdin();
|
||||
for line in stdin.lock().lines() {
|
||||
let line = line.unwrap();
|
||||
f(&line);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn emit_token(token: &str) {
|
||||
println!("TOK\t{}", escape_token(token));
|
||||
std::io::stdout().flush().unwrap();
|
||||
}
|
||||
|
||||
pub fn emit_eoq(generated: usize, query_start: Instant) {
|
||||
println!(
|
||||
"EOQ\t{}\t{:.3}",
|
||||
generated,
|
||||
query_start.elapsed().as_secs_f64() * 1e3
|
||||
);
|
||||
std::io::stdout().flush().unwrap();
|
||||
}
|
||||
|
||||
fn escape_token(s: &str) -> String {
|
||||
let mut out = String::with_capacity(s.len());
|
||||
for ch in s.chars() {
|
||||
match ch {
|
||||
'\\' => out.push_str("\\\\"),
|
||||
'\t' => out.push_str("\\t"),
|
||||
'\n' => out.push_str("\\n"),
|
||||
'\r' => out.push_str("\\r"),
|
||||
_ => out.push(ch),
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
Reference in New Issue
Block a user