whisper example

This commit is contained in:
Joe Fioti
2026-05-02 21:45:15 +00:00
parent 5748ac644e
commit cfedd80c9b
12 changed files with 1589 additions and 1 deletions

View File

@@ -22,7 +22,7 @@ jobs:
strategy:
fail-fast: false
matrix:
example: [llama, gemma, qwen, qwen3_moe, gemma4_moe]
example: [llama, gemma, qwen, qwen3_moe, gemma4_moe, whisper]
gpu:
- { type: "A100-80GB" }
# To add more GPUs, just append another entry:

View File

@@ -40,6 +40,9 @@ EXPECTED_OUTPUT = {
"gemma4_moe": [
"city of romance, art and culture",
],
"whisper": [
"ask not what your country can do for you",
],
}

View File

@@ -782,3 +782,14 @@ identical across all attempts (dtype issue) vs varying (actual numerical issue).
3. **Why "defensive fallback" framing is misleading**: it implies the LLIR is broken. It isn't. The forward-walk-only `body_nodes` definition just doesn't cover this case, because the case requires no per-iter cloning at all. A *node not reachable from any loop input marker has no input-marker ancestor*, so by construction its value doesn't depend on the loop's per-iter state.
4. **Cleaner formulation**: name the concept. Compute an `iteration_invariant_slots: HashSet<LoopStart>` set at the same time `start_meta` is built, with the rule `body_producer ∉ body_nodes ⇒ iteration_invariant`. `resolve_src` and `marker_post_sub` then have explicit branches: if the slot is invariant, use `body_producer` directly; otherwise the standard per-iter clone lookup. The behavior is the same as the `unwrap_or` band-aid, but the code now documents that this is a real, sound case the unroll handles correctly — not a panic suppressor.
5. **Principle**: when an `unwrap_or` papers over a case that turns out to be semantically valid, the right cleanup isn't to keep the `unwrap_or` and add a comment — it's to name the case. Hoist the predicate into a set or enum and branch on it explicitly. The compiler then enforces that every consumer of the per-iter cloning machinery has an opinion on iteration-invariant slots, instead of silently relying on a `Map::get` returning `None` at the right moment.
## 2026-05-02 — Whisper port hit two missing-translator pitfalls
1. **Symptom**: Compiling a PyTorch port of Whisper-tiny.en through `luminal_backend` failed twice in a row at the dispatch table: first with `Unsupported ATen op: torch.ops.aten.gelu.default`, then with `full: unsupported fill value type ... -Infinity`.
2. **Root cause #1**: the dispatch table in `crates/luminal_python/rust/src/translator/dispatch.rs` mapped `sigmoid`, `tanh`, `relu` etc. but not `gelu` or `silu`. Whisper's encoder uses `F.gelu`, so the activation hit a hole.
3. **Root cause #2**: PyTorch serializes `float("-inf")` in PT2 as the string `"-Infinity"` (and `"NaN"`/`"Infinity"` analogously). `translate_full`'s `get_float_arg` only accepts numeric float/int payloads, so any `torch.full((..), -inf)` (the obvious way to write a causal mask) blows up. Decoder mask code is the most common spot.
4. **Why it was tricky**: both errors arrive from inside `pt2_backend` with a stack trace that ends in `process_pt2`, hiding the actual ATen target inside the message. You only see the offending op name in the error string itself, so you have to read `RuntimeError: Failed to translate node N: …` carefully and grep `dispatch.rs` for it.
5. **Fix in this session**:
- Added `aten.gelu.default → a.gelu()` and `aten.silu.default → a.silu()` to `dispatch.rs`.
- Worked around the `-Infinity` issue at the model level by using a finite `-1e10` for the causal mask in the example (matches the Rust example's convention). The cleaner fix (parsing `"-Infinity"`/`"Infinity"`/`"NaN"` strings in `get_float_arg` / `translate_full`) is left for a follow-up.
6. **Principle**: when adding a new model that goes through the PT2 backend, expect to plug small holes in `dispatch.rs` and `translator/tensor.rs::translate_full`. The trace points at the python frame, not the Rust dispatch arm — open `dispatch.rs`, ctrl-F the offending op name, and add the one-liner. For float-shaped sentinel values (`-inf`, `inf`, `nan`), the export pipeline currently only accepts finite floats; either rewrite the model or extend the parser.

View File

@@ -0,0 +1,429 @@
"""Whisper transcription demo using the luminal torch.compile backend.
Implements a small PyTorch port of ``openai/whisper-tiny.en`` that mirrors the
luminal Rust example (``examples/whisper`` in the workspace), loads the official
HuggingFace weights, and runs greedy decoding through the luminal backend via
``torch.compile``.
Usage::
uv run python examples/whisper.py [path/to/audio.wav]
If no path is provided, falls back to the JFK sample bundled with the Rust
``examples/whisper`` crate.
"""
from __future__ import annotations
import os
import sys
import time
import wave
from pathlib import Path
from typing import Optional
import numpy as np
import torch
import torch._dynamo
import torch.nn.functional as F
from transformers import WhisperFeatureExtractor, WhisperForConditionalGeneration, WhisperTokenizer
from luminal.pt2 import compile as luminal_compile
REPO_ID = "openai/whisper-tiny.en"
# whisper-tiny.en hyperparameters
N_MELS = 80
N_AUDIO_CTX = 1500
D_MODEL = 384
N_HEADS = 6
HEAD_DIM = D_MODEL // N_HEADS
N_AUDIO_LAYER = 4
N_TEXT_LAYER = 4
N_TEXT_CTX = 448
FF_DIM = 4 * D_MODEL
N_VOCAB = 51864
LAYER_NORM_EPS = 1e-5
# Decoder special tokens
TOKEN_SOT = 50257
TOKEN_NO_TIMESTAMPS = 50362
TOKEN_EOT = 50256
# ---------------------------------------------------------------------------
# Model — mirrors the HLIR encoder/decoder in examples/whisper/src/model.rs
# ---------------------------------------------------------------------------
class WhisperAttention(torch.nn.Module):
"""Multi-head attention with separate q/k/v projections (no bias on k_proj)."""
def __init__(self, d_model: int = D_MODEL, n_heads: int = N_HEADS):
super().__init__()
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.q_proj = torch.nn.Linear(d_model, d_model, bias=True)
self.k_proj = torch.nn.Linear(d_model, d_model, bias=False)
self.v_proj = torch.nn.Linear(d_model, d_model, bias=True)
self.out_proj = torch.nn.Linear(d_model, d_model, bias=True)
def forward(
self,
x: torch.Tensor,
kv_input: Optional[torch.Tensor] = None,
causal: bool = False,
) -> torch.Tensor:
# x: (seq, d_model). kv_input is None → self-attn; otherwise cross-attn.
kv = x if kv_input is None else kv_input
q = self.q_proj(x)
k = self.k_proj(kv)
v = self.v_proj(kv)
seq_q = q.shape[0]
seq_kv = k.shape[0]
# (seq, d_model) -> (n_heads, seq, head_dim)
q = q.reshape(seq_q, self.n_heads, self.head_dim).transpose(0, 1)
k = k.reshape(seq_kv, self.n_heads, self.head_dim).transpose(0, 1)
v = v.reshape(seq_kv, self.n_heads, self.head_dim).transpose(0, 1)
scale = 1.0 / (self.head_dim ** 0.5)
scores = torch.matmul(q, k.transpose(-2, -1)) * scale # (h, sq, sk)
if causal:
# Use a large finite negative instead of -inf so the export pipeline
# serializes a float instead of the unsupported "-Infinity" sentinel.
mask = torch.triu(
torch.full((seq_q, seq_kv), -1e10, device=x.device),
diagonal=1,
)
scores = scores + mask
weights = torch.softmax(scores, dim=-1)
attn = torch.matmul(weights, v) # (h, sq, hd)
merged = attn.transpose(0, 1).reshape(seq_q, -1)
return self.out_proj(merged)
class EncoderLayer(torch.nn.Module):
def __init__(self):
super().__init__()
self.self_attn = WhisperAttention()
self.self_attn_layer_norm = torch.nn.LayerNorm(D_MODEL, eps=LAYER_NORM_EPS)
self.fc1 = torch.nn.Linear(D_MODEL, FF_DIM, bias=True)
self.fc2 = torch.nn.Linear(FF_DIM, D_MODEL, bias=True)
self.final_layer_norm = torch.nn.LayerNorm(D_MODEL, eps=LAYER_NORM_EPS)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + self.self_attn(self.self_attn_layer_norm(x))
h = self.final_layer_norm(x)
h = F.gelu(self.fc1(h))
h = self.fc2(h)
return x + h
class WhisperEncoder(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv1d(N_MELS, D_MODEL, kernel_size=3, padding=1, bias=True)
self.conv2 = torch.nn.Conv1d(D_MODEL, D_MODEL, kernel_size=3, stride=2, padding=1, bias=True)
# Position embedding stored as a regular parameter (matches HF layout).
self.embed_positions = torch.nn.Embedding(N_AUDIO_CTX, D_MODEL)
self.layers = torch.nn.ModuleList([EncoderLayer() for _ in range(N_AUDIO_LAYER)])
self.layer_norm = torch.nn.LayerNorm(D_MODEL, eps=LAYER_NORM_EPS)
def forward(self, mel: torch.Tensor) -> torch.Tensor:
# mel: (n_mels, 3000) -> add batch dim for conv1d
x = mel.unsqueeze(0)
x = F.gelu(self.conv1(x))
x = F.gelu(self.conv2(x))
# (1, d_model, 1500) -> (1500, d_model)
x = x.squeeze(0).transpose(0, 1)
x = x + self.embed_positions.weight
for layer in self.layers:
x = layer(x)
return self.layer_norm(x)
class DecoderLayer(torch.nn.Module):
def __init__(self):
super().__init__()
self.self_attn = WhisperAttention()
self.self_attn_layer_norm = torch.nn.LayerNorm(D_MODEL, eps=LAYER_NORM_EPS)
self.encoder_attn = WhisperAttention()
self.encoder_attn_layer_norm = torch.nn.LayerNorm(D_MODEL, eps=LAYER_NORM_EPS)
self.fc1 = torch.nn.Linear(D_MODEL, FF_DIM, bias=True)
self.fc2 = torch.nn.Linear(FF_DIM, D_MODEL, bias=True)
self.final_layer_norm = torch.nn.LayerNorm(D_MODEL, eps=LAYER_NORM_EPS)
def forward(self, x: torch.Tensor, xa: torch.Tensor) -> torch.Tensor:
x = x + self.self_attn(self.self_attn_layer_norm(x), causal=True)
x = x + self.encoder_attn(self.encoder_attn_layer_norm(x), kv_input=xa)
h = self.final_layer_norm(x)
h = F.gelu(self.fc1(h))
h = self.fc2(h)
return x + h
class WhisperDecoder(torch.nn.Module):
def __init__(self):
super().__init__()
self.embed_tokens = torch.nn.Embedding(N_VOCAB, D_MODEL)
self.embed_positions = torch.nn.Embedding(N_TEXT_CTX, D_MODEL)
self.layers = torch.nn.ModuleList([DecoderLayer() for _ in range(N_TEXT_LAYER)])
self.layer_norm = torch.nn.LayerNorm(D_MODEL, eps=LAYER_NORM_EPS)
def forward(self, tokens: torch.Tensor, xa: torch.Tensor) -> torch.Tensor:
# tokens: (seq,) of int64 — absolute positions are 0..seq-1
seq = tokens.shape[0]
pos = torch.arange(seq, dtype=torch.long, device=tokens.device)
x = self.embed_tokens(tokens) + self.embed_positions(pos)
for layer in self.layers:
x = layer(x, xa)
x = self.layer_norm(x)
# Tied projection
return torch.matmul(x, self.embed_tokens.weight.transpose(0, 1))
class Whisper(torch.nn.Module):
def __init__(self):
super().__init__()
self.encoder = WhisperEncoder()
self.decoder = WhisperDecoder()
def forward(self, mel: torch.Tensor, tokens: torch.Tensor) -> torch.Tensor:
xa = self.encoder(mel)
return self.decoder(tokens, xa)
class DecoderWithFixedXa(torch.nn.Module):
"""Wraps the decoder with the encoder output stored as a buffer.
The audio is fixed for the whole utterance, so ``xa`` is a constant relative
to the per-token decode loop. Storing it as a buffer lets us compile the
decoder once with a single dynamic-length ``tokens`` input, avoiding a full
recompilation at every step as the sequence grows.
"""
def __init__(self, decoder: WhisperDecoder, xa: torch.Tensor):
super().__init__()
self.decoder = decoder
self.register_buffer("xa", xa)
def forward(self, tokens: torch.Tensor) -> torch.Tensor:
return self.decoder(tokens, self.xa)
# ---------------------------------------------------------------------------
# Weight loading: HF state_dict -> our model
# ---------------------------------------------------------------------------
def load_hf_weights_into(model: Whisper) -> None:
"""Copy HF whisper-tiny.en weights into our matching modules."""
hf = WhisperForConditionalGeneration.from_pretrained(REPO_ID).eval()
sd = hf.state_dict()
def get(name: str) -> torch.Tensor:
return sd[f"model.{name}"].clone()
enc = model.encoder
enc.conv1.weight.data.copy_(get("encoder.conv1.weight"))
enc.conv1.bias.data.copy_(get("encoder.conv1.bias"))
enc.conv2.weight.data.copy_(get("encoder.conv2.weight"))
enc.conv2.bias.data.copy_(get("encoder.conv2.bias"))
enc.embed_positions.weight.data.copy_(get("encoder.embed_positions.weight"))
enc.layer_norm.weight.data.copy_(get("encoder.layer_norm.weight"))
enc.layer_norm.bias.data.copy_(get("encoder.layer_norm.bias"))
for i, layer in enumerate(enc.layers):
prefix = f"encoder.layers.{i}"
layer.self_attn.q_proj.weight.data.copy_(get(f"{prefix}.self_attn.q_proj.weight"))
layer.self_attn.q_proj.bias.data.copy_(get(f"{prefix}.self_attn.q_proj.bias"))
layer.self_attn.k_proj.weight.data.copy_(get(f"{prefix}.self_attn.k_proj.weight"))
layer.self_attn.v_proj.weight.data.copy_(get(f"{prefix}.self_attn.v_proj.weight"))
layer.self_attn.v_proj.bias.data.copy_(get(f"{prefix}.self_attn.v_proj.bias"))
layer.self_attn.out_proj.weight.data.copy_(get(f"{prefix}.self_attn.out_proj.weight"))
layer.self_attn.out_proj.bias.data.copy_(get(f"{prefix}.self_attn.out_proj.bias"))
layer.self_attn_layer_norm.weight.data.copy_(get(f"{prefix}.self_attn_layer_norm.weight"))
layer.self_attn_layer_norm.bias.data.copy_(get(f"{prefix}.self_attn_layer_norm.bias"))
layer.fc1.weight.data.copy_(get(f"{prefix}.fc1.weight"))
layer.fc1.bias.data.copy_(get(f"{prefix}.fc1.bias"))
layer.fc2.weight.data.copy_(get(f"{prefix}.fc2.weight"))
layer.fc2.bias.data.copy_(get(f"{prefix}.fc2.bias"))
layer.final_layer_norm.weight.data.copy_(get(f"{prefix}.final_layer_norm.weight"))
layer.final_layer_norm.bias.data.copy_(get(f"{prefix}.final_layer_norm.bias"))
dec = model.decoder
dec.embed_tokens.weight.data.copy_(get("decoder.embed_tokens.weight"))
dec.embed_positions.weight.data.copy_(get("decoder.embed_positions.weight"))
dec.layer_norm.weight.data.copy_(get("decoder.layer_norm.weight"))
dec.layer_norm.bias.data.copy_(get("decoder.layer_norm.bias"))
for i, layer in enumerate(dec.layers):
prefix = f"decoder.layers.{i}"
layer.self_attn.q_proj.weight.data.copy_(get(f"{prefix}.self_attn.q_proj.weight"))
layer.self_attn.q_proj.bias.data.copy_(get(f"{prefix}.self_attn.q_proj.bias"))
layer.self_attn.k_proj.weight.data.copy_(get(f"{prefix}.self_attn.k_proj.weight"))
layer.self_attn.v_proj.weight.data.copy_(get(f"{prefix}.self_attn.v_proj.weight"))
layer.self_attn.v_proj.bias.data.copy_(get(f"{prefix}.self_attn.v_proj.bias"))
layer.self_attn.out_proj.weight.data.copy_(get(f"{prefix}.self_attn.out_proj.weight"))
layer.self_attn.out_proj.bias.data.copy_(get(f"{prefix}.self_attn.out_proj.bias"))
layer.self_attn_layer_norm.weight.data.copy_(get(f"{prefix}.self_attn_layer_norm.weight"))
layer.self_attn_layer_norm.bias.data.copy_(get(f"{prefix}.self_attn_layer_norm.bias"))
layer.encoder_attn.q_proj.weight.data.copy_(get(f"{prefix}.encoder_attn.q_proj.weight"))
layer.encoder_attn.q_proj.bias.data.copy_(get(f"{prefix}.encoder_attn.q_proj.bias"))
layer.encoder_attn.k_proj.weight.data.copy_(get(f"{prefix}.encoder_attn.k_proj.weight"))
layer.encoder_attn.v_proj.weight.data.copy_(get(f"{prefix}.encoder_attn.v_proj.weight"))
layer.encoder_attn.v_proj.bias.data.copy_(get(f"{prefix}.encoder_attn.v_proj.bias"))
layer.encoder_attn.out_proj.weight.data.copy_(get(f"{prefix}.encoder_attn.out_proj.weight"))
layer.encoder_attn.out_proj.bias.data.copy_(get(f"{prefix}.encoder_attn.out_proj.bias"))
layer.encoder_attn_layer_norm.weight.data.copy_(get(f"{prefix}.encoder_attn_layer_norm.weight"))
layer.encoder_attn_layer_norm.bias.data.copy_(get(f"{prefix}.encoder_attn_layer_norm.bias"))
layer.fc1.weight.data.copy_(get(f"{prefix}.fc1.weight"))
layer.fc1.bias.data.copy_(get(f"{prefix}.fc1.bias"))
layer.fc2.weight.data.copy_(get(f"{prefix}.fc2.weight"))
layer.fc2.bias.data.copy_(get(f"{prefix}.fc2.bias"))
layer.final_layer_norm.weight.data.copy_(get(f"{prefix}.final_layer_norm.weight"))
layer.final_layer_norm.bias.data.copy_(get(f"{prefix}.final_layer_norm.bias"))
# ---------------------------------------------------------------------------
# Audio loading + decoding
# ---------------------------------------------------------------------------
def load_wav_16k_mono(path: Path) -> np.ndarray:
with wave.open(str(path), "rb") as w:
sr = w.getframerate()
n = w.getnframes()
ch = w.getnchannels()
sw = w.getsampwidth()
raw = w.readframes(n)
if sw == 2:
samples = np.frombuffer(raw, dtype=np.int16).astype(np.float32) / 32768.0
elif sw == 4:
samples = np.frombuffer(raw, dtype=np.int32).astype(np.float32) / 2147483648.0
elif sw == 1:
samples = (np.frombuffer(raw, dtype=np.uint8).astype(np.float32) - 128.0) / 128.0
else:
raise ValueError(f"unsupported sample width {sw}")
if ch > 1:
samples = samples.reshape(-1, ch).mean(axis=1)
if sr != 16000:
ratio = sr / 16000
out_len = int(len(samples) / ratio)
idx = np.arange(out_len, dtype=np.float64) * ratio
lo = idx.astype(np.int64)
frac = (idx - lo).astype(np.float32)
hi = np.clip(lo + 1, 0, len(samples) - 1)
samples = samples[lo] * (1.0 - frac) + samples[hi] * frac
return samples.astype(np.float32)
def greedy_decode(logits_row: torch.Tensor, suppress_first_eot: bool) -> int:
masked = logits_row.clone()
masked[TOKEN_SOT:] = float("-inf")
if suppress_first_eot:
masked[TOKEN_EOT] = float("-inf")
return int(torch.argmax(masked).item())
def find_default_audio() -> Optional[Path]:
here = Path(__file__).resolve()
workspace_root = here.parents[3]
candidate = workspace_root / "examples" / "whisper" / "assets" / "jfk.wav"
return candidate if candidate.exists() else None
def main() -> None:
audio_arg = sys.argv[1] if len(sys.argv) > 1 else None
if audio_arg:
audio_path = Path(audio_arg)
else:
audio_path = find_default_audio()
if audio_path is None:
print("error: no audio file given and bundled jfk.wav not found", file=sys.stderr)
sys.exit(1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print("Loading audio:", audio_path)
audio = load_wav_16k_mono(audio_path)
print("Computing log-mel features...")
feature_extractor = WhisperFeatureExtractor.from_pretrained(REPO_ID)
features = feature_extractor(audio, sampling_rate=16000, return_tensors="pt")
mel: torch.Tensor = features.input_features[0].to(device) # (80, 3000)
assert mel.shape == (N_MELS, 3000), mel.shape
print("Building model and loading weights...")
model = Whisper().eval().to(device)
load_hf_weights_into(model)
model = model.to(device)
tokenizer = WhisperTokenizer.from_pretrained(REPO_ID)
use_compiled = os.environ.get("LUMINAL_DISABLE", "0") != "1"
max_new_tokens = int(os.environ.get("GEN_TOKENS", "100"))
search_iters = int(os.environ.get("SEARCH_ITERATIONS", "10"))
if use_compiled:
# 1. Run the encoder once eagerly. The audio doesn't change during decode,
# so xa is a constant input to the decoder.
with torch.no_grad():
xa = model.encoder(mel)
# 2. Wrap the decoder so its only varying input is `tokens`, then compile
# once with a dynamic length dim. Subsequent calls reuse the same
# compiled graph — no recompile per token.
decoder_only = DecoderWithFixedXa(model.decoder, xa).eval().to(device)
example_tokens = torch.tensor(
[TOKEN_SOT, TOKEN_NO_TIMESTAMPS], dtype=torch.long, device=device
)
print(f"Compiling decoder with dynamic seq dim (search_iters={search_iters})...")
compile_start = time.time()
compiled_decoder = luminal_compile(
decoder_only,
example_tokens,
search_iterations=search_iters,
dynamic_dim=0,
)
print(f"Compiled in {time.time() - compile_start:.1f}s")
def step_logits(decoder_input_ids: torch.Tensor) -> torch.Tensor:
out = compiled_decoder(decoder_input_ids)
return out[0] if isinstance(out, tuple) else out
else:
def step_logits(decoder_input_ids: torch.Tensor) -> torch.Tensor:
return model(mel, decoder_input_ids)
tokens = [TOKEN_SOT, TOKEN_NO_TIMESTAMPS]
print("Transcribing", end="", flush=True)
decode_start = time.time()
for step in range(max_new_tokens):
decoder_input_ids = torch.tensor(tokens, dtype=torch.long, device=device)
with torch.no_grad():
logits = step_logits(decoder_input_ids)
next_token = greedy_decode(logits[-1], suppress_first_eot=(step == 0))
if next_token == TOKEN_EOT:
break
tokens.append(next_token)
piece = tokenizer.decode([next_token], skip_special_tokens=False)
print(piece, end="", flush=True)
elapsed = time.time() - decode_start
print()
transcription = tokenizer.decode(tokens[2:], skip_special_tokens=True)
print(f"\nFinal transcription: {transcription}")
print(
f"Generated {len(tokens) - 2} tokens in {elapsed:.2f}s "
f"({(len(tokens) - 2) / max(elapsed, 1e-6):.1f} tok/s)"
)
if __name__ == "__main__":
main()

View File

@@ -68,6 +68,8 @@ impl<'a> Translator<'a> {
"torch.ops.aten.sigmoid.default" => self.translate_unary_op(node, |a| a.sigmoid())?,
"torch.ops.aten.relu.default" => self.translate_unary_op(node, |a| a.relu())?,
"torch.ops.aten.tanh.default" => self.translate_unary_op(node, |a| a.tanh())?,
"torch.ops.aten.silu.default" => self.translate_unary_op(node, |a| a.silu())?,
"torch.ops.aten.gelu.default" => self.translate_unary_op(node, |a| a.gelu())?,
"torch.ops.aten.abs.default" => self.translate_unary_op(node, |a| a.abs())?,
"torch.ops.aten.log.default" => self.translate_unary_op(node, |a| a.log())?,
"torch.ops.aten.log2.default" => self.translate_unary_op(node, |a| a.log2())?,

View File

@@ -0,0 +1,168 @@
"""Whisper integration tests for the luminal torch.compile backend.
These tests build a PyTorch port of ``openai/whisper-tiny.en`` (the same one
exercised by ``examples/whisper.py``) and verify that running it through
``torch.compile(..., backend=luminal_backend)`` produces logits that match the
eager-mode PyTorch reference, both with random-init small configs and with the
real pretrained tiny.en weights.
"""
from __future__ import annotations
import sys
from pathlib import Path
from typing import Callable
import pytest
import torch
import torch._dynamo
# Reuse the PyTorch port defined in the example script so we test exactly the
# code that runs the demo.
EXAMPLES_DIR = Path(__file__).resolve().parent.parent / "examples"
sys.path.insert(0, str(EXAMPLES_DIR))
import whisper as whisper_demo # noqa: E402 (path-modified import)
from luminal import luminal_backend # noqa: E402
def _make_small_whisper(seed: int = 0) -> whisper_demo.Whisper:
torch.manual_seed(seed)
model = whisper_demo.Whisper().eval()
return model
def _max_diff(a: torch.Tensor, b: torch.Tensor) -> float:
return torch.max(torch.abs(a - b)).item()
def test_whisper_attention_forward(device: torch.device):
"""Whisper self-attention: Q/K/V/out projections + scaled dot-product."""
torch.manual_seed(0)
attn = whisper_demo.WhisperAttention().eval().to(device)
compiled: Callable = torch.compile(attn, backend=luminal_backend)
x = torch.rand((4, whisper_demo.D_MODEL), device=device)
with torch.no_grad():
ref = attn(x)
out = compiled(x)
if isinstance(out, tuple):
out = out[0]
assert torch.allclose(out, ref, atol=1e-4), f"max_diff={_max_diff(out, ref):.2e}"
def test_whisper_encoder_layer(device: torch.device):
"""Single encoder block: pre-norm self-attention + FFN with GELU.
Tolerance is loose because luminal uses the tanh GELU approximation rather
than the exact erf form PyTorch uses for ``aten.gelu.default``.
"""
torch.manual_seed(0)
layer = whisper_demo.EncoderLayer().eval().to(device)
compiled: Callable = torch.compile(layer, backend=luminal_backend)
x = torch.rand((8, whisper_demo.D_MODEL), device=device)
with torch.no_grad():
ref = layer(x)
out = compiled(x)
if isinstance(out, tuple):
out = out[0]
assert torch.allclose(out, ref, atol=1e-3), f"max_diff={_max_diff(out, ref):.2e}"
def test_whisper_decoder_layer(device: torch.device):
"""Single decoder block: causal self-attention + cross-attention + FFN."""
torch.manual_seed(0)
layer = whisper_demo.DecoderLayer().eval().to(device)
compiled: Callable = torch.compile(layer, backend=luminal_backend)
x = torch.rand((4, whisper_demo.D_MODEL), device=device)
xa = torch.rand((16, whisper_demo.D_MODEL), device=device)
with torch.no_grad():
ref = layer(x, xa)
out = compiled(x, xa)
if isinstance(out, tuple):
out = out[0]
assert torch.allclose(out, ref, atol=1e-3), f"max_diff={_max_diff(out, ref):.2e}"
def test_whisper_encoder_random_init(device: torch.device):
"""Full encoder over a random mel: 2 conv stems + 4 transformer blocks."""
model = _make_small_whisper().to(device)
compiled: Callable = torch.compile(model.encoder, backend=luminal_backend)
mel = torch.rand((whisper_demo.N_MELS, 3000), device=device)
with torch.no_grad():
ref = model.encoder(mel)
out = compiled(mel)
if isinstance(out, tuple):
out = out[0]
assert torch.allclose(out, ref, atol=1e-3), f"max_diff={_max_diff(out, ref):.2e}"
def test_whisper_full_random_init_one_step(device: torch.device):
"""End-to-end Whisper forward (encoder + decoder for one step) with random weights.
Tolerance is loose because errors accumulate across the conv stems plus the
8 transformer blocks, and luminal uses the tanh GELU approximation rather
than the exact erf form that PyTorch ``aten.gelu.default`` evaluates.
"""
model = _make_small_whisper().to(device)
compiled: Callable = torch.compile(model, backend=luminal_backend)
mel = torch.rand((whisper_demo.N_MELS, 3000), device=device)
tokens = torch.tensor(
[whisper_demo.TOKEN_SOT, whisper_demo.TOKEN_NO_TIMESTAMPS],
dtype=torch.long,
device=device,
)
with torch.no_grad():
ref = model(mel, tokens)
out = compiled(mel, tokens)
if isinstance(out, tuple):
out = out[0]
assert torch.allclose(out, ref, atol=5e-2, rtol=1e-3), (
f"max_diff={_max_diff(out, ref):.2e}"
)
@pytest.mark.slow
def test_whisper_tiny_en_pretrained_first_token(device: torch.device):
"""Real whisper-tiny.en weights: first generated token must match reference.
Uses the bundled JFK sample if available; otherwise a zero-mel placeholder
(the assertion is purely compiled-vs-reference equality, not transcription
correctness).
"""
model = whisper_demo.Whisper().eval()
whisper_demo.load_hf_weights_into(model)
model = model.to(device)
# Try to use the real audio so the comparison is on a realistic mel.
audio_path = whisper_demo.find_default_audio()
if audio_path is None:
mel = torch.zeros((whisper_demo.N_MELS, 3000), device=device)
else:
from transformers import WhisperFeatureExtractor
audio = whisper_demo.load_wav_16k_mono(audio_path)
fe = WhisperFeatureExtractor.from_pretrained(whisper_demo.REPO_ID)
mel = fe(audio, sampling_rate=16000, return_tensors="pt").input_features[0].to(device)
tokens = torch.tensor(
[whisper_demo.TOKEN_SOT, whisper_demo.TOKEN_NO_TIMESTAMPS],
dtype=torch.long,
device=device,
)
torch._dynamo.reset()
compiled: Callable = torch.compile(model, backend=luminal_backend)
with torch.no_grad():
ref = model(mel, tokens)
out = compiled(mel, tokens)
if isinstance(out, tuple):
out = out[0]
# Logits diverge slightly due to the GELU approximation; what matters end
# to end is that the greedy argmax (with whisper's special-token suppression)
# picks the same token.
ref_tok = whisper_demo.greedy_decode(ref[-1], suppress_first_eot=True)
out_tok = whisper_demo.greedy_decode(out[-1], suppress_first_eot=True)
assert ref_tok == out_tok, (
f"first token mismatch: ref={ref_tok}, compiled={out_tok}, "
f"logits max_diff={_max_diff(out, ref):.2e}"
)

View File

@@ -0,0 +1,30 @@
[package]
name = "whisper"
version = "0.1.0"
edition = "2021"
[[bin]]
name = "whisper"
path = "src/main.rs"
[dependencies]
luminal = { path = "../.." }
luminal_nn = { path = "../../crates/luminal_nn" }
luminal_cuda_lite = { path = "../../crates/luminal_cuda_lite" }
luminal_tracing = { path = "../../crates/luminal_tracing" }
tokenizers = "0.15.2"
tracing = "0.1.43"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
# HuggingFace model download
hf-hub = { version = "0.4", default-features = false, features = ["rustls-tls", "ureq"] }
safetensors = "0.7.0"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
half = { version = "2.7.1", features = ["bytemuck"] }
bytemuck = "1.24.0"
memmap2 = "0.9.9"
# Audio + signal processing
hound = "3.5"
rustfft = "6.2"

Binary file not shown.

View File

@@ -0,0 +1,236 @@
use rustfft::{num_complex::Complex32, FftPlanner};
use std::io::Cursor;
use std::path::Path;
pub const SAMPLE_RATE: usize = 16_000;
pub const N_FFT: usize = 400;
pub const HOP_LENGTH: usize = 160;
pub const N_MELS: usize = 80;
pub const N_SAMPLES: usize = 30 * SAMPLE_RATE; // 480_000
pub const N_FRAMES: usize = N_SAMPLES / HOP_LENGTH; // 3000
/// Read a 16-bit / 32-bit / float WAV file, downmix to mono and resample to 16 kHz.
pub fn load_wav<P: AsRef<Path>>(path: P) -> Result<Vec<f32>, Box<dyn std::error::Error>> {
let reader = hound::WavReader::open(path)?;
decode_wav(reader)
}
pub fn load_wav_bytes(bytes: &[u8]) -> Result<Vec<f32>, Box<dyn std::error::Error>> {
let reader = hound::WavReader::new(Cursor::new(bytes))?;
decode_wav(reader)
}
fn decode_wav<R: std::io::Read>(
mut reader: hound::WavReader<R>,
) -> Result<Vec<f32>, Box<dyn std::error::Error>> {
let spec = reader.spec();
let channels = spec.channels as usize;
let samples: Vec<f32> = match spec.sample_format {
hound::SampleFormat::Int => {
let max = (1i64 << (spec.bits_per_sample - 1)) as f32;
reader
.samples::<i32>()
.map(|s| s.map(|v| v as f32 / max))
.collect::<Result<Vec<_>, _>>()?
}
hound::SampleFormat::Float => reader.samples::<f32>().collect::<Result<Vec<_>, _>>()?,
};
// Downmix to mono
let mono: Vec<f32> = if channels == 1 {
samples
} else {
samples
.chunks(channels)
.map(|c| c.iter().sum::<f32>() / channels as f32)
.collect()
};
// Resample to 16 kHz with simple linear interpolation if needed
if spec.sample_rate as usize == SAMPLE_RATE {
Ok(mono)
} else {
Ok(resample_linear(
&mono,
spec.sample_rate as usize,
SAMPLE_RATE,
))
}
}
fn resample_linear(input: &[f32], src_rate: usize, dst_rate: usize) -> Vec<f32> {
let ratio = src_rate as f64 / dst_rate as f64;
let out_len = ((input.len() as f64) / ratio).floor() as usize;
let mut out = Vec::with_capacity(out_len);
for i in 0..out_len {
let pos = i as f64 * ratio;
let lo = pos.floor() as usize;
let frac = (pos - lo as f64) as f32;
let a = input[lo.min(input.len() - 1)];
let b = input[(lo + 1).min(input.len() - 1)];
out.push(a * (1.0 - frac) + b * frac);
}
out
}
pub fn pad_or_trim(audio: &[f32], length: usize) -> Vec<f32> {
if audio.len() >= length {
audio[..length].to_vec()
} else {
let mut out = audio.to_vec();
out.resize(length, 0.0);
out
}
}
fn hz_to_mel_slaney(f: f32) -> f32 {
let f_sp = 200.0 / 3.0;
let min_log_hz = 1000.0_f32;
let min_log_mel = min_log_hz / f_sp;
let logstep = (6.4_f32.ln()) / 27.0;
if f >= min_log_hz {
min_log_mel + (f / min_log_hz).ln() / logstep
} else {
f / f_sp
}
}
fn mel_to_hz_slaney(m: f32) -> f32 {
let f_sp = 200.0 / 3.0;
let min_log_hz = 1000.0_f32;
let min_log_mel = min_log_hz / f_sp;
let logstep = (6.4_f32.ln()) / 27.0;
if m >= min_log_mel {
min_log_hz * (logstep * (m - min_log_mel)).exp()
} else {
f_sp * m
}
}
/// Slaney-style mel filterbank that matches `librosa.filters.mel(sr, n_fft, n_mels)`.
/// Returned shape: (n_mels, n_fft/2 + 1).
pub fn mel_filters(sr: usize, n_fft: usize, n_mels: usize) -> Vec<Vec<f32>> {
let n_freqs = n_fft / 2 + 1;
let fmin = 0.0_f32;
let fmax = sr as f32 / 2.0;
let fft_freqs: Vec<f32> = (0..n_freqs)
.map(|i| i as f32 * (sr as f32 / 2.0) / (n_freqs as f32 - 1.0))
.collect();
let mel_min = hz_to_mel_slaney(fmin);
let mel_max = hz_to_mel_slaney(fmax);
let mel_points: Vec<f32> = (0..n_mels + 2)
.map(|i| {
let m = mel_min + (mel_max - mel_min) * i as f32 / (n_mels + 1) as f32;
mel_to_hz_slaney(m)
})
.collect();
let fdiff: Vec<f32> = (0..n_mels + 1)
.map(|i| mel_points[i + 1] - mel_points[i])
.collect();
let mut weights = vec![vec![0.0_f32; n_freqs]; n_mels];
for i in 0..n_mels {
let enorm = 2.0 / (mel_points[i + 2] - mel_points[i]);
for j in 0..n_freqs {
let lower = (fft_freqs[j] - mel_points[i]) / fdiff[i];
let upper = (mel_points[i + 2] - fft_freqs[j]) / fdiff[i + 1];
let v = lower.min(upper).max(0.0);
weights[i][j] = v * enorm;
}
}
weights
}
fn hann_window(n: usize) -> Vec<f32> {
(0..n)
.map(|i| 0.5 - 0.5 * (2.0 * std::f32::consts::PI * i as f32 / n as f32).cos())
.collect()
}
/// Compute log-mel spectrogram with whisper's preprocessing:
/// - reflect-pad input by N_FFT/2 on each side (matches torch.stft center=True)
/// - hann window of size N_FFT
/// - hop = HOP_LENGTH
/// - drop the last frame (matches whisper's stft[..., :-1])
/// - magnitudes squared, project through mel filterbank, log10
/// - clamp at max - 8.0 then (x + 4) / 4
///
/// Output shape: (n_mels, n_frames) flattened row-major.
pub fn log_mel_spectrogram(audio: &[f32], n_mels: usize) -> Vec<f32> {
assert!(audio.len() == N_SAMPLES, "expected {} samples", N_SAMPLES);
let pad = N_FFT / 2;
let mut padded = vec![0.0_f32; audio.len() + 2 * pad];
// Reflect padding (without endpoint repetition, matching torch.nn.functional.pad reflect)
for i in 0..pad {
padded[pad - 1 - i] = audio[i + 1];
}
padded[pad..pad + audio.len()].copy_from_slice(audio);
let n = audio.len();
for i in 0..pad {
padded[pad + n + i] = audio[n - 2 - i];
}
let n_frames = (padded.len() - N_FFT) / HOP_LENGTH + 1;
debug_assert!(n_frames > N_FRAMES);
let window = hann_window(N_FFT);
let mut planner = FftPlanner::<f32>::new();
let fft = planner.plan_fft_forward(N_FFT);
let n_freqs = N_FFT / 2 + 1;
// magnitudes^2: (n_freqs, n_frames - 1) — drop the trailing frame at the end
let used_frames = n_frames - 1;
let mut magnitudes = vec![0.0_f32; n_freqs * used_frames];
let mut buffer: Vec<Complex32> = vec![Complex32::new(0.0, 0.0); N_FFT];
for f in 0..used_frames {
let start = f * HOP_LENGTH;
for i in 0..N_FFT {
buffer[i] = Complex32::new(padded[start + i] * window[i], 0.0);
}
fft.process(&mut buffer);
for k in 0..n_freqs {
let c = buffer[k];
magnitudes[k * used_frames + f] = c.norm_sqr();
}
}
// Apply mel filterbank: (n_mels, n_freqs) @ (n_freqs, used_frames) → (n_mels, used_frames)
let filters = mel_filters(SAMPLE_RATE, N_FFT, n_mels);
let mut log_spec = vec![0.0_f32; n_mels * used_frames];
for m in 0..n_mels {
for f in 0..used_frames {
let mut acc = 0.0_f32;
for k in 0..n_freqs {
acc += filters[m][k] * magnitudes[k * used_frames + f];
}
log_spec[m * used_frames + f] = acc;
}
}
// log10 with floor
for v in log_spec.iter_mut() {
*v = v.max(1e-10).log10();
}
// clamp at max - 8.0
let max_val = log_spec.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let floor_val = max_val - 8.0;
for v in log_spec.iter_mut() {
if *v < floor_val {
*v = floor_val;
}
}
// (x + 4) / 4
for v in log_spec.iter_mut() {
*v = (*v + 4.0) / 4.0;
}
log_spec
}

View File

@@ -0,0 +1,15 @@
use hf_hub::api::sync::Api;
use std::path::PathBuf;
/// Downloads whisper model files (tokenizer.json + model.safetensors) from HuggingFace.
/// Returns the path of the cache directory containing both files.
pub fn prepare_hf_model(repo_id: &str) -> Result<PathBuf, Box<dyn std::error::Error>> {
let api = Api::new()?;
let repo = api.model(repo_id.to_string());
let tokenizer_path = repo.get("tokenizer.json")?;
let model_dir = tokenizer_path.parent().unwrap().to_path_buf();
repo.get("model.safetensors")?;
Ok(model_dir)
}

View File

@@ -0,0 +1,202 @@
mod audio;
mod hf;
mod model;
use audio::{
load_wav, load_wav_bytes, log_mel_spectrogram, pad_or_trim, N_FRAMES, N_MELS, N_SAMPLES,
};
use hf::prepare_hf_model;
use luminal::prelude::*;
use luminal_cuda_lite::{cudarc::driver::CudaContext, runtime::CudaRuntime};
use model::*;
use std::{io::Write, time::Instant};
use tokenizers::Tokenizer;
const REPO_ID: &str = "openai/whisper-tiny.en";
/// Bundled JFK sample (16 kHz mono PCM WAV, ~11 seconds) so the example runs out of the box
/// without needing a local audio file.
const DEFAULT_AUDIO_BYTES: &[u8] = include_bytes!("../assets/jfk.wav");
fn env_usize(name: &str, default: usize) -> usize {
std::env::var(name)
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(default)
}
fn main() {
let max_target_pos = N_TEXT_CTX; // 448
let gen_tokens = env_usize("GEN_TOKENS", 200);
let search_graphs = env_usize("SEARCH_GRAPHS", 50);
let audio_path = std::env::args().nth(1);
let ctx = CudaContext::new(0).unwrap();
let stream = ctx.default_stream();
let model_dir = prepare_hf_model(REPO_ID).expect("Failed to prepare model");
println!("Using model directory: {}", model_dir.display());
let tokenizer = Tokenizer::from_file(model_dir.join("tokenizer.json")).unwrap();
let audio = match audio_path.as_deref() {
Some(path) => {
println!("Loading audio: {path}");
load_wav(path).expect("failed to load audio")
}
None => {
println!("Using bundled JFK sample audio");
load_wav_bytes(DEFAULT_AUDIO_BYTES).expect("failed to decode bundled audio")
}
};
let audio = pad_or_trim(&audio, N_SAMPLES);
println!("Computing log-mel spectrogram...");
let mel_data = log_mel_spectrogram(&audio, N_MELS);
assert_eq!(mel_data.len(), N_MELS * N_FRAMES);
// Build graph
let mut cx = Graph::default();
let mel_tensor = cx.named_tensor("mel", (N_MELS, N_FRAMES)).persist();
let input = cx.named_tensor("input", 's').as_dtype(DType::Int);
let pos_ids = cx.named_tensor("pos_ids", 's').as_dtype(DType::Int);
let kv_cache = KVCache::new(&mut cx, max_target_pos);
let whisper = Whisper::init(&mut cx);
let xa = whisper.encoder.forward(mel_tensor);
let (logits, cache_outputs) = whisper.decoder.forward(input, pos_ids, xa, &kv_cache);
let logits = logits.output();
for (k_out, v_out) in &cache_outputs {
k_out.output();
v_out.output();
}
println!("Building E-Graph...");
cx.build_search_space::<CudaRuntime>();
println!("Loading weights...");
let mut runtime = CudaRuntime::initialize(stream);
let weights_path = model_dir.join("model.safetensors");
runtime.load_safetensors(&cx, weights_path.to_str().unwrap());
let cache_bytes_per_layer =
N_TEXT_HEAD * max_target_pos * HEAD_DIM * std::mem::size_of::<f32>();
for i in 0..N_TEXT_LAYER {
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes_per_layer);
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes_per_layer);
}
// Set the mel spectrogram once.
runtime.set_data(mel_tensor, mel_data.clone());
println!("Compiling...");
cx.set_dim('s', 1);
cx.set_dim('p', 1);
runtime.set_data(input, vec![1i32]);
runtime.set_data(pos_ids, vec![1i32]);
runtime = cx.search(runtime, search_graphs);
// Reset the KV caches and re-set the mel after search (which executes test runs).
for i in 0..N_TEXT_LAYER {
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes_per_layer);
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes_per_layer);
}
runtime.set_data(mel_tensor, mel_data);
// -- Decoding loop --
// For tiny.en, decoder starts with [<|startoftranscript|>, <|notimestamps|>] and we process
// these prefill tokens one at a time (as the other luminal examples do), then sample
// greedily token-by-token.
let prompt: Vec<u32> = vec![TOKEN_SOT, TOKEN_NO_TIMESTAMPS];
let mut prev_seq = 0usize;
let mut next_input: i32 = prompt[0] as i32;
let mut prompt_idx = 1usize;
let mut generated: Vec<u32> = Vec::new();
let mut step = 0usize;
print!("Transcription:");
std::io::stdout().flush().unwrap();
let start = Instant::now();
loop {
cx.set_dim('s', 1);
cx.set_dim('p', prev_seq);
runtime.set_data(input, vec![next_input]);
runtime.set_data(pos_ids, vec![prev_seq as i32]);
runtime.execute(&cx.dyn_map);
let logits_data = runtime.get_f32(logits);
// Round-trip the KV caches
for (i, (k_out, v_out)) in cache_outputs.iter().enumerate() {
let k_buf = runtime.remove_buffer(*k_out);
let v_buf = runtime.remove_buffer(*v_out);
runtime.set_buffer(kv_cache.k_caches[i], k_buf);
runtime.set_buffer(kv_cache.v_caches[i], v_buf);
}
prev_seq += 1;
if prompt_idx < prompt.len() {
// Still feeding prefill tokens; ignore the logits.
next_input = prompt[prompt_idx] as i32;
prompt_idx += 1;
continue;
}
let last_row = logits_data[logits_data.len() - N_VOCAB..].to_vec();
let next_token = greedy_decode(&last_row, step == 0);
if next_token == TOKEN_EOT {
break;
}
if let Ok(decoded) = tokenizer.decode(&[next_token], false) {
print!("{decoded}");
std::io::stdout().flush().unwrap();
}
generated.push(next_token);
next_input = next_token as i32;
step += 1;
if step >= gen_tokens {
break;
}
if prev_seq >= max_target_pos - 1 {
break;
}
}
let elapsed = start.elapsed();
println!();
println!(
"Decoded {} tokens in {:.2}s ({:.1} tok/s)",
generated.len(),
elapsed.as_secs_f64(),
generated.len() as f64 / elapsed.as_secs_f64().max(1e-6),
);
}
/// Greedy argmax with whisper-style suppression of special tokens.
fn greedy_decode(logits: &[f32], is_first_step: bool) -> u32 {
debug_assert_eq!(logits.len(), N_VOCAB);
let mut best_idx = 0usize;
let mut best_val = f32::NEG_INFINITY;
for (i, &v) in logits.iter().enumerate() {
// Suppress all special / language / timestamp tokens (except <|endoftext|>).
// For tiny.en these live in the >= 50257 range. We allow only TOKEN_EOT.
if i as u32 != TOKEN_EOT && i >= TOKEN_SOT as usize {
continue;
}
// Suppress <|endoftext|> on the very first generated step to avoid empty output.
if is_first_step && i as u32 == TOKEN_EOT {
continue;
}
if v > best_val {
best_val = v;
best_idx = i;
}
}
best_idx as u32
}

View File

@@ -0,0 +1,492 @@
use luminal::{dtype::DType, graph::Graph, prelude::GraphTensor, shape::Expression};
use luminal_nn::LayerNorm;
// whisper-tiny.en hyperparameters
pub const N_MELS: usize = 80;
pub const N_AUDIO_CTX: usize = 1500;
pub const N_AUDIO_STATE: usize = 384;
pub const N_AUDIO_HEAD: usize = 6;
pub const N_AUDIO_LAYER: usize = 4;
pub const N_TEXT_CTX: usize = 448;
pub const N_TEXT_STATE: usize = 384;
pub const N_TEXT_HEAD: usize = 6;
pub const N_TEXT_LAYER: usize = 4;
pub const HEAD_DIM: usize = N_AUDIO_STATE / N_AUDIO_HEAD; // 64
pub const FF_DIM: usize = 4 * N_AUDIO_STATE; // 1536
pub const N_VOCAB: usize = 51864;
pub const LAYER_NORM_EPS: f32 = 1e-5;
pub const TOKEN_SOT: u32 = 50257; // <|startoftranscript|>
pub const TOKEN_NO_TIMESTAMPS: u32 = 50362;
pub const TOKEN_EOT: u32 = 50256;
fn linear_with_bias(x: GraphTensor, w: GraphTensor, b: GraphTensor) -> GraphTensor {
let out = x.matmul(w.t());
let prefix: Vec<Expression> = out.dims()[..out.dims().len() - 1].to_vec();
out + b.expand_lhs(prefix)
}
fn linear_no_bias(x: GraphTensor, w: GraphTensor) -> GraphTensor {
x.matmul(w.t())
}
/// 1D convolution with bias. Input: (ch_in, length). Weight: (ch_out, ch_in*kernel)
/// (HF stores it as (ch_out, ch_in, kernel) which flat-loads identically). Output: (ch_out, out_length).
fn conv1d_bias(
x: GraphTensor,
weight: GraphTensor,
bias: GraphTensor,
kernel: usize,
stride: usize,
padding: usize,
) -> GraphTensor {
let padded = x.pad(
vec![
(Expression::from(0), Expression::from(0)),
(Expression::from(padding), Expression::from(padding)),
],
0.0,
);
let unfolded = padded.unfold([1usize, kernel], [1usize, stride], [1usize, 1usize]);
// unfolded: (ch_in, n_windows, 1, kernel)
let unfolded = unfolded.squeeze(2);
// (ch_in, n_windows, kernel) -> (n_windows, ch_in, kernel) -> (n_windows, ch_in*kernel)
let permuted = unfolded.permute((1, 0, 2));
let flat = permuted.merge_dims(1, 2);
// (n_windows, ch_in*kernel) @ (ch_in*kernel, ch_out) -> (n_windows, ch_out)
let out = flat.matmul(weight.t());
let n_windows = out.dims()[0];
let bias_expanded = bias.expand_dim(0, n_windows);
let out = out + bias_expanded;
// (n_windows, ch_out) -> (ch_out, n_windows)
out.transpose(0, 1)
}
/// Standard LayerNorm with mean-norm, std-norm, weight and bias (matches torch.nn.LayerNorm).
fn standard_layernorm(name: &str, dim: usize, cx: &mut Graph) -> LayerNorm {
LayerNorm::new(
dim,
Some(&format!("{name}.weight")),
Some(&format!("{name}.bias")),
true,
LAYER_NORM_EPS,
cx,
)
}
struct AttentionWeights {
q_proj: GraphTensor,
q_bias: GraphTensor,
k_proj: GraphTensor,
v_proj: GraphTensor,
v_bias: GraphTensor,
out_proj: GraphTensor,
out_bias: GraphTensor,
}
impl AttentionWeights {
fn new(prefix: &str, dim: usize, cx: &mut Graph) -> Self {
Self {
q_proj: cx
.named_tensor(format!("{prefix}.q_proj.weight"), (dim, dim))
.persist(),
q_bias: cx
.named_tensor(format!("{prefix}.q_proj.bias"), dim)
.persist(),
k_proj: cx
.named_tensor(format!("{prefix}.k_proj.weight"), (dim, dim))
.persist(),
v_proj: cx
.named_tensor(format!("{prefix}.v_proj.weight"), (dim, dim))
.persist(),
v_bias: cx
.named_tensor(format!("{prefix}.v_proj.bias"), dim)
.persist(),
out_proj: cx
.named_tensor(format!("{prefix}.out_proj.weight"), (dim, dim))
.persist(),
out_bias: cx
.named_tensor(format!("{prefix}.out_proj.bias"), dim)
.persist(),
}
}
}
fn split_heads(x: GraphTensor) -> GraphTensor {
// (seq, dim) -> (n_heads, seq, head_dim)
x.split_dims(1, HEAD_DIM).transpose(0, 1)
}
fn merge_heads(x: GraphTensor) -> GraphTensor {
// (n_heads, seq, head_dim) -> (seq, n_heads, head_dim) -> (seq, dim)
x.transpose(0, 1).merge_dims(1, 2)
}
/// Encoder self-attention (full, non-causal). Input/output shape (seq, dim).
fn encoder_self_attention(x: GraphTensor, w: &AttentionWeights) -> GraphTensor {
let q = linear_with_bias(x, w.q_proj, w.q_bias);
let k = linear_no_bias(x, w.k_proj);
let v = linear_with_bias(x, w.v_proj, w.v_bias);
let q = split_heads(q);
let k = split_heads(k);
let v = split_heads(v);
let scale = (HEAD_DIM as f32).sqrt().recip();
let scores = q.matmul(k.transpose(1, 2)) * scale;
let weights = scores.softmax(2);
let attn = weights.matmul(v);
let merged = merge_heads(attn);
linear_with_bias(merged, w.out_proj, w.out_bias)
}
/// Decoder self-attention with KV cache. Returns (out, k_cache_out, v_cache_out).
fn decoder_self_attention(
x: GraphTensor,
w: &AttentionWeights,
k_cache_in: GraphTensor,
v_cache_in: GraphTensor,
max_seq: usize,
) -> (GraphTensor, GraphTensor, GraphTensor) {
let cx = x.graph();
let seq = x.dims()[0];
let prev = Expression::from('p');
let total = prev + seq;
let q = linear_with_bias(x, w.q_proj, w.q_bias);
let k = linear_no_bias(x, w.k_proj);
let v = linear_with_bias(x, w.v_proj, w.v_bias);
let k_new = split_heads(k); // (n_heads, seq, head_dim)
let v_new = split_heads(v);
// Build flat scatter indices to write new K/V into the cache at positions [prev..prev+seq).
let h_offset = cx.arange(N_TEXT_HEAD) * (max_seq * HEAD_DIM);
let p_offset = (cx.arange(seq) + prev) * HEAD_DIM;
let d_offset = cx.arange(HEAD_DIM);
let scatter_idx = h_offset.expand_dim(1, seq).expand_dim(2, HEAD_DIM)
+ p_offset.expand_dim(0, N_TEXT_HEAD).expand_dim(2, HEAD_DIM)
+ d_offset.expand_dim(0, N_TEXT_HEAD).expand_dim(1, seq);
let k_cache_out = k_new.scatter(scatter_idx, k_cache_in);
let v_cache_out = v_new.scatter(scatter_idx, v_cache_in);
let k_full = k_cache_out.slice((.., ..total, ..));
let v_full = v_cache_out.slice((.., ..total, ..));
let q = split_heads(q);
let scale = (HEAD_DIM as f32).sqrt().recip();
let scores = q.matmul(k_full.transpose(1, 2)) * scale;
// Causal mask
let q_abs = cx.arange(seq).cast(DType::F32) + prev;
let k_pos = cx.arange(total).cast(DType::F32);
let mask = k_pos.expand_dim(0, seq).gt(q_abs.expand_dim(1, total));
let mask_3d = mask.cast(DType::F32).expand_dim(0, N_TEXT_HEAD);
let masked = scores + mask_3d * (-1e10f32);
let weights = masked.softmax(2);
let attn = weights.matmul(v_full);
let merged = merge_heads(attn);
let out = linear_with_bias(merged, w.out_proj, w.out_bias);
(out, k_cache_out, v_cache_out)
}
/// Cross-attention: query from decoder, key/value from encoder output `xa`.
fn cross_attention(x: GraphTensor, xa: GraphTensor, w: &AttentionWeights) -> GraphTensor {
let q = linear_with_bias(x, w.q_proj, w.q_bias);
let k = linear_no_bias(xa, w.k_proj);
let v = linear_with_bias(xa, w.v_proj, w.v_bias);
let q = split_heads(q);
let k = split_heads(k);
let v = split_heads(v);
let scale = (HEAD_DIM as f32).sqrt().recip();
let scores = q.matmul(k.transpose(1, 2)) * scale;
let weights = scores.softmax(2);
let attn = weights.matmul(v);
let merged = merge_heads(attn);
linear_with_bias(merged, w.out_proj, w.out_bias)
}
struct EncoderLayer {
self_attn: AttentionWeights,
self_attn_ln: LayerNorm,
fc1: GraphTensor,
fc1_b: GraphTensor,
fc2: GraphTensor,
fc2_b: GraphTensor,
final_ln: LayerNorm,
}
impl EncoderLayer {
fn new(idx: usize, cx: &mut Graph) -> Self {
let prefix = format!("model.encoder.layers.{idx}");
Self {
self_attn: AttentionWeights::new(&format!("{prefix}.self_attn"), N_AUDIO_STATE, cx),
self_attn_ln: standard_layernorm(
&format!("{prefix}.self_attn_layer_norm"),
N_AUDIO_STATE,
cx,
),
fc1: cx
.named_tensor(format!("{prefix}.fc1.weight"), (FF_DIM, N_AUDIO_STATE))
.persist(),
fc1_b: cx
.named_tensor(format!("{prefix}.fc1.bias"), FF_DIM)
.persist(),
fc2: cx
.named_tensor(format!("{prefix}.fc2.weight"), (N_AUDIO_STATE, FF_DIM))
.persist(),
fc2_b: cx
.named_tensor(format!("{prefix}.fc2.bias"), N_AUDIO_STATE)
.persist(),
final_ln: standard_layernorm(&format!("{prefix}.final_layer_norm"), N_AUDIO_STATE, cx),
}
}
fn forward(&self, x: GraphTensor) -> GraphTensor {
let h = self.self_attn_ln.forward(x);
let h = encoder_self_attention(h, &self.self_attn);
let x = x + h;
let h = self.final_ln.forward(x);
let h = linear_with_bias(h, self.fc1, self.fc1_b).gelu();
let h = linear_with_bias(h, self.fc2, self.fc2_b);
x + h
}
}
struct DecoderLayer {
self_attn: AttentionWeights,
self_attn_ln: LayerNorm,
cross_attn: AttentionWeights,
cross_attn_ln: LayerNorm,
fc1: GraphTensor,
fc1_b: GraphTensor,
fc2: GraphTensor,
fc2_b: GraphTensor,
final_ln: LayerNorm,
}
impl DecoderLayer {
fn new(idx: usize, cx: &mut Graph) -> Self {
let prefix = format!("model.decoder.layers.{idx}");
Self {
self_attn: AttentionWeights::new(&format!("{prefix}.self_attn"), N_TEXT_STATE, cx),
self_attn_ln: standard_layernorm(
&format!("{prefix}.self_attn_layer_norm"),
N_TEXT_STATE,
cx,
),
cross_attn: AttentionWeights::new(&format!("{prefix}.encoder_attn"), N_TEXT_STATE, cx),
cross_attn_ln: standard_layernorm(
&format!("{prefix}.encoder_attn_layer_norm"),
N_TEXT_STATE,
cx,
),
fc1: cx
.named_tensor(format!("{prefix}.fc1.weight"), (FF_DIM, N_TEXT_STATE))
.persist(),
fc1_b: cx
.named_tensor(format!("{prefix}.fc1.bias"), FF_DIM)
.persist(),
fc2: cx
.named_tensor(format!("{prefix}.fc2.weight"), (N_TEXT_STATE, FF_DIM))
.persist(),
fc2_b: cx
.named_tensor(format!("{prefix}.fc2.bias"), N_TEXT_STATE)
.persist(),
final_ln: standard_layernorm(&format!("{prefix}.final_layer_norm"), N_TEXT_STATE, cx),
}
}
fn forward(
&self,
x: GraphTensor,
xa: GraphTensor,
k_cache_in: GraphTensor,
v_cache_in: GraphTensor,
max_seq: usize,
) -> (GraphTensor, GraphTensor, GraphTensor) {
let h = self.self_attn_ln.forward(x);
let (h, k_out, v_out) =
decoder_self_attention(h, &self.self_attn, k_cache_in, v_cache_in, max_seq);
let x = x + h;
let h = self.cross_attn_ln.forward(x);
let h = cross_attention(h, xa, &self.cross_attn);
let x = x + h;
let h = self.final_ln.forward(x);
let h = linear_with_bias(h, self.fc1, self.fc1_b).gelu();
let h = linear_with_bias(h, self.fc2, self.fc2_b);
(x + h, k_out, v_out)
}
}
pub struct KVCache {
pub k_caches: Vec<GraphTensor>,
pub v_caches: Vec<GraphTensor>,
pub max_seq: usize,
}
impl KVCache {
pub fn new(cx: &mut Graph, max_seq: usize) -> Self {
let mut k_caches = Vec::with_capacity(N_TEXT_LAYER);
let mut v_caches = Vec::with_capacity(N_TEXT_LAYER);
for l in 0..N_TEXT_LAYER {
let k = cx
.named_tensor(format!("kv_cache.{l}.k"), (N_TEXT_HEAD, max_seq, HEAD_DIM))
.persist();
let v = cx
.named_tensor(format!("kv_cache.{l}.v"), (N_TEXT_HEAD, max_seq, HEAD_DIM))
.persist();
k_caches.push(k);
v_caches.push(v);
}
Self {
k_caches,
v_caches,
max_seq,
}
}
}
pub struct WhisperEncoder {
conv1_w: GraphTensor,
conv1_b: GraphTensor,
conv2_w: GraphTensor,
conv2_b: GraphTensor,
positional_embedding: GraphTensor,
layers: Vec<EncoderLayer>,
layer_norm: LayerNorm,
}
impl WhisperEncoder {
pub fn init(cx: &mut Graph) -> Self {
Self {
conv1_w: cx
.named_tensor("model.encoder.conv1.weight", (N_AUDIO_STATE, N_MELS * 3))
.persist(),
conv1_b: cx
.named_tensor("model.encoder.conv1.bias", N_AUDIO_STATE)
.persist(),
conv2_w: cx
.named_tensor(
"model.encoder.conv2.weight",
(N_AUDIO_STATE, N_AUDIO_STATE * 3),
)
.persist(),
conv2_b: cx
.named_tensor("model.encoder.conv2.bias", N_AUDIO_STATE)
.persist(),
positional_embedding: cx
.named_tensor(
"model.encoder.embed_positions.weight",
(N_AUDIO_CTX, N_AUDIO_STATE),
)
.persist(),
layers: (0..N_AUDIO_LAYER)
.map(|i| EncoderLayer::new(i, cx))
.collect(),
layer_norm: standard_layernorm("model.encoder.layer_norm", N_AUDIO_STATE, cx),
}
}
/// Input mel spectrogram: (N_MELS, 3000). Output: (N_AUDIO_CTX=1500, N_AUDIO_STATE).
pub fn forward(&self, mel: GraphTensor) -> GraphTensor {
let h = conv1d_bias(mel, self.conv1_w, self.conv1_b, 3, 1, 1).gelu();
let h = conv1d_bias(h, self.conv2_w, self.conv2_b, 3, 2, 1).gelu();
// h: (N_AUDIO_STATE, N_AUDIO_CTX) -> (N_AUDIO_CTX, N_AUDIO_STATE)
let mut x = h.transpose(0, 1) + self.positional_embedding;
for layer in &self.layers {
x = layer.forward(x);
}
self.layer_norm.forward(x)
}
}
pub struct WhisperDecoder {
embed_tokens: GraphTensor,
embed_positions: GraphTensor,
layers: Vec<DecoderLayer>,
layer_norm: LayerNorm,
}
impl WhisperDecoder {
pub fn init(cx: &mut Graph) -> Self {
Self {
embed_tokens: cx
.named_tensor("model.decoder.embed_tokens.weight", (N_VOCAB, N_TEXT_STATE))
.persist(),
embed_positions: cx
.named_tensor(
"model.decoder.embed_positions.weight",
(N_TEXT_CTX, N_TEXT_STATE),
)
.persist(),
layers: (0..N_TEXT_LAYER)
.map(|i| DecoderLayer::new(i, cx))
.collect(),
layer_norm: standard_layernorm("model.decoder.layer_norm", N_TEXT_STATE, cx),
}
}
pub fn forward(
&self,
token_ids: GraphTensor,
pos_ids: GraphTensor,
xa: GraphTensor,
kv_cache: &KVCache,
) -> (GraphTensor, Vec<(GraphTensor, GraphTensor)>) {
let seq = token_ids.dims1();
// Token embedding gather
let mut x = self.embed_tokens.gather(
(token_ids * N_TEXT_STATE).expand_dim(1, N_TEXT_STATE)
+ token_ids.graph().arange(N_TEXT_STATE).expand_dim(0, seq),
);
// Positional embedding gather (using pos_ids)
let pos_emb = self.embed_positions.gather(
(pos_ids * N_TEXT_STATE).expand_dim(1, N_TEXT_STATE)
+ pos_ids.graph().arange(N_TEXT_STATE).expand_dim(0, seq),
);
x += pos_emb;
let mut cache_outputs = Vec::with_capacity(N_TEXT_LAYER);
for (i, layer) in self.layers.iter().enumerate() {
let (x_new, k_out, v_out) = layer.forward(
x,
xa,
kv_cache.k_caches[i],
kv_cache.v_caches[i],
kv_cache.max_seq,
);
x = x_new;
cache_outputs.push((k_out, v_out));
}
let x = self.layer_norm.forward(x);
// Tied embeddings: projection to vocab
let logits = x.matmul(self.embed_tokens.t());
(logits, cache_outputs)
}
}
pub struct Whisper {
pub encoder: WhisperEncoder,
pub decoder: WhisperDecoder,
}
impl Whisper {
pub fn init(cx: &mut Graph) -> Self {
Self {
encoder: WhisperEncoder::init(cx),
decoder: WhisperDecoder::init(cx),
}
}
}