Compare commits

...

2 Commits

Author SHA1 Message Date
Tucker Morgan
9311c59b4c Extract example_common crate to dedup rust example boilerplate
Pulls escape_tok, env_usize/env_bool/has_arg, info(stdio_mode, msg),
sample_greedy_with_penalty, BenchEnv::from_env, and the whole
stdio::{ready, emit_tok, emit_eoq, next_prompt, serve} protocol loop into
a single examples/example_common crate. The four binaries (llama, qwen,
gemma4_moe, qwen3_moe) now reduce their stdio mode to a 4-line
stdio::serve(|prompt| { ... }) call.

Net: examples shrink by 539 lines, common crate adds 167 — ~370 fewer
lines and one source of truth for the protocol + sampling. Bench numbers
unchanged within noise.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-18 22:23:18 +00:00
Tucker Morgan
69f21f2a43 Add --stdio benchmark protocol to rust examples
Lets luminal-benchmarks drive each binary as a long-lived subprocess: after
init the binary prints READY, then per stdin prompt emits TOK\t<escaped>\n
per generated token and EOQ\t<n_tokens>\t<elapsed_ms>\n at the end.
GEN_TOKENS and SEARCH_GRAPHS env vars override the per-binary defaults so
sweeps can vary the search budget without rebuilding. In stdio mode all
informational output is routed to stderr so stdout stays clean for the
protocol; the legacy single-prompt path is unchanged.

Applies to llama, qwen, gemma4_moe, qwen3_moe.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-18 21:31:11 +00:00
11 changed files with 985 additions and 217 deletions

View File

@@ -0,0 +1,7 @@
[package]
name = "example_common"
version = "0.1.0"
edition = "2024"
[dependencies]
rustc-hash = "2"

View File

@@ -0,0 +1,167 @@
//! Shared helpers for the rust example binaries.
//!
//! - `--stdio` arg detection and READY/TOK/EOQ protocol used by the
//! luminal-benchmarks harness to drive a long-lived subprocess.
//! - Env-var parsing for benchmark knobs (`GEN_TOKENS`, `SEARCH_GRAPHS`).
//! - `info!` routing — stderr in stdio mode, stdout otherwise.
//! - Greedy sampling with a repetition penalty.
use rustc_hash::FxHashSet;
pub fn env_usize(name: &str, default: usize) -> usize {
std::env::var(name)
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(default)
}
pub fn env_bool(name: &str) -> bool {
std::env::var(name)
.ok()
.is_some_and(|s| matches!(s.as_str(), "1" | "true" | "TRUE" | "yes" | "YES"))
}
pub fn has_arg(name: &str) -> bool {
std::env::args().any(|a| a == name)
}
/// Route an info message to stderr in stdio mode (so the protocol channel
/// stays clean) or stdout otherwise.
pub fn info(stdio_mode: bool, msg: impl AsRef<str>) {
let msg = msg.as_ref();
if stdio_mode {
eprintln!("{msg}");
} else {
println!("{msg}");
}
}
/// Greedy argmax with a multiplicative repetition penalty applied to
/// previously-seen tokens.
pub fn sample_greedy_with_penalty(
logits_row: &[f32],
seen: &FxHashSet<u32>,
repetition_penalty: f32,
) -> u32 {
let mut row = logits_row.to_vec();
for &tok in seen {
let logit = &mut row[tok as usize];
if *logit > 0.0 {
*logit /= repetition_penalty;
} else {
*logit *= repetition_penalty;
}
}
row.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.total_cmp(b))
.unwrap()
.0 as u32
}
/// Escape a token's UTF-8 for one-line TOK transport: \ → \\, \t → \\t,
/// \n → \\n, \r → \\r. Inverted on the python side.
pub fn escape_tok(s: &str) -> String {
let mut out = String::with_capacity(s.len());
for c in s.chars() {
match c {
'\\' => out.push_str("\\\\"),
'\t' => out.push_str("\\t"),
'\n' => out.push_str("\\n"),
'\r' => out.push_str("\\r"),
_ => out.push(c),
}
}
out
}
pub mod stdio {
//! READY / TOK\t<text> / EOQ\t<n>\t<elapsed_ms> protocol shared with
//! `luminal-benchmarks/sut/rust.py`.
use super::escape_tok;
use std::io::{BufRead, Write};
use std::time::Duration;
/// Print the one-shot READY line that tells the harness init is done.
pub fn ready() {
let mut out = std::io::stdout().lock();
let _ = writeln!(out, "READY");
let _ = out.flush();
}
/// One generated token, emitted as it's produced.
pub fn emit_tok(text: &str) {
let mut out = std::io::stdout().lock();
let _ = writeln!(out, "TOK\t{}", escape_tok(text));
let _ = out.flush();
}
/// End-of-query marker with the total tokens produced for this prompt
/// and the elapsed time. (The harness uses LoadGen's own timestamps,
/// but the line is required to mark the boundary.)
pub fn emit_eoq(n_tokens: usize, elapsed: Duration) {
let mut out = std::io::stdout().lock();
let _ = writeln!(
out,
"EOQ\t{}\t{:.3}",
n_tokens,
elapsed.as_secs_f64() * 1e3
);
let _ = out.flush();
}
/// Read one prompt line. Blank lines are skipped (the harness writes
/// one prompt per non-empty line). Returns `None` on EOF / read error.
pub fn next_prompt<R: BufRead>(reader: &mut R, buf: &mut String) -> Option<String> {
loop {
buf.clear();
match reader.read_line(buf) {
Ok(0) => return None,
Ok(_) => {}
Err(e) => {
eprintln!("stdio read error: {e}");
return None;
}
}
let prompt = buf
.trim_end_matches('\n')
.trim_end_matches('\r')
.to_string();
if !prompt.is_empty() {
return Some(prompt);
}
}
}
/// Drive the per-prompt protocol: print READY, then for each stdin
/// line call `run` with the prompt and post-call emit EOQ. `run`
/// returns `(n_tokens_generated, prompt_elapsed)` and is expected to
/// have called `emit_tok` once per generated token.
pub fn serve(mut run: impl FnMut(&str) -> (usize, Duration)) {
ready();
let stdin = std::io::stdin();
let mut handle = stdin.lock();
let mut buf = String::new();
while let Some(prompt) = next_prompt(&mut handle, &mut buf) {
let (n, elapsed) = run(&prompt);
emit_eoq(n, elapsed);
}
}
}
/// Pull the standard benchmark knobs from env vars in one call.
pub struct BenchEnv {
pub gen_tokens: usize,
pub search_graphs: usize,
}
impl BenchEnv {
pub fn from_env(default_gen_tokens: usize, default_search_graphs: usize) -> Self {
Self {
gen_tokens: env_usize("GEN_TOKENS", default_gen_tokens),
search_graphs: env_usize("SEARCH_GRAPHS", default_search_graphs),
}
}
}

View File

@@ -9,6 +9,7 @@ edition = "2024"
luminal = { path = "../.." }
luminal_nn = { path = "../../crates/luminal_nn" }
luminal_cuda_lite = { path = "../../crates/luminal_cuda_lite" }
example_common = { path = "../example_common" }
tokenizers = "0.22.2"
rustc-hash = "2"

View File

@@ -1,6 +1,7 @@
mod hf;
mod model;
use example_common::{BenchEnv, env_bool, has_arg, info, sample_greedy_with_penalty, stdio};
use hf::prepare_hf_model;
use luminal::prelude::*;
use luminal_cuda_lite::{cudarc::driver::CudaContext, runtime::CudaRuntime};
@@ -10,32 +11,140 @@ use std::{io::Write, time::Duration};
use tokenizers::Tokenizer;
const REPO_ID: &str = "google/gemma-4-26B-A4B";
const STDIO_MAX_PREFILL: usize = 512;
const DEFAULT_GEN_TOKENS: usize = 30;
const DEFAULT_SEARCH_GRAPHS: usize = 50;
const EOS_TOKEN: u32 = 1;
fn env_bool(name: &str) -> bool {
std::env::var(name)
.ok()
.is_some_and(|s| matches!(s.as_str(), "1" | "true" | "TRUE" | "yes" | "YES"))
#[allow(clippy::too_many_arguments)]
fn run_one_prompt(
cx: &mut Graph,
runtime: &mut CudaRuntime,
tokenizer: &Tokenizer,
input: GraphTensor,
pos_ids: GraphTensor,
logits: GraphTensor,
kv_cache: &KVCache,
cache_outputs: &[(GraphTensor, GraphTensor)],
max_seq_len: usize,
prompt_tokens: &[u32],
gen_tokens: usize,
repetition_penalty: f32,
emit_tok: &mut dyn FnMut(&str),
) -> (usize, Duration) {
for layer in 0..LAYERS {
let cache_bytes = cache_bytes_for_layer(layer, max_seq_len);
runtime.set_zeros(kv_cache.k_caches[layer], cache_bytes);
runtime.set_zeros(kv_cache.v_caches[layer], cache_bytes);
}
let prompt_len = prompt_tokens.len();
if prompt_len == 0 || gen_tokens == 0 {
return (0, Duration::default());
}
let mut seen_tokens: FxHashSet<u32> = FxHashSet::default();
let mut generated = 0usize;
let start = std::time::Instant::now();
cx.set_dim('s', prompt_len);
cx.set_dim('p', 0);
runtime.set_data(
input,
prompt_tokens.iter().map(|t| *t as i32).collect::<Vec<_>>(),
);
runtime.set_data(pos_ids, (0..prompt_len as i32).collect::<Vec<_>>());
runtime.execute(&cx.dyn_map);
for (layer_idx, (k_out, v_out)) in cache_outputs.iter().enumerate() {
let k_buf = runtime.remove_buffer(*k_out);
let v_buf = runtime.remove_buffer(*v_out);
runtime.set_buffer(kv_cache.k_caches[layer_idx], k_buf);
runtime.set_buffer(kv_cache.v_caches[layer_idx], v_buf);
}
let mut prev_seq = prompt_len;
let logits_data = runtime.get_f32(logits);
let row_start = (prompt_len - 1) * VOCAB_SIZE;
let mut next_token = sample_greedy_with_penalty(
&logits_data[row_start..row_start + VOCAB_SIZE],
&seen_tokens,
repetition_penalty,
);
seen_tokens.insert(next_token);
generated += 1;
if next_token != EOS_TOKEN {
let decoded = tokenizer.decode(&[next_token], true).unwrap();
emit_tok(&decoded);
}
while generated < gen_tokens {
if next_token == EOS_TOKEN {
break;
}
cx.set_dim('s', 1);
cx.set_dim('p', prev_seq);
runtime.set_data(input, vec![next_token as i32]);
runtime.set_data(pos_ids, vec![prev_seq as i32]);
runtime.execute(&cx.dyn_map);
for (layer_idx, (k_out, v_out)) in cache_outputs.iter().enumerate() {
let k_buf = runtime.remove_buffer(*k_out);
let v_buf = runtime.remove_buffer(*v_out);
runtime.set_buffer(kv_cache.k_caches[layer_idx], k_buf);
runtime.set_buffer(kv_cache.v_caches[layer_idx], v_buf);
}
prev_seq += 1;
let logits_data = runtime.get_f32(logits);
next_token = sample_greedy_with_penalty(
&logits_data[..VOCAB_SIZE],
&seen_tokens,
repetition_penalty,
);
seen_tokens.insert(next_token);
generated += 1;
if next_token == EOS_TOKEN {
break;
}
let decoded = tokenizer.decode(&[next_token], true).unwrap();
emit_tok(&decoded);
}
(generated, start.elapsed())
}
fn main() {
let max_seq_len = 4096;
let gen_tokens = 30;
let search_graphs = 50;
let bench = BenchEnv::from_env(DEFAULT_GEN_TOKENS, DEFAULT_SEARCH_GRAPHS);
let stdio_mode = has_arg("--stdio");
let prompt = std::env::var("PROMPT").unwrap_or_else(|_| "The capital of France is".to_string());
let print_token_ids = env_bool("PRINT_TOKEN_IDS");
let log = |s: &str| info(stdio_mode, s);
let ctx = CudaContext::new(0).unwrap();
let stream = ctx.default_stream();
let model_dir = prepare_hf_model(REPO_ID).expect("Failed to prepare model");
println!("Using model directory: {}", model_dir.display());
log(&format!("Using model directory: {}", model_dir.display()));
let tokenizer = Tokenizer::from_file(model_dir.join("tokenizer.json")).unwrap();
let prompt_tokens = tokenizer
.encode(prompt.as_str(), true)
.unwrap()
.get_ids()
.to_vec();
let (default_prompt_tokens, prompt_len) = if stdio_mode {
(Vec::<u32>::new(), 0usize)
} else {
let toks = tokenizer
.encode(prompt.as_str(), true)
.unwrap()
.get_ids()
.to_vec();
let len = toks.len();
(toks, len)
};
let mut cx = Graph::default();
let input = cx.named_tensor("input", 's').as_dtype(DType::Int);
@@ -48,10 +157,10 @@ fn main() {
v_out.output();
}
println!("Building E-Graph...");
log("Building E-Graph...");
cx.build_search_space::<CudaRuntime>();
println!("Loading weights...");
log("Loading weights...");
let mut runtime = CudaRuntime::initialize(stream);
let weights_path = model_dir.join("model_combined.safetensors");
runtime.load_safetensors(&cx, weights_path.to_str().unwrap());
@@ -62,10 +171,12 @@ fn main() {
runtime.set_zeros(kv_cache.v_caches[layer], cache_bytes);
}
println!("Compiling...");
let max_prefill = (prompt_tokens.len() + 16)
.next_power_of_two()
.min(max_seq_len);
log("Compiling...");
let max_prefill = if stdio_mode {
STDIO_MAX_PREFILL.min(max_seq_len)
} else {
(prompt_len + 16).next_power_of_two().min(max_seq_len)
};
let search_s = 16.min(max_prefill).max(2);
cx.set_dim_buckets(
's',
@@ -78,7 +189,7 @@ fn main() {
cx.set_dim('p', 0);
runtime.set_data(input, vec![1; search_s]);
runtime.set_data(pos_ids, (0..search_s as i32).collect::<Vec<_>>());
runtime = cx.search(runtime, search_graphs);
runtime = cx.search(runtime, bench.search_graphs);
for layer in 0..LAYERS {
let cache_bytes = cache_bytes_for_layer(layer, max_seq_len);
@@ -86,25 +197,56 @@ fn main() {
runtime.set_zeros(kv_cache.v_caches[layer], cache_bytes);
}
let repetition_penalty: f32 = 1.05;
if stdio_mode {
stdio::serve(|user_prompt| {
let prompt_tokens = tokenizer
.encode(user_prompt, true)
.unwrap()
.get_ids()
.to_vec();
run_one_prompt(
&mut cx,
&mut runtime,
&tokenizer,
input,
pos_ids,
logits,
&kv_cache,
&cache_outputs,
max_seq_len,
&prompt_tokens,
bench.gen_tokens,
repetition_penalty,
&mut |s| stdio::emit_tok(s),
)
});
return;
}
// Legacy single-prompt flow.
print!("{prompt}");
std::io::stdout().flush().unwrap();
let mut prev_seq: usize;
let mut fwd_durations = vec![];
let mut seen_tokens = FxHashSet::default();
let mut generated_token_ids = vec![];
let repetition_penalty: f32 = 1.05;
const EOS_TOKEN: u32 = 1;
let prefill_start = std::time::Instant::now();
cx.set_dim('s', prompt_tokens.len());
cx.set_dim('s', default_prompt_tokens.len());
cx.set_dim('p', 0);
runtime.set_data(
input,
prompt_tokens.iter().map(|t| *t as i32).collect::<Vec<_>>(),
default_prompt_tokens
.iter()
.map(|t| *t as i32)
.collect::<Vec<_>>(),
);
runtime.set_data(
pos_ids,
(0..default_prompt_tokens.len() as i32).collect::<Vec<_>>(),
);
runtime.set_data(pos_ids, (0..prompt_tokens.len() as i32).collect::<Vec<_>>());
runtime.execute(&cx.dyn_map);
for (layer_idx, (k_out, v_out)) in cache_outputs.iter().enumerate() {
@@ -113,24 +255,22 @@ fn main() {
runtime.set_buffer(kv_cache.k_caches[layer_idx], k_buf);
runtime.set_buffer(kv_cache.v_caches[layer_idx], v_buf);
}
prev_seq = prompt_tokens.len();
let mut prev_seq = default_prompt_tokens.len();
let prefill_duration = prefill_start.elapsed();
let logits_data = runtime.get_f32(logits);
let row_start = (prompt_tokens.len() - 1) * VOCAB_SIZE;
let last_row = &logits_data[row_start..row_start + VOCAB_SIZE];
let mut next_token = last_row
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.total_cmp(b))
.unwrap()
.0 as u32;
let row_start = (default_prompt_tokens.len() - 1) * VOCAB_SIZE;
let mut next_token = sample_greedy_with_penalty(
&logits_data[row_start..row_start + VOCAB_SIZE],
&seen_tokens,
repetition_penalty,
);
generated_token_ids.push(next_token);
print!("{}", tokenizer.decode(&[next_token], true).unwrap());
std::io::stdout().flush().unwrap();
seen_tokens.insert(next_token);
for _ in 1..gen_tokens {
for _ in 1..bench.gen_tokens {
let start = std::time::Instant::now();
cx.set_dim('s', 1);
cx.set_dim('p', prev_seq);
@@ -148,21 +288,11 @@ fn main() {
prev_seq += 1;
let logits_data = runtime.get_f32(logits);
let mut last_row = logits_data[..VOCAB_SIZE].to_vec();
for &tok in &seen_tokens {
let logit = &mut last_row[tok as usize];
if *logit > 0.0 {
*logit /= repetition_penalty;
} else {
*logit *= repetition_penalty;
}
}
next_token = last_row
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.total_cmp(b))
.unwrap()
.0 as u32;
next_token = sample_greedy_with_penalty(
&logits_data[..VOCAB_SIZE],
&seen_tokens,
repetition_penalty,
);
generated_token_ids.push(next_token);
seen_tokens.insert(next_token);
@@ -182,7 +312,7 @@ fn main() {
println!(
" TTFT: {:.2} ms ({} prompt tokens)",
prefill_duration.as_secs_f64() * 1e3,
prompt_tokens.len()
default_prompt_tokens.len()
);
if fwd_durations.len() > 1 {
println!(

View File

@@ -14,6 +14,7 @@ luminal = { path = "../.." }
luminal_nn = { path = "../../crates/luminal_nn" }
luminal_cuda_lite = { path = "../../crates/luminal_cuda_lite" }
luminal_tracing = {path="../../crates/luminal_tracing"}
example_common = { path = "../example_common" }
tokenizers = "0.15.2"
tracing = "0.1.43"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }

View File

@@ -1,6 +1,7 @@
mod hf;
mod model;
use example_common::{BenchEnv, info, sample_greedy_with_penalty, stdio};
use hf::{WeightFormat, prepare_hf_model};
use luminal::prelude::*;
use luminal_cuda_lite::{cudarc::driver::CudaContext, runtime::CudaRuntime};
@@ -15,14 +16,18 @@ use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
const FP32_REPO_ID: &str = "NousResearch/Meta-Llama-3-8B-Instruct";
const FP8_REPO_ID: &str = "nvidia/Llama-3.1-8B-Instruct-FP8";
const MAX_SEQ_LEN: usize = 4096;
const GEN_TOKENS: usize = 500;
const SEARCH_GRAPHS: usize = 500;
const DEFAULT_GEN_TOKENS: usize = 500;
const DEFAULT_SEARCH_GRAPHS: usize = 500;
const STDIO_MAX_PREFILL: usize = 512;
const SEARCH_TRIALS: usize = 1;
const SEARCH_KEEP_BEST: usize = 4;
const SEARCH_MEMORY_MIB: usize = 2048;
const SEARCH_SEED: u64 = 0;
const PROMPT: &str = "Explain what a neural network is in a paragraph.";
const EOS_TOKEN: u32 = 128009;
const STOP_TOKEN: u32 = 128001;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum LlamaWeightMode {
Fp32,
@@ -45,20 +50,28 @@ impl LlamaWeightMode {
}
}
struct CliArgs {
weight_mode: LlamaWeightMode,
stdio: bool,
}
fn print_usage(program: &str) {
println!("Usage: {program} [--fp8]");
println!("Usage: {program} [--fp8] [--stdio]");
println!();
println!(" --fp8 Use {FP8_REPO_ID} with FP8 projection weights");
println!(" --stdio Long-lived stdio benchmark protocol (READY/TOK/EOQ)");
println!(" -h,--help Show this help");
}
fn parse_args() -> LlamaWeightMode {
let mut mode = LlamaWeightMode::Fp32;
fn parse_args() -> CliArgs {
let mut weight_mode = LlamaWeightMode::Fp32;
let mut stdio = false;
let mut args = env::args();
let program = args.next().unwrap_or_else(|| "llama".to_string());
for arg in args {
match arg.as_str() {
"--fp8" => mode = LlamaWeightMode::Fp8,
"--fp8" => weight_mode = LlamaWeightMode::Fp8,
"--stdio" => stdio = true,
"-h" | "--help" => {
print_usage(&program);
std::process::exit(0);
@@ -70,7 +83,7 @@ fn parse_args() -> LlamaWeightMode {
}
}
}
mode
CliArgs { weight_mode, stdio }
}
fn llama3_chat_prompt(user_prompt: &str) -> String {
@@ -121,7 +134,10 @@ fn print_profile(label: &str, profile: &StepProfile, n: usize) {
);
}
fn print_host_op_summary(runtime: &CudaRuntime, label: &str) {
fn print_host_op_summary(runtime: &CudaRuntime, label: &str, quiet: bool) {
if quiet {
return;
}
let host_ops = runtime.host_ops();
let debug_ops = host_ops
.iter()
@@ -155,23 +171,6 @@ fn print_host_op_summary(runtime: &CudaRuntime, label: &str) {
);
}
fn sample_greedy(logits_row: &[f32], seen: &FxHashSet<u32>, repetition_penalty: f32) -> u32 {
let mut row = logits_row.to_vec();
for &tok in seen {
let logit = &mut row[tok as usize];
if *logit > 0.0 {
*logit /= repetition_penalty;
} else {
*logit *= repetition_penalty;
}
}
row.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.total_cmp(b))
.unwrap()
.0 as u32
}
#[allow(clippy::too_many_arguments)]
fn run_model_step(
cx: &mut Graph,
@@ -238,31 +237,154 @@ fn causal_mask(q_pos: &[usize], context_len: usize) -> Vec<f32> {
mask
}
#[allow(clippy::too_many_arguments)]
fn run_one_prompt(
cx: &mut Graph,
runtime: &mut CudaRuntime,
tokenizer: &Tokenizer,
input: GraphTensor,
q_pos_t: GraphTensor,
scatter_idx_t: GraphTensor,
gather_idx_t: GraphTensor,
attn_mask_t: GraphTensor,
logits: GraphTensor,
kv_cache: &KVCache,
cache_outputs: &[(GraphTensor, GraphTensor)],
cache_bytes: usize,
prompt_tokens: &[u32],
gen_tokens: usize,
repetition_penalty: f32,
emit_tok: &mut dyn FnMut(&str),
) -> (usize, Duration) {
for i in 0..LAYERS {
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
}
let prompt_len = prompt_tokens.len();
let mut seen_tokens: FxHashSet<u32> = FxHashSet::default();
let mut context_len = 0usize;
let mut generated = 0usize;
let mut next_token: Option<u32> = None;
let start = std::time::Instant::now();
if gen_tokens > 0 && prompt_len > 0 {
let positions: Vec<usize> = (0..prompt_len).collect();
let q_pos: Vec<i32> = positions.iter().map(|&p| p as i32).collect();
let scatter_idx = q_pos.clone();
let gather_idx = q_pos.clone();
let mask = causal_mask(&positions, prompt_len);
let (logits_data, _profile) = run_model_step(
cx,
runtime,
input,
q_pos_t,
scatter_idx_t,
gather_idx_t,
attn_mask_t,
logits,
kv_cache,
cache_outputs,
prompt_tokens,
&q_pos,
&scatter_idx,
&gather_idx,
&mask,
);
context_len = prompt_len;
let token = sample_greedy_with_penalty(
&logits_data[logits_data.len() - VOCAB_SIZE..],
&seen_tokens,
repetition_penalty,
);
seen_tokens.insert(token);
next_token = Some(token);
generated = 1;
if token != EOS_TOKEN && token != STOP_TOKEN {
let decoded = tokenizer.decode(&[token], true).unwrap();
emit_tok(&decoded);
}
}
while generated < gen_tokens {
let current_token = match next_token {
Some(token) if token != EOS_TOKEN && token != STOP_TOKEN => token,
_ => break,
};
let (logits_data, _profile) = run_model_step(
cx,
runtime,
input,
q_pos_t,
scatter_idx_t,
gather_idx_t,
attn_mask_t,
logits,
kv_cache,
cache_outputs,
&[current_token],
&[context_len as i32],
&[context_len as i32],
&(0..=context_len as i32).collect::<Vec<_>>(),
&causal_mask(&[context_len], context_len + 1),
);
context_len += 1;
let token = sample_greedy_with_penalty(
&logits_data[logits_data.len() - VOCAB_SIZE..],
&seen_tokens,
repetition_penalty,
);
seen_tokens.insert(token);
next_token = Some(token);
generated += 1;
if token == EOS_TOKEN || token == STOP_TOKEN {
break;
}
let decoded = tokenizer.decode(&[token], true).unwrap();
emit_tok(&decoded);
}
(generated, start.elapsed())
}
fn main() {
tracing_subscriber::registry()
.with(tracing_subscriber::fmt::layer())
.with(luminal_filter())
.init();
let weight_mode = parse_args();
let cli = parse_args();
let stdio_mode = cli.stdio;
let weight_mode = cli.weight_mode;
let bench = BenchEnv::from_env(DEFAULT_GEN_TOKENS, DEFAULT_SEARCH_GRAPHS);
let log = |s: &str| info(stdio_mode, s);
let ctx = CudaContext::new(0).unwrap();
let stream = ctx.default_stream();
let prepared = prepare_hf_model(weight_mode.repo_id(), weight_mode.weight_format())
.expect("Failed to prepare model");
println!("Using model: {}", weight_mode.repo_id());
println!("Using model directory: {}", prepared.model_dir.display());
log(&format!("Using model: {}", weight_mode.repo_id()));
log(&format!(
"Using model directory: {}",
prepared.model_dir.display()
));
let tokenizer = Tokenizer::from_file(prepared.model_dir.join("tokenizer.json")).unwrap();
let chat_prompt = llama3_chat_prompt(PROMPT);
let prompt_tokens = tokenizer
.encode(chat_prompt.as_str(), false)
.unwrap()
.get_ids()
.to_vec();
let prompt_len = prompt_tokens.len();
// Build graph
let (prompt_tokens_default, prompt_len) = if stdio_mode {
(Vec::<u32>::new(), 0usize)
} else {
let chat_prompt = llama3_chat_prompt(PROMPT);
let toks = tokenizer
.encode(chat_prompt.as_str(), false)
.unwrap()
.get_ids()
.to_vec();
let len = toks.len();
(toks, len)
};
let mut cx = Graph::default();
let input = cx.named_tensor("input", 's').as_dtype(DType::Int);
let q_pos_t = cx.named_tensor("q_pos", 's').as_dtype(DType::Int);
@@ -291,24 +413,27 @@ fn main() {
cx.set_dim('s', 1);
cx.set_dim('c', 1);
println!("Building E-Graph...");
log("Building E-Graph...");
let egraph_start = std::time::Instant::now();
cx.build_search_space_with_options::<CudaRuntime>(
BuildSearchSpaceOptions::new().max_memory_mib(SEARCH_MEMORY_MIB),
);
println!(
log(&format!(
" E-Graph build: {:.2} s",
egraph_start.elapsed().as_secs_f64()
);
));
println!("Loading weights...");
log("Loading weights...");
let load_start = std::time::Instant::now();
let mut runtime = CudaRuntime::initialize(stream);
for weights_path in &prepared.weight_files {
println!(" Loading {}", weights_path.display());
log(&format!(" Loading {}", weights_path.display()));
runtime.load_safetensors(&cx, weights_path.to_str().unwrap());
}
println!(" Weight load: {:.2} s", load_start.elapsed().as_secs_f64());
log(&format!(
" Weight load: {:.2} s",
load_start.elapsed().as_secs_f64()
));
let cache_bytes = MAX_SEQ_LEN * KV_DIM * std::mem::size_of::<f32>();
for i in 0..LAYERS {
@@ -316,9 +441,13 @@ fn main() {
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
}
println!("Compiling...");
log("Compiling...");
let compile_start = std::time::Instant::now();
let max_prefill = (prompt_len + 16).next_power_of_two().min(MAX_SEQ_LEN);
let max_prefill = if stdio_mode {
STDIO_MAX_PREFILL.min(MAX_SEQ_LEN)
} else {
(prompt_len + 16).next_power_of_two().min(MAX_SEQ_LEN)
};
let search_s = 16.min(max_prefill).max(2);
cx.set_dim_buckets(
's',
@@ -334,45 +463,74 @@ fn main() {
runtime.set_data(scatter_idx_t, (0..search_s as i32).collect::<Vec<_>>());
runtime.set_data(gather_idx_t, (0..search_s as i32).collect::<Vec<_>>());
runtime.set_data(attn_mask_t, vec![0.0f32; search_s * search_s]);
println!(" Search seed: {SEARCH_SEED}");
println!(" Search trials: {SEARCH_TRIALS}");
println!(" Search keep-best: {SEARCH_KEEP_BEST}");
log(&format!(" Search seed: {SEARCH_SEED}"));
log(&format!(" Search trials: {SEARCH_TRIALS}"));
log(&format!(" Search keep-best: {SEARCH_KEEP_BEST}"));
let mut rng = StdRng::seed_from_u64(SEARCH_SEED);
runtime = cx.search_options(
runtime,
SearchOptions::new(SEARCH_GRAPHS)
SearchOptions::new(bench.search_graphs)
.trials(SEARCH_TRIALS)
.keep_best(SEARCH_KEEP_BEST),
&mut rng,
);
println!(
log(&format!(
" Search/compile: {:.2} s",
compile_start.elapsed().as_secs_f64()
);
print_host_op_summary(&runtime, "post-compile active bucket");
));
print_host_op_summary(&runtime, "post-compile active bucket", stdio_mode);
for i in 0..LAYERS {
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
}
let repetition_penalty: f32 = 1.05;
if stdio_mode {
stdio::serve(|user_prompt| {
let chat_prompt = llama3_chat_prompt(user_prompt);
let prompt_tokens = tokenizer
.encode(chat_prompt.as_str(), false)
.unwrap()
.get_ids()
.to_vec();
run_one_prompt(
&mut cx,
&mut runtime,
&tokenizer,
input,
q_pos_t,
scatter_idx_t,
gather_idx_t,
attn_mask_t,
logits,
&kv_cache,
&cache_outputs,
cache_bytes,
&prompt_tokens,
bench.gen_tokens,
repetition_penalty,
&mut |s| stdio::emit_tok(s),
)
});
return;
}
// Non-stdio: legacy single-prompt flow with profiling output.
let mut context_len = 0usize;
let mut fwd_durations = vec![];
let mut step_profiles = vec![];
let mut seen_tokens = FxHashSet::default();
let repetition_penalty: f32 = 1.05;
const EOS_TOKEN: u32 = 128009;
const STOP_TOKEN: u32 = 128001;
println!(
"Prompt: {} tokens, generating up to {} tokens",
prompt_len, GEN_TOKENS
prompt_len, bench.gen_tokens
);
let mut generated = 0usize;
let mut next_token = None;
if GEN_TOKENS > 0 && prompt_len > 0 {
if bench.gen_tokens > 0 && prompt_len > 0 {
let positions: Vec<usize> = (0..prompt_len).collect();
let q_pos: Vec<i32> = positions.iter().map(|&p| p as i32).collect();
let scatter_idx = q_pos.clone();
@@ -389,17 +547,17 @@ fn main() {
logits,
&kv_cache,
&cache_outputs,
&prompt_tokens,
&prompt_tokens_default,
&q_pos,
&scatter_idx,
&gather_idx,
&mask,
);
print_host_op_summary(&runtime, "after prefill");
print_host_op_summary(&runtime, "after prefill", false);
context_len = prompt_len;
let sample_start = std::time::Instant::now();
let token = sample_greedy(
let token = sample_greedy_with_penalty(
&logits_data[logits_data.len() - VOCAB_SIZE..],
&seen_tokens,
repetition_penalty,
@@ -420,7 +578,7 @@ fn main() {
}
}
while generated < GEN_TOKENS {
while generated < bench.gen_tokens {
let current_token = match next_token {
Some(token) if token != EOS_TOKEN && token != STOP_TOKEN => token,
_ => break,
@@ -444,12 +602,12 @@ fn main() {
&causal_mask(&[context_len], context_len + 1),
);
if generated == 1 {
print_host_op_summary(&runtime, "after first decode");
print_host_op_summary(&runtime, "after first decode", false);
}
context_len += 1;
let sample_start = std::time::Instant::now();
let token = sample_greedy(
let token = sample_greedy_with_penalty(
&logits_data[logits_data.len() - VOCAB_SIZE..],
&seen_tokens,
repetition_penalty,

View File

@@ -18,6 +18,7 @@ luminal_nn = { path = "../../crates/luminal_nn" }
luminal_tracing = { path = "../../crates/luminal_tracing" }
luminal_cuda_lite = { path = "../../crates/luminal_cuda_lite", optional = true }
luminal_metal = { path = "../../crates/luminal_metal", optional = true }
example_common = { path = "../example_common" }
tokenizers = "0.22.2"
tracing = "0.1.43"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }

View File

@@ -1,6 +1,7 @@
pub mod hf;
pub mod model;
use example_common::{BenchEnv, info, sample_greedy_with_penalty, stdio};
use hf::prepare_hf_model;
pub use luminal::prelude::Runtime;
use luminal::prelude::*;
@@ -13,6 +14,7 @@ use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
const EOS_TOKEN: u32 = 151645; // <|im_end|>
const STOP_TOKEN: u32 = 151643; // <|endoftext|>
const STDIO_MAX_PREFILL: usize = 512;
pub struct QwenRunConfig {
pub repo_id: String,
@@ -22,6 +24,7 @@ pub struct QwenRunConfig {
pub prompt: String,
pub repetition_penalty: f32,
pub layers: usize,
pub stdio: bool,
}
fn qwen3_chat_prompt(user_prompt: &str) -> String {
@@ -32,14 +35,16 @@ fn qwen3_chat_prompt(user_prompt: &str) -> String {
impl Default for QwenRunConfig {
fn default() -> Self {
let bench = BenchEnv::from_env(500, 500);
Self {
repo_id: "Qwen/Qwen3-4B".to_string(),
max_seq_len: 4096,
gen_tokens: 500,
search_graphs: 500,
gen_tokens: bench.gen_tokens,
search_graphs: bench.search_graphs,
prompt: "Explain what a neural network is in a paragraph.".to_string(),
repetition_penalty: 1.05,
layers: LAYERS,
stdio: false,
}
}
}
@@ -119,6 +124,127 @@ impl QwenRuntime for luminal_metal::MetalRuntime {
}
}
#[allow(clippy::too_many_arguments)]
fn run_one_prompt<R: QwenRuntime>(
cx: &mut Graph,
runtime: &mut R,
tokenizer: &Tokenizer,
input: GraphTensor,
token_ids: GraphTensor,
logits: GraphTensor,
kv_cache: &KVCache,
cache_outputs: &[(GraphTensor, GraphTensor)],
cache_bytes: usize,
layers: usize,
prompt_tokens: &[u32],
gen_tokens: usize,
repetition_penalty: f32,
emit_tok: &mut dyn FnMut(&str),
) -> Result<(usize, Duration), Box<dyn Error>> {
for i in 0..layers {
runtime.set_zeros(kv_cache.k_caches[i].id, cache_bytes);
runtime.set_zeros(kv_cache.v_caches[i].id, cache_bytes);
}
let prompt_len = prompt_tokens.len();
let mut prev_seq = 0usize;
let mut seen_tokens: FxHashSet<u32> = FxHashSet::default();
let mut generated = 0usize;
let mut sentence: Vec<u32> = Vec::new();
let start = std::time::Instant::now();
if gen_tokens > 0 && prompt_len > 0 {
cx.set_dim('s', prompt_len);
cx.set_dim('p', 0);
runtime.set_i32_data(
input.id,
prompt_tokens.iter().map(|t| *t as i32).collect::<Vec<_>>(),
);
runtime.set_i32_data(token_ids.id, (0..prompt_len as i32).collect::<Vec<_>>());
runtime.prepare_execute(&cx.dyn_map);
runtime.execute(&cx.dyn_map);
let logits_data = runtime.get_f32(logits.id);
for (layer_idx, (k_out, v_out)) in cache_outputs.iter().enumerate() {
let k_buf = runtime.remove_buffer(k_out.id);
let v_buf = runtime.remove_buffer(v_out.id);
runtime.set_buffer(kv_cache.k_caches[layer_idx].id, k_buf);
runtime.set_buffer(kv_cache.v_caches[layer_idx].id, v_buf);
}
prev_seq = prompt_len;
let row_start = (prompt_len - 1) * VOCAB_SIZE;
let next_token = sample_greedy_with_penalty(
&logits_data[row_start..row_start + VOCAB_SIZE],
&seen_tokens,
repetition_penalty,
);
sentence = vec![next_token];
seen_tokens.insert(next_token);
generated = 1;
if next_token != EOS_TOKEN && next_token != STOP_TOKEN {
let decoded = tokenizer
.decode(&[next_token], true)
.map_err(|err| err as Box<dyn Error>)?;
emit_tok(&decoded);
}
}
while generated < gen_tokens && !sentence.is_empty() {
let seq_len = sentence.len();
let current_token = sentence[0];
if current_token == EOS_TOKEN || current_token == STOP_TOKEN {
break;
}
cx.set_dim('s', seq_len);
cx.set_dim('p', prev_seq);
runtime.set_i32_data(
input.id,
sentence.iter().map(|t| *t as i32).collect::<Vec<_>>(),
);
runtime.set_i32_data(
token_ids.id,
(prev_seq as i32..(seq_len + prev_seq) as i32).collect::<Vec<_>>(),
);
runtime.prepare_execute(&cx.dyn_map);
runtime.execute(&cx.dyn_map);
let logits_data = runtime.get_f32(logits.id);
for (layer_idx, (k_out, v_out)) in cache_outputs.iter().enumerate() {
let k_buf = runtime.remove_buffer(k_out.id);
let v_buf = runtime.remove_buffer(v_out.id);
runtime.set_buffer(kv_cache.k_caches[layer_idx].id, k_buf);
runtime.set_buffer(kv_cache.v_caches[layer_idx].id, v_buf);
}
prev_seq += seq_len;
let next_token = sample_greedy_with_penalty(
&logits_data[logits_data.len() - VOCAB_SIZE..],
&seen_tokens,
repetition_penalty,
);
sentence = vec![next_token];
seen_tokens.insert(next_token);
generated += 1;
if next_token == EOS_TOKEN || next_token == STOP_TOKEN {
break;
}
let decoded = tokenizer
.decode(&[next_token], true)
.map_err(|err| err as Box<dyn Error>)?;
emit_tok(&decoded);
}
Ok((generated, start.elapsed()))
}
pub fn run_qwen<R>(mut runtime: R, config: QwenRunConfig) -> Result<(), Box<dyn Error>>
where
R: QwenRuntime + 'static,
@@ -128,17 +254,27 @@ where
.with(luminal_filter())
.try_init();
let stdio_mode = config.stdio;
let log = |s: &str| info(stdio_mode, s);
let model_dir = prepare_hf_model(&config.repo_id)?;
println!("Using model directory: {}", model_dir.display());
log(&format!("Using model directory: {}", model_dir.display()));
let tokenizer = Tokenizer::from_file(model_dir.join("tokenizer.json"))
.map_err(|err| err as Box<dyn Error>)?;
let prompt = qwen3_chat_prompt(&config.prompt);
let prompt_tokens = tokenizer
.encode(prompt.as_str(), false)
.map_err(|err| err as Box<dyn Error>)?
.get_ids()
.to_vec();
let (default_prompt_tokens, prompt_len) = if stdio_mode {
(Vec::<u32>::new(), 0usize)
} else {
let prompt = qwen3_chat_prompt(&config.prompt);
let toks = tokenizer
.encode(prompt.as_str(), false)
.map_err(|err| err as Box<dyn Error>)?
.get_ids()
.to_vec();
let len = toks.len();
(toks, len)
};
let mut cx = Graph::default();
let input = cx.named_tensor("input", 's').as_dtype(DType::Int);
@@ -152,10 +288,10 @@ where
v_out.output();
}
println!("Building E-Graph...");
log("Building E-Graph...");
cx.build_search_space::<R>();
println!("Loading weights...");
log("Loading weights...");
let weights_path = model_dir.join("model_combined.safetensors");
runtime.load_safetensors(&cx, weights_path.to_str().unwrap());
@@ -165,10 +301,14 @@ where
runtime.set_zeros(kv_cache.v_caches[i].id, cache_bytes);
}
println!("Compiling...");
let max_prefill = (prompt_tokens.len() + 16)
.next_power_of_two()
.min(config.max_seq_len);
log("Compiling...");
let max_prefill = if stdio_mode {
STDIO_MAX_PREFILL.min(config.max_seq_len)
} else {
(prompt_len + 16)
.next_power_of_two()
.min(config.max_seq_len)
};
let search_s = 16.min(max_prefill).max(2);
cx.set_dim_buckets(
's',
@@ -188,7 +328,36 @@ where
runtime.set_zeros(kv_cache.v_caches[i].id, cache_bytes);
}
let prompt_len = prompt_tokens.len();
if stdio_mode {
stdio::serve(|user_prompt| {
let chat_prompt = qwen3_chat_prompt(user_prompt);
let prompt_tokens = tokenizer
.encode(chat_prompt.as_str(), false)
.expect("tokenize failed")
.get_ids()
.to_vec();
run_one_prompt(
&mut cx,
&mut runtime,
&tokenizer,
input,
token_ids,
logits,
&kv_cache,
&cache_outputs,
cache_bytes,
config.layers,
&prompt_tokens,
config.gen_tokens,
config.repetition_penalty,
&mut |s| stdio::emit_tok(s),
)
.expect("decode failed")
});
return Ok(());
}
let prompt_tokens = default_prompt_tokens;
let mut prev_seq = 0usize;
let mut fwd_durations = vec![];
let mut seen_tokens = FxHashSet::default();
@@ -228,21 +397,11 @@ where
fwd_durations.push(start.elapsed());
let row_start = (prompt_len - 1) * VOCAB_SIZE;
let mut last_row = logits_data[row_start..row_start + VOCAB_SIZE].to_vec();
for &tok in &seen_tokens {
let logit = &mut last_row[tok as usize];
if *logit > 0.0 {
*logit /= config.repetition_penalty;
} else {
*logit *= config.repetition_penalty;
}
}
let next_token = last_row
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.total_cmp(b))
.unwrap()
.0 as u32;
let next_token = sample_greedy_with_penalty(
&logits_data[row_start..row_start + VOCAB_SIZE],
&seen_tokens,
config.repetition_penalty,
);
sentence = vec![next_token];
seen_tokens.insert(next_token);
generated = 1;
@@ -291,21 +450,11 @@ where
prev_seq += seq_len;
fwd_durations.push(start.elapsed());
let mut last_row = logits_data[logits_data.len() - VOCAB_SIZE..].to_vec();
for &tok in &seen_tokens {
let logit = &mut last_row[tok as usize];
if *logit > 0.0 {
*logit /= config.repetition_penalty;
} else {
*logit *= config.repetition_penalty;
}
}
let next_token = last_row
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.total_cmp(b))
.unwrap()
.0 as u32;
let next_token = sample_greedy_with_penalty(
&logits_data[logits_data.len() - VOCAB_SIZE..],
&seen_tokens,
config.repetition_penalty,
);
sentence = vec![next_token];
seen_tokens.insert(next_token);
generated += 1;
@@ -337,5 +486,6 @@ where
);
}
let _ = generated;
Ok(())
}

View File

@@ -14,16 +14,38 @@ use luminal_metal::MetalRuntime;
))]
use qwen::{QwenRunConfig, Runtime, run_qwen};
#[cfg(any(
all(feature = "cuda", not(feature = "metal")),
all(feature = "metal", not(feature = "cuda"), target_vendor = "apple")
))]
fn parse_cli() -> QwenRunConfig {
let mut cfg = QwenRunConfig::default();
for arg in std::env::args().skip(1) {
match arg.as_str() {
"--stdio" => cfg.stdio = true,
"-h" | "--help" => {
println!("Usage: qwen [--stdio]");
std::process::exit(0);
}
other => {
eprintln!("Unknown argument: {other}");
std::process::exit(2);
}
}
}
cfg
}
#[cfg(all(feature = "cuda", not(feature = "metal")))]
fn main() {
let ctx = CudaContext::new(0).unwrap();
let stream = ctx.default_stream();
run_qwen(CudaRuntime::initialize(stream), QwenRunConfig::default()).unwrap();
run_qwen(CudaRuntime::initialize(stream), parse_cli()).unwrap();
}
#[cfg(all(feature = "metal", not(feature = "cuda"), target_vendor = "apple"))]
fn main() {
run_qwen(MetalRuntime::initialize(()), QwenRunConfig::default()).unwrap();
run_qwen(MetalRuntime::initialize(()), parse_cli()).unwrap();
}
#[cfg(all(feature = "metal", not(feature = "cuda"), not(target_vendor = "apple")))]

View File

@@ -9,6 +9,7 @@ edition = "2024"
luminal = { path = "../.." }
luminal_nn = { path = "../../crates/luminal_nn" }
luminal_cuda_lite = { path = "../../crates/luminal_cuda_lite" }
example_common = { path = "../example_common" }
tokenizers = "0.22.2"
rustc-hash = "2"

View File

@@ -1,6 +1,7 @@
mod hf;
mod model;
use example_common::{BenchEnv, has_arg, info, sample_greedy_with_penalty, stdio};
use hf::prepare_hf_model;
use luminal::prelude::*;
use luminal_cuda_lite::{cudarc::driver::CudaContext, runtime::CudaRuntime};
@@ -12,6 +13,11 @@ use tokenizers::Tokenizer;
const REPO_ID: &str = "Qwen/Qwen3-30B-A3B";
const SEARCH_SEED: u64 = 0;
const STDIO_MAX_PREFILL: usize = 512;
const DEFAULT_GEN_TOKENS: usize = 30;
const DEFAULT_SEARCH_GRAPHS: usize = 50;
const EOS_TOKEN: u32 = 151645; // <|im_end|>
const STOP_TOKEN: u32 = 151643; // <|endoftext|>
fn qwen3_chat_prompt(user_prompt: &str) -> String {
format!(
@@ -19,27 +25,134 @@ fn qwen3_chat_prompt(user_prompt: &str) -> String {
)
}
#[allow(clippy::too_many_arguments)]
fn run_one_prompt(
cx: &mut Graph,
runtime: &mut CudaRuntime,
tokenizer: &Tokenizer,
input: GraphTensor,
pos_ids: GraphTensor,
logits: GraphTensor,
kv_cache: &KVCache,
cache_outputs: &[(GraphTensor, GraphTensor)],
cache_bytes: usize,
prompt_tokens: &[u32],
gen_tokens: usize,
repetition_penalty: f32,
emit_tok: &mut dyn FnMut(&str),
) -> (usize, Duration) {
for i in 0..LAYERS {
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
}
let prompt_len = prompt_tokens.len();
if prompt_len == 0 || gen_tokens == 0 {
return (0, Duration::default());
}
let mut seen_tokens: FxHashSet<u32> = FxHashSet::default();
let mut generated = 0usize;
let start = std::time::Instant::now();
cx.set_dim('s', prompt_len);
cx.set_dim('p', 0);
runtime.set_data(
input,
prompt_tokens.iter().map(|t| *t as i32).collect::<Vec<_>>(),
);
runtime.set_data(pos_ids, (0..prompt_len as i32).collect::<Vec<_>>());
runtime.execute(&cx.dyn_map);
for (layer_idx, (k_out, v_out)) in cache_outputs.iter().enumerate() {
let k_buf = runtime.remove_buffer(*k_out);
let v_buf = runtime.remove_buffer(*v_out);
runtime.set_buffer(kv_cache.k_caches[layer_idx], k_buf);
runtime.set_buffer(kv_cache.v_caches[layer_idx], v_buf);
}
let mut prev_seq = prompt_len;
let logits_data = runtime.get_f32(logits);
let mut next_token = sample_greedy_with_penalty(
&logits_data[logits_data.len() - VOCAB_SIZE..],
&seen_tokens,
repetition_penalty,
);
seen_tokens.insert(next_token);
generated += 1;
if next_token != EOS_TOKEN && next_token != STOP_TOKEN {
let decoded = tokenizer.decode(&[next_token], true).unwrap();
emit_tok(&decoded);
}
while generated < gen_tokens {
if next_token == EOS_TOKEN || next_token == STOP_TOKEN {
break;
}
cx.set_dim('s', 1);
cx.set_dim('p', prev_seq);
runtime.set_data(input, vec![next_token as i32]);
runtime.set_data(pos_ids, vec![prev_seq as i32]);
runtime.execute(&cx.dyn_map);
for (layer_idx, (k_out, v_out)) in cache_outputs.iter().enumerate() {
let k_buf = runtime.remove_buffer(*k_out);
let v_buf = runtime.remove_buffer(*v_out);
runtime.set_buffer(kv_cache.k_caches[layer_idx], k_buf);
runtime.set_buffer(kv_cache.v_caches[layer_idx], v_buf);
}
prev_seq += 1;
let logits_data = runtime.get_f32(logits);
next_token = sample_greedy_with_penalty(
&logits_data[..VOCAB_SIZE],
&seen_tokens,
repetition_penalty,
);
seen_tokens.insert(next_token);
generated += 1;
if next_token == EOS_TOKEN || next_token == STOP_TOKEN {
break;
}
let decoded = tokenizer.decode(&[next_token], true).unwrap();
emit_tok(&decoded);
}
(generated, start.elapsed())
}
fn main() {
let max_seq_len = 4096;
let gen_tokens = 30;
let search_graphs = 50;
let bench = BenchEnv::from_env(DEFAULT_GEN_TOKENS, DEFAULT_SEARCH_GRAPHS);
let stdio_mode = has_arg("--stdio");
let prompt = "What is the capital of France?";
let log = |s: &str| info(stdio_mode, s);
let ctx = CudaContext::new(0).unwrap();
let stream = ctx.default_stream();
let model_dir = prepare_hf_model(REPO_ID).expect("Failed to prepare model");
println!("Using model directory: {}", model_dir.display());
log(&format!("Using model directory: {}", model_dir.display()));
let tokenizer = Tokenizer::from_file(model_dir.join("tokenizer.json")).unwrap();
let chat_prompt = qwen3_chat_prompt(prompt);
let prompt_tokens = tokenizer
.encode(chat_prompt.as_str(), false)
.unwrap()
.get_ids()
.to_vec();
// Build graph
let (default_prompt_tokens, prompt_len) = if stdio_mode {
(Vec::<u32>::new(), 0usize)
} else {
let chat_prompt = qwen3_chat_prompt(prompt);
let toks = tokenizer
.encode(chat_prompt.as_str(), false)
.unwrap()
.get_ids()
.to_vec();
let len = toks.len();
(toks, len)
};
let mut cx = Graph::default();
let input = cx.named_tensor("input", 's').as_dtype(DType::Int);
let pos_ids = cx.named_tensor("pos_ids", 's').as_dtype(DType::Int);
@@ -51,10 +164,10 @@ fn main() {
v_out.output();
}
println!("Building E-Graph...");
log("Building E-Graph...");
cx.build_search_space::<CudaRuntime>();
println!("Loading weights...");
log("Loading weights...");
let mut runtime = CudaRuntime::initialize(stream);
let weights_path = model_dir.join("model_combined.safetensors");
runtime.load_safetensors(&cx, weights_path.to_str().unwrap());
@@ -65,10 +178,12 @@ fn main() {
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
}
println!("Compiling...");
let max_prefill = (prompt_tokens.len() + 16)
.next_power_of_two()
.min(max_seq_len);
log("Compiling...");
let max_prefill = if stdio_mode {
STDIO_MAX_PREFILL.min(max_seq_len)
} else {
(prompt_len + 16).next_power_of_two().min(max_seq_len)
};
let search_s = 16.min(max_prefill).max(2);
cx.set_dim_buckets(
's',
@@ -82,33 +197,64 @@ fn main() {
runtime.set_data(input, vec![1; search_s]);
runtime.set_data(pos_ids, (0..search_s as i32).collect::<Vec<_>>());
let mut rng = SmallRng::seed_from_u64(SEARCH_SEED);
runtime = cx.search_options(runtime, SearchOptions::new(search_graphs), &mut rng);
runtime = cx.search_options(runtime, SearchOptions::new(bench.search_graphs), &mut rng);
for i in 0..LAYERS {
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
}
let repetition_penalty: f32 = 1.05;
if stdio_mode {
stdio::serve(|user_prompt| {
let chat_prompt = qwen3_chat_prompt(user_prompt);
let prompt_tokens = tokenizer
.encode(chat_prompt.as_str(), false)
.unwrap()
.get_ids()
.to_vec();
run_one_prompt(
&mut cx,
&mut runtime,
&tokenizer,
input,
pos_ids,
logits,
&kv_cache,
&cache_outputs,
cache_bytes,
&prompt_tokens,
bench.gen_tokens,
repetition_penalty,
&mut |s| stdio::emit_tok(s),
)
});
return;
}
// Legacy single-prompt flow.
println!("Prompt: {prompt}");
print!("Response: ");
std::io::stdout().flush().unwrap();
let mut prev_seq: usize;
let mut fwd_durations = vec![];
let mut seen_tokens = FxHashSet::default();
let repetition_penalty: f32 = 1.05;
const EOS_TOKEN: u32 = 151645; // <|im_end|>
const STOP_TOKEN: u32 = 151643; // <|endoftext|>
let prefill_start = std::time::Instant::now();
cx.set_dim('s', prompt_tokens.len());
cx.set_dim('s', default_prompt_tokens.len());
cx.set_dim('p', 0);
runtime.set_data(
input,
prompt_tokens.iter().map(|t| *t as i32).collect::<Vec<_>>(),
default_prompt_tokens
.iter()
.map(|t| *t as i32)
.collect::<Vec<_>>(),
);
runtime.set_data(
pos_ids,
(0..default_prompt_tokens.len() as i32).collect::<Vec<_>>(),
);
runtime.set_data(pos_ids, (0..prompt_tokens.len() as i32).collect::<Vec<_>>());
runtime.execute(&cx.dyn_map);
for (layer_idx, (k_out, v_out)) in cache_outputs.iter().enumerate() {
@@ -117,24 +263,20 @@ fn main() {
runtime.set_buffer(kv_cache.k_caches[layer_idx], k_buf);
runtime.set_buffer(kv_cache.v_caches[layer_idx], v_buf);
}
prev_seq = prompt_tokens.len();
let mut prev_seq = default_prompt_tokens.len();
let prefill_duration = prefill_start.elapsed();
// Get logits from the last prompt row and sample first new token
let logits_data = runtime.get_f32(logits);
let last_row = &logits_data[logits_data.len() - VOCAB_SIZE..];
let mut next_token = last_row
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.total_cmp(b))
.unwrap()
.0 as u32;
let mut next_token = sample_greedy_with_penalty(
&logits_data[logits_data.len() - VOCAB_SIZE..],
&seen_tokens,
repetition_penalty,
);
print!("{}", tokenizer.decode(&[next_token], true).unwrap());
std::io::stdout().flush().unwrap();
seen_tokens.insert(next_token);
// Decode loop
for _ in 1..gen_tokens {
for _ in 1..bench.gen_tokens {
let start = std::time::Instant::now();
cx.set_dim('s', 1);
cx.set_dim('p', prev_seq);
@@ -142,7 +284,6 @@ fn main() {
runtime.set_data(pos_ids, vec![prev_seq as i32]);
runtime.execute(&cx.dyn_map);
// Round-trip KV cache
for (layer_idx, (k_out, v_out)) in cache_outputs.iter().enumerate() {
let k_buf = runtime.remove_buffer(*k_out);
let v_buf = runtime.remove_buffer(*v_out);
@@ -153,21 +294,11 @@ fn main() {
prev_seq += 1;
let logits_data = runtime.get_f32(logits);
let mut last_row = logits_data[..VOCAB_SIZE].to_vec();
for &tok in &seen_tokens {
let logit = &mut last_row[tok as usize];
if *logit > 0.0 {
*logit /= repetition_penalty;
} else {
*logit *= repetition_penalty;
}
}
next_token = last_row
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.total_cmp(b))
.unwrap()
.0 as u32;
next_token = sample_greedy_with_penalty(
&logits_data[..VOCAB_SIZE],
&seen_tokens,
repetition_penalty,
);
seen_tokens.insert(next_token);
if next_token == EOS_TOKEN || next_token == STOP_TOKEN {
@@ -180,11 +311,10 @@ fn main() {
}
println!();
// Report benchmarks
println!(
" TTFT: {:.2} ms ({} prompt tokens)",
prefill_duration.as_secs_f64() * 1e3,
prompt_tokens.len()
default_prompt_tokens.len()
);
if fwd_durations.len() > 1 {
println!(