mirror of
https://github.com/jafioti/luminal.git
synced 2026-06-01 21:49:47 +09:00
Compare commits
2 Commits
strided-in
...
codex/rust
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0b1e09cf23 | ||
|
|
7402503bd4 |
@@ -8,7 +8,7 @@ echo "=========================================="
|
||||
echo " Luminal Python: Full Test Suite"
|
||||
echo "=========================================="
|
||||
|
||||
NATIVE_TESTS="tests/test_hlir_ops.py tests/test_unary.py tests/test_input_layout.py tests/test_dtype_boundary.py tests/test_mutation_alias_contract.py"
|
||||
NATIVE_TESTS="tests/test_hlir_ops.py tests/test_unary.py"
|
||||
CUDA_TESTS="tests/"
|
||||
|
||||
# ── Phase 1: Native Backend ─────────────────────────────────
|
||||
|
||||
@@ -16,7 +16,7 @@ uv run maturin develop --manifest-path rust/Cargo.toml
|
||||
echo "Step 3: Running pytest..."
|
||||
# it is best not to add the full model tests, they end up running billion parameter models
|
||||
# on the CPU and it takes far to long
|
||||
uv run pytest tests/test_hlir_ops.py tests/test_unary.py tests/test_input_layout.py tests/test_dtype_boundary.py tests/test_mutation_alias_contract.py -v
|
||||
uv run pytest tests/test_hlir_ops.py tests/test_unary.py -v
|
||||
|
||||
echo ""
|
||||
echo "=== Tests Complete ==="
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
"""CompiledModel wrapper for the Rust CompiledGraph."""
|
||||
|
||||
import warnings
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
@@ -9,10 +8,6 @@ from .dtype_util import code_to_torch_dtype
|
||||
from .dtype_util import torch_dtype_code as _torch_dtype_code
|
||||
|
||||
|
||||
class DTypeBoundaryWarning(UserWarning):
|
||||
"""Warns when the PyTorch boundary must cast input data before execution."""
|
||||
|
||||
|
||||
class CompiledModel:
|
||||
"""Wrapper around CompiledGraph that handles PyTorch tensor conversion."""
|
||||
|
||||
@@ -100,15 +95,6 @@ class CompiledModel:
|
||||
for name, tensor, expected_dtype in zip(
|
||||
self._input_names, user_inputs, self._input_dtypes
|
||||
):
|
||||
if tensor.dtype != expected_dtype:
|
||||
warnings.warn(
|
||||
"Luminal compiled input "
|
||||
f"'{name}' has dtype {tensor.dtype}, but the compiled graph "
|
||||
f"expects {expected_dtype}; converting at every call will "
|
||||
"allocate/copy input data.",
|
||||
DTypeBoundaryWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
if self._supports_device_ptrs and tensor.is_cuda:
|
||||
t = tensor.detach().contiguous().to(expected_dtype)
|
||||
n_bytes = t.numel() * t.element_size()
|
||||
|
||||
@@ -1,215 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
import warnings
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from luminal import luminal_backend
|
||||
from luminal.compiled_model import DTypeBoundaryWarning
|
||||
|
||||
|
||||
class BoundaryNoopModel(torch.nn.Module):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if x.dtype is torch.bool:
|
||||
return x | torch.zeros((), dtype=torch.bool, device=x.device)
|
||||
return x + torch.zeros((), dtype=x.dtype, device=x.device)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DTypeCase:
|
||||
name: str
|
||||
dtype: torch.dtype
|
||||
values: Callable[[], torch.Tensor]
|
||||
xfail_reason: str | None = None
|
||||
|
||||
|
||||
DTYPE_CASES = [
|
||||
DTypeCase(
|
||||
"bool",
|
||||
torch.bool,
|
||||
lambda: torch.tensor([True, False, True], dtype=torch.bool),
|
||||
),
|
||||
DTypeCase(
|
||||
"uint8",
|
||||
torch.uint8,
|
||||
lambda: torch.tensor([0, 127, 255], dtype=torch.uint8),
|
||||
),
|
||||
DTypeCase(
|
||||
"int8",
|
||||
torch.int8,
|
||||
lambda: torch.tensor([-128, -1, 127], dtype=torch.int8),
|
||||
),
|
||||
DTypeCase(
|
||||
"int16",
|
||||
torch.int16,
|
||||
lambda: torch.tensor([-32768, -1, 32767], dtype=torch.int16),
|
||||
),
|
||||
DTypeCase(
|
||||
"int32",
|
||||
torch.int32,
|
||||
lambda: torch.tensor(
|
||||
[-2147483648, -1, 2147483647],
|
||||
dtype=torch.int32,
|
||||
),
|
||||
),
|
||||
DTypeCase(
|
||||
"int64_i32_range",
|
||||
torch.int64,
|
||||
lambda: torch.tensor(
|
||||
[-2147483648, -1, 2147483647],
|
||||
dtype=torch.int64,
|
||||
),
|
||||
),
|
||||
DTypeCase(
|
||||
"float16",
|
||||
torch.float16,
|
||||
lambda: torch.tensor([1.0, 1.5, -2.0], dtype=torch.float16),
|
||||
),
|
||||
DTypeCase(
|
||||
"bfloat16",
|
||||
torch.bfloat16,
|
||||
lambda: torch.tensor([1.0, 1.5, -2.0], dtype=torch.bfloat16),
|
||||
),
|
||||
DTypeCase(
|
||||
"float32",
|
||||
torch.float32,
|
||||
lambda: torch.tensor([1.0, 1.5, -2.0], dtype=torch.float32),
|
||||
),
|
||||
DTypeCase(
|
||||
"float64_f32_exact",
|
||||
torch.float64,
|
||||
lambda: torch.tensor([1.0, 1.5, float(2**40)], dtype=torch.float64),
|
||||
),
|
||||
DTypeCase(
|
||||
"int64_outside_i32_range",
|
||||
torch.int64,
|
||||
lambda: torch.tensor([-(2**40), -1, 2**40], dtype=torch.int64),
|
||||
xfail_reason=(
|
||||
"Luminal currently collapses integer inputs through i32 at the "
|
||||
"compiled boundary, so out-of-range int64 values lose information."
|
||||
),
|
||||
),
|
||||
DTypeCase(
|
||||
"float64_precision_sensitive",
|
||||
torch.float64,
|
||||
lambda: torch.tensor(
|
||||
[1.0, 1.0000000000000002, float(2**40) + 0.25],
|
||||
dtype=torch.float64,
|
||||
),
|
||||
xfail_reason=(
|
||||
"Luminal currently routes float64 no-op computation through f32 "
|
||||
"storage/outputs before restoring the PyTorch-visible dtype."
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def _cuda_skip_reason() -> str | None:
|
||||
if not torch.cuda.is_available():
|
||||
return "CUDA is not available"
|
||||
|
||||
try:
|
||||
from luminal.luminal import _cuda_lite_factory_capsule
|
||||
|
||||
_cuda_lite_factory_capsule()
|
||||
except (ImportError, AttributeError, RuntimeError) as exc:
|
||||
return f"luminal_python was not built with CUDA support: {exc}"
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@pytest.fixture(params=["cpu", "cuda"], ids=["cpu", "cuda"])
|
||||
def boundary_device(request) -> torch.device:
|
||||
device_name = request.param
|
||||
if device_name == "cuda":
|
||||
skip_reason = _cuda_skip_reason()
|
||||
if skip_reason is not None:
|
||||
pytest.skip(skip_reason)
|
||||
return torch.device(device_name)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
[
|
||||
pytest.param(
|
||||
case,
|
||||
marks=pytest.mark.xfail(reason=case.xfail_reason, strict=True)
|
||||
if case.xfail_reason is not None
|
||||
else (),
|
||||
id=case.name,
|
||||
)
|
||||
for case in DTYPE_CASES
|
||||
],
|
||||
)
|
||||
def test_boundary_noop_preserves_dtype_and_values(
|
||||
boundary_device: torch.device,
|
||||
case: DTypeCase,
|
||||
) -> None:
|
||||
model = BoundaryNoopModel().to(boundary_device)
|
||||
compiled = torch.compile(model, backend=luminal_backend)
|
||||
|
||||
x = case.values().to(boundary_device)
|
||||
expected = model(x)
|
||||
actual = compiled(x)
|
||||
|
||||
assert isinstance(actual, torch.Tensor)
|
||||
assert actual.dtype == expected.dtype
|
||||
assert torch.equal(actual.cpu(), expected.cpu())
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
[
|
||||
pytest.param(case, id=case.name)
|
||||
for case in DTYPE_CASES
|
||||
if case.name
|
||||
in {
|
||||
"uint8",
|
||||
"int8",
|
||||
"int16",
|
||||
"int64_i32_range",
|
||||
"int64_outside_i32_range",
|
||||
"float64_f32_exact",
|
||||
"float64_precision_sensitive",
|
||||
}
|
||||
],
|
||||
)
|
||||
def test_boundary_warns_when_input_dtype_requires_conversion(
|
||||
boundary_device: torch.device,
|
||||
case: DTypeCase,
|
||||
) -> None:
|
||||
model = BoundaryNoopModel().to(boundary_device)
|
||||
compiled = torch.compile(model, backend=luminal_backend)
|
||||
x = case.values().to(boundary_device)
|
||||
|
||||
with pytest.warns(DTypeBoundaryWarning, match="allocate/copy input data"):
|
||||
compiled(x)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
[
|
||||
pytest.param(case, id=case.name)
|
||||
for case in DTYPE_CASES
|
||||
if case.name in {"bool", "int32", "float16", "bfloat16", "float32"}
|
||||
],
|
||||
)
|
||||
def test_boundary_does_not_warn_when_input_dtype_matches_graph(
|
||||
boundary_device: torch.device,
|
||||
case: DTypeCase,
|
||||
) -> None:
|
||||
model = BoundaryNoopModel().to(boundary_device)
|
||||
compiled = torch.compile(model, backend=luminal_backend)
|
||||
x = case.values().to(boundary_device)
|
||||
|
||||
with warnings.catch_warnings(record=True) as records:
|
||||
warnings.simplefilter("always")
|
||||
compiled(x)
|
||||
|
||||
dtype_boundary_warnings = [
|
||||
record
|
||||
for record in records
|
||||
if issubclass(record.category, DTypeBoundaryWarning)
|
||||
]
|
||||
assert dtype_boundary_warnings == []
|
||||
@@ -1,142 +0,0 @@
|
||||
import torch
|
||||
import pytest
|
||||
|
||||
from luminal import luminal_backend
|
||||
|
||||
|
||||
class StrideSensitiveInputModel(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.register_buffer(
|
||||
"coeff",
|
||||
torch.tensor([1.0, 10.0, 100.0], dtype=torch.float32),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x @ self.coeff
|
||||
|
||||
|
||||
class TwoInputReadModel(torch.nn.Module):
|
||||
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
return x * 2.0 + y * 3.0
|
||||
|
||||
|
||||
class ReturnInputModel(torch.nn.Module):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x
|
||||
|
||||
|
||||
class ReturnInputAndComputedModel(torch.nn.Module):
|
||||
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return x, x + 1.0
|
||||
|
||||
|
||||
class CloneThenMutateModel(torch.nn.Module):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
y = x.clone()
|
||||
y.add_(1.0)
|
||||
return y, x * 2.0
|
||||
|
||||
|
||||
def _base_view(device: torch.device) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
base = torch.arange(12, dtype=torch.float32, device=device).reshape(3, 4)
|
||||
return base, base.t()
|
||||
|
||||
|
||||
def _assert_non_contiguous_storage_alias(base: torch.Tensor, view: torch.Tensor) -> None:
|
||||
assert not view.is_contiguous()
|
||||
assert view.untyped_storage().data_ptr() == base.untyped_storage().data_ptr()
|
||||
|
||||
|
||||
def _assert_same(actual, expected) -> None:
|
||||
if isinstance(expected, tuple):
|
||||
assert isinstance(actual, tuple)
|
||||
assert len(actual) == len(expected)
|
||||
for actual_item, expected_item in zip(actual, expected):
|
||||
_assert_same(actual_item, expected_item)
|
||||
return
|
||||
|
||||
assert torch.allclose(actual, expected)
|
||||
|
||||
|
||||
def _single_non_contiguous_view(device: torch.device):
|
||||
base, view = _base_view(device)
|
||||
_assert_non_contiguous_storage_alias(base, view)
|
||||
return StrideSensitiveInputModel().to(device), (view,), base
|
||||
|
||||
|
||||
def _same_view_twice(device: torch.device):
|
||||
base, view = _base_view(device)
|
||||
_assert_non_contiguous_storage_alias(base, view)
|
||||
return TwoInputReadModel().to(device), (view, view), base
|
||||
|
||||
|
||||
def _overlapping_views(device: torch.device):
|
||||
base = torch.arange(20, dtype=torch.float32, device=device).reshape(4, 5)
|
||||
x = base[:3, :4]
|
||||
y = base[1:, 1:]
|
||||
assert not x.is_contiguous()
|
||||
assert not y.is_contiguous()
|
||||
assert x.untyped_storage().data_ptr() == base.untyped_storage().data_ptr()
|
||||
assert y.untyped_storage().data_ptr() == base.untyped_storage().data_ptr()
|
||||
return TwoInputReadModel().to(device), (x, y), base
|
||||
|
||||
|
||||
def _return_input(device: torch.device):
|
||||
base, view = _base_view(device)
|
||||
_assert_non_contiguous_storage_alias(base, view)
|
||||
return ReturnInputModel().to(device), (view,), base
|
||||
|
||||
|
||||
def _return_input_and_computed(device: torch.device):
|
||||
base, view = _base_view(device)
|
||||
_assert_non_contiguous_storage_alias(base, view)
|
||||
return ReturnInputAndComputedModel().to(device), (view,), base
|
||||
|
||||
|
||||
def _internal_clone_inplace(device: torch.device):
|
||||
base, view = _base_view(device)
|
||||
_assert_non_contiguous_storage_alias(base, view)
|
||||
return CloneThenMutateModel().to(device), (view,), base
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"make_case",
|
||||
[
|
||||
pytest.param(
|
||||
_single_non_contiguous_view,
|
||||
id="single_non_contiguous_view_stride_sensitive_read",
|
||||
),
|
||||
pytest.param(_same_view_twice, id="same_view_passed_as_two_read_inputs"),
|
||||
pytest.param(_overlapping_views, id="overlapping_views_as_two_read_inputs"),
|
||||
pytest.param(_return_input, id="return_input_boundary_value"),
|
||||
pytest.param(
|
||||
_return_input_and_computed,
|
||||
id="return_input_boundary_value_and_computed_value",
|
||||
),
|
||||
pytest.param(_internal_clone_inplace, id="inplace_mutation_on_internal_clone"),
|
||||
],
|
||||
)
|
||||
def test_input_boundary_contiguous_materialization_cases(
|
||||
device: torch.device, make_case
|
||||
) -> None:
|
||||
model, inputs, base = make_case(device)
|
||||
compiled = torch.compile(model, backend=luminal_backend)
|
||||
|
||||
base_before = base.clone()
|
||||
expected = model(*inputs)
|
||||
actual = compiled(*inputs)
|
||||
|
||||
_assert_same(actual, expected)
|
||||
assert torch.allclose(base, base_before)
|
||||
|
||||
|
||||
def test_non_contiguous_view_input_fails_if_raw_storage_order_is_used(
|
||||
device: torch.device,
|
||||
) -> None:
|
||||
model, (view,), base = _single_non_contiguous_view(device)
|
||||
|
||||
wrong_if_storage_order_used = model(base.reshape(view.shape))
|
||||
expected = model(view)
|
||||
|
||||
assert not torch.allclose(wrong_if_storage_order_used, expected)
|
||||
@@ -1,138 +0,0 @@
|
||||
"""Regression coverage for torch.compile mutation and alias contracts.
|
||||
|
||||
PyTorch backends are expected to preserve the semantics of the traced graph.
|
||||
After torch.export functionalization, input mutations are represented as
|
||||
leading mutation outputs before user outputs. Luminal currently treats every
|
||||
compiled graph output as a user output and also materializes inputs at the
|
||||
boundary, so caller-visible mutation and aliasing semantics are not preserved.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from luminal import luminal_backend
|
||||
|
||||
|
||||
class MutateInputThenCompute(torch.nn.Module):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x.add_(1.0)
|
||||
return x * 2.0
|
||||
|
||||
|
||||
class MutateInputReturnAlias(torch.nn.Module):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x.add_(1.0)
|
||||
return x
|
||||
|
||||
|
||||
class MutateOverlappingInputAlias(torch.nn.Module):
|
||||
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
x.add_(10.0)
|
||||
return y * 2.0
|
||||
|
||||
|
||||
class ReturnInputView(torch.nn.Module):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x.t()
|
||||
|
||||
|
||||
def _assert_same_storage(a: torch.Tensor, b: torch.Tensor) -> None:
|
||||
assert a.untyped_storage().data_ptr() == b.untyped_storage().data_ptr()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("backend", ["eager", "aot_eager", "inductor"])
|
||||
def test_stock_torch_compile_preserves_input_mutation_writeback(backend: str) -> None:
|
||||
model = MutateInputThenCompute()
|
||||
expected_input = torch.arange(6, dtype=torch.float32).reshape(2, 3)
|
||||
actual_input = expected_input.clone()
|
||||
|
||||
expected = model(expected_input)
|
||||
compiled = torch.compile(model, backend=backend)
|
||||
actual = compiled(actual_input)
|
||||
|
||||
assert torch.equal(actual, expected)
|
||||
assert torch.equal(actual_input, expected_input)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("backend", ["eager", "aot_eager", "inductor"])
|
||||
def test_stock_torch_compile_preserves_mutated_return_alias(backend: str) -> None:
|
||||
model = MutateInputReturnAlias()
|
||||
x = torch.arange(6, dtype=torch.float32).reshape(2, 3)
|
||||
|
||||
compiled = torch.compile(model, backend=backend)
|
||||
out = compiled(x)
|
||||
|
||||
assert torch.equal(x, torch.arange(1, 7, dtype=torch.float32).reshape(2, 3))
|
||||
_assert_same_storage(out, x)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("backend", ["eager", "aot_eager", "inductor"])
|
||||
def test_stock_torch_compile_preserves_returned_view_alias(backend: str) -> None:
|
||||
model = ReturnInputView()
|
||||
x = torch.arange(6, dtype=torch.float32).reshape(2, 3)
|
||||
|
||||
compiled = torch.compile(model, backend=backend)
|
||||
out = compiled(x)
|
||||
|
||||
assert torch.equal(out, x.t())
|
||||
assert out.stride() == (1, 3)
|
||||
_assert_same_storage(out, x)
|
||||
|
||||
|
||||
@pytest.mark.xfail(
|
||||
strict=True,
|
||||
reason=(
|
||||
"Luminal currently treats functionalized input-mutation outputs as user "
|
||||
"outputs and does not copy mutation outputs back to caller inputs."
|
||||
),
|
||||
)
|
||||
def test_luminal_input_mutation_writeback_contract(device: torch.device) -> None:
|
||||
model = MutateInputThenCompute().to(device)
|
||||
x = torch.arange(6, dtype=torch.float32, device=device).reshape(2, 3)
|
||||
|
||||
compiled = torch.compile(model, backend=luminal_backend)
|
||||
out = compiled(x)
|
||||
|
||||
expected_x = torch.arange(1, 7, dtype=torch.float32, device=device).reshape(2, 3)
|
||||
expected_out = expected_x * 2.0
|
||||
assert torch.equal(out, expected_out)
|
||||
assert torch.equal(x, expected_x)
|
||||
|
||||
|
||||
@pytest.mark.xfail(
|
||||
strict=True,
|
||||
reason=(
|
||||
"Luminal does not preserve caller-visible overlapping input aliasing "
|
||||
"when one aliased input is mutated."
|
||||
),
|
||||
)
|
||||
def test_luminal_overlapping_input_alias_mutation_contract(
|
||||
device: torch.device,
|
||||
) -> None:
|
||||
model = MutateOverlappingInputAlias().to(device)
|
||||
|
||||
eager_base = torch.arange(6, dtype=torch.float32, device=device)
|
||||
expected = model(eager_base[:4], eager_base[1:5])
|
||||
|
||||
base = torch.arange(6, dtype=torch.float32, device=device)
|
||||
compiled = torch.compile(model, backend=luminal_backend)
|
||||
actual = compiled(base[:4], base[1:5])
|
||||
|
||||
assert torch.equal(actual, expected)
|
||||
assert torch.equal(base, eager_base)
|
||||
|
||||
|
||||
@pytest.mark.xfail(
|
||||
strict=True,
|
||||
reason="Luminal materializes returned input views instead of preserving aliasing.",
|
||||
)
|
||||
def test_luminal_returned_view_alias_contract(device: torch.device) -> None:
|
||||
model = ReturnInputView().to(device)
|
||||
x = torch.arange(6, dtype=torch.float32, device=device).reshape(2, 3)
|
||||
|
||||
compiled = torch.compile(model, backend=luminal_backend)
|
||||
out = compiled(x)
|
||||
|
||||
assert torch.equal(out, x.t())
|
||||
assert out.stride() == (1, 3)
|
||||
_assert_same_storage(out, x)
|
||||
@@ -1,3 +1,5 @@
|
||||
#[path = "../../../examples_common/benchmark_stdio.rs"]
|
||||
mod benchmark_stdio;
|
||||
mod hf;
|
||||
mod model;
|
||||
|
||||
@@ -6,18 +8,14 @@ use luminal::prelude::*;
|
||||
use luminal_cuda_lite::{cudarc::driver::CudaContext, runtime::CudaRuntime};
|
||||
use model::*;
|
||||
use rustc_hash::FxHashSet;
|
||||
use std::{io::Write, time::Duration};
|
||||
use std::{
|
||||
io::Write,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
const REPO_ID: &str = "google/gemma-4-26B-A4B";
|
||||
|
||||
fn env_usize(name: &str, default: usize) -> usize {
|
||||
std::env::var(name)
|
||||
.ok()
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(default)
|
||||
}
|
||||
|
||||
fn env_bool(name: &str) -> bool {
|
||||
std::env::var(name)
|
||||
.ok()
|
||||
@@ -25,9 +23,10 @@ fn env_bool(name: &str) -> bool {
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let max_seq_len = env_usize("MAX_SEQ_LEN", 4096);
|
||||
let gen_tokens = env_usize("GEN_TOKENS", 30);
|
||||
let search_graphs = env_usize("SEARCH_GRAPHS", 50);
|
||||
let stdio = benchmark_stdio::enabled();
|
||||
let max_seq_len = benchmark_stdio::env_usize("MAX_SEQ_LEN", 4096);
|
||||
let gen_tokens = benchmark_stdio::env_usize("GEN_TOKENS", 30);
|
||||
let search_graphs = benchmark_stdio::env_usize("SEARCH_GRAPHS", 50);
|
||||
let prompt = std::env::var("PROMPT").unwrap_or_else(|_| "The capital of France is".to_string());
|
||||
let print_token_ids = env_bool("PRINT_TOKEN_IDS");
|
||||
|
||||
@@ -38,11 +37,6 @@ fn main() {
|
||||
println!("Using model directory: {}", model_dir.display());
|
||||
|
||||
let tokenizer = Tokenizer::from_file(model_dir.join("tokenizer.json")).unwrap();
|
||||
let prompt_tokens = tokenizer
|
||||
.encode(prompt.as_str(), true)
|
||||
.unwrap()
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
|
||||
let mut cx = Graph::default();
|
||||
let input = cx.named_tensor("input", 's').as_dtype(DType::Int);
|
||||
@@ -63,11 +57,14 @@ fn main() {
|
||||
let weights_path = model_dir.join("model_combined.safetensors");
|
||||
runtime.load_safetensors(&cx, weights_path.to_str().unwrap());
|
||||
|
||||
for layer in 0..LAYERS {
|
||||
let cache_bytes = cache_bytes_for_layer(layer, max_seq_len);
|
||||
runtime.set_zeros(kv_cache.k_caches[layer], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[layer], cache_bytes);
|
||||
}
|
||||
let reset_cache = |runtime: &mut CudaRuntime| {
|
||||
for layer in 0..LAYERS {
|
||||
let cache_bytes = cache_bytes_for_layer(layer, max_seq_len);
|
||||
runtime.set_zeros(kv_cache.k_caches[layer], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[layer], cache_bytes);
|
||||
}
|
||||
};
|
||||
reset_cache(&mut runtime);
|
||||
|
||||
println!("Compiling...");
|
||||
cx.set_dim('s', 1);
|
||||
@@ -75,15 +72,66 @@ fn main() {
|
||||
runtime.set_data(input, vec![1]);
|
||||
runtime.set_data(pos_ids, vec![1]);
|
||||
runtime = cx.search(runtime, search_graphs);
|
||||
reset_cache(&mut runtime);
|
||||
|
||||
for layer in 0..LAYERS {
|
||||
let cache_bytes = cache_bytes_for_layer(layer, max_seq_len);
|
||||
runtime.set_zeros(kv_cache.k_caches[layer], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[layer], cache_bytes);
|
||||
if stdio {
|
||||
benchmark_stdio::serve(|prompt| {
|
||||
reset_cache(&mut runtime);
|
||||
run_prompt(
|
||||
prompt,
|
||||
gen_tokens,
|
||||
print_token_ids,
|
||||
&tokenizer,
|
||||
&mut cx,
|
||||
&mut runtime,
|
||||
input,
|
||||
pos_ids,
|
||||
logits,
|
||||
&cache_outputs,
|
||||
&kv_cache,
|
||||
true,
|
||||
);
|
||||
});
|
||||
} else {
|
||||
run_prompt(
|
||||
&prompt,
|
||||
gen_tokens,
|
||||
print_token_ids,
|
||||
&tokenizer,
|
||||
&mut cx,
|
||||
&mut runtime,
|
||||
input,
|
||||
pos_ids,
|
||||
logits,
|
||||
&cache_outputs,
|
||||
&kv_cache,
|
||||
false,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
print!("{prompt}");
|
||||
std::io::stdout().flush().unwrap();
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn run_prompt(
|
||||
prompt: &str,
|
||||
gen_tokens: usize,
|
||||
print_token_ids: bool,
|
||||
tokenizer: &Tokenizer,
|
||||
cx: &mut Graph,
|
||||
runtime: &mut CudaRuntime,
|
||||
input: GraphTensor,
|
||||
pos_ids: GraphTensor,
|
||||
logits: GraphTensor,
|
||||
cache_outputs: &[(GraphTensor, GraphTensor)],
|
||||
kv_cache: &KVCache,
|
||||
stdio: bool,
|
||||
) {
|
||||
let prompt_tokens = tokenizer.encode(prompt, true).unwrap().get_ids().to_vec();
|
||||
let query_start = Instant::now();
|
||||
|
||||
if !stdio {
|
||||
print!("{prompt}");
|
||||
std::io::stdout().flush().unwrap();
|
||||
}
|
||||
|
||||
let mut prev_seq = 0usize;
|
||||
let mut fwd_durations = vec![];
|
||||
@@ -93,7 +141,7 @@ fn main() {
|
||||
|
||||
const EOS_TOKEN: u32 = 1;
|
||||
|
||||
let prefill_start = std::time::Instant::now();
|
||||
let prefill_start = Instant::now();
|
||||
for &token in &prompt_tokens {
|
||||
cx.set_dim('s', 1);
|
||||
cx.set_dim('p', prev_seq);
|
||||
@@ -121,12 +169,26 @@ fn main() {
|
||||
.unwrap()
|
||||
.0 as u32;
|
||||
generated_token_ids.push(next_token);
|
||||
print!("{}", tokenizer.decode(&[next_token], true).unwrap());
|
||||
std::io::stdout().flush().unwrap();
|
||||
let mut generated = 0usize;
|
||||
if stdio {
|
||||
if next_token != EOS_TOKEN {
|
||||
let decoded = tokenizer.decode(&[next_token], true).unwrap();
|
||||
benchmark_stdio::emit_token(&decoded);
|
||||
generated += 1;
|
||||
}
|
||||
} else {
|
||||
let decoded = tokenizer.decode(&[next_token], true).unwrap();
|
||||
print!("{decoded}");
|
||||
std::io::stdout().flush().unwrap();
|
||||
}
|
||||
seen_tokens.insert(next_token);
|
||||
|
||||
for _ in 1..gen_tokens {
|
||||
let start = std::time::Instant::now();
|
||||
if stdio && next_token == EOS_TOKEN {
|
||||
break;
|
||||
}
|
||||
|
||||
let start = Instant::now();
|
||||
cx.set_dim('s', 1);
|
||||
cx.set_dim('p', prev_seq);
|
||||
runtime.set_data(input, vec![next_token as i32]);
|
||||
@@ -165,10 +227,21 @@ fn main() {
|
||||
break;
|
||||
}
|
||||
|
||||
print!("{}", tokenizer.decode(&[next_token], true).unwrap());
|
||||
std::io::stdout().flush().unwrap();
|
||||
let decoded = tokenizer.decode(&[next_token], true).unwrap();
|
||||
if stdio {
|
||||
benchmark_stdio::emit_token(&decoded);
|
||||
} else {
|
||||
print!("{decoded}");
|
||||
std::io::stdout().flush().unwrap();
|
||||
}
|
||||
generated += 1;
|
||||
fwd_durations.push(start.elapsed());
|
||||
}
|
||||
if stdio {
|
||||
benchmark_stdio::emit_eoq(generated, query_start);
|
||||
return;
|
||||
}
|
||||
|
||||
println!();
|
||||
if print_token_ids {
|
||||
println!("Generated token ids: {generated_token_ids:?}");
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
#[path = "../../../examples_common/benchmark_stdio.rs"]
|
||||
mod benchmark_stdio;
|
||||
mod hf;
|
||||
mod model;
|
||||
|
||||
@@ -7,22 +9,36 @@ use luminal_cuda_lite::{cudarc::driver::CudaContext, runtime::CudaRuntime};
|
||||
use luminal_tracing::*;
|
||||
use model::*;
|
||||
use rustc_hash::FxHashSet;
|
||||
use std::{io::Write, time::Duration};
|
||||
use std::{
|
||||
io::Write,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use tokenizers::Tokenizer;
|
||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
||||
|
||||
const REPO_ID: &str = "NousResearch/Meta-Llama-3-8B-Instruct";
|
||||
|
||||
fn main() {
|
||||
let stdio = benchmark_stdio::enabled();
|
||||
let max_seq_len = 4096;
|
||||
let gen_tokens = 500;
|
||||
let search_graphs = 500;
|
||||
let gen_tokens = if stdio {
|
||||
benchmark_stdio::env_usize("GEN_TOKENS", 500)
|
||||
} else {
|
||||
500
|
||||
};
|
||||
let search_graphs = if stdio {
|
||||
benchmark_stdio::env_usize("SEARCH_GRAPHS", 500)
|
||||
} else {
|
||||
500
|
||||
};
|
||||
let prompt = "Explain what a neural network is in a paragraph.";
|
||||
|
||||
tracing_subscriber::registry()
|
||||
.with(tracing_subscriber::fmt::layer())
|
||||
.with(luminal_filter())
|
||||
.init();
|
||||
if !stdio {
|
||||
tracing_subscriber::registry()
|
||||
.with(tracing_subscriber::fmt::layer())
|
||||
.with(luminal_filter())
|
||||
.init();
|
||||
}
|
||||
|
||||
let ctx = CudaContext::new(0).unwrap();
|
||||
let stream = ctx.default_stream();
|
||||
@@ -31,14 +47,6 @@ fn main() {
|
||||
println!("Using model directory: {}", model_dir.display());
|
||||
|
||||
let tokenizer = Tokenizer::from_file(model_dir.join("tokenizer.json")).unwrap();
|
||||
let chat_prompt = format!(
|
||||
"<|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
);
|
||||
let prompt_tokens = tokenizer
|
||||
.encode(chat_prompt.as_str(), true)
|
||||
.unwrap()
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
|
||||
// Build graph
|
||||
let mut cx = Graph::default();
|
||||
@@ -66,10 +74,13 @@ fn main() {
|
||||
runtime.load_safetensors(&cx, weights_path.to_str().unwrap());
|
||||
|
||||
let cache_bytes = N_KV_HEADS * max_seq_len * HEAD_DIM * std::mem::size_of::<f32>();
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
|
||||
}
|
||||
let reset_cache = |runtime: &mut CudaRuntime| {
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
|
||||
}
|
||||
};
|
||||
reset_cache(&mut runtime);
|
||||
|
||||
println!("Compiling...");
|
||||
cx.set_dim('s', 1);
|
||||
@@ -77,12 +88,65 @@ fn main() {
|
||||
runtime.set_data(input, vec![1]);
|
||||
runtime.set_data(token_ids, vec![1]);
|
||||
runtime = cx.search(runtime, search_graphs);
|
||||
reset_cache(&mut runtime);
|
||||
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
|
||||
if stdio {
|
||||
benchmark_stdio::serve(|prompt| {
|
||||
reset_cache(&mut runtime);
|
||||
run_prompt(
|
||||
prompt,
|
||||
gen_tokens,
|
||||
&tokenizer,
|
||||
&mut cx,
|
||||
&mut runtime,
|
||||
input,
|
||||
token_ids,
|
||||
logits,
|
||||
&cache_outputs,
|
||||
&kv_cache,
|
||||
true,
|
||||
);
|
||||
});
|
||||
} else {
|
||||
run_prompt(
|
||||
prompt,
|
||||
gen_tokens,
|
||||
&tokenizer,
|
||||
&mut cx,
|
||||
&mut runtime,
|
||||
input,
|
||||
token_ids,
|
||||
logits,
|
||||
&cache_outputs,
|
||||
&kv_cache,
|
||||
false,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn run_prompt(
|
||||
prompt: &str,
|
||||
gen_tokens: usize,
|
||||
tokenizer: &Tokenizer,
|
||||
cx: &mut Graph,
|
||||
runtime: &mut CudaRuntime,
|
||||
input: GraphTensor,
|
||||
token_ids: GraphTensor,
|
||||
logits: GraphTensor,
|
||||
cache_outputs: &[(GraphTensor, GraphTensor)],
|
||||
kv_cache: &KVCache,
|
||||
stdio: bool,
|
||||
) {
|
||||
let chat_prompt = format!(
|
||||
"<|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
);
|
||||
let prompt_tokens = tokenizer
|
||||
.encode(chat_prompt.as_str(), true)
|
||||
.unwrap()
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
let query_start = Instant::now();
|
||||
let mut prev_seq = 1usize;
|
||||
let mut sentence = vec![prompt_tokens[0]];
|
||||
let total_steps = prompt_tokens.len() - 1 + gen_tokens;
|
||||
@@ -94,13 +158,16 @@ fn main() {
|
||||
const EOS_TOKEN: u32 = 128009;
|
||||
const STOP_TOKEN: u32 = 128001;
|
||||
|
||||
println!(
|
||||
"Prompt: {} tokens, generating up to {} tokens",
|
||||
prompt_len, gen_tokens
|
||||
);
|
||||
if !stdio {
|
||||
println!(
|
||||
"Prompt: {} tokens, generating up to {} tokens",
|
||||
prompt_len, gen_tokens
|
||||
);
|
||||
}
|
||||
|
||||
let mut generated = 0usize;
|
||||
for i in 0..total_steps {
|
||||
let start = std::time::Instant::now();
|
||||
let start = Instant::now();
|
||||
let is_prefill = i < prompt_len - 1;
|
||||
let seq_len = sentence.len();
|
||||
|
||||
@@ -159,12 +226,21 @@ fn main() {
|
||||
}
|
||||
|
||||
let decoded = tokenizer.decode(&[next_token], true).unwrap();
|
||||
print!("{}", decoded);
|
||||
std::io::stdout().flush().unwrap();
|
||||
if stdio {
|
||||
benchmark_stdio::emit_token(&decoded);
|
||||
} else {
|
||||
print!("{}", decoded);
|
||||
std::io::stdout().flush().unwrap();
|
||||
}
|
||||
generated += 1;
|
||||
}
|
||||
if stdio {
|
||||
benchmark_stdio::emit_eoq(generated, query_start);
|
||||
return;
|
||||
}
|
||||
println!();
|
||||
|
||||
// Benchmarks
|
||||
println!();
|
||||
let decode_durations: Vec<_> = fwd_durations.iter().skip(prompt_len).collect();
|
||||
if decode_durations.len() > 2 {
|
||||
println!(
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
#[path = "../../../examples_common/benchmark_stdio.rs"]
|
||||
mod benchmark_stdio;
|
||||
mod hf;
|
||||
mod model;
|
||||
|
||||
@@ -7,22 +9,36 @@ use luminal_cuda_lite::{cudarc::driver::CudaContext, runtime::CudaRuntime};
|
||||
use luminal_tracing::*;
|
||||
use model::*;
|
||||
use rustc_hash::FxHashSet;
|
||||
use std::{io::Write, time::Duration};
|
||||
use std::{
|
||||
io::Write,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use tokenizers::Tokenizer;
|
||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
||||
|
||||
const REPO_ID: &str = "Qwen/Qwen3-4B";
|
||||
|
||||
fn main() {
|
||||
let stdio = benchmark_stdio::enabled();
|
||||
let max_seq_len = 4096;
|
||||
let gen_tokens = 500;
|
||||
let search_graphs = 500;
|
||||
let gen_tokens = if stdio {
|
||||
benchmark_stdio::env_usize("GEN_TOKENS", 500)
|
||||
} else {
|
||||
500
|
||||
};
|
||||
let search_graphs = if stdio {
|
||||
benchmark_stdio::env_usize("SEARCH_GRAPHS", 500)
|
||||
} else {
|
||||
500
|
||||
};
|
||||
let prompt = "Explain what a neural network is in a paragraph.";
|
||||
|
||||
tracing_subscriber::registry()
|
||||
.with(tracing_subscriber::fmt::layer())
|
||||
.with(luminal_filter())
|
||||
.init();
|
||||
if !stdio {
|
||||
tracing_subscriber::registry()
|
||||
.with(tracing_subscriber::fmt::layer())
|
||||
.with(luminal_filter())
|
||||
.init();
|
||||
}
|
||||
|
||||
let ctx = CudaContext::new(0).unwrap();
|
||||
let stream = ctx.default_stream();
|
||||
@@ -31,7 +47,6 @@ fn main() {
|
||||
println!("Using model directory: {}", model_dir.display());
|
||||
|
||||
let tokenizer = Tokenizer::from_file(model_dir.join("tokenizer.json")).unwrap();
|
||||
let prompt_tokens = tokenizer.encode(prompt, true).unwrap().get_ids().to_vec();
|
||||
|
||||
// Build graph
|
||||
let mut cx = Graph::default();
|
||||
@@ -54,10 +69,13 @@ fn main() {
|
||||
runtime.load_safetensors(&cx, weights_path.to_str().unwrap());
|
||||
|
||||
let cache_bytes = N_KV_HEADS * max_seq_len * HEAD_DIM * std::mem::size_of::<f32>();
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
|
||||
}
|
||||
let reset_cache = |runtime: &mut CudaRuntime| {
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
|
||||
}
|
||||
};
|
||||
reset_cache(&mut runtime);
|
||||
|
||||
println!("Compiling...");
|
||||
cx.set_dim('s', 1);
|
||||
@@ -65,12 +83,58 @@ fn main() {
|
||||
runtime.set_data(input, vec![1]);
|
||||
runtime.set_data(token_ids, vec![1]);
|
||||
runtime = cx.search(runtime, search_graphs);
|
||||
reset_cache(&mut runtime);
|
||||
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
|
||||
if stdio {
|
||||
benchmark_stdio::serve(|prompt| {
|
||||
reset_cache(&mut runtime);
|
||||
run_prompt(
|
||||
prompt,
|
||||
gen_tokens,
|
||||
&tokenizer,
|
||||
&mut cx,
|
||||
&mut runtime,
|
||||
input,
|
||||
token_ids,
|
||||
logits,
|
||||
&cache_outputs,
|
||||
&kv_cache,
|
||||
true,
|
||||
);
|
||||
});
|
||||
} else {
|
||||
run_prompt(
|
||||
prompt,
|
||||
gen_tokens,
|
||||
&tokenizer,
|
||||
&mut cx,
|
||||
&mut runtime,
|
||||
input,
|
||||
token_ids,
|
||||
logits,
|
||||
&cache_outputs,
|
||||
&kv_cache,
|
||||
false,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn run_prompt(
|
||||
prompt: &str,
|
||||
gen_tokens: usize,
|
||||
tokenizer: &Tokenizer,
|
||||
cx: &mut Graph,
|
||||
runtime: &mut CudaRuntime,
|
||||
input: GraphTensor,
|
||||
token_ids: GraphTensor,
|
||||
logits: GraphTensor,
|
||||
cache_outputs: &[(GraphTensor, GraphTensor)],
|
||||
kv_cache: &KVCache,
|
||||
stdio: bool,
|
||||
) {
|
||||
let prompt_tokens = tokenizer.encode(prompt, true).unwrap().get_ids().to_vec();
|
||||
let query_start = Instant::now();
|
||||
let mut prev_seq = 1usize;
|
||||
let mut sentence = vec![prompt_tokens[0]];
|
||||
let total_steps = prompt_tokens.len() - 1 + gen_tokens;
|
||||
@@ -82,13 +146,16 @@ fn main() {
|
||||
const EOS_TOKEN: u32 = 151645; // <|endoftext|>
|
||||
const STOP_TOKEN: u32 = 151643; // <|end|>
|
||||
|
||||
println!(
|
||||
"Prompt: {} tokens, generating up to {} tokens",
|
||||
prompt_len, gen_tokens
|
||||
);
|
||||
if !stdio {
|
||||
println!(
|
||||
"Prompt: {} tokens, generating up to {} tokens",
|
||||
prompt_len, gen_tokens
|
||||
);
|
||||
}
|
||||
|
||||
let mut generated = 0usize;
|
||||
for i in 0..total_steps {
|
||||
let start = std::time::Instant::now();
|
||||
let start = Instant::now();
|
||||
let is_prefill = i < prompt_len - 1;
|
||||
let seq_len = sentence.len();
|
||||
|
||||
@@ -147,12 +214,21 @@ fn main() {
|
||||
}
|
||||
|
||||
let decoded = tokenizer.decode(&[next_token], true).unwrap();
|
||||
print!("{}", decoded);
|
||||
std::io::stdout().flush().unwrap();
|
||||
if stdio {
|
||||
benchmark_stdio::emit_token(&decoded);
|
||||
} else {
|
||||
print!("{}", decoded);
|
||||
std::io::stdout().flush().unwrap();
|
||||
}
|
||||
generated += 1;
|
||||
}
|
||||
if stdio {
|
||||
benchmark_stdio::emit_eoq(generated, query_start);
|
||||
return;
|
||||
}
|
||||
println!();
|
||||
|
||||
// Benchmarks
|
||||
println!();
|
||||
let decode_durations: Vec<_> = fwd_durations.iter().skip(prompt_len).collect();
|
||||
if decode_durations.len() > 2 {
|
||||
println!(
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
#[path = "../../../examples_common/benchmark_stdio.rs"]
|
||||
mod benchmark_stdio;
|
||||
mod hf;
|
||||
mod model;
|
||||
|
||||
@@ -6,15 +8,27 @@ use luminal::prelude::*;
|
||||
use luminal_cuda_lite::{cudarc::driver::CudaContext, runtime::CudaRuntime};
|
||||
use model::*;
|
||||
use rustc_hash::FxHashSet;
|
||||
use std::{io::Write, time::Duration};
|
||||
use std::{
|
||||
io::Write,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
const REPO_ID: &str = "Qwen/Qwen3-30B-A3B";
|
||||
|
||||
fn main() {
|
||||
let stdio = benchmark_stdio::enabled();
|
||||
let max_seq_len = 4096;
|
||||
let gen_tokens = 30;
|
||||
let search_graphs = 50;
|
||||
let gen_tokens = if stdio {
|
||||
benchmark_stdio::env_usize("GEN_TOKENS", 30)
|
||||
} else {
|
||||
30
|
||||
};
|
||||
let search_graphs = if stdio {
|
||||
benchmark_stdio::env_usize("SEARCH_GRAPHS", 50)
|
||||
} else {
|
||||
50
|
||||
};
|
||||
let prompt = "The capital of France is";
|
||||
|
||||
let ctx = CudaContext::new(0).unwrap();
|
||||
@@ -24,7 +38,6 @@ fn main() {
|
||||
println!("Using model directory: {}", model_dir.display());
|
||||
|
||||
let tokenizer = Tokenizer::from_file(model_dir.join("tokenizer.json")).unwrap();
|
||||
let prompt_tokens = tokenizer.encode(prompt, true).unwrap().get_ids().to_vec();
|
||||
|
||||
// Build graph
|
||||
let mut cx = Graph::default();
|
||||
@@ -47,10 +60,13 @@ fn main() {
|
||||
runtime.load_safetensors(&cx, weights_path.to_str().unwrap());
|
||||
|
||||
let cache_bytes = N_KV_HEADS * max_seq_len * HEAD_DIM * std::mem::size_of::<f32>();
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
|
||||
}
|
||||
let reset_cache = |runtime: &mut CudaRuntime| {
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
|
||||
}
|
||||
};
|
||||
reset_cache(&mut runtime);
|
||||
|
||||
println!("Compiling...");
|
||||
cx.set_dim('s', 1);
|
||||
@@ -58,14 +74,63 @@ fn main() {
|
||||
runtime.set_data(input, vec![1]);
|
||||
runtime.set_data(pos_ids, vec![1]);
|
||||
runtime = cx.search(runtime, search_graphs);
|
||||
reset_cache(&mut runtime);
|
||||
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
|
||||
if stdio {
|
||||
benchmark_stdio::serve(|prompt| {
|
||||
reset_cache(&mut runtime);
|
||||
run_prompt(
|
||||
prompt,
|
||||
gen_tokens,
|
||||
&tokenizer,
|
||||
&mut cx,
|
||||
&mut runtime,
|
||||
input,
|
||||
pos_ids,
|
||||
logits,
|
||||
&cache_outputs,
|
||||
&kv_cache,
|
||||
true,
|
||||
);
|
||||
});
|
||||
} else {
|
||||
run_prompt(
|
||||
prompt,
|
||||
gen_tokens,
|
||||
&tokenizer,
|
||||
&mut cx,
|
||||
&mut runtime,
|
||||
input,
|
||||
pos_ids,
|
||||
logits,
|
||||
&cache_outputs,
|
||||
&kv_cache,
|
||||
false,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
print!("{prompt}");
|
||||
std::io::stdout().flush().unwrap();
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn run_prompt(
|
||||
prompt: &str,
|
||||
gen_tokens: usize,
|
||||
tokenizer: &Tokenizer,
|
||||
cx: &mut Graph,
|
||||
runtime: &mut CudaRuntime,
|
||||
input: GraphTensor,
|
||||
pos_ids: GraphTensor,
|
||||
logits: GraphTensor,
|
||||
cache_outputs: &[(GraphTensor, GraphTensor)],
|
||||
kv_cache: &KVCache,
|
||||
stdio: bool,
|
||||
) {
|
||||
let prompt_tokens = tokenizer.encode(prompt, true).unwrap().get_ids().to_vec();
|
||||
let query_start = Instant::now();
|
||||
|
||||
if !stdio {
|
||||
print!("{prompt}");
|
||||
std::io::stdout().flush().unwrap();
|
||||
}
|
||||
|
||||
let mut prev_seq = 0usize;
|
||||
let mut fwd_durations = vec![];
|
||||
@@ -76,7 +141,7 @@ fn main() {
|
||||
const STOP_TOKEN: u32 = 151643;
|
||||
|
||||
// Prefill: process prompt tokens one at a time
|
||||
let prefill_start = std::time::Instant::now();
|
||||
let prefill_start = Instant::now();
|
||||
for &token in &prompt_tokens {
|
||||
cx.set_dim('s', 1);
|
||||
cx.set_dim('p', prev_seq);
|
||||
@@ -105,13 +170,27 @@ fn main() {
|
||||
.max_by(|(_, a), (_, b)| a.total_cmp(b))
|
||||
.unwrap()
|
||||
.0 as u32;
|
||||
print!("{}", tokenizer.decode(&[next_token], true).unwrap());
|
||||
std::io::stdout().flush().unwrap();
|
||||
let mut generated = 0usize;
|
||||
if stdio {
|
||||
if next_token != EOS_TOKEN && next_token != STOP_TOKEN {
|
||||
let decoded = tokenizer.decode(&[next_token], true).unwrap();
|
||||
benchmark_stdio::emit_token(&decoded);
|
||||
generated += 1;
|
||||
}
|
||||
} else {
|
||||
let decoded = tokenizer.decode(&[next_token], true).unwrap();
|
||||
print!("{decoded}");
|
||||
std::io::stdout().flush().unwrap();
|
||||
}
|
||||
seen_tokens.insert(next_token);
|
||||
|
||||
// Decode loop
|
||||
for _ in 1..gen_tokens {
|
||||
let start = std::time::Instant::now();
|
||||
if stdio && (next_token == EOS_TOKEN || next_token == STOP_TOKEN) {
|
||||
break;
|
||||
}
|
||||
|
||||
let start = Instant::now();
|
||||
cx.set_dim('s', 1);
|
||||
cx.set_dim('p', prev_seq);
|
||||
runtime.set_data(input, vec![next_token as i32]);
|
||||
@@ -150,13 +229,23 @@ fn main() {
|
||||
break;
|
||||
}
|
||||
|
||||
print!("{}", tokenizer.decode(&[next_token], true).unwrap());
|
||||
std::io::stdout().flush().unwrap();
|
||||
let decoded = tokenizer.decode(&[next_token], true).unwrap();
|
||||
if stdio {
|
||||
benchmark_stdio::emit_token(&decoded);
|
||||
} else {
|
||||
print!("{decoded}");
|
||||
std::io::stdout().flush().unwrap();
|
||||
}
|
||||
generated += 1;
|
||||
fwd_durations.push(start.elapsed());
|
||||
}
|
||||
println!();
|
||||
if stdio {
|
||||
benchmark_stdio::emit_eoq(generated, query_start);
|
||||
return;
|
||||
}
|
||||
|
||||
// Report benchmarks
|
||||
println!();
|
||||
println!(
|
||||
" TTFT: {:.2} ms ({} prompt tokens)",
|
||||
prefill_duration.as_secs_f64() * 1e3,
|
||||
|
||||
58
examples_common/benchmark_stdio.rs
Normal file
58
examples_common/benchmark_stdio.rs
Normal file
@@ -0,0 +1,58 @@
|
||||
use std::{
|
||||
io::{BufRead, Write},
|
||||
time::Instant,
|
||||
};
|
||||
|
||||
pub fn enabled() -> bool {
|
||||
std::env::args().any(|arg| arg == "--stdio")
|
||||
}
|
||||
|
||||
pub fn env_usize(name: &str, default: usize) -> usize {
|
||||
std::env::var(name)
|
||||
.ok()
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(default)
|
||||
}
|
||||
|
||||
fn emit_ready() {
|
||||
println!("READY");
|
||||
std::io::stdout().flush().unwrap();
|
||||
}
|
||||
|
||||
pub fn serve(mut f: impl FnMut(&str)) {
|
||||
emit_ready();
|
||||
|
||||
let stdin = std::io::stdin();
|
||||
for line in stdin.lock().lines() {
|
||||
let line = line.unwrap();
|
||||
f(&line);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn emit_token(token: &str) {
|
||||
println!("TOK\t{}", escape_token(token));
|
||||
std::io::stdout().flush().unwrap();
|
||||
}
|
||||
|
||||
pub fn emit_eoq(generated: usize, query_start: Instant) {
|
||||
println!(
|
||||
"EOQ\t{}\t{:.3}",
|
||||
generated,
|
||||
query_start.elapsed().as_secs_f64() * 1e3
|
||||
);
|
||||
std::io::stdout().flush().unwrap();
|
||||
}
|
||||
|
||||
fn escape_token(s: &str) -> String {
|
||||
let mut out = String::with_capacity(s.len());
|
||||
for ch in s.chars() {
|
||||
match ch {
|
||||
'\\' => out.push_str("\\\\"),
|
||||
'\t' => out.push_str("\\t"),
|
||||
'\n' => out.push_str("\\n"),
|
||||
'\r' => out.push_str("\\r"),
|
||||
_ => out.push(ch),
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
Reference in New Issue
Block a user