mirror of
https://git.teahaven.kr/Rust-related/luminal.git
synced 2026-06-04 16:49:49 +09:00
whisper example
This commit is contained in:
2
.github/workflows/modal-examples.yml
vendored
2
.github/workflows/modal-examples.yml
vendored
@@ -22,7 +22,7 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
example: [llama, gemma, qwen, qwen3_moe, gemma4_moe]
|
||||
example: [llama, gemma, qwen, qwen3_moe, gemma4_moe, whisper]
|
||||
gpu:
|
||||
- { type: "A100-80GB" }
|
||||
# To add more GPUs, just append another entry:
|
||||
|
||||
@@ -40,6 +40,9 @@ EXPECTED_OUTPUT = {
|
||||
"gemma4_moe": [
|
||||
"city of romance, art and culture",
|
||||
],
|
||||
"whisper": [
|
||||
"ask not what your country can do for you",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -782,3 +782,14 @@ identical across all attempts (dtype issue) vs varying (actual numerical issue).
|
||||
3. **Why "defensive fallback" framing is misleading**: it implies the LLIR is broken. It isn't. The forward-walk-only `body_nodes` definition just doesn't cover this case, because the case requires no per-iter cloning at all. A *node not reachable from any loop input marker has no input-marker ancestor*, so by construction its value doesn't depend on the loop's per-iter state.
|
||||
4. **Cleaner formulation**: name the concept. Compute an `iteration_invariant_slots: HashSet<LoopStart>` set at the same time `start_meta` is built, with the rule `body_producer ∉ body_nodes ⇒ iteration_invariant`. `resolve_src` and `marker_post_sub` then have explicit branches: if the slot is invariant, use `body_producer` directly; otherwise the standard per-iter clone lookup. The behavior is the same as the `unwrap_or` band-aid, but the code now documents that this is a real, sound case the unroll handles correctly — not a panic suppressor.
|
||||
5. **Principle**: when an `unwrap_or` papers over a case that turns out to be semantically valid, the right cleanup isn't to keep the `unwrap_or` and add a comment — it's to name the case. Hoist the predicate into a set or enum and branch on it explicitly. The compiler then enforces that every consumer of the per-iter cloning machinery has an opinion on iteration-invariant slots, instead of silently relying on a `Map::get` returning `None` at the right moment.
|
||||
|
||||
## 2026-05-02 — Whisper port hit two missing-translator pitfalls
|
||||
|
||||
1. **Symptom**: Compiling a PyTorch port of Whisper-tiny.en through `luminal_backend` failed twice in a row at the dispatch table: first with `Unsupported ATen op: torch.ops.aten.gelu.default`, then with `full: unsupported fill value type ... -Infinity`.
|
||||
2. **Root cause #1**: the dispatch table in `crates/luminal_python/rust/src/translator/dispatch.rs` mapped `sigmoid`, `tanh`, `relu` etc. but not `gelu` or `silu`. Whisper's encoder uses `F.gelu`, so the activation hit a hole.
|
||||
3. **Root cause #2**: PyTorch serializes `float("-inf")` in PT2 as the string `"-Infinity"` (and `"NaN"`/`"Infinity"` analogously). `translate_full`'s `get_float_arg` only accepts numeric float/int payloads, so any `torch.full((..), -inf)` (the obvious way to write a causal mask) blows up. Decoder mask code is the most common spot.
|
||||
4. **Why it was tricky**: both errors arrive from inside `pt2_backend` with a stack trace that ends in `process_pt2`, hiding the actual ATen target inside the message. You only see the offending op name in the error string itself, so you have to read `RuntimeError: Failed to translate node N: …` carefully and grep `dispatch.rs` for it.
|
||||
5. **Fix in this session**:
|
||||
- Added `aten.gelu.default → a.gelu()` and `aten.silu.default → a.silu()` to `dispatch.rs`.
|
||||
- Worked around the `-Infinity` issue at the model level by using a finite `-1e10` for the causal mask in the example (matches the Rust example's convention). The cleaner fix (parsing `"-Infinity"`/`"Infinity"`/`"NaN"` strings in `get_float_arg` / `translate_full`) is left for a follow-up.
|
||||
6. **Principle**: when adding a new model that goes through the PT2 backend, expect to plug small holes in `dispatch.rs` and `translator/tensor.rs::translate_full`. The trace points at the python frame, not the Rust dispatch arm — open `dispatch.rs`, ctrl-F the offending op name, and add the one-liner. For float-shaped sentinel values (`-inf`, `inf`, `nan`), the export pipeline currently only accepts finite floats; either rewrite the model or extend the parser.
|
||||
|
||||
429
crates/luminal_python/examples/whisper.py
Normal file
429
crates/luminal_python/examples/whisper.py
Normal file
@@ -0,0 +1,429 @@
|
||||
"""Whisper transcription demo using the luminal torch.compile backend.
|
||||
|
||||
Implements a small PyTorch port of ``openai/whisper-tiny.en`` that mirrors the
|
||||
luminal Rust example (``examples/whisper`` in the workspace), loads the official
|
||||
HuggingFace weights, and runs greedy decoding through the luminal backend via
|
||||
``torch.compile``.
|
||||
|
||||
Usage::
|
||||
|
||||
uv run python examples/whisper.py [path/to/audio.wav]
|
||||
|
||||
If no path is provided, falls back to the JFK sample bundled with the Rust
|
||||
``examples/whisper`` crate.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import wave
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch._dynamo
|
||||
import torch.nn.functional as F
|
||||
from transformers import WhisperFeatureExtractor, WhisperForConditionalGeneration, WhisperTokenizer
|
||||
|
||||
from luminal.pt2 import compile as luminal_compile
|
||||
|
||||
REPO_ID = "openai/whisper-tiny.en"
|
||||
|
||||
# whisper-tiny.en hyperparameters
|
||||
N_MELS = 80
|
||||
N_AUDIO_CTX = 1500
|
||||
D_MODEL = 384
|
||||
N_HEADS = 6
|
||||
HEAD_DIM = D_MODEL // N_HEADS
|
||||
N_AUDIO_LAYER = 4
|
||||
N_TEXT_LAYER = 4
|
||||
N_TEXT_CTX = 448
|
||||
FF_DIM = 4 * D_MODEL
|
||||
N_VOCAB = 51864
|
||||
LAYER_NORM_EPS = 1e-5
|
||||
|
||||
# Decoder special tokens
|
||||
TOKEN_SOT = 50257
|
||||
TOKEN_NO_TIMESTAMPS = 50362
|
||||
TOKEN_EOT = 50256
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Model — mirrors the HLIR encoder/decoder in examples/whisper/src/model.rs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class WhisperAttention(torch.nn.Module):
|
||||
"""Multi-head attention with separate q/k/v projections (no bias on k_proj)."""
|
||||
|
||||
def __init__(self, d_model: int = D_MODEL, n_heads: int = N_HEADS):
|
||||
super().__init__()
|
||||
self.n_heads = n_heads
|
||||
self.head_dim = d_model // n_heads
|
||||
self.q_proj = torch.nn.Linear(d_model, d_model, bias=True)
|
||||
self.k_proj = torch.nn.Linear(d_model, d_model, bias=False)
|
||||
self.v_proj = torch.nn.Linear(d_model, d_model, bias=True)
|
||||
self.out_proj = torch.nn.Linear(d_model, d_model, bias=True)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
kv_input: Optional[torch.Tensor] = None,
|
||||
causal: bool = False,
|
||||
) -> torch.Tensor:
|
||||
# x: (seq, d_model). kv_input is None → self-attn; otherwise cross-attn.
|
||||
kv = x if kv_input is None else kv_input
|
||||
q = self.q_proj(x)
|
||||
k = self.k_proj(kv)
|
||||
v = self.v_proj(kv)
|
||||
|
||||
seq_q = q.shape[0]
|
||||
seq_kv = k.shape[0]
|
||||
|
||||
# (seq, d_model) -> (n_heads, seq, head_dim)
|
||||
q = q.reshape(seq_q, self.n_heads, self.head_dim).transpose(0, 1)
|
||||
k = k.reshape(seq_kv, self.n_heads, self.head_dim).transpose(0, 1)
|
||||
v = v.reshape(seq_kv, self.n_heads, self.head_dim).transpose(0, 1)
|
||||
|
||||
scale = 1.0 / (self.head_dim ** 0.5)
|
||||
scores = torch.matmul(q, k.transpose(-2, -1)) * scale # (h, sq, sk)
|
||||
if causal:
|
||||
# Use a large finite negative instead of -inf so the export pipeline
|
||||
# serializes a float instead of the unsupported "-Infinity" sentinel.
|
||||
mask = torch.triu(
|
||||
torch.full((seq_q, seq_kv), -1e10, device=x.device),
|
||||
diagonal=1,
|
||||
)
|
||||
scores = scores + mask
|
||||
weights = torch.softmax(scores, dim=-1)
|
||||
attn = torch.matmul(weights, v) # (h, sq, hd)
|
||||
merged = attn.transpose(0, 1).reshape(seq_q, -1)
|
||||
return self.out_proj(merged)
|
||||
|
||||
|
||||
class EncoderLayer(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.self_attn = WhisperAttention()
|
||||
self.self_attn_layer_norm = torch.nn.LayerNorm(D_MODEL, eps=LAYER_NORM_EPS)
|
||||
self.fc1 = torch.nn.Linear(D_MODEL, FF_DIM, bias=True)
|
||||
self.fc2 = torch.nn.Linear(FF_DIM, D_MODEL, bias=True)
|
||||
self.final_layer_norm = torch.nn.LayerNorm(D_MODEL, eps=LAYER_NORM_EPS)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = x + self.self_attn(self.self_attn_layer_norm(x))
|
||||
h = self.final_layer_norm(x)
|
||||
h = F.gelu(self.fc1(h))
|
||||
h = self.fc2(h)
|
||||
return x + h
|
||||
|
||||
|
||||
class WhisperEncoder(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv1 = torch.nn.Conv1d(N_MELS, D_MODEL, kernel_size=3, padding=1, bias=True)
|
||||
self.conv2 = torch.nn.Conv1d(D_MODEL, D_MODEL, kernel_size=3, stride=2, padding=1, bias=True)
|
||||
# Position embedding stored as a regular parameter (matches HF layout).
|
||||
self.embed_positions = torch.nn.Embedding(N_AUDIO_CTX, D_MODEL)
|
||||
self.layers = torch.nn.ModuleList([EncoderLayer() for _ in range(N_AUDIO_LAYER)])
|
||||
self.layer_norm = torch.nn.LayerNorm(D_MODEL, eps=LAYER_NORM_EPS)
|
||||
|
||||
def forward(self, mel: torch.Tensor) -> torch.Tensor:
|
||||
# mel: (n_mels, 3000) -> add batch dim for conv1d
|
||||
x = mel.unsqueeze(0)
|
||||
x = F.gelu(self.conv1(x))
|
||||
x = F.gelu(self.conv2(x))
|
||||
# (1, d_model, 1500) -> (1500, d_model)
|
||||
x = x.squeeze(0).transpose(0, 1)
|
||||
x = x + self.embed_positions.weight
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
return self.layer_norm(x)
|
||||
|
||||
|
||||
class DecoderLayer(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.self_attn = WhisperAttention()
|
||||
self.self_attn_layer_norm = torch.nn.LayerNorm(D_MODEL, eps=LAYER_NORM_EPS)
|
||||
self.encoder_attn = WhisperAttention()
|
||||
self.encoder_attn_layer_norm = torch.nn.LayerNorm(D_MODEL, eps=LAYER_NORM_EPS)
|
||||
self.fc1 = torch.nn.Linear(D_MODEL, FF_DIM, bias=True)
|
||||
self.fc2 = torch.nn.Linear(FF_DIM, D_MODEL, bias=True)
|
||||
self.final_layer_norm = torch.nn.LayerNorm(D_MODEL, eps=LAYER_NORM_EPS)
|
||||
|
||||
def forward(self, x: torch.Tensor, xa: torch.Tensor) -> torch.Tensor:
|
||||
x = x + self.self_attn(self.self_attn_layer_norm(x), causal=True)
|
||||
x = x + self.encoder_attn(self.encoder_attn_layer_norm(x), kv_input=xa)
|
||||
h = self.final_layer_norm(x)
|
||||
h = F.gelu(self.fc1(h))
|
||||
h = self.fc2(h)
|
||||
return x + h
|
||||
|
||||
|
||||
class WhisperDecoder(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.embed_tokens = torch.nn.Embedding(N_VOCAB, D_MODEL)
|
||||
self.embed_positions = torch.nn.Embedding(N_TEXT_CTX, D_MODEL)
|
||||
self.layers = torch.nn.ModuleList([DecoderLayer() for _ in range(N_TEXT_LAYER)])
|
||||
self.layer_norm = torch.nn.LayerNorm(D_MODEL, eps=LAYER_NORM_EPS)
|
||||
|
||||
def forward(self, tokens: torch.Tensor, xa: torch.Tensor) -> torch.Tensor:
|
||||
# tokens: (seq,) of int64 — absolute positions are 0..seq-1
|
||||
seq = tokens.shape[0]
|
||||
pos = torch.arange(seq, dtype=torch.long, device=tokens.device)
|
||||
x = self.embed_tokens(tokens) + self.embed_positions(pos)
|
||||
for layer in self.layers:
|
||||
x = layer(x, xa)
|
||||
x = self.layer_norm(x)
|
||||
# Tied projection
|
||||
return torch.matmul(x, self.embed_tokens.weight.transpose(0, 1))
|
||||
|
||||
|
||||
class Whisper(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.encoder = WhisperEncoder()
|
||||
self.decoder = WhisperDecoder()
|
||||
|
||||
def forward(self, mel: torch.Tensor, tokens: torch.Tensor) -> torch.Tensor:
|
||||
xa = self.encoder(mel)
|
||||
return self.decoder(tokens, xa)
|
||||
|
||||
|
||||
class DecoderWithFixedXa(torch.nn.Module):
|
||||
"""Wraps the decoder with the encoder output stored as a buffer.
|
||||
|
||||
The audio is fixed for the whole utterance, so ``xa`` is a constant relative
|
||||
to the per-token decode loop. Storing it as a buffer lets us compile the
|
||||
decoder once with a single dynamic-length ``tokens`` input, avoiding a full
|
||||
recompilation at every step as the sequence grows.
|
||||
"""
|
||||
|
||||
def __init__(self, decoder: WhisperDecoder, xa: torch.Tensor):
|
||||
super().__init__()
|
||||
self.decoder = decoder
|
||||
self.register_buffer("xa", xa)
|
||||
|
||||
def forward(self, tokens: torch.Tensor) -> torch.Tensor:
|
||||
return self.decoder(tokens, self.xa)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Weight loading: HF state_dict -> our model
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def load_hf_weights_into(model: Whisper) -> None:
|
||||
"""Copy HF whisper-tiny.en weights into our matching modules."""
|
||||
hf = WhisperForConditionalGeneration.from_pretrained(REPO_ID).eval()
|
||||
sd = hf.state_dict()
|
||||
|
||||
def get(name: str) -> torch.Tensor:
|
||||
return sd[f"model.{name}"].clone()
|
||||
|
||||
enc = model.encoder
|
||||
enc.conv1.weight.data.copy_(get("encoder.conv1.weight"))
|
||||
enc.conv1.bias.data.copy_(get("encoder.conv1.bias"))
|
||||
enc.conv2.weight.data.copy_(get("encoder.conv2.weight"))
|
||||
enc.conv2.bias.data.copy_(get("encoder.conv2.bias"))
|
||||
enc.embed_positions.weight.data.copy_(get("encoder.embed_positions.weight"))
|
||||
enc.layer_norm.weight.data.copy_(get("encoder.layer_norm.weight"))
|
||||
enc.layer_norm.bias.data.copy_(get("encoder.layer_norm.bias"))
|
||||
for i, layer in enumerate(enc.layers):
|
||||
prefix = f"encoder.layers.{i}"
|
||||
layer.self_attn.q_proj.weight.data.copy_(get(f"{prefix}.self_attn.q_proj.weight"))
|
||||
layer.self_attn.q_proj.bias.data.copy_(get(f"{prefix}.self_attn.q_proj.bias"))
|
||||
layer.self_attn.k_proj.weight.data.copy_(get(f"{prefix}.self_attn.k_proj.weight"))
|
||||
layer.self_attn.v_proj.weight.data.copy_(get(f"{prefix}.self_attn.v_proj.weight"))
|
||||
layer.self_attn.v_proj.bias.data.copy_(get(f"{prefix}.self_attn.v_proj.bias"))
|
||||
layer.self_attn.out_proj.weight.data.copy_(get(f"{prefix}.self_attn.out_proj.weight"))
|
||||
layer.self_attn.out_proj.bias.data.copy_(get(f"{prefix}.self_attn.out_proj.bias"))
|
||||
layer.self_attn_layer_norm.weight.data.copy_(get(f"{prefix}.self_attn_layer_norm.weight"))
|
||||
layer.self_attn_layer_norm.bias.data.copy_(get(f"{prefix}.self_attn_layer_norm.bias"))
|
||||
layer.fc1.weight.data.copy_(get(f"{prefix}.fc1.weight"))
|
||||
layer.fc1.bias.data.copy_(get(f"{prefix}.fc1.bias"))
|
||||
layer.fc2.weight.data.copy_(get(f"{prefix}.fc2.weight"))
|
||||
layer.fc2.bias.data.copy_(get(f"{prefix}.fc2.bias"))
|
||||
layer.final_layer_norm.weight.data.copy_(get(f"{prefix}.final_layer_norm.weight"))
|
||||
layer.final_layer_norm.bias.data.copy_(get(f"{prefix}.final_layer_norm.bias"))
|
||||
|
||||
dec = model.decoder
|
||||
dec.embed_tokens.weight.data.copy_(get("decoder.embed_tokens.weight"))
|
||||
dec.embed_positions.weight.data.copy_(get("decoder.embed_positions.weight"))
|
||||
dec.layer_norm.weight.data.copy_(get("decoder.layer_norm.weight"))
|
||||
dec.layer_norm.bias.data.copy_(get("decoder.layer_norm.bias"))
|
||||
for i, layer in enumerate(dec.layers):
|
||||
prefix = f"decoder.layers.{i}"
|
||||
layer.self_attn.q_proj.weight.data.copy_(get(f"{prefix}.self_attn.q_proj.weight"))
|
||||
layer.self_attn.q_proj.bias.data.copy_(get(f"{prefix}.self_attn.q_proj.bias"))
|
||||
layer.self_attn.k_proj.weight.data.copy_(get(f"{prefix}.self_attn.k_proj.weight"))
|
||||
layer.self_attn.v_proj.weight.data.copy_(get(f"{prefix}.self_attn.v_proj.weight"))
|
||||
layer.self_attn.v_proj.bias.data.copy_(get(f"{prefix}.self_attn.v_proj.bias"))
|
||||
layer.self_attn.out_proj.weight.data.copy_(get(f"{prefix}.self_attn.out_proj.weight"))
|
||||
layer.self_attn.out_proj.bias.data.copy_(get(f"{prefix}.self_attn.out_proj.bias"))
|
||||
layer.self_attn_layer_norm.weight.data.copy_(get(f"{prefix}.self_attn_layer_norm.weight"))
|
||||
layer.self_attn_layer_norm.bias.data.copy_(get(f"{prefix}.self_attn_layer_norm.bias"))
|
||||
layer.encoder_attn.q_proj.weight.data.copy_(get(f"{prefix}.encoder_attn.q_proj.weight"))
|
||||
layer.encoder_attn.q_proj.bias.data.copy_(get(f"{prefix}.encoder_attn.q_proj.bias"))
|
||||
layer.encoder_attn.k_proj.weight.data.copy_(get(f"{prefix}.encoder_attn.k_proj.weight"))
|
||||
layer.encoder_attn.v_proj.weight.data.copy_(get(f"{prefix}.encoder_attn.v_proj.weight"))
|
||||
layer.encoder_attn.v_proj.bias.data.copy_(get(f"{prefix}.encoder_attn.v_proj.bias"))
|
||||
layer.encoder_attn.out_proj.weight.data.copy_(get(f"{prefix}.encoder_attn.out_proj.weight"))
|
||||
layer.encoder_attn.out_proj.bias.data.copy_(get(f"{prefix}.encoder_attn.out_proj.bias"))
|
||||
layer.encoder_attn_layer_norm.weight.data.copy_(get(f"{prefix}.encoder_attn_layer_norm.weight"))
|
||||
layer.encoder_attn_layer_norm.bias.data.copy_(get(f"{prefix}.encoder_attn_layer_norm.bias"))
|
||||
layer.fc1.weight.data.copy_(get(f"{prefix}.fc1.weight"))
|
||||
layer.fc1.bias.data.copy_(get(f"{prefix}.fc1.bias"))
|
||||
layer.fc2.weight.data.copy_(get(f"{prefix}.fc2.weight"))
|
||||
layer.fc2.bias.data.copy_(get(f"{prefix}.fc2.bias"))
|
||||
layer.final_layer_norm.weight.data.copy_(get(f"{prefix}.final_layer_norm.weight"))
|
||||
layer.final_layer_norm.bias.data.copy_(get(f"{prefix}.final_layer_norm.bias"))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Audio loading + decoding
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def load_wav_16k_mono(path: Path) -> np.ndarray:
|
||||
with wave.open(str(path), "rb") as w:
|
||||
sr = w.getframerate()
|
||||
n = w.getnframes()
|
||||
ch = w.getnchannels()
|
||||
sw = w.getsampwidth()
|
||||
raw = w.readframes(n)
|
||||
|
||||
if sw == 2:
|
||||
samples = np.frombuffer(raw, dtype=np.int16).astype(np.float32) / 32768.0
|
||||
elif sw == 4:
|
||||
samples = np.frombuffer(raw, dtype=np.int32).astype(np.float32) / 2147483648.0
|
||||
elif sw == 1:
|
||||
samples = (np.frombuffer(raw, dtype=np.uint8).astype(np.float32) - 128.0) / 128.0
|
||||
else:
|
||||
raise ValueError(f"unsupported sample width {sw}")
|
||||
|
||||
if ch > 1:
|
||||
samples = samples.reshape(-1, ch).mean(axis=1)
|
||||
|
||||
if sr != 16000:
|
||||
ratio = sr / 16000
|
||||
out_len = int(len(samples) / ratio)
|
||||
idx = np.arange(out_len, dtype=np.float64) * ratio
|
||||
lo = idx.astype(np.int64)
|
||||
frac = (idx - lo).astype(np.float32)
|
||||
hi = np.clip(lo + 1, 0, len(samples) - 1)
|
||||
samples = samples[lo] * (1.0 - frac) + samples[hi] * frac
|
||||
|
||||
return samples.astype(np.float32)
|
||||
|
||||
|
||||
def greedy_decode(logits_row: torch.Tensor, suppress_first_eot: bool) -> int:
|
||||
masked = logits_row.clone()
|
||||
masked[TOKEN_SOT:] = float("-inf")
|
||||
if suppress_first_eot:
|
||||
masked[TOKEN_EOT] = float("-inf")
|
||||
return int(torch.argmax(masked).item())
|
||||
|
||||
|
||||
def find_default_audio() -> Optional[Path]:
|
||||
here = Path(__file__).resolve()
|
||||
workspace_root = here.parents[3]
|
||||
candidate = workspace_root / "examples" / "whisper" / "assets" / "jfk.wav"
|
||||
return candidate if candidate.exists() else None
|
||||
|
||||
|
||||
def main() -> None:
|
||||
audio_arg = sys.argv[1] if len(sys.argv) > 1 else None
|
||||
if audio_arg:
|
||||
audio_path = Path(audio_arg)
|
||||
else:
|
||||
audio_path = find_default_audio()
|
||||
if audio_path is None:
|
||||
print("error: no audio file given and bundled jfk.wav not found", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f"Using device: {device}")
|
||||
|
||||
print("Loading audio:", audio_path)
|
||||
audio = load_wav_16k_mono(audio_path)
|
||||
|
||||
print("Computing log-mel features...")
|
||||
feature_extractor = WhisperFeatureExtractor.from_pretrained(REPO_ID)
|
||||
features = feature_extractor(audio, sampling_rate=16000, return_tensors="pt")
|
||||
mel: torch.Tensor = features.input_features[0].to(device) # (80, 3000)
|
||||
assert mel.shape == (N_MELS, 3000), mel.shape
|
||||
|
||||
print("Building model and loading weights...")
|
||||
model = Whisper().eval().to(device)
|
||||
load_hf_weights_into(model)
|
||||
model = model.to(device)
|
||||
tokenizer = WhisperTokenizer.from_pretrained(REPO_ID)
|
||||
|
||||
use_compiled = os.environ.get("LUMINAL_DISABLE", "0") != "1"
|
||||
max_new_tokens = int(os.environ.get("GEN_TOKENS", "100"))
|
||||
search_iters = int(os.environ.get("SEARCH_ITERATIONS", "10"))
|
||||
|
||||
if use_compiled:
|
||||
# 1. Run the encoder once eagerly. The audio doesn't change during decode,
|
||||
# so xa is a constant input to the decoder.
|
||||
with torch.no_grad():
|
||||
xa = model.encoder(mel)
|
||||
|
||||
# 2. Wrap the decoder so its only varying input is `tokens`, then compile
|
||||
# once with a dynamic length dim. Subsequent calls reuse the same
|
||||
# compiled graph — no recompile per token.
|
||||
decoder_only = DecoderWithFixedXa(model.decoder, xa).eval().to(device)
|
||||
example_tokens = torch.tensor(
|
||||
[TOKEN_SOT, TOKEN_NO_TIMESTAMPS], dtype=torch.long, device=device
|
||||
)
|
||||
print(f"Compiling decoder with dynamic seq dim (search_iters={search_iters})...")
|
||||
compile_start = time.time()
|
||||
compiled_decoder = luminal_compile(
|
||||
decoder_only,
|
||||
example_tokens,
|
||||
search_iterations=search_iters,
|
||||
dynamic_dim=0,
|
||||
)
|
||||
print(f"Compiled in {time.time() - compile_start:.1f}s")
|
||||
|
||||
def step_logits(decoder_input_ids: torch.Tensor) -> torch.Tensor:
|
||||
out = compiled_decoder(decoder_input_ids)
|
||||
return out[0] if isinstance(out, tuple) else out
|
||||
else:
|
||||
def step_logits(decoder_input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return model(mel, decoder_input_ids)
|
||||
|
||||
tokens = [TOKEN_SOT, TOKEN_NO_TIMESTAMPS]
|
||||
|
||||
print("Transcribing", end="", flush=True)
|
||||
decode_start = time.time()
|
||||
for step in range(max_new_tokens):
|
||||
decoder_input_ids = torch.tensor(tokens, dtype=torch.long, device=device)
|
||||
with torch.no_grad():
|
||||
logits = step_logits(decoder_input_ids)
|
||||
|
||||
next_token = greedy_decode(logits[-1], suppress_first_eot=(step == 0))
|
||||
if next_token == TOKEN_EOT:
|
||||
break
|
||||
tokens.append(next_token)
|
||||
piece = tokenizer.decode([next_token], skip_special_tokens=False)
|
||||
print(piece, end="", flush=True)
|
||||
elapsed = time.time() - decode_start
|
||||
print()
|
||||
|
||||
transcription = tokenizer.decode(tokens[2:], skip_special_tokens=True)
|
||||
print(f"\nFinal transcription: {transcription}")
|
||||
print(
|
||||
f"Generated {len(tokens) - 2} tokens in {elapsed:.2f}s "
|
||||
f"({(len(tokens) - 2) / max(elapsed, 1e-6):.1f} tok/s)"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -68,6 +68,8 @@ impl<'a> Translator<'a> {
|
||||
"torch.ops.aten.sigmoid.default" => self.translate_unary_op(node, |a| a.sigmoid())?,
|
||||
"torch.ops.aten.relu.default" => self.translate_unary_op(node, |a| a.relu())?,
|
||||
"torch.ops.aten.tanh.default" => self.translate_unary_op(node, |a| a.tanh())?,
|
||||
"torch.ops.aten.silu.default" => self.translate_unary_op(node, |a| a.silu())?,
|
||||
"torch.ops.aten.gelu.default" => self.translate_unary_op(node, |a| a.gelu())?,
|
||||
"torch.ops.aten.abs.default" => self.translate_unary_op(node, |a| a.abs())?,
|
||||
"torch.ops.aten.log.default" => self.translate_unary_op(node, |a| a.log())?,
|
||||
"torch.ops.aten.log2.default" => self.translate_unary_op(node, |a| a.log2())?,
|
||||
|
||||
168
crates/luminal_python/tests/test_whisper.py
Normal file
168
crates/luminal_python/tests/test_whisper.py
Normal file
@@ -0,0 +1,168 @@
|
||||
"""Whisper integration tests for the luminal torch.compile backend.
|
||||
|
||||
These tests build a PyTorch port of ``openai/whisper-tiny.en`` (the same one
|
||||
exercised by ``examples/whisper.py``) and verify that running it through
|
||||
``torch.compile(..., backend=luminal_backend)`` produces logits that match the
|
||||
eager-mode PyTorch reference, both with random-init small configs and with the
|
||||
real pretrained tiny.en weights.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch._dynamo
|
||||
|
||||
# Reuse the PyTorch port defined in the example script so we test exactly the
|
||||
# code that runs the demo.
|
||||
EXAMPLES_DIR = Path(__file__).resolve().parent.parent / "examples"
|
||||
sys.path.insert(0, str(EXAMPLES_DIR))
|
||||
import whisper as whisper_demo # noqa: E402 (path-modified import)
|
||||
|
||||
from luminal import luminal_backend # noqa: E402
|
||||
|
||||
|
||||
def _make_small_whisper(seed: int = 0) -> whisper_demo.Whisper:
|
||||
torch.manual_seed(seed)
|
||||
model = whisper_demo.Whisper().eval()
|
||||
return model
|
||||
|
||||
|
||||
def _max_diff(a: torch.Tensor, b: torch.Tensor) -> float:
|
||||
return torch.max(torch.abs(a - b)).item()
|
||||
|
||||
|
||||
def test_whisper_attention_forward(device: torch.device):
|
||||
"""Whisper self-attention: Q/K/V/out projections + scaled dot-product."""
|
||||
torch.manual_seed(0)
|
||||
attn = whisper_demo.WhisperAttention().eval().to(device)
|
||||
compiled: Callable = torch.compile(attn, backend=luminal_backend)
|
||||
x = torch.rand((4, whisper_demo.D_MODEL), device=device)
|
||||
with torch.no_grad():
|
||||
ref = attn(x)
|
||||
out = compiled(x)
|
||||
if isinstance(out, tuple):
|
||||
out = out[0]
|
||||
assert torch.allclose(out, ref, atol=1e-4), f"max_diff={_max_diff(out, ref):.2e}"
|
||||
|
||||
|
||||
def test_whisper_encoder_layer(device: torch.device):
|
||||
"""Single encoder block: pre-norm self-attention + FFN with GELU.
|
||||
|
||||
Tolerance is loose because luminal uses the tanh GELU approximation rather
|
||||
than the exact erf form PyTorch uses for ``aten.gelu.default``.
|
||||
"""
|
||||
torch.manual_seed(0)
|
||||
layer = whisper_demo.EncoderLayer().eval().to(device)
|
||||
compiled: Callable = torch.compile(layer, backend=luminal_backend)
|
||||
x = torch.rand((8, whisper_demo.D_MODEL), device=device)
|
||||
with torch.no_grad():
|
||||
ref = layer(x)
|
||||
out = compiled(x)
|
||||
if isinstance(out, tuple):
|
||||
out = out[0]
|
||||
assert torch.allclose(out, ref, atol=1e-3), f"max_diff={_max_diff(out, ref):.2e}"
|
||||
|
||||
|
||||
def test_whisper_decoder_layer(device: torch.device):
|
||||
"""Single decoder block: causal self-attention + cross-attention + FFN."""
|
||||
torch.manual_seed(0)
|
||||
layer = whisper_demo.DecoderLayer().eval().to(device)
|
||||
compiled: Callable = torch.compile(layer, backend=luminal_backend)
|
||||
x = torch.rand((4, whisper_demo.D_MODEL), device=device)
|
||||
xa = torch.rand((16, whisper_demo.D_MODEL), device=device)
|
||||
with torch.no_grad():
|
||||
ref = layer(x, xa)
|
||||
out = compiled(x, xa)
|
||||
if isinstance(out, tuple):
|
||||
out = out[0]
|
||||
assert torch.allclose(out, ref, atol=1e-3), f"max_diff={_max_diff(out, ref):.2e}"
|
||||
|
||||
|
||||
def test_whisper_encoder_random_init(device: torch.device):
|
||||
"""Full encoder over a random mel: 2 conv stems + 4 transformer blocks."""
|
||||
model = _make_small_whisper().to(device)
|
||||
compiled: Callable = torch.compile(model.encoder, backend=luminal_backend)
|
||||
mel = torch.rand((whisper_demo.N_MELS, 3000), device=device)
|
||||
with torch.no_grad():
|
||||
ref = model.encoder(mel)
|
||||
out = compiled(mel)
|
||||
if isinstance(out, tuple):
|
||||
out = out[0]
|
||||
assert torch.allclose(out, ref, atol=1e-3), f"max_diff={_max_diff(out, ref):.2e}"
|
||||
|
||||
|
||||
def test_whisper_full_random_init_one_step(device: torch.device):
|
||||
"""End-to-end Whisper forward (encoder + decoder for one step) with random weights.
|
||||
|
||||
Tolerance is loose because errors accumulate across the conv stems plus the
|
||||
8 transformer blocks, and luminal uses the tanh GELU approximation rather
|
||||
than the exact erf form that PyTorch ``aten.gelu.default`` evaluates.
|
||||
"""
|
||||
model = _make_small_whisper().to(device)
|
||||
compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
mel = torch.rand((whisper_demo.N_MELS, 3000), device=device)
|
||||
tokens = torch.tensor(
|
||||
[whisper_demo.TOKEN_SOT, whisper_demo.TOKEN_NO_TIMESTAMPS],
|
||||
dtype=torch.long,
|
||||
device=device,
|
||||
)
|
||||
with torch.no_grad():
|
||||
ref = model(mel, tokens)
|
||||
out = compiled(mel, tokens)
|
||||
if isinstance(out, tuple):
|
||||
out = out[0]
|
||||
assert torch.allclose(out, ref, atol=5e-2, rtol=1e-3), (
|
||||
f"max_diff={_max_diff(out, ref):.2e}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_whisper_tiny_en_pretrained_first_token(device: torch.device):
|
||||
"""Real whisper-tiny.en weights: first generated token must match reference.
|
||||
|
||||
Uses the bundled JFK sample if available; otherwise a zero-mel placeholder
|
||||
(the assertion is purely compiled-vs-reference equality, not transcription
|
||||
correctness).
|
||||
"""
|
||||
model = whisper_demo.Whisper().eval()
|
||||
whisper_demo.load_hf_weights_into(model)
|
||||
model = model.to(device)
|
||||
|
||||
# Try to use the real audio so the comparison is on a realistic mel.
|
||||
audio_path = whisper_demo.find_default_audio()
|
||||
if audio_path is None:
|
||||
mel = torch.zeros((whisper_demo.N_MELS, 3000), device=device)
|
||||
else:
|
||||
from transformers import WhisperFeatureExtractor
|
||||
|
||||
audio = whisper_demo.load_wav_16k_mono(audio_path)
|
||||
fe = WhisperFeatureExtractor.from_pretrained(whisper_demo.REPO_ID)
|
||||
mel = fe(audio, sampling_rate=16000, return_tensors="pt").input_features[0].to(device)
|
||||
|
||||
tokens = torch.tensor(
|
||||
[whisper_demo.TOKEN_SOT, whisper_demo.TOKEN_NO_TIMESTAMPS],
|
||||
dtype=torch.long,
|
||||
device=device,
|
||||
)
|
||||
|
||||
torch._dynamo.reset()
|
||||
compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
with torch.no_grad():
|
||||
ref = model(mel, tokens)
|
||||
out = compiled(mel, tokens)
|
||||
if isinstance(out, tuple):
|
||||
out = out[0]
|
||||
# Logits diverge slightly due to the GELU approximation; what matters end
|
||||
# to end is that the greedy argmax (with whisper's special-token suppression)
|
||||
# picks the same token.
|
||||
ref_tok = whisper_demo.greedy_decode(ref[-1], suppress_first_eot=True)
|
||||
out_tok = whisper_demo.greedy_decode(out[-1], suppress_first_eot=True)
|
||||
assert ref_tok == out_tok, (
|
||||
f"first token mismatch: ref={ref_tok}, compiled={out_tok}, "
|
||||
f"logits max_diff={_max_diff(out, ref):.2e}"
|
||||
)
|
||||
30
examples/whisper/Cargo.toml
Normal file
30
examples/whisper/Cargo.toml
Normal file
@@ -0,0 +1,30 @@
|
||||
[package]
|
||||
name = "whisper"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[[bin]]
|
||||
name = "whisper"
|
||||
path = "src/main.rs"
|
||||
|
||||
[dependencies]
|
||||
luminal = { path = "../.." }
|
||||
luminal_nn = { path = "../../crates/luminal_nn" }
|
||||
luminal_cuda_lite = { path = "../../crates/luminal_cuda_lite" }
|
||||
luminal_tracing = { path = "../../crates/luminal_tracing" }
|
||||
tokenizers = "0.15.2"
|
||||
tracing = "0.1.43"
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||
|
||||
# HuggingFace model download
|
||||
hf-hub = { version = "0.4", default-features = false, features = ["rustls-tls", "ureq"] }
|
||||
safetensors = "0.7.0"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
half = { version = "2.7.1", features = ["bytemuck"] }
|
||||
bytemuck = "1.24.0"
|
||||
memmap2 = "0.9.9"
|
||||
|
||||
# Audio + signal processing
|
||||
hound = "3.5"
|
||||
rustfft = "6.2"
|
||||
BIN
examples/whisper/assets/jfk.wav
Normal file
BIN
examples/whisper/assets/jfk.wav
Normal file
Binary file not shown.
236
examples/whisper/src/audio.rs
Normal file
236
examples/whisper/src/audio.rs
Normal file
@@ -0,0 +1,236 @@
|
||||
use rustfft::{num_complex::Complex32, FftPlanner};
|
||||
use std::io::Cursor;
|
||||
use std::path::Path;
|
||||
|
||||
pub const SAMPLE_RATE: usize = 16_000;
|
||||
pub const N_FFT: usize = 400;
|
||||
pub const HOP_LENGTH: usize = 160;
|
||||
pub const N_MELS: usize = 80;
|
||||
pub const N_SAMPLES: usize = 30 * SAMPLE_RATE; // 480_000
|
||||
pub const N_FRAMES: usize = N_SAMPLES / HOP_LENGTH; // 3000
|
||||
|
||||
/// Read a 16-bit / 32-bit / float WAV file, downmix to mono and resample to 16 kHz.
|
||||
pub fn load_wav<P: AsRef<Path>>(path: P) -> Result<Vec<f32>, Box<dyn std::error::Error>> {
|
||||
let reader = hound::WavReader::open(path)?;
|
||||
decode_wav(reader)
|
||||
}
|
||||
|
||||
pub fn load_wav_bytes(bytes: &[u8]) -> Result<Vec<f32>, Box<dyn std::error::Error>> {
|
||||
let reader = hound::WavReader::new(Cursor::new(bytes))?;
|
||||
decode_wav(reader)
|
||||
}
|
||||
|
||||
fn decode_wav<R: std::io::Read>(
|
||||
mut reader: hound::WavReader<R>,
|
||||
) -> Result<Vec<f32>, Box<dyn std::error::Error>> {
|
||||
let spec = reader.spec();
|
||||
let channels = spec.channels as usize;
|
||||
|
||||
let samples: Vec<f32> = match spec.sample_format {
|
||||
hound::SampleFormat::Int => {
|
||||
let max = (1i64 << (spec.bits_per_sample - 1)) as f32;
|
||||
reader
|
||||
.samples::<i32>()
|
||||
.map(|s| s.map(|v| v as f32 / max))
|
||||
.collect::<Result<Vec<_>, _>>()?
|
||||
}
|
||||
hound::SampleFormat::Float => reader.samples::<f32>().collect::<Result<Vec<_>, _>>()?,
|
||||
};
|
||||
|
||||
// Downmix to mono
|
||||
let mono: Vec<f32> = if channels == 1 {
|
||||
samples
|
||||
} else {
|
||||
samples
|
||||
.chunks(channels)
|
||||
.map(|c| c.iter().sum::<f32>() / channels as f32)
|
||||
.collect()
|
||||
};
|
||||
|
||||
// Resample to 16 kHz with simple linear interpolation if needed
|
||||
if spec.sample_rate as usize == SAMPLE_RATE {
|
||||
Ok(mono)
|
||||
} else {
|
||||
Ok(resample_linear(
|
||||
&mono,
|
||||
spec.sample_rate as usize,
|
||||
SAMPLE_RATE,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
fn resample_linear(input: &[f32], src_rate: usize, dst_rate: usize) -> Vec<f32> {
|
||||
let ratio = src_rate as f64 / dst_rate as f64;
|
||||
let out_len = ((input.len() as f64) / ratio).floor() as usize;
|
||||
let mut out = Vec::with_capacity(out_len);
|
||||
for i in 0..out_len {
|
||||
let pos = i as f64 * ratio;
|
||||
let lo = pos.floor() as usize;
|
||||
let frac = (pos - lo as f64) as f32;
|
||||
let a = input[lo.min(input.len() - 1)];
|
||||
let b = input[(lo + 1).min(input.len() - 1)];
|
||||
out.push(a * (1.0 - frac) + b * frac);
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
pub fn pad_or_trim(audio: &[f32], length: usize) -> Vec<f32> {
|
||||
if audio.len() >= length {
|
||||
audio[..length].to_vec()
|
||||
} else {
|
||||
let mut out = audio.to_vec();
|
||||
out.resize(length, 0.0);
|
||||
out
|
||||
}
|
||||
}
|
||||
|
||||
fn hz_to_mel_slaney(f: f32) -> f32 {
|
||||
let f_sp = 200.0 / 3.0;
|
||||
let min_log_hz = 1000.0_f32;
|
||||
let min_log_mel = min_log_hz / f_sp;
|
||||
let logstep = (6.4_f32.ln()) / 27.0;
|
||||
if f >= min_log_hz {
|
||||
min_log_mel + (f / min_log_hz).ln() / logstep
|
||||
} else {
|
||||
f / f_sp
|
||||
}
|
||||
}
|
||||
|
||||
fn mel_to_hz_slaney(m: f32) -> f32 {
|
||||
let f_sp = 200.0 / 3.0;
|
||||
let min_log_hz = 1000.0_f32;
|
||||
let min_log_mel = min_log_hz / f_sp;
|
||||
let logstep = (6.4_f32.ln()) / 27.0;
|
||||
if m >= min_log_mel {
|
||||
min_log_hz * (logstep * (m - min_log_mel)).exp()
|
||||
} else {
|
||||
f_sp * m
|
||||
}
|
||||
}
|
||||
|
||||
/// Slaney-style mel filterbank that matches `librosa.filters.mel(sr, n_fft, n_mels)`.
|
||||
/// Returned shape: (n_mels, n_fft/2 + 1).
|
||||
pub fn mel_filters(sr: usize, n_fft: usize, n_mels: usize) -> Vec<Vec<f32>> {
|
||||
let n_freqs = n_fft / 2 + 1;
|
||||
let fmin = 0.0_f32;
|
||||
let fmax = sr as f32 / 2.0;
|
||||
|
||||
let fft_freqs: Vec<f32> = (0..n_freqs)
|
||||
.map(|i| i as f32 * (sr as f32 / 2.0) / (n_freqs as f32 - 1.0))
|
||||
.collect();
|
||||
|
||||
let mel_min = hz_to_mel_slaney(fmin);
|
||||
let mel_max = hz_to_mel_slaney(fmax);
|
||||
let mel_points: Vec<f32> = (0..n_mels + 2)
|
||||
.map(|i| {
|
||||
let m = mel_min + (mel_max - mel_min) * i as f32 / (n_mels + 1) as f32;
|
||||
mel_to_hz_slaney(m)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let fdiff: Vec<f32> = (0..n_mels + 1)
|
||||
.map(|i| mel_points[i + 1] - mel_points[i])
|
||||
.collect();
|
||||
|
||||
let mut weights = vec![vec![0.0_f32; n_freqs]; n_mels];
|
||||
for i in 0..n_mels {
|
||||
let enorm = 2.0 / (mel_points[i + 2] - mel_points[i]);
|
||||
for j in 0..n_freqs {
|
||||
let lower = (fft_freqs[j] - mel_points[i]) / fdiff[i];
|
||||
let upper = (mel_points[i + 2] - fft_freqs[j]) / fdiff[i + 1];
|
||||
let v = lower.min(upper).max(0.0);
|
||||
weights[i][j] = v * enorm;
|
||||
}
|
||||
}
|
||||
weights
|
||||
}
|
||||
|
||||
fn hann_window(n: usize) -> Vec<f32> {
|
||||
(0..n)
|
||||
.map(|i| 0.5 - 0.5 * (2.0 * std::f32::consts::PI * i as f32 / n as f32).cos())
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Compute log-mel spectrogram with whisper's preprocessing:
|
||||
/// - reflect-pad input by N_FFT/2 on each side (matches torch.stft center=True)
|
||||
/// - hann window of size N_FFT
|
||||
/// - hop = HOP_LENGTH
|
||||
/// - drop the last frame (matches whisper's stft[..., :-1])
|
||||
/// - magnitudes squared, project through mel filterbank, log10
|
||||
/// - clamp at max - 8.0 then (x + 4) / 4
|
||||
///
|
||||
/// Output shape: (n_mels, n_frames) flattened row-major.
|
||||
pub fn log_mel_spectrogram(audio: &[f32], n_mels: usize) -> Vec<f32> {
|
||||
assert!(audio.len() == N_SAMPLES, "expected {} samples", N_SAMPLES);
|
||||
|
||||
let pad = N_FFT / 2;
|
||||
let mut padded = vec![0.0_f32; audio.len() + 2 * pad];
|
||||
// Reflect padding (without endpoint repetition, matching torch.nn.functional.pad reflect)
|
||||
for i in 0..pad {
|
||||
padded[pad - 1 - i] = audio[i + 1];
|
||||
}
|
||||
padded[pad..pad + audio.len()].copy_from_slice(audio);
|
||||
let n = audio.len();
|
||||
for i in 0..pad {
|
||||
padded[pad + n + i] = audio[n - 2 - i];
|
||||
}
|
||||
|
||||
let n_frames = (padded.len() - N_FFT) / HOP_LENGTH + 1;
|
||||
debug_assert!(n_frames > N_FRAMES);
|
||||
|
||||
let window = hann_window(N_FFT);
|
||||
|
||||
let mut planner = FftPlanner::<f32>::new();
|
||||
let fft = planner.plan_fft_forward(N_FFT);
|
||||
let n_freqs = N_FFT / 2 + 1;
|
||||
|
||||
// magnitudes^2: (n_freqs, n_frames - 1) — drop the trailing frame at the end
|
||||
let used_frames = n_frames - 1;
|
||||
let mut magnitudes = vec![0.0_f32; n_freqs * used_frames];
|
||||
let mut buffer: Vec<Complex32> = vec![Complex32::new(0.0, 0.0); N_FFT];
|
||||
|
||||
for f in 0..used_frames {
|
||||
let start = f * HOP_LENGTH;
|
||||
for i in 0..N_FFT {
|
||||
buffer[i] = Complex32::new(padded[start + i] * window[i], 0.0);
|
||||
}
|
||||
fft.process(&mut buffer);
|
||||
for k in 0..n_freqs {
|
||||
let c = buffer[k];
|
||||
magnitudes[k * used_frames + f] = c.norm_sqr();
|
||||
}
|
||||
}
|
||||
|
||||
// Apply mel filterbank: (n_mels, n_freqs) @ (n_freqs, used_frames) → (n_mels, used_frames)
|
||||
let filters = mel_filters(SAMPLE_RATE, N_FFT, n_mels);
|
||||
|
||||
let mut log_spec = vec![0.0_f32; n_mels * used_frames];
|
||||
for m in 0..n_mels {
|
||||
for f in 0..used_frames {
|
||||
let mut acc = 0.0_f32;
|
||||
for k in 0..n_freqs {
|
||||
acc += filters[m][k] * magnitudes[k * used_frames + f];
|
||||
}
|
||||
log_spec[m * used_frames + f] = acc;
|
||||
}
|
||||
}
|
||||
|
||||
// log10 with floor
|
||||
for v in log_spec.iter_mut() {
|
||||
*v = v.max(1e-10).log10();
|
||||
}
|
||||
// clamp at max - 8.0
|
||||
let max_val = log_spec.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
|
||||
let floor_val = max_val - 8.0;
|
||||
for v in log_spec.iter_mut() {
|
||||
if *v < floor_val {
|
||||
*v = floor_val;
|
||||
}
|
||||
}
|
||||
// (x + 4) / 4
|
||||
for v in log_spec.iter_mut() {
|
||||
*v = (*v + 4.0) / 4.0;
|
||||
}
|
||||
|
||||
log_spec
|
||||
}
|
||||
15
examples/whisper/src/hf.rs
Normal file
15
examples/whisper/src/hf.rs
Normal file
@@ -0,0 +1,15 @@
|
||||
use hf_hub::api::sync::Api;
|
||||
use std::path::PathBuf;
|
||||
|
||||
/// Downloads whisper model files (tokenizer.json + model.safetensors) from HuggingFace.
|
||||
/// Returns the path of the cache directory containing both files.
|
||||
pub fn prepare_hf_model(repo_id: &str) -> Result<PathBuf, Box<dyn std::error::Error>> {
|
||||
let api = Api::new()?;
|
||||
let repo = api.model(repo_id.to_string());
|
||||
|
||||
let tokenizer_path = repo.get("tokenizer.json")?;
|
||||
let model_dir = tokenizer_path.parent().unwrap().to_path_buf();
|
||||
|
||||
repo.get("model.safetensors")?;
|
||||
Ok(model_dir)
|
||||
}
|
||||
202
examples/whisper/src/main.rs
Normal file
202
examples/whisper/src/main.rs
Normal file
@@ -0,0 +1,202 @@
|
||||
mod audio;
|
||||
mod hf;
|
||||
mod model;
|
||||
|
||||
use audio::{
|
||||
load_wav, load_wav_bytes, log_mel_spectrogram, pad_or_trim, N_FRAMES, N_MELS, N_SAMPLES,
|
||||
};
|
||||
use hf::prepare_hf_model;
|
||||
use luminal::prelude::*;
|
||||
use luminal_cuda_lite::{cudarc::driver::CudaContext, runtime::CudaRuntime};
|
||||
use model::*;
|
||||
use std::{io::Write, time::Instant};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
const REPO_ID: &str = "openai/whisper-tiny.en";
|
||||
|
||||
/// Bundled JFK sample (16 kHz mono PCM WAV, ~11 seconds) so the example runs out of the box
|
||||
/// without needing a local audio file.
|
||||
const DEFAULT_AUDIO_BYTES: &[u8] = include_bytes!("../assets/jfk.wav");
|
||||
|
||||
fn env_usize(name: &str, default: usize) -> usize {
|
||||
std::env::var(name)
|
||||
.ok()
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(default)
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let max_target_pos = N_TEXT_CTX; // 448
|
||||
let gen_tokens = env_usize("GEN_TOKENS", 200);
|
||||
let search_graphs = env_usize("SEARCH_GRAPHS", 50);
|
||||
let audio_path = std::env::args().nth(1);
|
||||
|
||||
let ctx = CudaContext::new(0).unwrap();
|
||||
let stream = ctx.default_stream();
|
||||
|
||||
let model_dir = prepare_hf_model(REPO_ID).expect("Failed to prepare model");
|
||||
println!("Using model directory: {}", model_dir.display());
|
||||
|
||||
let tokenizer = Tokenizer::from_file(model_dir.join("tokenizer.json")).unwrap();
|
||||
|
||||
let audio = match audio_path.as_deref() {
|
||||
Some(path) => {
|
||||
println!("Loading audio: {path}");
|
||||
load_wav(path).expect("failed to load audio")
|
||||
}
|
||||
None => {
|
||||
println!("Using bundled JFK sample audio");
|
||||
load_wav_bytes(DEFAULT_AUDIO_BYTES).expect("failed to decode bundled audio")
|
||||
}
|
||||
};
|
||||
let audio = pad_or_trim(&audio, N_SAMPLES);
|
||||
println!("Computing log-mel spectrogram...");
|
||||
let mel_data = log_mel_spectrogram(&audio, N_MELS);
|
||||
assert_eq!(mel_data.len(), N_MELS * N_FRAMES);
|
||||
|
||||
// Build graph
|
||||
let mut cx = Graph::default();
|
||||
let mel_tensor = cx.named_tensor("mel", (N_MELS, N_FRAMES)).persist();
|
||||
let input = cx.named_tensor("input", 's').as_dtype(DType::Int);
|
||||
let pos_ids = cx.named_tensor("pos_ids", 's').as_dtype(DType::Int);
|
||||
let kv_cache = KVCache::new(&mut cx, max_target_pos);
|
||||
|
||||
let whisper = Whisper::init(&mut cx);
|
||||
let xa = whisper.encoder.forward(mel_tensor);
|
||||
let (logits, cache_outputs) = whisper.decoder.forward(input, pos_ids, xa, &kv_cache);
|
||||
|
||||
let logits = logits.output();
|
||||
for (k_out, v_out) in &cache_outputs {
|
||||
k_out.output();
|
||||
v_out.output();
|
||||
}
|
||||
|
||||
println!("Building E-Graph...");
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
|
||||
println!("Loading weights...");
|
||||
let mut runtime = CudaRuntime::initialize(stream);
|
||||
let weights_path = model_dir.join("model.safetensors");
|
||||
runtime.load_safetensors(&cx, weights_path.to_str().unwrap());
|
||||
|
||||
let cache_bytes_per_layer =
|
||||
N_TEXT_HEAD * max_target_pos * HEAD_DIM * std::mem::size_of::<f32>();
|
||||
for i in 0..N_TEXT_LAYER {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes_per_layer);
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes_per_layer);
|
||||
}
|
||||
|
||||
// Set the mel spectrogram once.
|
||||
runtime.set_data(mel_tensor, mel_data.clone());
|
||||
|
||||
println!("Compiling...");
|
||||
cx.set_dim('s', 1);
|
||||
cx.set_dim('p', 1);
|
||||
runtime.set_data(input, vec![1i32]);
|
||||
runtime.set_data(pos_ids, vec![1i32]);
|
||||
runtime = cx.search(runtime, search_graphs);
|
||||
|
||||
// Reset the KV caches and re-set the mel after search (which executes test runs).
|
||||
for i in 0..N_TEXT_LAYER {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes_per_layer);
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes_per_layer);
|
||||
}
|
||||
runtime.set_data(mel_tensor, mel_data);
|
||||
|
||||
// -- Decoding loop --
|
||||
// For tiny.en, decoder starts with [<|startoftranscript|>, <|notimestamps|>] and we process
|
||||
// these prefill tokens one at a time (as the other luminal examples do), then sample
|
||||
// greedily token-by-token.
|
||||
let prompt: Vec<u32> = vec![TOKEN_SOT, TOKEN_NO_TIMESTAMPS];
|
||||
let mut prev_seq = 0usize;
|
||||
let mut next_input: i32 = prompt[0] as i32;
|
||||
let mut prompt_idx = 1usize;
|
||||
let mut generated: Vec<u32> = Vec::new();
|
||||
let mut step = 0usize;
|
||||
|
||||
print!("Transcription:");
|
||||
std::io::stdout().flush().unwrap();
|
||||
|
||||
let start = Instant::now();
|
||||
loop {
|
||||
cx.set_dim('s', 1);
|
||||
cx.set_dim('p', prev_seq);
|
||||
|
||||
runtime.set_data(input, vec![next_input]);
|
||||
runtime.set_data(pos_ids, vec![prev_seq as i32]);
|
||||
|
||||
runtime.execute(&cx.dyn_map);
|
||||
let logits_data = runtime.get_f32(logits);
|
||||
|
||||
// Round-trip the KV caches
|
||||
for (i, (k_out, v_out)) in cache_outputs.iter().enumerate() {
|
||||
let k_buf = runtime.remove_buffer(*k_out);
|
||||
let v_buf = runtime.remove_buffer(*v_out);
|
||||
runtime.set_buffer(kv_cache.k_caches[i], k_buf);
|
||||
runtime.set_buffer(kv_cache.v_caches[i], v_buf);
|
||||
}
|
||||
|
||||
prev_seq += 1;
|
||||
|
||||
if prompt_idx < prompt.len() {
|
||||
// Still feeding prefill tokens; ignore the logits.
|
||||
next_input = prompt[prompt_idx] as i32;
|
||||
prompt_idx += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
let last_row = logits_data[logits_data.len() - N_VOCAB..].to_vec();
|
||||
let next_token = greedy_decode(&last_row, step == 0);
|
||||
|
||||
if next_token == TOKEN_EOT {
|
||||
break;
|
||||
}
|
||||
|
||||
if let Ok(decoded) = tokenizer.decode(&[next_token], false) {
|
||||
print!("{decoded}");
|
||||
std::io::stdout().flush().unwrap();
|
||||
}
|
||||
|
||||
generated.push(next_token);
|
||||
next_input = next_token as i32;
|
||||
step += 1;
|
||||
|
||||
if step >= gen_tokens {
|
||||
break;
|
||||
}
|
||||
if prev_seq >= max_target_pos - 1 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
let elapsed = start.elapsed();
|
||||
println!();
|
||||
println!(
|
||||
"Decoded {} tokens in {:.2}s ({:.1} tok/s)",
|
||||
generated.len(),
|
||||
elapsed.as_secs_f64(),
|
||||
generated.len() as f64 / elapsed.as_secs_f64().max(1e-6),
|
||||
);
|
||||
}
|
||||
|
||||
/// Greedy argmax with whisper-style suppression of special tokens.
|
||||
fn greedy_decode(logits: &[f32], is_first_step: bool) -> u32 {
|
||||
debug_assert_eq!(logits.len(), N_VOCAB);
|
||||
let mut best_idx = 0usize;
|
||||
let mut best_val = f32::NEG_INFINITY;
|
||||
for (i, &v) in logits.iter().enumerate() {
|
||||
// Suppress all special / language / timestamp tokens (except <|endoftext|>).
|
||||
// For tiny.en these live in the >= 50257 range. We allow only TOKEN_EOT.
|
||||
if i as u32 != TOKEN_EOT && i >= TOKEN_SOT as usize {
|
||||
continue;
|
||||
}
|
||||
// Suppress <|endoftext|> on the very first generated step to avoid empty output.
|
||||
if is_first_step && i as u32 == TOKEN_EOT {
|
||||
continue;
|
||||
}
|
||||
if v > best_val {
|
||||
best_val = v;
|
||||
best_idx = i;
|
||||
}
|
||||
}
|
||||
best_idx as u32
|
||||
}
|
||||
492
examples/whisper/src/model.rs
Normal file
492
examples/whisper/src/model.rs
Normal file
@@ -0,0 +1,492 @@
|
||||
use luminal::{dtype::DType, graph::Graph, prelude::GraphTensor, shape::Expression};
|
||||
use luminal_nn::LayerNorm;
|
||||
|
||||
// whisper-tiny.en hyperparameters
|
||||
pub const N_MELS: usize = 80;
|
||||
pub const N_AUDIO_CTX: usize = 1500;
|
||||
pub const N_AUDIO_STATE: usize = 384;
|
||||
pub const N_AUDIO_HEAD: usize = 6;
|
||||
pub const N_AUDIO_LAYER: usize = 4;
|
||||
|
||||
pub const N_TEXT_CTX: usize = 448;
|
||||
pub const N_TEXT_STATE: usize = 384;
|
||||
pub const N_TEXT_HEAD: usize = 6;
|
||||
pub const N_TEXT_LAYER: usize = 4;
|
||||
|
||||
pub const HEAD_DIM: usize = N_AUDIO_STATE / N_AUDIO_HEAD; // 64
|
||||
pub const FF_DIM: usize = 4 * N_AUDIO_STATE; // 1536
|
||||
|
||||
pub const N_VOCAB: usize = 51864;
|
||||
pub const LAYER_NORM_EPS: f32 = 1e-5;
|
||||
|
||||
pub const TOKEN_SOT: u32 = 50257; // <|startoftranscript|>
|
||||
pub const TOKEN_NO_TIMESTAMPS: u32 = 50362;
|
||||
pub const TOKEN_EOT: u32 = 50256;
|
||||
|
||||
fn linear_with_bias(x: GraphTensor, w: GraphTensor, b: GraphTensor) -> GraphTensor {
|
||||
let out = x.matmul(w.t());
|
||||
let prefix: Vec<Expression> = out.dims()[..out.dims().len() - 1].to_vec();
|
||||
out + b.expand_lhs(prefix)
|
||||
}
|
||||
|
||||
fn linear_no_bias(x: GraphTensor, w: GraphTensor) -> GraphTensor {
|
||||
x.matmul(w.t())
|
||||
}
|
||||
|
||||
/// 1D convolution with bias. Input: (ch_in, length). Weight: (ch_out, ch_in*kernel)
|
||||
/// (HF stores it as (ch_out, ch_in, kernel) which flat-loads identically). Output: (ch_out, out_length).
|
||||
fn conv1d_bias(
|
||||
x: GraphTensor,
|
||||
weight: GraphTensor,
|
||||
bias: GraphTensor,
|
||||
kernel: usize,
|
||||
stride: usize,
|
||||
padding: usize,
|
||||
) -> GraphTensor {
|
||||
let padded = x.pad(
|
||||
vec![
|
||||
(Expression::from(0), Expression::from(0)),
|
||||
(Expression::from(padding), Expression::from(padding)),
|
||||
],
|
||||
0.0,
|
||||
);
|
||||
let unfolded = padded.unfold([1usize, kernel], [1usize, stride], [1usize, 1usize]);
|
||||
// unfolded: (ch_in, n_windows, 1, kernel)
|
||||
let unfolded = unfolded.squeeze(2);
|
||||
// (ch_in, n_windows, kernel) -> (n_windows, ch_in, kernel) -> (n_windows, ch_in*kernel)
|
||||
let permuted = unfolded.permute((1, 0, 2));
|
||||
let flat = permuted.merge_dims(1, 2);
|
||||
// (n_windows, ch_in*kernel) @ (ch_in*kernel, ch_out) -> (n_windows, ch_out)
|
||||
let out = flat.matmul(weight.t());
|
||||
let n_windows = out.dims()[0];
|
||||
let bias_expanded = bias.expand_dim(0, n_windows);
|
||||
let out = out + bias_expanded;
|
||||
// (n_windows, ch_out) -> (ch_out, n_windows)
|
||||
out.transpose(0, 1)
|
||||
}
|
||||
|
||||
/// Standard LayerNorm with mean-norm, std-norm, weight and bias (matches torch.nn.LayerNorm).
|
||||
fn standard_layernorm(name: &str, dim: usize, cx: &mut Graph) -> LayerNorm {
|
||||
LayerNorm::new(
|
||||
dim,
|
||||
Some(&format!("{name}.weight")),
|
||||
Some(&format!("{name}.bias")),
|
||||
true,
|
||||
LAYER_NORM_EPS,
|
||||
cx,
|
||||
)
|
||||
}
|
||||
|
||||
struct AttentionWeights {
|
||||
q_proj: GraphTensor,
|
||||
q_bias: GraphTensor,
|
||||
k_proj: GraphTensor,
|
||||
v_proj: GraphTensor,
|
||||
v_bias: GraphTensor,
|
||||
out_proj: GraphTensor,
|
||||
out_bias: GraphTensor,
|
||||
}
|
||||
|
||||
impl AttentionWeights {
|
||||
fn new(prefix: &str, dim: usize, cx: &mut Graph) -> Self {
|
||||
Self {
|
||||
q_proj: cx
|
||||
.named_tensor(format!("{prefix}.q_proj.weight"), (dim, dim))
|
||||
.persist(),
|
||||
q_bias: cx
|
||||
.named_tensor(format!("{prefix}.q_proj.bias"), dim)
|
||||
.persist(),
|
||||
k_proj: cx
|
||||
.named_tensor(format!("{prefix}.k_proj.weight"), (dim, dim))
|
||||
.persist(),
|
||||
v_proj: cx
|
||||
.named_tensor(format!("{prefix}.v_proj.weight"), (dim, dim))
|
||||
.persist(),
|
||||
v_bias: cx
|
||||
.named_tensor(format!("{prefix}.v_proj.bias"), dim)
|
||||
.persist(),
|
||||
out_proj: cx
|
||||
.named_tensor(format!("{prefix}.out_proj.weight"), (dim, dim))
|
||||
.persist(),
|
||||
out_bias: cx
|
||||
.named_tensor(format!("{prefix}.out_proj.bias"), dim)
|
||||
.persist(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn split_heads(x: GraphTensor) -> GraphTensor {
|
||||
// (seq, dim) -> (n_heads, seq, head_dim)
|
||||
x.split_dims(1, HEAD_DIM).transpose(0, 1)
|
||||
}
|
||||
|
||||
fn merge_heads(x: GraphTensor) -> GraphTensor {
|
||||
// (n_heads, seq, head_dim) -> (seq, n_heads, head_dim) -> (seq, dim)
|
||||
x.transpose(0, 1).merge_dims(1, 2)
|
||||
}
|
||||
|
||||
/// Encoder self-attention (full, non-causal). Input/output shape (seq, dim).
|
||||
fn encoder_self_attention(x: GraphTensor, w: &AttentionWeights) -> GraphTensor {
|
||||
let q = linear_with_bias(x, w.q_proj, w.q_bias);
|
||||
let k = linear_no_bias(x, w.k_proj);
|
||||
let v = linear_with_bias(x, w.v_proj, w.v_bias);
|
||||
|
||||
let q = split_heads(q);
|
||||
let k = split_heads(k);
|
||||
let v = split_heads(v);
|
||||
|
||||
let scale = (HEAD_DIM as f32).sqrt().recip();
|
||||
let scores = q.matmul(k.transpose(1, 2)) * scale;
|
||||
let weights = scores.softmax(2);
|
||||
let attn = weights.matmul(v);
|
||||
let merged = merge_heads(attn);
|
||||
linear_with_bias(merged, w.out_proj, w.out_bias)
|
||||
}
|
||||
|
||||
/// Decoder self-attention with KV cache. Returns (out, k_cache_out, v_cache_out).
|
||||
fn decoder_self_attention(
|
||||
x: GraphTensor,
|
||||
w: &AttentionWeights,
|
||||
k_cache_in: GraphTensor,
|
||||
v_cache_in: GraphTensor,
|
||||
max_seq: usize,
|
||||
) -> (GraphTensor, GraphTensor, GraphTensor) {
|
||||
let cx = x.graph();
|
||||
let seq = x.dims()[0];
|
||||
let prev = Expression::from('p');
|
||||
let total = prev + seq;
|
||||
|
||||
let q = linear_with_bias(x, w.q_proj, w.q_bias);
|
||||
let k = linear_no_bias(x, w.k_proj);
|
||||
let v = linear_with_bias(x, w.v_proj, w.v_bias);
|
||||
|
||||
let k_new = split_heads(k); // (n_heads, seq, head_dim)
|
||||
let v_new = split_heads(v);
|
||||
|
||||
// Build flat scatter indices to write new K/V into the cache at positions [prev..prev+seq).
|
||||
let h_offset = cx.arange(N_TEXT_HEAD) * (max_seq * HEAD_DIM);
|
||||
let p_offset = (cx.arange(seq) + prev) * HEAD_DIM;
|
||||
let d_offset = cx.arange(HEAD_DIM);
|
||||
let scatter_idx = h_offset.expand_dim(1, seq).expand_dim(2, HEAD_DIM)
|
||||
+ p_offset.expand_dim(0, N_TEXT_HEAD).expand_dim(2, HEAD_DIM)
|
||||
+ d_offset.expand_dim(0, N_TEXT_HEAD).expand_dim(1, seq);
|
||||
|
||||
let k_cache_out = k_new.scatter(scatter_idx, k_cache_in);
|
||||
let v_cache_out = v_new.scatter(scatter_idx, v_cache_in);
|
||||
|
||||
let k_full = k_cache_out.slice((.., ..total, ..));
|
||||
let v_full = v_cache_out.slice((.., ..total, ..));
|
||||
|
||||
let q = split_heads(q);
|
||||
|
||||
let scale = (HEAD_DIM as f32).sqrt().recip();
|
||||
let scores = q.matmul(k_full.transpose(1, 2)) * scale;
|
||||
|
||||
// Causal mask
|
||||
let q_abs = cx.arange(seq).cast(DType::F32) + prev;
|
||||
let k_pos = cx.arange(total).cast(DType::F32);
|
||||
let mask = k_pos.expand_dim(0, seq).gt(q_abs.expand_dim(1, total));
|
||||
let mask_3d = mask.cast(DType::F32).expand_dim(0, N_TEXT_HEAD);
|
||||
let masked = scores + mask_3d * (-1e10f32);
|
||||
|
||||
let weights = masked.softmax(2);
|
||||
let attn = weights.matmul(v_full);
|
||||
let merged = merge_heads(attn);
|
||||
let out = linear_with_bias(merged, w.out_proj, w.out_bias);
|
||||
(out, k_cache_out, v_cache_out)
|
||||
}
|
||||
|
||||
/// Cross-attention: query from decoder, key/value from encoder output `xa`.
|
||||
fn cross_attention(x: GraphTensor, xa: GraphTensor, w: &AttentionWeights) -> GraphTensor {
|
||||
let q = linear_with_bias(x, w.q_proj, w.q_bias);
|
||||
let k = linear_no_bias(xa, w.k_proj);
|
||||
let v = linear_with_bias(xa, w.v_proj, w.v_bias);
|
||||
|
||||
let q = split_heads(q);
|
||||
let k = split_heads(k);
|
||||
let v = split_heads(v);
|
||||
|
||||
let scale = (HEAD_DIM as f32).sqrt().recip();
|
||||
let scores = q.matmul(k.transpose(1, 2)) * scale;
|
||||
let weights = scores.softmax(2);
|
||||
let attn = weights.matmul(v);
|
||||
let merged = merge_heads(attn);
|
||||
linear_with_bias(merged, w.out_proj, w.out_bias)
|
||||
}
|
||||
|
||||
struct EncoderLayer {
|
||||
self_attn: AttentionWeights,
|
||||
self_attn_ln: LayerNorm,
|
||||
fc1: GraphTensor,
|
||||
fc1_b: GraphTensor,
|
||||
fc2: GraphTensor,
|
||||
fc2_b: GraphTensor,
|
||||
final_ln: LayerNorm,
|
||||
}
|
||||
|
||||
impl EncoderLayer {
|
||||
fn new(idx: usize, cx: &mut Graph) -> Self {
|
||||
let prefix = format!("model.encoder.layers.{idx}");
|
||||
Self {
|
||||
self_attn: AttentionWeights::new(&format!("{prefix}.self_attn"), N_AUDIO_STATE, cx),
|
||||
self_attn_ln: standard_layernorm(
|
||||
&format!("{prefix}.self_attn_layer_norm"),
|
||||
N_AUDIO_STATE,
|
||||
cx,
|
||||
),
|
||||
fc1: cx
|
||||
.named_tensor(format!("{prefix}.fc1.weight"), (FF_DIM, N_AUDIO_STATE))
|
||||
.persist(),
|
||||
fc1_b: cx
|
||||
.named_tensor(format!("{prefix}.fc1.bias"), FF_DIM)
|
||||
.persist(),
|
||||
fc2: cx
|
||||
.named_tensor(format!("{prefix}.fc2.weight"), (N_AUDIO_STATE, FF_DIM))
|
||||
.persist(),
|
||||
fc2_b: cx
|
||||
.named_tensor(format!("{prefix}.fc2.bias"), N_AUDIO_STATE)
|
||||
.persist(),
|
||||
final_ln: standard_layernorm(&format!("{prefix}.final_layer_norm"), N_AUDIO_STATE, cx),
|
||||
}
|
||||
}
|
||||
|
||||
fn forward(&self, x: GraphTensor) -> GraphTensor {
|
||||
let h = self.self_attn_ln.forward(x);
|
||||
let h = encoder_self_attention(h, &self.self_attn);
|
||||
let x = x + h;
|
||||
|
||||
let h = self.final_ln.forward(x);
|
||||
let h = linear_with_bias(h, self.fc1, self.fc1_b).gelu();
|
||||
let h = linear_with_bias(h, self.fc2, self.fc2_b);
|
||||
x + h
|
||||
}
|
||||
}
|
||||
|
||||
struct DecoderLayer {
|
||||
self_attn: AttentionWeights,
|
||||
self_attn_ln: LayerNorm,
|
||||
cross_attn: AttentionWeights,
|
||||
cross_attn_ln: LayerNorm,
|
||||
fc1: GraphTensor,
|
||||
fc1_b: GraphTensor,
|
||||
fc2: GraphTensor,
|
||||
fc2_b: GraphTensor,
|
||||
final_ln: LayerNorm,
|
||||
}
|
||||
|
||||
impl DecoderLayer {
|
||||
fn new(idx: usize, cx: &mut Graph) -> Self {
|
||||
let prefix = format!("model.decoder.layers.{idx}");
|
||||
Self {
|
||||
self_attn: AttentionWeights::new(&format!("{prefix}.self_attn"), N_TEXT_STATE, cx),
|
||||
self_attn_ln: standard_layernorm(
|
||||
&format!("{prefix}.self_attn_layer_norm"),
|
||||
N_TEXT_STATE,
|
||||
cx,
|
||||
),
|
||||
cross_attn: AttentionWeights::new(&format!("{prefix}.encoder_attn"), N_TEXT_STATE, cx),
|
||||
cross_attn_ln: standard_layernorm(
|
||||
&format!("{prefix}.encoder_attn_layer_norm"),
|
||||
N_TEXT_STATE,
|
||||
cx,
|
||||
),
|
||||
fc1: cx
|
||||
.named_tensor(format!("{prefix}.fc1.weight"), (FF_DIM, N_TEXT_STATE))
|
||||
.persist(),
|
||||
fc1_b: cx
|
||||
.named_tensor(format!("{prefix}.fc1.bias"), FF_DIM)
|
||||
.persist(),
|
||||
fc2: cx
|
||||
.named_tensor(format!("{prefix}.fc2.weight"), (N_TEXT_STATE, FF_DIM))
|
||||
.persist(),
|
||||
fc2_b: cx
|
||||
.named_tensor(format!("{prefix}.fc2.bias"), N_TEXT_STATE)
|
||||
.persist(),
|
||||
final_ln: standard_layernorm(&format!("{prefix}.final_layer_norm"), N_TEXT_STATE, cx),
|
||||
}
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&self,
|
||||
x: GraphTensor,
|
||||
xa: GraphTensor,
|
||||
k_cache_in: GraphTensor,
|
||||
v_cache_in: GraphTensor,
|
||||
max_seq: usize,
|
||||
) -> (GraphTensor, GraphTensor, GraphTensor) {
|
||||
let h = self.self_attn_ln.forward(x);
|
||||
let (h, k_out, v_out) =
|
||||
decoder_self_attention(h, &self.self_attn, k_cache_in, v_cache_in, max_seq);
|
||||
let x = x + h;
|
||||
|
||||
let h = self.cross_attn_ln.forward(x);
|
||||
let h = cross_attention(h, xa, &self.cross_attn);
|
||||
let x = x + h;
|
||||
|
||||
let h = self.final_ln.forward(x);
|
||||
let h = linear_with_bias(h, self.fc1, self.fc1_b).gelu();
|
||||
let h = linear_with_bias(h, self.fc2, self.fc2_b);
|
||||
(x + h, k_out, v_out)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct KVCache {
|
||||
pub k_caches: Vec<GraphTensor>,
|
||||
pub v_caches: Vec<GraphTensor>,
|
||||
pub max_seq: usize,
|
||||
}
|
||||
|
||||
impl KVCache {
|
||||
pub fn new(cx: &mut Graph, max_seq: usize) -> Self {
|
||||
let mut k_caches = Vec::with_capacity(N_TEXT_LAYER);
|
||||
let mut v_caches = Vec::with_capacity(N_TEXT_LAYER);
|
||||
for l in 0..N_TEXT_LAYER {
|
||||
let k = cx
|
||||
.named_tensor(format!("kv_cache.{l}.k"), (N_TEXT_HEAD, max_seq, HEAD_DIM))
|
||||
.persist();
|
||||
let v = cx
|
||||
.named_tensor(format!("kv_cache.{l}.v"), (N_TEXT_HEAD, max_seq, HEAD_DIM))
|
||||
.persist();
|
||||
k_caches.push(k);
|
||||
v_caches.push(v);
|
||||
}
|
||||
Self {
|
||||
k_caches,
|
||||
v_caches,
|
||||
max_seq,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct WhisperEncoder {
|
||||
conv1_w: GraphTensor,
|
||||
conv1_b: GraphTensor,
|
||||
conv2_w: GraphTensor,
|
||||
conv2_b: GraphTensor,
|
||||
positional_embedding: GraphTensor,
|
||||
layers: Vec<EncoderLayer>,
|
||||
layer_norm: LayerNorm,
|
||||
}
|
||||
|
||||
impl WhisperEncoder {
|
||||
pub fn init(cx: &mut Graph) -> Self {
|
||||
Self {
|
||||
conv1_w: cx
|
||||
.named_tensor("model.encoder.conv1.weight", (N_AUDIO_STATE, N_MELS * 3))
|
||||
.persist(),
|
||||
conv1_b: cx
|
||||
.named_tensor("model.encoder.conv1.bias", N_AUDIO_STATE)
|
||||
.persist(),
|
||||
conv2_w: cx
|
||||
.named_tensor(
|
||||
"model.encoder.conv2.weight",
|
||||
(N_AUDIO_STATE, N_AUDIO_STATE * 3),
|
||||
)
|
||||
.persist(),
|
||||
conv2_b: cx
|
||||
.named_tensor("model.encoder.conv2.bias", N_AUDIO_STATE)
|
||||
.persist(),
|
||||
positional_embedding: cx
|
||||
.named_tensor(
|
||||
"model.encoder.embed_positions.weight",
|
||||
(N_AUDIO_CTX, N_AUDIO_STATE),
|
||||
)
|
||||
.persist(),
|
||||
layers: (0..N_AUDIO_LAYER)
|
||||
.map(|i| EncoderLayer::new(i, cx))
|
||||
.collect(),
|
||||
layer_norm: standard_layernorm("model.encoder.layer_norm", N_AUDIO_STATE, cx),
|
||||
}
|
||||
}
|
||||
|
||||
/// Input mel spectrogram: (N_MELS, 3000). Output: (N_AUDIO_CTX=1500, N_AUDIO_STATE).
|
||||
pub fn forward(&self, mel: GraphTensor) -> GraphTensor {
|
||||
let h = conv1d_bias(mel, self.conv1_w, self.conv1_b, 3, 1, 1).gelu();
|
||||
let h = conv1d_bias(h, self.conv2_w, self.conv2_b, 3, 2, 1).gelu();
|
||||
// h: (N_AUDIO_STATE, N_AUDIO_CTX) -> (N_AUDIO_CTX, N_AUDIO_STATE)
|
||||
let mut x = h.transpose(0, 1) + self.positional_embedding;
|
||||
for layer in &self.layers {
|
||||
x = layer.forward(x);
|
||||
}
|
||||
self.layer_norm.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct WhisperDecoder {
|
||||
embed_tokens: GraphTensor,
|
||||
embed_positions: GraphTensor,
|
||||
layers: Vec<DecoderLayer>,
|
||||
layer_norm: LayerNorm,
|
||||
}
|
||||
|
||||
impl WhisperDecoder {
|
||||
pub fn init(cx: &mut Graph) -> Self {
|
||||
Self {
|
||||
embed_tokens: cx
|
||||
.named_tensor("model.decoder.embed_tokens.weight", (N_VOCAB, N_TEXT_STATE))
|
||||
.persist(),
|
||||
embed_positions: cx
|
||||
.named_tensor(
|
||||
"model.decoder.embed_positions.weight",
|
||||
(N_TEXT_CTX, N_TEXT_STATE),
|
||||
)
|
||||
.persist(),
|
||||
layers: (0..N_TEXT_LAYER)
|
||||
.map(|i| DecoderLayer::new(i, cx))
|
||||
.collect(),
|
||||
layer_norm: standard_layernorm("model.decoder.layer_norm", N_TEXT_STATE, cx),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward(
|
||||
&self,
|
||||
token_ids: GraphTensor,
|
||||
pos_ids: GraphTensor,
|
||||
xa: GraphTensor,
|
||||
kv_cache: &KVCache,
|
||||
) -> (GraphTensor, Vec<(GraphTensor, GraphTensor)>) {
|
||||
let seq = token_ids.dims1();
|
||||
// Token embedding gather
|
||||
let mut x = self.embed_tokens.gather(
|
||||
(token_ids * N_TEXT_STATE).expand_dim(1, N_TEXT_STATE)
|
||||
+ token_ids.graph().arange(N_TEXT_STATE).expand_dim(0, seq),
|
||||
);
|
||||
// Positional embedding gather (using pos_ids)
|
||||
let pos_emb = self.embed_positions.gather(
|
||||
(pos_ids * N_TEXT_STATE).expand_dim(1, N_TEXT_STATE)
|
||||
+ pos_ids.graph().arange(N_TEXT_STATE).expand_dim(0, seq),
|
||||
);
|
||||
x += pos_emb;
|
||||
|
||||
let mut cache_outputs = Vec::with_capacity(N_TEXT_LAYER);
|
||||
for (i, layer) in self.layers.iter().enumerate() {
|
||||
let (x_new, k_out, v_out) = layer.forward(
|
||||
x,
|
||||
xa,
|
||||
kv_cache.k_caches[i],
|
||||
kv_cache.v_caches[i],
|
||||
kv_cache.max_seq,
|
||||
);
|
||||
x = x_new;
|
||||
cache_outputs.push((k_out, v_out));
|
||||
}
|
||||
let x = self.layer_norm.forward(x);
|
||||
// Tied embeddings: projection to vocab
|
||||
let logits = x.matmul(self.embed_tokens.t());
|
||||
(logits, cache_outputs)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Whisper {
|
||||
pub encoder: WhisperEncoder,
|
||||
pub decoder: WhisperDecoder,
|
||||
}
|
||||
|
||||
impl Whisper {
|
||||
pub fn init(cx: &mut Graph) -> Self {
|
||||
Self {
|
||||
encoder: WhisperEncoder::init(cx),
|
||||
decoder: WhisperDecoder::init(cx),
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user