mirror of
https://github.com/jafioti/luminal.git
synced 2026-06-01 21:49:47 +09:00
Compare commits
16 Commits
bf16-gemma
...
pytest-cla
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7ee5b54438 | ||
|
|
389c05abeb | ||
|
|
dcc2c9cbb4 | ||
|
|
a9af4c3923 | ||
|
|
3092d0d68b | ||
|
|
8a2bd714ac | ||
|
|
54a26a044c | ||
|
|
5a0d3f87cc | ||
|
|
112d064700 | ||
|
|
c51c36fbcb | ||
|
|
ee372d464e | ||
|
|
1bef1344d1 | ||
|
|
8d41c491fd | ||
|
|
64f390a833 | ||
|
|
8d20581f38 | ||
|
|
bfd4ae9b27 |
130
.agents/skills/aoti-debug/SKILL.md
Normal file
130
.agents/skills/aoti-debug/SKILL.md
Normal file
@@ -0,0 +1,130 @@
|
||||
---
|
||||
name: aoti-debug
|
||||
description: Debug AOTInductor (AOTI) errors including device mismatches, CUDA illegal memory access, segfaults, and wrong outputs when deploying compiled PyTorch models. Use when encountering errors with aoti_compile_and_package, aoti_load_package, or the deprecated aot_compile/aot_load APIs.
|
||||
---
|
||||
|
||||
# AOTInductor Debugging
|
||||
|
||||
Debug errors when compiling and deploying PyTorch models with AOTInductor.
|
||||
|
||||
## First Step: Always Check Device and Shape Matching
|
||||
|
||||
**For ANY AOTI error (segfault, exception, crash, wrong output), check these first:**
|
||||
|
||||
1. **Compile device == Load device**: The model must be loaded on the same device type it was compiled on
|
||||
2. **Input devices match**: Runtime inputs must be on the same device as the compiled model
|
||||
3. **Input shapes match**: Runtime input shapes must match compilation shapes (or satisfy dynamic shape constraints)
|
||||
|
||||
```python
|
||||
# Compilation -- note the device and shapes
|
||||
model = MyModel().eval().cuda()
|
||||
inp = torch.randn(2, 10, device="cuda")
|
||||
pkg = torch._inductor.aoti_compile_and_package(model, (inp,))
|
||||
|
||||
# Loading -- device type MUST match compilation
|
||||
loaded = torch._inductor.aoti_load_package(pkg) # auto-detects device from package
|
||||
|
||||
# Inference -- device and shapes MUST match
|
||||
out = loaded(torch.randn(2, 10, device="cuda")) # same device, same shape
|
||||
```
|
||||
|
||||
**AOTI requires compile and load to use the same device type.** Cross-device loading (compile on GPU, load on CPU) is NOT supported. Device index can differ (cuda:0 vs cuda:1).
|
||||
|
||||
## Current vs Deprecated API
|
||||
|
||||
### Current API (use this)
|
||||
```python
|
||||
torch._inductor.aoti_compile_and_package() # compile
|
||||
torch._inductor.aoti_load_package() # load (auto-detects device)
|
||||
```
|
||||
|
||||
### Deprecated API (migrate away)
|
||||
```python
|
||||
torch._export.aot_compile() # deprecated
|
||||
torch._export.aot_load() # deprecated
|
||||
```
|
||||
|
||||
The new API stores device metadata in the package, so `aoti_load_package()` automatically uses the correct device type.
|
||||
|
||||
## Common Error Patterns
|
||||
|
||||
### Device Mismatch Segfault
|
||||
|
||||
**Symptom**: Segfault, exception, or crash during load or execution.
|
||||
|
||||
**Example errors**:
|
||||
- `The specified pointer resides on host memory and is not registered with any CUDA device`
|
||||
- Crash during constant loading
|
||||
- `Expected out tensor to have device cuda:0, but got cpu instead`
|
||||
|
||||
**Solution**: Ensure compile and load use the same device type.
|
||||
|
||||
### Input Device Mismatch at Runtime
|
||||
|
||||
**Symptom**: RuntimeError during model execution.
|
||||
|
||||
**Better debugging**: Run with `AOTI_RUNTIME_CHECK_INPUTS=1` for clear errors:
|
||||
```bash
|
||||
AOTI_RUNTIME_CHECK_INPUTS=1 python script.py
|
||||
```
|
||||
|
||||
Produces actionable messages like:
|
||||
```
|
||||
Error: input_handles[0]: unmatched device type, expected: 0(cpu), but got: 1(cuda)
|
||||
```
|
||||
|
||||
## Debugging CUDA Illegal Memory Access (IMA)
|
||||
|
||||
### Step 1: Sanity Checks
|
||||
|
||||
```bash
|
||||
AOTI_RUNTIME_CHECK_INPUTS=1 python script.py # validate inputs match compilation guards
|
||||
TORCHINDUCTOR_NAN_ASSERTS=1 python script.py # check for NaN before/after each kernel
|
||||
```
|
||||
|
||||
Both flags take effect at **compile time** (codegen time).
|
||||
|
||||
### Step 2: Make IMA Deterministic
|
||||
|
||||
```bash
|
||||
PYTORCH_NO_CUDA_MEMORY_CACHING=1 CUDA_LAUNCH_BLOCKING=1 python script.py
|
||||
```
|
||||
|
||||
- `PYTORCH_NO_CUDA_MEMORY_CACHING=1` -- disables caching allocator (which allocates bigger buffers, masking IMA)
|
||||
- `CUDA_LAUNCH_BLOCKING=1` -- forces synchronous kernel launches (pinpoints which kernel crashed)
|
||||
|
||||
Both take effect at **runtime**.
|
||||
|
||||
### Step 3: Identify the Problematic Kernel
|
||||
|
||||
```bash
|
||||
AOT_INDUCTOR_DEBUG_INTERMEDIATE_VALUE_PRINTER=3 python script.py
|
||||
```
|
||||
|
||||
Prints kernels one by one at runtime. Combined with Step 2 flags, shows which kernel launched right before the error.
|
||||
|
||||
To inspect inputs to specific kernels:
|
||||
```bash
|
||||
AOT_INDUCTOR_FILTERED_KERNELS_TO_PRINT="kernel_name_1,kernel_name_2" \
|
||||
AOT_INDUCTOR_DEBUG_INTERMEDIATE_VALUE_PRINTER=2 python script.py
|
||||
```
|
||||
|
||||
If inputs to a kernel are unexpected, trace back to the kernel that produced the bad input.
|
||||
|
||||
## Environment Variables Reference
|
||||
|
||||
| Variable | When | Purpose |
|
||||
|---|---|---|
|
||||
| `AOTI_RUNTIME_CHECK_INPUTS=1` | Compile time | Validate inputs match compilation guards |
|
||||
| `TORCHINDUCTOR_NAN_ASSERTS=1` | Compile time | Check for NaN before/after kernels |
|
||||
| `PYTORCH_NO_CUDA_MEMORY_CACHING=1` | Runtime | Make IMA errors deterministic |
|
||||
| `CUDA_LAUNCH_BLOCKING=1` | Runtime | Force synchronous kernel launches |
|
||||
| `AOT_INDUCTOR_DEBUG_INTERMEDIATE_VALUE_PRINTER=3` | Compile time | Print kernels at runtime |
|
||||
| `AOT_INDUCTOR_FILTERED_KERNELS_TO_PRINT="..."` | Compile time | Filter which kernels to print |
|
||||
| `TORCH_LOGS="+inductor,output_code"` | Runtime | See PT2 internal logs |
|
||||
| `TORCH_SHOW_CPP_STACKTRACES=1` | Runtime | Show C++ stack traces |
|
||||
|
||||
## Common Sources of Issues
|
||||
|
||||
- **Dynamic shapes**: Historically a common source of IMA errors. Pay special attention when using dynamic shape constraints.
|
||||
- **Custom ops**: Especially C++ custom ops with dynamic shapes. The meta function may need to handle SymInt properly.
|
||||
195
.agents/skills/pt2-debug/SKILL.md
Normal file
195
.agents/skills/pt2-debug/SKILL.md
Normal file
@@ -0,0 +1,195 @@
|
||||
---
|
||||
name: pt2-debug
|
||||
description: Debug torch.compile failures, graph breaks, recompilation issues, accuracy mismatches, and Triton kernel errors. Use when encountering BackendCompilerFailed exceptions, torch.compile errors, recompilation warnings, or numerical accuracy issues with compiled PyTorch models.
|
||||
---
|
||||
|
||||
# PyTorch 2 Compile Debugging
|
||||
|
||||
Debug `torch.compile`, Dynamo, Inductor, and AOTAutograd failures when using PyTorch as a library.
|
||||
|
||||
## Diagnostic Environment Variables
|
||||
|
||||
Pick the right diagnostic based on the error:
|
||||
|
||||
| Command | When to use |
|
||||
|---|---|
|
||||
| `TORCH_LOGS="+dynamo,graph_breaks,recompiles" python script.py` | Quick overview of what's going wrong |
|
||||
| `TORCH_COMPILE_DEBUG=1 python script.py` | Full debug artifacts (FX graphs, Inductor IR, generated code) in `torch_compile_debug/` |
|
||||
| `TORCH_LOGS="output_code" python script.py` | See the generated Triton/C++ kernel code |
|
||||
| `TORCH_TRACE=/path/to/trace python script.py` | Structured trace (parse with `tlparse`) |
|
||||
| `TORCHINDUCTOR_COMPILE_THREADS=1 python script.py` | Single-threaded compilation for pdb debugging |
|
||||
|
||||
## Error Triage
|
||||
|
||||
Classify the failure and jump to the right section:
|
||||
|
||||
| Error Pattern | Category |
|
||||
|---|---|
|
||||
| `Unsupported: ...` or `graph break` in logs | [Graph Breaks](#graph-breaks) |
|
||||
| `BackendCompilerFailed` | [Backend Failures](#backend-compiler-failures) |
|
||||
| `RecompileError` or `cache_size_limit` | [Recompilation](#recompilation-issues) |
|
||||
| Accuracy mismatch / wrong numerical output | [Accuracy](#accuracy-issues) |
|
||||
| `InternalTorchDynamoError` | [Internal Errors](#internal-dynamo-errors) |
|
||||
| Segfault or CUDA IMA | [Runtime Crashes](#runtime-crashes) |
|
||||
| Triton assertion / index out of bounds | [Triton Failures](#triton-kernel-failures) |
|
||||
|
||||
## Graph Breaks
|
||||
|
||||
Graph breaks split the compiled graph into smaller subgraphs, causing performance regressions.
|
||||
|
||||
**Diagnose:**
|
||||
```bash
|
||||
TORCH_LOGS="graph_breaks" python script.py
|
||||
```
|
||||
|
||||
**Common causes:**
|
||||
- Data-dependent control flow
|
||||
- Unsupported Python builtins
|
||||
- In-place ops on inputs, unsupported dtypes
|
||||
- Calls to non-traceable functions
|
||||
|
||||
**Fix approaches:**
|
||||
1. Read the graph break message to identify the unsupported operation
|
||||
2. Check for a decomposition or supported alternative
|
||||
3. Consider `torch._dynamo.allow_in_graph` or restructure user code
|
||||
|
||||
## Backend Compiler Failures
|
||||
|
||||
`BackendCompilerFailed` means Inductor crashed during compilation.
|
||||
|
||||
**Diagnose with the minifier:**
|
||||
```bash
|
||||
# Generate minifier launcher
|
||||
TORCHDYNAMO_REPRO_AFTER=aot TORCHDYNAMO_REPRO_LEVEL=2 python script.py
|
||||
|
||||
# Run the minifier to get minimal failing graph
|
||||
python minifier_launcher.py minify
|
||||
|
||||
# Run the minimized reproduction
|
||||
python minifier_launcher.py run
|
||||
```
|
||||
|
||||
**Then inspect:**
|
||||
```bash
|
||||
TORCH_COMPILE_DEBUG=1 python script.py # FX graphs in torch_compile_debug/
|
||||
```
|
||||
|
||||
## Recompilation Issues
|
||||
|
||||
Excessive recompilation from guards that are too specific, causing cache misses.
|
||||
|
||||
**Diagnose:**
|
||||
```bash
|
||||
TORCH_LOGS="recompiles,recompiles_verbose,guards" python script.py
|
||||
```
|
||||
|
||||
**Key config:**
|
||||
```python
|
||||
torch._dynamo.config.recompile_limit # default: 8
|
||||
torch._dynamo.config.fail_on_recompile_limit_hit = True # hard error on limit
|
||||
```
|
||||
|
||||
**Common causes:**
|
||||
- Changing tensor shapes without marking them dynamic
|
||||
- Python scalar values that change between calls
|
||||
- Global state mutations between calls
|
||||
|
||||
**Fix:** Read the recompilation reason from logs, identify the failing guard, then either:
|
||||
- Mark dimensions as dynamic: `torch._dynamo.mark_dynamic(tensor, dim)`
|
||||
- Fix the source of guard instability
|
||||
|
||||
## Accuracy Issues
|
||||
|
||||
Compiled model produces different numerical results than eager mode.
|
||||
|
||||
**Diagnose:**
|
||||
```bash
|
||||
# Compares compiled vs eager with fp64 reference, dumps repro on failure
|
||||
TORCHDYNAMO_REPRO_AFTER=aot TORCHDYNAMO_REPRO_LEVEL=4 python script.py
|
||||
```
|
||||
|
||||
**Fix approach:**
|
||||
1. Get minimal failing graph from the minifier
|
||||
2. Compare eager vs compiled output at fp64 precision
|
||||
3. Binary search through ops to find the diverging operation
|
||||
4. Check for known issues: reduction order, fused kernels, dtype promotions
|
||||
|
||||
## Internal Dynamo Errors
|
||||
|
||||
`InternalTorchDynamoError` indicates a bug in Dynamo.
|
||||
|
||||
**Diagnose:**
|
||||
```bash
|
||||
TORCHDYNAMO_VERBOSE=1 python script.py
|
||||
# or equivalently:
|
||||
TORCH_LOGS="+dynamo" python script.py
|
||||
```
|
||||
|
||||
**Debug interactively:**
|
||||
```bash
|
||||
TORCHINDUCTOR_COMPILE_THREADS=1 python script.py # then attach pdb
|
||||
```
|
||||
|
||||
## Runtime Crashes
|
||||
|
||||
Segfaults and CUDA illegal memory access during execution of compiled code.
|
||||
|
||||
**Make crash deterministic:**
|
||||
```bash
|
||||
PYTORCH_NO_CUDA_MEMORY_CACHING=1 CUDA_LAUNCH_BLOCKING=1 python script.py
|
||||
```
|
||||
|
||||
**Add NaN checks to find the first bad kernel:**
|
||||
```bash
|
||||
TORCHINDUCTOR_NAN_ASSERTS=1 python script.py
|
||||
```
|
||||
|
||||
**Inductor sync debugging:**
|
||||
```python
|
||||
torch._inductor.config.triton.debug_sync_kernel = True # sync after every kernel
|
||||
torch._inductor.config.triton.debug_sync_graph = True # sync before/after graph
|
||||
```
|
||||
|
||||
**Fix approach:**
|
||||
1. Make deterministic with `PYTORCH_NO_CUDA_MEMORY_CACHING=1 CUDA_LAUNCH_BLOCKING=1`
|
||||
2. Check input shapes, devices, dtypes
|
||||
3. Inspect generated kernel code with `TORCH_LOGS="output_code"`
|
||||
4. Use `TORCHINDUCTOR_NAN_ASSERTS=1` to find the first kernel producing bad values
|
||||
5. Dynamic shapes are historically a common source of IMA
|
||||
|
||||
## Triton Kernel Failures
|
||||
|
||||
Triton assertion failures or index-out-of-bounds in generated kernels.
|
||||
|
||||
**Diagnose:**
|
||||
```bash
|
||||
TORCH_LOGS="output_code,schedule" python script.py
|
||||
```
|
||||
|
||||
**Fix approach:**
|
||||
1. Get the generated Triton kernel from `output_code` logs
|
||||
2. Check index computations for off-by-one or wrong stride calculations
|
||||
3. Check IR with `TORCH_COMPILE_DEBUG=1` to trace back to the FX op
|
||||
4. Check if fusion decisions created invalid index combinations
|
||||
|
||||
## Distinguish Trace-Time vs Runtime
|
||||
|
||||
Many bugs come from confusing these:
|
||||
- **Trace-time**: Inside Dynamo's symbolic interpreter. Function calls may be constant-folded.
|
||||
- **Runtime**: Real tensors, real Python calls.
|
||||
|
||||
When debugging, add `print()` directly in source files rather than monkey-patching -- dispatch chains make monkey-patching unreliable.
|
||||
|
||||
## Using the Minifier
|
||||
|
||||
The minifier reduces a failing graph to the smallest reproduction:
|
||||
|
||||
```bash
|
||||
# For compilation failures (level 2)
|
||||
TORCHDYNAMO_REPRO_AFTER=aot TORCHDYNAMO_REPRO_LEVEL=2 python script.py
|
||||
python minifier_launcher.py minify
|
||||
python minifier_launcher.py run
|
||||
|
||||
# For accuracy failures (level 4)
|
||||
TORCHDYNAMO_REPRO_AFTER=aot TORCHDYNAMO_REPRO_LEVEL=4 python script.py
|
||||
```
|
||||
134
.agents/skills/ruff/SKILL.md
Normal file
134
.agents/skills/ruff/SKILL.md
Normal file
@@ -0,0 +1,134 @@
|
||||
---
|
||||
name: ruff
|
||||
description:
|
||||
Guide for using ruff, the extremely fast Python linter and formatter. Use this
|
||||
when linting, formatting, or fixing Python code.
|
||||
---
|
||||
|
||||
# ruff
|
||||
|
||||
Ruff is an extremely fast Python linter and code formatter. It replaces Flake8,
|
||||
isort, Black, pyupgrade, autoflake, and dozens of other tools.
|
||||
|
||||
## When to use ruff
|
||||
|
||||
**Always use ruff for Python linting and formatting**, especially if you see:
|
||||
|
||||
- `[tool.ruff]` section in `pyproject.toml`
|
||||
- A `ruff.toml` or `.ruff.toml` configuration file
|
||||
|
||||
However, avoid making unnecessary changes:
|
||||
|
||||
- **Don't format unformatted code** - If `ruff format --diff` shows changes
|
||||
throughout an entire file, the project likely isn't using ruff for formatting.
|
||||
Skip formatting to avoid obscuring actual changes.
|
||||
- **Scope fixes to code being edited** - Use `ruff check --diff` to see fixes
|
||||
relevant to the code you're changing. Only apply fixes to files you're
|
||||
modifying unless the user explicitly asks for broader fixes.
|
||||
|
||||
## How to invoke ruff
|
||||
|
||||
- `uv run ruff ...` - Use when ruff is in the project's dependencies to ensure
|
||||
you use the pinned version
|
||||
- `uvx ruff ...` - Use when ruff is not a project dependency, or for quick
|
||||
one-off checks
|
||||
- `ruff ...` - Use if ruff is installed globally
|
||||
|
||||
## Commands
|
||||
|
||||
### Linting
|
||||
|
||||
```bash
|
||||
ruff check . # Check all files in current directory
|
||||
ruff check path/to/file.py # Check specific file
|
||||
ruff check --fix . # Auto-fix fixable violations
|
||||
ruff check --fix --unsafe-fixes . # Include unsafe fixes (review changes!)
|
||||
ruff check --watch . # Watch for changes and re-lint
|
||||
ruff check --select E,F . # Only check specific rules
|
||||
ruff check --ignore E501 . # Ignore specific rules
|
||||
ruff rule E501 # Explain a specific rule
|
||||
ruff linter # List available linters
|
||||
```
|
||||
|
||||
### Formatting
|
||||
|
||||
```bash
|
||||
ruff format . # Format all files
|
||||
ruff format path/to/file.py # Format specific file
|
||||
ruff format --check . # Check if files are formatted (no changes)
|
||||
ruff format --diff . # Show formatting diff without applying
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
Ruff is configured in `pyproject.toml` or `ruff.toml`:
|
||||
|
||||
```toml
|
||||
# pyproject.toml
|
||||
[tool.ruff.lint]
|
||||
select = ["E", "F", "I", "UP"] # Enable specific rule sets
|
||||
ignore = ["E501"] # Ignore specific rules
|
||||
|
||||
[tool.ruff.lint.isort]
|
||||
known-first-party = ["myproject"]
|
||||
```
|
||||
|
||||
## Migrating from other tools
|
||||
|
||||
### Black → ruff format
|
||||
|
||||
```bash
|
||||
black . → ruff format .
|
||||
black --check . → ruff format --check .
|
||||
black --diff . → ruff format --diff .
|
||||
```
|
||||
|
||||
### Flake8 → ruff check
|
||||
|
||||
```bash
|
||||
flake8 . → ruff check .
|
||||
flake8 --select E,F . → ruff check --select E,F .
|
||||
flake8 --ignore E501 . → ruff check --ignore E501 .
|
||||
```
|
||||
|
||||
### isort → ruff check
|
||||
|
||||
```bash
|
||||
isort . → ruff check --select I --fix .
|
||||
isort --check . → ruff check --select I .
|
||||
isort --diff . → ruff check --select I --diff .
|
||||
```
|
||||
|
||||
## Common patterns
|
||||
|
||||
### Apply lint fixes before formatting
|
||||
|
||||
Run `ruff check --fix` before `ruff format`. Lint fixes can change code
|
||||
structure (e.g., reordering imports), which formatting then cleans up.
|
||||
|
||||
```bash
|
||||
ruff check --fix .
|
||||
ruff format .
|
||||
```
|
||||
|
||||
### Applying and reviewing unsafe fixes
|
||||
|
||||
Ruff categorizes some auto-fixes as "unsafe" because they may change code
|
||||
behavior, not just style. For example, removing unused imports could break code
|
||||
that relies on side effects.
|
||||
|
||||
```bash
|
||||
ruff check --fix --unsafe-fixes --diff . # Preview changes first
|
||||
ruff check --fix --unsafe-fixes . # Apply changes
|
||||
```
|
||||
|
||||
**Always review changes before applying `--unsafe-fixes`:**
|
||||
|
||||
- Use `ruff rule <CODE>` to understand why the fix is considered unsafe
|
||||
- Verify the fix doesn't violate those assumptions in your code
|
||||
|
||||
## Documentation
|
||||
|
||||
For detailed information, read the official documentation:
|
||||
|
||||
- https://docs.astral.sh/ruff/
|
||||
135
.agents/skills/ty/SKILL.md
Normal file
135
.agents/skills/ty/SKILL.md
Normal file
@@ -0,0 +1,135 @@
|
||||
---
|
||||
name: ty
|
||||
description:
|
||||
Guide for using ty, the extremely fast Python type checker and language
|
||||
server. Use this when type checking Python code or setting up type checking in
|
||||
Python projects.
|
||||
---
|
||||
|
||||
# ty
|
||||
|
||||
ty is an extremely fast Python type checker and language server. It replaces
|
||||
mypy, Pyright, and other type checkers.
|
||||
|
||||
## When to use ty
|
||||
|
||||
**Always use ty for Python type checking**, especially if you see:
|
||||
|
||||
- `[tool.ty]` section in `pyproject.toml`
|
||||
- A `ty.toml` configuration file
|
||||
|
||||
## How to invoke ty
|
||||
|
||||
- `uv run ty ...` - Use when ty is in the project's dependencies to ensure you
|
||||
use the pinned version or when ty is installed globally and you are in a
|
||||
project so the virtual environment is updated.
|
||||
- `uvx ty ...` - Use when ty is not a project dependency, or for quick one-off
|
||||
checks
|
||||
|
||||
## Commands
|
||||
|
||||
### Type checking
|
||||
|
||||
```bash
|
||||
ty check # Check all files in current directory
|
||||
ty check path/to/file.py # Check specific file
|
||||
ty check src/ # Check specific directory
|
||||
```
|
||||
|
||||
### Rule configuration
|
||||
|
||||
```bash
|
||||
ty check --error possibly-unresolved-reference # Treat as error
|
||||
ty check --warn division-by-zero # Treat as warning
|
||||
ty check --ignore unresolved-import # Disable rule
|
||||
```
|
||||
|
||||
### Python version targeting
|
||||
|
||||
```bash
|
||||
ty check --python-version 3.12 # Check against Python 3.12
|
||||
ty check --python-platform linux # Target Linux platform
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
ty is configured in `pyproject.toml` or `ty.toml`:
|
||||
|
||||
```toml
|
||||
# pyproject.toml
|
||||
[tool.ty.environment]
|
||||
python-version = "3.12"
|
||||
|
||||
[tool.ty.rules]
|
||||
possibly-unresolved-reference = "warn"
|
||||
division-by-zero = "error"
|
||||
|
||||
[tool.ty.src]
|
||||
include = ["src/**/*.py"]
|
||||
exclude = ["**/migrations/**"]
|
||||
|
||||
[tool.ty.terminal]
|
||||
output-format = "full"
|
||||
error-on-warning = false
|
||||
```
|
||||
|
||||
### Per-file overrides
|
||||
|
||||
Use overrides to apply different rules to specific files, such as relaxing rules
|
||||
for tests or scripts that have different typing requirements than production
|
||||
code:
|
||||
|
||||
```toml
|
||||
[[tool.ty.overrides]]
|
||||
include = ["tests/**", "**/test_*.py"]
|
||||
|
||||
[tool.ty.overrides.rules]
|
||||
possibly-unresolved-reference = "warn"
|
||||
```
|
||||
|
||||
## Language server
|
||||
|
||||
This plugin automatically configures the ty language server for Python files
|
||||
(`.py` and `.pyi`).
|
||||
|
||||
## Migrating from other tools
|
||||
|
||||
### mypy → ty
|
||||
|
||||
```bash
|
||||
mypy . → ty check
|
||||
mypy --strict . → ty check --error-on-warning
|
||||
mypy path/to/file.py → ty check path/to/file.py
|
||||
```
|
||||
|
||||
### Pyright → ty
|
||||
|
||||
```bash
|
||||
pyright . → ty check
|
||||
pyright path/to/file.py → ty check path/to/file.py
|
||||
```
|
||||
|
||||
## Common patterns
|
||||
|
||||
### Don't add ignore comments
|
||||
|
||||
Fix type errors instead of suppressing them. Only add ignore comments when
|
||||
explicitly requested by the user. Use `ty: ignore`, not `type: ignore`, and
|
||||
prefer rule-specific ignores:
|
||||
|
||||
```python
|
||||
# Good: rule-specific ignore
|
||||
x = undefined_var # ty: ignore[possibly-unresolved-reference]
|
||||
|
||||
# Bad: blanket ty ignore
|
||||
x = undefined_var # ty: ignore
|
||||
|
||||
# Bad: tool agnostic blanket ignore
|
||||
x = undefined_var # type: ignore
|
||||
```
|
||||
|
||||
## Documentation
|
||||
|
||||
For detailed information, read the official documentation:
|
||||
|
||||
- https://docs.astral.sh/ty/
|
||||
182
.agents/skills/uv/SKILL.md
Normal file
182
.agents/skills/uv/SKILL.md
Normal file
@@ -0,0 +1,182 @@
|
||||
---
|
||||
name: uv
|
||||
description:
|
||||
Guide for using uv, the Python package and project manager. Use this when
|
||||
working with Python projects, scripts, packages, or tools.
|
||||
---
|
||||
|
||||
# uv
|
||||
|
||||
uv is an extremely fast Python package and project manager. It replaces pip,
|
||||
pip-tools, pipx, pyenv, virtualenv, poetry, etc.
|
||||
|
||||
## When to use uv
|
||||
|
||||
**Always use uv for Python work**, especially if you see:
|
||||
|
||||
- The `uv.lock` file
|
||||
- uv headers in `requirements*` files, e.g., "This file was autogenerated by uv"
|
||||
|
||||
Don't use uv in projects managed by other tools:
|
||||
|
||||
- Poetry projects (identifiable by `poetry.lock` file)
|
||||
- PDM projects (identifiable by `pdm.lock` file)
|
||||
|
||||
## Choosing the right workflow
|
||||
|
||||
### Scripts
|
||||
|
||||
**Use when:** Running single Python files and standalone scripts.
|
||||
|
||||
**Key commands:**
|
||||
|
||||
```bash
|
||||
uv run script.py # Run a script
|
||||
uv run --with requests script.py # Run with additional packages
|
||||
uv add --script script.py requests # Add dependencies inline to the script
|
||||
```
|
||||
|
||||
### Projects
|
||||
|
||||
**Use when:** There is a `pyproject.toml` or `uv.lock`
|
||||
|
||||
**Key commands:**
|
||||
|
||||
```bash
|
||||
uv init # Create new project
|
||||
uv add requests # Add dependency
|
||||
uv remove requests # Remove dependency
|
||||
uv sync # Install from lockfile
|
||||
uv run <command> # Run commands in environment
|
||||
uv run python -c "" # Run Python in project environment
|
||||
uv run -p 3.12 <command> # Run with specific Python version
|
||||
```
|
||||
|
||||
### Tools
|
||||
|
||||
**Use when:** Running command-line tools (e.g., ruff, ty, pytest) without
|
||||
installation.
|
||||
|
||||
**Key commands:**
|
||||
|
||||
```bash
|
||||
uvx <tool> <args> # Run a tool without installation
|
||||
uvx <tool>@<version> <args> # Run a specific version of a tool
|
||||
```
|
||||
|
||||
**Important:**
|
||||
|
||||
- `uvx` runs tools from PyPI by package name. This can be unsafe - only run
|
||||
well-known tools.
|
||||
- Only use `uv tool install` only when specifically requested by the user.
|
||||
|
||||
### Pip interface
|
||||
|
||||
**Use when:** Legacy workflows with `requirements.txt` or manual environment
|
||||
management, no `uv.lock` present.
|
||||
|
||||
**Key commands:**
|
||||
|
||||
```bash
|
||||
uv venv
|
||||
uv pip install -r requirements.txt
|
||||
uv pip compile requirements.in -o requirements.txt
|
||||
uv pip sync requirements.txt
|
||||
|
||||
# Platform independent resolution
|
||||
uv pip compile --universal requirements.in -o requirements.txt
|
||||
```
|
||||
|
||||
**Important:**
|
||||
|
||||
- Don't use the pip interface unless clearly needed.
|
||||
- Don't introduce new `requirements.txt` files.
|
||||
- Prefer `uv init` for new projects.
|
||||
|
||||
## Migrating from other tools
|
||||
|
||||
### pyenv → uv python
|
||||
|
||||
```bash
|
||||
pyenv install 3.12 → uv python install 3.12
|
||||
pyenv versions → uv python list --only-installed
|
||||
pyenv local 3.12 → uv python pin 3.12
|
||||
pyenv global 3.12 → uv python install 3.12 --default
|
||||
```
|
||||
|
||||
### pipx → uvx
|
||||
|
||||
```bash
|
||||
pipx run ruff → uvx ruff
|
||||
pipx install ruff → uv tool install ruff
|
||||
pipx upgrade ruff → uv tool upgrade ruff
|
||||
pipx list → uv tool list
|
||||
```
|
||||
|
||||
### pip and pip-tools → uv pip
|
||||
|
||||
```bash
|
||||
pip install package → uv pip install package
|
||||
pip install -r req.txt → uv pip install -r req.txt
|
||||
pip freeze → uv pip freeze
|
||||
pip-compile req.in → uv pip compile req.in
|
||||
pip-sync req.txt → uv pip sync req.txt
|
||||
virtualenv .venv → uv venv
|
||||
```
|
||||
|
||||
## Common patterns
|
||||
|
||||
### Don't use pip in uv projects
|
||||
|
||||
```bash
|
||||
# Bad
|
||||
pip install requests
|
||||
|
||||
# Good
|
||||
uv add requests
|
||||
```
|
||||
|
||||
### Don't run python directly
|
||||
|
||||
```bash
|
||||
# Bad
|
||||
python script.py
|
||||
|
||||
# Good
|
||||
uv run script.py
|
||||
```
|
||||
|
||||
```bash
|
||||
# Bad
|
||||
python -c "..."
|
||||
|
||||
# Good
|
||||
uv run python -c "..."
|
||||
```
|
||||
|
||||
```bash
|
||||
# Bad
|
||||
python3.12 -c "..."
|
||||
|
||||
# Good
|
||||
uvx python@3.12 -c "..."
|
||||
```
|
||||
|
||||
### Don't manually manage environments in uv projects
|
||||
|
||||
```bash
|
||||
# Bad
|
||||
python -m venv .venv
|
||||
source .venv/bin/activate
|
||||
|
||||
# Good
|
||||
uv run <command>
|
||||
```
|
||||
|
||||
## Documentation
|
||||
|
||||
For detailed information, read the official documentation:
|
||||
|
||||
- https://docs.astral.sh/uv/llms.txt
|
||||
|
||||
The documentation links to specific pages for each of these workflows.
|
||||
@@ -17,11 +17,15 @@
|
||||
"userUid": "1000",
|
||||
"userGid": "1000",
|
||||
"configureZshAsDefaultShell": false
|
||||
},
|
||||
"ghcr.io/devcontainers/features/node:1": {
|
||||
"version": "lts"
|
||||
}
|
||||
},
|
||||
"remoteUser": "ubuntu",
|
||||
"remoteEnv": {
|
||||
"CARGO_HOME": "/home/ubuntu/.cache/luminal/cargo"
|
||||
"CARGO_HOME": "/home/ubuntu/.cache/luminal/cargo",
|
||||
"CODEX_HOME": "${containerWorkspaceFolder}/.claude/codex"
|
||||
},
|
||||
"postStartCommand": "mkdir -p /home/ubuntu/.cache/luminal/cargo && git config --global --add safe.directory ${containerWorkspaceFolder} && gh auth setup-git",
|
||||
"customizations": {
|
||||
|
||||
@@ -21,11 +21,15 @@
|
||||
"userUid": "1000",
|
||||
"userGid": "1000",
|
||||
"configureZshAsDefaultShell": false
|
||||
},
|
||||
"ghcr.io/devcontainers/features/node:1": {
|
||||
"version": "lts"
|
||||
}
|
||||
},
|
||||
"remoteUser": "ubuntu",
|
||||
"remoteEnv": {
|
||||
"CARGO_HOME": "/home/ubuntu/.cache/luminal/cargo"
|
||||
"CARGO_HOME": "/home/ubuntu/.cache/luminal/cargo",
|
||||
"CODEX_HOME": "${containerWorkspaceFolder}/.claude/codex"
|
||||
},
|
||||
"postStartCommand": "mkdir -p /home/ubuntu/.cache/luminal/cargo && git config --global --add safe.directory ${containerWorkspaceFolder} && gh auth setup-git",
|
||||
"customizations": {
|
||||
@@ -52,4 +56,4 @@
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -1,6 +1,9 @@
|
||||
/target
|
||||
/crates/**/target
|
||||
/examples/**/target
|
||||
.claude-project
|
||||
.claude-memory
|
||||
.codex
|
||||
|
||||
*.env
|
||||
.claude/
|
||||
|
||||
34
CLAUDE.md
Normal file
34
CLAUDE.md
Normal file
@@ -0,0 +1,34 @@
|
||||
# Luminal
|
||||
|
||||
## Package Management
|
||||
|
||||
- Use `uv add`, `uv add --dev`, `uv remove` for Python dependencies (pyproject.toml is in `crates/luminal_python/`)
|
||||
- Use `uv sync` to sync the Python environment
|
||||
- Never use pip, pip-tools, poetry, or conda
|
||||
- Never manually create or activate virtual environments — uv manages `.venv/` automatically
|
||||
- Never generate requirements.txt
|
||||
|
||||
## Code Execution
|
||||
|
||||
- Always use `uv run` to execute Python tools: `uv run pytest`, `uv run pre-commit`, `uv run python`
|
||||
- Use `cargo` directly for Rust: `cargo build`, `cargo test`, `cargo check`, `cargo clippy`
|
||||
- Python project root is `crates/luminal_python/` — run `uv run` commands from there
|
||||
|
||||
## Building the Python Package (Maturin)
|
||||
|
||||
- After modifying `.rs` files that affect the Python bridge, rebuild with: `maturin develop --release`
|
||||
- Maturin config is in `crates/luminal_python/pyproject.toml` under `[tool.maturin]`
|
||||
|
||||
## Pre-commit
|
||||
|
||||
- Run with: `uv run pre-commit run --all-files`
|
||||
- Hooks configured: ruff-check, ruff-format (Python), cargo-fmt, cargo-clippy (Rust)
|
||||
- Manual-stage hooks (cargo-clippy-metal, cargo-clippy-cuda-lite) run with `--hook-stage manual`
|
||||
|
||||
## Testing
|
||||
|
||||
- **Rust tests**: `cargo test -p <crate_name>`
|
||||
- **Python tests**: `cd crates/luminal_python && uv run pytest`
|
||||
- `./run_test.sh` — native backend
|
||||
- `./run_tests_cuda.sh` — CUDA backend
|
||||
- See `crates/luminal_python/CLAUDE.md` for Python test patterns and conventions
|
||||
1
crates/luminal_python/.gitignore
vendored
1
crates/luminal_python/.gitignore
vendored
@@ -1,5 +1,4 @@
|
||||
*.onnx
|
||||
tests/llama38b_ref_logits.pt
|
||||
__pycache__/
|
||||
*.pyc
|
||||
uv.lock
|
||||
|
||||
@@ -1,4 +1,8 @@
|
||||
A couple of short things to keep in mind
|
||||
## Python Environment
|
||||
|
||||
- Always use `uv run` to execute Python tools (pytest, pre-commit, python) — never bare `pytest` or `python`
|
||||
- Use `uv add` / `uv add --dev` / `uv remove` for dependencies — never hand-edit pyproject.toml deps
|
||||
- After modifying Rust source files, rebuild before running Python tests: `maturin develop --release`
|
||||
|
||||
## Lessons Learned
|
||||
|
||||
|
||||
@@ -3,15 +3,19 @@ name = "luminal_python"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"numpy>=2.0.2",
|
||||
"torch>=2.10.0",
|
||||
"onnx",
|
||||
"onnxscript",
|
||||
"safetensors",
|
||||
"flash-attn-3>=3.0.0",
|
||||
]
|
||||
|
||||
[tool.uv]
|
||||
no-build-isolation-package = ["flash-attn"]
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch-cu128"
|
||||
url = "https://download.pytorch.org/whl/cu128"
|
||||
@@ -21,6 +25,7 @@ explicit = true
|
||||
torch = [
|
||||
{ index = "pytorch-cu128", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||
]
|
||||
flash-attn-3 = { index = "pytorch-cu128" }
|
||||
|
||||
|
||||
[build-system]
|
||||
@@ -40,13 +45,21 @@ markers = [
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"maturin>=1.0,<2.0",
|
||||
"maturin-import-hook>=0.3.0",
|
||||
"pytest>=9.0.2",
|
||||
"pytest-profiling",
|
||||
"snakeviz",
|
||||
"maturin-import-hook>=0.3.0",
|
||||
"pytest-randomly>=4.0.1",
|
||||
"transformers>=4.40.0",
|
||||
"transformers>=5.5.0,<6",
|
||||
"diffusers>=0.35.0",
|
||||
"onnxsim",
|
||||
"tiktoken>=0.12.0",
|
||||
"pydantic>=2.12.5",
|
||||
"psutil>=7.2.2",
|
||||
"modal>=1.3.5",
|
||||
"pillow",
|
||||
"flash-attn>=2.8.3",
|
||||
]
|
||||
flash-attention-4 = [
|
||||
"nvidia-cutlass-dsl==4.1.0",
|
||||
]
|
||||
|
||||
176
crates/luminal_python/tests/_llama38b_artifacts.py
Normal file
176
crates/luminal_python/tests/_llama38b_artifacts.py
Normal file
@@ -0,0 +1,176 @@
|
||||
"""Helpers for caching Llama 3.1-8B test artifacts under pytest cache."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from safetensors.torch import save_file
|
||||
from transformers import AutoConfig, LlamaForCausalLM
|
||||
|
||||
# This code is designed to be deleted.
|
||||
# We should not need to cache pt2 or onnx files to get reasonable compile performance.
|
||||
|
||||
MODEL_ID = "NousResearch/Meta-Llama-3.1-8B-Instruct"
|
||||
INPUT_IDS_LIST = [1, 2, 3, 4]
|
||||
INPUT_IDS = torch.tensor([INPUT_IDS_LIST], dtype=torch.long)
|
||||
ARTIFACT_SCHEMA_VERSION = 1
|
||||
ONNX_OPSET_VERSION = 20
|
||||
PT2_STRICT = False
|
||||
|
||||
_REF_LOGITS_META_KEY = "luminal_python/llama38b_artifacts/ref_logits_v1"
|
||||
_ONNX_META_KEY = "luminal_python/llama38b_artifacts/onnx_v1"
|
||||
_PT2_META_KEY = "luminal_python/llama38b_artifacts/pt2_v1"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Llama38BArtifactBundle:
|
||||
ref_logits_path: Path
|
||||
onnx_path: Path | None = None
|
||||
pt2_path: Path | None = None
|
||||
weights_path: Path | None = None
|
||||
|
||||
|
||||
def ensure_onnx_bundle(cache, cache_dir: Path) -> Llama38BArtifactBundle:
|
||||
"""Ensure ONNX artifacts and shared reference logits exist in pytest cache."""
|
||||
ref_logits_path = cache_dir / "ref_logits.pt"
|
||||
onnx_dir = cache_dir / "onnx"
|
||||
onnx_path = onnx_dir / "llama38b.onnx"
|
||||
|
||||
ref_metadata = _ref_logits_metadata()
|
||||
onnx_metadata = _onnx_metadata()
|
||||
needs_ref_logits = cache.get(_REF_LOGITS_META_KEY, None) != ref_metadata or not (
|
||||
ref_logits_path.is_file()
|
||||
)
|
||||
needs_onnx = cache.get(_ONNX_META_KEY, None) != onnx_metadata or not (
|
||||
onnx_path.is_file()
|
||||
)
|
||||
|
||||
if needs_ref_logits or needs_onnx:
|
||||
print(f"Generating cached ONNX artifacts for {MODEL_ID} in {cache_dir}")
|
||||
if needs_ref_logits:
|
||||
ref_logits_path.unlink(missing_ok=True)
|
||||
if needs_onnx:
|
||||
shutil.rmtree(onnx_dir, ignore_errors=True)
|
||||
onnx_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
model = _load_model()
|
||||
|
||||
if needs_ref_logits:
|
||||
ref_logits = _compute_ref_logits(model)
|
||||
torch.save(ref_logits, ref_logits_path)
|
||||
cache.set(_REF_LOGITS_META_KEY, ref_metadata)
|
||||
|
||||
if needs_onnx:
|
||||
torch.onnx.export(
|
||||
model,
|
||||
(INPUT_IDS,),
|
||||
str(onnx_path),
|
||||
opset_version=ONNX_OPSET_VERSION,
|
||||
input_names=["input_ids"],
|
||||
output_names=["logits"],
|
||||
)
|
||||
cache.set(_ONNX_META_KEY, onnx_metadata)
|
||||
|
||||
return Llama38BArtifactBundle(ref_logits_path=ref_logits_path, onnx_path=onnx_path)
|
||||
|
||||
|
||||
def ensure_pt2_bundle(cache, cache_dir: Path) -> Llama38BArtifactBundle:
|
||||
"""Ensure PT2 artifacts and shared reference logits exist in pytest cache."""
|
||||
ref_logits_path = cache_dir / "ref_logits.pt"
|
||||
pt2_dir = cache_dir / "pt2"
|
||||
pt2_path = pt2_dir / "llama38b.pt2"
|
||||
weights_path = pt2_dir / "llama38b_weights.safetensors"
|
||||
|
||||
ref_metadata = _ref_logits_metadata()
|
||||
pt2_metadata = _pt2_metadata()
|
||||
needs_ref_logits = cache.get(_REF_LOGITS_META_KEY, None) != ref_metadata or not (
|
||||
ref_logits_path.is_file()
|
||||
)
|
||||
needs_pt2 = cache.get(_PT2_META_KEY, None) != pt2_metadata or not (
|
||||
pt2_path.is_file() and weights_path.is_file()
|
||||
)
|
||||
|
||||
if needs_ref_logits or needs_pt2:
|
||||
print(f"Generating cached PT2 artifacts for {MODEL_ID} in {cache_dir}")
|
||||
if needs_ref_logits:
|
||||
ref_logits_path.unlink(missing_ok=True)
|
||||
if needs_pt2:
|
||||
shutil.rmtree(pt2_dir, ignore_errors=True)
|
||||
pt2_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
model = _load_model()
|
||||
|
||||
if needs_ref_logits:
|
||||
ref_logits = _compute_ref_logits(model)
|
||||
torch.save(ref_logits, ref_logits_path)
|
||||
cache.set(_REF_LOGITS_META_KEY, ref_metadata)
|
||||
|
||||
if needs_pt2:
|
||||
exported_program = torch.export.export(
|
||||
model, (INPUT_IDS,), strict=PT2_STRICT
|
||||
)
|
||||
torch.export.save(exported_program, str(pt2_path))
|
||||
|
||||
state_dict = {
|
||||
key: value.float().clone()
|
||||
for key, value in exported_program.state_dict.items()
|
||||
}
|
||||
save_file(state_dict, str(weights_path))
|
||||
cache.set(_PT2_META_KEY, pt2_metadata)
|
||||
|
||||
return Llama38BArtifactBundle(
|
||||
ref_logits_path=ref_logits_path,
|
||||
pt2_path=pt2_path,
|
||||
weights_path=weights_path,
|
||||
)
|
||||
|
||||
|
||||
def _load_model() -> LlamaForCausalLM:
|
||||
config = AutoConfig.from_pretrained(MODEL_ID)
|
||||
config.use_cache = False
|
||||
config._attn_implementation = "eager"
|
||||
|
||||
return LlamaForCausalLM.from_pretrained(
|
||||
MODEL_ID,
|
||||
config=config,
|
||||
torch_dtype=torch.float32,
|
||||
).eval()
|
||||
|
||||
|
||||
def _compute_ref_logits(model: LlamaForCausalLM) -> torch.Tensor:
|
||||
with torch.no_grad():
|
||||
return model(INPUT_IDS).logits.clone()
|
||||
|
||||
|
||||
def _ref_logits_metadata() -> dict[str, object]:
|
||||
return {
|
||||
"schema_version": ARTIFACT_SCHEMA_VERSION,
|
||||
"model_id": MODEL_ID,
|
||||
"input_ids": INPUT_IDS_LIST,
|
||||
"device": "cpu",
|
||||
"torch_dtype": "float32",
|
||||
"use_cache": False,
|
||||
"attn_implementation": "eager",
|
||||
"torch_version": torch.__version__,
|
||||
"transformers_version": transformers.__version__,
|
||||
}
|
||||
|
||||
|
||||
def _onnx_metadata() -> dict[str, object]:
|
||||
return {
|
||||
**_ref_logits_metadata(),
|
||||
"artifact_type": "onnx",
|
||||
"opset_version": ONNX_OPSET_VERSION,
|
||||
}
|
||||
|
||||
|
||||
def _pt2_metadata() -> dict[str, object]:
|
||||
return {
|
||||
**_ref_logits_metadata(),
|
||||
"artifact_type": "pt2",
|
||||
"strict": PT2_STRICT,
|
||||
}
|
||||
@@ -1,21 +1,64 @@
|
||||
"""Test configuration."""
|
||||
# ruff: noqa: E402
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
from urllib.request import urlopen
|
||||
import warnings
|
||||
|
||||
try:
|
||||
import huggingface_hub
|
||||
from transformers import logging as transformers_logging
|
||||
except ImportError: # pragma: no cover - optional for non-HF test environments
|
||||
huggingface_hub = None
|
||||
transformers_logging = None
|
||||
|
||||
# Enable automatic Rust rebuilds during test development
|
||||
try:
|
||||
import maturin_import_hook
|
||||
from maturin_import_hook.settings import MaturinSettings
|
||||
import maturin_import_hook
|
||||
from maturin_import_hook.settings import MaturinSettings
|
||||
from maturin_import_hook.project_importer import DefaultProjectFileSearcher
|
||||
|
||||
backend = os.getenv("LUMINAL_BACKEND", "native").lower()
|
||||
settings = MaturinSettings(features=["cuda"]) if backend == "cuda" else None
|
||||
maturin_import_hook.install(settings=settings)
|
||||
except ImportError:
|
||||
pass # Hook not available, rebuilds will be manual
|
||||
backend = os.getenv("LUMINAL_BACKEND", "native").lower()
|
||||
settings = MaturinSettings(
|
||||
release=(backend == "cuda"),
|
||||
features=["cuda"] if backend == "cuda" else None,
|
||||
skip_install=True,
|
||||
)
|
||||
searcher = DefaultProjectFileSearcher(
|
||||
source_excluded_dir_names=(
|
||||
DefaultProjectFileSearcher.DEFAULT_SOURCE_EXCLUDED_DIR_NAMES
|
||||
| {".claude", "docs", ".github", "examples"}
|
||||
),
|
||||
)
|
||||
maturin_import_hook.install(
|
||||
settings=settings,
|
||||
enable_automatic_installation=True,
|
||||
file_searcher=searcher,
|
||||
)
|
||||
logging.getLogger("maturin_import_hook").disabled = True
|
||||
logging.getLogger("maturin_import_hook.project_importer").disabled = True
|
||||
|
||||
# Silence noisy ONNX / onnxscript / httpx logging
|
||||
for _logger_name in (
|
||||
"onnxscript",
|
||||
"onnx_ir",
|
||||
"torch.onnx",
|
||||
"httpx",
|
||||
):
|
||||
logging.getLogger(_logger_name).setLevel(logging.WARNING)
|
||||
|
||||
# Suppress torch.onnx diagnostics/progress output and torchvision warnings
|
||||
os.environ.setdefault("TORCH_ONNX_VERBOSE", "0")
|
||||
os.environ.setdefault("TORCH_ONNX_LOG_LEVEL", "ERROR")
|
||||
warnings.filterwarnings("ignore", message=".*torchvision.*")
|
||||
warnings.filterwarnings("ignore", module="torch.onnx")
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch._dynamo
|
||||
from _llama38b_artifacts import ensure_onnx_bundle, ensure_pt2_bundle
|
||||
|
||||
torch.set_float32_matmul_precision("highest")
|
||||
|
||||
@@ -26,6 +69,139 @@ def device() -> torch.device:
|
||||
return torch.device("cuda") if backend == "cuda" else torch.device("cpu")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def configure_hf_test_output() -> None:
|
||||
if transformers_logging is not None:
|
||||
transformers_logging.disable_progress_bar()
|
||||
if huggingface_hub is not None:
|
||||
huggingface_hub.utils.disable_progress_bars()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def configure_dynamo():
|
||||
original_cache_size_limit = torch._dynamo.config.cache_size_limit
|
||||
original_suppress_errors = torch._dynamo.config.suppress_errors
|
||||
|
||||
def _configure(
|
||||
*, cache_size_limit: int | None = None, suppress_errors: bool | None = None
|
||||
) -> None:
|
||||
if cache_size_limit is not None:
|
||||
torch._dynamo.config.cache_size_limit = cache_size_limit
|
||||
if suppress_errors is not None:
|
||||
torch._dynamo.config.suppress_errors = suppress_errors
|
||||
|
||||
yield _configure
|
||||
|
||||
torch._dynamo.config.cache_size_limit = original_cache_size_limit
|
||||
torch._dynamo.config.suppress_errors = original_suppress_errors
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def _llama38b_cache_dir(pytestconfig: pytest.Config) -> Path:
|
||||
return pytestconfig.cache.mkdir("luminal_llama38b_artifacts_v1")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def _hf_multimodal_cache_dir(pytestconfig: pytest.Config) -> Path:
|
||||
return pytestconfig.cache.mkdir("luminal_hf_multimodal_v1")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def hf_multimodal_image_path(
|
||||
pytestconfig: pytest.Config, _hf_multimodal_cache_dir: Path
|
||||
) -> Path:
|
||||
image_url = (
|
||||
"https://huggingface.co/datasets/huggingface/documentation-images/"
|
||||
"resolve/main/bee.jpg"
|
||||
)
|
||||
image_path = _hf_multimodal_cache_dir / "bee.jpg"
|
||||
metadata_key = "luminal_python/hf_multimodal_image_v1"
|
||||
metadata = {
|
||||
"schema_version": 1,
|
||||
"url": image_url,
|
||||
"filename": image_path.name,
|
||||
}
|
||||
|
||||
needs_download = pytestconfig.cache.get(metadata_key, None) != metadata or not (
|
||||
image_path.is_file()
|
||||
)
|
||||
if not needs_download:
|
||||
return image_path
|
||||
|
||||
image_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
tmp_path: Path | None = None
|
||||
try:
|
||||
with urlopen(image_url, timeout=60) as response:
|
||||
with tempfile.NamedTemporaryFile(
|
||||
dir=image_path.parent, delete=False
|
||||
) as tmp_file:
|
||||
tmp_path = Path(tmp_file.name)
|
||||
while chunk := response.read(1024 * 1024):
|
||||
tmp_file.write(chunk)
|
||||
|
||||
assert tmp_path is not None
|
||||
tmp_path.replace(image_path)
|
||||
except Exception:
|
||||
if tmp_path is not None:
|
||||
tmp_path.unlink(missing_ok=True)
|
||||
raise
|
||||
|
||||
pytestconfig.cache.set(metadata_key, metadata)
|
||||
return image_path
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def _llama38b_onnx_bundle(pytestconfig: pytest.Config, _llama38b_cache_dir: Path):
|
||||
return ensure_onnx_bundle(pytestconfig.cache, _llama38b_cache_dir)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def _llama38b_pt2_bundle(pytestconfig: pytest.Config, _llama38b_cache_dir: Path):
|
||||
return ensure_pt2_bundle(pytestconfig.cache, _llama38b_cache_dir)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def llama38b_ref_logits(request: pytest.FixtureRequest) -> torch.Tensor:
|
||||
fixturenames = set(request.fixturenames)
|
||||
uses_onnx = "llama38b_onnx_path" in fixturenames
|
||||
uses_pt2 = bool({"llama38b_pt2_path", "llama38b_weights_path"} & fixturenames)
|
||||
|
||||
if uses_onnx and uses_pt2:
|
||||
raise pytest.UsageError(
|
||||
"llama38b_ref_logits cannot be requested with both ONNX and PT2 "
|
||||
"artifact fixtures in the same test"
|
||||
)
|
||||
if uses_onnx:
|
||||
bundle = request.getfixturevalue("_llama38b_onnx_bundle")
|
||||
elif uses_pt2:
|
||||
bundle = request.getfixturevalue("_llama38b_pt2_bundle")
|
||||
else:
|
||||
raise pytest.UsageError(
|
||||
"llama38b_ref_logits must be requested alongside llama38b_onnx_path "
|
||||
"or llama38b_pt2_path/llama38b_weights_path"
|
||||
)
|
||||
|
||||
return torch.load(bundle.ref_logits_path, weights_only=True)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def llama38b_onnx_path(_llama38b_onnx_bundle) -> Path:
|
||||
assert _llama38b_onnx_bundle.onnx_path is not None
|
||||
return _llama38b_onnx_bundle.onnx_path
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def llama38b_pt2_path(_llama38b_pt2_bundle) -> Path:
|
||||
assert _llama38b_pt2_bundle.pt2_path is not None
|
||||
return _llama38b_pt2_bundle.pt2_path
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def llama38b_weights_path(_llama38b_pt2_bundle) -> Path:
|
||||
assert _llama38b_pt2_bundle.weights_path is not None
|
||||
return _llama38b_pt2_bundle.weights_path
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True, scope="function")
|
||||
def reset_torch_dynamo():
|
||||
# We need this for two reasons
|
||||
|
||||
@@ -1,57 +0,0 @@
|
||||
"""Generate pre-computed artifacts for test_hf_llama38b_cached_onnx.
|
||||
|
||||
Run once:
|
||||
uv run python tests/generate_llama38b_artifacts.py
|
||||
|
||||
Produces:
|
||||
tests/llama38b.onnx — ONNX export of Llama 3.1-8B
|
||||
tests/llama38b_ref_logits.pt — reference logits for input_ids=[1,2,3,4]
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from transformers import AutoConfig, LlamaForCausalLM
|
||||
|
||||
SCRIPT_DIR = Path(__file__).resolve().parent
|
||||
ONNX_PATH = SCRIPT_DIR / "llama38b.onnx"
|
||||
LOGITS_PATH = SCRIPT_DIR / "llama38b_ref_logits.pt"
|
||||
|
||||
INPUT_IDS = torch.tensor([[1, 2, 3, 4]])
|
||||
|
||||
|
||||
def main():
|
||||
config = AutoConfig.from_pretrained("NousResearch/Meta-Llama-3.1-8B-Instruct")
|
||||
config.use_cache = False
|
||||
config._attn_implementation = "eager"
|
||||
|
||||
print("Loading model on CPU...")
|
||||
model = LlamaForCausalLM.from_pretrained(
|
||||
"NousResearch/Meta-Llama-3.1-8B-Instruct",
|
||||
config=config,
|
||||
torch_dtype=torch.float32,
|
||||
).eval()
|
||||
|
||||
print("Computing reference logits...")
|
||||
with torch.no_grad():
|
||||
ref_logits = model(INPUT_IDS).logits.clone()
|
||||
print(f"Reference logits shape: {ref_logits.shape}")
|
||||
|
||||
print(f"Saving reference logits to {LOGITS_PATH}")
|
||||
torch.save(ref_logits, LOGITS_PATH)
|
||||
|
||||
print(f"Exporting ONNX to {ONNX_PATH}")
|
||||
torch.onnx.export(
|
||||
model,
|
||||
(INPUT_IDS,),
|
||||
str(ONNX_PATH),
|
||||
opset_version=20,
|
||||
input_names=["input_ids"],
|
||||
output_names=["logits"],
|
||||
)
|
||||
|
||||
print("Done.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,62 +0,0 @@
|
||||
"""Generate pre-computed PT2 artifacts for test_hf_llama38b_cached.
|
||||
|
||||
Run once:
|
||||
uv run python tests/generate_llama38b_pt2_artifacts.py
|
||||
|
||||
Produces:
|
||||
tests/llama38b.pt2 — torch.export of Llama 3.1-8B
|
||||
tests/llama38b_weights.safetensors — model weights
|
||||
tests/llama38b_ref_logits.pt — reference logits for input_ids=[1,2,3,4]
|
||||
(shared with ONNX artifact script)
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from safetensors.torch import save_file
|
||||
from transformers import AutoConfig, LlamaForCausalLM
|
||||
|
||||
SCRIPT_DIR = Path(__file__).resolve().parent
|
||||
PT2_PATH = SCRIPT_DIR / "llama38b.pt2"
|
||||
WEIGHTS_PATH = SCRIPT_DIR / "llama38b_weights.safetensors"
|
||||
LOGITS_PATH = SCRIPT_DIR / "llama38b_ref_logits.pt"
|
||||
|
||||
INPUT_IDS = torch.tensor([[1, 2, 3, 4]])
|
||||
|
||||
|
||||
def main():
|
||||
config = AutoConfig.from_pretrained("NousResearch/Meta-Llama-3.1-8B-Instruct")
|
||||
config.use_cache = False
|
||||
config._attn_implementation = "eager"
|
||||
|
||||
print("Loading model on CPU...")
|
||||
model = LlamaForCausalLM.from_pretrained(
|
||||
"NousResearch/Meta-Llama-3.1-8B-Instruct",
|
||||
config=config,
|
||||
torch_dtype=torch.float32,
|
||||
).eval()
|
||||
|
||||
# Generate reference logits (shared with ONNX artifact script)
|
||||
if not LOGITS_PATH.exists():
|
||||
print("Computing reference logits...")
|
||||
with torch.no_grad():
|
||||
ref_logits = model(INPUT_IDS).logits.clone()
|
||||
print(f"Reference logits shape: {ref_logits.shape}")
|
||||
print(f"Saving reference logits to {LOGITS_PATH}")
|
||||
torch.save(ref_logits, LOGITS_PATH)
|
||||
else:
|
||||
print(f"Reference logits already exist at {LOGITS_PATH}, skipping")
|
||||
|
||||
print(f"Exporting PT2 to {PT2_PATH}")
|
||||
ep = torch.export.export(model, (INPUT_IDS,), strict=False)
|
||||
torch.export.save(ep, str(PT2_PATH))
|
||||
|
||||
print(f"Saving weights to {WEIGHTS_PATH}")
|
||||
state_dict = {k: v.float().clone() for k, v in ep.state_dict.items()}
|
||||
save_file(state_dict, str(WEIGHTS_PATH))
|
||||
|
||||
print("Done.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
282
crates/luminal_python/tests/test_hf_causal_lm_config_options.py
Normal file
282
crates/luminal_python/tests/test_hf_causal_lm_config_options.py
Normal file
@@ -0,0 +1,282 @@
|
||||
"""Hugging Face causal-LM config option support tests.
|
||||
|
||||
These tests verify that luminal matches eager Hugging Face execution across
|
||||
supported causal-LM config options using tiny public model definitions loaded
|
||||
through AutoConfig.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch._dynamo
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, GenerationConfig
|
||||
from transformers.generation.configuration_utils import ContinuousBatchingConfig
|
||||
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
|
||||
from luminal import luminal_backend
|
||||
|
||||
# Attention implementations that require optional packages.
|
||||
_ATTN_REQUIRES_PACKAGE: dict[str, str] = {
|
||||
"flash_attention_2": "flash_attn",
|
||||
"flash_attention_3": "flash_attn_3",
|
||||
"flash_attention_4": "cutlass",
|
||||
}
|
||||
|
||||
# Attention implementations known to be incompatible with tiny random models
|
||||
# (e.g. head_dim < 16 or missing scaffolding).
|
||||
_ATTN_SKIP_TINY_MODEL: set[str] = {"flex_attention", "paged_attention"}
|
||||
|
||||
_PAGED_ATTN_IMPLEMENTATIONS = tuple(
|
||||
k for k in ALL_ATTENTION_FUNCTIONS.valid_keys() if k.startswith("paged|")
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _CausalLMConfigCase:
|
||||
case_id: str
|
||||
model_id: str
|
||||
input_ids: tuple[int, ...]
|
||||
atol: float
|
||||
rtol: float
|
||||
|
||||
|
||||
_MODEL_CASES = [
|
||||
_CausalLMConfigCase(
|
||||
case_id="llama_3.2_1B",
|
||||
model_id="meta-llama/Llama-3.2-1B",
|
||||
input_ids=(1, 2, 3, 4),
|
||||
atol=1e-5,
|
||||
rtol=1e-5,
|
||||
)
|
||||
]
|
||||
|
||||
_ATTN_IMPLEMENTATIONS = tuple(
|
||||
dict.fromkeys([None, "eager", *ALL_ATTENTION_FUNCTIONS.valid_keys()])
|
||||
)
|
||||
_CUDA_BACKEND_AVAILABLE = (
|
||||
os.getenv("LUMINAL_BACKEND", "native").lower() == "cuda"
|
||||
and torch.cuda.is_available()
|
||||
)
|
||||
|
||||
|
||||
def _attn_id(attn_impl: str | None) -> str:
|
||||
return "default" if attn_impl is None else attn_impl
|
||||
|
||||
|
||||
def _base_attn_impl(attn_impl: str | None) -> str | None:
|
||||
if attn_impl is None:
|
||||
return None
|
||||
if attn_impl.startswith("paged|"):
|
||||
return attn_impl.split("|", maxsplit=1)[1]
|
||||
return attn_impl
|
||||
|
||||
|
||||
def _attn_param(attn_impl: str | None, *, allow_paged: bool) -> pytest.ParameterSet:
|
||||
marks = []
|
||||
base_attn_impl = _base_attn_impl(attn_impl)
|
||||
|
||||
if base_attn_impl == "flash_attention_2":
|
||||
marks.append(pytest.mark.skip(reason="flash_attention_2 is very slow"))
|
||||
|
||||
if attn_impl is not None and attn_impl.startswith("paged|") and not allow_paged:
|
||||
marks.append(
|
||||
pytest.mark.skip(reason=f"{attn_impl} requires continuous batching API")
|
||||
)
|
||||
|
||||
if base_attn_impl in _ATTN_REQUIRES_PACKAGE:
|
||||
pkg = _ATTN_REQUIRES_PACKAGE[base_attn_impl]
|
||||
if importlib.util.find_spec(pkg) is None:
|
||||
marks.append(
|
||||
pytest.mark.skip(
|
||||
reason=f"{attn_impl} requires package '{pkg}' which is not installed"
|
||||
)
|
||||
)
|
||||
|
||||
if base_attn_impl in _ATTN_SKIP_TINY_MODEL:
|
||||
marks.append(
|
||||
pytest.mark.skip(
|
||||
reason=f"{attn_impl} is incompatible with tiny random test models"
|
||||
)
|
||||
)
|
||||
|
||||
kwargs = {"id": _attn_id(attn_impl)}
|
||||
if marks:
|
||||
kwargs["marks"] = marks
|
||||
return pytest.param(attn_impl, **kwargs)
|
||||
|
||||
|
||||
_ATTN_PREFILL_PARAMS = tuple(
|
||||
_attn_param(attn_impl, allow_paged=False) for attn_impl in _ATTN_IMPLEMENTATIONS
|
||||
)
|
||||
_ATTN_GENERATE_BATCH_PARAMS = tuple(
|
||||
_attn_param(attn_impl, allow_paged=True) for attn_impl in _ATTN_IMPLEMENTATIONS
|
||||
)
|
||||
|
||||
|
||||
def _compare_past_key_values(lhs, rhs, *, atol: float, rtol: float) -> None:
|
||||
assert lhs is not None
|
||||
assert rhs is not None
|
||||
assert hasattr(lhs, "layers")
|
||||
assert hasattr(rhs, "layers")
|
||||
assert len(lhs.layers) == len(rhs.layers)
|
||||
|
||||
for lhs_layer, rhs_layer in zip(lhs.layers, rhs.layers):
|
||||
torch.testing.assert_close(lhs_layer.keys, rhs_layer.keys, atol=atol, rtol=rtol)
|
||||
torch.testing.assert_close(
|
||||
lhs_layer.values, rhs_layer.values, atol=atol, rtol=rtol
|
||||
)
|
||||
|
||||
|
||||
def _instantiate_model(
|
||||
model_id: str, *, use_cache: bool, attn_impl: str | None, device: torch.device
|
||||
) -> AutoModelForCausalLM:
|
||||
config = AutoConfig.from_pretrained(model_id)
|
||||
config.use_cache = use_cache
|
||||
if attn_impl is not None:
|
||||
config._attn_implementation = attn_impl
|
||||
return AutoModelForCausalLM.from_config(config).eval().to(device)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_case", _MODEL_CASES, ids=lambda case: case.case_id)
|
||||
@pytest.mark.parametrize("use_cache", [False, True], ids=["no_cache", "cache"])
|
||||
@pytest.mark.parametrize("attn_impl", _ATTN_PREFILL_PARAMS)
|
||||
def test_hf_causal_lm_config_options_match_eager(
|
||||
model_case: _CausalLMConfigCase,
|
||||
use_cache: bool,
|
||||
attn_impl: str | None,
|
||||
device: torch.device,
|
||||
configure_dynamo,
|
||||
):
|
||||
"""Compare luminal against eager HF across causal-LM config options."""
|
||||
if use_cache:
|
||||
configure_dynamo(cache_size_limit=2)
|
||||
|
||||
model = _instantiate_model(
|
||||
model_case.model_id,
|
||||
use_cache=use_cache,
|
||||
attn_impl=attn_impl,
|
||||
device=device,
|
||||
)
|
||||
input_ids = torch.tensor([model_case.input_ids], device=device)
|
||||
|
||||
with torch.no_grad():
|
||||
eager_prefill = model(input_ids)
|
||||
|
||||
compiled_model = torch.compile(model, backend=luminal_backend)
|
||||
|
||||
with torch.no_grad():
|
||||
compiled_prefill = compiled_model(input_ids)
|
||||
|
||||
torch.testing.assert_close(
|
||||
compiled_prefill.logits,
|
||||
eager_prefill.logits,
|
||||
atol=model_case.atol,
|
||||
rtol=model_case.rtol,
|
||||
)
|
||||
|
||||
if not use_cache:
|
||||
assert eager_prefill.past_key_values is None
|
||||
assert compiled_prefill.past_key_values is None
|
||||
return
|
||||
|
||||
assert eager_prefill.past_key_values is not None
|
||||
assert compiled_prefill.past_key_values is not None
|
||||
_compare_past_key_values(
|
||||
compiled_prefill.past_key_values,
|
||||
eager_prefill.past_key_values,
|
||||
atol=model_case.atol,
|
||||
rtol=model_case.rtol,
|
||||
)
|
||||
|
||||
next_token = eager_prefill.logits[:, -1, :].argmax(dim=-1, keepdim=True)
|
||||
|
||||
with torch.no_grad():
|
||||
eager_decode = model(next_token, past_key_values=eager_prefill.past_key_values)
|
||||
compiled_decode = compiled_model(
|
||||
next_token,
|
||||
past_key_values=compiled_prefill.past_key_values,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(
|
||||
compiled_decode.logits,
|
||||
eager_decode.logits,
|
||||
atol=model_case.atol,
|
||||
rtol=model_case.rtol,
|
||||
)
|
||||
assert eager_decode.past_key_values is not None
|
||||
assert compiled_decode.past_key_values is not None
|
||||
_compare_past_key_values(
|
||||
compiled_decode.past_key_values,
|
||||
eager_decode.past_key_values,
|
||||
atol=model_case.atol,
|
||||
rtol=model_case.rtol,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_case", _MODEL_CASES, ids=lambda case: case.case_id)
|
||||
@pytest.mark.parametrize("attn_impl", _ATTN_GENERATE_BATCH_PARAMS)
|
||||
@pytest.mark.skipif(not _CUDA_BACKEND_AVAILABLE, reason="generate_batch requires CUDA")
|
||||
def test_hf_generate_batch(
|
||||
model_case: _CausalLMConfigCase,
|
||||
attn_impl: str | None,
|
||||
device: torch.device,
|
||||
):
|
||||
"""Compare generate_batch output for each attention variant against eager baseline."""
|
||||
config = AutoConfig.from_pretrained(model_case.model_id)
|
||||
config.use_cache = True
|
||||
model = (
|
||||
AutoModelForCausalLM.from_config(config)
|
||||
.to(dtype=torch.bfloat16)
|
||||
.eval()
|
||||
.to(device)
|
||||
)
|
||||
|
||||
gen_config = GenerationConfig(
|
||||
do_sample=False,
|
||||
max_new_tokens=5,
|
||||
temperature=None,
|
||||
top_p=None,
|
||||
top_k=None,
|
||||
)
|
||||
cb_config = ContinuousBatchingConfig(
|
||||
block_size=256,
|
||||
use_cuda_graph=False,
|
||||
)
|
||||
|
||||
inputs = [list(model_case.input_ids)]
|
||||
|
||||
# Baseline: eager generate_batch.
|
||||
model.set_attn_implementation("eager")
|
||||
eager_outputs = model.generate_batch(
|
||||
inputs,
|
||||
generation_config=gen_config,
|
||||
continuous_batching_config=cb_config,
|
||||
progress_bar=False,
|
||||
warmup=False,
|
||||
)
|
||||
|
||||
# Variant under test.
|
||||
if attn_impl is not None:
|
||||
model.set_attn_implementation(attn_impl)
|
||||
variant_outputs = model.generate_batch(
|
||||
inputs,
|
||||
generation_config=gen_config,
|
||||
continuous_batching_config=cb_config,
|
||||
progress_bar=False,
|
||||
warmup=False,
|
||||
)
|
||||
|
||||
assert len(eager_outputs) == len(variant_outputs)
|
||||
eager_out = next(iter(eager_outputs.values()))
|
||||
variant_out = next(iter(variant_outputs.values()))
|
||||
assert eager_out.error is None, f"Eager baseline failed: {eager_out.error}"
|
||||
assert variant_out.error is None, f"Variant {attn_impl} failed: {variant_out.error}"
|
||||
assert eager_out.generated_tokens == variant_out.generated_tokens, (
|
||||
f"Token mismatch for {attn_impl}: eager={eager_out.generated_tokens} "
|
||||
f"vs variant={variant_out.generated_tokens}"
|
||||
)
|
||||
168
crates/luminal_python/tests/test_hf_causal_lm_experts_options.py
Normal file
168
crates/luminal_python/tests/test_hf_causal_lm_experts_options.py
Normal file
@@ -0,0 +1,168 @@
|
||||
"""Hugging Face causal-LM experts backend smoke tests.
|
||||
|
||||
These tests load a real pretrained text-only MoE model and compare eager
|
||||
PyTorch against torch.compile with the luminal backend across the standardized
|
||||
`experts_implementation` backends.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoConfig, AutoModelForCausalLM
|
||||
|
||||
from luminal import luminal_backend
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _HFMoeCase:
|
||||
case_id: str
|
||||
model_id: str
|
||||
input_ids: tuple[tuple[int, ...], ...]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _HFMoeBundle:
|
||||
case: _HFMoeCase
|
||||
model: AutoModelForCausalLM
|
||||
device: torch.device
|
||||
dtype: torch.dtype
|
||||
|
||||
|
||||
_MODEL_CASES = [
|
||||
_HFMoeCase(
|
||||
case_id="qwen15_moe_a27b",
|
||||
model_id="Qwen/Qwen1.5-MoE-A2.7B",
|
||||
input_ids=(
|
||||
(1, 2, 3, 4, 5, 6, 7, 8),
|
||||
(8, 7, 6, 5, 4, 3, 2, 1),
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
_EXPERTS_IMPLEMENTATIONS = ("eager", "batched_mm", "grouped_mm")
|
||||
_CUDA_BACKEND_AVAILABLE = (
|
||||
os.getenv("LUMINAL_BACKEND", "native").lower() == "cuda"
|
||||
and torch.cuda.is_available()
|
||||
)
|
||||
|
||||
pytestmark = [
|
||||
pytest.mark.slow,
|
||||
pytest.mark.skipif(
|
||||
not _CUDA_BACKEND_AVAILABLE,
|
||||
reason="HF MoE experts backend tests require the CUDA backend",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def _model_dtype(device: torch.device) -> torch.dtype:
|
||||
if device.type != "cuda":
|
||||
return torch.float32
|
||||
return torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
||||
|
||||
|
||||
def _output_tolerance(dtype: torch.dtype) -> float:
|
||||
if dtype == torch.bfloat16:
|
||||
return 5e-2
|
||||
if dtype == torch.float16:
|
||||
return 1e-2
|
||||
return 1e-3
|
||||
|
||||
|
||||
def _compare_router_logits(lhs, rhs, *, atol: float, rtol: float) -> None:
|
||||
assert lhs is not None
|
||||
assert rhs is not None
|
||||
|
||||
if isinstance(lhs, torch.Tensor):
|
||||
torch.testing.assert_close(lhs, rhs, atol=atol, rtol=rtol)
|
||||
return
|
||||
|
||||
assert len(lhs) == len(rhs)
|
||||
for lhs_layer, rhs_layer in zip(lhs, rhs):
|
||||
torch.testing.assert_close(lhs_layer, rhs_layer, atol=atol, rtol=rtol)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=_MODEL_CASES, ids=lambda case: case.case_id)
|
||||
def hf_moe_case(request: pytest.FixtureRequest) -> _HFMoeCase:
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def hf_moe_bundle(hf_moe_case: _HFMoeCase) -> _HFMoeBundle:
|
||||
case = hf_moe_case
|
||||
device = torch.device("cuda")
|
||||
dtype = _model_dtype(device)
|
||||
|
||||
config = AutoConfig.from_pretrained(case.model_id)
|
||||
config.use_cache = False
|
||||
config.output_router_logits = True
|
||||
# Keep attention fixed so this test isolates experts backends.
|
||||
config._attn_implementation = "eager"
|
||||
|
||||
model = (
|
||||
AutoModelForCausalLM.from_pretrained(
|
||||
case.model_id,
|
||||
config=config,
|
||||
torch_dtype=dtype,
|
||||
)
|
||||
.eval()
|
||||
.to(device)
|
||||
)
|
||||
return _HFMoeBundle(case=case, model=model, device=device, dtype=dtype)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"experts_implementation",
|
||||
_EXPERTS_IMPLEMENTATIONS,
|
||||
ids=list(_EXPERTS_IMPLEMENTATIONS),
|
||||
)
|
||||
def test_hf_causal_lm_experts_implementation_matches_eager(
|
||||
hf_moe_bundle: _HFMoeBundle, experts_implementation: str
|
||||
):
|
||||
model = hf_moe_bundle.model
|
||||
model.set_experts_implementation(experts_implementation)
|
||||
assert model.config._experts_implementation == experts_implementation
|
||||
|
||||
input_ids = torch.tensor(hf_moe_bundle.case.input_ids, device=hf_moe_bundle.device)
|
||||
kwargs = {
|
||||
"input_ids": input_ids,
|
||||
"use_cache": False,
|
||||
"output_router_logits": True,
|
||||
"logits_to_keep": 1,
|
||||
}
|
||||
|
||||
with torch.no_grad():
|
||||
eager_output = model(**kwargs)
|
||||
|
||||
compiled_model = torch.compile(model, backend=luminal_backend)
|
||||
with torch.no_grad():
|
||||
compiled_output = compiled_model(**kwargs)
|
||||
|
||||
atol = _output_tolerance(hf_moe_bundle.dtype)
|
||||
rtol = 1e-3
|
||||
|
||||
torch.testing.assert_close(
|
||||
compiled_output.logits,
|
||||
eager_output.logits,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
)
|
||||
_compare_router_logits(
|
||||
compiled_output.router_logits,
|
||||
eager_output.router_logits,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
)
|
||||
|
||||
if eager_output.aux_loss is not None or compiled_output.aux_loss is not None:
|
||||
assert eager_output.aux_loss is not None
|
||||
assert compiled_output.aux_loss is not None
|
||||
torch.testing.assert_close(
|
||||
compiled_output.aux_loss,
|
||||
eager_output.aux_loss,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
)
|
||||
242
crates/luminal_python/tests/test_hf_multimodal_generation.py
Normal file
242
crates/luminal_python/tests/test_hf_multimodal_generation.py
Normal file
@@ -0,0 +1,242 @@
|
||||
"""Hugging Face multimodal image-text-to-text smoke tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoConfig, AutoModelForImageTextToText, AutoProcessor
|
||||
|
||||
from luminal import luminal_backend
|
||||
|
||||
MODEL_ID = "google/gemma-3-4b-it"
|
||||
_CUDA_BACKEND_AVAILABLE = (
|
||||
os.getenv("LUMINAL_BACKEND", "native").lower() == "cuda"
|
||||
and torch.cuda.is_available()
|
||||
)
|
||||
|
||||
pytestmark = [
|
||||
pytest.mark.slow,
|
||||
pytest.mark.skipif(
|
||||
not _CUDA_BACKEND_AVAILABLE,
|
||||
reason="Gemma 3 multimodal tests require the CUDA backend",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class HFMultimodalCase:
|
||||
case_id: str
|
||||
messages_builder: Callable[[Path], list[dict]]
|
||||
max_new_tokens: int
|
||||
expects_pixel_values: bool
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Gemma3MultimodalBundle:
|
||||
model: AutoModelForImageTextToText
|
||||
processor: AutoProcessor
|
||||
device: torch.device
|
||||
dtype: torch.dtype
|
||||
|
||||
|
||||
def _build_text_only_messages(_: Path) -> list[dict]:
|
||||
return [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [{"type": "text", "text": "You are a concise assistant."}],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "In one short sentence, explain what a compiler does.",
|
||||
}
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def _build_image_to_text_messages(image_path: Path) -> list[dict]:
|
||||
return [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [{"type": "text", "text": "You are a helpful assistant."}],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "path": str(image_path)},
|
||||
{"type": "text", "text": "Describe this image in one short sentence."},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
MULTIMODAL_CASES = [
|
||||
HFMultimodalCase(
|
||||
case_id="chat_text_only",
|
||||
messages_builder=_build_text_only_messages,
|
||||
max_new_tokens=12,
|
||||
expects_pixel_values=False,
|
||||
),
|
||||
HFMultimodalCase(
|
||||
case_id="image_to_text",
|
||||
messages_builder=_build_image_to_text_messages,
|
||||
max_new_tokens=16,
|
||||
expects_pixel_values=True,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def _model_dtype(device: torch.device) -> torch.dtype:
|
||||
if device.type != "cuda":
|
||||
return torch.float32
|
||||
return torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
||||
|
||||
|
||||
def _set_greedy_generation(model) -> None:
|
||||
model.generation_config.temperature = None
|
||||
model.generation_config.top_p = None
|
||||
model.generation_config.top_k = None
|
||||
|
||||
|
||||
def _move_to_device(
|
||||
encoded: dict[str, torch.Tensor], device: torch.device, dtype: torch.dtype
|
||||
) -> dict[str, torch.Tensor]:
|
||||
result = {}
|
||||
for key, value in encoded.items():
|
||||
if not isinstance(value, torch.Tensor):
|
||||
result[key] = value
|
||||
continue
|
||||
|
||||
moved = value.to(device)
|
||||
if moved.is_floating_point():
|
||||
moved = moved.to(dtype=dtype)
|
||||
result[key] = moved
|
||||
return result
|
||||
|
||||
|
||||
def _encode_case(
|
||||
bundle: Gemma3MultimodalBundle,
|
||||
case: HFMultimodalCase,
|
||||
image_path: Path,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
encoded = bundle.processor.apply_chat_template(
|
||||
case.messages_builder(image_path),
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
encoded = _move_to_device(dict(encoded), bundle.device, bundle.dtype)
|
||||
|
||||
if case.expects_pixel_values:
|
||||
assert "pixel_values" in encoded
|
||||
assert "input_ids" in encoded
|
||||
return encoded
|
||||
|
||||
|
||||
def _generate_kwargs(
|
||||
bundle: Gemma3MultimodalBundle,
|
||||
encoded: dict[str, torch.Tensor],
|
||||
max_new_tokens: int,
|
||||
) -> dict:
|
||||
tokenizer = bundle.processor.tokenizer
|
||||
return dict(
|
||||
**encoded,
|
||||
max_new_tokens=max_new_tokens,
|
||||
do_sample=False,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
)
|
||||
|
||||
|
||||
def _logits_tolerance(dtype: torch.dtype) -> float:
|
||||
if dtype == torch.bfloat16:
|
||||
return 5e-2
|
||||
if dtype == torch.float16:
|
||||
return 1e-2
|
||||
return 1e-3
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def gemma3_multimodal_bundle() -> Gemma3MultimodalBundle:
|
||||
device = torch.device("cuda")
|
||||
dtype = _model_dtype(device)
|
||||
|
||||
config = AutoConfig.from_pretrained(MODEL_ID)
|
||||
processor = AutoProcessor.from_pretrained(MODEL_ID)
|
||||
tokenizer = processor.tokenizer
|
||||
if tokenizer.pad_token_id is None and tokenizer.eos_token is not None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
model = (
|
||||
AutoModelForImageTextToText.from_pretrained(
|
||||
MODEL_ID,
|
||||
config=config,
|
||||
torch_dtype=dtype,
|
||||
)
|
||||
.eval()
|
||||
.to(device)
|
||||
)
|
||||
_set_greedy_generation(model)
|
||||
return Gemma3MultimodalBundle(
|
||||
model=model, processor=processor, device=device, dtype=dtype
|
||||
)
|
||||
|
||||
|
||||
class TestHFMultimodalGeneration:
|
||||
@pytest.mark.parametrize("case", MULTIMODAL_CASES, ids=lambda case: case.case_id)
|
||||
def test_generate_matches_eager(
|
||||
self,
|
||||
case: HFMultimodalCase,
|
||||
gemma3_multimodal_bundle: Gemma3MultimodalBundle,
|
||||
hf_multimodal_image_path: Path,
|
||||
):
|
||||
encoded = _encode_case(gemma3_multimodal_bundle, case, hf_multimodal_image_path)
|
||||
kwargs = _generate_kwargs(
|
||||
gemma3_multimodal_bundle, encoded, case.max_new_tokens
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
eager_output = gemma3_multimodal_bundle.model.generate(**kwargs)
|
||||
|
||||
compiled_model = torch.compile(
|
||||
gemma3_multimodal_bundle.model, backend=luminal_backend
|
||||
)
|
||||
with torch.no_grad():
|
||||
compiled_output = compiled_model.generate(**kwargs)
|
||||
|
||||
torch.testing.assert_close(compiled_output, eager_output)
|
||||
|
||||
@pytest.mark.parametrize("case", MULTIMODAL_CASES, ids=lambda case: case.case_id)
|
||||
def test_forward_logits_match_eager(
|
||||
self,
|
||||
case: HFMultimodalCase,
|
||||
gemma3_multimodal_bundle: Gemma3MultimodalBundle,
|
||||
hf_multimodal_image_path: Path,
|
||||
):
|
||||
encoded = _encode_case(gemma3_multimodal_bundle, case, hf_multimodal_image_path)
|
||||
|
||||
with torch.no_grad():
|
||||
eager_out = gemma3_multimodal_bundle.model(**encoded)
|
||||
|
||||
compiled_model = torch.compile(
|
||||
gemma3_multimodal_bundle.model, backend=luminal_backend
|
||||
)
|
||||
with torch.no_grad():
|
||||
compiled_out = compiled_model(**encoded)
|
||||
|
||||
atol = _logits_tolerance(gemma3_multimodal_bundle.dtype)
|
||||
torch.testing.assert_close(
|
||||
compiled_out.logits,
|
||||
eager_out.logits,
|
||||
atol=atol,
|
||||
rtol=1e-3,
|
||||
)
|
||||
288
crates/luminal_python/tests/test_hf_text_generation.py
Normal file
288
crates/luminal_python/tests/test_hf_text_generation.py
Normal file
@@ -0,0 +1,288 @@
|
||||
"""Hugging Face text-generation smoke tests.
|
||||
|
||||
These tests intentionally download real Hugging Face checkpoints, configs,
|
||||
and tokenizers. They compare eager PyTorch output against torch.compile
|
||||
with the luminal backend to verify numerical equivalence.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch._dynamo
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from luminal import luminal_backend
|
||||
|
||||
SIMPLE_DENSE_MODELS = [
|
||||
"meta-llama/Llama-3.2-1B",
|
||||
"meta-llama/Llama-3.1-1B",
|
||||
"Qwen/Qwen3-8B",
|
||||
]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class HFTextGenerationBundle:
|
||||
model: AutoModelForCausalLM
|
||||
tokenizer: AutoTokenizer
|
||||
device: torch.device
|
||||
|
||||
|
||||
def _load_model_and_tokenizer(
|
||||
model_id: str, device: torch.device
|
||||
) -> tuple[AutoModelForCausalLM, AutoTokenizer]:
|
||||
"""Load a pretrained HF causal LM and its tokenizer, ready for generation."""
|
||||
config = AutoConfig.from_pretrained(model_id)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
|
||||
if tokenizer.pad_token_id is None and tokenizer.eos_token is not None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
dtype = torch.float16 if device.type == "cuda" else torch.float32
|
||||
model = (
|
||||
AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
config=config,
|
||||
torch_dtype=dtype,
|
||||
)
|
||||
.eval()
|
||||
.to(device)
|
||||
)
|
||||
|
||||
model.generation_config.temperature = None
|
||||
model.generation_config.top_p = None
|
||||
model.generation_config.top_k = None
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def _encode(
|
||||
tokenizer: AutoTokenizer, prompt: str, device: torch.device
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""Tokenize a prompt and move tensors to device."""
|
||||
encoded = tokenizer(prompt, return_tensors="pt")
|
||||
result = {"input_ids": encoded["input_ids"].to(device)}
|
||||
if encoded.get("attention_mask") is not None:
|
||||
result["attention_mask"] = encoded["attention_mask"].to(device)
|
||||
return result
|
||||
|
||||
|
||||
def _generate_kwargs(
|
||||
tokenizer: AutoTokenizer,
|
||||
encoded: dict[str, torch.Tensor],
|
||||
max_new_tokens: int = 6,
|
||||
) -> dict:
|
||||
"""Build kwargs dict for model.generate()."""
|
||||
return dict(
|
||||
**encoded,
|
||||
max_new_tokens=max_new_tokens,
|
||||
do_sample=False,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def hf_text_bundle(model_id: str, device: torch.device) -> HFTextGenerationBundle:
|
||||
model, tokenizer = _load_model_and_tokenizer(model_id, device)
|
||||
return HFTextGenerationBundle(model=model, tokenizer=tokenizer, device=device)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
class TestHFGeneration:
|
||||
"""End-to-end tests comparing eager PyTorch against torch.compile with luminal."""
|
||||
|
||||
@pytest.mark.parametrize("model_id", SIMPLE_DENSE_MODELS)
|
||||
def test_capital_of_france(self, hf_text_bundle: HFTextGenerationBundle):
|
||||
"""Basic greedy generation -- the original smoke test."""
|
||||
encoded = _encode(
|
||||
hf_text_bundle.tokenizer,
|
||||
"What is the capital of France ",
|
||||
hf_text_bundle.device,
|
||||
)
|
||||
kwargs = _generate_kwargs(hf_text_bundle.tokenizer, encoded)
|
||||
|
||||
with torch.no_grad():
|
||||
eager_output = hf_text_bundle.model.generate(**kwargs)
|
||||
|
||||
compiled_model = torch.compile(hf_text_bundle.model, backend=luminal_backend)
|
||||
with torch.no_grad():
|
||||
compiled_output = compiled_model.generate(**kwargs)
|
||||
|
||||
torch.testing.assert_close(compiled_output, eager_output)
|
||||
|
||||
@pytest.mark.parametrize("model_id", SIMPLE_DENSE_MODELS)
|
||||
def test_forward_logits(self, hf_text_bundle: HFTextGenerationBundle):
|
||||
"""Forward pass only -- compare raw logits, not generated tokens."""
|
||||
encoded = _encode(
|
||||
hf_text_bundle.tokenizer, "The quick brown fox", hf_text_bundle.device
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
eager_out = hf_text_bundle.model(**encoded)
|
||||
|
||||
compiled_model = torch.compile(hf_text_bundle.model, backend=luminal_backend)
|
||||
with torch.no_grad():
|
||||
compiled_out = compiled_model(**encoded)
|
||||
|
||||
dtype = next(hf_text_bundle.model.parameters()).dtype
|
||||
atol = 1e-2 if dtype == torch.float16 else 1e-3
|
||||
torch.testing.assert_close(
|
||||
compiled_out.logits, eager_out.logits, atol=atol, rtol=1e-3
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("model_id", SIMPLE_DENSE_MODELS[:1])
|
||||
@pytest.mark.parametrize(
|
||||
"prompt",
|
||||
[
|
||||
"Hi",
|
||||
"What is the capital of France",
|
||||
"Explain the theory of general relativity in simple terms that a high school student could understand",
|
||||
],
|
||||
ids=["short", "medium", "long"],
|
||||
)
|
||||
def test_variable_length_prompts(
|
||||
self,
|
||||
prompt: str,
|
||||
hf_text_bundle: HFTextGenerationBundle,
|
||||
):
|
||||
"""Generate with prompts of different lengths -- tests dynamic shape handling."""
|
||||
encoded = _encode(hf_text_bundle.tokenizer, prompt, hf_text_bundle.device)
|
||||
kwargs = _generate_kwargs(hf_text_bundle.tokenizer, encoded, max_new_tokens=4)
|
||||
|
||||
with torch.no_grad():
|
||||
eager_output = hf_text_bundle.model.generate(**kwargs)
|
||||
|
||||
compiled_model = torch.compile(hf_text_bundle.model, backend=luminal_backend)
|
||||
with torch.no_grad():
|
||||
compiled_output = compiled_model.generate(**kwargs)
|
||||
|
||||
torch.testing.assert_close(compiled_output, eager_output)
|
||||
|
||||
@pytest.mark.parametrize("model_id", SIMPLE_DENSE_MODELS[:1])
|
||||
def test_chat_template_generation(
|
||||
self,
|
||||
model_id: str,
|
||||
hf_text_bundle: HFTextGenerationBundle,
|
||||
):
|
||||
"""Generate using chat-templated input with special tokens."""
|
||||
if hf_text_bundle.tokenizer.chat_template is None:
|
||||
pytest.skip(f"{model_id} has no chat template")
|
||||
|
||||
messages = [{"role": "user", "content": "What is 2+2?"}]
|
||||
encoded = hf_text_bundle.tokenizer.apply_chat_template(
|
||||
messages,
|
||||
return_tensors="pt",
|
||||
add_generation_prompt=True,
|
||||
return_dict=True,
|
||||
)
|
||||
encoded = {k: v.to(hf_text_bundle.device) for k, v in encoded.items()}
|
||||
kwargs = _generate_kwargs(hf_text_bundle.tokenizer, encoded)
|
||||
|
||||
with torch.no_grad():
|
||||
eager_output = hf_text_bundle.model.generate(**kwargs)
|
||||
|
||||
compiled_model = torch.compile(hf_text_bundle.model, backend=luminal_backend)
|
||||
with torch.no_grad():
|
||||
compiled_output = compiled_model.generate(**kwargs)
|
||||
|
||||
torch.testing.assert_close(compiled_output, eager_output)
|
||||
|
||||
@pytest.mark.parametrize("model_id", SIMPLE_DENSE_MODELS[:1])
|
||||
@pytest.mark.parametrize("max_new_tokens", [20, 50])
|
||||
def test_longer_generation(
|
||||
self,
|
||||
max_new_tokens: int,
|
||||
hf_text_bundle: HFTextGenerationBundle,
|
||||
):
|
||||
"""Generate many tokens to stress KV cache over extended decode loop."""
|
||||
encoded = _encode(
|
||||
hf_text_bundle.tokenizer, "Once upon a time", hf_text_bundle.device
|
||||
)
|
||||
kwargs = _generate_kwargs(
|
||||
hf_text_bundle.tokenizer,
|
||||
encoded,
|
||||
max_new_tokens=max_new_tokens,
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
eager_output = hf_text_bundle.model.generate(**kwargs)
|
||||
|
||||
compiled_model = torch.compile(hf_text_bundle.model, backend=luminal_backend)
|
||||
with torch.no_grad():
|
||||
compiled_output = compiled_model.generate(**kwargs)
|
||||
|
||||
torch.testing.assert_close(compiled_output, eager_output)
|
||||
|
||||
@pytest.mark.parametrize("model_id", SIMPLE_DENSE_MODELS[:1])
|
||||
def test_greedy_determinism(
|
||||
self,
|
||||
hf_text_bundle: HFTextGenerationBundle,
|
||||
configure_dynamo,
|
||||
):
|
||||
"""Greedy generation produces identical results on repeated calls."""
|
||||
configure_dynamo(cache_size_limit=4)
|
||||
|
||||
encoded = _encode(
|
||||
hf_text_bundle.tokenizer,
|
||||
"The meaning of life is",
|
||||
hf_text_bundle.device,
|
||||
)
|
||||
kwargs = _generate_kwargs(hf_text_bundle.tokenizer, encoded, max_new_tokens=10)
|
||||
|
||||
compiled_model = torch.compile(hf_text_bundle.model, backend=luminal_backend)
|
||||
with torch.no_grad():
|
||||
output_1 = compiled_model.generate(**kwargs)
|
||||
output_2 = compiled_model.generate(**kwargs)
|
||||
|
||||
torch.testing.assert_close(output_1, output_2)
|
||||
|
||||
@pytest.mark.parametrize("model_id", SIMPLE_DENSE_MODELS[:1])
|
||||
def test_reuse_compiled_model(
|
||||
self,
|
||||
hf_text_bundle: HFTextGenerationBundle,
|
||||
configure_dynamo,
|
||||
):
|
||||
"""Call the same compiled model multiple times with different prompts."""
|
||||
configure_dynamo(cache_size_limit=8)
|
||||
|
||||
compiled_model = torch.compile(hf_text_bundle.model, backend=luminal_backend)
|
||||
|
||||
prompts = [
|
||||
"The capital of France is",
|
||||
"Water boils at",
|
||||
"The largest planet in our solar system is",
|
||||
]
|
||||
|
||||
for prompt in prompts:
|
||||
encoded = _encode(hf_text_bundle.tokenizer, prompt, hf_text_bundle.device)
|
||||
kwargs = _generate_kwargs(
|
||||
hf_text_bundle.tokenizer, encoded, max_new_tokens=4
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
eager_output = hf_text_bundle.model.generate(**kwargs)
|
||||
compiled_output = compiled_model.generate(**kwargs)
|
||||
|
||||
torch.testing.assert_close(compiled_output, eager_output)
|
||||
|
||||
@pytest.mark.parametrize("model_id", SIMPLE_DENSE_MODELS[:1])
|
||||
def test_batched_inference(self, hf_text_bundle: HFTextGenerationBundle):
|
||||
"""Batched generation with multiple prompts and left-padding."""
|
||||
hf_text_bundle.tokenizer.padding_side = "left"
|
||||
|
||||
prompts = ["Hello", "What is the capital of France"]
|
||||
encoded = hf_text_bundle.tokenizer(prompts, return_tensors="pt", padding=True)
|
||||
encoded = {k: v.to(hf_text_bundle.device) for k, v in encoded.items()}
|
||||
kwargs = _generate_kwargs(hf_text_bundle.tokenizer, encoded, max_new_tokens=4)
|
||||
|
||||
with torch.no_grad():
|
||||
eager_output = hf_text_bundle.model.generate(**kwargs)
|
||||
|
||||
compiled_model = torch.compile(hf_text_bundle.model, backend=luminal_backend)
|
||||
with torch.no_grad():
|
||||
compiled_output = compiled_model.generate(**kwargs)
|
||||
|
||||
torch.testing.assert_close(compiled_output, eager_output)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -392,3 +392,60 @@ def test_hf_llama38b_full(device: torch.device):
|
||||
assert torch.allclose(out.logits, ref.logits, atol=1e-5), (
|
||||
f"max_diff={torch.max(torch.abs(out.logits - ref.logits)).item():.2e}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_hf_llama38b_cached_onnx(
|
||||
llama38b_onnx_path, llama38b_ref_logits: torch.Tensor
|
||||
):
|
||||
import os
|
||||
|
||||
import luminal
|
||||
|
||||
backend = os.environ.get("LUMINAL_BACKEND", "cuda")
|
||||
|
||||
graph = luminal.process_onnx(str(llama38b_onnx_path), backend)
|
||||
print("Compiled luminal ONNX graph")
|
||||
|
||||
graph.set_input("input_ids", [float(t) for t in [1, 2, 3, 4]])
|
||||
graph.run()
|
||||
|
||||
logits_data = graph.get_output("logits")
|
||||
logits_shape = graph.output_shapes[0]
|
||||
logits = torch.tensor(logits_data, dtype=torch.float32).reshape(logits_shape)
|
||||
|
||||
print(f"Loaded reference logits: {llama38b_ref_logits.shape}")
|
||||
print(f"Output logits shape: {logits.shape}")
|
||||
|
||||
assert torch.allclose(logits, llama38b_ref_logits, atol=1e-3), (
|
||||
f"max_diff={torch.max(torch.abs(logits - llama38b_ref_logits)).item():.2e}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_hf_llama38b_cached_pt2(
|
||||
llama38b_pt2_path, llama38b_weights_path, llama38b_ref_logits: torch.Tensor
|
||||
):
|
||||
import os
|
||||
|
||||
import luminal
|
||||
from luminal import CompiledModel
|
||||
|
||||
backend = os.environ.get("LUMINAL_BACKEND", "cuda")
|
||||
backend_name = "cuda" if backend == "cuda" else "cpu"
|
||||
|
||||
compiled_inner = luminal.compile_pt2(
|
||||
str(llama38b_pt2_path), str(llama38b_weights_path), backend_name, 0
|
||||
)
|
||||
compiled = CompiledModel(compiled_inner)
|
||||
print("Compiled luminal PT2 graph")
|
||||
|
||||
input_ids = torch.tensor([[1, 2, 3, 4]])
|
||||
logits = compiled(input_ids)[0]
|
||||
|
||||
print(f"Loaded reference logits: {llama38b_ref_logits.shape}")
|
||||
print(f"Output logits shape: {logits.shape}")
|
||||
|
||||
assert torch.allclose(logits, llama38b_ref_logits, atol=1e-3), (
|
||||
f"max_diff={torch.max(torch.abs(logits - llama38b_ref_logits)).item():.2e}"
|
||||
)
|
||||
|
||||
@@ -41,9 +41,6 @@ class AddAddTestModel(torch.nn.Module):
|
||||
|
||||
|
||||
class AddConstantTestModel(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
return x + 10
|
||||
|
||||
@@ -59,25 +56,16 @@ class LinearLayerModel(torch.nn.Module):
|
||||
|
||||
|
||||
class SqrtTestModel(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x.sqrt()
|
||||
|
||||
|
||||
class SinTestModel(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.sin(x)
|
||||
|
||||
|
||||
class CosTestModel(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.cos(x)
|
||||
|
||||
@@ -94,9 +82,6 @@ class SubTestModel(torch.nn.Module):
|
||||
class TransposeTestModel(torch.nn.Module):
|
||||
"""Test basic 2D transpose (matrix transpose)."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x.t() # 2D transpose
|
||||
|
||||
@@ -104,9 +89,6 @@ class TransposeTestModel(torch.nn.Module):
|
||||
class Transpose3DTestModel(torch.nn.Module):
|
||||
"""Test 3D transpose with explicit permutation."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x.permute(2, 0, 1) # Rotate dimensions
|
||||
|
||||
@@ -114,9 +96,6 @@ class Transpose3DTestModel(torch.nn.Module):
|
||||
class Transpose4DTestModel(torch.nn.Module):
|
||||
"""Test 4D transpose (NCHW -> NHWC)."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x.permute(0, 2, 3, 1) # Common in CNNs
|
||||
|
||||
@@ -124,9 +103,6 @@ class Transpose4DTestModel(torch.nn.Module):
|
||||
class TransposeReverseTestModel(torch.nn.Module):
|
||||
"""Test reverse permutation (default transpose behavior)."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
dims = list(range(x.ndim))
|
||||
return x.permute(*reversed(dims))
|
||||
@@ -151,9 +127,6 @@ class TransposeInExpressionModel(torch.nn.Module):
|
||||
class ConstantScalarFloatModel(torch.nn.Module):
|
||||
"""Test scalar constant (broadcasts to input shape)."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
constant = torch.tensor(10.5).to(x.device)
|
||||
return x + constant
|
||||
@@ -162,9 +135,6 @@ class ConstantScalarFloatModel(torch.nn.Module):
|
||||
class Constant1DArrayFloatModel(torch.nn.Module):
|
||||
"""Test 1D array constant."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
constant = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0]).to(x.device)
|
||||
return x * constant
|
||||
@@ -173,9 +143,6 @@ class Constant1DArrayFloatModel(torch.nn.Module):
|
||||
class Constant2DMatrixFloatModel(torch.nn.Module):
|
||||
"""Test 2D matrix constant."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
constant = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]).to(x.device)
|
||||
return x + constant
|
||||
@@ -184,9 +151,6 @@ class Constant2DMatrixFloatModel(torch.nn.Module):
|
||||
class ConstantRawDataFloatModel(torch.nn.Module):
|
||||
"""Test constant with specific values (tests raw data format)."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
constant = torch.tensor([7.5, 8.5, 9.5]).to(x.device)
|
||||
return x + constant
|
||||
@@ -195,9 +159,6 @@ class ConstantRawDataFloatModel(torch.nn.Module):
|
||||
class ConstantInt32ConversionModel(torch.nn.Module):
|
||||
"""Test INT32 constant values (PyTorch exports as integers)."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
constant = torch.tensor([1, 2, 3, 4, 5], dtype=torch.int32).to(x.device)
|
||||
return x + constant.float()
|
||||
@@ -206,9 +167,6 @@ class ConstantInt32ConversionModel(torch.nn.Module):
|
||||
class ConstantInt64ConversionModel(torch.nn.Module):
|
||||
"""Test INT64 constant values."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
constant = torch.tensor([100, 200, 300], dtype=torch.int64).to(x.device)
|
||||
return x * constant.float()
|
||||
@@ -217,9 +175,6 @@ class ConstantInt64ConversionModel(torch.nn.Module):
|
||||
class ConstantFloat64ConversionModel(torch.nn.Module):
|
||||
"""Test FLOAT64 (double) constant values."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
constant = torch.tensor([1.5, 2.5, 3.5], dtype=torch.float64).to(x.device)
|
||||
return x * constant.float()
|
||||
@@ -228,9 +183,6 @@ class ConstantFloat64ConversionModel(torch.nn.Module):
|
||||
class ConstantBoolConversionModel(torch.nn.Module):
|
||||
"""Test boolean constant values (converted to 0.0/1.0)."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
constant = torch.tensor([True, False, True, False, True], dtype=torch.bool).to(
|
||||
x.device
|
||||
@@ -241,9 +193,6 @@ class ConstantBoolConversionModel(torch.nn.Module):
|
||||
class ConstantInt64RawDataModel(torch.nn.Module):
|
||||
"""Test INT64 constant with large values (tests raw data path)."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
constant = torch.tensor([1000, 2000, 3000], dtype=torch.int64).to(x.device)
|
||||
return x + constant.float()
|
||||
@@ -252,9 +201,6 @@ class ConstantInt64RawDataModel(torch.nn.Module):
|
||||
class ConstantNegativeValuesModel(torch.nn.Module):
|
||||
"""Test negative constant values."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
constant = torch.tensor([-5.0, -10.0, -15.0]).to(x.device)
|
||||
return x + constant
|
||||
@@ -263,9 +209,6 @@ class ConstantNegativeValuesModel(torch.nn.Module):
|
||||
class ConstantZeroValueModel(torch.nn.Module):
|
||||
"""Test all-zero constant."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
constant = torch.tensor([0.0, 0.0, 0.0, 0.0]).to(x.device)
|
||||
return x * constant
|
||||
@@ -274,9 +217,6 @@ class ConstantZeroValueModel(torch.nn.Module):
|
||||
class ConstantMultipleInGraphModel(torch.nn.Module):
|
||||
"""Test multiple constants in one graph."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
const1 = torch.tensor([10.0, 20.0, 30.0]).to(x.device)
|
||||
const2 = torch.tensor([1.0, 2.0, 3.0]).to(x.device)
|
||||
@@ -290,9 +230,6 @@ class ConstantMultipleInGraphModel(torch.nn.Module):
|
||||
class CastDoubleToFloatModel(torch.nn.Module):
|
||||
"""Test downcast: Double (FLOAT64) -> Float."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# Input will be float64, cast to float32
|
||||
return x.to(torch.float32)
|
||||
@@ -301,9 +238,6 @@ class CastDoubleToFloatModel(torch.nn.Module):
|
||||
class CastInt32ToFloatModel(torch.nn.Module):
|
||||
"""Test INT32 -> Float conversion."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x.to(torch.float32)
|
||||
|
||||
@@ -311,9 +245,6 @@ class CastInt32ToFloatModel(torch.nn.Module):
|
||||
class CastInt64ToFloatModel(torch.nn.Module):
|
||||
"""Test INT64 -> Float conversion."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x.to(torch.float32)
|
||||
|
||||
@@ -321,9 +252,6 @@ class CastInt64ToFloatModel(torch.nn.Module):
|
||||
class CastBoolToFloatModel(torch.nn.Module):
|
||||
"""Test BOOL -> Float conversion (non-zero -> 1.0, zero -> 0.0)."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x.to(torch.float32)
|
||||
|
||||
@@ -331,9 +259,6 @@ class CastBoolToFloatModel(torch.nn.Module):
|
||||
class CastInComputationGraphModel(torch.nn.Module):
|
||||
"""Test Cast node followed by an operation (Cast + Add)."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
casted = x.to(torch.float32)
|
||||
constant = torch.tensor([2.0, 2.0, 2.0]).to(x.device)
|
||||
@@ -343,9 +268,6 @@ class CastInComputationGraphModel(torch.nn.Module):
|
||||
class CastWith2DTensorModel(torch.nn.Module):
|
||||
"""Test Cast with 2D tensor (matrix)."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x.to(torch.float32)
|
||||
|
||||
@@ -353,9 +275,6 @@ class CastWith2DTensorModel(torch.nn.Module):
|
||||
class CastNegativeValuesModel(torch.nn.Module):
|
||||
"""Test Cast with negative integer values."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x.to(torch.float32)
|
||||
|
||||
@@ -363,9 +282,6 @@ class CastNegativeValuesModel(torch.nn.Module):
|
||||
class CastScalarValueModel(torch.nn.Module):
|
||||
"""Test Cast with scalar (single element)."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x.to(torch.float32)
|
||||
|
||||
@@ -389,9 +305,6 @@ class ModTestModel(torch.nn.Module):
|
||||
class ModByConstantModel(torch.nn.Module):
|
||||
"""Tests modulo with an inline constant tensor (ONNX Constant node)."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
constant = torch.tensor([3.0, 4.0, 5.0]).to(x.device)
|
||||
return x.fmod(constant)
|
||||
|
||||
@@ -1,154 +1,118 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch._dynamo
|
||||
from test_models import (
|
||||
SigmoidTestModel,
|
||||
SigmoidInExpressionModel,
|
||||
TanhTestModel,
|
||||
TanhInExpressionModel,
|
||||
ReluTestModel,
|
||||
ReluAllNegativeModel,
|
||||
ReluInExpressionModel,
|
||||
AbsTestModel,
|
||||
AbsAllNegativeModel,
|
||||
AbsInExpressionModel,
|
||||
NegTestModel,
|
||||
NegAllPositiveModel,
|
||||
NegInExpressionModel,
|
||||
ClipTestModel,
|
||||
ClipMinOnlyTestModel,
|
||||
ClipMaxOnlyTestModel,
|
||||
)
|
||||
import test_models as tm
|
||||
|
||||
from luminal import luminal_backend
|
||||
|
||||
# ── Sigmoid ──────────────────────────────────────────────────────────────────
|
||||
|
||||
Args = tuple[torch.Tensor, ...]
|
||||
Kwargs = dict[str, torch.Tensor]
|
||||
InputFactory = Callable[[torch.device], tuple[Args, Kwargs]]
|
||||
|
||||
|
||||
def test_sigmoid(device):
|
||||
model = SigmoidTestModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x = torch.rand((5, 5), device=device) * 2 - 1 # mixed positive/negative
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
@dataclass(frozen=True)
|
||||
class UnaryCase:
|
||||
id: str
|
||||
model_factory: Callable[[], torch.nn.Module]
|
||||
input_factory: InputFactory
|
||||
|
||||
|
||||
def test_sigmoid_in_expression(device):
|
||||
model = SigmoidInExpressionModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x = torch.rand((5, 5), device=device)
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
UNARY_CASES: list[UnaryCase] = [
|
||||
UnaryCase(
|
||||
id="sigmoid",
|
||||
model_factory=tm.SigmoidTestModel,
|
||||
input_factory=lambda device: ((torch.rand((5, 5), device=device) * 2 - 1,), {}),
|
||||
),
|
||||
UnaryCase(
|
||||
id="sigmoid_in_expression",
|
||||
model_factory=tm.SigmoidInExpressionModel,
|
||||
input_factory=lambda device: ((torch.rand((5, 5), device=device),), {}),
|
||||
),
|
||||
UnaryCase(
|
||||
id="tanh",
|
||||
model_factory=tm.TanhTestModel,
|
||||
input_factory=lambda device: ((torch.rand((5, 5), device=device) * 2 - 1,), {}),
|
||||
),
|
||||
UnaryCase(
|
||||
id="tanh_in_expression",
|
||||
model_factory=tm.TanhInExpressionModel,
|
||||
input_factory=lambda device: ((torch.rand((5, 5), device=device),), {}),
|
||||
),
|
||||
UnaryCase(
|
||||
id="relu",
|
||||
model_factory=tm.ReluTestModel,
|
||||
input_factory=lambda device: ((torch.rand((5, 5), device=device) * 2 - 1,), {}),
|
||||
),
|
||||
UnaryCase(
|
||||
id="relu_all_negative",
|
||||
model_factory=tm.ReluAllNegativeModel,
|
||||
input_factory=lambda device: ((-torch.rand((5, 5), device=device),), {}),
|
||||
),
|
||||
UnaryCase(
|
||||
id="relu_in_expression",
|
||||
model_factory=tm.ReluInExpressionModel,
|
||||
input_factory=lambda device: ((torch.rand((5, 5), device=device) * 2 - 1,), {}),
|
||||
),
|
||||
UnaryCase(
|
||||
id="abs",
|
||||
model_factory=tm.AbsTestModel,
|
||||
input_factory=lambda device: ((torch.rand((5, 5), device=device) * 2 - 1,), {}),
|
||||
),
|
||||
UnaryCase(
|
||||
id="abs_all_negative",
|
||||
model_factory=tm.AbsAllNegativeModel,
|
||||
input_factory=lambda device: ((-torch.rand((5, 5), device=device),), {}),
|
||||
),
|
||||
UnaryCase(
|
||||
id="abs_in_expression",
|
||||
model_factory=tm.AbsInExpressionModel,
|
||||
input_factory=lambda device: ((torch.rand((5, 5), device=device) * 2 - 1,), {}),
|
||||
),
|
||||
UnaryCase(
|
||||
id="neg",
|
||||
model_factory=tm.NegTestModel,
|
||||
input_factory=lambda device: ((torch.rand((5, 5), device=device) * 2 - 1,), {}),
|
||||
),
|
||||
UnaryCase(
|
||||
id="neg_all_positive",
|
||||
model_factory=tm.NegAllPositiveModel,
|
||||
input_factory=lambda device: ((torch.rand((5, 5), device=device),), {}),
|
||||
),
|
||||
UnaryCase(
|
||||
id="neg_in_expression",
|
||||
model_factory=tm.NegInExpressionModel,
|
||||
input_factory=lambda device: ((torch.rand((5, 5), device=device) * 2 - 1,), {}),
|
||||
),
|
||||
UnaryCase(
|
||||
id="clip",
|
||||
model_factory=tm.ClipTestModel,
|
||||
input_factory=lambda device: ((torch.rand((5, 5), device=device) * 4 - 2,), {}),
|
||||
),
|
||||
UnaryCase(
|
||||
id="clip_min_only",
|
||||
model_factory=tm.ClipMinOnlyTestModel,
|
||||
input_factory=lambda device: ((torch.rand((5, 5), device=device) * 4 - 2,), {}),
|
||||
),
|
||||
UnaryCase(
|
||||
id="clip_max_only",
|
||||
model_factory=tm.ClipMaxOnlyTestModel,
|
||||
input_factory=lambda device: ((torch.rand((5, 5), device=device) * 4 - 2,), {}),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
# ── Tanh ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_tanh(device):
|
||||
model = TanhTestModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x = torch.rand((5, 5), device=device) * 2 - 1
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
|
||||
|
||||
def test_tanh_in_expression(device):
|
||||
model = TanhInExpressionModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x = torch.rand((5, 5), device=device)
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
|
||||
|
||||
# ── Relu ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_relu(device):
|
||||
model = ReluTestModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x = torch.rand((5, 5), device=device) * 2 - 1 # mixed positive/negative
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
|
||||
|
||||
def test_relu_all_negative(device):
|
||||
model = ReluAllNegativeModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x = -torch.rand((5, 5), device=device) # all negative -> output all zeros
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
|
||||
|
||||
def test_relu_in_expression(device):
|
||||
model = ReluInExpressionModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x = torch.rand((5, 5), device=device) * 2 - 1
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
|
||||
|
||||
# ── Abs ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_abs(device):
|
||||
model = AbsTestModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x = torch.rand((5, 5), device=device) * 2 - 1 # mixed positive/negative
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
|
||||
|
||||
def test_abs_all_negative(device):
|
||||
model = AbsAllNegativeModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x = -torch.rand((5, 5), device=device) # all negative
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
|
||||
|
||||
def test_abs_in_expression(device):
|
||||
model = AbsInExpressionModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x = torch.rand((5, 5), device=device) * 2 - 1
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
|
||||
|
||||
# ── Neg ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_neg(device):
|
||||
model = NegTestModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x = torch.rand((5, 5), device=device) * 2 - 1 # mixed positive/negative
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
|
||||
|
||||
def test_neg_all_positive(device):
|
||||
model = NegAllPositiveModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x = torch.rand((5, 5), device=device) # all positive -> output all negative
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
|
||||
|
||||
def test_neg_in_expression(device):
|
||||
model = NegInExpressionModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x = torch.rand((5, 5), device=device) * 2 - 1
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
|
||||
|
||||
# ── Clip ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_clip(device):
|
||||
"""Clip tensor values to [-0.5, 0.5]."""
|
||||
model = ClipTestModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x = torch.rand((5, 5), device=device) * 4 - 2 # range [-2, 2]
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
|
||||
|
||||
def test_clip_min_only(device):
|
||||
"""Clip tensor values to [0.0, +inf]."""
|
||||
model = ClipMinOnlyTestModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x = torch.rand((5, 5), device=device) * 4 - 2
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
|
||||
|
||||
def test_clip_max_only(device):
|
||||
"""Clip tensor values to [-inf, 0.5]."""
|
||||
model = ClipMaxOnlyTestModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x = torch.rand((5, 5), device=device) * 4 - 2
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
class TestUnaryOps:
|
||||
@pytest.mark.parametrize("case", UNARY_CASES, ids=lambda case: case.id)
|
||||
def test_matches_eager(self, case: UnaryCase, device: torch.device) -> None:
|
||||
model = case.model_factory().to(device)
|
||||
compiled_model = torch.compile(model, backend=luminal_backend)
|
||||
args, kwargs = case.input_factory(device)
|
||||
torch.testing.assert_close(
|
||||
compiled_model(*args, **kwargs),
|
||||
model(*args, **kwargs),
|
||||
atol=1e-5,
|
||||
rtol=1e-5,
|
||||
)
|
||||
|
||||
20
skills-lock.json
Normal file
20
skills-lock.json
Normal file
@@ -0,0 +1,20 @@
|
||||
{
|
||||
"version": 1,
|
||||
"skills": {
|
||||
"ruff": {
|
||||
"source": "astral-sh/claude-code-plugins",
|
||||
"sourceType": "github",
|
||||
"computedHash": "905f1c2b66e722ab534d7cbeceafb6ee37d32e3bfe51f3e68c661c51eb19a776"
|
||||
},
|
||||
"ty": {
|
||||
"source": "astral-sh/claude-code-plugins",
|
||||
"sourceType": "github",
|
||||
"computedHash": "5d82b319a84d296fae47f2664228bfac1a36dd832cb487b485a65ec36b3df21f"
|
||||
},
|
||||
"uv": {
|
||||
"source": "astral-sh/claude-code-plugins",
|
||||
"sourceType": "github",
|
||||
"computedHash": "b1ca9e9826d906f6ee6ea172da50018ec13bb622ac9ba7204bda8b10e7fc8d0a"
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user