Compare commits

..

2 Commits

Author SHA1 Message Date
Tucker Morgan
0b1e09cf23 Merge remote-tracking branch 'origin/main' into codex/rust-stdio-benchmark 2026-05-14 17:01:19 +00:00
Tucker Morgan
7402503bd4 Add stdio mode to Rust benchmark examples 2026-05-12 21:21:29 +00:00
11 changed files with 480 additions and 617 deletions

View File

@@ -8,7 +8,7 @@ echo "=========================================="
echo " Luminal Python: Full Test Suite"
echo "=========================================="
NATIVE_TESTS="tests/test_hlir_ops.py tests/test_unary.py tests/test_input_layout.py tests/test_dtype_boundary.py tests/test_mutation_alias_contract.py"
NATIVE_TESTS="tests/test_hlir_ops.py tests/test_unary.py"
CUDA_TESTS="tests/"
# ── Phase 1: Native Backend ─────────────────────────────────

View File

@@ -16,7 +16,7 @@ uv run maturin develop --manifest-path rust/Cargo.toml
echo "Step 3: Running pytest..."
# it is best not to add the full model tests, they end up running billion parameter models
# on the CPU and it takes far to long
uv run pytest tests/test_hlir_ops.py tests/test_unary.py tests/test_input_layout.py tests/test_dtype_boundary.py tests/test_mutation_alias_contract.py -v
uv run pytest tests/test_hlir_ops.py tests/test_unary.py -v
echo ""
echo "=== Tests Complete ==="

View File

@@ -1,6 +1,5 @@
"""CompiledModel wrapper for the Rust CompiledGraph."""
import warnings
from typing import List
import torch
@@ -9,10 +8,6 @@ from .dtype_util import code_to_torch_dtype
from .dtype_util import torch_dtype_code as _torch_dtype_code
class DTypeBoundaryWarning(UserWarning):
"""Warns when the PyTorch boundary must cast input data before execution."""
class CompiledModel:
"""Wrapper around CompiledGraph that handles PyTorch tensor conversion."""
@@ -100,15 +95,6 @@ class CompiledModel:
for name, tensor, expected_dtype in zip(
self._input_names, user_inputs, self._input_dtypes
):
if tensor.dtype != expected_dtype:
warnings.warn(
"Luminal compiled input "
f"'{name}' has dtype {tensor.dtype}, but the compiled graph "
f"expects {expected_dtype}; converting at every call will "
"allocate/copy input data.",
DTypeBoundaryWarning,
stacklevel=2,
)
if self._supports_device_ptrs and tensor.is_cuda:
t = tensor.detach().contiguous().to(expected_dtype)
n_bytes = t.numel() * t.element_size()

View File

@@ -1,215 +0,0 @@
from dataclasses import dataclass
import warnings
from typing import Callable
import pytest
import torch
from luminal import luminal_backend
from luminal.compiled_model import DTypeBoundaryWarning
class BoundaryNoopModel(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
if x.dtype is torch.bool:
return x | torch.zeros((), dtype=torch.bool, device=x.device)
return x + torch.zeros((), dtype=x.dtype, device=x.device)
@dataclass(frozen=True)
class DTypeCase:
name: str
dtype: torch.dtype
values: Callable[[], torch.Tensor]
xfail_reason: str | None = None
DTYPE_CASES = [
DTypeCase(
"bool",
torch.bool,
lambda: torch.tensor([True, False, True], dtype=torch.bool),
),
DTypeCase(
"uint8",
torch.uint8,
lambda: torch.tensor([0, 127, 255], dtype=torch.uint8),
),
DTypeCase(
"int8",
torch.int8,
lambda: torch.tensor([-128, -1, 127], dtype=torch.int8),
),
DTypeCase(
"int16",
torch.int16,
lambda: torch.tensor([-32768, -1, 32767], dtype=torch.int16),
),
DTypeCase(
"int32",
torch.int32,
lambda: torch.tensor(
[-2147483648, -1, 2147483647],
dtype=torch.int32,
),
),
DTypeCase(
"int64_i32_range",
torch.int64,
lambda: torch.tensor(
[-2147483648, -1, 2147483647],
dtype=torch.int64,
),
),
DTypeCase(
"float16",
torch.float16,
lambda: torch.tensor([1.0, 1.5, -2.0], dtype=torch.float16),
),
DTypeCase(
"bfloat16",
torch.bfloat16,
lambda: torch.tensor([1.0, 1.5, -2.0], dtype=torch.bfloat16),
),
DTypeCase(
"float32",
torch.float32,
lambda: torch.tensor([1.0, 1.5, -2.0], dtype=torch.float32),
),
DTypeCase(
"float64_f32_exact",
torch.float64,
lambda: torch.tensor([1.0, 1.5, float(2**40)], dtype=torch.float64),
),
DTypeCase(
"int64_outside_i32_range",
torch.int64,
lambda: torch.tensor([-(2**40), -1, 2**40], dtype=torch.int64),
xfail_reason=(
"Luminal currently collapses integer inputs through i32 at the "
"compiled boundary, so out-of-range int64 values lose information."
),
),
DTypeCase(
"float64_precision_sensitive",
torch.float64,
lambda: torch.tensor(
[1.0, 1.0000000000000002, float(2**40) + 0.25],
dtype=torch.float64,
),
xfail_reason=(
"Luminal currently routes float64 no-op computation through f32 "
"storage/outputs before restoring the PyTorch-visible dtype."
),
),
]
def _cuda_skip_reason() -> str | None:
if not torch.cuda.is_available():
return "CUDA is not available"
try:
from luminal.luminal import _cuda_lite_factory_capsule
_cuda_lite_factory_capsule()
except (ImportError, AttributeError, RuntimeError) as exc:
return f"luminal_python was not built with CUDA support: {exc}"
return None
@pytest.fixture(params=["cpu", "cuda"], ids=["cpu", "cuda"])
def boundary_device(request) -> torch.device:
device_name = request.param
if device_name == "cuda":
skip_reason = _cuda_skip_reason()
if skip_reason is not None:
pytest.skip(skip_reason)
return torch.device(device_name)
@pytest.mark.parametrize(
"case",
[
pytest.param(
case,
marks=pytest.mark.xfail(reason=case.xfail_reason, strict=True)
if case.xfail_reason is not None
else (),
id=case.name,
)
for case in DTYPE_CASES
],
)
def test_boundary_noop_preserves_dtype_and_values(
boundary_device: torch.device,
case: DTypeCase,
) -> None:
model = BoundaryNoopModel().to(boundary_device)
compiled = torch.compile(model, backend=luminal_backend)
x = case.values().to(boundary_device)
expected = model(x)
actual = compiled(x)
assert isinstance(actual, torch.Tensor)
assert actual.dtype == expected.dtype
assert torch.equal(actual.cpu(), expected.cpu())
@pytest.mark.parametrize(
"case",
[
pytest.param(case, id=case.name)
for case in DTYPE_CASES
if case.name
in {
"uint8",
"int8",
"int16",
"int64_i32_range",
"int64_outside_i32_range",
"float64_f32_exact",
"float64_precision_sensitive",
}
],
)
def test_boundary_warns_when_input_dtype_requires_conversion(
boundary_device: torch.device,
case: DTypeCase,
) -> None:
model = BoundaryNoopModel().to(boundary_device)
compiled = torch.compile(model, backend=luminal_backend)
x = case.values().to(boundary_device)
with pytest.warns(DTypeBoundaryWarning, match="allocate/copy input data"):
compiled(x)
@pytest.mark.parametrize(
"case",
[
pytest.param(case, id=case.name)
for case in DTYPE_CASES
if case.name in {"bool", "int32", "float16", "bfloat16", "float32"}
],
)
def test_boundary_does_not_warn_when_input_dtype_matches_graph(
boundary_device: torch.device,
case: DTypeCase,
) -> None:
model = BoundaryNoopModel().to(boundary_device)
compiled = torch.compile(model, backend=luminal_backend)
x = case.values().to(boundary_device)
with warnings.catch_warnings(record=True) as records:
warnings.simplefilter("always")
compiled(x)
dtype_boundary_warnings = [
record
for record in records
if issubclass(record.category, DTypeBoundaryWarning)
]
assert dtype_boundary_warnings == []

View File

@@ -1,142 +0,0 @@
import torch
import pytest
from luminal import luminal_backend
class StrideSensitiveInputModel(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.register_buffer(
"coeff",
torch.tensor([1.0, 10.0, 100.0], dtype=torch.float32),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x @ self.coeff
class TwoInputReadModel(torch.nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x * 2.0 + y * 3.0
class ReturnInputModel(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x
class ReturnInputAndComputedModel(torch.nn.Module):
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
return x, x + 1.0
class CloneThenMutateModel(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
y = x.clone()
y.add_(1.0)
return y, x * 2.0
def _base_view(device: torch.device) -> tuple[torch.Tensor, torch.Tensor]:
base = torch.arange(12, dtype=torch.float32, device=device).reshape(3, 4)
return base, base.t()
def _assert_non_contiguous_storage_alias(base: torch.Tensor, view: torch.Tensor) -> None:
assert not view.is_contiguous()
assert view.untyped_storage().data_ptr() == base.untyped_storage().data_ptr()
def _assert_same(actual, expected) -> None:
if isinstance(expected, tuple):
assert isinstance(actual, tuple)
assert len(actual) == len(expected)
for actual_item, expected_item in zip(actual, expected):
_assert_same(actual_item, expected_item)
return
assert torch.allclose(actual, expected)
def _single_non_contiguous_view(device: torch.device):
base, view = _base_view(device)
_assert_non_contiguous_storage_alias(base, view)
return StrideSensitiveInputModel().to(device), (view,), base
def _same_view_twice(device: torch.device):
base, view = _base_view(device)
_assert_non_contiguous_storage_alias(base, view)
return TwoInputReadModel().to(device), (view, view), base
def _overlapping_views(device: torch.device):
base = torch.arange(20, dtype=torch.float32, device=device).reshape(4, 5)
x = base[:3, :4]
y = base[1:, 1:]
assert not x.is_contiguous()
assert not y.is_contiguous()
assert x.untyped_storage().data_ptr() == base.untyped_storage().data_ptr()
assert y.untyped_storage().data_ptr() == base.untyped_storage().data_ptr()
return TwoInputReadModel().to(device), (x, y), base
def _return_input(device: torch.device):
base, view = _base_view(device)
_assert_non_contiguous_storage_alias(base, view)
return ReturnInputModel().to(device), (view,), base
def _return_input_and_computed(device: torch.device):
base, view = _base_view(device)
_assert_non_contiguous_storage_alias(base, view)
return ReturnInputAndComputedModel().to(device), (view,), base
def _internal_clone_inplace(device: torch.device):
base, view = _base_view(device)
_assert_non_contiguous_storage_alias(base, view)
return CloneThenMutateModel().to(device), (view,), base
@pytest.mark.parametrize(
"make_case",
[
pytest.param(
_single_non_contiguous_view,
id="single_non_contiguous_view_stride_sensitive_read",
),
pytest.param(_same_view_twice, id="same_view_passed_as_two_read_inputs"),
pytest.param(_overlapping_views, id="overlapping_views_as_two_read_inputs"),
pytest.param(_return_input, id="return_input_boundary_value"),
pytest.param(
_return_input_and_computed,
id="return_input_boundary_value_and_computed_value",
),
pytest.param(_internal_clone_inplace, id="inplace_mutation_on_internal_clone"),
],
)
def test_input_boundary_contiguous_materialization_cases(
device: torch.device, make_case
) -> None:
model, inputs, base = make_case(device)
compiled = torch.compile(model, backend=luminal_backend)
base_before = base.clone()
expected = model(*inputs)
actual = compiled(*inputs)
_assert_same(actual, expected)
assert torch.allclose(base, base_before)
def test_non_contiguous_view_input_fails_if_raw_storage_order_is_used(
device: torch.device,
) -> None:
model, (view,), base = _single_non_contiguous_view(device)
wrong_if_storage_order_used = model(base.reshape(view.shape))
expected = model(view)
assert not torch.allclose(wrong_if_storage_order_used, expected)

View File

@@ -1,138 +0,0 @@
"""Regression coverage for torch.compile mutation and alias contracts.
PyTorch backends are expected to preserve the semantics of the traced graph.
After torch.export functionalization, input mutations are represented as
leading mutation outputs before user outputs. Luminal currently treats every
compiled graph output as a user output and also materializes inputs at the
boundary, so caller-visible mutation and aliasing semantics are not preserved.
"""
import pytest
import torch
from luminal import luminal_backend
class MutateInputThenCompute(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x.add_(1.0)
return x * 2.0
class MutateInputReturnAlias(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x.add_(1.0)
return x
class MutateOverlappingInputAlias(torch.nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
x.add_(10.0)
return y * 2.0
class ReturnInputView(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x.t()
def _assert_same_storage(a: torch.Tensor, b: torch.Tensor) -> None:
assert a.untyped_storage().data_ptr() == b.untyped_storage().data_ptr()
@pytest.mark.parametrize("backend", ["eager", "aot_eager", "inductor"])
def test_stock_torch_compile_preserves_input_mutation_writeback(backend: str) -> None:
model = MutateInputThenCompute()
expected_input = torch.arange(6, dtype=torch.float32).reshape(2, 3)
actual_input = expected_input.clone()
expected = model(expected_input)
compiled = torch.compile(model, backend=backend)
actual = compiled(actual_input)
assert torch.equal(actual, expected)
assert torch.equal(actual_input, expected_input)
@pytest.mark.parametrize("backend", ["eager", "aot_eager", "inductor"])
def test_stock_torch_compile_preserves_mutated_return_alias(backend: str) -> None:
model = MutateInputReturnAlias()
x = torch.arange(6, dtype=torch.float32).reshape(2, 3)
compiled = torch.compile(model, backend=backend)
out = compiled(x)
assert torch.equal(x, torch.arange(1, 7, dtype=torch.float32).reshape(2, 3))
_assert_same_storage(out, x)
@pytest.mark.parametrize("backend", ["eager", "aot_eager", "inductor"])
def test_stock_torch_compile_preserves_returned_view_alias(backend: str) -> None:
model = ReturnInputView()
x = torch.arange(6, dtype=torch.float32).reshape(2, 3)
compiled = torch.compile(model, backend=backend)
out = compiled(x)
assert torch.equal(out, x.t())
assert out.stride() == (1, 3)
_assert_same_storage(out, x)
@pytest.mark.xfail(
strict=True,
reason=(
"Luminal currently treats functionalized input-mutation outputs as user "
"outputs and does not copy mutation outputs back to caller inputs."
),
)
def test_luminal_input_mutation_writeback_contract(device: torch.device) -> None:
model = MutateInputThenCompute().to(device)
x = torch.arange(6, dtype=torch.float32, device=device).reshape(2, 3)
compiled = torch.compile(model, backend=luminal_backend)
out = compiled(x)
expected_x = torch.arange(1, 7, dtype=torch.float32, device=device).reshape(2, 3)
expected_out = expected_x * 2.0
assert torch.equal(out, expected_out)
assert torch.equal(x, expected_x)
@pytest.mark.xfail(
strict=True,
reason=(
"Luminal does not preserve caller-visible overlapping input aliasing "
"when one aliased input is mutated."
),
)
def test_luminal_overlapping_input_alias_mutation_contract(
device: torch.device,
) -> None:
model = MutateOverlappingInputAlias().to(device)
eager_base = torch.arange(6, dtype=torch.float32, device=device)
expected = model(eager_base[:4], eager_base[1:5])
base = torch.arange(6, dtype=torch.float32, device=device)
compiled = torch.compile(model, backend=luminal_backend)
actual = compiled(base[:4], base[1:5])
assert torch.equal(actual, expected)
assert torch.equal(base, eager_base)
@pytest.mark.xfail(
strict=True,
reason="Luminal materializes returned input views instead of preserving aliasing.",
)
def test_luminal_returned_view_alias_contract(device: torch.device) -> None:
model = ReturnInputView().to(device)
x = torch.arange(6, dtype=torch.float32, device=device).reshape(2, 3)
compiled = torch.compile(model, backend=luminal_backend)
out = compiled(x)
assert torch.equal(out, x.t())
assert out.stride() == (1, 3)
_assert_same_storage(out, x)

View File

@@ -1,3 +1,5 @@
#[path = "../../../examples_common/benchmark_stdio.rs"]
mod benchmark_stdio;
mod hf;
mod model;
@@ -6,18 +8,14 @@ use luminal::prelude::*;
use luminal_cuda_lite::{cudarc::driver::CudaContext, runtime::CudaRuntime};
use model::*;
use rustc_hash::FxHashSet;
use std::{io::Write, time::Duration};
use std::{
io::Write,
time::{Duration, Instant},
};
use tokenizers::Tokenizer;
const REPO_ID: &str = "google/gemma-4-26B-A4B";
fn env_usize(name: &str, default: usize) -> usize {
std::env::var(name)
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(default)
}
fn env_bool(name: &str) -> bool {
std::env::var(name)
.ok()
@@ -25,9 +23,10 @@ fn env_bool(name: &str) -> bool {
}
fn main() {
let max_seq_len = env_usize("MAX_SEQ_LEN", 4096);
let gen_tokens = env_usize("GEN_TOKENS", 30);
let search_graphs = env_usize("SEARCH_GRAPHS", 50);
let stdio = benchmark_stdio::enabled();
let max_seq_len = benchmark_stdio::env_usize("MAX_SEQ_LEN", 4096);
let gen_tokens = benchmark_stdio::env_usize("GEN_TOKENS", 30);
let search_graphs = benchmark_stdio::env_usize("SEARCH_GRAPHS", 50);
let prompt = std::env::var("PROMPT").unwrap_or_else(|_| "The capital of France is".to_string());
let print_token_ids = env_bool("PRINT_TOKEN_IDS");
@@ -38,11 +37,6 @@ fn main() {
println!("Using model directory: {}", model_dir.display());
let tokenizer = Tokenizer::from_file(model_dir.join("tokenizer.json")).unwrap();
let prompt_tokens = tokenizer
.encode(prompt.as_str(), true)
.unwrap()
.get_ids()
.to_vec();
let mut cx = Graph::default();
let input = cx.named_tensor("input", 's').as_dtype(DType::Int);
@@ -63,11 +57,14 @@ fn main() {
let weights_path = model_dir.join("model_combined.safetensors");
runtime.load_safetensors(&cx, weights_path.to_str().unwrap());
for layer in 0..LAYERS {
let cache_bytes = cache_bytes_for_layer(layer, max_seq_len);
runtime.set_zeros(kv_cache.k_caches[layer], cache_bytes);
runtime.set_zeros(kv_cache.v_caches[layer], cache_bytes);
}
let reset_cache = |runtime: &mut CudaRuntime| {
for layer in 0..LAYERS {
let cache_bytes = cache_bytes_for_layer(layer, max_seq_len);
runtime.set_zeros(kv_cache.k_caches[layer], cache_bytes);
runtime.set_zeros(kv_cache.v_caches[layer], cache_bytes);
}
};
reset_cache(&mut runtime);
println!("Compiling...");
cx.set_dim('s', 1);
@@ -75,15 +72,66 @@ fn main() {
runtime.set_data(input, vec![1]);
runtime.set_data(pos_ids, vec![1]);
runtime = cx.search(runtime, search_graphs);
reset_cache(&mut runtime);
for layer in 0..LAYERS {
let cache_bytes = cache_bytes_for_layer(layer, max_seq_len);
runtime.set_zeros(kv_cache.k_caches[layer], cache_bytes);
runtime.set_zeros(kv_cache.v_caches[layer], cache_bytes);
if stdio {
benchmark_stdio::serve(|prompt| {
reset_cache(&mut runtime);
run_prompt(
prompt,
gen_tokens,
print_token_ids,
&tokenizer,
&mut cx,
&mut runtime,
input,
pos_ids,
logits,
&cache_outputs,
&kv_cache,
true,
);
});
} else {
run_prompt(
&prompt,
gen_tokens,
print_token_ids,
&tokenizer,
&mut cx,
&mut runtime,
input,
pos_ids,
logits,
&cache_outputs,
&kv_cache,
false,
);
}
}
print!("{prompt}");
std::io::stdout().flush().unwrap();
#[allow(clippy::too_many_arguments)]
fn run_prompt(
prompt: &str,
gen_tokens: usize,
print_token_ids: bool,
tokenizer: &Tokenizer,
cx: &mut Graph,
runtime: &mut CudaRuntime,
input: GraphTensor,
pos_ids: GraphTensor,
logits: GraphTensor,
cache_outputs: &[(GraphTensor, GraphTensor)],
kv_cache: &KVCache,
stdio: bool,
) {
let prompt_tokens = tokenizer.encode(prompt, true).unwrap().get_ids().to_vec();
let query_start = Instant::now();
if !stdio {
print!("{prompt}");
std::io::stdout().flush().unwrap();
}
let mut prev_seq = 0usize;
let mut fwd_durations = vec![];
@@ -93,7 +141,7 @@ fn main() {
const EOS_TOKEN: u32 = 1;
let prefill_start = std::time::Instant::now();
let prefill_start = Instant::now();
for &token in &prompt_tokens {
cx.set_dim('s', 1);
cx.set_dim('p', prev_seq);
@@ -121,12 +169,26 @@ fn main() {
.unwrap()
.0 as u32;
generated_token_ids.push(next_token);
print!("{}", tokenizer.decode(&[next_token], true).unwrap());
std::io::stdout().flush().unwrap();
let mut generated = 0usize;
if stdio {
if next_token != EOS_TOKEN {
let decoded = tokenizer.decode(&[next_token], true).unwrap();
benchmark_stdio::emit_token(&decoded);
generated += 1;
}
} else {
let decoded = tokenizer.decode(&[next_token], true).unwrap();
print!("{decoded}");
std::io::stdout().flush().unwrap();
}
seen_tokens.insert(next_token);
for _ in 1..gen_tokens {
let start = std::time::Instant::now();
if stdio && next_token == EOS_TOKEN {
break;
}
let start = Instant::now();
cx.set_dim('s', 1);
cx.set_dim('p', prev_seq);
runtime.set_data(input, vec![next_token as i32]);
@@ -165,10 +227,21 @@ fn main() {
break;
}
print!("{}", tokenizer.decode(&[next_token], true).unwrap());
std::io::stdout().flush().unwrap();
let decoded = tokenizer.decode(&[next_token], true).unwrap();
if stdio {
benchmark_stdio::emit_token(&decoded);
} else {
print!("{decoded}");
std::io::stdout().flush().unwrap();
}
generated += 1;
fwd_durations.push(start.elapsed());
}
if stdio {
benchmark_stdio::emit_eoq(generated, query_start);
return;
}
println!();
if print_token_ids {
println!("Generated token ids: {generated_token_ids:?}");

View File

@@ -1,3 +1,5 @@
#[path = "../../../examples_common/benchmark_stdio.rs"]
mod benchmark_stdio;
mod hf;
mod model;
@@ -7,22 +9,36 @@ use luminal_cuda_lite::{cudarc::driver::CudaContext, runtime::CudaRuntime};
use luminal_tracing::*;
use model::*;
use rustc_hash::FxHashSet;
use std::{io::Write, time::Duration};
use std::{
io::Write,
time::{Duration, Instant},
};
use tokenizers::Tokenizer;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
const REPO_ID: &str = "NousResearch/Meta-Llama-3-8B-Instruct";
fn main() {
let stdio = benchmark_stdio::enabled();
let max_seq_len = 4096;
let gen_tokens = 500;
let search_graphs = 500;
let gen_tokens = if stdio {
benchmark_stdio::env_usize("GEN_TOKENS", 500)
} else {
500
};
let search_graphs = if stdio {
benchmark_stdio::env_usize("SEARCH_GRAPHS", 500)
} else {
500
};
let prompt = "Explain what a neural network is in a paragraph.";
tracing_subscriber::registry()
.with(tracing_subscriber::fmt::layer())
.with(luminal_filter())
.init();
if !stdio {
tracing_subscriber::registry()
.with(tracing_subscriber::fmt::layer())
.with(luminal_filter())
.init();
}
let ctx = CudaContext::new(0).unwrap();
let stream = ctx.default_stream();
@@ -31,14 +47,6 @@ fn main() {
println!("Using model directory: {}", model_dir.display());
let tokenizer = Tokenizer::from_file(model_dir.join("tokenizer.json")).unwrap();
let chat_prompt = format!(
"<|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
);
let prompt_tokens = tokenizer
.encode(chat_prompt.as_str(), true)
.unwrap()
.get_ids()
.to_vec();
// Build graph
let mut cx = Graph::default();
@@ -66,10 +74,13 @@ fn main() {
runtime.load_safetensors(&cx, weights_path.to_str().unwrap());
let cache_bytes = N_KV_HEADS * max_seq_len * HEAD_DIM * std::mem::size_of::<f32>();
for i in 0..LAYERS {
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
}
let reset_cache = |runtime: &mut CudaRuntime| {
for i in 0..LAYERS {
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
}
};
reset_cache(&mut runtime);
println!("Compiling...");
cx.set_dim('s', 1);
@@ -77,12 +88,65 @@ fn main() {
runtime.set_data(input, vec![1]);
runtime.set_data(token_ids, vec![1]);
runtime = cx.search(runtime, search_graphs);
reset_cache(&mut runtime);
for i in 0..LAYERS {
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
if stdio {
benchmark_stdio::serve(|prompt| {
reset_cache(&mut runtime);
run_prompt(
prompt,
gen_tokens,
&tokenizer,
&mut cx,
&mut runtime,
input,
token_ids,
logits,
&cache_outputs,
&kv_cache,
true,
);
});
} else {
run_prompt(
prompt,
gen_tokens,
&tokenizer,
&mut cx,
&mut runtime,
input,
token_ids,
logits,
&cache_outputs,
&kv_cache,
false,
);
}
}
#[allow(clippy::too_many_arguments)]
fn run_prompt(
prompt: &str,
gen_tokens: usize,
tokenizer: &Tokenizer,
cx: &mut Graph,
runtime: &mut CudaRuntime,
input: GraphTensor,
token_ids: GraphTensor,
logits: GraphTensor,
cache_outputs: &[(GraphTensor, GraphTensor)],
kv_cache: &KVCache,
stdio: bool,
) {
let chat_prompt = format!(
"<|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
);
let prompt_tokens = tokenizer
.encode(chat_prompt.as_str(), true)
.unwrap()
.get_ids()
.to_vec();
let query_start = Instant::now();
let mut prev_seq = 1usize;
let mut sentence = vec![prompt_tokens[0]];
let total_steps = prompt_tokens.len() - 1 + gen_tokens;
@@ -94,13 +158,16 @@ fn main() {
const EOS_TOKEN: u32 = 128009;
const STOP_TOKEN: u32 = 128001;
println!(
"Prompt: {} tokens, generating up to {} tokens",
prompt_len, gen_tokens
);
if !stdio {
println!(
"Prompt: {} tokens, generating up to {} tokens",
prompt_len, gen_tokens
);
}
let mut generated = 0usize;
for i in 0..total_steps {
let start = std::time::Instant::now();
let start = Instant::now();
let is_prefill = i < prompt_len - 1;
let seq_len = sentence.len();
@@ -159,12 +226,21 @@ fn main() {
}
let decoded = tokenizer.decode(&[next_token], true).unwrap();
print!("{}", decoded);
std::io::stdout().flush().unwrap();
if stdio {
benchmark_stdio::emit_token(&decoded);
} else {
print!("{}", decoded);
std::io::stdout().flush().unwrap();
}
generated += 1;
}
if stdio {
benchmark_stdio::emit_eoq(generated, query_start);
return;
}
println!();
// Benchmarks
println!();
let decode_durations: Vec<_> = fwd_durations.iter().skip(prompt_len).collect();
if decode_durations.len() > 2 {
println!(

View File

@@ -1,3 +1,5 @@
#[path = "../../../examples_common/benchmark_stdio.rs"]
mod benchmark_stdio;
mod hf;
mod model;
@@ -7,22 +9,36 @@ use luminal_cuda_lite::{cudarc::driver::CudaContext, runtime::CudaRuntime};
use luminal_tracing::*;
use model::*;
use rustc_hash::FxHashSet;
use std::{io::Write, time::Duration};
use std::{
io::Write,
time::{Duration, Instant},
};
use tokenizers::Tokenizer;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
const REPO_ID: &str = "Qwen/Qwen3-4B";
fn main() {
let stdio = benchmark_stdio::enabled();
let max_seq_len = 4096;
let gen_tokens = 500;
let search_graphs = 500;
let gen_tokens = if stdio {
benchmark_stdio::env_usize("GEN_TOKENS", 500)
} else {
500
};
let search_graphs = if stdio {
benchmark_stdio::env_usize("SEARCH_GRAPHS", 500)
} else {
500
};
let prompt = "Explain what a neural network is in a paragraph.";
tracing_subscriber::registry()
.with(tracing_subscriber::fmt::layer())
.with(luminal_filter())
.init();
if !stdio {
tracing_subscriber::registry()
.with(tracing_subscriber::fmt::layer())
.with(luminal_filter())
.init();
}
let ctx = CudaContext::new(0).unwrap();
let stream = ctx.default_stream();
@@ -31,7 +47,6 @@ fn main() {
println!("Using model directory: {}", model_dir.display());
let tokenizer = Tokenizer::from_file(model_dir.join("tokenizer.json")).unwrap();
let prompt_tokens = tokenizer.encode(prompt, true).unwrap().get_ids().to_vec();
// Build graph
let mut cx = Graph::default();
@@ -54,10 +69,13 @@ fn main() {
runtime.load_safetensors(&cx, weights_path.to_str().unwrap());
let cache_bytes = N_KV_HEADS * max_seq_len * HEAD_DIM * std::mem::size_of::<f32>();
for i in 0..LAYERS {
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
}
let reset_cache = |runtime: &mut CudaRuntime| {
for i in 0..LAYERS {
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
}
};
reset_cache(&mut runtime);
println!("Compiling...");
cx.set_dim('s', 1);
@@ -65,12 +83,58 @@ fn main() {
runtime.set_data(input, vec![1]);
runtime.set_data(token_ids, vec![1]);
runtime = cx.search(runtime, search_graphs);
reset_cache(&mut runtime);
for i in 0..LAYERS {
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
if stdio {
benchmark_stdio::serve(|prompt| {
reset_cache(&mut runtime);
run_prompt(
prompt,
gen_tokens,
&tokenizer,
&mut cx,
&mut runtime,
input,
token_ids,
logits,
&cache_outputs,
&kv_cache,
true,
);
});
} else {
run_prompt(
prompt,
gen_tokens,
&tokenizer,
&mut cx,
&mut runtime,
input,
token_ids,
logits,
&cache_outputs,
&kv_cache,
false,
);
}
}
#[allow(clippy::too_many_arguments)]
fn run_prompt(
prompt: &str,
gen_tokens: usize,
tokenizer: &Tokenizer,
cx: &mut Graph,
runtime: &mut CudaRuntime,
input: GraphTensor,
token_ids: GraphTensor,
logits: GraphTensor,
cache_outputs: &[(GraphTensor, GraphTensor)],
kv_cache: &KVCache,
stdio: bool,
) {
let prompt_tokens = tokenizer.encode(prompt, true).unwrap().get_ids().to_vec();
let query_start = Instant::now();
let mut prev_seq = 1usize;
let mut sentence = vec![prompt_tokens[0]];
let total_steps = prompt_tokens.len() - 1 + gen_tokens;
@@ -82,13 +146,16 @@ fn main() {
const EOS_TOKEN: u32 = 151645; // <|endoftext|>
const STOP_TOKEN: u32 = 151643; // <|end|>
println!(
"Prompt: {} tokens, generating up to {} tokens",
prompt_len, gen_tokens
);
if !stdio {
println!(
"Prompt: {} tokens, generating up to {} tokens",
prompt_len, gen_tokens
);
}
let mut generated = 0usize;
for i in 0..total_steps {
let start = std::time::Instant::now();
let start = Instant::now();
let is_prefill = i < prompt_len - 1;
let seq_len = sentence.len();
@@ -147,12 +214,21 @@ fn main() {
}
let decoded = tokenizer.decode(&[next_token], true).unwrap();
print!("{}", decoded);
std::io::stdout().flush().unwrap();
if stdio {
benchmark_stdio::emit_token(&decoded);
} else {
print!("{}", decoded);
std::io::stdout().flush().unwrap();
}
generated += 1;
}
if stdio {
benchmark_stdio::emit_eoq(generated, query_start);
return;
}
println!();
// Benchmarks
println!();
let decode_durations: Vec<_> = fwd_durations.iter().skip(prompt_len).collect();
if decode_durations.len() > 2 {
println!(

View File

@@ -1,3 +1,5 @@
#[path = "../../../examples_common/benchmark_stdio.rs"]
mod benchmark_stdio;
mod hf;
mod model;
@@ -6,15 +8,27 @@ use luminal::prelude::*;
use luminal_cuda_lite::{cudarc::driver::CudaContext, runtime::CudaRuntime};
use model::*;
use rustc_hash::FxHashSet;
use std::{io::Write, time::Duration};
use std::{
io::Write,
time::{Duration, Instant},
};
use tokenizers::Tokenizer;
const REPO_ID: &str = "Qwen/Qwen3-30B-A3B";
fn main() {
let stdio = benchmark_stdio::enabled();
let max_seq_len = 4096;
let gen_tokens = 30;
let search_graphs = 50;
let gen_tokens = if stdio {
benchmark_stdio::env_usize("GEN_TOKENS", 30)
} else {
30
};
let search_graphs = if stdio {
benchmark_stdio::env_usize("SEARCH_GRAPHS", 50)
} else {
50
};
let prompt = "The capital of France is";
let ctx = CudaContext::new(0).unwrap();
@@ -24,7 +38,6 @@ fn main() {
println!("Using model directory: {}", model_dir.display());
let tokenizer = Tokenizer::from_file(model_dir.join("tokenizer.json")).unwrap();
let prompt_tokens = tokenizer.encode(prompt, true).unwrap().get_ids().to_vec();
// Build graph
let mut cx = Graph::default();
@@ -47,10 +60,13 @@ fn main() {
runtime.load_safetensors(&cx, weights_path.to_str().unwrap());
let cache_bytes = N_KV_HEADS * max_seq_len * HEAD_DIM * std::mem::size_of::<f32>();
for i in 0..LAYERS {
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
}
let reset_cache = |runtime: &mut CudaRuntime| {
for i in 0..LAYERS {
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
}
};
reset_cache(&mut runtime);
println!("Compiling...");
cx.set_dim('s', 1);
@@ -58,14 +74,63 @@ fn main() {
runtime.set_data(input, vec![1]);
runtime.set_data(pos_ids, vec![1]);
runtime = cx.search(runtime, search_graphs);
reset_cache(&mut runtime);
for i in 0..LAYERS {
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
if stdio {
benchmark_stdio::serve(|prompt| {
reset_cache(&mut runtime);
run_prompt(
prompt,
gen_tokens,
&tokenizer,
&mut cx,
&mut runtime,
input,
pos_ids,
logits,
&cache_outputs,
&kv_cache,
true,
);
});
} else {
run_prompt(
prompt,
gen_tokens,
&tokenizer,
&mut cx,
&mut runtime,
input,
pos_ids,
logits,
&cache_outputs,
&kv_cache,
false,
);
}
}
print!("{prompt}");
std::io::stdout().flush().unwrap();
#[allow(clippy::too_many_arguments)]
fn run_prompt(
prompt: &str,
gen_tokens: usize,
tokenizer: &Tokenizer,
cx: &mut Graph,
runtime: &mut CudaRuntime,
input: GraphTensor,
pos_ids: GraphTensor,
logits: GraphTensor,
cache_outputs: &[(GraphTensor, GraphTensor)],
kv_cache: &KVCache,
stdio: bool,
) {
let prompt_tokens = tokenizer.encode(prompt, true).unwrap().get_ids().to_vec();
let query_start = Instant::now();
if !stdio {
print!("{prompt}");
std::io::stdout().flush().unwrap();
}
let mut prev_seq = 0usize;
let mut fwd_durations = vec![];
@@ -76,7 +141,7 @@ fn main() {
const STOP_TOKEN: u32 = 151643;
// Prefill: process prompt tokens one at a time
let prefill_start = std::time::Instant::now();
let prefill_start = Instant::now();
for &token in &prompt_tokens {
cx.set_dim('s', 1);
cx.set_dim('p', prev_seq);
@@ -105,13 +170,27 @@ fn main() {
.max_by(|(_, a), (_, b)| a.total_cmp(b))
.unwrap()
.0 as u32;
print!("{}", tokenizer.decode(&[next_token], true).unwrap());
std::io::stdout().flush().unwrap();
let mut generated = 0usize;
if stdio {
if next_token != EOS_TOKEN && next_token != STOP_TOKEN {
let decoded = tokenizer.decode(&[next_token], true).unwrap();
benchmark_stdio::emit_token(&decoded);
generated += 1;
}
} else {
let decoded = tokenizer.decode(&[next_token], true).unwrap();
print!("{decoded}");
std::io::stdout().flush().unwrap();
}
seen_tokens.insert(next_token);
// Decode loop
for _ in 1..gen_tokens {
let start = std::time::Instant::now();
if stdio && (next_token == EOS_TOKEN || next_token == STOP_TOKEN) {
break;
}
let start = Instant::now();
cx.set_dim('s', 1);
cx.set_dim('p', prev_seq);
runtime.set_data(input, vec![next_token as i32]);
@@ -150,13 +229,23 @@ fn main() {
break;
}
print!("{}", tokenizer.decode(&[next_token], true).unwrap());
std::io::stdout().flush().unwrap();
let decoded = tokenizer.decode(&[next_token], true).unwrap();
if stdio {
benchmark_stdio::emit_token(&decoded);
} else {
print!("{decoded}");
std::io::stdout().flush().unwrap();
}
generated += 1;
fwd_durations.push(start.elapsed());
}
println!();
if stdio {
benchmark_stdio::emit_eoq(generated, query_start);
return;
}
// Report benchmarks
println!();
println!(
" TTFT: {:.2} ms ({} prompt tokens)",
prefill_duration.as_secs_f64() * 1e3,

View File

@@ -0,0 +1,58 @@
use std::{
io::{BufRead, Write},
time::Instant,
};
pub fn enabled() -> bool {
std::env::args().any(|arg| arg == "--stdio")
}
pub fn env_usize(name: &str, default: usize) -> usize {
std::env::var(name)
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(default)
}
fn emit_ready() {
println!("READY");
std::io::stdout().flush().unwrap();
}
pub fn serve(mut f: impl FnMut(&str)) {
emit_ready();
let stdin = std::io::stdin();
for line in stdin.lock().lines() {
let line = line.unwrap();
f(&line);
}
}
pub fn emit_token(token: &str) {
println!("TOK\t{}", escape_token(token));
std::io::stdout().flush().unwrap();
}
pub fn emit_eoq(generated: usize, query_start: Instant) {
println!(
"EOQ\t{}\t{:.3}",
generated,
query_start.elapsed().as_secs_f64() * 1e3
);
std::io::stdout().flush().unwrap();
}
fn escape_token(s: &str) -> String {
let mut out = String::with_capacity(s.len());
for ch in s.chars() {
match ch {
'\\' => out.push_str("\\\\"),
'\t' => out.push_str("\\t"),
'\n' => out.push_str("\\n"),
'\r' => out.push_str("\\r"),
_ => out.push(ch),
}
}
out
}