mirror of
https://github.com/jafioti/luminal.git
synced 2026-06-01 21:49:47 +09:00
Compare commits
2 Commits
flashinfer
...
rust-examp
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9311c59b4c | ||
|
|
69f21f2a43 |
7
examples/example_common/Cargo.toml
Normal file
7
examples/example_common/Cargo.toml
Normal file
@@ -0,0 +1,7 @@
|
||||
[package]
|
||||
name = "example_common"
|
||||
version = "0.1.0"
|
||||
edition = "2024"
|
||||
|
||||
[dependencies]
|
||||
rustc-hash = "2"
|
||||
167
examples/example_common/src/lib.rs
Normal file
167
examples/example_common/src/lib.rs
Normal file
@@ -0,0 +1,167 @@
|
||||
//! Shared helpers for the rust example binaries.
|
||||
//!
|
||||
//! - `--stdio` arg detection and READY/TOK/EOQ protocol used by the
|
||||
//! luminal-benchmarks harness to drive a long-lived subprocess.
|
||||
//! - Env-var parsing for benchmark knobs (`GEN_TOKENS`, `SEARCH_GRAPHS`).
|
||||
//! - `info!` routing — stderr in stdio mode, stdout otherwise.
|
||||
//! - Greedy sampling with a repetition penalty.
|
||||
|
||||
use rustc_hash::FxHashSet;
|
||||
|
||||
pub fn env_usize(name: &str, default: usize) -> usize {
|
||||
std::env::var(name)
|
||||
.ok()
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(default)
|
||||
}
|
||||
|
||||
pub fn env_bool(name: &str) -> bool {
|
||||
std::env::var(name)
|
||||
.ok()
|
||||
.is_some_and(|s| matches!(s.as_str(), "1" | "true" | "TRUE" | "yes" | "YES"))
|
||||
}
|
||||
|
||||
pub fn has_arg(name: &str) -> bool {
|
||||
std::env::args().any(|a| a == name)
|
||||
}
|
||||
|
||||
/// Route an info message to stderr in stdio mode (so the protocol channel
|
||||
/// stays clean) or stdout otherwise.
|
||||
pub fn info(stdio_mode: bool, msg: impl AsRef<str>) {
|
||||
let msg = msg.as_ref();
|
||||
if stdio_mode {
|
||||
eprintln!("{msg}");
|
||||
} else {
|
||||
println!("{msg}");
|
||||
}
|
||||
}
|
||||
|
||||
/// Greedy argmax with a multiplicative repetition penalty applied to
|
||||
/// previously-seen tokens.
|
||||
pub fn sample_greedy_with_penalty(
|
||||
logits_row: &[f32],
|
||||
seen: &FxHashSet<u32>,
|
||||
repetition_penalty: f32,
|
||||
) -> u32 {
|
||||
let mut row = logits_row.to_vec();
|
||||
for &tok in seen {
|
||||
let logit = &mut row[tok as usize];
|
||||
if *logit > 0.0 {
|
||||
*logit /= repetition_penalty;
|
||||
} else {
|
||||
*logit *= repetition_penalty;
|
||||
}
|
||||
}
|
||||
row.iter()
|
||||
.enumerate()
|
||||
.max_by(|(_, a), (_, b)| a.total_cmp(b))
|
||||
.unwrap()
|
||||
.0 as u32
|
||||
}
|
||||
|
||||
/// Escape a token's UTF-8 for one-line TOK transport: \ → \\, \t → \\t,
|
||||
/// \n → \\n, \r → \\r. Inverted on the python side.
|
||||
pub fn escape_tok(s: &str) -> String {
|
||||
let mut out = String::with_capacity(s.len());
|
||||
for c in s.chars() {
|
||||
match c {
|
||||
'\\' => out.push_str("\\\\"),
|
||||
'\t' => out.push_str("\\t"),
|
||||
'\n' => out.push_str("\\n"),
|
||||
'\r' => out.push_str("\\r"),
|
||||
_ => out.push(c),
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
pub mod stdio {
|
||||
//! READY / TOK\t<text> / EOQ\t<n>\t<elapsed_ms> protocol shared with
|
||||
//! `luminal-benchmarks/sut/rust.py`.
|
||||
|
||||
use super::escape_tok;
|
||||
use std::io::{BufRead, Write};
|
||||
use std::time::Duration;
|
||||
|
||||
/// Print the one-shot READY line that tells the harness init is done.
|
||||
pub fn ready() {
|
||||
let mut out = std::io::stdout().lock();
|
||||
let _ = writeln!(out, "READY");
|
||||
let _ = out.flush();
|
||||
}
|
||||
|
||||
/// One generated token, emitted as it's produced.
|
||||
pub fn emit_tok(text: &str) {
|
||||
let mut out = std::io::stdout().lock();
|
||||
let _ = writeln!(out, "TOK\t{}", escape_tok(text));
|
||||
let _ = out.flush();
|
||||
}
|
||||
|
||||
/// End-of-query marker with the total tokens produced for this prompt
|
||||
/// and the elapsed time. (The harness uses LoadGen's own timestamps,
|
||||
/// but the line is required to mark the boundary.)
|
||||
pub fn emit_eoq(n_tokens: usize, elapsed: Duration) {
|
||||
let mut out = std::io::stdout().lock();
|
||||
let _ = writeln!(
|
||||
out,
|
||||
"EOQ\t{}\t{:.3}",
|
||||
n_tokens,
|
||||
elapsed.as_secs_f64() * 1e3
|
||||
);
|
||||
let _ = out.flush();
|
||||
}
|
||||
|
||||
/// Read one prompt line. Blank lines are skipped (the harness writes
|
||||
/// one prompt per non-empty line). Returns `None` on EOF / read error.
|
||||
pub fn next_prompt<R: BufRead>(reader: &mut R, buf: &mut String) -> Option<String> {
|
||||
loop {
|
||||
buf.clear();
|
||||
match reader.read_line(buf) {
|
||||
Ok(0) => return None,
|
||||
Ok(_) => {}
|
||||
Err(e) => {
|
||||
eprintln!("stdio read error: {e}");
|
||||
return None;
|
||||
}
|
||||
}
|
||||
let prompt = buf
|
||||
.trim_end_matches('\n')
|
||||
.trim_end_matches('\r')
|
||||
.to_string();
|
||||
if !prompt.is_empty() {
|
||||
return Some(prompt);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Drive the per-prompt protocol: print READY, then for each stdin
|
||||
/// line call `run` with the prompt and post-call emit EOQ. `run`
|
||||
/// returns `(n_tokens_generated, prompt_elapsed)` and is expected to
|
||||
/// have called `emit_tok` once per generated token.
|
||||
pub fn serve(mut run: impl FnMut(&str) -> (usize, Duration)) {
|
||||
ready();
|
||||
let stdin = std::io::stdin();
|
||||
let mut handle = stdin.lock();
|
||||
let mut buf = String::new();
|
||||
while let Some(prompt) = next_prompt(&mut handle, &mut buf) {
|
||||
let (n, elapsed) = run(&prompt);
|
||||
emit_eoq(n, elapsed);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Pull the standard benchmark knobs from env vars in one call.
|
||||
pub struct BenchEnv {
|
||||
pub gen_tokens: usize,
|
||||
pub search_graphs: usize,
|
||||
}
|
||||
|
||||
impl BenchEnv {
|
||||
pub fn from_env(default_gen_tokens: usize, default_search_graphs: usize) -> Self {
|
||||
Self {
|
||||
gen_tokens: env_usize("GEN_TOKENS", default_gen_tokens),
|
||||
search_graphs: env_usize("SEARCH_GRAPHS", default_search_graphs),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ edition = "2024"
|
||||
luminal = { path = "../.." }
|
||||
luminal_nn = { path = "../../crates/luminal_nn" }
|
||||
luminal_cuda_lite = { path = "../../crates/luminal_cuda_lite" }
|
||||
example_common = { path = "../example_common" }
|
||||
tokenizers = "0.22.2"
|
||||
rustc-hash = "2"
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
mod hf;
|
||||
mod model;
|
||||
|
||||
use example_common::{BenchEnv, env_bool, has_arg, info, sample_greedy_with_penalty, stdio};
|
||||
use hf::prepare_hf_model;
|
||||
use luminal::prelude::*;
|
||||
use luminal_cuda_lite::{cudarc::driver::CudaContext, runtime::CudaRuntime};
|
||||
@@ -10,32 +11,140 @@ use std::{io::Write, time::Duration};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
const REPO_ID: &str = "google/gemma-4-26B-A4B";
|
||||
const STDIO_MAX_PREFILL: usize = 512;
|
||||
const DEFAULT_GEN_TOKENS: usize = 30;
|
||||
const DEFAULT_SEARCH_GRAPHS: usize = 50;
|
||||
const EOS_TOKEN: u32 = 1;
|
||||
|
||||
fn env_bool(name: &str) -> bool {
|
||||
std::env::var(name)
|
||||
.ok()
|
||||
.is_some_and(|s| matches!(s.as_str(), "1" | "true" | "TRUE" | "yes" | "YES"))
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn run_one_prompt(
|
||||
cx: &mut Graph,
|
||||
runtime: &mut CudaRuntime,
|
||||
tokenizer: &Tokenizer,
|
||||
input: GraphTensor,
|
||||
pos_ids: GraphTensor,
|
||||
logits: GraphTensor,
|
||||
kv_cache: &KVCache,
|
||||
cache_outputs: &[(GraphTensor, GraphTensor)],
|
||||
max_seq_len: usize,
|
||||
prompt_tokens: &[u32],
|
||||
gen_tokens: usize,
|
||||
repetition_penalty: f32,
|
||||
emit_tok: &mut dyn FnMut(&str),
|
||||
) -> (usize, Duration) {
|
||||
for layer in 0..LAYERS {
|
||||
let cache_bytes = cache_bytes_for_layer(layer, max_seq_len);
|
||||
runtime.set_zeros(kv_cache.k_caches[layer], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[layer], cache_bytes);
|
||||
}
|
||||
|
||||
let prompt_len = prompt_tokens.len();
|
||||
if prompt_len == 0 || gen_tokens == 0 {
|
||||
return (0, Duration::default());
|
||||
}
|
||||
|
||||
let mut seen_tokens: FxHashSet<u32> = FxHashSet::default();
|
||||
let mut generated = 0usize;
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
cx.set_dim('s', prompt_len);
|
||||
cx.set_dim('p', 0);
|
||||
runtime.set_data(
|
||||
input,
|
||||
prompt_tokens.iter().map(|t| *t as i32).collect::<Vec<_>>(),
|
||||
);
|
||||
runtime.set_data(pos_ids, (0..prompt_len as i32).collect::<Vec<_>>());
|
||||
runtime.execute(&cx.dyn_map);
|
||||
|
||||
for (layer_idx, (k_out, v_out)) in cache_outputs.iter().enumerate() {
|
||||
let k_buf = runtime.remove_buffer(*k_out);
|
||||
let v_buf = runtime.remove_buffer(*v_out);
|
||||
runtime.set_buffer(kv_cache.k_caches[layer_idx], k_buf);
|
||||
runtime.set_buffer(kv_cache.v_caches[layer_idx], v_buf);
|
||||
}
|
||||
let mut prev_seq = prompt_len;
|
||||
|
||||
let logits_data = runtime.get_f32(logits);
|
||||
let row_start = (prompt_len - 1) * VOCAB_SIZE;
|
||||
let mut next_token = sample_greedy_with_penalty(
|
||||
&logits_data[row_start..row_start + VOCAB_SIZE],
|
||||
&seen_tokens,
|
||||
repetition_penalty,
|
||||
);
|
||||
seen_tokens.insert(next_token);
|
||||
generated += 1;
|
||||
if next_token != EOS_TOKEN {
|
||||
let decoded = tokenizer.decode(&[next_token], true).unwrap();
|
||||
emit_tok(&decoded);
|
||||
}
|
||||
|
||||
while generated < gen_tokens {
|
||||
if next_token == EOS_TOKEN {
|
||||
break;
|
||||
}
|
||||
cx.set_dim('s', 1);
|
||||
cx.set_dim('p', prev_seq);
|
||||
runtime.set_data(input, vec![next_token as i32]);
|
||||
runtime.set_data(pos_ids, vec![prev_seq as i32]);
|
||||
runtime.execute(&cx.dyn_map);
|
||||
|
||||
for (layer_idx, (k_out, v_out)) in cache_outputs.iter().enumerate() {
|
||||
let k_buf = runtime.remove_buffer(*k_out);
|
||||
let v_buf = runtime.remove_buffer(*v_out);
|
||||
runtime.set_buffer(kv_cache.k_caches[layer_idx], k_buf);
|
||||
runtime.set_buffer(kv_cache.v_caches[layer_idx], v_buf);
|
||||
}
|
||||
|
||||
prev_seq += 1;
|
||||
|
||||
let logits_data = runtime.get_f32(logits);
|
||||
next_token = sample_greedy_with_penalty(
|
||||
&logits_data[..VOCAB_SIZE],
|
||||
&seen_tokens,
|
||||
repetition_penalty,
|
||||
);
|
||||
seen_tokens.insert(next_token);
|
||||
generated += 1;
|
||||
|
||||
if next_token == EOS_TOKEN {
|
||||
break;
|
||||
}
|
||||
|
||||
let decoded = tokenizer.decode(&[next_token], true).unwrap();
|
||||
emit_tok(&decoded);
|
||||
}
|
||||
|
||||
(generated, start.elapsed())
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let max_seq_len = 4096;
|
||||
let gen_tokens = 30;
|
||||
let search_graphs = 50;
|
||||
let bench = BenchEnv::from_env(DEFAULT_GEN_TOKENS, DEFAULT_SEARCH_GRAPHS);
|
||||
let stdio_mode = has_arg("--stdio");
|
||||
let prompt = std::env::var("PROMPT").unwrap_or_else(|_| "The capital of France is".to_string());
|
||||
let print_token_ids = env_bool("PRINT_TOKEN_IDS");
|
||||
|
||||
let log = |s: &str| info(stdio_mode, s);
|
||||
|
||||
let ctx = CudaContext::new(0).unwrap();
|
||||
let stream = ctx.default_stream();
|
||||
|
||||
let model_dir = prepare_hf_model(REPO_ID).expect("Failed to prepare model");
|
||||
println!("Using model directory: {}", model_dir.display());
|
||||
log(&format!("Using model directory: {}", model_dir.display()));
|
||||
|
||||
let tokenizer = Tokenizer::from_file(model_dir.join("tokenizer.json")).unwrap();
|
||||
let prompt_tokens = tokenizer
|
||||
.encode(prompt.as_str(), true)
|
||||
.unwrap()
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
|
||||
let (default_prompt_tokens, prompt_len) = if stdio_mode {
|
||||
(Vec::<u32>::new(), 0usize)
|
||||
} else {
|
||||
let toks = tokenizer
|
||||
.encode(prompt.as_str(), true)
|
||||
.unwrap()
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
let len = toks.len();
|
||||
(toks, len)
|
||||
};
|
||||
|
||||
let mut cx = Graph::default();
|
||||
let input = cx.named_tensor("input", 's').as_dtype(DType::Int);
|
||||
@@ -48,10 +157,10 @@ fn main() {
|
||||
v_out.output();
|
||||
}
|
||||
|
||||
println!("Building E-Graph...");
|
||||
log("Building E-Graph...");
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
|
||||
println!("Loading weights...");
|
||||
log("Loading weights...");
|
||||
let mut runtime = CudaRuntime::initialize(stream);
|
||||
let weights_path = model_dir.join("model_combined.safetensors");
|
||||
runtime.load_safetensors(&cx, weights_path.to_str().unwrap());
|
||||
@@ -62,10 +171,12 @@ fn main() {
|
||||
runtime.set_zeros(kv_cache.v_caches[layer], cache_bytes);
|
||||
}
|
||||
|
||||
println!("Compiling...");
|
||||
let max_prefill = (prompt_tokens.len() + 16)
|
||||
.next_power_of_two()
|
||||
.min(max_seq_len);
|
||||
log("Compiling...");
|
||||
let max_prefill = if stdio_mode {
|
||||
STDIO_MAX_PREFILL.min(max_seq_len)
|
||||
} else {
|
||||
(prompt_len + 16).next_power_of_two().min(max_seq_len)
|
||||
};
|
||||
let search_s = 16.min(max_prefill).max(2);
|
||||
cx.set_dim_buckets(
|
||||
's',
|
||||
@@ -78,7 +189,7 @@ fn main() {
|
||||
cx.set_dim('p', 0);
|
||||
runtime.set_data(input, vec![1; search_s]);
|
||||
runtime.set_data(pos_ids, (0..search_s as i32).collect::<Vec<_>>());
|
||||
runtime = cx.search(runtime, search_graphs);
|
||||
runtime = cx.search(runtime, bench.search_graphs);
|
||||
|
||||
for layer in 0..LAYERS {
|
||||
let cache_bytes = cache_bytes_for_layer(layer, max_seq_len);
|
||||
@@ -86,25 +197,56 @@ fn main() {
|
||||
runtime.set_zeros(kv_cache.v_caches[layer], cache_bytes);
|
||||
}
|
||||
|
||||
let repetition_penalty: f32 = 1.05;
|
||||
|
||||
if stdio_mode {
|
||||
stdio::serve(|user_prompt| {
|
||||
let prompt_tokens = tokenizer
|
||||
.encode(user_prompt, true)
|
||||
.unwrap()
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
run_one_prompt(
|
||||
&mut cx,
|
||||
&mut runtime,
|
||||
&tokenizer,
|
||||
input,
|
||||
pos_ids,
|
||||
logits,
|
||||
&kv_cache,
|
||||
&cache_outputs,
|
||||
max_seq_len,
|
||||
&prompt_tokens,
|
||||
bench.gen_tokens,
|
||||
repetition_penalty,
|
||||
&mut |s| stdio::emit_tok(s),
|
||||
)
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
// Legacy single-prompt flow.
|
||||
print!("{prompt}");
|
||||
std::io::stdout().flush().unwrap();
|
||||
|
||||
let mut prev_seq: usize;
|
||||
let mut fwd_durations = vec![];
|
||||
let mut seen_tokens = FxHashSet::default();
|
||||
let mut generated_token_ids = vec![];
|
||||
let repetition_penalty: f32 = 1.05;
|
||||
|
||||
const EOS_TOKEN: u32 = 1;
|
||||
|
||||
let prefill_start = std::time::Instant::now();
|
||||
cx.set_dim('s', prompt_tokens.len());
|
||||
cx.set_dim('s', default_prompt_tokens.len());
|
||||
cx.set_dim('p', 0);
|
||||
runtime.set_data(
|
||||
input,
|
||||
prompt_tokens.iter().map(|t| *t as i32).collect::<Vec<_>>(),
|
||||
default_prompt_tokens
|
||||
.iter()
|
||||
.map(|t| *t as i32)
|
||||
.collect::<Vec<_>>(),
|
||||
);
|
||||
runtime.set_data(
|
||||
pos_ids,
|
||||
(0..default_prompt_tokens.len() as i32).collect::<Vec<_>>(),
|
||||
);
|
||||
runtime.set_data(pos_ids, (0..prompt_tokens.len() as i32).collect::<Vec<_>>());
|
||||
runtime.execute(&cx.dyn_map);
|
||||
|
||||
for (layer_idx, (k_out, v_out)) in cache_outputs.iter().enumerate() {
|
||||
@@ -113,24 +255,22 @@ fn main() {
|
||||
runtime.set_buffer(kv_cache.k_caches[layer_idx], k_buf);
|
||||
runtime.set_buffer(kv_cache.v_caches[layer_idx], v_buf);
|
||||
}
|
||||
prev_seq = prompt_tokens.len();
|
||||
let mut prev_seq = default_prompt_tokens.len();
|
||||
let prefill_duration = prefill_start.elapsed();
|
||||
|
||||
let logits_data = runtime.get_f32(logits);
|
||||
let row_start = (prompt_tokens.len() - 1) * VOCAB_SIZE;
|
||||
let last_row = &logits_data[row_start..row_start + VOCAB_SIZE];
|
||||
let mut next_token = last_row
|
||||
.iter()
|
||||
.enumerate()
|
||||
.max_by(|(_, a), (_, b)| a.total_cmp(b))
|
||||
.unwrap()
|
||||
.0 as u32;
|
||||
let row_start = (default_prompt_tokens.len() - 1) * VOCAB_SIZE;
|
||||
let mut next_token = sample_greedy_with_penalty(
|
||||
&logits_data[row_start..row_start + VOCAB_SIZE],
|
||||
&seen_tokens,
|
||||
repetition_penalty,
|
||||
);
|
||||
generated_token_ids.push(next_token);
|
||||
print!("{}", tokenizer.decode(&[next_token], true).unwrap());
|
||||
std::io::stdout().flush().unwrap();
|
||||
seen_tokens.insert(next_token);
|
||||
|
||||
for _ in 1..gen_tokens {
|
||||
for _ in 1..bench.gen_tokens {
|
||||
let start = std::time::Instant::now();
|
||||
cx.set_dim('s', 1);
|
||||
cx.set_dim('p', prev_seq);
|
||||
@@ -148,21 +288,11 @@ fn main() {
|
||||
prev_seq += 1;
|
||||
|
||||
let logits_data = runtime.get_f32(logits);
|
||||
let mut last_row = logits_data[..VOCAB_SIZE].to_vec();
|
||||
for &tok in &seen_tokens {
|
||||
let logit = &mut last_row[tok as usize];
|
||||
if *logit > 0.0 {
|
||||
*logit /= repetition_penalty;
|
||||
} else {
|
||||
*logit *= repetition_penalty;
|
||||
}
|
||||
}
|
||||
next_token = last_row
|
||||
.iter()
|
||||
.enumerate()
|
||||
.max_by(|(_, a), (_, b)| a.total_cmp(b))
|
||||
.unwrap()
|
||||
.0 as u32;
|
||||
next_token = sample_greedy_with_penalty(
|
||||
&logits_data[..VOCAB_SIZE],
|
||||
&seen_tokens,
|
||||
repetition_penalty,
|
||||
);
|
||||
generated_token_ids.push(next_token);
|
||||
seen_tokens.insert(next_token);
|
||||
|
||||
@@ -182,7 +312,7 @@ fn main() {
|
||||
println!(
|
||||
" TTFT: {:.2} ms ({} prompt tokens)",
|
||||
prefill_duration.as_secs_f64() * 1e3,
|
||||
prompt_tokens.len()
|
||||
default_prompt_tokens.len()
|
||||
);
|
||||
if fwd_durations.len() > 1 {
|
||||
println!(
|
||||
|
||||
@@ -14,6 +14,7 @@ luminal = { path = "../.." }
|
||||
luminal_nn = { path = "../../crates/luminal_nn" }
|
||||
luminal_cuda_lite = { path = "../../crates/luminal_cuda_lite" }
|
||||
luminal_tracing = {path="../../crates/luminal_tracing"}
|
||||
example_common = { path = "../example_common" }
|
||||
tokenizers = "0.15.2"
|
||||
tracing = "0.1.43"
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
mod hf;
|
||||
mod model;
|
||||
|
||||
use example_common::{BenchEnv, info, sample_greedy_with_penalty, stdio};
|
||||
use hf::{WeightFormat, prepare_hf_model};
|
||||
use luminal::prelude::*;
|
||||
use luminal_cuda_lite::{cudarc::driver::CudaContext, runtime::CudaRuntime};
|
||||
@@ -15,14 +16,18 @@ use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
||||
const FP32_REPO_ID: &str = "NousResearch/Meta-Llama-3-8B-Instruct";
|
||||
const FP8_REPO_ID: &str = "nvidia/Llama-3.1-8B-Instruct-FP8";
|
||||
const MAX_SEQ_LEN: usize = 4096;
|
||||
const GEN_TOKENS: usize = 500;
|
||||
const SEARCH_GRAPHS: usize = 500;
|
||||
const DEFAULT_GEN_TOKENS: usize = 500;
|
||||
const DEFAULT_SEARCH_GRAPHS: usize = 500;
|
||||
const STDIO_MAX_PREFILL: usize = 512;
|
||||
const SEARCH_TRIALS: usize = 1;
|
||||
const SEARCH_KEEP_BEST: usize = 4;
|
||||
const SEARCH_MEMORY_MIB: usize = 2048;
|
||||
const SEARCH_SEED: u64 = 0;
|
||||
const PROMPT: &str = "Explain what a neural network is in a paragraph.";
|
||||
|
||||
const EOS_TOKEN: u32 = 128009;
|
||||
const STOP_TOKEN: u32 = 128001;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
enum LlamaWeightMode {
|
||||
Fp32,
|
||||
@@ -45,20 +50,28 @@ impl LlamaWeightMode {
|
||||
}
|
||||
}
|
||||
|
||||
struct CliArgs {
|
||||
weight_mode: LlamaWeightMode,
|
||||
stdio: bool,
|
||||
}
|
||||
|
||||
fn print_usage(program: &str) {
|
||||
println!("Usage: {program} [--fp8]");
|
||||
println!("Usage: {program} [--fp8] [--stdio]");
|
||||
println!();
|
||||
println!(" --fp8 Use {FP8_REPO_ID} with FP8 projection weights");
|
||||
println!(" --stdio Long-lived stdio benchmark protocol (READY/TOK/EOQ)");
|
||||
println!(" -h,--help Show this help");
|
||||
}
|
||||
|
||||
fn parse_args() -> LlamaWeightMode {
|
||||
let mut mode = LlamaWeightMode::Fp32;
|
||||
fn parse_args() -> CliArgs {
|
||||
let mut weight_mode = LlamaWeightMode::Fp32;
|
||||
let mut stdio = false;
|
||||
let mut args = env::args();
|
||||
let program = args.next().unwrap_or_else(|| "llama".to_string());
|
||||
for arg in args {
|
||||
match arg.as_str() {
|
||||
"--fp8" => mode = LlamaWeightMode::Fp8,
|
||||
"--fp8" => weight_mode = LlamaWeightMode::Fp8,
|
||||
"--stdio" => stdio = true,
|
||||
"-h" | "--help" => {
|
||||
print_usage(&program);
|
||||
std::process::exit(0);
|
||||
@@ -70,7 +83,7 @@ fn parse_args() -> LlamaWeightMode {
|
||||
}
|
||||
}
|
||||
}
|
||||
mode
|
||||
CliArgs { weight_mode, stdio }
|
||||
}
|
||||
|
||||
fn llama3_chat_prompt(user_prompt: &str) -> String {
|
||||
@@ -121,7 +134,10 @@ fn print_profile(label: &str, profile: &StepProfile, n: usize) {
|
||||
);
|
||||
}
|
||||
|
||||
fn print_host_op_summary(runtime: &CudaRuntime, label: &str) {
|
||||
fn print_host_op_summary(runtime: &CudaRuntime, label: &str, quiet: bool) {
|
||||
if quiet {
|
||||
return;
|
||||
}
|
||||
let host_ops = runtime.host_ops();
|
||||
let debug_ops = host_ops
|
||||
.iter()
|
||||
@@ -155,23 +171,6 @@ fn print_host_op_summary(runtime: &CudaRuntime, label: &str) {
|
||||
);
|
||||
}
|
||||
|
||||
fn sample_greedy(logits_row: &[f32], seen: &FxHashSet<u32>, repetition_penalty: f32) -> u32 {
|
||||
let mut row = logits_row.to_vec();
|
||||
for &tok in seen {
|
||||
let logit = &mut row[tok as usize];
|
||||
if *logit > 0.0 {
|
||||
*logit /= repetition_penalty;
|
||||
} else {
|
||||
*logit *= repetition_penalty;
|
||||
}
|
||||
}
|
||||
row.iter()
|
||||
.enumerate()
|
||||
.max_by(|(_, a), (_, b)| a.total_cmp(b))
|
||||
.unwrap()
|
||||
.0 as u32
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn run_model_step(
|
||||
cx: &mut Graph,
|
||||
@@ -238,31 +237,154 @@ fn causal_mask(q_pos: &[usize], context_len: usize) -> Vec<f32> {
|
||||
mask
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn run_one_prompt(
|
||||
cx: &mut Graph,
|
||||
runtime: &mut CudaRuntime,
|
||||
tokenizer: &Tokenizer,
|
||||
input: GraphTensor,
|
||||
q_pos_t: GraphTensor,
|
||||
scatter_idx_t: GraphTensor,
|
||||
gather_idx_t: GraphTensor,
|
||||
attn_mask_t: GraphTensor,
|
||||
logits: GraphTensor,
|
||||
kv_cache: &KVCache,
|
||||
cache_outputs: &[(GraphTensor, GraphTensor)],
|
||||
cache_bytes: usize,
|
||||
prompt_tokens: &[u32],
|
||||
gen_tokens: usize,
|
||||
repetition_penalty: f32,
|
||||
emit_tok: &mut dyn FnMut(&str),
|
||||
) -> (usize, Duration) {
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
|
||||
}
|
||||
|
||||
let prompt_len = prompt_tokens.len();
|
||||
let mut seen_tokens: FxHashSet<u32> = FxHashSet::default();
|
||||
let mut context_len = 0usize;
|
||||
let mut generated = 0usize;
|
||||
let mut next_token: Option<u32> = None;
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
if gen_tokens > 0 && prompt_len > 0 {
|
||||
let positions: Vec<usize> = (0..prompt_len).collect();
|
||||
let q_pos: Vec<i32> = positions.iter().map(|&p| p as i32).collect();
|
||||
let scatter_idx = q_pos.clone();
|
||||
let gather_idx = q_pos.clone();
|
||||
let mask = causal_mask(&positions, prompt_len);
|
||||
let (logits_data, _profile) = run_model_step(
|
||||
cx,
|
||||
runtime,
|
||||
input,
|
||||
q_pos_t,
|
||||
scatter_idx_t,
|
||||
gather_idx_t,
|
||||
attn_mask_t,
|
||||
logits,
|
||||
kv_cache,
|
||||
cache_outputs,
|
||||
prompt_tokens,
|
||||
&q_pos,
|
||||
&scatter_idx,
|
||||
&gather_idx,
|
||||
&mask,
|
||||
);
|
||||
context_len = prompt_len;
|
||||
let token = sample_greedy_with_penalty(
|
||||
&logits_data[logits_data.len() - VOCAB_SIZE..],
|
||||
&seen_tokens,
|
||||
repetition_penalty,
|
||||
);
|
||||
seen_tokens.insert(token);
|
||||
next_token = Some(token);
|
||||
generated = 1;
|
||||
if token != EOS_TOKEN && token != STOP_TOKEN {
|
||||
let decoded = tokenizer.decode(&[token], true).unwrap();
|
||||
emit_tok(&decoded);
|
||||
}
|
||||
}
|
||||
|
||||
while generated < gen_tokens {
|
||||
let current_token = match next_token {
|
||||
Some(token) if token != EOS_TOKEN && token != STOP_TOKEN => token,
|
||||
_ => break,
|
||||
};
|
||||
let (logits_data, _profile) = run_model_step(
|
||||
cx,
|
||||
runtime,
|
||||
input,
|
||||
q_pos_t,
|
||||
scatter_idx_t,
|
||||
gather_idx_t,
|
||||
attn_mask_t,
|
||||
logits,
|
||||
kv_cache,
|
||||
cache_outputs,
|
||||
&[current_token],
|
||||
&[context_len as i32],
|
||||
&[context_len as i32],
|
||||
&(0..=context_len as i32).collect::<Vec<_>>(),
|
||||
&causal_mask(&[context_len], context_len + 1),
|
||||
);
|
||||
context_len += 1;
|
||||
let token = sample_greedy_with_penalty(
|
||||
&logits_data[logits_data.len() - VOCAB_SIZE..],
|
||||
&seen_tokens,
|
||||
repetition_penalty,
|
||||
);
|
||||
seen_tokens.insert(token);
|
||||
next_token = Some(token);
|
||||
generated += 1;
|
||||
if token == EOS_TOKEN || token == STOP_TOKEN {
|
||||
break;
|
||||
}
|
||||
let decoded = tokenizer.decode(&[token], true).unwrap();
|
||||
emit_tok(&decoded);
|
||||
}
|
||||
(generated, start.elapsed())
|
||||
}
|
||||
|
||||
fn main() {
|
||||
tracing_subscriber::registry()
|
||||
.with(tracing_subscriber::fmt::layer())
|
||||
.with(luminal_filter())
|
||||
.init();
|
||||
|
||||
let weight_mode = parse_args();
|
||||
let cli = parse_args();
|
||||
let stdio_mode = cli.stdio;
|
||||
let weight_mode = cli.weight_mode;
|
||||
|
||||
let bench = BenchEnv::from_env(DEFAULT_GEN_TOKENS, DEFAULT_SEARCH_GRAPHS);
|
||||
let log = |s: &str| info(stdio_mode, s);
|
||||
|
||||
let ctx = CudaContext::new(0).unwrap();
|
||||
let stream = ctx.default_stream();
|
||||
|
||||
let prepared = prepare_hf_model(weight_mode.repo_id(), weight_mode.weight_format())
|
||||
.expect("Failed to prepare model");
|
||||
println!("Using model: {}", weight_mode.repo_id());
|
||||
println!("Using model directory: {}", prepared.model_dir.display());
|
||||
log(&format!("Using model: {}", weight_mode.repo_id()));
|
||||
log(&format!(
|
||||
"Using model directory: {}",
|
||||
prepared.model_dir.display()
|
||||
));
|
||||
|
||||
let tokenizer = Tokenizer::from_file(prepared.model_dir.join("tokenizer.json")).unwrap();
|
||||
let chat_prompt = llama3_chat_prompt(PROMPT);
|
||||
let prompt_tokens = tokenizer
|
||||
.encode(chat_prompt.as_str(), false)
|
||||
.unwrap()
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
let prompt_len = prompt_tokens.len();
|
||||
|
||||
// Build graph
|
||||
let (prompt_tokens_default, prompt_len) = if stdio_mode {
|
||||
(Vec::<u32>::new(), 0usize)
|
||||
} else {
|
||||
let chat_prompt = llama3_chat_prompt(PROMPT);
|
||||
let toks = tokenizer
|
||||
.encode(chat_prompt.as_str(), false)
|
||||
.unwrap()
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
let len = toks.len();
|
||||
(toks, len)
|
||||
};
|
||||
|
||||
let mut cx = Graph::default();
|
||||
let input = cx.named_tensor("input", 's').as_dtype(DType::Int);
|
||||
let q_pos_t = cx.named_tensor("q_pos", 's').as_dtype(DType::Int);
|
||||
@@ -291,24 +413,27 @@ fn main() {
|
||||
cx.set_dim('s', 1);
|
||||
cx.set_dim('c', 1);
|
||||
|
||||
println!("Building E-Graph...");
|
||||
log("Building E-Graph...");
|
||||
let egraph_start = std::time::Instant::now();
|
||||
cx.build_search_space_with_options::<CudaRuntime>(
|
||||
BuildSearchSpaceOptions::new().max_memory_mib(SEARCH_MEMORY_MIB),
|
||||
);
|
||||
println!(
|
||||
log(&format!(
|
||||
" E-Graph build: {:.2} s",
|
||||
egraph_start.elapsed().as_secs_f64()
|
||||
);
|
||||
));
|
||||
|
||||
println!("Loading weights...");
|
||||
log("Loading weights...");
|
||||
let load_start = std::time::Instant::now();
|
||||
let mut runtime = CudaRuntime::initialize(stream);
|
||||
for weights_path in &prepared.weight_files {
|
||||
println!(" Loading {}", weights_path.display());
|
||||
log(&format!(" Loading {}", weights_path.display()));
|
||||
runtime.load_safetensors(&cx, weights_path.to_str().unwrap());
|
||||
}
|
||||
println!(" Weight load: {:.2} s", load_start.elapsed().as_secs_f64());
|
||||
log(&format!(
|
||||
" Weight load: {:.2} s",
|
||||
load_start.elapsed().as_secs_f64()
|
||||
));
|
||||
|
||||
let cache_bytes = MAX_SEQ_LEN * KV_DIM * std::mem::size_of::<f32>();
|
||||
for i in 0..LAYERS {
|
||||
@@ -316,9 +441,13 @@ fn main() {
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
|
||||
}
|
||||
|
||||
println!("Compiling...");
|
||||
log("Compiling...");
|
||||
let compile_start = std::time::Instant::now();
|
||||
let max_prefill = (prompt_len + 16).next_power_of_two().min(MAX_SEQ_LEN);
|
||||
let max_prefill = if stdio_mode {
|
||||
STDIO_MAX_PREFILL.min(MAX_SEQ_LEN)
|
||||
} else {
|
||||
(prompt_len + 16).next_power_of_two().min(MAX_SEQ_LEN)
|
||||
};
|
||||
let search_s = 16.min(max_prefill).max(2);
|
||||
cx.set_dim_buckets(
|
||||
's',
|
||||
@@ -334,45 +463,74 @@ fn main() {
|
||||
runtime.set_data(scatter_idx_t, (0..search_s as i32).collect::<Vec<_>>());
|
||||
runtime.set_data(gather_idx_t, (0..search_s as i32).collect::<Vec<_>>());
|
||||
runtime.set_data(attn_mask_t, vec![0.0f32; search_s * search_s]);
|
||||
println!(" Search seed: {SEARCH_SEED}");
|
||||
println!(" Search trials: {SEARCH_TRIALS}");
|
||||
println!(" Search keep-best: {SEARCH_KEEP_BEST}");
|
||||
log(&format!(" Search seed: {SEARCH_SEED}"));
|
||||
log(&format!(" Search trials: {SEARCH_TRIALS}"));
|
||||
log(&format!(" Search keep-best: {SEARCH_KEEP_BEST}"));
|
||||
let mut rng = StdRng::seed_from_u64(SEARCH_SEED);
|
||||
runtime = cx.search_options(
|
||||
runtime,
|
||||
SearchOptions::new(SEARCH_GRAPHS)
|
||||
SearchOptions::new(bench.search_graphs)
|
||||
.trials(SEARCH_TRIALS)
|
||||
.keep_best(SEARCH_KEEP_BEST),
|
||||
&mut rng,
|
||||
);
|
||||
println!(
|
||||
log(&format!(
|
||||
" Search/compile: {:.2} s",
|
||||
compile_start.elapsed().as_secs_f64()
|
||||
);
|
||||
print_host_op_summary(&runtime, "post-compile active bucket");
|
||||
));
|
||||
print_host_op_summary(&runtime, "post-compile active bucket", stdio_mode);
|
||||
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
|
||||
}
|
||||
|
||||
let repetition_penalty: f32 = 1.05;
|
||||
|
||||
if stdio_mode {
|
||||
stdio::serve(|user_prompt| {
|
||||
let chat_prompt = llama3_chat_prompt(user_prompt);
|
||||
let prompt_tokens = tokenizer
|
||||
.encode(chat_prompt.as_str(), false)
|
||||
.unwrap()
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
run_one_prompt(
|
||||
&mut cx,
|
||||
&mut runtime,
|
||||
&tokenizer,
|
||||
input,
|
||||
q_pos_t,
|
||||
scatter_idx_t,
|
||||
gather_idx_t,
|
||||
attn_mask_t,
|
||||
logits,
|
||||
&kv_cache,
|
||||
&cache_outputs,
|
||||
cache_bytes,
|
||||
&prompt_tokens,
|
||||
bench.gen_tokens,
|
||||
repetition_penalty,
|
||||
&mut |s| stdio::emit_tok(s),
|
||||
)
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
// Non-stdio: legacy single-prompt flow with profiling output.
|
||||
let mut context_len = 0usize;
|
||||
let mut fwd_durations = vec![];
|
||||
let mut step_profiles = vec![];
|
||||
let mut seen_tokens = FxHashSet::default();
|
||||
let repetition_penalty: f32 = 1.05;
|
||||
|
||||
const EOS_TOKEN: u32 = 128009;
|
||||
const STOP_TOKEN: u32 = 128001;
|
||||
|
||||
println!(
|
||||
"Prompt: {} tokens, generating up to {} tokens",
|
||||
prompt_len, GEN_TOKENS
|
||||
prompt_len, bench.gen_tokens
|
||||
);
|
||||
|
||||
let mut generated = 0usize;
|
||||
let mut next_token = None;
|
||||
if GEN_TOKENS > 0 && prompt_len > 0 {
|
||||
if bench.gen_tokens > 0 && prompt_len > 0 {
|
||||
let positions: Vec<usize> = (0..prompt_len).collect();
|
||||
let q_pos: Vec<i32> = positions.iter().map(|&p| p as i32).collect();
|
||||
let scatter_idx = q_pos.clone();
|
||||
@@ -389,17 +547,17 @@ fn main() {
|
||||
logits,
|
||||
&kv_cache,
|
||||
&cache_outputs,
|
||||
&prompt_tokens,
|
||||
&prompt_tokens_default,
|
||||
&q_pos,
|
||||
&scatter_idx,
|
||||
&gather_idx,
|
||||
&mask,
|
||||
);
|
||||
print_host_op_summary(&runtime, "after prefill");
|
||||
print_host_op_summary(&runtime, "after prefill", false);
|
||||
context_len = prompt_len;
|
||||
|
||||
let sample_start = std::time::Instant::now();
|
||||
let token = sample_greedy(
|
||||
let token = sample_greedy_with_penalty(
|
||||
&logits_data[logits_data.len() - VOCAB_SIZE..],
|
||||
&seen_tokens,
|
||||
repetition_penalty,
|
||||
@@ -420,7 +578,7 @@ fn main() {
|
||||
}
|
||||
}
|
||||
|
||||
while generated < GEN_TOKENS {
|
||||
while generated < bench.gen_tokens {
|
||||
let current_token = match next_token {
|
||||
Some(token) if token != EOS_TOKEN && token != STOP_TOKEN => token,
|
||||
_ => break,
|
||||
@@ -444,12 +602,12 @@ fn main() {
|
||||
&causal_mask(&[context_len], context_len + 1),
|
||||
);
|
||||
if generated == 1 {
|
||||
print_host_op_summary(&runtime, "after first decode");
|
||||
print_host_op_summary(&runtime, "after first decode", false);
|
||||
}
|
||||
context_len += 1;
|
||||
|
||||
let sample_start = std::time::Instant::now();
|
||||
let token = sample_greedy(
|
||||
let token = sample_greedy_with_penalty(
|
||||
&logits_data[logits_data.len() - VOCAB_SIZE..],
|
||||
&seen_tokens,
|
||||
repetition_penalty,
|
||||
|
||||
@@ -18,6 +18,7 @@ luminal_nn = { path = "../../crates/luminal_nn" }
|
||||
luminal_tracing = { path = "../../crates/luminal_tracing" }
|
||||
luminal_cuda_lite = { path = "../../crates/luminal_cuda_lite", optional = true }
|
||||
luminal_metal = { path = "../../crates/luminal_metal", optional = true }
|
||||
example_common = { path = "../example_common" }
|
||||
tokenizers = "0.22.2"
|
||||
tracing = "0.1.43"
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
pub mod hf;
|
||||
pub mod model;
|
||||
|
||||
use example_common::{BenchEnv, info, sample_greedy_with_penalty, stdio};
|
||||
use hf::prepare_hf_model;
|
||||
pub use luminal::prelude::Runtime;
|
||||
use luminal::prelude::*;
|
||||
@@ -13,6 +14,7 @@ use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
||||
|
||||
const EOS_TOKEN: u32 = 151645; // <|im_end|>
|
||||
const STOP_TOKEN: u32 = 151643; // <|endoftext|>
|
||||
const STDIO_MAX_PREFILL: usize = 512;
|
||||
|
||||
pub struct QwenRunConfig {
|
||||
pub repo_id: String,
|
||||
@@ -22,6 +24,7 @@ pub struct QwenRunConfig {
|
||||
pub prompt: String,
|
||||
pub repetition_penalty: f32,
|
||||
pub layers: usize,
|
||||
pub stdio: bool,
|
||||
}
|
||||
|
||||
fn qwen3_chat_prompt(user_prompt: &str) -> String {
|
||||
@@ -32,14 +35,16 @@ fn qwen3_chat_prompt(user_prompt: &str) -> String {
|
||||
|
||||
impl Default for QwenRunConfig {
|
||||
fn default() -> Self {
|
||||
let bench = BenchEnv::from_env(500, 500);
|
||||
Self {
|
||||
repo_id: "Qwen/Qwen3-4B".to_string(),
|
||||
max_seq_len: 4096,
|
||||
gen_tokens: 500,
|
||||
search_graphs: 500,
|
||||
gen_tokens: bench.gen_tokens,
|
||||
search_graphs: bench.search_graphs,
|
||||
prompt: "Explain what a neural network is in a paragraph.".to_string(),
|
||||
repetition_penalty: 1.05,
|
||||
layers: LAYERS,
|
||||
stdio: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -119,6 +124,127 @@ impl QwenRuntime for luminal_metal::MetalRuntime {
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn run_one_prompt<R: QwenRuntime>(
|
||||
cx: &mut Graph,
|
||||
runtime: &mut R,
|
||||
tokenizer: &Tokenizer,
|
||||
input: GraphTensor,
|
||||
token_ids: GraphTensor,
|
||||
logits: GraphTensor,
|
||||
kv_cache: &KVCache,
|
||||
cache_outputs: &[(GraphTensor, GraphTensor)],
|
||||
cache_bytes: usize,
|
||||
layers: usize,
|
||||
prompt_tokens: &[u32],
|
||||
gen_tokens: usize,
|
||||
repetition_penalty: f32,
|
||||
emit_tok: &mut dyn FnMut(&str),
|
||||
) -> Result<(usize, Duration), Box<dyn Error>> {
|
||||
for i in 0..layers {
|
||||
runtime.set_zeros(kv_cache.k_caches[i].id, cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[i].id, cache_bytes);
|
||||
}
|
||||
let prompt_len = prompt_tokens.len();
|
||||
let mut prev_seq = 0usize;
|
||||
let mut seen_tokens: FxHashSet<u32> = FxHashSet::default();
|
||||
let mut generated = 0usize;
|
||||
let mut sentence: Vec<u32> = Vec::new();
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
if gen_tokens > 0 && prompt_len > 0 {
|
||||
cx.set_dim('s', prompt_len);
|
||||
cx.set_dim('p', 0);
|
||||
|
||||
runtime.set_i32_data(
|
||||
input.id,
|
||||
prompt_tokens.iter().map(|t| *t as i32).collect::<Vec<_>>(),
|
||||
);
|
||||
runtime.set_i32_data(token_ids.id, (0..prompt_len as i32).collect::<Vec<_>>());
|
||||
runtime.prepare_execute(&cx.dyn_map);
|
||||
runtime.execute(&cx.dyn_map);
|
||||
let logits_data = runtime.get_f32(logits.id);
|
||||
|
||||
for (layer_idx, (k_out, v_out)) in cache_outputs.iter().enumerate() {
|
||||
let k_buf = runtime.remove_buffer(k_out.id);
|
||||
let v_buf = runtime.remove_buffer(v_out.id);
|
||||
runtime.set_buffer(kv_cache.k_caches[layer_idx].id, k_buf);
|
||||
runtime.set_buffer(kv_cache.v_caches[layer_idx].id, v_buf);
|
||||
}
|
||||
prev_seq = prompt_len;
|
||||
|
||||
let row_start = (prompt_len - 1) * VOCAB_SIZE;
|
||||
let next_token = sample_greedy_with_penalty(
|
||||
&logits_data[row_start..row_start + VOCAB_SIZE],
|
||||
&seen_tokens,
|
||||
repetition_penalty,
|
||||
);
|
||||
sentence = vec![next_token];
|
||||
seen_tokens.insert(next_token);
|
||||
generated = 1;
|
||||
|
||||
if next_token != EOS_TOKEN && next_token != STOP_TOKEN {
|
||||
let decoded = tokenizer
|
||||
.decode(&[next_token], true)
|
||||
.map_err(|err| err as Box<dyn Error>)?;
|
||||
emit_tok(&decoded);
|
||||
}
|
||||
}
|
||||
|
||||
while generated < gen_tokens && !sentence.is_empty() {
|
||||
let seq_len = sentence.len();
|
||||
let current_token = sentence[0];
|
||||
|
||||
if current_token == EOS_TOKEN || current_token == STOP_TOKEN {
|
||||
break;
|
||||
}
|
||||
|
||||
cx.set_dim('s', seq_len);
|
||||
cx.set_dim('p', prev_seq);
|
||||
|
||||
runtime.set_i32_data(
|
||||
input.id,
|
||||
sentence.iter().map(|t| *t as i32).collect::<Vec<_>>(),
|
||||
);
|
||||
runtime.set_i32_data(
|
||||
token_ids.id,
|
||||
(prev_seq as i32..(seq_len + prev_seq) as i32).collect::<Vec<_>>(),
|
||||
);
|
||||
runtime.prepare_execute(&cx.dyn_map);
|
||||
runtime.execute(&cx.dyn_map);
|
||||
let logits_data = runtime.get_f32(logits.id);
|
||||
|
||||
for (layer_idx, (k_out, v_out)) in cache_outputs.iter().enumerate() {
|
||||
let k_buf = runtime.remove_buffer(k_out.id);
|
||||
let v_buf = runtime.remove_buffer(v_out.id);
|
||||
runtime.set_buffer(kv_cache.k_caches[layer_idx].id, k_buf);
|
||||
runtime.set_buffer(kv_cache.v_caches[layer_idx].id, v_buf);
|
||||
}
|
||||
|
||||
prev_seq += seq_len;
|
||||
|
||||
let next_token = sample_greedy_with_penalty(
|
||||
&logits_data[logits_data.len() - VOCAB_SIZE..],
|
||||
&seen_tokens,
|
||||
repetition_penalty,
|
||||
);
|
||||
sentence = vec![next_token];
|
||||
seen_tokens.insert(next_token);
|
||||
generated += 1;
|
||||
|
||||
if next_token == EOS_TOKEN || next_token == STOP_TOKEN {
|
||||
break;
|
||||
}
|
||||
|
||||
let decoded = tokenizer
|
||||
.decode(&[next_token], true)
|
||||
.map_err(|err| err as Box<dyn Error>)?;
|
||||
emit_tok(&decoded);
|
||||
}
|
||||
|
||||
Ok((generated, start.elapsed()))
|
||||
}
|
||||
|
||||
pub fn run_qwen<R>(mut runtime: R, config: QwenRunConfig) -> Result<(), Box<dyn Error>>
|
||||
where
|
||||
R: QwenRuntime + 'static,
|
||||
@@ -128,17 +254,27 @@ where
|
||||
.with(luminal_filter())
|
||||
.try_init();
|
||||
|
||||
let stdio_mode = config.stdio;
|
||||
let log = |s: &str| info(stdio_mode, s);
|
||||
|
||||
let model_dir = prepare_hf_model(&config.repo_id)?;
|
||||
println!("Using model directory: {}", model_dir.display());
|
||||
log(&format!("Using model directory: {}", model_dir.display()));
|
||||
|
||||
let tokenizer = Tokenizer::from_file(model_dir.join("tokenizer.json"))
|
||||
.map_err(|err| err as Box<dyn Error>)?;
|
||||
let prompt = qwen3_chat_prompt(&config.prompt);
|
||||
let prompt_tokens = tokenizer
|
||||
.encode(prompt.as_str(), false)
|
||||
.map_err(|err| err as Box<dyn Error>)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
|
||||
let (default_prompt_tokens, prompt_len) = if stdio_mode {
|
||||
(Vec::<u32>::new(), 0usize)
|
||||
} else {
|
||||
let prompt = qwen3_chat_prompt(&config.prompt);
|
||||
let toks = tokenizer
|
||||
.encode(prompt.as_str(), false)
|
||||
.map_err(|err| err as Box<dyn Error>)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
let len = toks.len();
|
||||
(toks, len)
|
||||
};
|
||||
|
||||
let mut cx = Graph::default();
|
||||
let input = cx.named_tensor("input", 's').as_dtype(DType::Int);
|
||||
@@ -152,10 +288,10 @@ where
|
||||
v_out.output();
|
||||
}
|
||||
|
||||
println!("Building E-Graph...");
|
||||
log("Building E-Graph...");
|
||||
cx.build_search_space::<R>();
|
||||
|
||||
println!("Loading weights...");
|
||||
log("Loading weights...");
|
||||
let weights_path = model_dir.join("model_combined.safetensors");
|
||||
runtime.load_safetensors(&cx, weights_path.to_str().unwrap());
|
||||
|
||||
@@ -165,10 +301,14 @@ where
|
||||
runtime.set_zeros(kv_cache.v_caches[i].id, cache_bytes);
|
||||
}
|
||||
|
||||
println!("Compiling...");
|
||||
let max_prefill = (prompt_tokens.len() + 16)
|
||||
.next_power_of_two()
|
||||
.min(config.max_seq_len);
|
||||
log("Compiling...");
|
||||
let max_prefill = if stdio_mode {
|
||||
STDIO_MAX_PREFILL.min(config.max_seq_len)
|
||||
} else {
|
||||
(prompt_len + 16)
|
||||
.next_power_of_two()
|
||||
.min(config.max_seq_len)
|
||||
};
|
||||
let search_s = 16.min(max_prefill).max(2);
|
||||
cx.set_dim_buckets(
|
||||
's',
|
||||
@@ -188,7 +328,36 @@ where
|
||||
runtime.set_zeros(kv_cache.v_caches[i].id, cache_bytes);
|
||||
}
|
||||
|
||||
let prompt_len = prompt_tokens.len();
|
||||
if stdio_mode {
|
||||
stdio::serve(|user_prompt| {
|
||||
let chat_prompt = qwen3_chat_prompt(user_prompt);
|
||||
let prompt_tokens = tokenizer
|
||||
.encode(chat_prompt.as_str(), false)
|
||||
.expect("tokenize failed")
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
run_one_prompt(
|
||||
&mut cx,
|
||||
&mut runtime,
|
||||
&tokenizer,
|
||||
input,
|
||||
token_ids,
|
||||
logits,
|
||||
&kv_cache,
|
||||
&cache_outputs,
|
||||
cache_bytes,
|
||||
config.layers,
|
||||
&prompt_tokens,
|
||||
config.gen_tokens,
|
||||
config.repetition_penalty,
|
||||
&mut |s| stdio::emit_tok(s),
|
||||
)
|
||||
.expect("decode failed")
|
||||
});
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let prompt_tokens = default_prompt_tokens;
|
||||
let mut prev_seq = 0usize;
|
||||
let mut fwd_durations = vec![];
|
||||
let mut seen_tokens = FxHashSet::default();
|
||||
@@ -228,21 +397,11 @@ where
|
||||
fwd_durations.push(start.elapsed());
|
||||
|
||||
let row_start = (prompt_len - 1) * VOCAB_SIZE;
|
||||
let mut last_row = logits_data[row_start..row_start + VOCAB_SIZE].to_vec();
|
||||
for &tok in &seen_tokens {
|
||||
let logit = &mut last_row[tok as usize];
|
||||
if *logit > 0.0 {
|
||||
*logit /= config.repetition_penalty;
|
||||
} else {
|
||||
*logit *= config.repetition_penalty;
|
||||
}
|
||||
}
|
||||
let next_token = last_row
|
||||
.iter()
|
||||
.enumerate()
|
||||
.max_by(|(_, a), (_, b)| a.total_cmp(b))
|
||||
.unwrap()
|
||||
.0 as u32;
|
||||
let next_token = sample_greedy_with_penalty(
|
||||
&logits_data[row_start..row_start + VOCAB_SIZE],
|
||||
&seen_tokens,
|
||||
config.repetition_penalty,
|
||||
);
|
||||
sentence = vec![next_token];
|
||||
seen_tokens.insert(next_token);
|
||||
generated = 1;
|
||||
@@ -291,21 +450,11 @@ where
|
||||
prev_seq += seq_len;
|
||||
fwd_durations.push(start.elapsed());
|
||||
|
||||
let mut last_row = logits_data[logits_data.len() - VOCAB_SIZE..].to_vec();
|
||||
for &tok in &seen_tokens {
|
||||
let logit = &mut last_row[tok as usize];
|
||||
if *logit > 0.0 {
|
||||
*logit /= config.repetition_penalty;
|
||||
} else {
|
||||
*logit *= config.repetition_penalty;
|
||||
}
|
||||
}
|
||||
let next_token = last_row
|
||||
.iter()
|
||||
.enumerate()
|
||||
.max_by(|(_, a), (_, b)| a.total_cmp(b))
|
||||
.unwrap()
|
||||
.0 as u32;
|
||||
let next_token = sample_greedy_with_penalty(
|
||||
&logits_data[logits_data.len() - VOCAB_SIZE..],
|
||||
&seen_tokens,
|
||||
config.repetition_penalty,
|
||||
);
|
||||
sentence = vec![next_token];
|
||||
seen_tokens.insert(next_token);
|
||||
generated += 1;
|
||||
@@ -337,5 +486,6 @@ where
|
||||
);
|
||||
}
|
||||
|
||||
let _ = generated;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -14,16 +14,38 @@ use luminal_metal::MetalRuntime;
|
||||
))]
|
||||
use qwen::{QwenRunConfig, Runtime, run_qwen};
|
||||
|
||||
#[cfg(any(
|
||||
all(feature = "cuda", not(feature = "metal")),
|
||||
all(feature = "metal", not(feature = "cuda"), target_vendor = "apple")
|
||||
))]
|
||||
fn parse_cli() -> QwenRunConfig {
|
||||
let mut cfg = QwenRunConfig::default();
|
||||
for arg in std::env::args().skip(1) {
|
||||
match arg.as_str() {
|
||||
"--stdio" => cfg.stdio = true,
|
||||
"-h" | "--help" => {
|
||||
println!("Usage: qwen [--stdio]");
|
||||
std::process::exit(0);
|
||||
}
|
||||
other => {
|
||||
eprintln!("Unknown argument: {other}");
|
||||
std::process::exit(2);
|
||||
}
|
||||
}
|
||||
}
|
||||
cfg
|
||||
}
|
||||
|
||||
#[cfg(all(feature = "cuda", not(feature = "metal")))]
|
||||
fn main() {
|
||||
let ctx = CudaContext::new(0).unwrap();
|
||||
let stream = ctx.default_stream();
|
||||
run_qwen(CudaRuntime::initialize(stream), QwenRunConfig::default()).unwrap();
|
||||
run_qwen(CudaRuntime::initialize(stream), parse_cli()).unwrap();
|
||||
}
|
||||
|
||||
#[cfg(all(feature = "metal", not(feature = "cuda"), target_vendor = "apple"))]
|
||||
fn main() {
|
||||
run_qwen(MetalRuntime::initialize(()), QwenRunConfig::default()).unwrap();
|
||||
run_qwen(MetalRuntime::initialize(()), parse_cli()).unwrap();
|
||||
}
|
||||
|
||||
#[cfg(all(feature = "metal", not(feature = "cuda"), not(target_vendor = "apple")))]
|
||||
|
||||
@@ -9,6 +9,7 @@ edition = "2024"
|
||||
luminal = { path = "../.." }
|
||||
luminal_nn = { path = "../../crates/luminal_nn" }
|
||||
luminal_cuda_lite = { path = "../../crates/luminal_cuda_lite" }
|
||||
example_common = { path = "../example_common" }
|
||||
tokenizers = "0.22.2"
|
||||
rustc-hash = "2"
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
mod hf;
|
||||
mod model;
|
||||
|
||||
use example_common::{BenchEnv, has_arg, info, sample_greedy_with_penalty, stdio};
|
||||
use hf::prepare_hf_model;
|
||||
use luminal::prelude::*;
|
||||
use luminal_cuda_lite::{cudarc::driver::CudaContext, runtime::CudaRuntime};
|
||||
@@ -12,6 +13,11 @@ use tokenizers::Tokenizer;
|
||||
|
||||
const REPO_ID: &str = "Qwen/Qwen3-30B-A3B";
|
||||
const SEARCH_SEED: u64 = 0;
|
||||
const STDIO_MAX_PREFILL: usize = 512;
|
||||
const DEFAULT_GEN_TOKENS: usize = 30;
|
||||
const DEFAULT_SEARCH_GRAPHS: usize = 50;
|
||||
const EOS_TOKEN: u32 = 151645; // <|im_end|>
|
||||
const STOP_TOKEN: u32 = 151643; // <|endoftext|>
|
||||
|
||||
fn qwen3_chat_prompt(user_prompt: &str) -> String {
|
||||
format!(
|
||||
@@ -19,27 +25,134 @@ fn qwen3_chat_prompt(user_prompt: &str) -> String {
|
||||
)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn run_one_prompt(
|
||||
cx: &mut Graph,
|
||||
runtime: &mut CudaRuntime,
|
||||
tokenizer: &Tokenizer,
|
||||
input: GraphTensor,
|
||||
pos_ids: GraphTensor,
|
||||
logits: GraphTensor,
|
||||
kv_cache: &KVCache,
|
||||
cache_outputs: &[(GraphTensor, GraphTensor)],
|
||||
cache_bytes: usize,
|
||||
prompt_tokens: &[u32],
|
||||
gen_tokens: usize,
|
||||
repetition_penalty: f32,
|
||||
emit_tok: &mut dyn FnMut(&str),
|
||||
) -> (usize, Duration) {
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
|
||||
}
|
||||
|
||||
let prompt_len = prompt_tokens.len();
|
||||
if prompt_len == 0 || gen_tokens == 0 {
|
||||
return (0, Duration::default());
|
||||
}
|
||||
|
||||
let mut seen_tokens: FxHashSet<u32> = FxHashSet::default();
|
||||
let mut generated = 0usize;
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
cx.set_dim('s', prompt_len);
|
||||
cx.set_dim('p', 0);
|
||||
runtime.set_data(
|
||||
input,
|
||||
prompt_tokens.iter().map(|t| *t as i32).collect::<Vec<_>>(),
|
||||
);
|
||||
runtime.set_data(pos_ids, (0..prompt_len as i32).collect::<Vec<_>>());
|
||||
runtime.execute(&cx.dyn_map);
|
||||
|
||||
for (layer_idx, (k_out, v_out)) in cache_outputs.iter().enumerate() {
|
||||
let k_buf = runtime.remove_buffer(*k_out);
|
||||
let v_buf = runtime.remove_buffer(*v_out);
|
||||
runtime.set_buffer(kv_cache.k_caches[layer_idx], k_buf);
|
||||
runtime.set_buffer(kv_cache.v_caches[layer_idx], v_buf);
|
||||
}
|
||||
let mut prev_seq = prompt_len;
|
||||
|
||||
let logits_data = runtime.get_f32(logits);
|
||||
let mut next_token = sample_greedy_with_penalty(
|
||||
&logits_data[logits_data.len() - VOCAB_SIZE..],
|
||||
&seen_tokens,
|
||||
repetition_penalty,
|
||||
);
|
||||
seen_tokens.insert(next_token);
|
||||
generated += 1;
|
||||
if next_token != EOS_TOKEN && next_token != STOP_TOKEN {
|
||||
let decoded = tokenizer.decode(&[next_token], true).unwrap();
|
||||
emit_tok(&decoded);
|
||||
}
|
||||
|
||||
while generated < gen_tokens {
|
||||
if next_token == EOS_TOKEN || next_token == STOP_TOKEN {
|
||||
break;
|
||||
}
|
||||
cx.set_dim('s', 1);
|
||||
cx.set_dim('p', prev_seq);
|
||||
runtime.set_data(input, vec![next_token as i32]);
|
||||
runtime.set_data(pos_ids, vec![prev_seq as i32]);
|
||||
runtime.execute(&cx.dyn_map);
|
||||
|
||||
for (layer_idx, (k_out, v_out)) in cache_outputs.iter().enumerate() {
|
||||
let k_buf = runtime.remove_buffer(*k_out);
|
||||
let v_buf = runtime.remove_buffer(*v_out);
|
||||
runtime.set_buffer(kv_cache.k_caches[layer_idx], k_buf);
|
||||
runtime.set_buffer(kv_cache.v_caches[layer_idx], v_buf);
|
||||
}
|
||||
|
||||
prev_seq += 1;
|
||||
|
||||
let logits_data = runtime.get_f32(logits);
|
||||
next_token = sample_greedy_with_penalty(
|
||||
&logits_data[..VOCAB_SIZE],
|
||||
&seen_tokens,
|
||||
repetition_penalty,
|
||||
);
|
||||
seen_tokens.insert(next_token);
|
||||
generated += 1;
|
||||
|
||||
if next_token == EOS_TOKEN || next_token == STOP_TOKEN {
|
||||
break;
|
||||
}
|
||||
|
||||
let decoded = tokenizer.decode(&[next_token], true).unwrap();
|
||||
emit_tok(&decoded);
|
||||
}
|
||||
|
||||
(generated, start.elapsed())
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let max_seq_len = 4096;
|
||||
let gen_tokens = 30;
|
||||
let search_graphs = 50;
|
||||
let bench = BenchEnv::from_env(DEFAULT_GEN_TOKENS, DEFAULT_SEARCH_GRAPHS);
|
||||
let stdio_mode = has_arg("--stdio");
|
||||
let prompt = "What is the capital of France?";
|
||||
|
||||
let log = |s: &str| info(stdio_mode, s);
|
||||
|
||||
let ctx = CudaContext::new(0).unwrap();
|
||||
let stream = ctx.default_stream();
|
||||
|
||||
let model_dir = prepare_hf_model(REPO_ID).expect("Failed to prepare model");
|
||||
println!("Using model directory: {}", model_dir.display());
|
||||
log(&format!("Using model directory: {}", model_dir.display()));
|
||||
|
||||
let tokenizer = Tokenizer::from_file(model_dir.join("tokenizer.json")).unwrap();
|
||||
let chat_prompt = qwen3_chat_prompt(prompt);
|
||||
let prompt_tokens = tokenizer
|
||||
.encode(chat_prompt.as_str(), false)
|
||||
.unwrap()
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
|
||||
// Build graph
|
||||
let (default_prompt_tokens, prompt_len) = if stdio_mode {
|
||||
(Vec::<u32>::new(), 0usize)
|
||||
} else {
|
||||
let chat_prompt = qwen3_chat_prompt(prompt);
|
||||
let toks = tokenizer
|
||||
.encode(chat_prompt.as_str(), false)
|
||||
.unwrap()
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
let len = toks.len();
|
||||
(toks, len)
|
||||
};
|
||||
|
||||
let mut cx = Graph::default();
|
||||
let input = cx.named_tensor("input", 's').as_dtype(DType::Int);
|
||||
let pos_ids = cx.named_tensor("pos_ids", 's').as_dtype(DType::Int);
|
||||
@@ -51,10 +164,10 @@ fn main() {
|
||||
v_out.output();
|
||||
}
|
||||
|
||||
println!("Building E-Graph...");
|
||||
log("Building E-Graph...");
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
|
||||
println!("Loading weights...");
|
||||
log("Loading weights...");
|
||||
let mut runtime = CudaRuntime::initialize(stream);
|
||||
let weights_path = model_dir.join("model_combined.safetensors");
|
||||
runtime.load_safetensors(&cx, weights_path.to_str().unwrap());
|
||||
@@ -65,10 +178,12 @@ fn main() {
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
|
||||
}
|
||||
|
||||
println!("Compiling...");
|
||||
let max_prefill = (prompt_tokens.len() + 16)
|
||||
.next_power_of_two()
|
||||
.min(max_seq_len);
|
||||
log("Compiling...");
|
||||
let max_prefill = if stdio_mode {
|
||||
STDIO_MAX_PREFILL.min(max_seq_len)
|
||||
} else {
|
||||
(prompt_len + 16).next_power_of_two().min(max_seq_len)
|
||||
};
|
||||
let search_s = 16.min(max_prefill).max(2);
|
||||
cx.set_dim_buckets(
|
||||
's',
|
||||
@@ -82,33 +197,64 @@ fn main() {
|
||||
runtime.set_data(input, vec![1; search_s]);
|
||||
runtime.set_data(pos_ids, (0..search_s as i32).collect::<Vec<_>>());
|
||||
let mut rng = SmallRng::seed_from_u64(SEARCH_SEED);
|
||||
runtime = cx.search_options(runtime, SearchOptions::new(search_graphs), &mut rng);
|
||||
runtime = cx.search_options(runtime, SearchOptions::new(bench.search_graphs), &mut rng);
|
||||
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
|
||||
}
|
||||
|
||||
let repetition_penalty: f32 = 1.05;
|
||||
|
||||
if stdio_mode {
|
||||
stdio::serve(|user_prompt| {
|
||||
let chat_prompt = qwen3_chat_prompt(user_prompt);
|
||||
let prompt_tokens = tokenizer
|
||||
.encode(chat_prompt.as_str(), false)
|
||||
.unwrap()
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
run_one_prompt(
|
||||
&mut cx,
|
||||
&mut runtime,
|
||||
&tokenizer,
|
||||
input,
|
||||
pos_ids,
|
||||
logits,
|
||||
&kv_cache,
|
||||
&cache_outputs,
|
||||
cache_bytes,
|
||||
&prompt_tokens,
|
||||
bench.gen_tokens,
|
||||
repetition_penalty,
|
||||
&mut |s| stdio::emit_tok(s),
|
||||
)
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
// Legacy single-prompt flow.
|
||||
println!("Prompt: {prompt}");
|
||||
print!("Response: ");
|
||||
std::io::stdout().flush().unwrap();
|
||||
|
||||
let mut prev_seq: usize;
|
||||
let mut fwd_durations = vec![];
|
||||
let mut seen_tokens = FxHashSet::default();
|
||||
let repetition_penalty: f32 = 1.05;
|
||||
|
||||
const EOS_TOKEN: u32 = 151645; // <|im_end|>
|
||||
const STOP_TOKEN: u32 = 151643; // <|endoftext|>
|
||||
|
||||
let prefill_start = std::time::Instant::now();
|
||||
cx.set_dim('s', prompt_tokens.len());
|
||||
cx.set_dim('s', default_prompt_tokens.len());
|
||||
cx.set_dim('p', 0);
|
||||
runtime.set_data(
|
||||
input,
|
||||
prompt_tokens.iter().map(|t| *t as i32).collect::<Vec<_>>(),
|
||||
default_prompt_tokens
|
||||
.iter()
|
||||
.map(|t| *t as i32)
|
||||
.collect::<Vec<_>>(),
|
||||
);
|
||||
runtime.set_data(
|
||||
pos_ids,
|
||||
(0..default_prompt_tokens.len() as i32).collect::<Vec<_>>(),
|
||||
);
|
||||
runtime.set_data(pos_ids, (0..prompt_tokens.len() as i32).collect::<Vec<_>>());
|
||||
runtime.execute(&cx.dyn_map);
|
||||
|
||||
for (layer_idx, (k_out, v_out)) in cache_outputs.iter().enumerate() {
|
||||
@@ -117,24 +263,20 @@ fn main() {
|
||||
runtime.set_buffer(kv_cache.k_caches[layer_idx], k_buf);
|
||||
runtime.set_buffer(kv_cache.v_caches[layer_idx], v_buf);
|
||||
}
|
||||
prev_seq = prompt_tokens.len();
|
||||
let mut prev_seq = default_prompt_tokens.len();
|
||||
let prefill_duration = prefill_start.elapsed();
|
||||
|
||||
// Get logits from the last prompt row and sample first new token
|
||||
let logits_data = runtime.get_f32(logits);
|
||||
let last_row = &logits_data[logits_data.len() - VOCAB_SIZE..];
|
||||
let mut next_token = last_row
|
||||
.iter()
|
||||
.enumerate()
|
||||
.max_by(|(_, a), (_, b)| a.total_cmp(b))
|
||||
.unwrap()
|
||||
.0 as u32;
|
||||
let mut next_token = sample_greedy_with_penalty(
|
||||
&logits_data[logits_data.len() - VOCAB_SIZE..],
|
||||
&seen_tokens,
|
||||
repetition_penalty,
|
||||
);
|
||||
print!("{}", tokenizer.decode(&[next_token], true).unwrap());
|
||||
std::io::stdout().flush().unwrap();
|
||||
seen_tokens.insert(next_token);
|
||||
|
||||
// Decode loop
|
||||
for _ in 1..gen_tokens {
|
||||
for _ in 1..bench.gen_tokens {
|
||||
let start = std::time::Instant::now();
|
||||
cx.set_dim('s', 1);
|
||||
cx.set_dim('p', prev_seq);
|
||||
@@ -142,7 +284,6 @@ fn main() {
|
||||
runtime.set_data(pos_ids, vec![prev_seq as i32]);
|
||||
runtime.execute(&cx.dyn_map);
|
||||
|
||||
// Round-trip KV cache
|
||||
for (layer_idx, (k_out, v_out)) in cache_outputs.iter().enumerate() {
|
||||
let k_buf = runtime.remove_buffer(*k_out);
|
||||
let v_buf = runtime.remove_buffer(*v_out);
|
||||
@@ -153,21 +294,11 @@ fn main() {
|
||||
prev_seq += 1;
|
||||
|
||||
let logits_data = runtime.get_f32(logits);
|
||||
let mut last_row = logits_data[..VOCAB_SIZE].to_vec();
|
||||
for &tok in &seen_tokens {
|
||||
let logit = &mut last_row[tok as usize];
|
||||
if *logit > 0.0 {
|
||||
*logit /= repetition_penalty;
|
||||
} else {
|
||||
*logit *= repetition_penalty;
|
||||
}
|
||||
}
|
||||
next_token = last_row
|
||||
.iter()
|
||||
.enumerate()
|
||||
.max_by(|(_, a), (_, b)| a.total_cmp(b))
|
||||
.unwrap()
|
||||
.0 as u32;
|
||||
next_token = sample_greedy_with_penalty(
|
||||
&logits_data[..VOCAB_SIZE],
|
||||
&seen_tokens,
|
||||
repetition_penalty,
|
||||
);
|
||||
seen_tokens.insert(next_token);
|
||||
|
||||
if next_token == EOS_TOKEN || next_token == STOP_TOKEN {
|
||||
@@ -180,11 +311,10 @@ fn main() {
|
||||
}
|
||||
println!();
|
||||
|
||||
// Report benchmarks
|
||||
println!(
|
||||
" TTFT: {:.2} ms ({} prompt tokens)",
|
||||
prefill_duration.as_secs_f64() * 1e3,
|
||||
prompt_tokens.len()
|
||||
default_prompt_tokens.len()
|
||||
);
|
||||
if fwd_durations.len() > 1 {
|
||||
println!(
|
||||
|
||||
Reference in New Issue
Block a user