mirror of
https://github.com/jafioti/luminal.git
synced 2026-06-01 21:49:47 +09:00
Compare commits
54 Commits
nvidia-dev
...
pytest-cla
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7ee5b54438 | ||
|
|
389c05abeb | ||
|
|
dcc2c9cbb4 | ||
|
|
a9af4c3923 | ||
|
|
3092d0d68b | ||
|
|
8a2bd714ac | ||
|
|
54a26a044c | ||
|
|
5a0d3f87cc | ||
|
|
a28b755245 | ||
|
|
fd83534e53 | ||
|
|
b5d984c3fa | ||
|
|
64a5ca41b5 | ||
|
|
9bda47714a | ||
|
|
9e513b6589 | ||
|
|
a62d728bd7 | ||
|
|
4114714d3f | ||
|
|
6191597571 | ||
|
|
253cd95ab0 | ||
|
|
d7e396ba5b | ||
|
|
1a53626716 | ||
|
|
4329d68adc | ||
|
|
989e7e2d44 | ||
|
|
019972cdd4 | ||
|
|
d7a3f468bd | ||
|
|
c504fbf8a1 | ||
|
|
625be7f4da | ||
|
|
c2a17a4854 | ||
|
|
5c60f1d768 | ||
|
|
4c51e3ea84 | ||
|
|
846551aa6f | ||
|
|
c26076bc75 | ||
|
|
871629b770 | ||
|
|
c6dfa9c62f | ||
|
|
90e3a915d7 | ||
|
|
56cb237aa2 | ||
|
|
a2c42b35c8 | ||
|
|
898204b2dd | ||
|
|
2c1a7f087f | ||
|
|
112d064700 | ||
|
|
c51c36fbcb | ||
|
|
ee372d464e | ||
|
|
1bef1344d1 | ||
|
|
2e27c29b47 | ||
|
|
8d41c491fd | ||
|
|
64f390a833 | ||
|
|
8d20581f38 | ||
|
|
bfd4ae9b27 | ||
|
|
92e4260f1e | ||
|
|
662a564efc | ||
|
|
1761dc6b66 | ||
|
|
da71273d7e | ||
|
|
7c921d03a8 | ||
|
|
679aa7e092 | ||
|
|
3dd2be2fb2 |
130
.agents/skills/aoti-debug/SKILL.md
Normal file
130
.agents/skills/aoti-debug/SKILL.md
Normal file
@@ -0,0 +1,130 @@
|
||||
---
|
||||
name: aoti-debug
|
||||
description: Debug AOTInductor (AOTI) errors including device mismatches, CUDA illegal memory access, segfaults, and wrong outputs when deploying compiled PyTorch models. Use when encountering errors with aoti_compile_and_package, aoti_load_package, or the deprecated aot_compile/aot_load APIs.
|
||||
---
|
||||
|
||||
# AOTInductor Debugging
|
||||
|
||||
Debug errors when compiling and deploying PyTorch models with AOTInductor.
|
||||
|
||||
## First Step: Always Check Device and Shape Matching
|
||||
|
||||
**For ANY AOTI error (segfault, exception, crash, wrong output), check these first:**
|
||||
|
||||
1. **Compile device == Load device**: The model must be loaded on the same device type it was compiled on
|
||||
2. **Input devices match**: Runtime inputs must be on the same device as the compiled model
|
||||
3. **Input shapes match**: Runtime input shapes must match compilation shapes (or satisfy dynamic shape constraints)
|
||||
|
||||
```python
|
||||
# Compilation -- note the device and shapes
|
||||
model = MyModel().eval().cuda()
|
||||
inp = torch.randn(2, 10, device="cuda")
|
||||
pkg = torch._inductor.aoti_compile_and_package(model, (inp,))
|
||||
|
||||
# Loading -- device type MUST match compilation
|
||||
loaded = torch._inductor.aoti_load_package(pkg) # auto-detects device from package
|
||||
|
||||
# Inference -- device and shapes MUST match
|
||||
out = loaded(torch.randn(2, 10, device="cuda")) # same device, same shape
|
||||
```
|
||||
|
||||
**AOTI requires compile and load to use the same device type.** Cross-device loading (compile on GPU, load on CPU) is NOT supported. Device index can differ (cuda:0 vs cuda:1).
|
||||
|
||||
## Current vs Deprecated API
|
||||
|
||||
### Current API (use this)
|
||||
```python
|
||||
torch._inductor.aoti_compile_and_package() # compile
|
||||
torch._inductor.aoti_load_package() # load (auto-detects device)
|
||||
```
|
||||
|
||||
### Deprecated API (migrate away)
|
||||
```python
|
||||
torch._export.aot_compile() # deprecated
|
||||
torch._export.aot_load() # deprecated
|
||||
```
|
||||
|
||||
The new API stores device metadata in the package, so `aoti_load_package()` automatically uses the correct device type.
|
||||
|
||||
## Common Error Patterns
|
||||
|
||||
### Device Mismatch Segfault
|
||||
|
||||
**Symptom**: Segfault, exception, or crash during load or execution.
|
||||
|
||||
**Example errors**:
|
||||
- `The specified pointer resides on host memory and is not registered with any CUDA device`
|
||||
- Crash during constant loading
|
||||
- `Expected out tensor to have device cuda:0, but got cpu instead`
|
||||
|
||||
**Solution**: Ensure compile and load use the same device type.
|
||||
|
||||
### Input Device Mismatch at Runtime
|
||||
|
||||
**Symptom**: RuntimeError during model execution.
|
||||
|
||||
**Better debugging**: Run with `AOTI_RUNTIME_CHECK_INPUTS=1` for clear errors:
|
||||
```bash
|
||||
AOTI_RUNTIME_CHECK_INPUTS=1 python script.py
|
||||
```
|
||||
|
||||
Produces actionable messages like:
|
||||
```
|
||||
Error: input_handles[0]: unmatched device type, expected: 0(cpu), but got: 1(cuda)
|
||||
```
|
||||
|
||||
## Debugging CUDA Illegal Memory Access (IMA)
|
||||
|
||||
### Step 1: Sanity Checks
|
||||
|
||||
```bash
|
||||
AOTI_RUNTIME_CHECK_INPUTS=1 python script.py # validate inputs match compilation guards
|
||||
TORCHINDUCTOR_NAN_ASSERTS=1 python script.py # check for NaN before/after each kernel
|
||||
```
|
||||
|
||||
Both flags take effect at **compile time** (codegen time).
|
||||
|
||||
### Step 2: Make IMA Deterministic
|
||||
|
||||
```bash
|
||||
PYTORCH_NO_CUDA_MEMORY_CACHING=1 CUDA_LAUNCH_BLOCKING=1 python script.py
|
||||
```
|
||||
|
||||
- `PYTORCH_NO_CUDA_MEMORY_CACHING=1` -- disables caching allocator (which allocates bigger buffers, masking IMA)
|
||||
- `CUDA_LAUNCH_BLOCKING=1` -- forces synchronous kernel launches (pinpoints which kernel crashed)
|
||||
|
||||
Both take effect at **runtime**.
|
||||
|
||||
### Step 3: Identify the Problematic Kernel
|
||||
|
||||
```bash
|
||||
AOT_INDUCTOR_DEBUG_INTERMEDIATE_VALUE_PRINTER=3 python script.py
|
||||
```
|
||||
|
||||
Prints kernels one by one at runtime. Combined with Step 2 flags, shows which kernel launched right before the error.
|
||||
|
||||
To inspect inputs to specific kernels:
|
||||
```bash
|
||||
AOT_INDUCTOR_FILTERED_KERNELS_TO_PRINT="kernel_name_1,kernel_name_2" \
|
||||
AOT_INDUCTOR_DEBUG_INTERMEDIATE_VALUE_PRINTER=2 python script.py
|
||||
```
|
||||
|
||||
If inputs to a kernel are unexpected, trace back to the kernel that produced the bad input.
|
||||
|
||||
## Environment Variables Reference
|
||||
|
||||
| Variable | When | Purpose |
|
||||
|---|---|---|
|
||||
| `AOTI_RUNTIME_CHECK_INPUTS=1` | Compile time | Validate inputs match compilation guards |
|
||||
| `TORCHINDUCTOR_NAN_ASSERTS=1` | Compile time | Check for NaN before/after kernels |
|
||||
| `PYTORCH_NO_CUDA_MEMORY_CACHING=1` | Runtime | Make IMA errors deterministic |
|
||||
| `CUDA_LAUNCH_BLOCKING=1` | Runtime | Force synchronous kernel launches |
|
||||
| `AOT_INDUCTOR_DEBUG_INTERMEDIATE_VALUE_PRINTER=3` | Compile time | Print kernels at runtime |
|
||||
| `AOT_INDUCTOR_FILTERED_KERNELS_TO_PRINT="..."` | Compile time | Filter which kernels to print |
|
||||
| `TORCH_LOGS="+inductor,output_code"` | Runtime | See PT2 internal logs |
|
||||
| `TORCH_SHOW_CPP_STACKTRACES=1` | Runtime | Show C++ stack traces |
|
||||
|
||||
## Common Sources of Issues
|
||||
|
||||
- **Dynamic shapes**: Historically a common source of IMA errors. Pay special attention when using dynamic shape constraints.
|
||||
- **Custom ops**: Especially C++ custom ops with dynamic shapes. The meta function may need to handle SymInt properly.
|
||||
195
.agents/skills/pt2-debug/SKILL.md
Normal file
195
.agents/skills/pt2-debug/SKILL.md
Normal file
@@ -0,0 +1,195 @@
|
||||
---
|
||||
name: pt2-debug
|
||||
description: Debug torch.compile failures, graph breaks, recompilation issues, accuracy mismatches, and Triton kernel errors. Use when encountering BackendCompilerFailed exceptions, torch.compile errors, recompilation warnings, or numerical accuracy issues with compiled PyTorch models.
|
||||
---
|
||||
|
||||
# PyTorch 2 Compile Debugging
|
||||
|
||||
Debug `torch.compile`, Dynamo, Inductor, and AOTAutograd failures when using PyTorch as a library.
|
||||
|
||||
## Diagnostic Environment Variables
|
||||
|
||||
Pick the right diagnostic based on the error:
|
||||
|
||||
| Command | When to use |
|
||||
|---|---|
|
||||
| `TORCH_LOGS="+dynamo,graph_breaks,recompiles" python script.py` | Quick overview of what's going wrong |
|
||||
| `TORCH_COMPILE_DEBUG=1 python script.py` | Full debug artifacts (FX graphs, Inductor IR, generated code) in `torch_compile_debug/` |
|
||||
| `TORCH_LOGS="output_code" python script.py` | See the generated Triton/C++ kernel code |
|
||||
| `TORCH_TRACE=/path/to/trace python script.py` | Structured trace (parse with `tlparse`) |
|
||||
| `TORCHINDUCTOR_COMPILE_THREADS=1 python script.py` | Single-threaded compilation for pdb debugging |
|
||||
|
||||
## Error Triage
|
||||
|
||||
Classify the failure and jump to the right section:
|
||||
|
||||
| Error Pattern | Category |
|
||||
|---|---|
|
||||
| `Unsupported: ...` or `graph break` in logs | [Graph Breaks](#graph-breaks) |
|
||||
| `BackendCompilerFailed` | [Backend Failures](#backend-compiler-failures) |
|
||||
| `RecompileError` or `cache_size_limit` | [Recompilation](#recompilation-issues) |
|
||||
| Accuracy mismatch / wrong numerical output | [Accuracy](#accuracy-issues) |
|
||||
| `InternalTorchDynamoError` | [Internal Errors](#internal-dynamo-errors) |
|
||||
| Segfault or CUDA IMA | [Runtime Crashes](#runtime-crashes) |
|
||||
| Triton assertion / index out of bounds | [Triton Failures](#triton-kernel-failures) |
|
||||
|
||||
## Graph Breaks
|
||||
|
||||
Graph breaks split the compiled graph into smaller subgraphs, causing performance regressions.
|
||||
|
||||
**Diagnose:**
|
||||
```bash
|
||||
TORCH_LOGS="graph_breaks" python script.py
|
||||
```
|
||||
|
||||
**Common causes:**
|
||||
- Data-dependent control flow
|
||||
- Unsupported Python builtins
|
||||
- In-place ops on inputs, unsupported dtypes
|
||||
- Calls to non-traceable functions
|
||||
|
||||
**Fix approaches:**
|
||||
1. Read the graph break message to identify the unsupported operation
|
||||
2. Check for a decomposition or supported alternative
|
||||
3. Consider `torch._dynamo.allow_in_graph` or restructure user code
|
||||
|
||||
## Backend Compiler Failures
|
||||
|
||||
`BackendCompilerFailed` means Inductor crashed during compilation.
|
||||
|
||||
**Diagnose with the minifier:**
|
||||
```bash
|
||||
# Generate minifier launcher
|
||||
TORCHDYNAMO_REPRO_AFTER=aot TORCHDYNAMO_REPRO_LEVEL=2 python script.py
|
||||
|
||||
# Run the minifier to get minimal failing graph
|
||||
python minifier_launcher.py minify
|
||||
|
||||
# Run the minimized reproduction
|
||||
python minifier_launcher.py run
|
||||
```
|
||||
|
||||
**Then inspect:**
|
||||
```bash
|
||||
TORCH_COMPILE_DEBUG=1 python script.py # FX graphs in torch_compile_debug/
|
||||
```
|
||||
|
||||
## Recompilation Issues
|
||||
|
||||
Excessive recompilation from guards that are too specific, causing cache misses.
|
||||
|
||||
**Diagnose:**
|
||||
```bash
|
||||
TORCH_LOGS="recompiles,recompiles_verbose,guards" python script.py
|
||||
```
|
||||
|
||||
**Key config:**
|
||||
```python
|
||||
torch._dynamo.config.recompile_limit # default: 8
|
||||
torch._dynamo.config.fail_on_recompile_limit_hit = True # hard error on limit
|
||||
```
|
||||
|
||||
**Common causes:**
|
||||
- Changing tensor shapes without marking them dynamic
|
||||
- Python scalar values that change between calls
|
||||
- Global state mutations between calls
|
||||
|
||||
**Fix:** Read the recompilation reason from logs, identify the failing guard, then either:
|
||||
- Mark dimensions as dynamic: `torch._dynamo.mark_dynamic(tensor, dim)`
|
||||
- Fix the source of guard instability
|
||||
|
||||
## Accuracy Issues
|
||||
|
||||
Compiled model produces different numerical results than eager mode.
|
||||
|
||||
**Diagnose:**
|
||||
```bash
|
||||
# Compares compiled vs eager with fp64 reference, dumps repro on failure
|
||||
TORCHDYNAMO_REPRO_AFTER=aot TORCHDYNAMO_REPRO_LEVEL=4 python script.py
|
||||
```
|
||||
|
||||
**Fix approach:**
|
||||
1. Get minimal failing graph from the minifier
|
||||
2. Compare eager vs compiled output at fp64 precision
|
||||
3. Binary search through ops to find the diverging operation
|
||||
4. Check for known issues: reduction order, fused kernels, dtype promotions
|
||||
|
||||
## Internal Dynamo Errors
|
||||
|
||||
`InternalTorchDynamoError` indicates a bug in Dynamo.
|
||||
|
||||
**Diagnose:**
|
||||
```bash
|
||||
TORCHDYNAMO_VERBOSE=1 python script.py
|
||||
# or equivalently:
|
||||
TORCH_LOGS="+dynamo" python script.py
|
||||
```
|
||||
|
||||
**Debug interactively:**
|
||||
```bash
|
||||
TORCHINDUCTOR_COMPILE_THREADS=1 python script.py # then attach pdb
|
||||
```
|
||||
|
||||
## Runtime Crashes
|
||||
|
||||
Segfaults and CUDA illegal memory access during execution of compiled code.
|
||||
|
||||
**Make crash deterministic:**
|
||||
```bash
|
||||
PYTORCH_NO_CUDA_MEMORY_CACHING=1 CUDA_LAUNCH_BLOCKING=1 python script.py
|
||||
```
|
||||
|
||||
**Add NaN checks to find the first bad kernel:**
|
||||
```bash
|
||||
TORCHINDUCTOR_NAN_ASSERTS=1 python script.py
|
||||
```
|
||||
|
||||
**Inductor sync debugging:**
|
||||
```python
|
||||
torch._inductor.config.triton.debug_sync_kernel = True # sync after every kernel
|
||||
torch._inductor.config.triton.debug_sync_graph = True # sync before/after graph
|
||||
```
|
||||
|
||||
**Fix approach:**
|
||||
1. Make deterministic with `PYTORCH_NO_CUDA_MEMORY_CACHING=1 CUDA_LAUNCH_BLOCKING=1`
|
||||
2. Check input shapes, devices, dtypes
|
||||
3. Inspect generated kernel code with `TORCH_LOGS="output_code"`
|
||||
4. Use `TORCHINDUCTOR_NAN_ASSERTS=1` to find the first kernel producing bad values
|
||||
5. Dynamic shapes are historically a common source of IMA
|
||||
|
||||
## Triton Kernel Failures
|
||||
|
||||
Triton assertion failures or index-out-of-bounds in generated kernels.
|
||||
|
||||
**Diagnose:**
|
||||
```bash
|
||||
TORCH_LOGS="output_code,schedule" python script.py
|
||||
```
|
||||
|
||||
**Fix approach:**
|
||||
1. Get the generated Triton kernel from `output_code` logs
|
||||
2. Check index computations for off-by-one or wrong stride calculations
|
||||
3. Check IR with `TORCH_COMPILE_DEBUG=1` to trace back to the FX op
|
||||
4. Check if fusion decisions created invalid index combinations
|
||||
|
||||
## Distinguish Trace-Time vs Runtime
|
||||
|
||||
Many bugs come from confusing these:
|
||||
- **Trace-time**: Inside Dynamo's symbolic interpreter. Function calls may be constant-folded.
|
||||
- **Runtime**: Real tensors, real Python calls.
|
||||
|
||||
When debugging, add `print()` directly in source files rather than monkey-patching -- dispatch chains make monkey-patching unreliable.
|
||||
|
||||
## Using the Minifier
|
||||
|
||||
The minifier reduces a failing graph to the smallest reproduction:
|
||||
|
||||
```bash
|
||||
# For compilation failures (level 2)
|
||||
TORCHDYNAMO_REPRO_AFTER=aot TORCHDYNAMO_REPRO_LEVEL=2 python script.py
|
||||
python minifier_launcher.py minify
|
||||
python minifier_launcher.py run
|
||||
|
||||
# For accuracy failures (level 4)
|
||||
TORCHDYNAMO_REPRO_AFTER=aot TORCHDYNAMO_REPRO_LEVEL=4 python script.py
|
||||
```
|
||||
134
.agents/skills/ruff/SKILL.md
Normal file
134
.agents/skills/ruff/SKILL.md
Normal file
@@ -0,0 +1,134 @@
|
||||
---
|
||||
name: ruff
|
||||
description:
|
||||
Guide for using ruff, the extremely fast Python linter and formatter. Use this
|
||||
when linting, formatting, or fixing Python code.
|
||||
---
|
||||
|
||||
# ruff
|
||||
|
||||
Ruff is an extremely fast Python linter and code formatter. It replaces Flake8,
|
||||
isort, Black, pyupgrade, autoflake, and dozens of other tools.
|
||||
|
||||
## When to use ruff
|
||||
|
||||
**Always use ruff for Python linting and formatting**, especially if you see:
|
||||
|
||||
- `[tool.ruff]` section in `pyproject.toml`
|
||||
- A `ruff.toml` or `.ruff.toml` configuration file
|
||||
|
||||
However, avoid making unnecessary changes:
|
||||
|
||||
- **Don't format unformatted code** - If `ruff format --diff` shows changes
|
||||
throughout an entire file, the project likely isn't using ruff for formatting.
|
||||
Skip formatting to avoid obscuring actual changes.
|
||||
- **Scope fixes to code being edited** - Use `ruff check --diff` to see fixes
|
||||
relevant to the code you're changing. Only apply fixes to files you're
|
||||
modifying unless the user explicitly asks for broader fixes.
|
||||
|
||||
## How to invoke ruff
|
||||
|
||||
- `uv run ruff ...` - Use when ruff is in the project's dependencies to ensure
|
||||
you use the pinned version
|
||||
- `uvx ruff ...` - Use when ruff is not a project dependency, or for quick
|
||||
one-off checks
|
||||
- `ruff ...` - Use if ruff is installed globally
|
||||
|
||||
## Commands
|
||||
|
||||
### Linting
|
||||
|
||||
```bash
|
||||
ruff check . # Check all files in current directory
|
||||
ruff check path/to/file.py # Check specific file
|
||||
ruff check --fix . # Auto-fix fixable violations
|
||||
ruff check --fix --unsafe-fixes . # Include unsafe fixes (review changes!)
|
||||
ruff check --watch . # Watch for changes and re-lint
|
||||
ruff check --select E,F . # Only check specific rules
|
||||
ruff check --ignore E501 . # Ignore specific rules
|
||||
ruff rule E501 # Explain a specific rule
|
||||
ruff linter # List available linters
|
||||
```
|
||||
|
||||
### Formatting
|
||||
|
||||
```bash
|
||||
ruff format . # Format all files
|
||||
ruff format path/to/file.py # Format specific file
|
||||
ruff format --check . # Check if files are formatted (no changes)
|
||||
ruff format --diff . # Show formatting diff without applying
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
Ruff is configured in `pyproject.toml` or `ruff.toml`:
|
||||
|
||||
```toml
|
||||
# pyproject.toml
|
||||
[tool.ruff.lint]
|
||||
select = ["E", "F", "I", "UP"] # Enable specific rule sets
|
||||
ignore = ["E501"] # Ignore specific rules
|
||||
|
||||
[tool.ruff.lint.isort]
|
||||
known-first-party = ["myproject"]
|
||||
```
|
||||
|
||||
## Migrating from other tools
|
||||
|
||||
### Black → ruff format
|
||||
|
||||
```bash
|
||||
black . → ruff format .
|
||||
black --check . → ruff format --check .
|
||||
black --diff . → ruff format --diff .
|
||||
```
|
||||
|
||||
### Flake8 → ruff check
|
||||
|
||||
```bash
|
||||
flake8 . → ruff check .
|
||||
flake8 --select E,F . → ruff check --select E,F .
|
||||
flake8 --ignore E501 . → ruff check --ignore E501 .
|
||||
```
|
||||
|
||||
### isort → ruff check
|
||||
|
||||
```bash
|
||||
isort . → ruff check --select I --fix .
|
||||
isort --check . → ruff check --select I .
|
||||
isort --diff . → ruff check --select I --diff .
|
||||
```
|
||||
|
||||
## Common patterns
|
||||
|
||||
### Apply lint fixes before formatting
|
||||
|
||||
Run `ruff check --fix` before `ruff format`. Lint fixes can change code
|
||||
structure (e.g., reordering imports), which formatting then cleans up.
|
||||
|
||||
```bash
|
||||
ruff check --fix .
|
||||
ruff format .
|
||||
```
|
||||
|
||||
### Applying and reviewing unsafe fixes
|
||||
|
||||
Ruff categorizes some auto-fixes as "unsafe" because they may change code
|
||||
behavior, not just style. For example, removing unused imports could break code
|
||||
that relies on side effects.
|
||||
|
||||
```bash
|
||||
ruff check --fix --unsafe-fixes --diff . # Preview changes first
|
||||
ruff check --fix --unsafe-fixes . # Apply changes
|
||||
```
|
||||
|
||||
**Always review changes before applying `--unsafe-fixes`:**
|
||||
|
||||
- Use `ruff rule <CODE>` to understand why the fix is considered unsafe
|
||||
- Verify the fix doesn't violate those assumptions in your code
|
||||
|
||||
## Documentation
|
||||
|
||||
For detailed information, read the official documentation:
|
||||
|
||||
- https://docs.astral.sh/ruff/
|
||||
135
.agents/skills/ty/SKILL.md
Normal file
135
.agents/skills/ty/SKILL.md
Normal file
@@ -0,0 +1,135 @@
|
||||
---
|
||||
name: ty
|
||||
description:
|
||||
Guide for using ty, the extremely fast Python type checker and language
|
||||
server. Use this when type checking Python code or setting up type checking in
|
||||
Python projects.
|
||||
---
|
||||
|
||||
# ty
|
||||
|
||||
ty is an extremely fast Python type checker and language server. It replaces
|
||||
mypy, Pyright, and other type checkers.
|
||||
|
||||
## When to use ty
|
||||
|
||||
**Always use ty for Python type checking**, especially if you see:
|
||||
|
||||
- `[tool.ty]` section in `pyproject.toml`
|
||||
- A `ty.toml` configuration file
|
||||
|
||||
## How to invoke ty
|
||||
|
||||
- `uv run ty ...` - Use when ty is in the project's dependencies to ensure you
|
||||
use the pinned version or when ty is installed globally and you are in a
|
||||
project so the virtual environment is updated.
|
||||
- `uvx ty ...` - Use when ty is not a project dependency, or for quick one-off
|
||||
checks
|
||||
|
||||
## Commands
|
||||
|
||||
### Type checking
|
||||
|
||||
```bash
|
||||
ty check # Check all files in current directory
|
||||
ty check path/to/file.py # Check specific file
|
||||
ty check src/ # Check specific directory
|
||||
```
|
||||
|
||||
### Rule configuration
|
||||
|
||||
```bash
|
||||
ty check --error possibly-unresolved-reference # Treat as error
|
||||
ty check --warn division-by-zero # Treat as warning
|
||||
ty check --ignore unresolved-import # Disable rule
|
||||
```
|
||||
|
||||
### Python version targeting
|
||||
|
||||
```bash
|
||||
ty check --python-version 3.12 # Check against Python 3.12
|
||||
ty check --python-platform linux # Target Linux platform
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
ty is configured in `pyproject.toml` or `ty.toml`:
|
||||
|
||||
```toml
|
||||
# pyproject.toml
|
||||
[tool.ty.environment]
|
||||
python-version = "3.12"
|
||||
|
||||
[tool.ty.rules]
|
||||
possibly-unresolved-reference = "warn"
|
||||
division-by-zero = "error"
|
||||
|
||||
[tool.ty.src]
|
||||
include = ["src/**/*.py"]
|
||||
exclude = ["**/migrations/**"]
|
||||
|
||||
[tool.ty.terminal]
|
||||
output-format = "full"
|
||||
error-on-warning = false
|
||||
```
|
||||
|
||||
### Per-file overrides
|
||||
|
||||
Use overrides to apply different rules to specific files, such as relaxing rules
|
||||
for tests or scripts that have different typing requirements than production
|
||||
code:
|
||||
|
||||
```toml
|
||||
[[tool.ty.overrides]]
|
||||
include = ["tests/**", "**/test_*.py"]
|
||||
|
||||
[tool.ty.overrides.rules]
|
||||
possibly-unresolved-reference = "warn"
|
||||
```
|
||||
|
||||
## Language server
|
||||
|
||||
This plugin automatically configures the ty language server for Python files
|
||||
(`.py` and `.pyi`).
|
||||
|
||||
## Migrating from other tools
|
||||
|
||||
### mypy → ty
|
||||
|
||||
```bash
|
||||
mypy . → ty check
|
||||
mypy --strict . → ty check --error-on-warning
|
||||
mypy path/to/file.py → ty check path/to/file.py
|
||||
```
|
||||
|
||||
### Pyright → ty
|
||||
|
||||
```bash
|
||||
pyright . → ty check
|
||||
pyright path/to/file.py → ty check path/to/file.py
|
||||
```
|
||||
|
||||
## Common patterns
|
||||
|
||||
### Don't add ignore comments
|
||||
|
||||
Fix type errors instead of suppressing them. Only add ignore comments when
|
||||
explicitly requested by the user. Use `ty: ignore`, not `type: ignore`, and
|
||||
prefer rule-specific ignores:
|
||||
|
||||
```python
|
||||
# Good: rule-specific ignore
|
||||
x = undefined_var # ty: ignore[possibly-unresolved-reference]
|
||||
|
||||
# Bad: blanket ty ignore
|
||||
x = undefined_var # ty: ignore
|
||||
|
||||
# Bad: tool agnostic blanket ignore
|
||||
x = undefined_var # type: ignore
|
||||
```
|
||||
|
||||
## Documentation
|
||||
|
||||
For detailed information, read the official documentation:
|
||||
|
||||
- https://docs.astral.sh/ty/
|
||||
182
.agents/skills/uv/SKILL.md
Normal file
182
.agents/skills/uv/SKILL.md
Normal file
@@ -0,0 +1,182 @@
|
||||
---
|
||||
name: uv
|
||||
description:
|
||||
Guide for using uv, the Python package and project manager. Use this when
|
||||
working with Python projects, scripts, packages, or tools.
|
||||
---
|
||||
|
||||
# uv
|
||||
|
||||
uv is an extremely fast Python package and project manager. It replaces pip,
|
||||
pip-tools, pipx, pyenv, virtualenv, poetry, etc.
|
||||
|
||||
## When to use uv
|
||||
|
||||
**Always use uv for Python work**, especially if you see:
|
||||
|
||||
- The `uv.lock` file
|
||||
- uv headers in `requirements*` files, e.g., "This file was autogenerated by uv"
|
||||
|
||||
Don't use uv in projects managed by other tools:
|
||||
|
||||
- Poetry projects (identifiable by `poetry.lock` file)
|
||||
- PDM projects (identifiable by `pdm.lock` file)
|
||||
|
||||
## Choosing the right workflow
|
||||
|
||||
### Scripts
|
||||
|
||||
**Use when:** Running single Python files and standalone scripts.
|
||||
|
||||
**Key commands:**
|
||||
|
||||
```bash
|
||||
uv run script.py # Run a script
|
||||
uv run --with requests script.py # Run with additional packages
|
||||
uv add --script script.py requests # Add dependencies inline to the script
|
||||
```
|
||||
|
||||
### Projects
|
||||
|
||||
**Use when:** There is a `pyproject.toml` or `uv.lock`
|
||||
|
||||
**Key commands:**
|
||||
|
||||
```bash
|
||||
uv init # Create new project
|
||||
uv add requests # Add dependency
|
||||
uv remove requests # Remove dependency
|
||||
uv sync # Install from lockfile
|
||||
uv run <command> # Run commands in environment
|
||||
uv run python -c "" # Run Python in project environment
|
||||
uv run -p 3.12 <command> # Run with specific Python version
|
||||
```
|
||||
|
||||
### Tools
|
||||
|
||||
**Use when:** Running command-line tools (e.g., ruff, ty, pytest) without
|
||||
installation.
|
||||
|
||||
**Key commands:**
|
||||
|
||||
```bash
|
||||
uvx <tool> <args> # Run a tool without installation
|
||||
uvx <tool>@<version> <args> # Run a specific version of a tool
|
||||
```
|
||||
|
||||
**Important:**
|
||||
|
||||
- `uvx` runs tools from PyPI by package name. This can be unsafe - only run
|
||||
well-known tools.
|
||||
- Only use `uv tool install` only when specifically requested by the user.
|
||||
|
||||
### Pip interface
|
||||
|
||||
**Use when:** Legacy workflows with `requirements.txt` or manual environment
|
||||
management, no `uv.lock` present.
|
||||
|
||||
**Key commands:**
|
||||
|
||||
```bash
|
||||
uv venv
|
||||
uv pip install -r requirements.txt
|
||||
uv pip compile requirements.in -o requirements.txt
|
||||
uv pip sync requirements.txt
|
||||
|
||||
# Platform independent resolution
|
||||
uv pip compile --universal requirements.in -o requirements.txt
|
||||
```
|
||||
|
||||
**Important:**
|
||||
|
||||
- Don't use the pip interface unless clearly needed.
|
||||
- Don't introduce new `requirements.txt` files.
|
||||
- Prefer `uv init` for new projects.
|
||||
|
||||
## Migrating from other tools
|
||||
|
||||
### pyenv → uv python
|
||||
|
||||
```bash
|
||||
pyenv install 3.12 → uv python install 3.12
|
||||
pyenv versions → uv python list --only-installed
|
||||
pyenv local 3.12 → uv python pin 3.12
|
||||
pyenv global 3.12 → uv python install 3.12 --default
|
||||
```
|
||||
|
||||
### pipx → uvx
|
||||
|
||||
```bash
|
||||
pipx run ruff → uvx ruff
|
||||
pipx install ruff → uv tool install ruff
|
||||
pipx upgrade ruff → uv tool upgrade ruff
|
||||
pipx list → uv tool list
|
||||
```
|
||||
|
||||
### pip and pip-tools → uv pip
|
||||
|
||||
```bash
|
||||
pip install package → uv pip install package
|
||||
pip install -r req.txt → uv pip install -r req.txt
|
||||
pip freeze → uv pip freeze
|
||||
pip-compile req.in → uv pip compile req.in
|
||||
pip-sync req.txt → uv pip sync req.txt
|
||||
virtualenv .venv → uv venv
|
||||
```
|
||||
|
||||
## Common patterns
|
||||
|
||||
### Don't use pip in uv projects
|
||||
|
||||
```bash
|
||||
# Bad
|
||||
pip install requests
|
||||
|
||||
# Good
|
||||
uv add requests
|
||||
```
|
||||
|
||||
### Don't run python directly
|
||||
|
||||
```bash
|
||||
# Bad
|
||||
python script.py
|
||||
|
||||
# Good
|
||||
uv run script.py
|
||||
```
|
||||
|
||||
```bash
|
||||
# Bad
|
||||
python -c "..."
|
||||
|
||||
# Good
|
||||
uv run python -c "..."
|
||||
```
|
||||
|
||||
```bash
|
||||
# Bad
|
||||
python3.12 -c "..."
|
||||
|
||||
# Good
|
||||
uvx python@3.12 -c "..."
|
||||
```
|
||||
|
||||
### Don't manually manage environments in uv projects
|
||||
|
||||
```bash
|
||||
# Bad
|
||||
python -m venv .venv
|
||||
source .venv/bin/activate
|
||||
|
||||
# Good
|
||||
uv run <command>
|
||||
```
|
||||
|
||||
## Documentation
|
||||
|
||||
For detailed information, read the official documentation:
|
||||
|
||||
- https://docs.astral.sh/uv/llms.txt
|
||||
|
||||
The documentation links to specific pages for each of these workflows.
|
||||
@@ -17,11 +17,15 @@
|
||||
"userUid": "1000",
|
||||
"userGid": "1000",
|
||||
"configureZshAsDefaultShell": false
|
||||
},
|
||||
"ghcr.io/devcontainers/features/node:1": {
|
||||
"version": "lts"
|
||||
}
|
||||
},
|
||||
"remoteUser": "ubuntu",
|
||||
"remoteEnv": {
|
||||
"CARGO_HOME": "/home/ubuntu/.cache/luminal/cargo"
|
||||
"CARGO_HOME": "/home/ubuntu/.cache/luminal/cargo",
|
||||
"CODEX_HOME": "${containerWorkspaceFolder}/.claude/codex"
|
||||
},
|
||||
"postStartCommand": "mkdir -p /home/ubuntu/.cache/luminal/cargo && git config --global --add safe.directory ${containerWorkspaceFolder} && gh auth setup-git",
|
||||
"customizations": {
|
||||
|
||||
@@ -21,11 +21,15 @@
|
||||
"userUid": "1000",
|
||||
"userGid": "1000",
|
||||
"configureZshAsDefaultShell": false
|
||||
},
|
||||
"ghcr.io/devcontainers/features/node:1": {
|
||||
"version": "lts"
|
||||
}
|
||||
},
|
||||
"remoteUser": "ubuntu",
|
||||
"remoteEnv": {
|
||||
"CARGO_HOME": "/home/ubuntu/.cache/luminal/cargo"
|
||||
"CARGO_HOME": "/home/ubuntu/.cache/luminal/cargo",
|
||||
"CODEX_HOME": "${containerWorkspaceFolder}/.claude/codex"
|
||||
},
|
||||
"postStartCommand": "mkdir -p /home/ubuntu/.cache/luminal/cargo && git config --global --add safe.directory ${containerWorkspaceFolder} && gh auth setup-git",
|
||||
"customizations": {
|
||||
@@ -52,4 +56,4 @@
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
30
.github/workflows/cuda-clippy.yml
vendored
Normal file
30
.github/workflows/cuda-clippy.yml
vendored
Normal file
@@ -0,0 +1,30 @@
|
||||
name: CUDA Clippy
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: ["main"]
|
||||
pull_request:
|
||||
branches: ["main"]
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
cuda_clippy:
|
||||
name: CUDA Clippy
|
||||
runs-on: cuda_t4_runner
|
||||
container:
|
||||
image: ghcr.io/luminal-ai/luminal-docker:cuda
|
||||
options: --gpus all
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- name: Mark workspace as safe for git
|
||||
run: git config --global --add safe.directory "$GITHUB_WORKSPACE"
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- name: Update Rust toolchain
|
||||
run: rustup update
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
with:
|
||||
extra_args: cargo-clippy --all-files
|
||||
23
.github/workflows/fmt.yml
vendored
Normal file
23
.github/workflows/fmt.yml
vendored
Normal file
@@ -0,0 +1,23 @@
|
||||
name: Fmt
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: ["main"]
|
||||
pull_request:
|
||||
branches: ["main"]
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
fmt:
|
||||
name: Fmt
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
with:
|
||||
extra_args: cargo-fmt --all-files
|
||||
86
.github/workflows/lint.yml
vendored
86
.github/workflows/lint.yml
vendored
@@ -1,86 +0,0 @@
|
||||
name: Lint
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: ["main"]
|
||||
pull_request:
|
||||
branches: ["main"]
|
||||
workflow_dispatch:
|
||||
|
||||
env:
|
||||
CARGO_TERM_COLOR: always
|
||||
|
||||
jobs:
|
||||
ruff:
|
||||
name: Ruff
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
with:
|
||||
extra_args: ruff-check --all-files
|
||||
|
||||
ruff_format:
|
||||
name: Ruff Format
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
with:
|
||||
extra_args: ruff-format --all-files
|
||||
|
||||
clippy:
|
||||
name: Clippy
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- name: Update Rust toolchain
|
||||
run: rustup update
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
with:
|
||||
extra_args: cargo-clippy --all-files
|
||||
|
||||
metal_clippy:
|
||||
name: Metal Clippy
|
||||
runs-on: macos-14
|
||||
timeout-minutes: 30
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- name: Update Rust toolchain
|
||||
run: rustup update
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
with:
|
||||
extra_args: --hook-stage manual cargo-clippy-metal --all-files
|
||||
|
||||
fmt:
|
||||
name: Fmt
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
with:
|
||||
extra_args: cargo-fmt --all-files
|
||||
25
.github/workflows/metal-clippy.yml
vendored
Normal file
25
.github/workflows/metal-clippy.yml
vendored
Normal file
@@ -0,0 +1,25 @@
|
||||
name: Metal Clippy
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: ["main"]
|
||||
pull_request:
|
||||
branches: ["main"]
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
metal_clippy:
|
||||
name: Metal Clippy
|
||||
runs-on: macos-14
|
||||
timeout-minutes: 30
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- name: Update Rust toolchain
|
||||
run: rustup update
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
with:
|
||||
extra_args: --hook-stage manual cargo-clippy-metal --all-files
|
||||
9
.github/workflows/modal-examples.yml
vendored
9
.github/workflows/modal-examples.yml
vendored
@@ -5,13 +5,16 @@ on:
|
||||
branches: ["main"]
|
||||
pull_request:
|
||||
branches: ["main"]
|
||||
types: [opened, synchronize, reopened, ready_for_review]
|
||||
types: [labeled, synchronize]
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
modal_example:
|
||||
# Keep the draft check PR-specific so push/manual runs still execute.
|
||||
if: ${{ github.event_name != 'pull_request' || !github.event.pull_request.draft }}
|
||||
if: >-
|
||||
github.event_name == 'push'
|
||||
|| github.event_name == 'workflow_dispatch'
|
||||
|| (github.event_name == 'pull_request'
|
||||
&& contains(github.event.pull_request.labels.*.name, 'modal-ready'))
|
||||
name: "${{ matrix.example }} (Modal ${{ matrix.gpu.type }})"
|
||||
runs-on: ubuntu-latest
|
||||
environment: Modal
|
||||
|
||||
23
.github/workflows/ruff-format.yml
vendored
Normal file
23
.github/workflows/ruff-format.yml
vendored
Normal file
@@ -0,0 +1,23 @@
|
||||
name: Ruff Format
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: ["main"]
|
||||
pull_request:
|
||||
branches: ["main"]
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
ruff_format:
|
||||
name: Ruff Format
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
with:
|
||||
extra_args: ruff-format --all-files
|
||||
23
.github/workflows/ruff.yml
vendored
Normal file
23
.github/workflows/ruff.yml
vendored
Normal file
@@ -0,0 +1,23 @@
|
||||
name: Ruff
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: ["main"]
|
||||
pull_request:
|
||||
branches: ["main"]
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
ruff:
|
||||
name: Ruff
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
with:
|
||||
extra_args: ruff-check --all-files
|
||||
24
.github/workflows/test-core.yml
vendored
Normal file
24
.github/workflows/test-core.yml
vendored
Normal file
@@ -0,0 +1,24 @@
|
||||
name: Test Core
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: ["main"]
|
||||
pull_request:
|
||||
branches: ["main"]
|
||||
workflow_dispatch:
|
||||
|
||||
env:
|
||||
CARGO_TERM_COLOR: always
|
||||
|
||||
jobs:
|
||||
core_unit_test:
|
||||
name: Core Unit Tests
|
||||
runs-on: ubuntu-latest
|
||||
container:
|
||||
image: ghcr.io/luminal-ai/luminal-docker:cpu
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- name: Run tests
|
||||
run: cargo test --workspace --exclude luminal_cuda_lite --exclude luminal_metal --exclude luminal_bench --verbose
|
||||
51
.github/workflows/test-cuda.yml
vendored
51
.github/workflows/test-cuda.yml
vendored
@@ -5,44 +5,31 @@ on:
|
||||
branches: ["main"]
|
||||
pull_request:
|
||||
branches: ["main"]
|
||||
types: [labeled, synchronize]
|
||||
workflow_dispatch:
|
||||
|
||||
env:
|
||||
CARGO_TERM_COLOR: always
|
||||
|
||||
jobs:
|
||||
cuda_clippy:
|
||||
name: Cuda Clippy
|
||||
runs-on: cuda_t4_runner
|
||||
container:
|
||||
image: ghcr.io/luminal-ai/luminal-docker:cuda
|
||||
options: --gpus all
|
||||
cuda_unit_test:
|
||||
if: >-
|
||||
github.event_name == 'push'
|
||||
|| github.event_name == 'workflow_dispatch'
|
||||
|| (github.event_name == 'pull_request'
|
||||
&& contains(github.event.pull_request.labels.*.name, 'modal-ready'))
|
||||
name: Cuda Unit Tests
|
||||
runs-on: ubuntu-latest
|
||||
environment: Modal
|
||||
timeout-minutes: 30
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- name: Mark workspace as a safe git directory
|
||||
run: git config --global --add safe.directory "$GITHUB_WORKSPACE"
|
||||
- uses: actions/setup-python@v5
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
with:
|
||||
extra_args: --hook-stage manual cargo-clippy-cuda-lite --all-files
|
||||
|
||||
cuda_unit_test:
|
||||
name: Cuda Unit Tests
|
||||
runs-on: cuda_t4_runner
|
||||
container:
|
||||
image: ghcr.io/luminal-ai/luminal-docker:cuda
|
||||
options: --gpus all
|
||||
timeout-minutes: 30
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- name: Detect GPU compute capability
|
||||
run: |
|
||||
CAP=$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader | head -1 | tr -d '.')
|
||||
echo "CUDA_COMPUTE_CAP=${CAP}" >> "$GITHUB_ENV"
|
||||
- name: Run CUDA crate tests
|
||||
run: cargo test -p luminal_cuda_lite --verbose -- --test-threads=1
|
||||
- name: Install Modal
|
||||
run: pip install modal
|
||||
- name: Run CUDA tests on Modal
|
||||
env:
|
||||
MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }}
|
||||
MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }}
|
||||
run: modal run ci/modal_cargo_test.py
|
||||
|
||||
19
.github/workflows/test-metal.yml
vendored
Normal file
19
.github/workflows/test-metal.yml
vendored
Normal file
@@ -0,0 +1,19 @@
|
||||
name: Test Metal
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: ["main"]
|
||||
pull_request:
|
||||
branches: ["main"]
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
metal_unit_test:
|
||||
name: Metal Unit Tests
|
||||
runs-on: macos-14
|
||||
timeout-minutes: 30
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- name: Run Metal crate tests
|
||||
run: rustup update; cargo test -p luminal_metal --verbose -- --test-threads=1
|
||||
@@ -1,56 +1,20 @@
|
||||
name: Test
|
||||
name: Test Python CUDA
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: ["main"]
|
||||
pull_request:
|
||||
branches: ["main"]
|
||||
types: [labeled, synchronize]
|
||||
workflow_dispatch:
|
||||
|
||||
env:
|
||||
CARGO_TERM_COLOR: always
|
||||
|
||||
jobs:
|
||||
core_unit_test:
|
||||
name: Core Unit Tests
|
||||
runs-on: ubuntu-latest
|
||||
container:
|
||||
image: ghcr.io/luminal-ai/luminal-docker:cpu
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- name: Run tests
|
||||
run: cargo test --workspace --exclude luminal_cuda_lite --exclude luminal_metal --exclude luminal_bench --verbose
|
||||
metal_unit_test:
|
||||
name: Metal Unit Tests
|
||||
runs-on: macos-14
|
||||
timeout-minutes: 30
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- name: Run Metal crate tests
|
||||
run: rustup update; cargo test -p luminal_metal --verbose -- --test-threads=1
|
||||
python_native_tests:
|
||||
name: Python Native Tests
|
||||
runs-on: ubuntu-latest
|
||||
container:
|
||||
image: ghcr.io/luminal-ai/luminal-docker:cpu
|
||||
timeout-minutes: 45
|
||||
defaults:
|
||||
run:
|
||||
working-directory: crates/luminal_python
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- name: Update Rust toolchain
|
||||
run: rustup update
|
||||
- name: Build maturin extension
|
||||
run: uv run maturin develop --manifest-path rust/Cargo.toml
|
||||
- name: Run pytest
|
||||
run: uv run pytest tests/test_hlir_ops.py tests/test_unary.py -v -m "not slow"
|
||||
|
||||
python_cuda_tests:
|
||||
if: >-
|
||||
github.event_name == 'push'
|
||||
|| github.event_name == 'workflow_dispatch'
|
||||
|| (github.event_name == 'pull_request'
|
||||
&& contains(github.event.pull_request.labels.*.name, 'modal-ready'))
|
||||
name: Python CUDA Tests
|
||||
runs-on: ubuntu-latest
|
||||
environment: Modal
|
||||
28
.github/workflows/test-python-native.yml
vendored
Normal file
28
.github/workflows/test-python-native.yml
vendored
Normal file
@@ -0,0 +1,28 @@
|
||||
name: Test Python Native
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: ["main"]
|
||||
pull_request:
|
||||
branches: ["main"]
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
python_native_tests:
|
||||
name: Python Native Tests
|
||||
runs-on: ubuntu-latest
|
||||
container:
|
||||
image: ghcr.io/luminal-ai/luminal-docker:cpu
|
||||
timeout-minutes: 45
|
||||
defaults:
|
||||
run:
|
||||
working-directory: crates/luminal_python
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- name: Update Rust toolchain
|
||||
run: rustup update
|
||||
- name: Build maturin extension
|
||||
run: uv run maturin develop --manifest-path rust/Cargo.toml
|
||||
- name: Run pytest
|
||||
run: uv run pytest tests/test_hlir_ops.py tests/test_unary.py -v -m "not slow"
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -1,6 +1,9 @@
|
||||
/target
|
||||
/crates/**/target
|
||||
/examples/**/target
|
||||
.claude-project
|
||||
.claude-memory
|
||||
.codex
|
||||
|
||||
*.env
|
||||
.claude/
|
||||
|
||||
34
CLAUDE.md
Normal file
34
CLAUDE.md
Normal file
@@ -0,0 +1,34 @@
|
||||
# Luminal
|
||||
|
||||
## Package Management
|
||||
|
||||
- Use `uv add`, `uv add --dev`, `uv remove` for Python dependencies (pyproject.toml is in `crates/luminal_python/`)
|
||||
- Use `uv sync` to sync the Python environment
|
||||
- Never use pip, pip-tools, poetry, or conda
|
||||
- Never manually create or activate virtual environments — uv manages `.venv/` automatically
|
||||
- Never generate requirements.txt
|
||||
|
||||
## Code Execution
|
||||
|
||||
- Always use `uv run` to execute Python tools: `uv run pytest`, `uv run pre-commit`, `uv run python`
|
||||
- Use `cargo` directly for Rust: `cargo build`, `cargo test`, `cargo check`, `cargo clippy`
|
||||
- Python project root is `crates/luminal_python/` — run `uv run` commands from there
|
||||
|
||||
## Building the Python Package (Maturin)
|
||||
|
||||
- After modifying `.rs` files that affect the Python bridge, rebuild with: `maturin develop --release`
|
||||
- Maturin config is in `crates/luminal_python/pyproject.toml` under `[tool.maturin]`
|
||||
|
||||
## Pre-commit
|
||||
|
||||
- Run with: `uv run pre-commit run --all-files`
|
||||
- Hooks configured: ruff-check, ruff-format (Python), cargo-fmt, cargo-clippy (Rust)
|
||||
- Manual-stage hooks (cargo-clippy-metal, cargo-clippy-cuda-lite) run with `--hook-stage manual`
|
||||
|
||||
## Testing
|
||||
|
||||
- **Rust tests**: `cargo test -p <crate_name>`
|
||||
- **Python tests**: `cd crates/luminal_python && uv run pytest`
|
||||
- `./run_test.sh` — native backend
|
||||
- `./run_tests_cuda.sh` — CUDA backend
|
||||
- See `crates/luminal_python/CLAUDE.md` for Python test patterns and conventions
|
||||
67
ci/modal_cargo_test.py
Normal file
67
ci/modal_cargo_test.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import modal
|
||||
import subprocess
|
||||
import os
|
||||
import sys
|
||||
|
||||
gpu_type = os.environ.get("GPU_TYPE", "T4")
|
||||
CUDARC_CUDA_VERSION = "12080"
|
||||
|
||||
app = modal.App("luminal-ci-cargo-test")
|
||||
|
||||
WORKDIR = "/workspace/luminal"
|
||||
|
||||
cuda_image = (
|
||||
modal.Image.from_registry("nvcr.io/nvidia/pytorch:25.03-py3")
|
||||
.apt_install("protobuf-compiler")
|
||||
.run_commands(
|
||||
"curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y",
|
||||
)
|
||||
.env(
|
||||
{
|
||||
"PATH": "/root/.cargo/bin:$PATH",
|
||||
"CUDARC_CUDA_VERSION": CUDARC_CUDA_VERSION,
|
||||
}
|
||||
)
|
||||
.add_local_dir(".", remote_path=WORKDIR, copy=True)
|
||||
)
|
||||
|
||||
|
||||
@app.function(
|
||||
image=cuda_image,
|
||||
gpu=gpu_type,
|
||||
timeout=1800, # 30 minutes
|
||||
)
|
||||
def run_cargo_test():
|
||||
"""Run cargo test for luminal_cuda_lite on a Modal GPU."""
|
||||
subprocess.run(["nvidia-smi"], check=True)
|
||||
|
||||
# Detect GPU compute capability
|
||||
result = subprocess.run(
|
||||
["nvidia-smi", "--query-gpu=compute_cap", "--format=csv,noheader"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
compute_cap = result.stdout.strip().replace(".", "")
|
||||
|
||||
subprocess.run(
|
||||
[
|
||||
"cargo", "test",
|
||||
"-p", "luminal_cuda_lite",
|
||||
"--verbose",
|
||||
"--",
|
||||
"--test-threads=1",
|
||||
],
|
||||
cwd=WORKDIR,
|
||||
env={
|
||||
**os.environ,
|
||||
"CUDARC_CUDA_VERSION": CUDARC_CUDA_VERSION,
|
||||
"CUDA_COMPUTE_CAP": compute_cap,
|
||||
},
|
||||
check=True,
|
||||
)
|
||||
|
||||
|
||||
@app.local_entrypoint()
|
||||
def main():
|
||||
run_cargo_test.remote()
|
||||
@@ -69,7 +69,7 @@ pub type Ops = (
|
||||
|
||||
/// Build a rewrite that matches an HLIR op, reads dtype(s) from the given source fields,
|
||||
/// and unions with a kernel op that has the same fields plus the dtype(s) appended.
|
||||
fn kernel_rewrite<H: Default + EgglogOp, L: Default + EgglogOp>() -> Rule {
|
||||
pub fn kernel_rewrite<H: Default + EgglogOp, L: Default + EgglogOp>() -> Rule {
|
||||
let hlir = H::default().sort();
|
||||
let llir = L::default().sort();
|
||||
let (mut args, hlir_kind_term) = hlir.new_call();
|
||||
@@ -415,8 +415,12 @@ extern \"C\" {{
|
||||
long long iters = {iters};
|
||||
|
||||
{dtype} partial = 0;
|
||||
{dtype} comp = 0; // Kahan compensation
|
||||
for (long long i = tid; i < iters; i += THREADS_PER_BLOCK) {{
|
||||
partial += in_data[in_start + {iter_stride_of_i}];
|
||||
{dtype} y = in_data[in_start + {iter_stride_of_i}] - comp;
|
||||
{dtype} t = partial + y;
|
||||
comp = (t - partial) - y;
|
||||
partial = t;
|
||||
}}
|
||||
|
||||
#pragma unroll
|
||||
|
||||
@@ -3,7 +3,7 @@ use std::sync::Arc;
|
||||
use crate::{
|
||||
compile_module_image_for_current_device, cuda_dtype,
|
||||
kernel::KernelOp,
|
||||
kernel::hlir::{dtype_includes, generate_dyn_dims_defines},
|
||||
kernel::hlir::{dtype_includes, generate_dyn_dims_defines, kernel_rewrite},
|
||||
};
|
||||
use cudarc::driver::{CudaFunction, CudaModule, CudaSlice, CudaStream};
|
||||
use itertools::Itertools;
|
||||
@@ -22,6 +22,9 @@ pub type Ops = (
|
||||
KernelBatchMatVec,
|
||||
KernelBatchMatMul,
|
||||
KernelScatterNoCopy,
|
||||
KernelSoftmax,
|
||||
KernelExp,
|
||||
KernelSigmoid,
|
||||
);
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
@@ -1151,6 +1154,7 @@ impl EgglogOp for KernelSoftmax {
|
||||
("out_strides", ELIST),
|
||||
("reduce_dim", EXPRESSION),
|
||||
("reduce_stride", EXPRESSION),
|
||||
("dtype", DTYPE),
|
||||
],
|
||||
)
|
||||
}
|
||||
@@ -1160,8 +1164,24 @@ impl EgglogOp for KernelSoftmax {
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
// No rewrite rules yet - this op is not in the Ops tuple.
|
||||
vec![]
|
||||
vec![
|
||||
kernel_rewrite::<luminal::hlir::Softmax, Self>(),
|
||||
// Also add a direct rewrite that assumes F32 dtype, in case dtype
|
||||
// propagation hasn't reached the Softmax node yet.
|
||||
Rule::raw(
|
||||
"(rule
|
||||
(
|
||||
(= ?sm (Op (Softmax ?shape ?in_strides ?out_strides ?reduce_dim ?reduce_stride) ?inputs))
|
||||
)
|
||||
(
|
||||
(let ?ksm (Op (KernelSoftmax ?shape ?in_strides ?out_strides ?reduce_dim ?reduce_stride (F32)) ?inputs))
|
||||
(union ?sm ?ksm)
|
||||
(set (dtype ?ksm) (F32))
|
||||
)
|
||||
:name \"softmax-to-kernel-f32\"
|
||||
)",
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
@@ -1176,16 +1196,21 @@ impl EgglogOp for KernelSoftmax {
|
||||
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
let out_shape =
|
||||
extract_expr_list(egraph, kind_children[0], list_cache, expr_cache).unwrap();
|
||||
let in_stride =
|
||||
extract_expr_list(egraph, kind_children[1], list_cache, expr_cache).unwrap();
|
||||
let out_stride =
|
||||
extract_expr_list(egraph, kind_children[2], list_cache, expr_cache).unwrap();
|
||||
let reduce_dim = extract_expr(egraph, kind_children[3], expr_cache).unwrap();
|
||||
let reduce_stride = extract_expr(egraph, kind_children[4], expr_cache).unwrap();
|
||||
(
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(Self {
|
||||
out_shape: extract_expr_list(egraph, kind_children[0], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
in_stride: extract_expr_list(egraph, kind_children[1], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
out_stride: extract_expr_list(egraph, kind_children[2], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
reduce_dim: extract_expr(egraph, kind_children[3], expr_cache).unwrap(),
|
||||
reduce_stride: extract_expr(egraph, kind_children[4], expr_cache).unwrap(),
|
||||
out_shape,
|
||||
in_stride,
|
||||
out_stride,
|
||||
reduce_dim,
|
||||
reduce_stride,
|
||||
})),
|
||||
input_enodes,
|
||||
)
|
||||
@@ -1374,3 +1399,370 @@ extern \"C\" {{
|
||||
"Softmax"
|
||||
}
|
||||
}
|
||||
|
||||
// KernelExp: native exp (uses expf instead of exp2f * constant)
|
||||
// Single-kernel alternative to the 3-kernel Constant+Mul+Exp2 path.
|
||||
// Improves numerical precision by avoiding the truncated log2(e) constant.
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct KernelExp {
|
||||
shape: Vec<Expression>,
|
||||
in_strides: Vec<Expression>,
|
||||
out_strides: Vec<Expression>,
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
impl EgglogOp for KernelExp {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
"KernelExp",
|
||||
&[
|
||||
("shape", ELIST),
|
||||
("strides", ELIST),
|
||||
("out_strides", ELIST),
|
||||
("dtype", DTYPE),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
fn n_inputs(&self) -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![
|
||||
// Match Exp2(Mul(x, log2e_constant)) directly.
|
||||
// This matches the pattern created by frontend exp() = (self * (1/ln(2))).exp2()
|
||||
Rule::raw(
|
||||
"(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?shape ?x_stride ?const_stride ?inter_stride) (ICons ?x (ICons ?exp_const (INil)))))
|
||||
(= ?exp2 (Op (Exp2 ?shape ?inter_stride ?out_stride) (ICons ?mul (INil))))
|
||||
(= ?dt (dtype ?x))
|
||||
(= ?cv (Op (Constant ?val) (INil)))
|
||||
(= ?exp_const ?cv)
|
||||
(> ?val 1.44)
|
||||
(< ?val 1.45)
|
||||
)
|
||||
(
|
||||
(let ?kexp (Op (KernelExp ?shape ?x_stride ?out_stride ?dt) (ICons ?x (INil))))
|
||||
(union ?exp2 ?kexp)
|
||||
(set (dtype ?kexp) ?dt)
|
||||
)
|
||||
:name \"direct-exp-fusion\"
|
||||
)",
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
(
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(Self {
|
||||
shape: extract_expr_list(egraph, kind_children[0], list_cache, expr_cache).unwrap(),
|
||||
in_strides: extract_expr_list(egraph, kind_children[1], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
out_strides: extract_expr_list(egraph, kind_children[2], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
dtype: extract_dtype(egraph, kind_children[3]),
|
||||
})),
|
||||
input_enodes,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl KernelOp for KernelExp {
|
||||
fn compile(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> (
|
||||
CudaFunction,
|
||||
Arc<CudaModule>,
|
||||
String,
|
||||
(Expression, Expression, Expression),
|
||||
(Expression, Expression, Expression),
|
||||
Expression,
|
||||
FxHashMap<char, CudaSlice<u8>>,
|
||||
) {
|
||||
let vars = self
|
||||
.shape
|
||||
.iter()
|
||||
.flat_map(|e| e.dyn_vars())
|
||||
.chain(self.in_strides.iter().flat_map(|e| e.dyn_vars()))
|
||||
.chain(self.out_strides.iter().flat_map(|e| e.dyn_vars()))
|
||||
.collect::<FxHashSet<_>>();
|
||||
let dtype = cuda_dtype(self.dtype);
|
||||
let includes = dtype_includes(&[self.dtype]);
|
||||
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
|
||||
let dyn_dims_param = if vars.is_empty() {
|
||||
""
|
||||
} else {
|
||||
", const int* dyn_dims"
|
||||
};
|
||||
let n_elements = self
|
||||
.shape
|
||||
.iter()
|
||||
.copied()
|
||||
.product::<Expression>()
|
||||
.to_kernel();
|
||||
let out_idx = flatten_strides(&self.shape, &self.out_strides).to_kernel();
|
||||
let in_idx = flatten_strides(&self.shape, &self.in_strides).to_kernel();
|
||||
let kernel = format!(
|
||||
"{includes}
|
||||
{dyn_defines}
|
||||
extern \"C\" {{
|
||||
__global__ void exp_k({dtype} *out, const {dtype} *in{dyn_dims_param}) {{
|
||||
long long const_z = (long long)blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (const_z >= {n_elements}) return;
|
||||
out[{out_idx}] = expf(in[{in_idx}]);
|
||||
}}
|
||||
}}"
|
||||
);
|
||||
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
|
||||
(module.clone(), func.clone())
|
||||
} else {
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let func = module.load_function("exp_k").unwrap();
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
(module, func)
|
||||
};
|
||||
let out_size = self.shape.iter().copied().product::<Expression>();
|
||||
(
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(out_size.ceil_div(128), 1.into(), 1.into()),
|
||||
(out_size.min(128), 1.into(), 1.into()),
|
||||
0.into(),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
self.shape.iter().copied().product()
|
||||
}
|
||||
|
||||
fn output_bytes(&self) -> Expression {
|
||||
(self.output_size() * self.dtype.bits()).ceil_div(8)
|
||||
}
|
||||
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
self.output_bytes()
|
||||
}
|
||||
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
self.output_bytes()
|
||||
}
|
||||
|
||||
fn flops(&self) -> Expression {
|
||||
self.shape.iter().copied().product()
|
||||
}
|
||||
|
||||
fn output_dtype(&self) -> DType {
|
||||
self.dtype
|
||||
}
|
||||
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
"Exp"
|
||||
}
|
||||
}
|
||||
|
||||
// KernelSigmoid: fused sigmoid = 1/(1+exp(-x))
|
||||
// Single-kernel alternative to the 5-kernel Neg+Exp+Const+Add+Recip path.
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct KernelSigmoid {
|
||||
shape: Vec<Expression>,
|
||||
in_strides: Vec<Expression>,
|
||||
out_strides: Vec<Expression>,
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
impl EgglogOp for KernelSigmoid {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
"KernelSigmoid",
|
||||
&[
|
||||
("shape", ELIST),
|
||||
("strides", ELIST),
|
||||
("out_strides", ELIST),
|
||||
("dtype", DTYPE),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
fn n_inputs(&self) -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![
|
||||
// Match the HLIR pattern directly: Recip(Add(Exp2(Mul(Mul(x, -1), log2e)), 1))
|
||||
Rule::raw(
|
||||
"(rule
|
||||
(
|
||||
(= ?neg1 (Op (Constant ?nv) (INil)))
|
||||
(< ?nv -0.99)
|
||||
(> ?nv -1.01)
|
||||
(= ?neg_x (Op (Mul ?shape ?x_stride ?neg_stride ?neg_out_stride) (ICons ?x (ICons ?neg1 (INil)))))
|
||||
(= ?log2e (Op (Constant ?lv) (INil)))
|
||||
(> ?lv 1.44)
|
||||
(< ?lv 1.45)
|
||||
(= ?scaled (Op (Mul ?shape ?neg_out_stride ?log2e_stride ?scaled_stride) (ICons ?neg_x (ICons ?log2e (INil)))))
|
||||
(= ?exp2 (Op (Exp2 ?shape ?scaled_stride ?exp_stride) (ICons ?scaled (INil))))
|
||||
(= ?one (Op (Constant ?ov) (INil)))
|
||||
(> ?ov 0.99)
|
||||
(< ?ov 1.01)
|
||||
(= ?plus_one (Op (Add ?shape ?exp_stride ?one_stride ?add_stride) (ICons ?exp2 (ICons ?one (INil)))))
|
||||
(= ?sig_out (Op (Recip ?shape ?add_stride ?out_stride) (ICons ?plus_one (INil))))
|
||||
(= ?dt (dtype ?x))
|
||||
)
|
||||
(
|
||||
(let ?ksig (Op (KernelSigmoid ?shape ?x_stride ?out_stride ?dt) (ICons ?x (INil))))
|
||||
(union ?sig_out ?ksig)
|
||||
(set (dtype ?ksig) ?dt)
|
||||
)
|
||||
:name \"direct-sigmoid-fusion\"
|
||||
)",
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
(
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(Self {
|
||||
shape: extract_expr_list(egraph, kind_children[0], list_cache, expr_cache).unwrap(),
|
||||
in_strides: extract_expr_list(egraph, kind_children[1], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
out_strides: extract_expr_list(egraph, kind_children[2], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
dtype: extract_dtype(egraph, kind_children[3]),
|
||||
})),
|
||||
input_enodes,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl KernelOp for KernelSigmoid {
|
||||
fn compile(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> (
|
||||
CudaFunction,
|
||||
Arc<CudaModule>,
|
||||
String,
|
||||
(Expression, Expression, Expression),
|
||||
(Expression, Expression, Expression),
|
||||
Expression,
|
||||
FxHashMap<char, CudaSlice<u8>>,
|
||||
) {
|
||||
let vars = self
|
||||
.shape
|
||||
.iter()
|
||||
.flat_map(|e| e.dyn_vars())
|
||||
.chain(self.in_strides.iter().flat_map(|e| e.dyn_vars()))
|
||||
.chain(self.out_strides.iter().flat_map(|e| e.dyn_vars()))
|
||||
.collect::<FxHashSet<_>>();
|
||||
let dtype = cuda_dtype(self.dtype);
|
||||
let includes = dtype_includes(&[self.dtype]);
|
||||
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
|
||||
let dyn_dims_param = if vars.is_empty() {
|
||||
""
|
||||
} else {
|
||||
", const int* dyn_dims"
|
||||
};
|
||||
let n_elements = self
|
||||
.shape
|
||||
.iter()
|
||||
.copied()
|
||||
.product::<Expression>()
|
||||
.to_kernel();
|
||||
let out_idx = flatten_strides(&self.shape, &self.out_strides).to_kernel();
|
||||
let in_idx = flatten_strides(&self.shape, &self.in_strides).to_kernel();
|
||||
let kernel = format!(
|
||||
"{includes}
|
||||
{dyn_defines}
|
||||
extern \"C\" {{
|
||||
__global__ void sigmoid_k({dtype} *out, const {dtype} *in{dyn_dims_param}) {{
|
||||
long long const_z = (long long)blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (const_z >= {n_elements}) return;
|
||||
out[{out_idx}] = 1.0f / (1.0f + expf(-in[{in_idx}]));
|
||||
}}
|
||||
}}"
|
||||
);
|
||||
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
|
||||
(module.clone(), func.clone())
|
||||
} else {
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let func = module.load_function("sigmoid_k").unwrap();
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
(module, func)
|
||||
};
|
||||
let out_size = self.shape.iter().copied().product::<Expression>();
|
||||
(
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(out_size.ceil_div(128), 1.into(), 1.into()),
|
||||
(out_size.min(128), 1.into(), 1.into()),
|
||||
0.into(),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
self.shape.iter().copied().product()
|
||||
}
|
||||
|
||||
fn output_bytes(&self) -> Expression {
|
||||
(self.output_size() * self.dtype.bits()).ceil_div(8)
|
||||
}
|
||||
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
self.output_bytes()
|
||||
}
|
||||
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
self.output_bytes()
|
||||
}
|
||||
|
||||
fn flops(&self) -> Expression {
|
||||
// neg + exp + add + recip = ~4 ops per element
|
||||
self.shape.iter().copied().product::<Expression>() * 4
|
||||
}
|
||||
|
||||
fn output_dtype(&self) -> DType {
|
||||
self.dtype
|
||||
}
|
||||
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
"Sigmoid"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
pub mod host;
|
||||
pub mod kernel;
|
||||
pub mod logical;
|
||||
pub mod runtime;
|
||||
use std::{
|
||||
ffi::{CStr, CString},
|
||||
|
||||
@@ -1,71 +0,0 @@
|
||||
use std::fmt::Debug;
|
||||
|
||||
use luminal::{
|
||||
egglog_utils::api::{Rule, SortDef},
|
||||
hlir::unary_sort,
|
||||
op::EgglogOp,
|
||||
};
|
||||
|
||||
pub type Ops = (Exp, Sigmoid);
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct Exp;
|
||||
impl EgglogOp for Exp {
|
||||
fn sort(&self) -> SortDef {
|
||||
unary_sort("Exp")
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![Rule::raw(
|
||||
"(rule
|
||||
(
|
||||
(= ?exp_const (Op (Constant 1.442695) (INil)))
|
||||
(= ?mul (Op (Mul ?shape ?x_stride ?const_stride ?intermediate_stride) (ICons ?x (ICons ?exp_const (INil)))))
|
||||
(= ?exp2 (Op (Exp2 ?shape ?intermediate_stride ?out_stride) (ICons ?mul (INil))))
|
||||
(= ?dt (dtype ?x))
|
||||
)
|
||||
(
|
||||
(let ?exp (Op (Exp ?shape ?x_stride ?out_stride) (ICons ?x (INil))))
|
||||
(union ?exp2 ?exp)
|
||||
(set (dtype ?exp) ?dt)
|
||||
)
|
||||
)",
|
||||
)]
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct Sigmoid;
|
||||
impl EgglogOp for Sigmoid {
|
||||
fn sort(&self) -> SortDef {
|
||||
unary_sort("Sigmoid")
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![Rule::raw("(rule
|
||||
(
|
||||
(= ?neg1 (Op (Constant -1.0) (INil)))
|
||||
(= ?neg_input (Op (Mul ?input_range ?input_stride ?const_stride ?intermediate_stride) (ICons ?input (ICons ?neg1 (INil)))))
|
||||
(= ?exp (Op (Exp ?input_range ?intermediate_stride ?exp_stride) (ICons ?neg_input (INil))))
|
||||
(= ?one (Op (Constant 1.0) (INil)))
|
||||
(= ?plus_one (Op (Add ?input_range ?exp_stride ?const_stride ?plus_one_stride) (ICons ?exp (ICons ?one (INil)))))
|
||||
(= ?sig_out (Op (Recip ?input_range ?plus_one_stride ?out_stride) (ICons ?plus_one (INil))))
|
||||
(= ?dt (dtype ?input))
|
||||
)
|
||||
(
|
||||
(let ?sig (Op (Sigmoid ?input_range ?input_stride ?out_stride) (ICons ?input (INil))))
|
||||
(union ?sig_out ?sig)
|
||||
(set (dtype ?sig) ?dt)
|
||||
)
|
||||
:name \"sigmoid\"
|
||||
)")]
|
||||
}
|
||||
}
|
||||
@@ -119,6 +119,14 @@ pub struct CudaRuntime {
|
||||
active_bucket: usize,
|
||||
/// Bucket definitions per dimension (empty = single-bucket mode)
|
||||
dim_buckets: FxHashMap<char, Vec<DimBucket>>,
|
||||
|
||||
/// HLIR nodes that should never be consumed after execute().
|
||||
/// Used for weight tensors shared via external device pointers.
|
||||
persistent_hlir_nodes: FxHashSet<NodeIndex>,
|
||||
|
||||
/// Non-owning CudaSlice wrappers for external device pointers.
|
||||
/// ManuallyDrop prevents cuMemFree — the external allocator (e.g. PyTorch) owns the memory.
|
||||
external_buffers: FxHashMap<NodeIndex, std::mem::ManuallyDrop<CudaSlice<u8>>>,
|
||||
}
|
||||
|
||||
impl CudaRuntime {
|
||||
@@ -199,6 +207,32 @@ impl CudaRuntime {
|
||||
self.changed_hlir.insert(id);
|
||||
}
|
||||
|
||||
/// Set an external CUDA device pointer as input data. Zero-copy.
|
||||
/// The caller must ensure the pointer remains valid for the runtime's lifetime.
|
||||
///
|
||||
/// # Safety
|
||||
/// The device pointer must point to a valid CUDA allocation on the same device
|
||||
/// as this runtime's stream, with at least `n_bytes` bytes available.
|
||||
pub unsafe fn set_device_ptr(&mut self, id: impl ToId, device_ptr: u64, n_bytes: usize) {
|
||||
debug_assert!(device_ptr != 0, "set_device_ptr called with null pointer");
|
||||
let id = id.to_id();
|
||||
// Create CudaSlice view via cudarc's upgrade_device_ptr.
|
||||
// ManuallyDrop prevents cuMemFree on drop (external allocator owns this memory).
|
||||
let slice = unsafe {
|
||||
self.cuda_stream
|
||||
.upgrade_device_ptr::<u8>(device_ptr, n_bytes)
|
||||
};
|
||||
self.external_buffers
|
||||
.insert(id, std::mem::ManuallyDrop::new(slice));
|
||||
self.hlir_buffers.insert(id, CudaInput::Ptr(device_ptr));
|
||||
self.changed_hlir.insert(id);
|
||||
}
|
||||
|
||||
/// Mark an HLIR node as persistent — its buffer won't be consumed after execute().
|
||||
pub fn persist_hlir_node(&mut self, id: impl ToId) {
|
||||
self.persistent_hlir_nodes.insert(id.to_id());
|
||||
}
|
||||
|
||||
/// Find the LLIR producing node for an output tensor.
|
||||
fn find_producer_node(&self, id: impl ToId) -> NodeIndex {
|
||||
let id = id.to_id();
|
||||
@@ -281,12 +315,15 @@ impl CudaRuntime {
|
||||
.expect("Cannot find input tensor in runtime!")
|
||||
{
|
||||
CudaInput::Buffer(buf) => self.cuda_stream.clone_dtoh(buf).unwrap(),
|
||||
CudaInput::Ptr(p) => {
|
||||
// Raw pointer — need size from cached_buffer_ptrs or error
|
||||
panic!(
|
||||
"Cannot read raw pointer input (ptr=0x{:x}) — use Buffer variant",
|
||||
p
|
||||
);
|
||||
CudaInput::Ptr(_) => {
|
||||
// External device pointer — use the CudaSlice view from external_buffers
|
||||
if let Some(ext) = self.external_buffers.get(hlir_node) {
|
||||
self.cuda_stream.clone_dtoh(&**ext).unwrap()
|
||||
} else {
|
||||
panic!(
|
||||
"Cannot read raw pointer input — no external_buffers entry for node"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@@ -302,6 +339,57 @@ impl CudaRuntime {
|
||||
}
|
||||
}
|
||||
|
||||
/// Resolve the device-side CudaSlice for an output tensor without copying to host.
|
||||
/// Used by copy_output_to_device_ptr for DtoD transfers.
|
||||
fn resolve_output_slice(&self, id: impl ToId) -> &CudaSlice<u8> {
|
||||
let data_id = self.resolve_data_node(id);
|
||||
let bucket = self.active();
|
||||
if let Some(hlir_node) = bucket.llir_to_hlir.get(&data_id) {
|
||||
match self
|
||||
.hlir_buffers
|
||||
.get(hlir_node)
|
||||
.expect("Cannot find input tensor in runtime!")
|
||||
{
|
||||
CudaInput::Buffer(buf) => buf,
|
||||
CudaInput::Ptr(_) => self
|
||||
.external_buffers
|
||||
.get(hlir_node)
|
||||
.map(|ext| &**ext)
|
||||
.expect("Cannot read raw pointer input — no external_buffers entry for node"),
|
||||
}
|
||||
} else {
|
||||
bucket
|
||||
.buffers
|
||||
.get(&data_id)
|
||||
.expect("Cannot find tensor in runtime!")
|
||||
}
|
||||
}
|
||||
|
||||
/// Copy output tensor data to an external CUDA device pointer (DtoD).
|
||||
/// Much faster than get_f32 + HtoD for CUDA-to-CUDA workflows.
|
||||
///
|
||||
/// # Safety
|
||||
/// The dest_ptr must be a valid CUDA device allocation with at least n_bytes available.
|
||||
pub unsafe fn copy_output_to_device_ptr(&self, id: impl ToId, dest_ptr: u64, n_bytes: usize) {
|
||||
debug_assert!(
|
||||
dest_ptr != 0,
|
||||
"copy_output_to_device_ptr called with null pointer"
|
||||
);
|
||||
let src_slice = self.resolve_output_slice(id);
|
||||
let src_ptr = src_slice.device_ptr(&self.cuda_stream).0;
|
||||
let copy_bytes = n_bytes.min(src_slice.len());
|
||||
unsafe {
|
||||
cudarc::driver::result::memcpy_dtod_async(
|
||||
dest_ptr,
|
||||
src_ptr,
|
||||
copy_bytes,
|
||||
self.cuda_stream.cu_stream(),
|
||||
)
|
||||
.expect("cuMemcpyDtoDAsync failed");
|
||||
}
|
||||
self.cuda_stream.synchronize().unwrap();
|
||||
}
|
||||
|
||||
pub fn get_f32(&self, id: impl ToId) -> Vec<f32> {
|
||||
let bytes = self.get_output_data(id);
|
||||
let bytes = bytes.leak();
|
||||
@@ -684,7 +772,7 @@ fn format_duration_precise(d: &std::time::Duration) -> String {
|
||||
}
|
||||
|
||||
impl Runtime for CudaRuntime {
|
||||
type Ops = (crate::logical::Ops, crate::kernel::Ops, crate::host::Ops);
|
||||
type Ops = (crate::kernel::Ops, crate::host::Ops);
|
||||
type CompileArg = Arc<CudaStream>;
|
||||
type ExecReturn = ();
|
||||
type ProfileMetric = Duration;
|
||||
@@ -702,6 +790,8 @@ impl Runtime for CudaRuntime {
|
||||
compiled_buckets: vec![CompiledBucket::new()],
|
||||
active_bucket: 0,
|
||||
dim_buckets: FxHashMap::default(),
|
||||
persistent_hlir_nodes: FxHashSet::default(),
|
||||
external_buffers: FxHashMap::default(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -938,10 +1028,23 @@ impl Runtime for CudaRuntime {
|
||||
}
|
||||
// Add input buffers (prefer HLIR weight buffers over intermediate placeholders)
|
||||
for inp in exec_op.inputs.iter() {
|
||||
if let Some(hlir_node) = bucket.llir_to_hlir.get(inp)
|
||||
&& let Some(CudaInput::Buffer(buf)) = self.hlir_buffers.get(hlir_node)
|
||||
{
|
||||
buffer_map.insert(*inp, buf);
|
||||
if let Some(hlir_node) = bucket.llir_to_hlir.get(inp) {
|
||||
match self.hlir_buffers.get(hlir_node) {
|
||||
Some(CudaInput::Buffer(buf)) => {
|
||||
buffer_map.insert(*inp, buf);
|
||||
}
|
||||
Some(CudaInput::Ptr(_)) => {
|
||||
if let Some(ext) = self.external_buffers.get(hlir_node) {
|
||||
buffer_map.insert(*inp, &**ext);
|
||||
}
|
||||
}
|
||||
None => {}
|
||||
}
|
||||
if !buffer_map.contains_key(inp)
|
||||
&& let Some(buf) = bucket.buffers.get(inp)
|
||||
{
|
||||
buffer_map.insert(*inp, buf);
|
||||
}
|
||||
} else if let Some(buf) = bucket.buffers.get(inp) {
|
||||
buffer_map.insert(*inp, buf);
|
||||
}
|
||||
@@ -952,25 +1055,43 @@ impl Runtime for CudaRuntime {
|
||||
if let Entry::Vacant(e) = buffer_map.entry(extra_node) {
|
||||
if let Some(buf) = bucket.buffers.get(&extra_node) {
|
||||
e.insert(buf);
|
||||
} else if let Some(hlir_node) = bucket.llir_to_hlir.get(&extra_node)
|
||||
&& let Some(CudaInput::Buffer(buf)) = self.hlir_buffers.get(hlir_node)
|
||||
{
|
||||
e.insert(buf);
|
||||
} else if let Some(hlir_node) = bucket.llir_to_hlir.get(&extra_node) {
|
||||
match self.hlir_buffers.get(hlir_node) {
|
||||
Some(CudaInput::Buffer(buf)) => {
|
||||
e.insert(buf);
|
||||
}
|
||||
Some(CudaInput::Ptr(_)) => {
|
||||
if let Some(ext) = self.external_buffers.get(hlir_node) {
|
||||
e.insert(&**ext);
|
||||
}
|
||||
}
|
||||
None => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Resolve output aliases
|
||||
for (&alias_node, &alias_target) in &bucket.output_alias_map {
|
||||
if let std::collections::hash_map::Entry::Occupied(mut e) =
|
||||
buffer_map.entry(alias_node)
|
||||
{
|
||||
if let Some(hlir_node) = bucket.llir_to_hlir.get(&alias_target)
|
||||
&& let Some(CudaInput::Buffer(buf)) = self.hlir_buffers.get(hlir_node)
|
||||
{
|
||||
e.insert(buf);
|
||||
} else if let Some(buf) = bucket.buffers.get(&alias_target) {
|
||||
e.insert(buf);
|
||||
}
|
||||
if !buffer_map.contains_key(&alias_node) {
|
||||
continue;
|
||||
}
|
||||
// Try HLIR buffer first (includes external device pointers)
|
||||
let resolved: Option<&CudaSlice<u8>> =
|
||||
if let Some(hlir_node) = bucket.llir_to_hlir.get(&alias_target) {
|
||||
match self.hlir_buffers.get(hlir_node) {
|
||||
Some(CudaInput::Buffer(buf)) => Some(buf),
|
||||
Some(CudaInput::Ptr(_)) => {
|
||||
self.external_buffers.get(hlir_node).map(|ext| &**ext)
|
||||
}
|
||||
None => None,
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
if let Some(buf) = resolved {
|
||||
buffer_map.insert(alias_node, buf);
|
||||
} else if let Some(buf) = bucket.buffers.get(&alias_target) {
|
||||
buffer_map.insert(alias_node, buf);
|
||||
}
|
||||
}
|
||||
let _span = span!(
|
||||
@@ -1069,11 +1190,13 @@ impl Runtime for CudaRuntime {
|
||||
.hlir_buffers
|
||||
.keys()
|
||||
.filter(|hlir_node| !inputs_with_outputs.contains(hlir_node))
|
||||
.filter(|hlir_node| !self.persistent_hlir_nodes.contains(hlir_node))
|
||||
.copied()
|
||||
.collect();
|
||||
|
||||
for hlir_node in to_consume {
|
||||
self.hlir_buffers.remove(&hlir_node);
|
||||
self.external_buffers.remove(&hlir_node);
|
||||
let bucket = &mut self.compiled_buckets[self.active_bucket];
|
||||
if let Some(llir_node) = bucket.hlir_to_llir.get(&hlir_node) {
|
||||
bucket.cached_buffer_ptrs.remove(llir_node);
|
||||
|
||||
1
crates/luminal_python/.gitignore
vendored
1
crates/luminal_python/.gitignore
vendored
@@ -1,5 +1,4 @@
|
||||
*.onnx
|
||||
tests/llama38b_ref_logits.pt
|
||||
__pycache__/
|
||||
*.pyc
|
||||
uv.lock
|
||||
|
||||
@@ -1,4 +1,8 @@
|
||||
A couple of short things to keep in mind
|
||||
## Python Environment
|
||||
|
||||
- Always use `uv run` to execute Python tools (pytest, pre-commit, python) — never bare `pytest` or `python`
|
||||
- Use `uv add` / `uv add --dev` / `uv remove` for dependencies — never hand-edit pyproject.toml deps
|
||||
- After modifying Rust source files, rebuild before running Python tests: `maturin develop --release`
|
||||
|
||||
## Lessons Learned
|
||||
|
||||
|
||||
@@ -340,7 +340,7 @@ with matching shape tracker dimensions.
|
||||
|
||||
---
|
||||
|
||||
## Bug: TopK values wrong on CUDA (gather_elements with sliced non-contiguous indices)
|
||||
## 2026-03-05 — TopK Values Wrong on CUDA (gather_elements with sliced non-contiguous indices)
|
||||
|
||||
1. **Symptom**: `test_topk_values` failed on CUDA — rows 0-1 were correct but rows 2+ returned
|
||||
the value at column 0 of each row (all three top-k positions got the same value).
|
||||
@@ -748,3 +748,11 @@ method rather than string-matching on Debug output. Additionally, when diagnosin
|
||||
candidates rejected" during search, check whether the rejection is from actual float NaN
|
||||
or from dtype misinterpretation — the key diagnostic is whether the NaN pattern is
|
||||
identical across all attempts (dtype issue) vs varying (actual numerical issue).
|
||||
|
||||
## 2026-03-25 — KernelExp/KernelSigmoid: Fused CUDA Kernels for Precision
|
||||
|
||||
1. **Symptom**: `test_hf_llama3_full` (16-layer Llama-3.2-1B) had ~1e-4 max diff vs PyTorch.
|
||||
2. **Root cause**: `exp(x)` was computed as `exp2(x * 1.442695)` — the constant truncated by `{:.6}` format + extra multiply adds rounding. Sigmoid was 5 separate kernels. SumReduce had naive accumulation.
|
||||
3. **Why hard**: Per-operation error was ~1e-7 but compounded over 16 layers Ă— ~25 extra materializations. The egglog `Exp` rewrite depends on exact constant format matching.
|
||||
4. **Fix**: Added `KernelExp` (uses `expf()`), `KernelSigmoid` (uses `1/(1+expf(-x))`), and Kahan summation in SumReduce. Each uses both `kernel_rewrite` and a direct egglog pattern match with range checks (e.g., `(> ?val 1.44) (< ?val 1.45)`) to bypass constant format dependency.
|
||||
5. **Principle**: When decomposed CUDA kernel chains cause precision loss, add fused kernels via `kernel_rewrite`. For robustness, add BOTH the logical-op rewrite path AND a direct HLIR pattern match — the constant format in egglog can be fragile.
|
||||
|
||||
@@ -3,15 +3,19 @@ name = "luminal_python"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"numpy>=2.0.2",
|
||||
"torch>=2.10.0",
|
||||
"onnx",
|
||||
"onnxscript",
|
||||
"safetensors",
|
||||
"flash-attn-3>=3.0.0",
|
||||
]
|
||||
|
||||
[tool.uv]
|
||||
no-build-isolation-package = ["flash-attn"]
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch-cu128"
|
||||
url = "https://download.pytorch.org/whl/cu128"
|
||||
@@ -21,6 +25,7 @@ explicit = true
|
||||
torch = [
|
||||
{ index = "pytorch-cu128", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||
]
|
||||
flash-attn-3 = { index = "pytorch-cu128" }
|
||||
|
||||
|
||||
[build-system]
|
||||
@@ -40,13 +45,21 @@ markers = [
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"maturin>=1.0,<2.0",
|
||||
"maturin-import-hook>=0.3.0",
|
||||
"pytest>=9.0.2",
|
||||
"pytest-profiling",
|
||||
"snakeviz",
|
||||
"maturin-import-hook>=0.3.0",
|
||||
"pytest-randomly>=4.0.1",
|
||||
"transformers>=4.40.0",
|
||||
"transformers>=5.5.0,<6",
|
||||
"diffusers>=0.35.0",
|
||||
"onnxsim",
|
||||
"tiktoken>=0.12.0",
|
||||
"pydantic>=2.12.5",
|
||||
"psutil>=7.2.2",
|
||||
"modal>=1.3.5",
|
||||
"pillow",
|
||||
"flash-attn>=2.8.3",
|
||||
]
|
||||
flash-attention-4 = [
|
||||
"nvidia-cutlass-dsl==4.1.0",
|
||||
]
|
||||
|
||||
@@ -14,7 +14,7 @@ uv run maturin develop --manifest-path rust/Cargo.toml --features cuda -r
|
||||
|
||||
# Run pytest with CUDA backend
|
||||
echo "Step 3: Running pytest with CUDA backend..."
|
||||
RUST_BACKTRACE=1 LUMINAL_BACKEND=cuda uv run pytest tests/test_hlir_ops.py tests/test_unary.py tests/test_llama3.py -m "not slow" -v
|
||||
RUST_BACKTRACE=1 LUMINAL_BACKEND=cuda uv run pytest tests/test_llama3.py tests/test_hlir_ops.py tests/test_unary.py -v
|
||||
|
||||
echo ""
|
||||
echo "=== Tests Complete ==="
|
||||
|
||||
@@ -14,6 +14,6 @@ uv run maturin develop --manifest-path rust/Cargo.toml --features cuda -r
|
||||
|
||||
# Run pytest with CUDA backend and PT2 export mode
|
||||
echo "Step 3: Running pytest with CUDA backend + PT2 export mode..."
|
||||
RUST_BACKTRACE=1 LUMINAL_BACKEND=cuda LUMINAL_EXPORT_MODE=pt2 uv run pytest tests/test_hlir_ops.py tests/test_unary.py tests/test_llama3.py -m "not slow" -v
|
||||
RUST_BACKTRACE=1 LUMINAL_BACKEND=cuda LUMINAL_EXPORT_MODE=pt2 uv run pytest tests/test_llama3.py tests/test_hlir_ops.py tests/test_unary.py -v
|
||||
echo ""
|
||||
echo "=== Tests Complete ==="
|
||||
|
||||
@@ -1,34 +1,45 @@
|
||||
use luminal::{
|
||||
prelude::{
|
||||
tracing::{Level, span, trace},
|
||||
*,
|
||||
},
|
||||
shape::Expression,
|
||||
visualization::ToDot,
|
||||
};
|
||||
use onnx_protobuf::{GraphProto, ModelProto};
|
||||
use pyo3::prelude::*;
|
||||
use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
path::Path,
|
||||
};
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
use crate::util::transpose_weight_data;
|
||||
use crate::{
|
||||
dispatch::process_onnx_nodes,
|
||||
runtime::*,
|
||||
util::{
|
||||
DimParamMap, get_shape_for_onnx_value, get_shape_for_onnx_value_expr,
|
||||
load_all_tensor_floats, load_initializer_as_f32,
|
||||
},
|
||||
};
|
||||
use luminal::prelude::tracing::{trace, warn};
|
||||
use luminal::{prelude::*, shape::Expression, visualization::ToDot};
|
||||
use pyo3::prelude::*;
|
||||
use std::collections::HashMap;
|
||||
#[cfg(feature = "cuda")]
|
||||
use std::collections::HashSet;
|
||||
|
||||
use crate::{runtime::RuntimeBackend, util::DimParamMap};
|
||||
|
||||
/// Common intermediate result from translating a model graph (ONNX or FX).
|
||||
pub struct GraphTranslation {
|
||||
pub graph: Graph,
|
||||
pub tensor_ids: HashMap<String, NodeIndex>,
|
||||
pub input_names: Vec<String>,
|
||||
pub output_names: Vec<String>,
|
||||
pub output_shape_exprs: Vec<Vec<Expression>>,
|
||||
pub input_shape_exprs: Vec<Vec<Expression>>,
|
||||
pub dim_param_map: DimParamMap,
|
||||
}
|
||||
|
||||
/// Pre-loaded weight data from any model format.
|
||||
///
|
||||
/// NOTE: Currently assumes all data is F32. When the type system branch lands
|
||||
/// with proper multi-dtype support, this struct (and all callers) will need
|
||||
/// updating to carry dtype metadata alongside the raw data.
|
||||
pub struct WeightData {
|
||||
/// (Input node label, f32 data) for weights and constants.
|
||||
pub weights: Vec<(String, Vec<f32>)>,
|
||||
/// label → element count for ALL Input nodes (for CUDA dummy data sizing).
|
||||
pub tensor_sizes: HashMap<String, usize>,
|
||||
/// label → (device_ptr, n_bytes) for zero-copy CUDA weight sharing.
|
||||
pub device_ptrs: HashMap<String, (u64, usize)>,
|
||||
}
|
||||
|
||||
#[pyclass(unsendable)]
|
||||
pub struct CompiledGraph {
|
||||
pub graph: Graph,
|
||||
pub runtime: RuntimeBackend,
|
||||
pub tensor_ids: HashMap<String, NodeIndex>,
|
||||
/// Cached label → NodeIndex map for O(1) lookups in set_weight_* methods.
|
||||
label_map: HashMap<String, NodeIndex>,
|
||||
pub input_names: Vec<String>,
|
||||
pub output_names: Vec<String>,
|
||||
pub output_shapes: Vec<Vec<usize>>,
|
||||
@@ -38,218 +49,35 @@ pub struct CompiledGraph {
|
||||
}
|
||||
|
||||
impl CompiledGraph {
|
||||
/// Shared compilation pipeline for both ONNX and FX/PT2 graphs.
|
||||
///
|
||||
/// Takes a format-neutral `GraphTranslation` (produced by `translate_onnx` or
|
||||
/// `translate_pt2`) and `WeightData`, builds the backend, loads weights, and
|
||||
/// returns a ready-to-execute `CompiledGraph`.
|
||||
pub fn parse_graph(
|
||||
model: ModelProto,
|
||||
model_directory: &Path,
|
||||
translation: GraphTranslation,
|
||||
weight_data: WeightData,
|
||||
backend: &str,
|
||||
search_iters: usize,
|
||||
) -> Result<CompiledGraph, String> {
|
||||
let _span = span!(Level::TRACE, "Onnx Graphing Parsing").entered();
|
||||
let onnx_graph = &model.graph;
|
||||
let mut cx = Graph::new();
|
||||
// We will need to track the tensors we allocate so we can match up inputs and outputs in the graph
|
||||
let mut tensors: HashMap<String, GraphTensor> = HashMap::new();
|
||||
|
||||
// Dynamic dimension tracking
|
||||
let mut dim_param_map: DimParamMap = HashMap::new();
|
||||
let mut next_char = 'a';
|
||||
|
||||
// This is the name of all of the tensors we will need to fill in parameters for
|
||||
let initializer_names: HashSet<&str> = onnx_graph
|
||||
.initializer
|
||||
.iter()
|
||||
.map(|t| t.name.as_str())
|
||||
.collect();
|
||||
|
||||
// Input is an overloaded term in Onnx, it both means the inputs into the model, like the next token
|
||||
// and the parameters of the layers, for this we don't want any of the parameters
|
||||
// Input here is in the straightforward meaning, those tensors you feed into the network for a
|
||||
// forward passd
|
||||
let input_names: Vec<String> = onnx_graph
|
||||
.input
|
||||
.iter()
|
||||
.filter(|inp| !initializer_names.contains(inp.name.as_str()))
|
||||
.map(|inp| inp.name.clone())
|
||||
.collect();
|
||||
|
||||
// Create "holding" tensors for the input
|
||||
// this way they can be considered in the graph computation, and later as we do mutiple runs we can target them and swap out the values
|
||||
// in them and not need to recompile the network
|
||||
for input in &onnx_graph.input {
|
||||
// Use expression-aware shape parsing to detect DimParam (dynamic dims)
|
||||
let shape_exprs =
|
||||
get_shape_for_onnx_value_expr(input, &mut dim_param_map, &mut next_char);
|
||||
if shape_exprs.is_empty() {
|
||||
// Fall back to concrete parsing (initializer shapes don't have DimParam)
|
||||
let shape = get_shape_for_onnx_value(input);
|
||||
if shape.is_empty() {
|
||||
trace!("Input {} skipped because it is empty", input.name.clone());
|
||||
continue;
|
||||
}
|
||||
let tensor = cx.named_tensor(input.name.clone(), shape);
|
||||
trace!("Input {} added to tensors", input.name.clone());
|
||||
tensors.insert(input.name.clone(), tensor);
|
||||
continue;
|
||||
}
|
||||
// Always F32: Python runtime always sends float32 data via .float().numpy()
|
||||
let tensor = cx.named_tensor(input.name.clone(), shape_exprs);
|
||||
trace!("Input {} added to tensors", input.name.clone());
|
||||
tensors.insert(input.name.clone(), tensor);
|
||||
}
|
||||
|
||||
for init in &onnx_graph.initializer {
|
||||
if !tensors.contains_key(&init.name) {
|
||||
let mut shape: Vec<usize> = init.dims.iter().map(|&d| d as usize).collect();
|
||||
// Scalar (0-dim) tensors have empty dims; represent as [1] in luminal
|
||||
if shape.is_empty() {
|
||||
shape = vec![1];
|
||||
}
|
||||
let tensor = cx.named_tensor(init.name.clone(), shape);
|
||||
tensors.insert(init.name.clone(), tensor);
|
||||
}
|
||||
}
|
||||
|
||||
let mut weight_data = Vec::new();
|
||||
|
||||
let mut known_values: HashMap<String, Vec<f32>> = HashMap::new();
|
||||
|
||||
for init in &onnx_graph.initializer {
|
||||
let n_elements: usize = init
|
||||
.dims
|
||||
.iter()
|
||||
.map(|&d| d as usize)
|
||||
.product::<usize>()
|
||||
.max(1);
|
||||
// MAGIC_NUMBER:
|
||||
if n_elements <= 32 {
|
||||
if let Some(floats) = load_initializer_as_f32(init) {
|
||||
known_values.insert(init.name.clone(), floats);
|
||||
} else {
|
||||
// Questions
|
||||
// Should this be fatal
|
||||
// Should this be a print or a log
|
||||
panic!("Unable to initializer values for {:?}", init.name);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Shape expressions map for propagating symbolic shape values through
|
||||
// Shape→Gather→Unsqueeze→Concat chains in dynamic ONNX graphs
|
||||
let mut shape_exprs: HashMap<String, Vec<Expression>> = HashMap::new();
|
||||
|
||||
// Process computation nodes (Constant nodes add to weight_data)
|
||||
process_onnx_nodes(
|
||||
&onnx_graph.node,
|
||||
&mut tensors,
|
||||
&mut cx,
|
||||
&mut weight_data,
|
||||
&mut known_values,
|
||||
&mut shape_exprs,
|
||||
)
|
||||
.map_err(|e| format!("process_onnx_nodes failed: {}", e))?;
|
||||
|
||||
// Mark weight/constant tensors as persistent so their buffers survive
|
||||
// execute()'s input consumption. User inputs (like input_ids) are NOT persisted
|
||||
// since they are re-set via set_input() before each execution.
|
||||
for (name, gt) in &tensors {
|
||||
if !input_names.contains(name) {
|
||||
gt.persist();
|
||||
}
|
||||
}
|
||||
|
||||
let has_dynamic = !dim_param_map.is_empty();
|
||||
|
||||
// Mark graph outputs (must happen before build_search_space)
|
||||
let mut output_names = Vec::new();
|
||||
let mut output_shapes = Vec::new();
|
||||
let mut output_shape_exprs = Vec::new();
|
||||
for output_vi in &onnx_graph.output {
|
||||
if let Some(>) = tensors.get(&output_vi.name) {
|
||||
// Force contiguous if the shape tracker is a non-contiguous view
|
||||
// (e.g. a view-only slice that changed dims without a gather).
|
||||
// Without this, get_f32 returns the full underlying buffer.
|
||||
let gt = if gt.shape != gt.shape.contiguous() {
|
||||
let contiguous = gt * 1.0;
|
||||
tensors.insert(output_vi.name.clone(), contiguous);
|
||||
contiguous
|
||||
} else {
|
||||
gt
|
||||
};
|
||||
gt.output();
|
||||
let dims = gt.dims();
|
||||
|
||||
// Store Expression-based shapes for dynamic resolution
|
||||
output_shape_exprs.push(dims.clone());
|
||||
|
||||
// For concrete output shapes, resolve now; for dynamic, use placeholder
|
||||
let shape: Vec<usize> = dims.iter().map(|d| d.to_usize().unwrap_or(1)).collect();
|
||||
if shape.is_empty() {
|
||||
return Err(format!(
|
||||
"Output tensor '{}' has no shape information in the ONNX model",
|
||||
output_vi.name
|
||||
));
|
||||
}
|
||||
output_names.push(output_vi.name.clone());
|
||||
output_shapes.push(shape);
|
||||
}
|
||||
}
|
||||
// If we have dynamic dims, set initial values in the graph's dyn_map
|
||||
// based on the concrete shapes from the example input used during export
|
||||
if has_dynamic {
|
||||
for input in &onnx_graph.input {
|
||||
if initializer_names.contains(input.name.as_str()) {
|
||||
continue;
|
||||
}
|
||||
let concrete_shape = get_shape_for_onnx_value(input);
|
||||
let expr_shape =
|
||||
get_shape_for_onnx_value_expr(input, &mut dim_param_map, &mut next_char);
|
||||
for (expr, concrete) in expr_shape.iter().zip(concrete_shape.iter()) {
|
||||
if expr.to_usize().is_none() {
|
||||
// This is a symbolic dim — set initial value in dyn_map
|
||||
// Extract the char variable from the expression
|
||||
if let Some(ch) = dim_param_map
|
||||
.values()
|
||||
.find(|&&ch| Expression::from(ch) == *expr)
|
||||
{
|
||||
cx.set_dim(*ch, *concrete);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Extract weight data from initializers (handles inline + external storage)
|
||||
// Batch load reads each external file only once instead of per-tensor
|
||||
for (name, floats) in load_all_tensor_floats(&onnx_graph.initializer, model_directory) {
|
||||
if let Some(f) = floats {
|
||||
weight_data.push((name, f));
|
||||
}
|
||||
}
|
||||
|
||||
// Collect tensor name -> NodeIndex mapping
|
||||
let tensor_ids: HashMap<String, NodeIndex> = tensors
|
||||
.iter()
|
||||
.map(|(name, gt)| (name.clone(), gt.id))
|
||||
.collect();
|
||||
|
||||
// Track which tensor names are Input nodes (includes those created during process_onnx_nodes)
|
||||
let input_tensor_names: HashSet<String> = tensors.keys().cloned().collect();
|
||||
let GraphTranslation {
|
||||
mut graph,
|
||||
tensor_ids,
|
||||
input_names,
|
||||
output_names,
|
||||
output_shape_exprs,
|
||||
input_shape_exprs,
|
||||
dim_param_map,
|
||||
} = translation;
|
||||
|
||||
let rt = match backend {
|
||||
#[cfg(feature = "cuda")]
|
||||
"cuda" => CompiledGraph::build_cuda_backend(
|
||||
onnx_graph,
|
||||
model_directory,
|
||||
&mut tensors,
|
||||
&mut weight_data,
|
||||
&mut cx,
|
||||
&input_tensor_names,
|
||||
)?,
|
||||
"native" => CompiledGraph::build_native_backend(
|
||||
onnx_graph,
|
||||
model_directory,
|
||||
&mut tensors,
|
||||
&mut weight_data,
|
||||
&mut cx,
|
||||
&input_tensor_names,
|
||||
)?,
|
||||
"cuda" | "gpu" => {
|
||||
CompiledGraph::build_cuda_backend(&mut graph, &weight_data, search_iters)?
|
||||
}
|
||||
"native" | "cpu" => {
|
||||
CompiledGraph::build_native_backend(&mut graph, &weight_data, search_iters)?
|
||||
}
|
||||
_ => {
|
||||
#[cfg(feature = "cuda")]
|
||||
{
|
||||
@@ -274,22 +102,19 @@ impl CompiledGraph {
|
||||
}
|
||||
};
|
||||
|
||||
// Build input_shape_exprs for user inputs (needed for auto-dim detection)
|
||||
let input_shape_exprs: Vec<Vec<Expression>> = input_names
|
||||
// Resolve concrete output shapes from expressions
|
||||
let output_shapes: Vec<Vec<usize>> = output_shape_exprs
|
||||
.iter()
|
||||
.map(|name| {
|
||||
if let Some(>) = tensors.get(name) {
|
||||
gt.dims()
|
||||
} else {
|
||||
vec![]
|
||||
}
|
||||
})
|
||||
.map(|exprs| exprs.iter().map(|e| e.to_usize().unwrap_or(1)).collect())
|
||||
.collect();
|
||||
|
||||
let label_map = CompiledGraph::build_label_map(&graph);
|
||||
|
||||
Ok(CompiledGraph {
|
||||
graph: cx,
|
||||
graph,
|
||||
runtime: rt,
|
||||
tensor_ids,
|
||||
label_map,
|
||||
input_names,
|
||||
output_names,
|
||||
output_shapes,
|
||||
@@ -299,124 +124,154 @@ impl CompiledGraph {
|
||||
})
|
||||
}
|
||||
|
||||
/// Build a label → NodeIndex map for all Input nodes in the graph.
|
||||
/// Used for efficient weight loading by label matching.
|
||||
fn build_label_map(graph: &Graph) -> HashMap<String, NodeIndex> {
|
||||
graph
|
||||
.graph
|
||||
.node_indices()
|
||||
.filter_map(|node_id| {
|
||||
(*graph.graph[node_id])
|
||||
.as_any()
|
||||
.downcast_ref::<luminal::hlir::Input>()
|
||||
.map(|input| (input.label.clone(), node_id))
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
fn build_cuda_backend(
|
||||
onnx_graph: &protobuf::MessageField<GraphProto>,
|
||||
model_directory: &Path,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
weight_data: &mut Vec<(String, Vec<f32>)>,
|
||||
context: &mut Graph,
|
||||
input_tensor_names: &HashSet<String>,
|
||||
graph: &mut Graph,
|
||||
weight_data: &WeightData,
|
||||
search_iters: usize,
|
||||
) -> Result<RuntimeBackend, String> {
|
||||
let compute_n_elements = |name: &str| -> usize {
|
||||
if let Some(vi) = onnx_graph.input.iter().find(|i| i.name == name) {
|
||||
let shape = get_shape_for_onnx_value(vi);
|
||||
shape.iter().product::<usize>()
|
||||
} else if let Some(init) = onnx_graph.initializer.iter().find(|i| i.name == name) {
|
||||
init.dims.iter().map(|&d| d as usize).product::<usize>()
|
||||
} else if let Some((_, data)) = weight_data.iter().find(|(n, _)| n == name) {
|
||||
data.len()
|
||||
let device_ptrs = &weight_data.device_ptrs;
|
||||
use luminal_cuda_lite::cudarc::driver::CudaContext;
|
||||
use luminal_cuda_lite::runtime::CudaRuntime;
|
||||
|
||||
let cuda_ctx = CudaContext::new(0).map_err(|e| format!("CUDA context init failed: {e}"))?;
|
||||
let stream = cuda_ctx.default_stream();
|
||||
|
||||
graph.build_search_space::<CudaRuntime>();
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
|
||||
// Build label → NodeIndex map for device pointer matching.
|
||||
let label_map = CompiledGraph::build_label_map(graph);
|
||||
|
||||
// For weights with device pointers: use them directly (zero-copy).
|
||||
// This avoids allocating ~N GB of dummy data during search.
|
||||
// The pointers survive search because profiling mode skips buffer consumption,
|
||||
// and persist_hlir_node ensures they survive post-search execution too.
|
||||
let mut device_ptr_nodes: HashSet<NodeIndex> = HashSet::new();
|
||||
let mut matched_count = 0usize;
|
||||
let mut missed_labels: Vec<String> = Vec::new();
|
||||
for (label, &(ptr, n_bytes)) in device_ptrs {
|
||||
if let Some(&node_id) = label_map.get(label) {
|
||||
unsafe { rt.set_device_ptr(node_id, ptr, n_bytes) };
|
||||
rt.persist_hlir_node(node_id);
|
||||
device_ptr_nodes.insert(node_id);
|
||||
matched_count += 1;
|
||||
} else {
|
||||
0
|
||||
missed_labels.push(label.clone());
|
||||
}
|
||||
};
|
||||
}
|
||||
let total_device_bytes: usize = device_ptrs.values().map(|(_, n)| *n).sum();
|
||||
trace!(
|
||||
"[CUDA BUILD] Device pointers: {} matched, {} missed out of {} total ({:.3} GiB)",
|
||||
matched_count,
|
||||
missed_labels.len(),
|
||||
device_ptrs.len(),
|
||||
total_device_bytes as f64 / (1024.0 * 1024.0 * 1024.0),
|
||||
);
|
||||
if !missed_labels.is_empty() {
|
||||
warn!(
|
||||
"[CUDA BUILD] {} device-ptr labels did not match any Input node (first 10): {:?}",
|
||||
missed_labels.len(),
|
||||
&missed_labels[..missed_labels.len().min(10)]
|
||||
);
|
||||
let available: Vec<&String> = label_map.keys().take(10).collect();
|
||||
warn!(
|
||||
"[CUDA BUILD] Available label_map keys (first 10): {:?}",
|
||||
available
|
||||
);
|
||||
}
|
||||
|
||||
// CUDA: Two-phase - set data BEFORE search for profiling
|
||||
let (mut cuda_rt, _stream) = prepare_cuda(context)?;
|
||||
|
||||
// Set dummy data for ALL input tensors using small non-zero values (ones).
|
||||
// Set dummy 1.0 data for remaining Input nodes (user inputs, constants without
|
||||
// device pointers) for safe search profiling.
|
||||
// IMPORTANT: Must use 1.0, NOT 0.0. Zero inputs cause NaN in many ops:
|
||||
// - fmod(0, 0) = NaN (Mod)
|
||||
// - recip(0) = inf → weight * inf = NaN (Div)
|
||||
// - log(0) = -inf (Pow)
|
||||
// - chain ops with zero produce NaN (Erf)
|
||||
// The search's has_nan_outputs check then rejects ALL candidates, causing
|
||||
// "Failed to find viable genome" errors. See LessonsLearned.md entry #1.
|
||||
// Note: torch.compile passes model weights as additional ONNX inputs (not
|
||||
// initializers), so these dummy values also cover weight tensors.
|
||||
for (name, gt) in &mut *tensors {
|
||||
if !input_tensor_names.contains(name) {
|
||||
let mut dummy_total_elements = 0usize;
|
||||
let mut dummy_count = 0usize;
|
||||
for node_id in graph.graph.node_indices() {
|
||||
if device_ptr_nodes.contains(&node_id) {
|
||||
continue;
|
||||
}
|
||||
let n_elements = compute_n_elements(name);
|
||||
if n_elements > 0 {
|
||||
cuda_rt.set_data(gt.id, vec![1.0f32; n_elements]);
|
||||
}
|
||||
}
|
||||
|
||||
// Overwrite with real initializer data (for accurate profiling)
|
||||
// Batch load reads each external file only once
|
||||
let init_data = load_all_tensor_floats(&onnx_graph.initializer, model_directory);
|
||||
for (i, (name, floats_opt)) in init_data.iter().enumerate() {
|
||||
let floats = match floats_opt {
|
||||
Some(f) => f,
|
||||
None => continue,
|
||||
};
|
||||
if let Some(gt) = tensors.get(name) {
|
||||
cuda_rt.set_data(gt.id, floats.clone());
|
||||
}
|
||||
let kn_name = format!("{}_kn", name);
|
||||
if let Some(gt_kn) = tensors.get(&kn_name) {
|
||||
let dims: Vec<usize> = onnx_graph.initializer[i]
|
||||
.dims
|
||||
.iter()
|
||||
.map(|&d| d as usize)
|
||||
.collect();
|
||||
if dims.len() == 2 {
|
||||
let transposed = transpose_weight_data(floats, dims[0], dims[1]);
|
||||
cuda_rt.set_data(gt_kn.id, transposed);
|
||||
if let Some(input) = (*graph.graph[node_id])
|
||||
.as_any()
|
||||
.downcast_ref::<luminal::hlir::Input>()
|
||||
{
|
||||
if let Some(&n) = weight_data.tensor_sizes.get(&input.label) {
|
||||
if n > 0 {
|
||||
dummy_total_elements += n;
|
||||
dummy_count += 1;
|
||||
rt.set_data(node_id, vec![1.0f32; n]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
trace!(
|
||||
"[CUDA BUILD] Dummy data: {} nodes, {} elements ({:.3} GiB as f32)",
|
||||
dummy_count,
|
||||
dummy_total_elements,
|
||||
(dummy_total_elements * 4) as f64 / (1024.0 * 1024.0 * 1024.0),
|
||||
);
|
||||
|
||||
// Load constant node data
|
||||
for (name, floats) in weight_data {
|
||||
if let Some(gt) = tensors.get(name) {
|
||||
cuda_rt.set_data(gt.id, floats.clone());
|
||||
// Search (device-pointer weights are used directly; dummy data for the rest)
|
||||
let mut rt = graph.search(rt, search_iters);
|
||||
|
||||
// Load real weight data for non-device-ptr weights (constants from PT2 archive, etc.)
|
||||
let mut loaded_weight_elements = 0usize;
|
||||
let mut loaded_weight_count = 0usize;
|
||||
for (label, data) in &weight_data.weights {
|
||||
if !device_ptrs.contains_key(label) {
|
||||
if let Some(&node_id) = label_map.get(label) {
|
||||
loaded_weight_elements += data.len();
|
||||
loaded_weight_count += 1;
|
||||
rt.set_data(node_id, data.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
trace!(
|
||||
"[CUDA BUILD] Post-search weight load: {} weights, {} elements ({:.3} GiB as f32)",
|
||||
loaded_weight_count,
|
||||
loaded_weight_elements,
|
||||
(loaded_weight_elements * 4) as f64 / (1024.0 * 1024.0 * 1024.0),
|
||||
);
|
||||
|
||||
// Now finalize (search with profiling, data is available)
|
||||
let cuda_rt = finalize_cuda(context, cuda_rt);
|
||||
|
||||
Ok(cuda_rt)
|
||||
Ok(RuntimeBackend::Cuda(Box::new(rt)))
|
||||
}
|
||||
|
||||
fn build_native_backend(
|
||||
onnx_graph: &protobuf::MessageField<GraphProto>,
|
||||
model_directory: &Path,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
weight_data: &mut Vec<(String, Vec<f32>)>,
|
||||
context: &mut Graph,
|
||||
_input_tensor_names: &HashSet<String>,
|
||||
graph: &mut Graph,
|
||||
weight_data: &WeightData,
|
||||
search_iters: usize,
|
||||
) -> Result<RuntimeBackend, String> {
|
||||
let mut rt = initialize_native(context)?;
|
||||
context.search(NativeRuntime::default(), 1);
|
||||
graph.build_search_space::<NativeRuntime>();
|
||||
let mut rt = graph.search(NativeRuntime::default(), search_iters);
|
||||
|
||||
// Set initializer data - these MUST exist after optimization (they're weights)
|
||||
// Skip _kn variants - they might be optimized away
|
||||
// Batch load reads each external file only once
|
||||
for (name, floats_opt) in load_all_tensor_floats(&onnx_graph.initializer, model_directory) {
|
||||
let floats = match floats_opt {
|
||||
Some(f) => f,
|
||||
None => continue,
|
||||
};
|
||||
if let Some(gt) = tensors.get(&name) {
|
||||
rt.set_data(gt.id, floats);
|
||||
// Load weight data after search
|
||||
let label_map = CompiledGraph::build_label_map(graph);
|
||||
for (label, data) in &weight_data.weights {
|
||||
if let Some(&node_id) = label_map.get(label) {
|
||||
rt.set_data(node_id, data.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Load constant node data, but skip _kn transposed variants
|
||||
for (name, floats) in weight_data {
|
||||
// Skip _kn transposed variants - might be optimized away
|
||||
if name.ends_with("_kn") {
|
||||
continue;
|
||||
}
|
||||
if let Some(gt) = tensors.get(name) {
|
||||
rt.set_data(gt.id, floats.clone());
|
||||
}
|
||||
}
|
||||
Ok(rt)
|
||||
Ok(RuntimeBackend::Native(rt))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -525,6 +380,94 @@ impl CompiledGraph {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Set input tensor data from a CPU host memory pointer (avoids Python list conversion).
|
||||
/// The pointer must point to contiguous f32 data (from tensor.data_ptr() on a CPU float32 tensor).
|
||||
fn set_input_from_ptr(&mut self, name: &str, ptr: u64, n_elements: usize) -> PyResult<()> {
|
||||
debug_assert!(ptr != 0, "set_input_from_ptr called with null pointer");
|
||||
let node_id = self.tensor_ids.get(name).ok_or_else(|| {
|
||||
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!("Unknown input tensor: {}", name))
|
||||
})?;
|
||||
let data: Vec<f32> =
|
||||
unsafe { std::slice::from_raw_parts(ptr as *const f32, n_elements).to_vec() };
|
||||
self.runtime.set_data(*node_id, data);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Set input from a CUDA device pointer. Zero-copy on device.
|
||||
/// The pointer must be a valid CUDA device allocation with at least n_bytes bytes.
|
||||
#[cfg(feature = "cuda")]
|
||||
fn set_input_device_ptr(
|
||||
&mut self,
|
||||
name: &str,
|
||||
device_ptr: u64,
|
||||
n_bytes: usize,
|
||||
) -> PyResult<()> {
|
||||
let node_id = self.tensor_ids.get(name).ok_or_else(|| {
|
||||
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!("Unknown input tensor: {}", name))
|
||||
})?;
|
||||
match &mut self.runtime {
|
||||
RuntimeBackend::Cuda(rt) => unsafe { rt.set_device_ptr(*node_id, device_ptr, n_bytes) },
|
||||
_ => {
|
||||
return Err(pyo3::exceptions::PyValueError::new_err(
|
||||
"set_input_device_ptr requires CUDA backend",
|
||||
));
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Mark an input tensor as persistent (survives execute() calls).
|
||||
/// Call this for weight tensors that should not be consumed after each execution.
|
||||
fn persist_input(&mut self, name: &str) -> PyResult<()> {
|
||||
let _node_id = *self.tensor_ids.get(name).ok_or_else(|| {
|
||||
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!("Unknown input tensor: {}", name))
|
||||
})?;
|
||||
match &mut self.runtime {
|
||||
#[cfg(feature = "cuda")]
|
||||
RuntimeBackend::Cuda(rt) => rt.persist_hlir_node(_node_id),
|
||||
RuntimeBackend::Native(_) => {} // Native: persist is handled at graph level
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Set a weight tensor from a CUDA device pointer, matching by Input node label.
|
||||
/// Also marks the weight as persistent. For PT2 weights (e.g. "fc1.weight").
|
||||
#[cfg(feature = "cuda")]
|
||||
fn set_weight_device_ptr(
|
||||
&mut self,
|
||||
label: &str,
|
||||
device_ptr: u64,
|
||||
n_bytes: usize,
|
||||
) -> PyResult<()> {
|
||||
let &node_id = self.label_map.get(label).ok_or_else(|| {
|
||||
pyo3::exceptions::PyKeyError::new_err(format!("No Input node with label: {}", label))
|
||||
})?;
|
||||
match &mut self.runtime {
|
||||
RuntimeBackend::Cuda(rt) => {
|
||||
unsafe { rt.set_device_ptr(node_id, device_ptr, n_bytes) };
|
||||
rt.persist_hlir_node(node_id);
|
||||
}
|
||||
_ => {
|
||||
return Err(pyo3::exceptions::PyValueError::new_err(
|
||||
"set_weight_device_ptr requires CUDA backend",
|
||||
));
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Set a weight tensor from a CPU host pointer, matching by Input node label.
|
||||
fn set_weight_from_ptr(&mut self, label: &str, ptr: u64, n_elements: usize) -> PyResult<()> {
|
||||
debug_assert!(ptr != 0, "set_weight_from_ptr called with null pointer");
|
||||
let &node_id = self.label_map.get(label).ok_or_else(|| {
|
||||
pyo3::exceptions::PyKeyError::new_err(format!("No Input node with label: {}", label))
|
||||
})?;
|
||||
let data: Vec<f32> =
|
||||
unsafe { std::slice::from_raw_parts(ptr as *const f32, n_elements).to_vec() };
|
||||
self.runtime.set_data(node_id, data);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Execute the graph.
|
||||
fn run(&mut self) {
|
||||
self.runtime.execute(&self.graph.dyn_map);
|
||||
@@ -537,7 +480,7 @@ impl CompiledGraph {
|
||||
})
|
||||
}
|
||||
|
||||
/// Get output tensor data by name.
|
||||
/// Get output tensor data by name (copies to host).
|
||||
fn get_output(&self, name: &str) -> PyResult<Vec<f32>> {
|
||||
let node_id = self.tensor_ids.get(name).ok_or_else(|| {
|
||||
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!(
|
||||
@@ -547,4 +490,25 @@ impl CompiledGraph {
|
||||
})?;
|
||||
Ok(self.runtime.get_f32(*node_id))
|
||||
}
|
||||
|
||||
/// Copy output tensor data directly to a CUDA device pointer (DtoD).
|
||||
/// Avoids the DtoH + HtoD round-trip of get_output() + .to(device).
|
||||
#[cfg(feature = "cuda")]
|
||||
fn copy_output_to_device_ptr(&self, name: &str, dest_ptr: u64, n_bytes: usize) -> PyResult<()> {
|
||||
let node_id = self.tensor_ids.get(name).ok_or_else(|| {
|
||||
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!(
|
||||
"Unknown output tensor: {}",
|
||||
name
|
||||
))
|
||||
})?;
|
||||
match &self.runtime {
|
||||
RuntimeBackend::Cuda(rt) => {
|
||||
unsafe { rt.copy_output_to_device_ptr(*node_id, dest_ptr, n_bytes) };
|
||||
Ok(())
|
||||
}
|
||||
_ => Err(pyo3::exceptions::PyValueError::new_err(
|
||||
"copy_output_to_device_ptr requires CUDA backend",
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
mod compiled_graph;
|
||||
mod dispatch;
|
||||
mod onnx_translator;
|
||||
mod ops_parse;
|
||||
mod runtime;
|
||||
mod util;
|
||||
@@ -12,12 +13,9 @@ mod pt2_util;
|
||||
mod translator;
|
||||
|
||||
use compiled_graph::CompiledGraph;
|
||||
use onnx_protobuf::ModelProto;
|
||||
use protobuf::Message;
|
||||
use pt2_compiled_model::compile_pt2;
|
||||
use pt2_compiled_model::process_pt2;
|
||||
use pyo3::prelude::*;
|
||||
use std::fs;
|
||||
use std::path::Path;
|
||||
use std::collections::HashMap;
|
||||
|
||||
fn validate_backend(backend: &str) -> PyResult<()> {
|
||||
match backend {
|
||||
@@ -48,46 +46,28 @@ fn validate_backend(backend: &str) -> PyResult<()> {
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
#[pyo3(signature = (path, backend="native"))]
|
||||
fn process_onnx(path: &str, backend: &str) -> PyResult<CompiledGraph> {
|
||||
#[pyo3(signature = (path, backend="native", search_iters=10, weight_device_ptrs=None))]
|
||||
fn process_onnx(
|
||||
path: &str,
|
||||
backend: &str,
|
||||
search_iters: usize,
|
||||
weight_device_ptrs: Option<HashMap<String, (u64, usize)>>,
|
||||
) -> PyResult<CompiledGraph> {
|
||||
validate_backend(backend)?;
|
||||
|
||||
parse_onnx(path, backend).map_err(pyo3::exceptions::PyRuntimeError::new_err)
|
||||
}
|
||||
|
||||
fn parse_onnx(path: &str, backend: &str) -> Result<CompiledGraph, String> {
|
||||
let data = fs::read(path).map_err(|e| format!("Failed to read file: {}", e))?;
|
||||
let model_directory = Path::new(path).parent().unwrap_or(Path::new("."));
|
||||
let model = ModelProto::parse_from_bytes(&data)
|
||||
.map_err(|e| format!("Failed to parse Onnx Model: {}", e))?;
|
||||
|
||||
let opset_version = model
|
||||
.opset_import
|
||||
.iter()
|
||||
.find(|entry| entry.domain.is_empty())
|
||||
.map(|entry| entry.version);
|
||||
|
||||
match opset_version {
|
||||
Some(20) => {}
|
||||
Some(v) => {
|
||||
return Err(format!(
|
||||
"Unsupported ONNX opset version {v}. Only opset 20 is supported."
|
||||
));
|
||||
}
|
||||
None => {
|
||||
return Err(
|
||||
"No ONNX opset version found in model. Only opset 20 is supported.".to_string(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
CompiledGraph::parse_graph(model, model_directory, backend)
|
||||
onnx_translator::compile_onnx(
|
||||
path,
|
||||
backend,
|
||||
weight_device_ptrs.unwrap_or_default(),
|
||||
search_iters,
|
||||
)
|
||||
.map_err(pyo3::exceptions::PyRuntimeError::new_err)
|
||||
}
|
||||
|
||||
#[pymodule]
|
||||
fn luminal(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_function(wrap_pyfunction!(process_onnx, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(compile_pt2, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(process_pt2, m)?)?;
|
||||
m.add_class::<CompiledGraph>()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
283
crates/luminal_python/rust/src/onnx_translator.rs
Normal file
283
crates/luminal_python/rust/src/onnx_translator.rs
Normal file
@@ -0,0 +1,283 @@
|
||||
use luminal::{
|
||||
prelude::{
|
||||
tracing::{Level, span, trace},
|
||||
*,
|
||||
},
|
||||
shape::Expression,
|
||||
};
|
||||
use onnx_protobuf::ModelProto;
|
||||
use protobuf::Message;
|
||||
use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
fs,
|
||||
path::Path,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
compiled_graph::{CompiledGraph, GraphTranslation, WeightData},
|
||||
dispatch::process_onnx_nodes,
|
||||
util::{
|
||||
DimParamMap, get_shape_for_onnx_value, get_shape_for_onnx_value_expr,
|
||||
load_all_tensor_floats, load_initializer_as_f32,
|
||||
},
|
||||
};
|
||||
|
||||
/// Load, validate, translate, and compile an ONNX model.
|
||||
///
|
||||
/// This is the ONNX counterpart of `pt2_compiled_model::compile_pt2()`.
|
||||
pub fn compile_onnx(
|
||||
path: &str,
|
||||
backend: &str,
|
||||
weight_device_ptrs: HashMap<String, (u64, usize)>,
|
||||
search_iters: usize,
|
||||
) -> Result<CompiledGraph, String> {
|
||||
let data = fs::read(path).map_err(|e| format!("Failed to read file: {}", e))?;
|
||||
let model_directory = Path::new(path).parent().unwrap_or(Path::new("."));
|
||||
let model = ModelProto::parse_from_bytes(&data)
|
||||
.map_err(|e| format!("Failed to parse ONNX model: {}", e))?;
|
||||
|
||||
let opset_version = model
|
||||
.opset_import
|
||||
.iter()
|
||||
.find(|entry| entry.domain.is_empty())
|
||||
.map(|entry| entry.version);
|
||||
|
||||
match opset_version {
|
||||
Some(20) => {}
|
||||
Some(v) => {
|
||||
return Err(format!(
|
||||
"Unsupported ONNX opset version {v}. Only opset 20 is supported."
|
||||
));
|
||||
}
|
||||
None => {
|
||||
return Err(
|
||||
"No ONNX opset version found in model. Only opset 20 is supported.".to_string(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
let (translation, mut weights) = translate_onnx(model, model_directory)?;
|
||||
weights.device_ptrs = weight_device_ptrs;
|
||||
CompiledGraph::parse_graph(translation, weights, backend, search_iters)
|
||||
}
|
||||
|
||||
/// Translate an ONNX model into a format-neutral GraphTranslation + WeightData.
|
||||
pub fn translate_onnx(
|
||||
model: ModelProto,
|
||||
model_directory: &Path,
|
||||
) -> Result<(GraphTranslation, WeightData), String> {
|
||||
let _span = span!(Level::TRACE, "ONNX Graph Translation").entered();
|
||||
let onnx_graph = &model.graph;
|
||||
let mut cx = Graph::new();
|
||||
let mut tensors: HashMap<String, GraphTensor> = HashMap::new();
|
||||
|
||||
// Dynamic dimension tracking
|
||||
let mut dim_param_map: DimParamMap = HashMap::new();
|
||||
let mut next_char = 'a';
|
||||
|
||||
// Separate initializers (weights) from true user inputs
|
||||
let initializer_names: HashSet<&str> = onnx_graph
|
||||
.initializer
|
||||
.iter()
|
||||
.map(|t| t.name.as_str())
|
||||
.collect();
|
||||
|
||||
let input_names: Vec<String> = onnx_graph
|
||||
.input
|
||||
.iter()
|
||||
.filter(|inp| !initializer_names.contains(inp.name.as_str()))
|
||||
.map(|inp| inp.name.clone())
|
||||
.collect();
|
||||
|
||||
// Create input tensors with dynamic dimension support
|
||||
for input in &onnx_graph.input {
|
||||
let shape_exprs = get_shape_for_onnx_value_expr(input, &mut dim_param_map, &mut next_char);
|
||||
if shape_exprs.is_empty() {
|
||||
let shape = get_shape_for_onnx_value(input);
|
||||
if shape.is_empty() {
|
||||
trace!("Input {} skipped because it is empty", input.name.clone());
|
||||
continue;
|
||||
}
|
||||
let tensor = cx.named_tensor(input.name.clone(), shape);
|
||||
trace!("Input {} added to tensors", input.name.clone());
|
||||
tensors.insert(input.name.clone(), tensor);
|
||||
continue;
|
||||
}
|
||||
let tensor = cx.named_tensor(input.name.clone(), shape_exprs);
|
||||
trace!("Input {} added to tensors", input.name.clone());
|
||||
tensors.insert(input.name.clone(), tensor);
|
||||
}
|
||||
|
||||
// Create initializer (weight) tensors
|
||||
for init in &onnx_graph.initializer {
|
||||
if !tensors.contains_key(&init.name) {
|
||||
let mut shape: Vec<usize> = init.dims.iter().map(|&d| d as usize).collect();
|
||||
if shape.is_empty() {
|
||||
shape = vec![1];
|
||||
}
|
||||
let tensor = cx.named_tensor(init.name.clone(), shape);
|
||||
tensors.insert(init.name.clone(), tensor);
|
||||
}
|
||||
}
|
||||
|
||||
// Load small constants for constant folding
|
||||
let mut known_values: HashMap<String, Vec<f32>> = HashMap::new();
|
||||
for init in &onnx_graph.initializer {
|
||||
let n_elements: usize = init
|
||||
.dims
|
||||
.iter()
|
||||
.map(|&d| d as usize)
|
||||
.product::<usize>()
|
||||
.max(1);
|
||||
if n_elements <= 32 {
|
||||
if let Some(floats) = load_initializer_as_f32(init) {
|
||||
known_values.insert(init.name.clone(), floats);
|
||||
} else {
|
||||
panic!("Unable to load initializer values for {:?}", init.name);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Shape expressions for propagating symbolic shapes through ONNX graphs
|
||||
let mut shape_exprs: HashMap<String, Vec<Expression>> = HashMap::new();
|
||||
|
||||
// Accumulates constant node data from process_onnx_nodes
|
||||
let mut constant_data: Vec<(String, Vec<f32>)> = Vec::new();
|
||||
|
||||
// Process computation nodes
|
||||
process_onnx_nodes(
|
||||
&onnx_graph.node,
|
||||
&mut tensors,
|
||||
&mut cx,
|
||||
&mut constant_data,
|
||||
&mut known_values,
|
||||
&mut shape_exprs,
|
||||
)
|
||||
.map_err(|e| format!("process_onnx_nodes failed: {}", e))?;
|
||||
|
||||
// Mark weight/constant tensors as persistent so their buffers survive execute()
|
||||
for (name, gt) in &tensors {
|
||||
if !input_names.contains(name) {
|
||||
gt.persist();
|
||||
}
|
||||
}
|
||||
|
||||
// Mark graph outputs (must happen before build_search_space)
|
||||
let mut output_names = Vec::new();
|
||||
let mut output_shape_exprs = Vec::new();
|
||||
for output_vi in &onnx_graph.output {
|
||||
if let Some(>) = tensors.get(&output_vi.name) {
|
||||
// Force contiguous if the shape tracker is a non-contiguous view
|
||||
let gt = if gt.shape != gt.shape.contiguous() {
|
||||
let contiguous = gt * 1.0;
|
||||
tensors.insert(output_vi.name.clone(), contiguous);
|
||||
contiguous
|
||||
} else {
|
||||
gt
|
||||
};
|
||||
gt.output();
|
||||
let dims = gt.dims();
|
||||
output_shape_exprs.push(dims.clone());
|
||||
|
||||
let shape: Vec<usize> = dims.iter().map(|d| d.to_usize().unwrap_or(1)).collect();
|
||||
if shape.is_empty() {
|
||||
return Err(format!(
|
||||
"Output tensor '{}' has no shape information in the ONNX model",
|
||||
output_vi.name
|
||||
));
|
||||
}
|
||||
output_names.push(output_vi.name.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Set initial dynamic dimension values from example input shapes
|
||||
let has_dynamic = !dim_param_map.is_empty();
|
||||
if has_dynamic {
|
||||
for input in &onnx_graph.input {
|
||||
if initializer_names.contains(input.name.as_str()) {
|
||||
continue;
|
||||
}
|
||||
let concrete_shape = get_shape_for_onnx_value(input);
|
||||
let expr_shape =
|
||||
get_shape_for_onnx_value_expr(input, &mut dim_param_map, &mut next_char);
|
||||
for (expr, concrete) in expr_shape.iter().zip(concrete_shape.iter()) {
|
||||
if expr.to_usize().is_none()
|
||||
&& let Some(ch) = dim_param_map
|
||||
.values()
|
||||
.find(|&&ch| Expression::from(ch) == *expr)
|
||||
{
|
||||
cx.set_dim(*ch, *concrete);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Build weight data: initializers + constants from process_onnx_nodes
|
||||
let mut weights: Vec<(String, Vec<f32>)> = Vec::new();
|
||||
for (name, floats) in load_all_tensor_floats(&onnx_graph.initializer, model_directory) {
|
||||
if let Some(f) = floats {
|
||||
weights.push((name, f));
|
||||
}
|
||||
}
|
||||
weights.extend(constant_data);
|
||||
|
||||
// Build tensor sizes for CUDA dummy data allocation
|
||||
let mut tensor_sizes: HashMap<String, usize> = HashMap::new();
|
||||
for input in &onnx_graph.input {
|
||||
if !initializer_names.contains(input.name.as_str()) {
|
||||
let shape = get_shape_for_onnx_value(input);
|
||||
let n: usize = shape.iter().product::<usize>().max(1);
|
||||
tensor_sizes.insert(input.name.clone(), n);
|
||||
}
|
||||
}
|
||||
for init in &onnx_graph.initializer {
|
||||
let n: usize = init
|
||||
.dims
|
||||
.iter()
|
||||
.map(|&d| d as usize)
|
||||
.product::<usize>()
|
||||
.max(1);
|
||||
tensor_sizes.insert(init.name.clone(), n);
|
||||
}
|
||||
for (name, data) in &weights {
|
||||
if !tensor_sizes.contains_key(name) {
|
||||
tensor_sizes.insert(name.clone(), data.len());
|
||||
}
|
||||
}
|
||||
|
||||
// Collect tensor name → NodeIndex mapping
|
||||
let tensor_ids: HashMap<String, NodeIndex> = tensors
|
||||
.iter()
|
||||
.map(|(name, gt)| (name.clone(), gt.id))
|
||||
.collect();
|
||||
|
||||
// Build input_shape_exprs for user inputs (needed for auto-dim detection)
|
||||
let input_shape_exprs: Vec<Vec<Expression>> = input_names
|
||||
.iter()
|
||||
.map(|name| {
|
||||
if let Some(>) = tensors.get(name) {
|
||||
gt.dims()
|
||||
} else {
|
||||
vec![]
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
let translation = GraphTranslation {
|
||||
graph: cx,
|
||||
tensor_ids,
|
||||
input_names,
|
||||
output_names,
|
||||
output_shape_exprs,
|
||||
input_shape_exprs,
|
||||
dim_param_map,
|
||||
};
|
||||
|
||||
let weight_data = WeightData {
|
||||
weights,
|
||||
tensor_sizes,
|
||||
device_ptrs: HashMap::new(),
|
||||
};
|
||||
|
||||
Ok((translation, weight_data))
|
||||
}
|
||||
@@ -1,20 +1,16 @@
|
||||
use luminal::graph::Graph as LuminalGraph;
|
||||
use luminal::prelude::*;
|
||||
use pyo3::prelude::*;
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
use luminal_cuda_lite::cudarc::driver::CudaContext;
|
||||
#[cfg(feature = "cuda")]
|
||||
use luminal_cuda_lite::runtime::CudaRuntime;
|
||||
|
||||
use crate::compiled_graph::CompiledGraph;
|
||||
use crate::compiled_graph::{CompiledGraph, GraphTranslation, WeightData};
|
||||
use crate::pt2_parser;
|
||||
use crate::pt2_schema;
|
||||
use crate::runtime::RuntimeBackend;
|
||||
use crate::translator;
|
||||
use crate::util::DimParamMap;
|
||||
|
||||
/// Pre-loaded weight/constant data paired with tensor sizes.
|
||||
type PreloadResult = (Vec<(String, Vec<f32>)>, HashMap<String, usize>);
|
||||
|
||||
fn resolve_dim_sizes(
|
||||
sizes: &[pt2_schema::DimSize],
|
||||
sym_to_char: &HashMap<String, char>,
|
||||
@@ -39,32 +35,55 @@ fn resolve_dim_sizes(
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
pub fn compile_pt2(
|
||||
#[pyo3(signature = (pt2_path, weights_path, backend, search_iters, weight_device_ptrs=None))]
|
||||
pub fn process_pt2(
|
||||
pt2_path: &str,
|
||||
weights_path: &str,
|
||||
backend: &str,
|
||||
search_iters: usize,
|
||||
weight_device_ptrs: Option<HashMap<String, (u64, usize)>>,
|
||||
) -> PyResult<CompiledGraph> {
|
||||
compile_pt2_inner(pt2_path, weights_path, backend, search_iters)
|
||||
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e:#}")))
|
||||
compile_pt2(
|
||||
pt2_path,
|
||||
weights_path,
|
||||
backend,
|
||||
search_iters,
|
||||
weight_device_ptrs.unwrap_or_default(),
|
||||
)
|
||||
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e:#}")))
|
||||
}
|
||||
|
||||
fn compile_pt2_inner(
|
||||
fn compile_pt2(
|
||||
pt2_path: &str,
|
||||
weights_path: &str,
|
||||
backend: &str,
|
||||
search_iters: usize,
|
||||
weight_device_ptrs: HashMap<String, (u64, usize)>,
|
||||
) -> anyhow::Result<CompiledGraph> {
|
||||
let (translation, mut weights) = translate_pt2(pt2_path, weights_path)?;
|
||||
weights.device_ptrs = weight_device_ptrs;
|
||||
|
||||
CompiledGraph::parse_graph(translation, weights, backend, search_iters)
|
||||
.map_err(|e| anyhow::anyhow!(e))
|
||||
}
|
||||
|
||||
/// Translate a PT2 exported model into a format-neutral GraphTranslation + WeightData.
|
||||
pub fn translate_pt2(
|
||||
pt2_path: &str,
|
||||
weights_path: &str,
|
||||
) -> anyhow::Result<(GraphTranslation, WeightData)> {
|
||||
let parsed = pt2_parser::parse_pt2(pt2_path)?;
|
||||
let translated = translator::translate(&parsed)?;
|
||||
let mut graph = translated.graph;
|
||||
|
||||
// Set initial dynamic dim values from symbol ranges
|
||||
for (sym_name, c) in &translated.sym_map.sym_to_char {
|
||||
if let Some(rc) = translated.sym_map.ranges.get(sym_name) {
|
||||
graph.set_dim(*c, rc.min_val as usize);
|
||||
}
|
||||
}
|
||||
|
||||
// Compute shape expressions from PT2 tensor metadata
|
||||
let output_shape_exprs: Vec<Vec<Expression>> = translated
|
||||
.output_ids
|
||||
.iter()
|
||||
@@ -98,45 +117,6 @@ fn compile_pt2_inner(
|
||||
})
|
||||
.collect();
|
||||
|
||||
let user_input_sizes: Vec<(NodeIndex, usize)> = translated
|
||||
.user_input_ids
|
||||
.iter()
|
||||
.map(|(name, id)| {
|
||||
let meta = parsed.tensor_meta(name);
|
||||
let n_elements = meta
|
||||
.map(|m| {
|
||||
m.sizes
|
||||
.iter()
|
||||
.map(|s| s.hint().unwrap_or(1) as usize)
|
||||
.product()
|
||||
})
|
||||
.unwrap_or(1);
|
||||
(*id, n_elements)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let runtime = match backend {
|
||||
"cpu" | "native" => {
|
||||
graph.build_search_space::<NativeRuntime>();
|
||||
let mut rt = graph.search(NativeRuntime::default(), search_iters);
|
||||
if !weights_path.is_empty() {
|
||||
load_safetensors_native(&mut rt, &graph, weights_path)?;
|
||||
}
|
||||
load_constants_native(&mut rt, &graph, &parsed)?;
|
||||
RuntimeBackend::Native(rt)
|
||||
}
|
||||
"cuda" | "gpu" => init_cuda_runtime(
|
||||
&mut graph,
|
||||
weights_path,
|
||||
&parsed,
|
||||
&user_input_sizes,
|
||||
search_iters,
|
||||
)?,
|
||||
other => {
|
||||
anyhow::bail!("Unknown backend: {other}. Use 'cpu' or 'cuda'.");
|
||||
}
|
||||
};
|
||||
|
||||
// Build tensor_ids from user inputs and outputs
|
||||
let mut tensor_ids: HashMap<String, NodeIndex> = HashMap::new();
|
||||
for (name, id) in &translated.user_input_ids {
|
||||
@@ -146,80 +126,90 @@ fn compile_pt2_inner(
|
||||
tensor_ids.insert(name.clone(), *id);
|
||||
}
|
||||
|
||||
// Resolve concrete output shapes
|
||||
let output_shapes: Vec<Vec<usize>> = output_shape_exprs
|
||||
.iter()
|
||||
.map(|exprs| exprs.iter().map(|e| e.to_usize().unwrap_or(1)).collect())
|
||||
.collect();
|
||||
// Pre-load weights and compute tensor sizes for CUDA dummy data
|
||||
let mut weights: Vec<(String, Vec<f32>)> = Vec::new();
|
||||
let mut tensor_sizes: HashMap<String, usize> = HashMap::new();
|
||||
|
||||
// Load safetensors weights
|
||||
if !weights_path.is_empty() {
|
||||
let (st_weights, st_sizes) = preload_safetensors(&graph, weights_path)?;
|
||||
weights.extend(st_weights);
|
||||
tensor_sizes.extend(st_sizes);
|
||||
}
|
||||
|
||||
// Load PT2 constants from ZIP archive
|
||||
let (const_weights, const_sizes) = preload_constants(&graph, &parsed)?;
|
||||
weights.extend(const_weights);
|
||||
tensor_sizes.extend(const_sizes);
|
||||
|
||||
// Add tensor sizes from PT2 metadata for parameters/buffers not in safetensors
|
||||
// (covers case when weights are loaded via device pointers after compilation)
|
||||
for input_kind in parsed.classify_inputs() {
|
||||
let (graph_name, original_name) = match &input_kind {
|
||||
pt2_parser::InputKind::Parameter {
|
||||
graph_name,
|
||||
original_name,
|
||||
} => (graph_name.as_str(), original_name.as_str()),
|
||||
pt2_parser::InputKind::Buffer {
|
||||
graph_name,
|
||||
original_name,
|
||||
} => (graph_name.as_str(), original_name.as_str()),
|
||||
pt2_parser::InputKind::UserInput { .. } => continue,
|
||||
};
|
||||
// Always use authoritative sizes from model.json tensor_meta,
|
||||
// even if preload_constants inserted a different (possibly stripped) size.
|
||||
if let Some(meta) = parsed.tensor_meta(graph_name) {
|
||||
let n: usize = meta
|
||||
.sizes
|
||||
.iter()
|
||||
.map(|s| s.hint().unwrap_or(1) as usize)
|
||||
.product();
|
||||
tensor_sizes.insert(original_name.to_string(), n);
|
||||
}
|
||||
}
|
||||
|
||||
// Add user input sizes
|
||||
for (name, _id) in &translated.user_input_ids {
|
||||
if !tensor_sizes.contains_key(name)
|
||||
&& let Some(meta) = parsed.tensor_meta(name)
|
||||
{
|
||||
let n: usize = meta
|
||||
.sizes
|
||||
.iter()
|
||||
.map(|s| s.hint().unwrap_or(1) as usize)
|
||||
.product();
|
||||
tensor_sizes.insert(name.clone(), n);
|
||||
}
|
||||
}
|
||||
|
||||
// Build dim_param_map from sym_map
|
||||
let dim_param_map: DimParamMap = translated.sym_map.sym_to_char;
|
||||
|
||||
Ok(CompiledGraph {
|
||||
let translation = GraphTranslation {
|
||||
graph,
|
||||
runtime,
|
||||
tensor_ids,
|
||||
input_names,
|
||||
output_names,
|
||||
output_shapes,
|
||||
output_shape_exprs,
|
||||
input_shape_exprs,
|
||||
dim_param_map,
|
||||
})
|
||||
}
|
||||
};
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
fn init_cuda_runtime(
|
||||
graph: &mut LuminalGraph,
|
||||
weights_path: &str,
|
||||
parsed: &pt2_parser::ParsedPT2,
|
||||
user_input_sizes: &[(NodeIndex, usize)],
|
||||
search_iters: usize,
|
||||
) -> anyhow::Result<RuntimeBackend> {
|
||||
let cuda_ctx =
|
||||
CudaContext::new(0).map_err(|e| anyhow::anyhow!("CUDA context init failed: {e}"))?;
|
||||
let stream = cuda_ctx.default_stream();
|
||||
let weight_data = WeightData {
|
||||
weights,
|
||||
tensor_sizes,
|
||||
device_ptrs: HashMap::new(),
|
||||
};
|
||||
|
||||
graph.build_search_space::<CudaRuntime>();
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
|
||||
// Phase 1: Set ALL input nodes to safe dummy data (1.0) for search profiling.
|
||||
// Real weights/constants may contain -inf (e.g. causal attention mask) which
|
||||
// produce NaN in intermediate computations (e.g. -inf - (-inf) = NaN in softmax
|
||||
// decomposition), causing the search's has_nan_outputs check to reject ALL
|
||||
// candidates. We load real data only AFTER the search completes.
|
||||
set_all_inputs_dummy_cuda(&mut rt, graph, weights_path, parsed, user_input_sizes)?;
|
||||
|
||||
let mut rt = graph.search(rt, search_iters);
|
||||
|
||||
if !weights_path.is_empty() {
|
||||
load_safetensors_cuda(&mut rt, graph, weights_path)?;
|
||||
}
|
||||
load_constants_cuda(&mut rt, graph, parsed)?;
|
||||
|
||||
Ok(RuntimeBackend::Cuda(Box::new(rt)))
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
fn init_cuda_runtime(
|
||||
_graph: &mut LuminalGraph,
|
||||
_weights_path: &str,
|
||||
_parsed: &pt2_parser::ParsedPT2,
|
||||
_user_input_sizes: &[(NodeIndex, usize)],
|
||||
_search_iters: usize,
|
||||
) -> anyhow::Result<RuntimeBackend> {
|
||||
anyhow::bail!("CUDA support not compiled. Rebuild with --features cuda")
|
||||
Ok((translation, weight_data))
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Weight loading
|
||||
// Weight pre-loading helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
fn load_safetensors_impl(
|
||||
cx: &LuminalGraph,
|
||||
file_path: &str,
|
||||
mut set_data: impl FnMut(NodeIndex, Vec<f32>),
|
||||
) -> anyhow::Result<()> {
|
||||
/// Pre-load all safetensors weights that match Input nodes in the graph.
|
||||
/// Returns (weight data, tensor sizes for all tensors in the file).
|
||||
fn preload_safetensors(graph: &Graph, file_path: &str) -> anyhow::Result<PreloadResult> {
|
||||
use memmap2::MmapOptions;
|
||||
use safetensors::SafeTensors;
|
||||
use std::fs::File;
|
||||
@@ -229,95 +219,78 @@ fn load_safetensors_impl(
|
||||
let st = SafeTensors::deserialize(&mmap)
|
||||
.map_err(|e| anyhow::anyhow!("SafeTensors deserialize error: {e}"))?;
|
||||
|
||||
for node in cx.graph.node_indices() {
|
||||
if let Some(input) = (*cx.graph[node])
|
||||
let mut weights = Vec::new();
|
||||
let mut sizes = HashMap::new();
|
||||
|
||||
// Get sizes for ALL tensors in the file (for dummy data allocation)
|
||||
for (name, info) in st.tensors() {
|
||||
let n: usize = info.shape().iter().product();
|
||||
sizes.insert(name.to_string(), n);
|
||||
}
|
||||
|
||||
// Load weight data for Input nodes that match safetensors tensor names
|
||||
for node_id in graph.graph.node_indices() {
|
||||
if let Some(input) = (*graph.graph[node_id])
|
||||
.as_any()
|
||||
.downcast_ref::<luminal::hlir::Input>()
|
||||
&& let Ok(tensor) = st.tensor(&input.label)
|
||||
{
|
||||
let f32s = bytes_to_f32(tensor.data(), safetensors_dtype_to_pt2(tensor.dtype()));
|
||||
set_data(node, f32s);
|
||||
weights.push((input.label.clone(), f32s));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
Ok((weights, sizes))
|
||||
}
|
||||
|
||||
fn load_safetensors_native(
|
||||
rt: &mut NativeRuntime,
|
||||
cx: &LuminalGraph,
|
||||
file_path: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
load_safetensors_impl(cx, file_path, |node, data| rt.set_data(node, data))
|
||||
}
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
fn load_safetensors_cuda(
|
||||
rt: &mut CudaRuntime,
|
||||
cx: &LuminalGraph,
|
||||
file_path: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
load_safetensors_impl(cx, file_path, |node, data| rt.set_data(node, data))
|
||||
}
|
||||
|
||||
/// Set ALL input nodes to dummy 1.0 data for safe CUDA search profiling.
|
||||
#[cfg(feature = "cuda")]
|
||||
fn set_all_inputs_dummy_cuda(
|
||||
rt: &mut CudaRuntime,
|
||||
cx: &LuminalGraph,
|
||||
weights_path: &str,
|
||||
/// Pre-load all PT2 constants from the ZIP archive.
|
||||
/// Returns (constant data, tensor sizes for all constants).
|
||||
fn preload_constants(
|
||||
_graph: &Graph,
|
||||
parsed: &pt2_parser::ParsedPT2,
|
||||
user_input_sizes: &[(NodeIndex, usize)],
|
||||
) -> anyhow::Result<()> {
|
||||
use memmap2::MmapOptions;
|
||||
use safetensors::SafeTensors;
|
||||
use std::fs::File;
|
||||
) -> anyhow::Result<PreloadResult> {
|
||||
let constants_config = match &parsed.constants_config {
|
||||
Some(c) => c,
|
||||
None => return Ok((Vec::new(), HashMap::new())),
|
||||
};
|
||||
|
||||
let mut label_sizes: HashMap<String, usize> = HashMap::new();
|
||||
let mut weights = Vec::new();
|
||||
let mut sizes = HashMap::new();
|
||||
|
||||
if !weights_path.is_empty() {
|
||||
let f = File::open(weights_path)?;
|
||||
let mmap = unsafe { MmapOptions::new().map(&f)? };
|
||||
let st = SafeTensors::deserialize(&mmap)
|
||||
.map_err(|e| anyhow::anyhow!("SafeTensors deserialize error: {e}"))?;
|
||||
for (name, info) in st.tensors() {
|
||||
let n: usize = info.shape().iter().product();
|
||||
label_sizes.insert(name.to_string(), n);
|
||||
}
|
||||
}
|
||||
for (name, entry) in &constants_config.config {
|
||||
let n: usize = entry
|
||||
.tensor_meta
|
||||
.sizes
|
||||
.iter()
|
||||
.map(|s| s.hint().unwrap_or(1) as usize)
|
||||
.product();
|
||||
sizes.insert(name.clone(), n);
|
||||
|
||||
if let Some(cc) = &parsed.constants_config {
|
||||
for (name, entry) in &cc.config {
|
||||
let n: usize = entry
|
||||
.tensor_meta
|
||||
.sizes
|
||||
.iter()
|
||||
.map(|s| s.hint().unwrap_or(1) as usize)
|
||||
.product();
|
||||
label_sizes.insert(name.clone(), n);
|
||||
}
|
||||
}
|
||||
|
||||
for node_id in cx.graph.node_indices() {
|
||||
if let Some(input) = (*cx.graph[node_id])
|
||||
.as_any()
|
||||
.downcast_ref::<luminal::hlir::Input>()
|
||||
{
|
||||
if let Some(&n) = label_sizes.get(&input.label) {
|
||||
if n > 0 {
|
||||
rt.set_data(node_id, vec![1.0f32; n]);
|
||||
}
|
||||
let raw_bytes = match pt2_parser::read_constant_bytes(
|
||||
&parsed.pt2_path,
|
||||
&parsed.archive_prefix,
|
||||
entry,
|
||||
) {
|
||||
Ok(b) => b,
|
||||
Err(e) => {
|
||||
eprintln!(
|
||||
"[luminal] Warning: failed to load constant '{}': {:#}",
|
||||
name, e
|
||||
);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
};
|
||||
let f32_data = bytes_to_f32(&raw_bytes, entry.tensor_meta.dtype);
|
||||
weights.push((name.clone(), f32_data));
|
||||
}
|
||||
|
||||
for &(id, n_elements) in user_input_sizes {
|
||||
rt.set_data(id, vec![1.0f32; n_elements]);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
Ok((weights, sizes))
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Byte conversion helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Convert safetensors Dtype to PT2 dtype number.
|
||||
fn safetensors_dtype_to_pt2(dtype: safetensors::Dtype) -> u32 {
|
||||
match dtype {
|
||||
@@ -381,60 +354,3 @@ fn bytes_to_f32(bytes: &[u8], dtype: u32) -> Vec<f32> {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn load_constants_impl(
|
||||
cx: &LuminalGraph,
|
||||
parsed: &pt2_parser::ParsedPT2,
|
||||
mut set_data: impl FnMut(NodeIndex, Vec<f32>),
|
||||
) -> anyhow::Result<()> {
|
||||
let constants_config = match &parsed.constants_config {
|
||||
Some(c) => c,
|
||||
None => return Ok(()),
|
||||
};
|
||||
|
||||
for (name, entry) in &constants_config.config {
|
||||
let raw_bytes = match pt2_parser::read_constant_bytes(
|
||||
&parsed.pt2_path,
|
||||
&parsed.archive_prefix,
|
||||
entry,
|
||||
) {
|
||||
Ok(b) => b,
|
||||
Err(e) => {
|
||||
eprintln!(
|
||||
"[luminal] Warning: failed to load constant '{}': {:#}",
|
||||
name, e
|
||||
);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
let f32_data = bytes_to_f32(&raw_bytes, entry.tensor_meta.dtype);
|
||||
|
||||
for node_id in cx.graph.node_indices() {
|
||||
if let Some(input) = (*cx.graph[node_id])
|
||||
.as_any()
|
||||
.downcast_ref::<luminal::hlir::Input>()
|
||||
&& input.label == *name
|
||||
{
|
||||
set_data(node_id, f32_data.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn load_constants_native(
|
||||
rt: &mut NativeRuntime,
|
||||
cx: &LuminalGraph,
|
||||
parsed: &pt2_parser::ParsedPT2,
|
||||
) -> anyhow::Result<()> {
|
||||
load_constants_impl(cx, parsed, |node, data| rt.set_data(node, data))
|
||||
}
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
fn load_constants_cuda(
|
||||
rt: &mut CudaRuntime,
|
||||
cx: &LuminalGraph,
|
||||
parsed: &pt2_parser::ParsedPT2,
|
||||
) -> anyhow::Result<()> {
|
||||
load_constants_impl(cx, parsed, |node, data| rt.set_data(node, data))
|
||||
}
|
||||
|
||||
@@ -79,11 +79,3 @@ pub fn finalize_cuda(context: &mut Graph, rt: CudaRuntime) -> RuntimeBackend {
|
||||
let optimized_rt = context.search(rt, 10);
|
||||
RuntimeBackend::Cuda(Box::new(optimized_rt))
|
||||
}
|
||||
|
||||
/// Initialize a native (CPU) runtime using single-phase approach.
|
||||
/// NativeRuntime validates Input nodes, so we must search first, then set data.
|
||||
pub fn initialize_native(context: &mut Graph) -> Result<RuntimeBackend, String> {
|
||||
context.build_search_space::<NativeRuntime>();
|
||||
let rt = context.search(NativeRuntime::default(), 10);
|
||||
Ok(RuntimeBackend::Native(rt))
|
||||
}
|
||||
|
||||
@@ -434,18 +434,6 @@ pub fn load_initializer_as_f32(init: &onnx_protobuf::TensorProto) -> Option<Vec<
|
||||
}
|
||||
}
|
||||
|
||||
/// Transpose weight data from [rows, cols] to [cols, rows] row-major layout
|
||||
#[cfg(feature = "cuda")]
|
||||
pub fn transpose_weight_data(data: &[f32], rows: usize, cols: usize) -> Vec<f32> {
|
||||
let mut transposed = vec![0.0f32; rows * cols];
|
||||
for r in 0..rows {
|
||||
for c in 0..cols {
|
||||
transposed[c * rows + r] = data[r * cols + c];
|
||||
}
|
||||
}
|
||||
transposed
|
||||
}
|
||||
|
||||
/// Get an integer attribute from a node, with a default value
|
||||
pub fn get_int_attr(node: &NodeProto, name: &str, default: i64) -> i64 {
|
||||
for attr in &node.attribute {
|
||||
|
||||
@@ -2,11 +2,16 @@
|
||||
|
||||
# Import Python components
|
||||
from .compiled_model import CompiledModel
|
||||
from .main import luminal_backend
|
||||
|
||||
# Import Rust extension components (built by maturin)
|
||||
# These are available directly in the package namespace
|
||||
from .luminal import process_onnx, CompiledGraph, compile_pt2
|
||||
from .luminal import CompiledGraph, process_onnx, process_pt2
|
||||
from .main import luminal_backend
|
||||
|
||||
# Register DynamicCache pytree serialization once at import time
|
||||
from .cache_utils import _register_cache_serialization
|
||||
|
||||
_register_cache_serialization()
|
||||
|
||||
# Re-export everything for clean package interface
|
||||
__all__ = [
|
||||
@@ -14,5 +19,5 @@ __all__ = [
|
||||
"luminal_backend",
|
||||
"process_onnx",
|
||||
"CompiledGraph",
|
||||
"compile_pt2",
|
||||
"process_pt2",
|
||||
]
|
||||
|
||||
@@ -8,17 +8,27 @@ import torch
|
||||
class CompiledModel:
|
||||
"""Wrapper around CompiledGraph that handles PyTorch tensor conversion."""
|
||||
|
||||
def __init__(self, graph_result):
|
||||
def __init__(
|
||||
self, graph_result, weight_refs=None, input_names=None, user_indices=None
|
||||
):
|
||||
"""Initialize with a compiled CompiledGraph from Rust.
|
||||
|
||||
Args:
|
||||
graph_result: The CompiledGraph from luminal_python.process_onnx() or compile_pt2()
|
||||
graph_result: The CompiledGraph from luminal_python.process_onnx() or process_pt2()
|
||||
weight_refs: List of PyTorch tensors to keep alive (prevents GC of shared weights)
|
||||
input_names: Override for user input names. If None, uses graph_result.input_names.
|
||||
user_indices: When torch.compile lifts model parameters into extra args,
|
||||
this tells __call__ which arg positions are actual user inputs.
|
||||
None means all args are user inputs (PT2 path).
|
||||
"""
|
||||
self._graph = graph_result
|
||||
self._input_names = graph_result.input_names
|
||||
self._input_names = input_names or graph_result.input_names
|
||||
self._output_names = graph_result.output_names
|
||||
self._output_shapes = graph_result.output_shapes
|
||||
self._has_dynamic_dims = getattr(graph_result, "has_dynamic_dims", False)
|
||||
self._weight_refs = weight_refs or []
|
||||
self._user_indices = user_indices
|
||||
self._is_cuda = graph_result.backend == "cuda"
|
||||
|
||||
def set_dim(self, param_name: str, value: int) -> None:
|
||||
"""Set a dynamic dimension value by its param name."""
|
||||
@@ -36,29 +46,42 @@ class CompiledModel:
|
||||
"""Execute the compiled model with PyTorch tensor inputs.
|
||||
|
||||
Args:
|
||||
*inputs: PyTorch tensors matching the model's input signature
|
||||
*inputs: PyTorch tensors. When torch.compile lifts model parameters,
|
||||
this includes both weights and user inputs. user_indices filters
|
||||
to just the user inputs.
|
||||
|
||||
Returns:
|
||||
Tuple of PyTorch tensors containing the model outputs
|
||||
"""
|
||||
if len(inputs) != len(self._input_names):
|
||||
raise ValueError(
|
||||
f"Expected {len(self._input_names)} inputs, got {len(inputs)}"
|
||||
)
|
||||
# Extract user inputs (torch.compile may pass lifted weights as extra args)
|
||||
if self._user_indices is not None:
|
||||
user_inputs = [inputs[i] for i in self._user_indices]
|
||||
else:
|
||||
if len(inputs) != len(self._input_names):
|
||||
raise ValueError(
|
||||
f"Expected {len(self._input_names)} inputs, got {len(inputs)}"
|
||||
)
|
||||
user_inputs = inputs
|
||||
|
||||
input_device = inputs[0].device if inputs else torch.device("cpu")
|
||||
|
||||
# Auto-detect dynamic dims from input shapes
|
||||
if self._has_dynamic_dims:
|
||||
input_shapes = [list(t.shape) for t in inputs]
|
||||
input_shapes = [list(t.shape) for t in user_inputs]
|
||||
self._graph.auto_set_dims_from_input_shapes(input_shapes)
|
||||
|
||||
# Set input data
|
||||
for name, tensor in zip(self._input_names, inputs):
|
||||
# Convert to contiguous float32 numpy array (move to CPU first for CUDA tensors)
|
||||
arr = tensor.detach().cpu().contiguous().float().numpy()
|
||||
data = arr.flatten().tolist()
|
||||
self._graph.set_input(name, data)
|
||||
# Set user input data via pointer (avoids Python list conversion).
|
||||
# For CUDA inputs, keep references alive so the caching allocator doesn't
|
||||
# recycle GPU memory before run() reads the pointers.
|
||||
_input_refs = []
|
||||
for name, tensor in zip(self._input_names, user_inputs):
|
||||
if self._is_cuda and tensor.is_cuda:
|
||||
t = tensor.detach().contiguous().float()
|
||||
self._graph.set_input_device_ptr(name, t.data_ptr(), t.numel() * 4)
|
||||
_input_refs.append(t)
|
||||
else:
|
||||
t = tensor.detach().cpu().contiguous().float()
|
||||
self._graph.set_input_from_ptr(name, t.data_ptr(), t.numel())
|
||||
|
||||
# Run the graph
|
||||
self._graph.run()
|
||||
@@ -69,16 +92,22 @@ class CompiledModel:
|
||||
else:
|
||||
output_shapes = self._output_shapes
|
||||
|
||||
# Get outputs and convert back to PyTorch tensors on the same device as inputs
|
||||
# Get outputs and convert back to PyTorch tensors on the same device as inputs.
|
||||
# For CUDA: DtoD copy avoids the DtoH + HtoD round-trip.
|
||||
outputs = []
|
||||
for name, shape in zip(self._output_names, output_shapes):
|
||||
data = self._graph.get_output(name)
|
||||
tensor = (
|
||||
torch.tensor(data, dtype=torch.float32)
|
||||
.reshape(tuple(shape))
|
||||
.to(input_device)
|
||||
)
|
||||
outputs.append(tensor)
|
||||
if self._is_cuda and hasattr(self._graph, "copy_output_to_device_ptr"):
|
||||
out = torch.empty(shape, dtype=torch.float32, device=input_device)
|
||||
self._graph.copy_output_to_device_ptr(
|
||||
name, out.data_ptr(), out.numel() * 4
|
||||
)
|
||||
else:
|
||||
data = self._graph.get_output(name)
|
||||
out = (
|
||||
torch.tensor(data, dtype=torch.float32)
|
||||
.reshape(tuple(shape))
|
||||
.to(input_device)
|
||||
)
|
||||
outputs.append(out)
|
||||
|
||||
# Return as a tuple (TorchDynamo expects tuple return from backend callables)
|
||||
return tuple(outputs)
|
||||
|
||||
@@ -6,10 +6,61 @@ import torch._dynamo
|
||||
|
||||
import luminal
|
||||
|
||||
from .cache_utils import _register_cache_serialization
|
||||
from .compiled_model import CompiledModel
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared helpers (used by both ONNX and PT2 paths)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _detect_backend(example_inputs):
|
||||
"""Detect backend from input device. Returns 'cuda' or 'native'."""
|
||||
device = example_inputs[0].device if example_inputs else torch.device("cpu")
|
||||
return "cuda" if device.type == "cuda" else "native"
|
||||
|
||||
|
||||
def _collect_weight_pointers(weights, backend):
|
||||
"""Partition weight tensors into CUDA device pointers and CPU host pointers.
|
||||
|
||||
Args:
|
||||
weights: dict of name -> torch.Tensor
|
||||
backend: "cuda", "gpu", "cpu", or "native"
|
||||
|
||||
Returns:
|
||||
(keep_alive, device_ptrs, cpu_ptrs) where:
|
||||
- keep_alive: list[Tensor] to prevent GC of shared weight memory
|
||||
- device_ptrs: {name: (device_ptr, n_bytes)}
|
||||
- cpu_ptrs: {name: (host_ptr, n_elements)}
|
||||
"""
|
||||
keep_alive = []
|
||||
device_ptrs = {}
|
||||
cpu_ptrs = {}
|
||||
for name, tensor in weights.items():
|
||||
t = tensor.detach().contiguous()
|
||||
if t.dtype != torch.float32:
|
||||
t = t.float()
|
||||
if backend in ("cuda", "gpu") and t.is_cuda:
|
||||
keep_alive.append(t)
|
||||
device_ptrs[name] = (t.data_ptr(), t.numel() * 4)
|
||||
else:
|
||||
t = t.cpu() if t.is_cuda else t
|
||||
keep_alive.append(t)
|
||||
cpu_ptrs[name] = (t.data_ptr(), t.numel())
|
||||
return keep_alive, device_ptrs, cpu_ptrs
|
||||
|
||||
|
||||
def _load_cpu_weights(compiled_graph, cpu_weights):
|
||||
"""Load CPU weight data into a compiled graph after Rust compilation."""
|
||||
for name, (ptr, n_elements) in cpu_weights.items():
|
||||
compiled_graph.set_weight_from_ptr(name, ptr, n_elements)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# torch.compile backend entry point
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def luminal_backend(gm, example_inputs, options=None):
|
||||
"""Luminal torch.compile backend.
|
||||
|
||||
@@ -30,17 +81,39 @@ def luminal_backend(gm, example_inputs, options=None):
|
||||
)
|
||||
opset = options.get("opset", 20)
|
||||
|
||||
_register_cache_serialization()
|
||||
device = example_inputs[0].device if example_inputs else torch.device("cpu")
|
||||
backend = "cuda" if device.type == "cuda" else "native"
|
||||
backend = _detect_backend(example_inputs)
|
||||
|
||||
if export_mode == "pt2":
|
||||
return _compile_pt2(gm, example_inputs, backend)
|
||||
return _compile_onnx(gm, example_inputs, backend, opset=opset)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ONNX compilation path
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _compile_onnx(gm, example_inputs, backend, opset=20):
|
||||
"""ONNX compilation path."""
|
||||
# Identify weight vs user inputs from FX graph placeholders.
|
||||
# torch.compile lifts model parameters into graph inputs — we detect them by name prefix.
|
||||
weight_tensors = {} # onnx_name -> tensor
|
||||
user_indices = []
|
||||
ph_idx = 0
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == "placeholder":
|
||||
onnx_name = f"input_{ph_idx}"
|
||||
if node.name.startswith(("l_self_", "l_model_", "l__self_")):
|
||||
weight_tensors[onnx_name] = example_inputs[ph_idx]
|
||||
else:
|
||||
user_indices.append(ph_idx)
|
||||
ph_idx += 1
|
||||
|
||||
# Collect weight pointers for Rust (avoids duplicate GPU buffer allocation)
|
||||
weight_refs, weight_device_ptrs, cpu_weights = _collect_weight_pointers(
|
||||
weight_tensors, backend
|
||||
)
|
||||
|
||||
tmp = tempfile.NamedTemporaryFile(suffix=".onnx", delete=False)
|
||||
tmp_path = tmp.name
|
||||
tmp.close()
|
||||
@@ -54,11 +127,29 @@ def _compile_onnx(gm, example_inputs, backend, opset=20):
|
||||
input_names=[f"input_{i}" for i in range(len(example_inputs))],
|
||||
)
|
||||
|
||||
result = luminal.process_onnx(tmp_path, backend)
|
||||
result = luminal.process_onnx(
|
||||
tmp_path, backend, weight_device_ptrs=weight_device_ptrs
|
||||
)
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
compiled = CompiledModel(result)
|
||||
return compiled
|
||||
|
||||
# Load CPU weights after compilation
|
||||
_load_cpu_weights(result, cpu_weights)
|
||||
|
||||
# Only expose user input names to CompiledModel (weights are pre-loaded).
|
||||
# user_indices tells __call__ which args from torch.compile are real user inputs.
|
||||
user_input_names = [f"input_{i}" for i in user_indices]
|
||||
return CompiledModel(
|
||||
result,
|
||||
weight_refs=weight_refs,
|
||||
input_names=user_input_names,
|
||||
user_indices=user_indices,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PT2 compilation path (delegates to pt2 module)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _compile_pt2(gm, example_inputs, backend):
|
||||
|
||||
@@ -11,12 +11,10 @@ import shutil
|
||||
import tempfile
|
||||
|
||||
import torch
|
||||
from safetensors.torch import save_file
|
||||
|
||||
from .cache_utils import _register_cache_serialization
|
||||
from .compiled_model import CompiledModel
|
||||
from .luminal import compile_pt2 as _compile_pt2_rust
|
||||
|
||||
from .luminal import process_pt2
|
||||
from .main import _collect_weight_pointers, _detect_backend, _load_cpu_weights
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
@@ -34,37 +32,61 @@ def _export_kwargs():
|
||||
return kwargs
|
||||
|
||||
|
||||
def _save_and_compile(ep, backend, search_iterations):
|
||||
"""Save ExportedProgram + weights to temp files, compile via Rust, return CompiledModel."""
|
||||
tmpdir = tempfile.mkdtemp(prefix="luminal_")
|
||||
def _save_and_compile(ep_or_path, backend, search_iterations, original_weights=None):
|
||||
"""Compile a PT2 model via Rust, return CompiledModel.
|
||||
|
||||
Args:
|
||||
ep_or_path: Either an ExportedProgram (will be saved to a temp file) or
|
||||
a path to an already-saved .pt2 file.
|
||||
original_weights: Optional dict mapping state_dict key -> original PyTorch tensor.
|
||||
When provided, device pointers are taken from these tensors instead of
|
||||
ep.state_dict (which torch.export may have cloned), enabling true zero-copy
|
||||
sharing with the original model's GPU memory.
|
||||
"""
|
||||
owns_tmpdir = not isinstance(ep_or_path, str)
|
||||
tmpdir = tempfile.mkdtemp(prefix="luminal_") if owns_tmpdir else None
|
||||
try:
|
||||
pt2_path = os.path.join(tmpdir, "model.pt2")
|
||||
weights_path = os.path.join(tmpdir, "weights.safetensors")
|
||||
|
||||
torch.export.save(ep, pt2_path)
|
||||
|
||||
state_dict = {k: v.float().clone() for k, v in ep.state_dict.items()}
|
||||
if state_dict:
|
||||
save_file(state_dict, weights_path)
|
||||
if owns_tmpdir:
|
||||
pt2_path = os.path.join(tmpdir, "model.pt2")
|
||||
torch.export.save(ep_or_path, pt2_path)
|
||||
weight_source = (
|
||||
original_weights if original_weights else ep_or_path.state_dict
|
||||
)
|
||||
else:
|
||||
weights_path = ""
|
||||
pt2_path = ep_or_path
|
||||
weight_source = original_weights or {}
|
||||
|
||||
compiled = _compile_pt2_rust(pt2_path, weights_path, backend, search_iterations)
|
||||
return CompiledModel(compiled)
|
||||
# Collect weight pointers for Rust (avoids duplicate GPU buffer allocation)
|
||||
keep_alive, weight_device_ptrs, cpu_weights = _collect_weight_pointers(
|
||||
weight_source, backend
|
||||
)
|
||||
|
||||
# Compile with device pointers — search uses actual weight memory (zero-copy)
|
||||
compiled = process_pt2(
|
||||
pt2_path, "", backend, search_iterations, weight_device_ptrs
|
||||
)
|
||||
|
||||
# Load CPU weights after compilation
|
||||
_load_cpu_weights(compiled, cpu_weights)
|
||||
|
||||
return CompiledModel(compiled, weight_refs=keep_alive)
|
||||
finally:
|
||||
shutil.rmtree(tmpdir, ignore_errors=True)
|
||||
if owns_tmpdir and tmpdir:
|
||||
shutil.rmtree(tmpdir, ignore_errors=True)
|
||||
|
||||
|
||||
def _reinternalize_lifted_params(gm, example_inputs):
|
||||
"""Re-internalize lifted params as buffers so torch.export sees them as model state.
|
||||
|
||||
torch.compile lifts model parameters out of the module and passes them as
|
||||
extra elements in example_inputs. The Rust PT2 compiler expects weights in
|
||||
extra elements in example_inputs. The Rust PT2 compiler may expect weights in
|
||||
the .pt2 state dict, not as runtime inputs. This function reverses the
|
||||
lifting by registering them as buffers and replacing the placeholder nodes
|
||||
with get_attr nodes.
|
||||
|
||||
Returns (gm, user_inputs) where user_inputs contains only the real inputs.
|
||||
Returns (gm, user_inputs, original_weights) where:
|
||||
- user_inputs contains only the real inputs
|
||||
- original_weights maps buffer name -> original tensor (for zero-copy device pointers)
|
||||
"""
|
||||
buffer_indices = []
|
||||
user_indices = []
|
||||
@@ -80,12 +102,15 @@ def _reinternalize_lifted_params(gm, example_inputs):
|
||||
user_indices.append(placeholder_idx)
|
||||
placeholder_idx += 1
|
||||
|
||||
original_weights = {}
|
||||
if buffer_nodes:
|
||||
for i, node in enumerate(buffer_nodes):
|
||||
attr_name = f"_luminal_param_{i}"
|
||||
gm.register_buffer(
|
||||
attr_name, example_inputs[buffer_indices[i]].detach().clone()
|
||||
)
|
||||
# Keep a reference to the original tensor for zero-copy device pointers.
|
||||
# torch.export.export may clone the registered buffer, so we bypass
|
||||
# the EP's state_dict and use the originals directly.
|
||||
original_weights[attr_name] = example_inputs[buffer_indices[i]]
|
||||
gm.register_buffer(attr_name, example_inputs[buffer_indices[i]].detach())
|
||||
with gm.graph.inserting_before(node):
|
||||
new_node = gm.graph.create_node("get_attr", attr_name)
|
||||
new_node.meta = node.meta.copy()
|
||||
@@ -99,7 +124,7 @@ def _reinternalize_lifted_params(gm, example_inputs):
|
||||
if user_indices
|
||||
else list(example_inputs)
|
||||
)
|
||||
return gm, user_inputs
|
||||
return gm, user_inputs, original_weights
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -121,22 +146,20 @@ def compile(
|
||||
model: A PyTorch nn.Module.
|
||||
example_input: Example input tensor(s) for tracing.
|
||||
search_iterations: Number of optimization search iterations.
|
||||
backend: "cpu" or "cuda". Auto-detected if None.
|
||||
backend: "native" or "cuda". Auto-detected if None.
|
||||
export_kwargs: Extra kwargs passed to torch.export.export.
|
||||
dynamic_dim: Which input dimension to make dynamic.
|
||||
|
||||
Returns:
|
||||
A CompiledModel callable.
|
||||
"""
|
||||
_register_cache_serialization()
|
||||
|
||||
if dynamic_dim is None:
|
||||
dynamic_dim = "auto"
|
||||
|
||||
if backend is None:
|
||||
backend = os.environ.get("LUMINAL_BACKEND", None)
|
||||
if backend is None:
|
||||
backend = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
backend = "cuda" if torch.cuda.is_available() else "native"
|
||||
|
||||
kwargs = export_kwargs or {}
|
||||
extra = _export_kwargs()
|
||||
@@ -191,11 +214,43 @@ def pt2_backend(gm, example_inputs, backend=None):
|
||||
|
||||
Usage: torch.compile(model, backend=luminal.pt2.pt2_backend)
|
||||
"""
|
||||
_register_cache_serialization()
|
||||
import gc
|
||||
|
||||
if backend is None:
|
||||
device = example_inputs[0].device if example_inputs else torch.device("cpu")
|
||||
backend = "cuda" if device.type == "cuda" else "cpu"
|
||||
backend = _detect_backend(example_inputs)
|
||||
|
||||
gm = gm.eval()
|
||||
gm, user_inputs = _reinternalize_lifted_params(gm, example_inputs)
|
||||
gm, user_inputs, original_weights = _reinternalize_lifted_params(gm, example_inputs)
|
||||
|
||||
ep = torch.export.export(gm, tuple(user_inputs), **_export_kwargs())
|
||||
return _save_and_compile(ep, backend, 10)
|
||||
|
||||
# When using shared memory (original_weights), strip large weight buffers from
|
||||
# the EP before saving. The Rust side uses device pointers for these weights,
|
||||
# not the .pt2 file data, so serializing them is pure IO waste (~32 GB for 8B
|
||||
# models). Replacing with tiny CPU scalars shrinks the .pt2 to < 1 MB.
|
||||
if original_weights:
|
||||
for key in list(ep._state_dict.keys()):
|
||||
if key in original_weights:
|
||||
orig = ep._state_dict[key]
|
||||
ep._state_dict[key] = torch.zeros(1, dtype=orig.dtype, device="cpu")
|
||||
del orig
|
||||
|
||||
# Save the exported program to disk, then free it and the traced graph module
|
||||
# BEFORE Rust compilation. torch.export clones the state_dict internally, so
|
||||
# holding ep alive during compilation would double the weight memory on GPU.
|
||||
tmpdir = tempfile.mkdtemp(prefix="luminal_")
|
||||
pt2_path = os.path.join(tmpdir, "model.pt2")
|
||||
torch.export.save(ep, pt2_path)
|
||||
|
||||
del ep, gm
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
try:
|
||||
result = _save_and_compile(
|
||||
pt2_path, backend, 10, original_weights=original_weights
|
||||
)
|
||||
return result
|
||||
finally:
|
||||
shutil.rmtree(tmpdir, ignore_errors=True)
|
||||
|
||||
176
crates/luminal_python/tests/_llama38b_artifacts.py
Normal file
176
crates/luminal_python/tests/_llama38b_artifacts.py
Normal file
@@ -0,0 +1,176 @@
|
||||
"""Helpers for caching Llama 3.1-8B test artifacts under pytest cache."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from safetensors.torch import save_file
|
||||
from transformers import AutoConfig, LlamaForCausalLM
|
||||
|
||||
# This code is designed to be deleted.
|
||||
# We should not need to cache pt2 or onnx files to get reasonable compile performance.
|
||||
|
||||
MODEL_ID = "NousResearch/Meta-Llama-3.1-8B-Instruct"
|
||||
INPUT_IDS_LIST = [1, 2, 3, 4]
|
||||
INPUT_IDS = torch.tensor([INPUT_IDS_LIST], dtype=torch.long)
|
||||
ARTIFACT_SCHEMA_VERSION = 1
|
||||
ONNX_OPSET_VERSION = 20
|
||||
PT2_STRICT = False
|
||||
|
||||
_REF_LOGITS_META_KEY = "luminal_python/llama38b_artifacts/ref_logits_v1"
|
||||
_ONNX_META_KEY = "luminal_python/llama38b_artifacts/onnx_v1"
|
||||
_PT2_META_KEY = "luminal_python/llama38b_artifacts/pt2_v1"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Llama38BArtifactBundle:
|
||||
ref_logits_path: Path
|
||||
onnx_path: Path | None = None
|
||||
pt2_path: Path | None = None
|
||||
weights_path: Path | None = None
|
||||
|
||||
|
||||
def ensure_onnx_bundle(cache, cache_dir: Path) -> Llama38BArtifactBundle:
|
||||
"""Ensure ONNX artifacts and shared reference logits exist in pytest cache."""
|
||||
ref_logits_path = cache_dir / "ref_logits.pt"
|
||||
onnx_dir = cache_dir / "onnx"
|
||||
onnx_path = onnx_dir / "llama38b.onnx"
|
||||
|
||||
ref_metadata = _ref_logits_metadata()
|
||||
onnx_metadata = _onnx_metadata()
|
||||
needs_ref_logits = cache.get(_REF_LOGITS_META_KEY, None) != ref_metadata or not (
|
||||
ref_logits_path.is_file()
|
||||
)
|
||||
needs_onnx = cache.get(_ONNX_META_KEY, None) != onnx_metadata or not (
|
||||
onnx_path.is_file()
|
||||
)
|
||||
|
||||
if needs_ref_logits or needs_onnx:
|
||||
print(f"Generating cached ONNX artifacts for {MODEL_ID} in {cache_dir}")
|
||||
if needs_ref_logits:
|
||||
ref_logits_path.unlink(missing_ok=True)
|
||||
if needs_onnx:
|
||||
shutil.rmtree(onnx_dir, ignore_errors=True)
|
||||
onnx_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
model = _load_model()
|
||||
|
||||
if needs_ref_logits:
|
||||
ref_logits = _compute_ref_logits(model)
|
||||
torch.save(ref_logits, ref_logits_path)
|
||||
cache.set(_REF_LOGITS_META_KEY, ref_metadata)
|
||||
|
||||
if needs_onnx:
|
||||
torch.onnx.export(
|
||||
model,
|
||||
(INPUT_IDS,),
|
||||
str(onnx_path),
|
||||
opset_version=ONNX_OPSET_VERSION,
|
||||
input_names=["input_ids"],
|
||||
output_names=["logits"],
|
||||
)
|
||||
cache.set(_ONNX_META_KEY, onnx_metadata)
|
||||
|
||||
return Llama38BArtifactBundle(ref_logits_path=ref_logits_path, onnx_path=onnx_path)
|
||||
|
||||
|
||||
def ensure_pt2_bundle(cache, cache_dir: Path) -> Llama38BArtifactBundle:
|
||||
"""Ensure PT2 artifacts and shared reference logits exist in pytest cache."""
|
||||
ref_logits_path = cache_dir / "ref_logits.pt"
|
||||
pt2_dir = cache_dir / "pt2"
|
||||
pt2_path = pt2_dir / "llama38b.pt2"
|
||||
weights_path = pt2_dir / "llama38b_weights.safetensors"
|
||||
|
||||
ref_metadata = _ref_logits_metadata()
|
||||
pt2_metadata = _pt2_metadata()
|
||||
needs_ref_logits = cache.get(_REF_LOGITS_META_KEY, None) != ref_metadata or not (
|
||||
ref_logits_path.is_file()
|
||||
)
|
||||
needs_pt2 = cache.get(_PT2_META_KEY, None) != pt2_metadata or not (
|
||||
pt2_path.is_file() and weights_path.is_file()
|
||||
)
|
||||
|
||||
if needs_ref_logits or needs_pt2:
|
||||
print(f"Generating cached PT2 artifacts for {MODEL_ID} in {cache_dir}")
|
||||
if needs_ref_logits:
|
||||
ref_logits_path.unlink(missing_ok=True)
|
||||
if needs_pt2:
|
||||
shutil.rmtree(pt2_dir, ignore_errors=True)
|
||||
pt2_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
model = _load_model()
|
||||
|
||||
if needs_ref_logits:
|
||||
ref_logits = _compute_ref_logits(model)
|
||||
torch.save(ref_logits, ref_logits_path)
|
||||
cache.set(_REF_LOGITS_META_KEY, ref_metadata)
|
||||
|
||||
if needs_pt2:
|
||||
exported_program = torch.export.export(
|
||||
model, (INPUT_IDS,), strict=PT2_STRICT
|
||||
)
|
||||
torch.export.save(exported_program, str(pt2_path))
|
||||
|
||||
state_dict = {
|
||||
key: value.float().clone()
|
||||
for key, value in exported_program.state_dict.items()
|
||||
}
|
||||
save_file(state_dict, str(weights_path))
|
||||
cache.set(_PT2_META_KEY, pt2_metadata)
|
||||
|
||||
return Llama38BArtifactBundle(
|
||||
ref_logits_path=ref_logits_path,
|
||||
pt2_path=pt2_path,
|
||||
weights_path=weights_path,
|
||||
)
|
||||
|
||||
|
||||
def _load_model() -> LlamaForCausalLM:
|
||||
config = AutoConfig.from_pretrained(MODEL_ID)
|
||||
config.use_cache = False
|
||||
config._attn_implementation = "eager"
|
||||
|
||||
return LlamaForCausalLM.from_pretrained(
|
||||
MODEL_ID,
|
||||
config=config,
|
||||
torch_dtype=torch.float32,
|
||||
).eval()
|
||||
|
||||
|
||||
def _compute_ref_logits(model: LlamaForCausalLM) -> torch.Tensor:
|
||||
with torch.no_grad():
|
||||
return model(INPUT_IDS).logits.clone()
|
||||
|
||||
|
||||
def _ref_logits_metadata() -> dict[str, object]:
|
||||
return {
|
||||
"schema_version": ARTIFACT_SCHEMA_VERSION,
|
||||
"model_id": MODEL_ID,
|
||||
"input_ids": INPUT_IDS_LIST,
|
||||
"device": "cpu",
|
||||
"torch_dtype": "float32",
|
||||
"use_cache": False,
|
||||
"attn_implementation": "eager",
|
||||
"torch_version": torch.__version__,
|
||||
"transformers_version": transformers.__version__,
|
||||
}
|
||||
|
||||
|
||||
def _onnx_metadata() -> dict[str, object]:
|
||||
return {
|
||||
**_ref_logits_metadata(),
|
||||
"artifact_type": "onnx",
|
||||
"opset_version": ONNX_OPSET_VERSION,
|
||||
}
|
||||
|
||||
|
||||
def _pt2_metadata() -> dict[str, object]:
|
||||
return {
|
||||
**_ref_logits_metadata(),
|
||||
"artifact_type": "pt2",
|
||||
"strict": PT2_STRICT,
|
||||
}
|
||||
@@ -10,7 +10,6 @@ import torch._dynamo
|
||||
|
||||
from luminal import luminal_backend
|
||||
|
||||
|
||||
# ========== HuggingFace Qwen3ForCausalLM Tests ==========
|
||||
|
||||
|
||||
@@ -56,12 +55,12 @@ def _run_hf_qwen3_test(config, device: torch.device, atol: float):
|
||||
def test_hf_qwen3_tiny(device: torch.device):
|
||||
"""HuggingFace Qwen3ForCausalLM -- tiny (64 hidden, 1 layer, ~70K params)."""
|
||||
config = _make_qwen3_config(
|
||||
hidden_size=64,
|
||||
num_attention_heads=4,
|
||||
num_key_value_heads=2,
|
||||
hidden_size=32,
|
||||
num_attention_heads=2,
|
||||
num_key_value_heads=1,
|
||||
num_hidden_layers=1,
|
||||
intermediate_size=128,
|
||||
vocab_size=256,
|
||||
intermediate_size=64,
|
||||
vocab_size=128,
|
||||
)
|
||||
_run_hf_qwen3_test(config, device, atol=1e-5)
|
||||
|
||||
@@ -161,167 +160,6 @@ def test_hf_qwen3_decode_loop_static(device: torch.device):
|
||||
tokens.append(next_token)
|
||||
|
||||
|
||||
def test_hf_qwen3_decode_loop_dynamic():
|
||||
"""Decode loop with dynamic shapes -- compile once, run with varying seq_len.
|
||||
|
||||
Bypasses torch.compile to use luminal's dynamic dim support directly.
|
||||
Exports ONNX once with dynamic_axes, then calls set_dim/set_input/run/get_output.
|
||||
"""
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
from transformers import Qwen3Config, Qwen3ForCausalLM
|
||||
|
||||
import luminal
|
||||
|
||||
config = Qwen3Config(
|
||||
hidden_size=64,
|
||||
num_attention_heads=4,
|
||||
num_key_value_heads=2,
|
||||
num_hidden_layers=1,
|
||||
intermediate_size=128,
|
||||
vocab_size=256,
|
||||
max_position_embeddings=128,
|
||||
use_cache=False,
|
||||
attn_implementation="eager",
|
||||
)
|
||||
model = Qwen3ForCausalLM(config).eval()
|
||||
|
||||
# Export ONNX once with dynamic seq_len
|
||||
dummy = torch.tensor([[1, 2, 3, 4]])
|
||||
tmp = tempfile.NamedTemporaryFile(suffix=".onnx", delete=False)
|
||||
tmp_path = tmp.name
|
||||
tmp.close()
|
||||
try:
|
||||
torch.onnx.export(
|
||||
model,
|
||||
(dummy,),
|
||||
tmp_path,
|
||||
opset_version=20,
|
||||
input_names=["input_ids"],
|
||||
output_names=["logits"],
|
||||
dynamic_axes={"input_ids": {1: "seq_len"}, "logits": {1: "seq_len"}},
|
||||
)
|
||||
|
||||
graph = luminal.process_onnx(tmp_path, "native")
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
|
||||
assert graph.has_dynamic_dims, "Graph should have dynamic dims"
|
||||
assert "seq_len" in graph.dim_params, f"Expected 'seq_len' in {graph.dim_params}"
|
||||
|
||||
tokens = [1, 2, 3, 4]
|
||||
for step in range(3):
|
||||
seq_len = len(tokens)
|
||||
graph.set_dim("seq_len", seq_len)
|
||||
|
||||
# Set input as float (luminal works with f32 internally)
|
||||
graph.set_input("input_ids", [float(t) for t in tokens])
|
||||
graph.run()
|
||||
|
||||
# Get output and reshape using resolved shapes
|
||||
output_shapes = graph.resolve_output_shapes()
|
||||
logits_data = graph.get_output("logits")
|
||||
logits = torch.tensor(logits_data, dtype=torch.float32).reshape(
|
||||
output_shapes[0]
|
||||
)
|
||||
|
||||
# Compare against PyTorch reference
|
||||
input_ids = torch.tensor([tokens])
|
||||
with torch.no_grad():
|
||||
ref = model(input_ids)
|
||||
|
||||
assert torch.allclose(logits, ref.logits, atol=1e-4), (
|
||||
f"step {step}: max_diff={torch.max(torch.abs(logits - ref.logits)).item():.2e}"
|
||||
)
|
||||
|
||||
next_token = ref.logits[0, -1, :].argmax().item()
|
||||
tokens.append(next_token)
|
||||
|
||||
|
||||
def test_hf_qwen3_8b_decode_loop_dynamic():
|
||||
"""Decode loop with dynamic shapes on real Qwen3-8B -- compile once, run with varying seq_len.
|
||||
|
||||
Full 8B model with pretrained weights, ONNX exported once with dynamic_axes
|
||||
for seq_len, then decoded autoregressively without recompilation.
|
||||
"""
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
from transformers import AutoConfig, AutoTokenizer, Qwen3ForCausalLM
|
||||
|
||||
import luminal
|
||||
|
||||
backend = os.environ.get("LUMINAL_BACKEND", "cuda")
|
||||
|
||||
config = AutoConfig.from_pretrained("Qwen/Qwen3-8B")
|
||||
config.use_cache = False
|
||||
config._attn_implementation = "eager"
|
||||
print("Loaded config")
|
||||
model = Qwen3ForCausalLM.from_pretrained(
|
||||
"Qwen/Qwen3-8B",
|
||||
config=config,
|
||||
torch_dtype=torch.float32,
|
||||
).eval()
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B")
|
||||
print("Loaded Model")
|
||||
|
||||
# Export ONNX once with dynamic seq_len
|
||||
dummy = torch.tensor([[1, 2, 3, 4]])
|
||||
tmp = tempfile.NamedTemporaryFile(suffix=".onnx", delete=False)
|
||||
tmp_path = tmp.name
|
||||
tmp.close()
|
||||
|
||||
try:
|
||||
torch.onnx.export(
|
||||
model,
|
||||
(dummy,),
|
||||
tmp_path,
|
||||
opset_version=20,
|
||||
input_names=["input_ids"],
|
||||
output_names=["logits"],
|
||||
dynamic_axes={"input_ids": {1: "seq_len"}, "logits": {1: "seq_len"}},
|
||||
)
|
||||
print("Exported onnx")
|
||||
graph = luminal.process_onnx(tmp_path, backend)
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
print("Exported Model")
|
||||
assert graph.has_dynamic_dims, "Graph should have dynamic dims"
|
||||
assert "seq_len" in graph.dim_params, f"Expected 'seq_len' in {graph.dim_params}"
|
||||
|
||||
prompt = "The capital of france is"
|
||||
tokens = tokenizer.encode(prompt)
|
||||
print(f"Prompt: '{prompt}' -> {len(tokens)} tokens: {tokens}")
|
||||
|
||||
num_generate = 3
|
||||
for step in range(num_generate):
|
||||
seq_len = len(tokens)
|
||||
graph.set_dim("seq_len", seq_len)
|
||||
|
||||
graph.set_input("input_ids", [float(t) for t in tokens])
|
||||
graph.run()
|
||||
|
||||
output_shapes = graph.resolve_output_shapes()
|
||||
logits_data = graph.get_output("logits")
|
||||
logits = torch.tensor(logits_data, dtype=torch.float32).reshape(
|
||||
output_shapes[0]
|
||||
)
|
||||
|
||||
# Compare against PyTorch reference
|
||||
input_ids = torch.tensor([tokens])
|
||||
with torch.no_grad():
|
||||
ref = model(input_ids)
|
||||
|
||||
assert torch.allclose(logits, ref.logits, atol=1e-3), (
|
||||
f"step {step}: max_diff={torch.max(torch.abs(logits - ref.logits)).item():.2e}"
|
||||
)
|
||||
|
||||
next_token = ref.logits[0, -1, :].argmax().item()
|
||||
tokens.append(next_token)
|
||||
print(f"Step {step}: '{tokenizer.decode(tokens)}'")
|
||||
|
||||
|
||||
def test_hf_qwen3_8b_full(device: torch.device):
|
||||
"""HuggingFace Qwen3ForCausalLM -- full Qwen3-8B with real pretrained weights.
|
||||
|
||||
|
||||
@@ -1,21 +1,64 @@
|
||||
"""Test configuration."""
|
||||
# ruff: noqa: E402
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
from urllib.request import urlopen
|
||||
import warnings
|
||||
|
||||
try:
|
||||
import huggingface_hub
|
||||
from transformers import logging as transformers_logging
|
||||
except ImportError: # pragma: no cover - optional for non-HF test environments
|
||||
huggingface_hub = None
|
||||
transformers_logging = None
|
||||
|
||||
# Enable automatic Rust rebuilds during test development
|
||||
try:
|
||||
import maturin_import_hook
|
||||
from maturin_import_hook.settings import MaturinSettings
|
||||
import maturin_import_hook
|
||||
from maturin_import_hook.settings import MaturinSettings
|
||||
from maturin_import_hook.project_importer import DefaultProjectFileSearcher
|
||||
|
||||
backend = os.getenv("LUMINAL_BACKEND", "native").lower()
|
||||
settings = MaturinSettings(features=["cuda"]) if backend == "cuda" else None
|
||||
maturin_import_hook.install(settings=settings)
|
||||
except ImportError:
|
||||
pass # Hook not available, rebuilds will be manual
|
||||
backend = os.getenv("LUMINAL_BACKEND", "native").lower()
|
||||
settings = MaturinSettings(
|
||||
release=(backend == "cuda"),
|
||||
features=["cuda"] if backend == "cuda" else None,
|
||||
skip_install=True,
|
||||
)
|
||||
searcher = DefaultProjectFileSearcher(
|
||||
source_excluded_dir_names=(
|
||||
DefaultProjectFileSearcher.DEFAULT_SOURCE_EXCLUDED_DIR_NAMES
|
||||
| {".claude", "docs", ".github", "examples"}
|
||||
),
|
||||
)
|
||||
maturin_import_hook.install(
|
||||
settings=settings,
|
||||
enable_automatic_installation=True,
|
||||
file_searcher=searcher,
|
||||
)
|
||||
logging.getLogger("maturin_import_hook").disabled = True
|
||||
logging.getLogger("maturin_import_hook.project_importer").disabled = True
|
||||
|
||||
# Silence noisy ONNX / onnxscript / httpx logging
|
||||
for _logger_name in (
|
||||
"onnxscript",
|
||||
"onnx_ir",
|
||||
"torch.onnx",
|
||||
"httpx",
|
||||
):
|
||||
logging.getLogger(_logger_name).setLevel(logging.WARNING)
|
||||
|
||||
# Suppress torch.onnx diagnostics/progress output and torchvision warnings
|
||||
os.environ.setdefault("TORCH_ONNX_VERBOSE", "0")
|
||||
os.environ.setdefault("TORCH_ONNX_LOG_LEVEL", "ERROR")
|
||||
warnings.filterwarnings("ignore", message=".*torchvision.*")
|
||||
warnings.filterwarnings("ignore", module="torch.onnx")
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch._dynamo
|
||||
from _llama38b_artifacts import ensure_onnx_bundle, ensure_pt2_bundle
|
||||
|
||||
torch.set_float32_matmul_precision("highest")
|
||||
|
||||
@@ -26,6 +69,139 @@ def device() -> torch.device:
|
||||
return torch.device("cuda") if backend == "cuda" else torch.device("cpu")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def configure_hf_test_output() -> None:
|
||||
if transformers_logging is not None:
|
||||
transformers_logging.disable_progress_bar()
|
||||
if huggingface_hub is not None:
|
||||
huggingface_hub.utils.disable_progress_bars()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def configure_dynamo():
|
||||
original_cache_size_limit = torch._dynamo.config.cache_size_limit
|
||||
original_suppress_errors = torch._dynamo.config.suppress_errors
|
||||
|
||||
def _configure(
|
||||
*, cache_size_limit: int | None = None, suppress_errors: bool | None = None
|
||||
) -> None:
|
||||
if cache_size_limit is not None:
|
||||
torch._dynamo.config.cache_size_limit = cache_size_limit
|
||||
if suppress_errors is not None:
|
||||
torch._dynamo.config.suppress_errors = suppress_errors
|
||||
|
||||
yield _configure
|
||||
|
||||
torch._dynamo.config.cache_size_limit = original_cache_size_limit
|
||||
torch._dynamo.config.suppress_errors = original_suppress_errors
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def _llama38b_cache_dir(pytestconfig: pytest.Config) -> Path:
|
||||
return pytestconfig.cache.mkdir("luminal_llama38b_artifacts_v1")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def _hf_multimodal_cache_dir(pytestconfig: pytest.Config) -> Path:
|
||||
return pytestconfig.cache.mkdir("luminal_hf_multimodal_v1")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def hf_multimodal_image_path(
|
||||
pytestconfig: pytest.Config, _hf_multimodal_cache_dir: Path
|
||||
) -> Path:
|
||||
image_url = (
|
||||
"https://huggingface.co/datasets/huggingface/documentation-images/"
|
||||
"resolve/main/bee.jpg"
|
||||
)
|
||||
image_path = _hf_multimodal_cache_dir / "bee.jpg"
|
||||
metadata_key = "luminal_python/hf_multimodal_image_v1"
|
||||
metadata = {
|
||||
"schema_version": 1,
|
||||
"url": image_url,
|
||||
"filename": image_path.name,
|
||||
}
|
||||
|
||||
needs_download = pytestconfig.cache.get(metadata_key, None) != metadata or not (
|
||||
image_path.is_file()
|
||||
)
|
||||
if not needs_download:
|
||||
return image_path
|
||||
|
||||
image_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
tmp_path: Path | None = None
|
||||
try:
|
||||
with urlopen(image_url, timeout=60) as response:
|
||||
with tempfile.NamedTemporaryFile(
|
||||
dir=image_path.parent, delete=False
|
||||
) as tmp_file:
|
||||
tmp_path = Path(tmp_file.name)
|
||||
while chunk := response.read(1024 * 1024):
|
||||
tmp_file.write(chunk)
|
||||
|
||||
assert tmp_path is not None
|
||||
tmp_path.replace(image_path)
|
||||
except Exception:
|
||||
if tmp_path is not None:
|
||||
tmp_path.unlink(missing_ok=True)
|
||||
raise
|
||||
|
||||
pytestconfig.cache.set(metadata_key, metadata)
|
||||
return image_path
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def _llama38b_onnx_bundle(pytestconfig: pytest.Config, _llama38b_cache_dir: Path):
|
||||
return ensure_onnx_bundle(pytestconfig.cache, _llama38b_cache_dir)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def _llama38b_pt2_bundle(pytestconfig: pytest.Config, _llama38b_cache_dir: Path):
|
||||
return ensure_pt2_bundle(pytestconfig.cache, _llama38b_cache_dir)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def llama38b_ref_logits(request: pytest.FixtureRequest) -> torch.Tensor:
|
||||
fixturenames = set(request.fixturenames)
|
||||
uses_onnx = "llama38b_onnx_path" in fixturenames
|
||||
uses_pt2 = bool({"llama38b_pt2_path", "llama38b_weights_path"} & fixturenames)
|
||||
|
||||
if uses_onnx and uses_pt2:
|
||||
raise pytest.UsageError(
|
||||
"llama38b_ref_logits cannot be requested with both ONNX and PT2 "
|
||||
"artifact fixtures in the same test"
|
||||
)
|
||||
if uses_onnx:
|
||||
bundle = request.getfixturevalue("_llama38b_onnx_bundle")
|
||||
elif uses_pt2:
|
||||
bundle = request.getfixturevalue("_llama38b_pt2_bundle")
|
||||
else:
|
||||
raise pytest.UsageError(
|
||||
"llama38b_ref_logits must be requested alongside llama38b_onnx_path "
|
||||
"or llama38b_pt2_path/llama38b_weights_path"
|
||||
)
|
||||
|
||||
return torch.load(bundle.ref_logits_path, weights_only=True)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def llama38b_onnx_path(_llama38b_onnx_bundle) -> Path:
|
||||
assert _llama38b_onnx_bundle.onnx_path is not None
|
||||
return _llama38b_onnx_bundle.onnx_path
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def llama38b_pt2_path(_llama38b_pt2_bundle) -> Path:
|
||||
assert _llama38b_pt2_bundle.pt2_path is not None
|
||||
return _llama38b_pt2_bundle.pt2_path
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def llama38b_weights_path(_llama38b_pt2_bundle) -> Path:
|
||||
assert _llama38b_pt2_bundle.weights_path is not None
|
||||
return _llama38b_pt2_bundle.weights_path
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True, scope="function")
|
||||
def reset_torch_dynamo():
|
||||
# We need this for two reasons
|
||||
|
||||
@@ -1,57 +0,0 @@
|
||||
"""Generate pre-computed artifacts for test_hf_llama38b_cached_onnx.
|
||||
|
||||
Run once:
|
||||
uv run python tests/generate_llama38b_artifacts.py
|
||||
|
||||
Produces:
|
||||
tests/llama38b.onnx — ONNX export of Llama 3.1-8B
|
||||
tests/llama38b_ref_logits.pt — reference logits for input_ids=[1,2,3,4]
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from transformers import AutoConfig, LlamaForCausalLM
|
||||
|
||||
SCRIPT_DIR = Path(__file__).resolve().parent
|
||||
ONNX_PATH = SCRIPT_DIR / "llama38b.onnx"
|
||||
LOGITS_PATH = SCRIPT_DIR / "llama38b_ref_logits.pt"
|
||||
|
||||
INPUT_IDS = torch.tensor([[1, 2, 3, 4]])
|
||||
|
||||
|
||||
def main():
|
||||
config = AutoConfig.from_pretrained("NousResearch/Meta-Llama-3.1-8B-Instruct")
|
||||
config.use_cache = False
|
||||
config._attn_implementation = "eager"
|
||||
|
||||
print("Loading model on CPU...")
|
||||
model = LlamaForCausalLM.from_pretrained(
|
||||
"NousResearch/Meta-Llama-3.1-8B-Instruct",
|
||||
config=config,
|
||||
torch_dtype=torch.float32,
|
||||
).eval()
|
||||
|
||||
print("Computing reference logits...")
|
||||
with torch.no_grad():
|
||||
ref_logits = model(INPUT_IDS).logits.clone()
|
||||
print(f"Reference logits shape: {ref_logits.shape}")
|
||||
|
||||
print(f"Saving reference logits to {LOGITS_PATH}")
|
||||
torch.save(ref_logits, LOGITS_PATH)
|
||||
|
||||
print(f"Exporting ONNX to {ONNX_PATH}")
|
||||
torch.onnx.export(
|
||||
model,
|
||||
(INPUT_IDS,),
|
||||
str(ONNX_PATH),
|
||||
opset_version=20,
|
||||
input_names=["input_ids"],
|
||||
output_names=["logits"],
|
||||
)
|
||||
|
||||
print("Done.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,62 +0,0 @@
|
||||
"""Generate pre-computed PT2 artifacts for test_hf_llama38b_cached.
|
||||
|
||||
Run once:
|
||||
uv run python tests/generate_llama38b_pt2_artifacts.py
|
||||
|
||||
Produces:
|
||||
tests/llama38b.pt2 — torch.export of Llama 3.1-8B
|
||||
tests/llama38b_weights.safetensors — model weights
|
||||
tests/llama38b_ref_logits.pt — reference logits for input_ids=[1,2,3,4]
|
||||
(shared with ONNX artifact script)
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from safetensors.torch import save_file
|
||||
from transformers import AutoConfig, LlamaForCausalLM
|
||||
|
||||
SCRIPT_DIR = Path(__file__).resolve().parent
|
||||
PT2_PATH = SCRIPT_DIR / "llama38b.pt2"
|
||||
WEIGHTS_PATH = SCRIPT_DIR / "llama38b_weights.safetensors"
|
||||
LOGITS_PATH = SCRIPT_DIR / "llama38b_ref_logits.pt"
|
||||
|
||||
INPUT_IDS = torch.tensor([[1, 2, 3, 4]])
|
||||
|
||||
|
||||
def main():
|
||||
config = AutoConfig.from_pretrained("NousResearch/Meta-Llama-3.1-8B-Instruct")
|
||||
config.use_cache = False
|
||||
config._attn_implementation = "eager"
|
||||
|
||||
print("Loading model on CPU...")
|
||||
model = LlamaForCausalLM.from_pretrained(
|
||||
"NousResearch/Meta-Llama-3.1-8B-Instruct",
|
||||
config=config,
|
||||
torch_dtype=torch.float32,
|
||||
).eval()
|
||||
|
||||
# Generate reference logits (shared with ONNX artifact script)
|
||||
if not LOGITS_PATH.exists():
|
||||
print("Computing reference logits...")
|
||||
with torch.no_grad():
|
||||
ref_logits = model(INPUT_IDS).logits.clone()
|
||||
print(f"Reference logits shape: {ref_logits.shape}")
|
||||
print(f"Saving reference logits to {LOGITS_PATH}")
|
||||
torch.save(ref_logits, LOGITS_PATH)
|
||||
else:
|
||||
print(f"Reference logits already exist at {LOGITS_PATH}, skipping")
|
||||
|
||||
print(f"Exporting PT2 to {PT2_PATH}")
|
||||
ep = torch.export.export(model, (INPUT_IDS,), strict=False)
|
||||
torch.export.save(ep, str(PT2_PATH))
|
||||
|
||||
print(f"Saving weights to {WEIGHTS_PATH}")
|
||||
state_dict = {k: v.float().clone() for k, v in ep.state_dict.items()}
|
||||
save_file(state_dict, str(WEIGHTS_PATH))
|
||||
|
||||
print("Done.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
282
crates/luminal_python/tests/test_hf_causal_lm_config_options.py
Normal file
282
crates/luminal_python/tests/test_hf_causal_lm_config_options.py
Normal file
@@ -0,0 +1,282 @@
|
||||
"""Hugging Face causal-LM config option support tests.
|
||||
|
||||
These tests verify that luminal matches eager Hugging Face execution across
|
||||
supported causal-LM config options using tiny public model definitions loaded
|
||||
through AutoConfig.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch._dynamo
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, GenerationConfig
|
||||
from transformers.generation.configuration_utils import ContinuousBatchingConfig
|
||||
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
|
||||
from luminal import luminal_backend
|
||||
|
||||
# Attention implementations that require optional packages.
|
||||
_ATTN_REQUIRES_PACKAGE: dict[str, str] = {
|
||||
"flash_attention_2": "flash_attn",
|
||||
"flash_attention_3": "flash_attn_3",
|
||||
"flash_attention_4": "cutlass",
|
||||
}
|
||||
|
||||
# Attention implementations known to be incompatible with tiny random models
|
||||
# (e.g. head_dim < 16 or missing scaffolding).
|
||||
_ATTN_SKIP_TINY_MODEL: set[str] = {"flex_attention", "paged_attention"}
|
||||
|
||||
_PAGED_ATTN_IMPLEMENTATIONS = tuple(
|
||||
k for k in ALL_ATTENTION_FUNCTIONS.valid_keys() if k.startswith("paged|")
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _CausalLMConfigCase:
|
||||
case_id: str
|
||||
model_id: str
|
||||
input_ids: tuple[int, ...]
|
||||
atol: float
|
||||
rtol: float
|
||||
|
||||
|
||||
_MODEL_CASES = [
|
||||
_CausalLMConfigCase(
|
||||
case_id="llama_3.2_1B",
|
||||
model_id="meta-llama/Llama-3.2-1B",
|
||||
input_ids=(1, 2, 3, 4),
|
||||
atol=1e-5,
|
||||
rtol=1e-5,
|
||||
)
|
||||
]
|
||||
|
||||
_ATTN_IMPLEMENTATIONS = tuple(
|
||||
dict.fromkeys([None, "eager", *ALL_ATTENTION_FUNCTIONS.valid_keys()])
|
||||
)
|
||||
_CUDA_BACKEND_AVAILABLE = (
|
||||
os.getenv("LUMINAL_BACKEND", "native").lower() == "cuda"
|
||||
and torch.cuda.is_available()
|
||||
)
|
||||
|
||||
|
||||
def _attn_id(attn_impl: str | None) -> str:
|
||||
return "default" if attn_impl is None else attn_impl
|
||||
|
||||
|
||||
def _base_attn_impl(attn_impl: str | None) -> str | None:
|
||||
if attn_impl is None:
|
||||
return None
|
||||
if attn_impl.startswith("paged|"):
|
||||
return attn_impl.split("|", maxsplit=1)[1]
|
||||
return attn_impl
|
||||
|
||||
|
||||
def _attn_param(attn_impl: str | None, *, allow_paged: bool) -> pytest.ParameterSet:
|
||||
marks = []
|
||||
base_attn_impl = _base_attn_impl(attn_impl)
|
||||
|
||||
if base_attn_impl == "flash_attention_2":
|
||||
marks.append(pytest.mark.skip(reason="flash_attention_2 is very slow"))
|
||||
|
||||
if attn_impl is not None and attn_impl.startswith("paged|") and not allow_paged:
|
||||
marks.append(
|
||||
pytest.mark.skip(reason=f"{attn_impl} requires continuous batching API")
|
||||
)
|
||||
|
||||
if base_attn_impl in _ATTN_REQUIRES_PACKAGE:
|
||||
pkg = _ATTN_REQUIRES_PACKAGE[base_attn_impl]
|
||||
if importlib.util.find_spec(pkg) is None:
|
||||
marks.append(
|
||||
pytest.mark.skip(
|
||||
reason=f"{attn_impl} requires package '{pkg}' which is not installed"
|
||||
)
|
||||
)
|
||||
|
||||
if base_attn_impl in _ATTN_SKIP_TINY_MODEL:
|
||||
marks.append(
|
||||
pytest.mark.skip(
|
||||
reason=f"{attn_impl} is incompatible with tiny random test models"
|
||||
)
|
||||
)
|
||||
|
||||
kwargs = {"id": _attn_id(attn_impl)}
|
||||
if marks:
|
||||
kwargs["marks"] = marks
|
||||
return pytest.param(attn_impl, **kwargs)
|
||||
|
||||
|
||||
_ATTN_PREFILL_PARAMS = tuple(
|
||||
_attn_param(attn_impl, allow_paged=False) for attn_impl in _ATTN_IMPLEMENTATIONS
|
||||
)
|
||||
_ATTN_GENERATE_BATCH_PARAMS = tuple(
|
||||
_attn_param(attn_impl, allow_paged=True) for attn_impl in _ATTN_IMPLEMENTATIONS
|
||||
)
|
||||
|
||||
|
||||
def _compare_past_key_values(lhs, rhs, *, atol: float, rtol: float) -> None:
|
||||
assert lhs is not None
|
||||
assert rhs is not None
|
||||
assert hasattr(lhs, "layers")
|
||||
assert hasattr(rhs, "layers")
|
||||
assert len(lhs.layers) == len(rhs.layers)
|
||||
|
||||
for lhs_layer, rhs_layer in zip(lhs.layers, rhs.layers):
|
||||
torch.testing.assert_close(lhs_layer.keys, rhs_layer.keys, atol=atol, rtol=rtol)
|
||||
torch.testing.assert_close(
|
||||
lhs_layer.values, rhs_layer.values, atol=atol, rtol=rtol
|
||||
)
|
||||
|
||||
|
||||
def _instantiate_model(
|
||||
model_id: str, *, use_cache: bool, attn_impl: str | None, device: torch.device
|
||||
) -> AutoModelForCausalLM:
|
||||
config = AutoConfig.from_pretrained(model_id)
|
||||
config.use_cache = use_cache
|
||||
if attn_impl is not None:
|
||||
config._attn_implementation = attn_impl
|
||||
return AutoModelForCausalLM.from_config(config).eval().to(device)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_case", _MODEL_CASES, ids=lambda case: case.case_id)
|
||||
@pytest.mark.parametrize("use_cache", [False, True], ids=["no_cache", "cache"])
|
||||
@pytest.mark.parametrize("attn_impl", _ATTN_PREFILL_PARAMS)
|
||||
def test_hf_causal_lm_config_options_match_eager(
|
||||
model_case: _CausalLMConfigCase,
|
||||
use_cache: bool,
|
||||
attn_impl: str | None,
|
||||
device: torch.device,
|
||||
configure_dynamo,
|
||||
):
|
||||
"""Compare luminal against eager HF across causal-LM config options."""
|
||||
if use_cache:
|
||||
configure_dynamo(cache_size_limit=2)
|
||||
|
||||
model = _instantiate_model(
|
||||
model_case.model_id,
|
||||
use_cache=use_cache,
|
||||
attn_impl=attn_impl,
|
||||
device=device,
|
||||
)
|
||||
input_ids = torch.tensor([model_case.input_ids], device=device)
|
||||
|
||||
with torch.no_grad():
|
||||
eager_prefill = model(input_ids)
|
||||
|
||||
compiled_model = torch.compile(model, backend=luminal_backend)
|
||||
|
||||
with torch.no_grad():
|
||||
compiled_prefill = compiled_model(input_ids)
|
||||
|
||||
torch.testing.assert_close(
|
||||
compiled_prefill.logits,
|
||||
eager_prefill.logits,
|
||||
atol=model_case.atol,
|
||||
rtol=model_case.rtol,
|
||||
)
|
||||
|
||||
if not use_cache:
|
||||
assert eager_prefill.past_key_values is None
|
||||
assert compiled_prefill.past_key_values is None
|
||||
return
|
||||
|
||||
assert eager_prefill.past_key_values is not None
|
||||
assert compiled_prefill.past_key_values is not None
|
||||
_compare_past_key_values(
|
||||
compiled_prefill.past_key_values,
|
||||
eager_prefill.past_key_values,
|
||||
atol=model_case.atol,
|
||||
rtol=model_case.rtol,
|
||||
)
|
||||
|
||||
next_token = eager_prefill.logits[:, -1, :].argmax(dim=-1, keepdim=True)
|
||||
|
||||
with torch.no_grad():
|
||||
eager_decode = model(next_token, past_key_values=eager_prefill.past_key_values)
|
||||
compiled_decode = compiled_model(
|
||||
next_token,
|
||||
past_key_values=compiled_prefill.past_key_values,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(
|
||||
compiled_decode.logits,
|
||||
eager_decode.logits,
|
||||
atol=model_case.atol,
|
||||
rtol=model_case.rtol,
|
||||
)
|
||||
assert eager_decode.past_key_values is not None
|
||||
assert compiled_decode.past_key_values is not None
|
||||
_compare_past_key_values(
|
||||
compiled_decode.past_key_values,
|
||||
eager_decode.past_key_values,
|
||||
atol=model_case.atol,
|
||||
rtol=model_case.rtol,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_case", _MODEL_CASES, ids=lambda case: case.case_id)
|
||||
@pytest.mark.parametrize("attn_impl", _ATTN_GENERATE_BATCH_PARAMS)
|
||||
@pytest.mark.skipif(not _CUDA_BACKEND_AVAILABLE, reason="generate_batch requires CUDA")
|
||||
def test_hf_generate_batch(
|
||||
model_case: _CausalLMConfigCase,
|
||||
attn_impl: str | None,
|
||||
device: torch.device,
|
||||
):
|
||||
"""Compare generate_batch output for each attention variant against eager baseline."""
|
||||
config = AutoConfig.from_pretrained(model_case.model_id)
|
||||
config.use_cache = True
|
||||
model = (
|
||||
AutoModelForCausalLM.from_config(config)
|
||||
.to(dtype=torch.bfloat16)
|
||||
.eval()
|
||||
.to(device)
|
||||
)
|
||||
|
||||
gen_config = GenerationConfig(
|
||||
do_sample=False,
|
||||
max_new_tokens=5,
|
||||
temperature=None,
|
||||
top_p=None,
|
||||
top_k=None,
|
||||
)
|
||||
cb_config = ContinuousBatchingConfig(
|
||||
block_size=256,
|
||||
use_cuda_graph=False,
|
||||
)
|
||||
|
||||
inputs = [list(model_case.input_ids)]
|
||||
|
||||
# Baseline: eager generate_batch.
|
||||
model.set_attn_implementation("eager")
|
||||
eager_outputs = model.generate_batch(
|
||||
inputs,
|
||||
generation_config=gen_config,
|
||||
continuous_batching_config=cb_config,
|
||||
progress_bar=False,
|
||||
warmup=False,
|
||||
)
|
||||
|
||||
# Variant under test.
|
||||
if attn_impl is not None:
|
||||
model.set_attn_implementation(attn_impl)
|
||||
variant_outputs = model.generate_batch(
|
||||
inputs,
|
||||
generation_config=gen_config,
|
||||
continuous_batching_config=cb_config,
|
||||
progress_bar=False,
|
||||
warmup=False,
|
||||
)
|
||||
|
||||
assert len(eager_outputs) == len(variant_outputs)
|
||||
eager_out = next(iter(eager_outputs.values()))
|
||||
variant_out = next(iter(variant_outputs.values()))
|
||||
assert eager_out.error is None, f"Eager baseline failed: {eager_out.error}"
|
||||
assert variant_out.error is None, f"Variant {attn_impl} failed: {variant_out.error}"
|
||||
assert eager_out.generated_tokens == variant_out.generated_tokens, (
|
||||
f"Token mismatch for {attn_impl}: eager={eager_out.generated_tokens} "
|
||||
f"vs variant={variant_out.generated_tokens}"
|
||||
)
|
||||
168
crates/luminal_python/tests/test_hf_causal_lm_experts_options.py
Normal file
168
crates/luminal_python/tests/test_hf_causal_lm_experts_options.py
Normal file
@@ -0,0 +1,168 @@
|
||||
"""Hugging Face causal-LM experts backend smoke tests.
|
||||
|
||||
These tests load a real pretrained text-only MoE model and compare eager
|
||||
PyTorch against torch.compile with the luminal backend across the standardized
|
||||
`experts_implementation` backends.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoConfig, AutoModelForCausalLM
|
||||
|
||||
from luminal import luminal_backend
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _HFMoeCase:
|
||||
case_id: str
|
||||
model_id: str
|
||||
input_ids: tuple[tuple[int, ...], ...]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _HFMoeBundle:
|
||||
case: _HFMoeCase
|
||||
model: AutoModelForCausalLM
|
||||
device: torch.device
|
||||
dtype: torch.dtype
|
||||
|
||||
|
||||
_MODEL_CASES = [
|
||||
_HFMoeCase(
|
||||
case_id="qwen15_moe_a27b",
|
||||
model_id="Qwen/Qwen1.5-MoE-A2.7B",
|
||||
input_ids=(
|
||||
(1, 2, 3, 4, 5, 6, 7, 8),
|
||||
(8, 7, 6, 5, 4, 3, 2, 1),
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
_EXPERTS_IMPLEMENTATIONS = ("eager", "batched_mm", "grouped_mm")
|
||||
_CUDA_BACKEND_AVAILABLE = (
|
||||
os.getenv("LUMINAL_BACKEND", "native").lower() == "cuda"
|
||||
and torch.cuda.is_available()
|
||||
)
|
||||
|
||||
pytestmark = [
|
||||
pytest.mark.slow,
|
||||
pytest.mark.skipif(
|
||||
not _CUDA_BACKEND_AVAILABLE,
|
||||
reason="HF MoE experts backend tests require the CUDA backend",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def _model_dtype(device: torch.device) -> torch.dtype:
|
||||
if device.type != "cuda":
|
||||
return torch.float32
|
||||
return torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
||||
|
||||
|
||||
def _output_tolerance(dtype: torch.dtype) -> float:
|
||||
if dtype == torch.bfloat16:
|
||||
return 5e-2
|
||||
if dtype == torch.float16:
|
||||
return 1e-2
|
||||
return 1e-3
|
||||
|
||||
|
||||
def _compare_router_logits(lhs, rhs, *, atol: float, rtol: float) -> None:
|
||||
assert lhs is not None
|
||||
assert rhs is not None
|
||||
|
||||
if isinstance(lhs, torch.Tensor):
|
||||
torch.testing.assert_close(lhs, rhs, atol=atol, rtol=rtol)
|
||||
return
|
||||
|
||||
assert len(lhs) == len(rhs)
|
||||
for lhs_layer, rhs_layer in zip(lhs, rhs):
|
||||
torch.testing.assert_close(lhs_layer, rhs_layer, atol=atol, rtol=rtol)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=_MODEL_CASES, ids=lambda case: case.case_id)
|
||||
def hf_moe_case(request: pytest.FixtureRequest) -> _HFMoeCase:
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def hf_moe_bundle(hf_moe_case: _HFMoeCase) -> _HFMoeBundle:
|
||||
case = hf_moe_case
|
||||
device = torch.device("cuda")
|
||||
dtype = _model_dtype(device)
|
||||
|
||||
config = AutoConfig.from_pretrained(case.model_id)
|
||||
config.use_cache = False
|
||||
config.output_router_logits = True
|
||||
# Keep attention fixed so this test isolates experts backends.
|
||||
config._attn_implementation = "eager"
|
||||
|
||||
model = (
|
||||
AutoModelForCausalLM.from_pretrained(
|
||||
case.model_id,
|
||||
config=config,
|
||||
torch_dtype=dtype,
|
||||
)
|
||||
.eval()
|
||||
.to(device)
|
||||
)
|
||||
return _HFMoeBundle(case=case, model=model, device=device, dtype=dtype)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"experts_implementation",
|
||||
_EXPERTS_IMPLEMENTATIONS,
|
||||
ids=list(_EXPERTS_IMPLEMENTATIONS),
|
||||
)
|
||||
def test_hf_causal_lm_experts_implementation_matches_eager(
|
||||
hf_moe_bundle: _HFMoeBundle, experts_implementation: str
|
||||
):
|
||||
model = hf_moe_bundle.model
|
||||
model.set_experts_implementation(experts_implementation)
|
||||
assert model.config._experts_implementation == experts_implementation
|
||||
|
||||
input_ids = torch.tensor(hf_moe_bundle.case.input_ids, device=hf_moe_bundle.device)
|
||||
kwargs = {
|
||||
"input_ids": input_ids,
|
||||
"use_cache": False,
|
||||
"output_router_logits": True,
|
||||
"logits_to_keep": 1,
|
||||
}
|
||||
|
||||
with torch.no_grad():
|
||||
eager_output = model(**kwargs)
|
||||
|
||||
compiled_model = torch.compile(model, backend=luminal_backend)
|
||||
with torch.no_grad():
|
||||
compiled_output = compiled_model(**kwargs)
|
||||
|
||||
atol = _output_tolerance(hf_moe_bundle.dtype)
|
||||
rtol = 1e-3
|
||||
|
||||
torch.testing.assert_close(
|
||||
compiled_output.logits,
|
||||
eager_output.logits,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
)
|
||||
_compare_router_logits(
|
||||
compiled_output.router_logits,
|
||||
eager_output.router_logits,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
)
|
||||
|
||||
if eager_output.aux_loss is not None or compiled_output.aux_loss is not None:
|
||||
assert eager_output.aux_loss is not None
|
||||
assert compiled_output.aux_loss is not None
|
||||
torch.testing.assert_close(
|
||||
compiled_output.aux_loss,
|
||||
eager_output.aux_loss,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
)
|
||||
242
crates/luminal_python/tests/test_hf_multimodal_generation.py
Normal file
242
crates/luminal_python/tests/test_hf_multimodal_generation.py
Normal file
@@ -0,0 +1,242 @@
|
||||
"""Hugging Face multimodal image-text-to-text smoke tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoConfig, AutoModelForImageTextToText, AutoProcessor
|
||||
|
||||
from luminal import luminal_backend
|
||||
|
||||
MODEL_ID = "google/gemma-3-4b-it"
|
||||
_CUDA_BACKEND_AVAILABLE = (
|
||||
os.getenv("LUMINAL_BACKEND", "native").lower() == "cuda"
|
||||
and torch.cuda.is_available()
|
||||
)
|
||||
|
||||
pytestmark = [
|
||||
pytest.mark.slow,
|
||||
pytest.mark.skipif(
|
||||
not _CUDA_BACKEND_AVAILABLE,
|
||||
reason="Gemma 3 multimodal tests require the CUDA backend",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class HFMultimodalCase:
|
||||
case_id: str
|
||||
messages_builder: Callable[[Path], list[dict]]
|
||||
max_new_tokens: int
|
||||
expects_pixel_values: bool
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Gemma3MultimodalBundle:
|
||||
model: AutoModelForImageTextToText
|
||||
processor: AutoProcessor
|
||||
device: torch.device
|
||||
dtype: torch.dtype
|
||||
|
||||
|
||||
def _build_text_only_messages(_: Path) -> list[dict]:
|
||||
return [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [{"type": "text", "text": "You are a concise assistant."}],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "In one short sentence, explain what a compiler does.",
|
||||
}
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def _build_image_to_text_messages(image_path: Path) -> list[dict]:
|
||||
return [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [{"type": "text", "text": "You are a helpful assistant."}],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "path": str(image_path)},
|
||||
{"type": "text", "text": "Describe this image in one short sentence."},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
MULTIMODAL_CASES = [
|
||||
HFMultimodalCase(
|
||||
case_id="chat_text_only",
|
||||
messages_builder=_build_text_only_messages,
|
||||
max_new_tokens=12,
|
||||
expects_pixel_values=False,
|
||||
),
|
||||
HFMultimodalCase(
|
||||
case_id="image_to_text",
|
||||
messages_builder=_build_image_to_text_messages,
|
||||
max_new_tokens=16,
|
||||
expects_pixel_values=True,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def _model_dtype(device: torch.device) -> torch.dtype:
|
||||
if device.type != "cuda":
|
||||
return torch.float32
|
||||
return torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
||||
|
||||
|
||||
def _set_greedy_generation(model) -> None:
|
||||
model.generation_config.temperature = None
|
||||
model.generation_config.top_p = None
|
||||
model.generation_config.top_k = None
|
||||
|
||||
|
||||
def _move_to_device(
|
||||
encoded: dict[str, torch.Tensor], device: torch.device, dtype: torch.dtype
|
||||
) -> dict[str, torch.Tensor]:
|
||||
result = {}
|
||||
for key, value in encoded.items():
|
||||
if not isinstance(value, torch.Tensor):
|
||||
result[key] = value
|
||||
continue
|
||||
|
||||
moved = value.to(device)
|
||||
if moved.is_floating_point():
|
||||
moved = moved.to(dtype=dtype)
|
||||
result[key] = moved
|
||||
return result
|
||||
|
||||
|
||||
def _encode_case(
|
||||
bundle: Gemma3MultimodalBundle,
|
||||
case: HFMultimodalCase,
|
||||
image_path: Path,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
encoded = bundle.processor.apply_chat_template(
|
||||
case.messages_builder(image_path),
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
encoded = _move_to_device(dict(encoded), bundle.device, bundle.dtype)
|
||||
|
||||
if case.expects_pixel_values:
|
||||
assert "pixel_values" in encoded
|
||||
assert "input_ids" in encoded
|
||||
return encoded
|
||||
|
||||
|
||||
def _generate_kwargs(
|
||||
bundle: Gemma3MultimodalBundle,
|
||||
encoded: dict[str, torch.Tensor],
|
||||
max_new_tokens: int,
|
||||
) -> dict:
|
||||
tokenizer = bundle.processor.tokenizer
|
||||
return dict(
|
||||
**encoded,
|
||||
max_new_tokens=max_new_tokens,
|
||||
do_sample=False,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
)
|
||||
|
||||
|
||||
def _logits_tolerance(dtype: torch.dtype) -> float:
|
||||
if dtype == torch.bfloat16:
|
||||
return 5e-2
|
||||
if dtype == torch.float16:
|
||||
return 1e-2
|
||||
return 1e-3
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def gemma3_multimodal_bundle() -> Gemma3MultimodalBundle:
|
||||
device = torch.device("cuda")
|
||||
dtype = _model_dtype(device)
|
||||
|
||||
config = AutoConfig.from_pretrained(MODEL_ID)
|
||||
processor = AutoProcessor.from_pretrained(MODEL_ID)
|
||||
tokenizer = processor.tokenizer
|
||||
if tokenizer.pad_token_id is None and tokenizer.eos_token is not None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
model = (
|
||||
AutoModelForImageTextToText.from_pretrained(
|
||||
MODEL_ID,
|
||||
config=config,
|
||||
torch_dtype=dtype,
|
||||
)
|
||||
.eval()
|
||||
.to(device)
|
||||
)
|
||||
_set_greedy_generation(model)
|
||||
return Gemma3MultimodalBundle(
|
||||
model=model, processor=processor, device=device, dtype=dtype
|
||||
)
|
||||
|
||||
|
||||
class TestHFMultimodalGeneration:
|
||||
@pytest.mark.parametrize("case", MULTIMODAL_CASES, ids=lambda case: case.case_id)
|
||||
def test_generate_matches_eager(
|
||||
self,
|
||||
case: HFMultimodalCase,
|
||||
gemma3_multimodal_bundle: Gemma3MultimodalBundle,
|
||||
hf_multimodal_image_path: Path,
|
||||
):
|
||||
encoded = _encode_case(gemma3_multimodal_bundle, case, hf_multimodal_image_path)
|
||||
kwargs = _generate_kwargs(
|
||||
gemma3_multimodal_bundle, encoded, case.max_new_tokens
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
eager_output = gemma3_multimodal_bundle.model.generate(**kwargs)
|
||||
|
||||
compiled_model = torch.compile(
|
||||
gemma3_multimodal_bundle.model, backend=luminal_backend
|
||||
)
|
||||
with torch.no_grad():
|
||||
compiled_output = compiled_model.generate(**kwargs)
|
||||
|
||||
torch.testing.assert_close(compiled_output, eager_output)
|
||||
|
||||
@pytest.mark.parametrize("case", MULTIMODAL_CASES, ids=lambda case: case.case_id)
|
||||
def test_forward_logits_match_eager(
|
||||
self,
|
||||
case: HFMultimodalCase,
|
||||
gemma3_multimodal_bundle: Gemma3MultimodalBundle,
|
||||
hf_multimodal_image_path: Path,
|
||||
):
|
||||
encoded = _encode_case(gemma3_multimodal_bundle, case, hf_multimodal_image_path)
|
||||
|
||||
with torch.no_grad():
|
||||
eager_out = gemma3_multimodal_bundle.model(**encoded)
|
||||
|
||||
compiled_model = torch.compile(
|
||||
gemma3_multimodal_bundle.model, backend=luminal_backend
|
||||
)
|
||||
with torch.no_grad():
|
||||
compiled_out = compiled_model(**encoded)
|
||||
|
||||
atol = _logits_tolerance(gemma3_multimodal_bundle.dtype)
|
||||
torch.testing.assert_close(
|
||||
compiled_out.logits,
|
||||
eager_out.logits,
|
||||
atol=atol,
|
||||
rtol=1e-3,
|
||||
)
|
||||
288
crates/luminal_python/tests/test_hf_text_generation.py
Normal file
288
crates/luminal_python/tests/test_hf_text_generation.py
Normal file
@@ -0,0 +1,288 @@
|
||||
"""Hugging Face text-generation smoke tests.
|
||||
|
||||
These tests intentionally download real Hugging Face checkpoints, configs,
|
||||
and tokenizers. They compare eager PyTorch output against torch.compile
|
||||
with the luminal backend to verify numerical equivalence.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch._dynamo
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from luminal import luminal_backend
|
||||
|
||||
SIMPLE_DENSE_MODELS = [
|
||||
"meta-llama/Llama-3.2-1B",
|
||||
"meta-llama/Llama-3.1-1B",
|
||||
"Qwen/Qwen3-8B",
|
||||
]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class HFTextGenerationBundle:
|
||||
model: AutoModelForCausalLM
|
||||
tokenizer: AutoTokenizer
|
||||
device: torch.device
|
||||
|
||||
|
||||
def _load_model_and_tokenizer(
|
||||
model_id: str, device: torch.device
|
||||
) -> tuple[AutoModelForCausalLM, AutoTokenizer]:
|
||||
"""Load a pretrained HF causal LM and its tokenizer, ready for generation."""
|
||||
config = AutoConfig.from_pretrained(model_id)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
|
||||
if tokenizer.pad_token_id is None and tokenizer.eos_token is not None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
dtype = torch.float16 if device.type == "cuda" else torch.float32
|
||||
model = (
|
||||
AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
config=config,
|
||||
torch_dtype=dtype,
|
||||
)
|
||||
.eval()
|
||||
.to(device)
|
||||
)
|
||||
|
||||
model.generation_config.temperature = None
|
||||
model.generation_config.top_p = None
|
||||
model.generation_config.top_k = None
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def _encode(
|
||||
tokenizer: AutoTokenizer, prompt: str, device: torch.device
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""Tokenize a prompt and move tensors to device."""
|
||||
encoded = tokenizer(prompt, return_tensors="pt")
|
||||
result = {"input_ids": encoded["input_ids"].to(device)}
|
||||
if encoded.get("attention_mask") is not None:
|
||||
result["attention_mask"] = encoded["attention_mask"].to(device)
|
||||
return result
|
||||
|
||||
|
||||
def _generate_kwargs(
|
||||
tokenizer: AutoTokenizer,
|
||||
encoded: dict[str, torch.Tensor],
|
||||
max_new_tokens: int = 6,
|
||||
) -> dict:
|
||||
"""Build kwargs dict for model.generate()."""
|
||||
return dict(
|
||||
**encoded,
|
||||
max_new_tokens=max_new_tokens,
|
||||
do_sample=False,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def hf_text_bundle(model_id: str, device: torch.device) -> HFTextGenerationBundle:
|
||||
model, tokenizer = _load_model_and_tokenizer(model_id, device)
|
||||
return HFTextGenerationBundle(model=model, tokenizer=tokenizer, device=device)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
class TestHFGeneration:
|
||||
"""End-to-end tests comparing eager PyTorch against torch.compile with luminal."""
|
||||
|
||||
@pytest.mark.parametrize("model_id", SIMPLE_DENSE_MODELS)
|
||||
def test_capital_of_france(self, hf_text_bundle: HFTextGenerationBundle):
|
||||
"""Basic greedy generation -- the original smoke test."""
|
||||
encoded = _encode(
|
||||
hf_text_bundle.tokenizer,
|
||||
"What is the capital of France ",
|
||||
hf_text_bundle.device,
|
||||
)
|
||||
kwargs = _generate_kwargs(hf_text_bundle.tokenizer, encoded)
|
||||
|
||||
with torch.no_grad():
|
||||
eager_output = hf_text_bundle.model.generate(**kwargs)
|
||||
|
||||
compiled_model = torch.compile(hf_text_bundle.model, backend=luminal_backend)
|
||||
with torch.no_grad():
|
||||
compiled_output = compiled_model.generate(**kwargs)
|
||||
|
||||
torch.testing.assert_close(compiled_output, eager_output)
|
||||
|
||||
@pytest.mark.parametrize("model_id", SIMPLE_DENSE_MODELS)
|
||||
def test_forward_logits(self, hf_text_bundle: HFTextGenerationBundle):
|
||||
"""Forward pass only -- compare raw logits, not generated tokens."""
|
||||
encoded = _encode(
|
||||
hf_text_bundle.tokenizer, "The quick brown fox", hf_text_bundle.device
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
eager_out = hf_text_bundle.model(**encoded)
|
||||
|
||||
compiled_model = torch.compile(hf_text_bundle.model, backend=luminal_backend)
|
||||
with torch.no_grad():
|
||||
compiled_out = compiled_model(**encoded)
|
||||
|
||||
dtype = next(hf_text_bundle.model.parameters()).dtype
|
||||
atol = 1e-2 if dtype == torch.float16 else 1e-3
|
||||
torch.testing.assert_close(
|
||||
compiled_out.logits, eager_out.logits, atol=atol, rtol=1e-3
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("model_id", SIMPLE_DENSE_MODELS[:1])
|
||||
@pytest.mark.parametrize(
|
||||
"prompt",
|
||||
[
|
||||
"Hi",
|
||||
"What is the capital of France",
|
||||
"Explain the theory of general relativity in simple terms that a high school student could understand",
|
||||
],
|
||||
ids=["short", "medium", "long"],
|
||||
)
|
||||
def test_variable_length_prompts(
|
||||
self,
|
||||
prompt: str,
|
||||
hf_text_bundle: HFTextGenerationBundle,
|
||||
):
|
||||
"""Generate with prompts of different lengths -- tests dynamic shape handling."""
|
||||
encoded = _encode(hf_text_bundle.tokenizer, prompt, hf_text_bundle.device)
|
||||
kwargs = _generate_kwargs(hf_text_bundle.tokenizer, encoded, max_new_tokens=4)
|
||||
|
||||
with torch.no_grad():
|
||||
eager_output = hf_text_bundle.model.generate(**kwargs)
|
||||
|
||||
compiled_model = torch.compile(hf_text_bundle.model, backend=luminal_backend)
|
||||
with torch.no_grad():
|
||||
compiled_output = compiled_model.generate(**kwargs)
|
||||
|
||||
torch.testing.assert_close(compiled_output, eager_output)
|
||||
|
||||
@pytest.mark.parametrize("model_id", SIMPLE_DENSE_MODELS[:1])
|
||||
def test_chat_template_generation(
|
||||
self,
|
||||
model_id: str,
|
||||
hf_text_bundle: HFTextGenerationBundle,
|
||||
):
|
||||
"""Generate using chat-templated input with special tokens."""
|
||||
if hf_text_bundle.tokenizer.chat_template is None:
|
||||
pytest.skip(f"{model_id} has no chat template")
|
||||
|
||||
messages = [{"role": "user", "content": "What is 2+2?"}]
|
||||
encoded = hf_text_bundle.tokenizer.apply_chat_template(
|
||||
messages,
|
||||
return_tensors="pt",
|
||||
add_generation_prompt=True,
|
||||
return_dict=True,
|
||||
)
|
||||
encoded = {k: v.to(hf_text_bundle.device) for k, v in encoded.items()}
|
||||
kwargs = _generate_kwargs(hf_text_bundle.tokenizer, encoded)
|
||||
|
||||
with torch.no_grad():
|
||||
eager_output = hf_text_bundle.model.generate(**kwargs)
|
||||
|
||||
compiled_model = torch.compile(hf_text_bundle.model, backend=luminal_backend)
|
||||
with torch.no_grad():
|
||||
compiled_output = compiled_model.generate(**kwargs)
|
||||
|
||||
torch.testing.assert_close(compiled_output, eager_output)
|
||||
|
||||
@pytest.mark.parametrize("model_id", SIMPLE_DENSE_MODELS[:1])
|
||||
@pytest.mark.parametrize("max_new_tokens", [20, 50])
|
||||
def test_longer_generation(
|
||||
self,
|
||||
max_new_tokens: int,
|
||||
hf_text_bundle: HFTextGenerationBundle,
|
||||
):
|
||||
"""Generate many tokens to stress KV cache over extended decode loop."""
|
||||
encoded = _encode(
|
||||
hf_text_bundle.tokenizer, "Once upon a time", hf_text_bundle.device
|
||||
)
|
||||
kwargs = _generate_kwargs(
|
||||
hf_text_bundle.tokenizer,
|
||||
encoded,
|
||||
max_new_tokens=max_new_tokens,
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
eager_output = hf_text_bundle.model.generate(**kwargs)
|
||||
|
||||
compiled_model = torch.compile(hf_text_bundle.model, backend=luminal_backend)
|
||||
with torch.no_grad():
|
||||
compiled_output = compiled_model.generate(**kwargs)
|
||||
|
||||
torch.testing.assert_close(compiled_output, eager_output)
|
||||
|
||||
@pytest.mark.parametrize("model_id", SIMPLE_DENSE_MODELS[:1])
|
||||
def test_greedy_determinism(
|
||||
self,
|
||||
hf_text_bundle: HFTextGenerationBundle,
|
||||
configure_dynamo,
|
||||
):
|
||||
"""Greedy generation produces identical results on repeated calls."""
|
||||
configure_dynamo(cache_size_limit=4)
|
||||
|
||||
encoded = _encode(
|
||||
hf_text_bundle.tokenizer,
|
||||
"The meaning of life is",
|
||||
hf_text_bundle.device,
|
||||
)
|
||||
kwargs = _generate_kwargs(hf_text_bundle.tokenizer, encoded, max_new_tokens=10)
|
||||
|
||||
compiled_model = torch.compile(hf_text_bundle.model, backend=luminal_backend)
|
||||
with torch.no_grad():
|
||||
output_1 = compiled_model.generate(**kwargs)
|
||||
output_2 = compiled_model.generate(**kwargs)
|
||||
|
||||
torch.testing.assert_close(output_1, output_2)
|
||||
|
||||
@pytest.mark.parametrize("model_id", SIMPLE_DENSE_MODELS[:1])
|
||||
def test_reuse_compiled_model(
|
||||
self,
|
||||
hf_text_bundle: HFTextGenerationBundle,
|
||||
configure_dynamo,
|
||||
):
|
||||
"""Call the same compiled model multiple times with different prompts."""
|
||||
configure_dynamo(cache_size_limit=8)
|
||||
|
||||
compiled_model = torch.compile(hf_text_bundle.model, backend=luminal_backend)
|
||||
|
||||
prompts = [
|
||||
"The capital of France is",
|
||||
"Water boils at",
|
||||
"The largest planet in our solar system is",
|
||||
]
|
||||
|
||||
for prompt in prompts:
|
||||
encoded = _encode(hf_text_bundle.tokenizer, prompt, hf_text_bundle.device)
|
||||
kwargs = _generate_kwargs(
|
||||
hf_text_bundle.tokenizer, encoded, max_new_tokens=4
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
eager_output = hf_text_bundle.model.generate(**kwargs)
|
||||
compiled_output = compiled_model.generate(**kwargs)
|
||||
|
||||
torch.testing.assert_close(compiled_output, eager_output)
|
||||
|
||||
@pytest.mark.parametrize("model_id", SIMPLE_DENSE_MODELS[:1])
|
||||
def test_batched_inference(self, hf_text_bundle: HFTextGenerationBundle):
|
||||
"""Batched generation with multiple prompts and left-padding."""
|
||||
hf_text_bundle.tokenizer.padding_side = "left"
|
||||
|
||||
prompts = ["Hello", "What is the capital of France"]
|
||||
encoded = hf_text_bundle.tokenizer(prompts, return_tensors="pt", padding=True)
|
||||
encoded = {k: v.to(hf_text_bundle.device) for k, v in encoded.items()}
|
||||
kwargs = _generate_kwargs(hf_text_bundle.tokenizer, encoded, max_new_tokens=4)
|
||||
|
||||
with torch.no_grad():
|
||||
eager_output = hf_text_bundle.model.generate(**kwargs)
|
||||
|
||||
compiled_model = torch.compile(hf_text_bundle.model, backend=luminal_backend)
|
||||
with torch.no_grad():
|
||||
compiled_output = compiled_model.generate(**kwargs)
|
||||
|
||||
torch.testing.assert_close(compiled_output, eager_output)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -224,127 +224,15 @@ def test_hf_llama_decode_loop_static(device: torch.device):
|
||||
tokens.append(next_token)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.skip(reason="This is currently failing and in development")
|
||||
def test_hf_llama3_1b_decode_loop_dynamic():
|
||||
"""Decode loop with dynamic shapes on real Llama3.2-1B — compile once, run with varying seq_len.
|
||||
@pytest.mark.xfail(reason="numerical precision — max_diff exceeds atol")
|
||||
def test_hf_llama3_1b_decode_loop_dynamic(device: torch.device):
|
||||
"""Decode loop on real Llama3.2-1B with pretrained weights.
|
||||
|
||||
This is the end-goal test: full 1B model with pretrained weights, CUDA backend,
|
||||
ONNX exported once with dynamic_axes for seq_len, then decoded autoregressively
|
||||
without recompilation.
|
||||
|
||||
Supports both ONNX and PT2 export modes via LUMINAL_EXPORT_MODE env var.
|
||||
Recompiles each step as sequence length grows, using the standard
|
||||
torch.compile(model, backend=luminal_backend) pattern.
|
||||
"""
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
from transformers import AutoConfig, AutoTokenizer, LlamaForCausalLM
|
||||
|
||||
import luminal
|
||||
|
||||
backend = os.environ.get("LUMINAL_BACKEND", "cuda")
|
||||
export_mode = os.getenv("LUMINAL_EXPORT_MODE", "onnx").lower()
|
||||
|
||||
config = AutoConfig.from_pretrained("NousResearch/Llama-3.2-1B")
|
||||
config.use_cache = False
|
||||
config._attn_implementation = "eager"
|
||||
print("Loaded config")
|
||||
model = LlamaForCausalLM.from_pretrained(
|
||||
"NousResearch/Llama-3.2-1B",
|
||||
config=config,
|
||||
torch_dtype=torch.float32,
|
||||
).eval()
|
||||
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-3.2-1B")
|
||||
print("Loaded Model")
|
||||
|
||||
prompt = "The capital of france is"
|
||||
tokens = tokenizer.encode(prompt)
|
||||
print(f"Prompt: '{prompt}' -> {len(tokens)} tokens: {tokens}")
|
||||
num_generate = 3
|
||||
|
||||
if export_mode == "pt2":
|
||||
from luminal.pt2 import compile as luminal_compile
|
||||
|
||||
dummy = torch.tensor([[1, 2, 3, 4]])
|
||||
compiled = luminal_compile(model, dummy, search_iterations=0, dynamic_dim=1)
|
||||
|
||||
for step in range(num_generate):
|
||||
input_ids = torch.tensor([tokens])
|
||||
logits = compiled(input_ids)[0]
|
||||
|
||||
with torch.no_grad():
|
||||
ref = model(input_ids)
|
||||
|
||||
assert torch.allclose(logits, ref.logits, atol=1e-3), (
|
||||
f"step {step}: max_diff={torch.max(torch.abs(logits - ref.logits)).item():.2e}"
|
||||
)
|
||||
|
||||
next_token = ref.logits[0, -1, :].argmax().item()
|
||||
tokens.append(next_token)
|
||||
print(f"Step {step}: '{tokenizer.decode(tokens)}'")
|
||||
else:
|
||||
# ONNX path — manual export with dynamic_axes
|
||||
dummy = torch.tensor([[1, 2, 3, 4]])
|
||||
tmp = tempfile.NamedTemporaryFile(suffix=".onnx", delete=False)
|
||||
tmp_path = tmp.name
|
||||
tmp.close()
|
||||
|
||||
try:
|
||||
torch.onnx.export(
|
||||
model,
|
||||
(dummy,),
|
||||
tmp_path,
|
||||
opset_version=20,
|
||||
input_names=["input_ids"],
|
||||
output_names=["logits"],
|
||||
dynamic_axes={"input_ids": {1: "seq_len"}, "logits": {1: "seq_len"}},
|
||||
)
|
||||
print("Exported onnx")
|
||||
graph = luminal.process_onnx(tmp_path, backend)
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
print("Exported Model")
|
||||
assert graph.has_dynamic_dims, "Graph should have dynamic dims"
|
||||
assert "seq_len" in graph.dim_params, (
|
||||
f"Expected 'seq_len' in {graph.dim_params}"
|
||||
)
|
||||
|
||||
for step in range(num_generate):
|
||||
seq_len = len(tokens)
|
||||
graph.set_dim("seq_len", seq_len)
|
||||
|
||||
graph.set_input("input_ids", [float(t) for t in tokens])
|
||||
graph.run()
|
||||
|
||||
output_shapes = graph.resolve_output_shapes()
|
||||
logits_data = graph.get_output("logits")
|
||||
logits = torch.tensor(logits_data, dtype=torch.float32).reshape(
|
||||
output_shapes[0]
|
||||
)
|
||||
|
||||
# Compare against PyTorch reference
|
||||
input_ids = torch.tensor([tokens])
|
||||
with torch.no_grad():
|
||||
ref = model(input_ids)
|
||||
|
||||
assert torch.allclose(logits, ref.logits, atol=1e-3), (
|
||||
f"step {step}: max_diff={torch.max(torch.abs(logits - ref.logits)).item():.2e}"
|
||||
)
|
||||
|
||||
next_token = ref.logits[0, -1, :].argmax().item()
|
||||
tokens.append(next_token)
|
||||
print(f"Step {step}: '{tokenizer.decode(tokens)}'")
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_hf_llama3_full(device: torch.device):
|
||||
"""HuggingFace LlamaForCausalLM — full Llama3.2-1B with real pretrained weights.
|
||||
|
||||
No config alterations except use_cache=False and eager attention.
|
||||
Loads actual weights from NousResearch/Llama-3.2-1B.
|
||||
"""
|
||||
from transformers import AutoConfig, LlamaForCausalLM
|
||||
|
||||
config = AutoConfig.from_pretrained("NousResearch/Llama-3.2-1B")
|
||||
config.use_cache = False
|
||||
config._attn_implementation = "eager"
|
||||
@@ -358,18 +246,41 @@ def test_hf_llama3_full(device: torch.device):
|
||||
.eval()
|
||||
.to(device)
|
||||
)
|
||||
compiled = torch.compile(model, backend=luminal_backend)
|
||||
input_ids = torch.tensor([[1, 2, 3, 4]], device=device)
|
||||
with torch.no_grad():
|
||||
ref = model(input_ids)
|
||||
out = compiled(input_ids)
|
||||
assert torch.allclose(out.logits, ref.logits, atol=1e-3), (
|
||||
f"max_diff={torch.max(torch.abs(out.logits - ref.logits)).item():.2e}"
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-3.2-1B")
|
||||
|
||||
prompt = "The capital of france is"
|
||||
tokens = tokenizer.encode(prompt)
|
||||
print(f"Prompt: '{prompt}' -> {len(tokens)} tokens: {tokens}")
|
||||
num_generate = 3
|
||||
|
||||
for step in range(num_generate):
|
||||
input_ids = torch.tensor([tokens], device=device)
|
||||
torch._dynamo.reset()
|
||||
compiled = torch.compile(model, backend=luminal_backend)
|
||||
with torch.no_grad():
|
||||
ref = model(input_ids)
|
||||
out = compiled(input_ids)
|
||||
assert torch.allclose(out.logits, ref.logits, atol=1e-5), (
|
||||
f"step {step}: max_diff={torch.max(torch.abs(out.logits - ref.logits)).item():.2e}"
|
||||
)
|
||||
next_token = ref.logits[0, -1, :].argmax().item()
|
||||
tokens.append(next_token)
|
||||
print(f"Step {step}: '{tokenizer.decode(tokens)}'")
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_hf_llama38b_full(device: torch.device):
|
||||
def _gpu_mem(label):
|
||||
"""Print GPU memory stats at a given checkpoint."""
|
||||
if torch.cuda.is_available():
|
||||
alloc = torch.cuda.memory_allocated() / (1024**3)
|
||||
reserved = torch.cuda.memory_reserved() / (1024**3)
|
||||
peak = torch.cuda.max_memory_allocated() / (1024**3)
|
||||
print(
|
||||
f"[GPU MEM] {label}: allocated={alloc:.3f} GiB, reserved={reserved:.3f} GiB, peak={peak:.3f} GiB"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="numerical precision — max_diff exceeds atol")
|
||||
def test_hf_llama3_full(device: torch.device):
|
||||
"""HuggingFace LlamaForCausalLM — full Llama3.2-1B with real pretrained weights.
|
||||
|
||||
No config alterations except use_cache=False and eager attention.
|
||||
@@ -377,6 +288,57 @@ def test_hf_llama38b_full(device: torch.device):
|
||||
"""
|
||||
from transformers import AutoConfig, LlamaForCausalLM
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
_gpu_mem("before model load")
|
||||
|
||||
config = AutoConfig.from_pretrained("NousResearch/Llama-3.2-1B")
|
||||
config.use_cache = False
|
||||
config._attn_implementation = "eager"
|
||||
|
||||
model = (
|
||||
LlamaForCausalLM.from_pretrained(
|
||||
"NousResearch/Llama-3.2-1B",
|
||||
config=config,
|
||||
torch_dtype=torch.float32,
|
||||
)
|
||||
.eval()
|
||||
.to(device)
|
||||
)
|
||||
n_params = sum(p.numel() for p in model.parameters())
|
||||
print(
|
||||
f"[MODEL] Total parameters: {n_params:,} ({n_params * 4 / 1024**3:.3f} GiB in f32)"
|
||||
)
|
||||
_gpu_mem("after model load")
|
||||
|
||||
compiled = torch.compile(model, backend=luminal_backend)
|
||||
_gpu_mem("after torch.compile (lazy, no compilation yet)")
|
||||
|
||||
input_ids = torch.tensor([[1, 2, 3, 4]], device=device)
|
||||
with torch.no_grad():
|
||||
ref = model(input_ids)
|
||||
_gpu_mem("after PyTorch reference forward")
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
_gpu_mem("before compiled forward (peak reset)")
|
||||
out = compiled(input_ids)
|
||||
_gpu_mem("after compiled forward (includes compilation)")
|
||||
|
||||
assert torch.allclose(out.logits, ref.logits, atol=1e-5), (
|
||||
f"max_diff={torch.max(torch.abs(out.logits - ref.logits)).item():.2e}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="numerical precision — max_diff exceeds atol")
|
||||
def test_hf_llama3_large_full(device: torch.device):
|
||||
"""HuggingFace LlamaForCausalLM — full Llama-3.1-8B-Instruct with real pretrained weights.
|
||||
|
||||
No config alterations except use_cache=False and eager attention.
|
||||
Loads actual weights from NousResearch/Meta-Llama-3.1-8B-Instruct.
|
||||
"""
|
||||
from transformers import AutoConfig, LlamaForCausalLM
|
||||
|
||||
config = AutoConfig.from_pretrained("NousResearch/Meta-Llama-3.1-8B-Instruct")
|
||||
config.use_cache = False
|
||||
config._attn_implementation = "eager"
|
||||
@@ -395,79 +357,95 @@ def test_hf_llama38b_full(device: torch.device):
|
||||
with torch.no_grad():
|
||||
ref = model(input_ids)
|
||||
out = compiled(input_ids)
|
||||
assert torch.allclose(out.logits, ref.logits, atol=1e-3), (
|
||||
assert torch.allclose(out.logits, ref.logits, atol=1e-5), (
|
||||
f"max_diff={torch.max(torch.abs(out.logits - ref.logits)).item():.2e}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="numerical precision — max_diff exceeds atol")
|
||||
def test_hf_llama38b_full(device: torch.device):
|
||||
"""HuggingFace LlamaForCausalLM — full Llama-3.1-8B-Instruct with real pretrained weights.
|
||||
|
||||
No config alterations except use_cache=False and eager attention.
|
||||
Loads actual weights from NousResearch/Meta-Llama-3.1-8B-Instruct.
|
||||
"""
|
||||
from transformers import AutoConfig, LlamaForCausalLM
|
||||
|
||||
config = AutoConfig.from_pretrained("NousResearch/Meta-Llama-3.1-8B-Instruct")
|
||||
config.use_cache = False
|
||||
config._attn_implementation = "eager"
|
||||
|
||||
model = (
|
||||
LlamaForCausalLM.from_pretrained(
|
||||
"NousResearch/Meta-Llama-3.1-8B-Instruct",
|
||||
config=config,
|
||||
torch_dtype=torch.float32,
|
||||
)
|
||||
.eval()
|
||||
.to(device)
|
||||
)
|
||||
compiled = torch.compile(model, backend=luminal_backend)
|
||||
input_ids = torch.tensor([[1, 2, 3, 4]], device=device)
|
||||
with torch.no_grad():
|
||||
ref = model(input_ids)
|
||||
out = compiled(input_ids)
|
||||
assert torch.allclose(out.logits, ref.logits, atol=1e-5), (
|
||||
f"max_diff={torch.max(torch.abs(out.logits - ref.logits)).item():.2e}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_hf_llama38b_cached():
|
||||
"""Llama 3.1-8B via pre-generated artifacts + reference logits.
|
||||
|
||||
Supports both ONNX and PT2 export modes via LUMINAL_EXPORT_MODE env var.
|
||||
|
||||
Requires artifacts generated by:
|
||||
ONNX: uv run python tests/generate_llama38b_artifacts.py
|
||||
PT2: uv run python tests/generate_llama38b_pt2_artifacts.py
|
||||
"""
|
||||
def test_hf_llama38b_cached_onnx(
|
||||
llama38b_onnx_path, llama38b_ref_logits: torch.Tensor
|
||||
):
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import luminal
|
||||
|
||||
backend = os.environ.get("LUMINAL_BACKEND", "cuda")
|
||||
export_mode = os.getenv("LUMINAL_EXPORT_MODE", "onnx").lower()
|
||||
|
||||
tests_dir = Path(__file__).resolve().parent
|
||||
logits_path = tests_dir / "llama38b_ref_logits.pt"
|
||||
graph = luminal.process_onnx(str(llama38b_onnx_path), backend)
|
||||
print("Compiled luminal ONNX graph")
|
||||
|
||||
assert logits_path.exists(), (
|
||||
f"{logits_path} not found. Run: uv run python tests/generate_llama38b_artifacts.py"
|
||||
)
|
||||
ref_logits = torch.load(logits_path, weights_only=True)
|
||||
print(f"Loaded reference logits: {ref_logits.shape}")
|
||||
graph.set_input("input_ids", [float(t) for t in [1, 2, 3, 4]])
|
||||
graph.run()
|
||||
|
||||
if export_mode == "pt2":
|
||||
from luminal import CompiledModel
|
||||
|
||||
pt2_path = tests_dir / "llama38b.pt2"
|
||||
weights_path = tests_dir / "llama38b_weights.safetensors"
|
||||
|
||||
assert pt2_path.exists(), (
|
||||
f"{pt2_path} not found. Run: uv run python tests/generate_llama38b_pt2_artifacts.py"
|
||||
)
|
||||
assert weights_path.exists(), (
|
||||
f"{weights_path} not found. Run: uv run python tests/generate_llama38b_pt2_artifacts.py"
|
||||
)
|
||||
|
||||
backend_name = "cuda" if backend == "cuda" else "cpu"
|
||||
compiled_inner = luminal.compile_pt2(
|
||||
str(pt2_path), str(weights_path), backend_name, 0
|
||||
)
|
||||
compiled = CompiledModel(compiled_inner)
|
||||
print("Compiled luminal PT2 graph")
|
||||
|
||||
input_ids = torch.tensor([[1, 2, 3, 4]])
|
||||
logits = compiled(input_ids)[0]
|
||||
else:
|
||||
onnx_path = tests_dir / "llama38b.onnx"
|
||||
|
||||
assert onnx_path.exists(), (
|
||||
f"{onnx_path} not found. Run: uv run python tests/generate_llama38b_artifacts.py"
|
||||
)
|
||||
|
||||
graph = luminal.process_onnx(str(onnx_path), backend)
|
||||
print("Compiled luminal ONNX graph")
|
||||
|
||||
graph.set_input("input_ids", [float(t) for t in [1, 2, 3, 4]])
|
||||
graph.run()
|
||||
|
||||
logits_data = graph.get_output("logits")
|
||||
logits_shape = graph.output_shapes[0]
|
||||
logits = torch.tensor(logits_data, dtype=torch.float32).reshape(logits_shape)
|
||||
logits_data = graph.get_output("logits")
|
||||
logits_shape = graph.output_shapes[0]
|
||||
logits = torch.tensor(logits_data, dtype=torch.float32).reshape(logits_shape)
|
||||
|
||||
print(f"Loaded reference logits: {llama38b_ref_logits.shape}")
|
||||
print(f"Output logits shape: {logits.shape}")
|
||||
|
||||
assert torch.allclose(logits, ref_logits, atol=1e-3), (
|
||||
f"max_diff={torch.max(torch.abs(logits - ref_logits)).item():.2e}"
|
||||
assert torch.allclose(logits, llama38b_ref_logits, atol=1e-3), (
|
||||
f"max_diff={torch.max(torch.abs(logits - llama38b_ref_logits)).item():.2e}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_hf_llama38b_cached_pt2(
|
||||
llama38b_pt2_path, llama38b_weights_path, llama38b_ref_logits: torch.Tensor
|
||||
):
|
||||
import os
|
||||
|
||||
import luminal
|
||||
from luminal import CompiledModel
|
||||
|
||||
backend = os.environ.get("LUMINAL_BACKEND", "cuda")
|
||||
backend_name = "cuda" if backend == "cuda" else "cpu"
|
||||
|
||||
compiled_inner = luminal.compile_pt2(
|
||||
str(llama38b_pt2_path), str(llama38b_weights_path), backend_name, 0
|
||||
)
|
||||
compiled = CompiledModel(compiled_inner)
|
||||
print("Compiled luminal PT2 graph")
|
||||
|
||||
input_ids = torch.tensor([[1, 2, 3, 4]])
|
||||
logits = compiled(input_ids)[0]
|
||||
|
||||
print(f"Loaded reference logits: {llama38b_ref_logits.shape}")
|
||||
print(f"Output logits shape: {logits.shape}")
|
||||
|
||||
assert torch.allclose(logits, llama38b_ref_logits, atol=1e-3), (
|
||||
f"max_diff={torch.max(torch.abs(logits - llama38b_ref_logits)).item():.2e}"
|
||||
)
|
||||
|
||||
@@ -41,9 +41,6 @@ class AddAddTestModel(torch.nn.Module):
|
||||
|
||||
|
||||
class AddConstantTestModel(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
return x + 10
|
||||
|
||||
@@ -59,25 +56,16 @@ class LinearLayerModel(torch.nn.Module):
|
||||
|
||||
|
||||
class SqrtTestModel(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x.sqrt()
|
||||
|
||||
|
||||
class SinTestModel(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.sin(x)
|
||||
|
||||
|
||||
class CosTestModel(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.cos(x)
|
||||
|
||||
@@ -94,9 +82,6 @@ class SubTestModel(torch.nn.Module):
|
||||
class TransposeTestModel(torch.nn.Module):
|
||||
"""Test basic 2D transpose (matrix transpose)."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x.t() # 2D transpose
|
||||
|
||||
@@ -104,9 +89,6 @@ class TransposeTestModel(torch.nn.Module):
|
||||
class Transpose3DTestModel(torch.nn.Module):
|
||||
"""Test 3D transpose with explicit permutation."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x.permute(2, 0, 1) # Rotate dimensions
|
||||
|
||||
@@ -114,9 +96,6 @@ class Transpose3DTestModel(torch.nn.Module):
|
||||
class Transpose4DTestModel(torch.nn.Module):
|
||||
"""Test 4D transpose (NCHW -> NHWC)."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x.permute(0, 2, 3, 1) # Common in CNNs
|
||||
|
||||
@@ -124,9 +103,6 @@ class Transpose4DTestModel(torch.nn.Module):
|
||||
class TransposeReverseTestModel(torch.nn.Module):
|
||||
"""Test reverse permutation (default transpose behavior)."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
dims = list(range(x.ndim))
|
||||
return x.permute(*reversed(dims))
|
||||
@@ -151,9 +127,6 @@ class TransposeInExpressionModel(torch.nn.Module):
|
||||
class ConstantScalarFloatModel(torch.nn.Module):
|
||||
"""Test scalar constant (broadcasts to input shape)."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
constant = torch.tensor(10.5).to(x.device)
|
||||
return x + constant
|
||||
@@ -162,9 +135,6 @@ class ConstantScalarFloatModel(torch.nn.Module):
|
||||
class Constant1DArrayFloatModel(torch.nn.Module):
|
||||
"""Test 1D array constant."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
constant = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0]).to(x.device)
|
||||
return x * constant
|
||||
@@ -173,9 +143,6 @@ class Constant1DArrayFloatModel(torch.nn.Module):
|
||||
class Constant2DMatrixFloatModel(torch.nn.Module):
|
||||
"""Test 2D matrix constant."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
constant = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]).to(x.device)
|
||||
return x + constant
|
||||
@@ -184,9 +151,6 @@ class Constant2DMatrixFloatModel(torch.nn.Module):
|
||||
class ConstantRawDataFloatModel(torch.nn.Module):
|
||||
"""Test constant with specific values (tests raw data format)."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
constant = torch.tensor([7.5, 8.5, 9.5]).to(x.device)
|
||||
return x + constant
|
||||
@@ -195,9 +159,6 @@ class ConstantRawDataFloatModel(torch.nn.Module):
|
||||
class ConstantInt32ConversionModel(torch.nn.Module):
|
||||
"""Test INT32 constant values (PyTorch exports as integers)."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
constant = torch.tensor([1, 2, 3, 4, 5], dtype=torch.int32).to(x.device)
|
||||
return x + constant.float()
|
||||
@@ -206,9 +167,6 @@ class ConstantInt32ConversionModel(torch.nn.Module):
|
||||
class ConstantInt64ConversionModel(torch.nn.Module):
|
||||
"""Test INT64 constant values."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
constant = torch.tensor([100, 200, 300], dtype=torch.int64).to(x.device)
|
||||
return x * constant.float()
|
||||
@@ -217,9 +175,6 @@ class ConstantInt64ConversionModel(torch.nn.Module):
|
||||
class ConstantFloat64ConversionModel(torch.nn.Module):
|
||||
"""Test FLOAT64 (double) constant values."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
constant = torch.tensor([1.5, 2.5, 3.5], dtype=torch.float64).to(x.device)
|
||||
return x * constant.float()
|
||||
@@ -228,9 +183,6 @@ class ConstantFloat64ConversionModel(torch.nn.Module):
|
||||
class ConstantBoolConversionModel(torch.nn.Module):
|
||||
"""Test boolean constant values (converted to 0.0/1.0)."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
constant = torch.tensor([True, False, True, False, True], dtype=torch.bool).to(
|
||||
x.device
|
||||
@@ -241,9 +193,6 @@ class ConstantBoolConversionModel(torch.nn.Module):
|
||||
class ConstantInt64RawDataModel(torch.nn.Module):
|
||||
"""Test INT64 constant with large values (tests raw data path)."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
constant = torch.tensor([1000, 2000, 3000], dtype=torch.int64).to(x.device)
|
||||
return x + constant.float()
|
||||
@@ -252,9 +201,6 @@ class ConstantInt64RawDataModel(torch.nn.Module):
|
||||
class ConstantNegativeValuesModel(torch.nn.Module):
|
||||
"""Test negative constant values."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
constant = torch.tensor([-5.0, -10.0, -15.0]).to(x.device)
|
||||
return x + constant
|
||||
@@ -263,9 +209,6 @@ class ConstantNegativeValuesModel(torch.nn.Module):
|
||||
class ConstantZeroValueModel(torch.nn.Module):
|
||||
"""Test all-zero constant."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
constant = torch.tensor([0.0, 0.0, 0.0, 0.0]).to(x.device)
|
||||
return x * constant
|
||||
@@ -274,9 +217,6 @@ class ConstantZeroValueModel(torch.nn.Module):
|
||||
class ConstantMultipleInGraphModel(torch.nn.Module):
|
||||
"""Test multiple constants in one graph."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
const1 = torch.tensor([10.0, 20.0, 30.0]).to(x.device)
|
||||
const2 = torch.tensor([1.0, 2.0, 3.0]).to(x.device)
|
||||
@@ -290,9 +230,6 @@ class ConstantMultipleInGraphModel(torch.nn.Module):
|
||||
class CastDoubleToFloatModel(torch.nn.Module):
|
||||
"""Test downcast: Double (FLOAT64) -> Float."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# Input will be float64, cast to float32
|
||||
return x.to(torch.float32)
|
||||
@@ -301,9 +238,6 @@ class CastDoubleToFloatModel(torch.nn.Module):
|
||||
class CastInt32ToFloatModel(torch.nn.Module):
|
||||
"""Test INT32 -> Float conversion."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x.to(torch.float32)
|
||||
|
||||
@@ -311,9 +245,6 @@ class CastInt32ToFloatModel(torch.nn.Module):
|
||||
class CastInt64ToFloatModel(torch.nn.Module):
|
||||
"""Test INT64 -> Float conversion."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x.to(torch.float32)
|
||||
|
||||
@@ -321,9 +252,6 @@ class CastInt64ToFloatModel(torch.nn.Module):
|
||||
class CastBoolToFloatModel(torch.nn.Module):
|
||||
"""Test BOOL -> Float conversion (non-zero -> 1.0, zero -> 0.0)."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x.to(torch.float32)
|
||||
|
||||
@@ -331,9 +259,6 @@ class CastBoolToFloatModel(torch.nn.Module):
|
||||
class CastInComputationGraphModel(torch.nn.Module):
|
||||
"""Test Cast node followed by an operation (Cast + Add)."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
casted = x.to(torch.float32)
|
||||
constant = torch.tensor([2.0, 2.0, 2.0]).to(x.device)
|
||||
@@ -343,9 +268,6 @@ class CastInComputationGraphModel(torch.nn.Module):
|
||||
class CastWith2DTensorModel(torch.nn.Module):
|
||||
"""Test Cast with 2D tensor (matrix)."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x.to(torch.float32)
|
||||
|
||||
@@ -353,9 +275,6 @@ class CastWith2DTensorModel(torch.nn.Module):
|
||||
class CastNegativeValuesModel(torch.nn.Module):
|
||||
"""Test Cast with negative integer values."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x.to(torch.float32)
|
||||
|
||||
@@ -363,9 +282,6 @@ class CastNegativeValuesModel(torch.nn.Module):
|
||||
class CastScalarValueModel(torch.nn.Module):
|
||||
"""Test Cast with scalar (single element)."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x.to(torch.float32)
|
||||
|
||||
@@ -389,9 +305,6 @@ class ModTestModel(torch.nn.Module):
|
||||
class ModByConstantModel(torch.nn.Module):
|
||||
"""Tests modulo with an inline constant tensor (ONNX Constant node)."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
constant = torch.tensor([3.0, 4.0, 5.0]).to(x.device)
|
||||
return x.fmod(constant)
|
||||
|
||||
@@ -1,154 +1,118 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch._dynamo
|
||||
from test_models import (
|
||||
SigmoidTestModel,
|
||||
SigmoidInExpressionModel,
|
||||
TanhTestModel,
|
||||
TanhInExpressionModel,
|
||||
ReluTestModel,
|
||||
ReluAllNegativeModel,
|
||||
ReluInExpressionModel,
|
||||
AbsTestModel,
|
||||
AbsAllNegativeModel,
|
||||
AbsInExpressionModel,
|
||||
NegTestModel,
|
||||
NegAllPositiveModel,
|
||||
NegInExpressionModel,
|
||||
ClipTestModel,
|
||||
ClipMinOnlyTestModel,
|
||||
ClipMaxOnlyTestModel,
|
||||
)
|
||||
import test_models as tm
|
||||
|
||||
from luminal import luminal_backend
|
||||
|
||||
# ── Sigmoid ──────────────────────────────────────────────────────────────────
|
||||
|
||||
Args = tuple[torch.Tensor, ...]
|
||||
Kwargs = dict[str, torch.Tensor]
|
||||
InputFactory = Callable[[torch.device], tuple[Args, Kwargs]]
|
||||
|
||||
|
||||
def test_sigmoid(device):
|
||||
model = SigmoidTestModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x = torch.rand((5, 5), device=device) * 2 - 1 # mixed positive/negative
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
@dataclass(frozen=True)
|
||||
class UnaryCase:
|
||||
id: str
|
||||
model_factory: Callable[[], torch.nn.Module]
|
||||
input_factory: InputFactory
|
||||
|
||||
|
||||
def test_sigmoid_in_expression(device):
|
||||
model = SigmoidInExpressionModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x = torch.rand((5, 5), device=device)
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
UNARY_CASES: list[UnaryCase] = [
|
||||
UnaryCase(
|
||||
id="sigmoid",
|
||||
model_factory=tm.SigmoidTestModel,
|
||||
input_factory=lambda device: ((torch.rand((5, 5), device=device) * 2 - 1,), {}),
|
||||
),
|
||||
UnaryCase(
|
||||
id="sigmoid_in_expression",
|
||||
model_factory=tm.SigmoidInExpressionModel,
|
||||
input_factory=lambda device: ((torch.rand((5, 5), device=device),), {}),
|
||||
),
|
||||
UnaryCase(
|
||||
id="tanh",
|
||||
model_factory=tm.TanhTestModel,
|
||||
input_factory=lambda device: ((torch.rand((5, 5), device=device) * 2 - 1,), {}),
|
||||
),
|
||||
UnaryCase(
|
||||
id="tanh_in_expression",
|
||||
model_factory=tm.TanhInExpressionModel,
|
||||
input_factory=lambda device: ((torch.rand((5, 5), device=device),), {}),
|
||||
),
|
||||
UnaryCase(
|
||||
id="relu",
|
||||
model_factory=tm.ReluTestModel,
|
||||
input_factory=lambda device: ((torch.rand((5, 5), device=device) * 2 - 1,), {}),
|
||||
),
|
||||
UnaryCase(
|
||||
id="relu_all_negative",
|
||||
model_factory=tm.ReluAllNegativeModel,
|
||||
input_factory=lambda device: ((-torch.rand((5, 5), device=device),), {}),
|
||||
),
|
||||
UnaryCase(
|
||||
id="relu_in_expression",
|
||||
model_factory=tm.ReluInExpressionModel,
|
||||
input_factory=lambda device: ((torch.rand((5, 5), device=device) * 2 - 1,), {}),
|
||||
),
|
||||
UnaryCase(
|
||||
id="abs",
|
||||
model_factory=tm.AbsTestModel,
|
||||
input_factory=lambda device: ((torch.rand((5, 5), device=device) * 2 - 1,), {}),
|
||||
),
|
||||
UnaryCase(
|
||||
id="abs_all_negative",
|
||||
model_factory=tm.AbsAllNegativeModel,
|
||||
input_factory=lambda device: ((-torch.rand((5, 5), device=device),), {}),
|
||||
),
|
||||
UnaryCase(
|
||||
id="abs_in_expression",
|
||||
model_factory=tm.AbsInExpressionModel,
|
||||
input_factory=lambda device: ((torch.rand((5, 5), device=device) * 2 - 1,), {}),
|
||||
),
|
||||
UnaryCase(
|
||||
id="neg",
|
||||
model_factory=tm.NegTestModel,
|
||||
input_factory=lambda device: ((torch.rand((5, 5), device=device) * 2 - 1,), {}),
|
||||
),
|
||||
UnaryCase(
|
||||
id="neg_all_positive",
|
||||
model_factory=tm.NegAllPositiveModel,
|
||||
input_factory=lambda device: ((torch.rand((5, 5), device=device),), {}),
|
||||
),
|
||||
UnaryCase(
|
||||
id="neg_in_expression",
|
||||
model_factory=tm.NegInExpressionModel,
|
||||
input_factory=lambda device: ((torch.rand((5, 5), device=device) * 2 - 1,), {}),
|
||||
),
|
||||
UnaryCase(
|
||||
id="clip",
|
||||
model_factory=tm.ClipTestModel,
|
||||
input_factory=lambda device: ((torch.rand((5, 5), device=device) * 4 - 2,), {}),
|
||||
),
|
||||
UnaryCase(
|
||||
id="clip_min_only",
|
||||
model_factory=tm.ClipMinOnlyTestModel,
|
||||
input_factory=lambda device: ((torch.rand((5, 5), device=device) * 4 - 2,), {}),
|
||||
),
|
||||
UnaryCase(
|
||||
id="clip_max_only",
|
||||
model_factory=tm.ClipMaxOnlyTestModel,
|
||||
input_factory=lambda device: ((torch.rand((5, 5), device=device) * 4 - 2,), {}),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
# ── Tanh ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_tanh(device):
|
||||
model = TanhTestModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x = torch.rand((5, 5), device=device) * 2 - 1
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
|
||||
|
||||
def test_tanh_in_expression(device):
|
||||
model = TanhInExpressionModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x = torch.rand((5, 5), device=device)
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
|
||||
|
||||
# ── Relu ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_relu(device):
|
||||
model = ReluTestModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x = torch.rand((5, 5), device=device) * 2 - 1 # mixed positive/negative
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
|
||||
|
||||
def test_relu_all_negative(device):
|
||||
model = ReluAllNegativeModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x = -torch.rand((5, 5), device=device) # all negative -> output all zeros
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
|
||||
|
||||
def test_relu_in_expression(device):
|
||||
model = ReluInExpressionModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x = torch.rand((5, 5), device=device) * 2 - 1
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
|
||||
|
||||
# ── Abs ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_abs(device):
|
||||
model = AbsTestModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x = torch.rand((5, 5), device=device) * 2 - 1 # mixed positive/negative
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
|
||||
|
||||
def test_abs_all_negative(device):
|
||||
model = AbsAllNegativeModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x = -torch.rand((5, 5), device=device) # all negative
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
|
||||
|
||||
def test_abs_in_expression(device):
|
||||
model = AbsInExpressionModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x = torch.rand((5, 5), device=device) * 2 - 1
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
|
||||
|
||||
# ── Neg ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_neg(device):
|
||||
model = NegTestModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x = torch.rand((5, 5), device=device) * 2 - 1 # mixed positive/negative
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
|
||||
|
||||
def test_neg_all_positive(device):
|
||||
model = NegAllPositiveModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x = torch.rand((5, 5), device=device) # all positive -> output all negative
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
|
||||
|
||||
def test_neg_in_expression(device):
|
||||
model = NegInExpressionModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x = torch.rand((5, 5), device=device) * 2 - 1
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
|
||||
|
||||
# ── Clip ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_clip(device):
|
||||
"""Clip tensor values to [-0.5, 0.5]."""
|
||||
model = ClipTestModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x = torch.rand((5, 5), device=device) * 4 - 2 # range [-2, 2]
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
|
||||
|
||||
def test_clip_min_only(device):
|
||||
"""Clip tensor values to [0.0, +inf]."""
|
||||
model = ClipMinOnlyTestModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x = torch.rand((5, 5), device=device) * 4 - 2
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
|
||||
|
||||
def test_clip_max_only(device):
|
||||
"""Clip tensor values to [-inf, 0.5]."""
|
||||
model = ClipMaxOnlyTestModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x = torch.rand((5, 5), device=device) * 4 - 2
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
class TestUnaryOps:
|
||||
@pytest.mark.parametrize("case", UNARY_CASES, ids=lambda case: case.id)
|
||||
def test_matches_eager(self, case: UnaryCase, device: torch.device) -> None:
|
||||
model = case.model_factory().to(device)
|
||||
compiled_model = torch.compile(model, backend=luminal_backend)
|
||||
args, kwargs = case.input_factory(device)
|
||||
torch.testing.assert_close(
|
||||
compiled_model(*args, **kwargs),
|
||||
model(*args, **kwargs),
|
||||
atol=1e-5,
|
||||
rtol=1e-5,
|
||||
)
|
||||
|
||||
20
skills-lock.json
Normal file
20
skills-lock.json
Normal file
@@ -0,0 +1,20 @@
|
||||
{
|
||||
"version": 1,
|
||||
"skills": {
|
||||
"ruff": {
|
||||
"source": "astral-sh/claude-code-plugins",
|
||||
"sourceType": "github",
|
||||
"computedHash": "905f1c2b66e722ab534d7cbeceafb6ee37d32e3bfe51f3e68c661c51eb19a776"
|
||||
},
|
||||
"ty": {
|
||||
"source": "astral-sh/claude-code-plugins",
|
||||
"sourceType": "github",
|
||||
"computedHash": "5d82b319a84d296fae47f2664228bfac1a36dd832cb487b485a65ec36b3df21f"
|
||||
},
|
||||
"uv": {
|
||||
"source": "astral-sh/claude-code-plugins",
|
||||
"sourceType": "github",
|
||||
"computedHash": "b1ca9e9826d906f6ee6ea172da50018ec13bb622ac9ba7204bda8b10e7fc8d0a"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -254,6 +254,7 @@ impl Graph {
|
||||
if subgraphs.len() <= 1 {
|
||||
let (program, root) = hlir_to_egglog(self);
|
||||
self.egraphs = vec![run_egglog(&program, &root, &ops, cleanup_hlir).unwrap()];
|
||||
|
||||
self.chunk_groups = vec![ChunkGroup {
|
||||
representative: 0,
|
||||
members: vec![0],
|
||||
@@ -579,7 +580,6 @@ impl Graph {
|
||||
|
||||
for (group_idx, group) in self.chunk_groups.iter().enumerate() {
|
||||
let egraph = &self.egraphs[group_idx];
|
||||
|
||||
let mut prev_selected: FxHashSet<u64> = FxHashSet::default();
|
||||
let mut list_cache = FxHashMap::default();
|
||||
let mut expr_cache = FxHashMap::default();
|
||||
|
||||
155
src/hlir.rs
155
src/hlir.rs
@@ -149,6 +149,7 @@ pub type HLIROps = (
|
||||
Scatter,
|
||||
SumReduce,
|
||||
MaxReduce,
|
||||
Softmax,
|
||||
);
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
@@ -1836,6 +1837,160 @@ impl NativeOp for MaxReduce {
|
||||
}
|
||||
}
|
||||
|
||||
// Fused Softmax: softmax(x, axis) = exp(x - max(x)) / sum(exp(x - max(x)))
|
||||
// A single HLIR op that replaces the 6-op decomposed chain.
|
||||
// On CUDA, KernelSoftmax provides a fused 3-pass kernel.
|
||||
// On native, NativeOp implements softmax directly.
|
||||
#[derive(Debug, Clone, Default, PartialEq)]
|
||||
pub struct Softmax {
|
||||
pub axis: usize,
|
||||
pub input_shape: ShapeTracker,
|
||||
// Extracted fields (populated during egglog extraction, used by NativeOp)
|
||||
pub shape: Vec<Expression>,
|
||||
pub in_strides: Vec<Expression>,
|
||||
pub reduce_dim: Expression,
|
||||
pub reduce_stride: Expression,
|
||||
}
|
||||
impl Display for Softmax {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "Softmax(axis={})", self.axis)
|
||||
}
|
||||
}
|
||||
|
||||
/// Sort for Softmax: (shape, in_strides, out_strides, reduce_dim, reduce_stride)
|
||||
pub fn softmax_sort(name: &str) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
name,
|
||||
&[
|
||||
("shape", ELIST),
|
||||
("in_strides", ELIST),
|
||||
("out_strides", ELIST),
|
||||
("reduce_dim", EXPRESSION),
|
||||
("reduce_stride", EXPRESSION),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
impl HLIROp for Softmax {
|
||||
fn to_egglog(&self, inputs: &[(NodeIndex, String)]) -> String {
|
||||
let reduce_dim = self.input_shape.dims[self.axis];
|
||||
let reduce_stride = self.input_shape.strides[self.axis];
|
||||
format!(
|
||||
"(Op (Softmax {} {} {} {} {}) {})",
|
||||
elist_to_egglog(&self.input_shape.dims),
|
||||
elist_to_egglog(&self.input_shape.strides),
|
||||
elist_to_egglog(&self.input_shape.contiguous().strides),
|
||||
reduce_dim.to_egglog(),
|
||||
reduce_stride.to_egglog(),
|
||||
ilist_egglog(&[&inputs[0].1]),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl EgglogOp for Softmax {
|
||||
fn sort(&self) -> SortDef {
|
||||
softmax_sort("Softmax")
|
||||
}
|
||||
fn cleanup(&self) -> bool {
|
||||
true
|
||||
}
|
||||
fn n_inputs(&self) -> usize {
|
||||
1
|
||||
}
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![dtype_propagation_op(&self.sort())]
|
||||
}
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
let shape = extract_expr_list(egraph, kind_children[0], list_cache, expr_cache).unwrap();
|
||||
let in_strides =
|
||||
extract_expr_list(egraph, kind_children[1], list_cache, expr_cache).unwrap();
|
||||
let reduce_dim = extract_expr(egraph, kind_children[3], expr_cache).unwrap();
|
||||
let reduce_stride = extract_expr(egraph, kind_children[4], expr_cache).unwrap();
|
||||
(
|
||||
LLIROp::new::<dyn NativeOp>(Box::new(Self {
|
||||
axis: 0,
|
||||
input_shape: ShapeTracker::default(),
|
||||
shape,
|
||||
in_strides,
|
||||
reduce_dim,
|
||||
reduce_stride,
|
||||
})),
|
||||
input_enodes,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl NativeOp for Softmax {
|
||||
fn execute(&self, inputs: Vec<&NativeData>, dyn_map: &FxHashMap<char, usize>) -> NativeData {
|
||||
match inputs[0] {
|
||||
NativeData::F32(a) => {
|
||||
// Use extracted fields (populated during egglog extraction)
|
||||
let dims: Vec<usize> = self
|
||||
.shape
|
||||
.iter()
|
||||
.map(|d| d.exec(dyn_map).unwrap())
|
||||
.collect();
|
||||
let n = self.reduce_dim.exec(dyn_map).unwrap();
|
||||
let mut reduce_stride_expr = self.reduce_stride;
|
||||
for (&var, &val) in dyn_map {
|
||||
reduce_stride_expr =
|
||||
reduce_stride_expr.substitute(var, Expression::from(val as i32));
|
||||
}
|
||||
|
||||
// Compute row index strides (all dims except last, since softmax is always last-dim)
|
||||
let ndim = dims.len();
|
||||
let out_size: usize = dims.iter().product();
|
||||
let mut out = vec![0.0f32; out_size];
|
||||
|
||||
// Use StridedIterator for the row dimensions
|
||||
let row_ind = StridedIterator::new(
|
||||
&self.shape[..ndim - 1],
|
||||
&self.in_strides[..ndim - 1],
|
||||
dyn_map,
|
||||
);
|
||||
|
||||
for (row_idx, in_base) in row_ind.enumerate() {
|
||||
// Pass 1: find max
|
||||
let mut max_val = f32::NEG_INFINITY;
|
||||
for i in 0..n {
|
||||
let val = a[in_base + reduce_stride_expr.exec_single_var(i)];
|
||||
if val > max_val {
|
||||
max_val = val;
|
||||
}
|
||||
}
|
||||
|
||||
// Pass 2: exp(x - max) and sum
|
||||
let mut sum = 0.0f32;
|
||||
let out_base = row_idx * n;
|
||||
for i in 0..n {
|
||||
let val =
|
||||
(a[in_base + reduce_stride_expr.exec_single_var(i)] - max_val).exp();
|
||||
out[out_base + i] = val;
|
||||
sum += val;
|
||||
}
|
||||
|
||||
// Pass 3: normalize
|
||||
let inv_sum = 1.0 / sum;
|
||||
for i in 0..n {
|
||||
out[out_base + i] *= inv_sum;
|
||||
}
|
||||
}
|
||||
|
||||
NativeData::F32(out)
|
||||
}
|
||||
_ => panic!("Softmax only supports F32"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait NativeOp: Debug + AsAny + Send + Sync {
|
||||
fn execute(&self, inputs: Vec<&NativeData>, dyn_map: &FxHashMap<char, usize>) -> NativeData;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user