Compare commits

..

54 Commits

Author SHA1 Message Date
Austin Glover
7ee5b54438 idiomatic changes 2026-04-06 23:05:06 +00:00
Austin Glover
389c05abeb auto rebuild, silence onnx, cache images. 2026-04-06 22:43:24 +00:00
Austin Glover
dcc2c9cbb4 hf tests 2026-04-06 22:42:58 +00:00
Austin Glover
a9af4c3923 test deps 2026-04-06 22:42:27 +00:00
Austin Glover
3092d0d68b skill slop 2026-04-06 22:29:34 +00:00
Austin Glover
8a2bd714ac test classes wip 2026-04-02 01:32:37 +00:00
Austin Glover
54a26a044c save codex data on container restart 2026-04-02 01:31:23 +00:00
Austin Glover
5a0d3f87cc Merge remote-tracking branch 'origin/main' into pytest-classes
# Conflicts:
#	crates/luminal_python/pyproject.toml
#	crates/luminal_python/tests/conftest.py
#	crates/luminal_python/tests/generate_llama38b_artifacts.py
#	crates/luminal_python/tests/test_llama3.py
2026-04-02 01:18:40 +00:00
Joe Fioti
a28b755245 Merge pull request #259 from luminal-ai/tucker_shared_pytorch_memory 2026-04-01 12:51:15 -07:00
Tucker Morgan
fd83534e53 Remove dead logical.rs stub from luminal_cuda_lite
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-01 18:58:12 +00:00
Tucker Morgan
b5d984c3fa Move KernelExp/KernelSigmoid to other_ops.rs and remove logical intermediaries
hlir.rs should only contain 1:1 HLIR op analogues. KernelExp and KernelSigmoid
are fused kernels, so they belong in other_ops.rs. Also removed the redundant
logical::Exp and logical::Sigmoid intermediary ops since the kernel ops match
HLIR patterns directly via their direct-fusion egglog rules.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-01 18:46:17 +00:00
Tucker Morgan
64a5ca41b5 Merge remote-tracking branch 'origin/main' into tucker_shared_pytorch_memory 2026-04-01 16:45:16 +00:00
Joe Fioti
9bda47714a Merge pull request #256 from luminal-ai/asglover/modal_ci_ready 2026-04-01 05:21:02 -07:00
Austin Glover
9e513b6589 Fix git safe.directory for pre-commit in CUDA clippy container
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-31 19:32:40 -07:00
Austin Glover
a62d728bd7 Fix CUDA clippy container image to luminal-docker
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-31 19:21:36 -07:00
Austin Glover
4114714d3f Rename clippy workflow to cuda-clippy and fix container image
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-31 19:17:47 -07:00
Austin Glover
6191597571 Remove Modal CUDA clippy job, now handled by T4 runner
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-31 17:10:17 -07:00
Austin Glover
253cd95ab0 Run clippy on T4 runner with CUDA container for full lint coverage
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-31 17:05:05 -07:00
Austin Glover
d7e396ba5b Gate Modal CI on 'modal-ready' label and convert CUDA tests to Modal
- Gate test-cuda.yml and test-python-cuda.yml behind 'modal-ready' label
- Convert CUDA clippy and unit tests from self-hosted runner to Modal
- Add ci/modal_cargo_test.py and ci/modal_cargo_clippy.py runners

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-31 16:03:17 -07:00
Joe Fioti
1a53626716 Merge pull request #260 from luminal-ai/nvidia-devcontainer-args 2026-03-31 15:55:21 -07:00
Austin Glover
4329d68adc Merge main and resolve workflow conflicts
Resolve conflicts from main's pre-commit migration and Modal pytest runner.
Split new lint jobs (ruff, ruff-format, metal-clippy) into individual files
and update test-python-cuda to use Modal runner from main.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-31 15:44:29 -07:00
Tucker Morgan
989e7e2d44 Fixing native tests 2026-03-31 21:27:34 +00:00
Tucker Morgan
019972cdd4 Fixing ruff lint issue 2026-03-31 20:46:17 +00:00
Tucker Morgan
d7a3f468bd Ruff formatting 2026-03-31 20:44:23 +00:00
Tucker Morgan
c504fbf8a1 Merge cleanip 2026-03-31 20:41:40 +00:00
Tucker Morgan
625be7f4da Merge origin/main into tucker_shared_pytorch_memory
Resolved conflicts:
- other_ops.rs: kept kernel_rewrite import, dropped unused compile_kernel
- lib.rs: kept weight_device_ptrs param, added validate_backend call
- runtime.rs: accepted two-phase CUDA init helpers from main
- compiled_model.py: kept weight_refs/user_indices/is_cuda fields
- pt2.py: kept original_weights tracking for zero-copy
- test_llama3.py: kept xfail + device param for dynamic test

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-31 20:27:30 +00:00
Tucker Morgan
c2a17a4854 Removing uneeded qwen3 moe test file 2026-03-31 19:04:29 +00:00
Tucker Morgan
5c60f1d768 Fixing up small things for review 2026-03-31 18:24:16 +00:00
Tucker Morgan
4c51e3ea84 Cargo fmt: 2026-03-31 16:44:03 +00:00
Tucker Morgan
846551aa6f Cargo clippy 2026-03-31 16:42:30 +00:00
Tucker Morgan
c26076bc75 Cargo fmt 2026-03-31 16:38:09 +00:00
Tucker Morgan
871629b770 fmt and clippy 2026-03-31 16:35:13 +00:00
Tucker Morgan
c6dfa9c62f Unify ONNX/PT2 compilation paths and extract shared helpers
Restructure so both ONNX and PT2 paths follow the same call flow:
  lib.rs (thin PyO3 wrapper)
    → onnx_translator.rs / pt2_compiled_model.rs (format-specific translate + compile)
      → compiled_graph.rs::parse_graph (shared backend pipeline)

Rust changes:
- Create onnx_translator.rs with compile_onnx() and translate_onnx()
  (moved from compiled_graph.rs and lib.rs)
- compiled_graph.rs now only contains shared code (GraphTranslation,
  WeightData, CompiledGraph, parse_graph)
- Cache label_map in CompiledGraph for O(1) set_weight_* lookups
- Move weight_device_ptrs into WeightData.device_ptrs
- Add search_iters param to process_onnx (parity with PT2)
- Fix .unwrap() → ? error propagation in ONNX file loading
- lib.rs reduced to thin PyO3 registration layer

Python changes:
- Extract _collect_weight_pointers(), _detect_backend(),
  _load_cpu_weights() shared helpers in main.py
- Both ONNX and PT2 paths use the same helpers
- Centralize _register_cache_serialization() in __init__.py
- CompiledModel: add input_names override, keep user_indices for
  torch.compile lifted-param filtering

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-30 23:29:15 +00:00
Tucker Morgan
90e3a915d7 Cargo fmt 2026-03-30 22:20:41 +00:00
Tucker Morgan
56cb237aa2 removing uneeded prints 2026-03-30 22:20:32 +00:00
Tucker Morgan
a2c42b35c8 Cleaning up qwen tests 2026-03-30 21:36:30 +00:00
Tucker Morgan
898204b2dd setting test right 2026-03-30 17:51:32 +00:00
Tucker Morgan
2c1a7f087f removing uneeded logs 2026-03-30 17:36:26 +00:00
Austin Glover
112d064700 remove unnecessary ignore 2026-03-28 00:39:34 +00:00
Austin Glover
c51c36fbcb add node for mcp servers 2026-03-28 00:37:46 +00:00
Austin Glover
ee372d464e ignore codex 2026-03-28 00:37:34 +00:00
Austin Glover
1bef1344d1 pytest native approach to caching 2026-03-28 00:36:41 +00:00
Austin Glover
2e27c29b47 Gate Modal CI on 'modal-ready' label and split workflows into one-job-per-file
Modal examples now only run on PRs when the 'modal-ready' label is applied,
preventing expensive GPU runs on every push. Split test.yml and lint.yml
into individual workflow files for clearer CI organization.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-27 15:45:40 -07:00
Austin Glover
8d41c491fd test prototype 2026-03-27 01:36:37 +00:00
Austin Glover
64f390a833 silence maturin logs 2026-03-27 01:36:23 +00:00
Austin Glover
8d20581f38 testing infra 2026-03-27 01:35:56 +00:00
Austin Glover
bfd4ae9b27 temp 2026-03-27 01:35:49 +00:00
Tucker Morgan
92e4260f1e Fixing weight stripping issues 2026-03-26 21:31:48 +00:00
Tucker Morgan
662a564efc Cleaning up a set of changes 2026-03-26 18:39:57 +00:00
Tucker Morgan
1761dc6b66 Missed a directory 2026-03-26 18:05:28 +00:00
Tucker Morgan
da71273d7e Getting LLama tests closer to proper passing 2026-03-26 18:05:13 +00:00
Tucker Morgan
7c921d03a8 Working weight sharing in both onnx and pt 2026-03-25 21:27:14 +00:00
Tucker Morgan
679aa7e092 Fixing up the onnx and fx parsing layer to share more of their code paths 2026-03-25 17:25:00 +00:00
Tucker Morgan
3dd2be2fb2 First pass of the new memory model 2026-03-25 15:59:06 +00:00
59 changed files with 5620 additions and 3456 deletions

View File

@@ -0,0 +1,130 @@
---
name: aoti-debug
description: Debug AOTInductor (AOTI) errors including device mismatches, CUDA illegal memory access, segfaults, and wrong outputs when deploying compiled PyTorch models. Use when encountering errors with aoti_compile_and_package, aoti_load_package, or the deprecated aot_compile/aot_load APIs.
---
# AOTInductor Debugging
Debug errors when compiling and deploying PyTorch models with AOTInductor.
## First Step: Always Check Device and Shape Matching
**For ANY AOTI error (segfault, exception, crash, wrong output), check these first:**
1. **Compile device == Load device**: The model must be loaded on the same device type it was compiled on
2. **Input devices match**: Runtime inputs must be on the same device as the compiled model
3. **Input shapes match**: Runtime input shapes must match compilation shapes (or satisfy dynamic shape constraints)
```python
# Compilation -- note the device and shapes
model = MyModel().eval().cuda()
inp = torch.randn(2, 10, device="cuda")
pkg = torch._inductor.aoti_compile_and_package(model, (inp,))
# Loading -- device type MUST match compilation
loaded = torch._inductor.aoti_load_package(pkg) # auto-detects device from package
# Inference -- device and shapes MUST match
out = loaded(torch.randn(2, 10, device="cuda")) # same device, same shape
```
**AOTI requires compile and load to use the same device type.** Cross-device loading (compile on GPU, load on CPU) is NOT supported. Device index can differ (cuda:0 vs cuda:1).
## Current vs Deprecated API
### Current API (use this)
```python
torch._inductor.aoti_compile_and_package() # compile
torch._inductor.aoti_load_package() # load (auto-detects device)
```
### Deprecated API (migrate away)
```python
torch._export.aot_compile() # deprecated
torch._export.aot_load() # deprecated
```
The new API stores device metadata in the package, so `aoti_load_package()` automatically uses the correct device type.
## Common Error Patterns
### Device Mismatch Segfault
**Symptom**: Segfault, exception, or crash during load or execution.
**Example errors**:
- `The specified pointer resides on host memory and is not registered with any CUDA device`
- Crash during constant loading
- `Expected out tensor to have device cuda:0, but got cpu instead`
**Solution**: Ensure compile and load use the same device type.
### Input Device Mismatch at Runtime
**Symptom**: RuntimeError during model execution.
**Better debugging**: Run with `AOTI_RUNTIME_CHECK_INPUTS=1` for clear errors:
```bash
AOTI_RUNTIME_CHECK_INPUTS=1 python script.py
```
Produces actionable messages like:
```
Error: input_handles[0]: unmatched device type, expected: 0(cpu), but got: 1(cuda)
```
## Debugging CUDA Illegal Memory Access (IMA)
### Step 1: Sanity Checks
```bash
AOTI_RUNTIME_CHECK_INPUTS=1 python script.py # validate inputs match compilation guards
TORCHINDUCTOR_NAN_ASSERTS=1 python script.py # check for NaN before/after each kernel
```
Both flags take effect at **compile time** (codegen time).
### Step 2: Make IMA Deterministic
```bash
PYTORCH_NO_CUDA_MEMORY_CACHING=1 CUDA_LAUNCH_BLOCKING=1 python script.py
```
- `PYTORCH_NO_CUDA_MEMORY_CACHING=1` -- disables caching allocator (which allocates bigger buffers, masking IMA)
- `CUDA_LAUNCH_BLOCKING=1` -- forces synchronous kernel launches (pinpoints which kernel crashed)
Both take effect at **runtime**.
### Step 3: Identify the Problematic Kernel
```bash
AOT_INDUCTOR_DEBUG_INTERMEDIATE_VALUE_PRINTER=3 python script.py
```
Prints kernels one by one at runtime. Combined with Step 2 flags, shows which kernel launched right before the error.
To inspect inputs to specific kernels:
```bash
AOT_INDUCTOR_FILTERED_KERNELS_TO_PRINT="kernel_name_1,kernel_name_2" \
AOT_INDUCTOR_DEBUG_INTERMEDIATE_VALUE_PRINTER=2 python script.py
```
If inputs to a kernel are unexpected, trace back to the kernel that produced the bad input.
## Environment Variables Reference
| Variable | When | Purpose |
|---|---|---|
| `AOTI_RUNTIME_CHECK_INPUTS=1` | Compile time | Validate inputs match compilation guards |
| `TORCHINDUCTOR_NAN_ASSERTS=1` | Compile time | Check for NaN before/after kernels |
| `PYTORCH_NO_CUDA_MEMORY_CACHING=1` | Runtime | Make IMA errors deterministic |
| `CUDA_LAUNCH_BLOCKING=1` | Runtime | Force synchronous kernel launches |
| `AOT_INDUCTOR_DEBUG_INTERMEDIATE_VALUE_PRINTER=3` | Compile time | Print kernels at runtime |
| `AOT_INDUCTOR_FILTERED_KERNELS_TO_PRINT="..."` | Compile time | Filter which kernels to print |
| `TORCH_LOGS="+inductor,output_code"` | Runtime | See PT2 internal logs |
| `TORCH_SHOW_CPP_STACKTRACES=1` | Runtime | Show C++ stack traces |
## Common Sources of Issues
- **Dynamic shapes**: Historically a common source of IMA errors. Pay special attention when using dynamic shape constraints.
- **Custom ops**: Especially C++ custom ops with dynamic shapes. The meta function may need to handle SymInt properly.

View File

@@ -0,0 +1,195 @@
---
name: pt2-debug
description: Debug torch.compile failures, graph breaks, recompilation issues, accuracy mismatches, and Triton kernel errors. Use when encountering BackendCompilerFailed exceptions, torch.compile errors, recompilation warnings, or numerical accuracy issues with compiled PyTorch models.
---
# PyTorch 2 Compile Debugging
Debug `torch.compile`, Dynamo, Inductor, and AOTAutograd failures when using PyTorch as a library.
## Diagnostic Environment Variables
Pick the right diagnostic based on the error:
| Command | When to use |
|---|---|
| `TORCH_LOGS="+dynamo,graph_breaks,recompiles" python script.py` | Quick overview of what's going wrong |
| `TORCH_COMPILE_DEBUG=1 python script.py` | Full debug artifacts (FX graphs, Inductor IR, generated code) in `torch_compile_debug/` |
| `TORCH_LOGS="output_code" python script.py` | See the generated Triton/C++ kernel code |
| `TORCH_TRACE=/path/to/trace python script.py` | Structured trace (parse with `tlparse`) |
| `TORCHINDUCTOR_COMPILE_THREADS=1 python script.py` | Single-threaded compilation for pdb debugging |
## Error Triage
Classify the failure and jump to the right section:
| Error Pattern | Category |
|---|---|
| `Unsupported: ...` or `graph break` in logs | [Graph Breaks](#graph-breaks) |
| `BackendCompilerFailed` | [Backend Failures](#backend-compiler-failures) |
| `RecompileError` or `cache_size_limit` | [Recompilation](#recompilation-issues) |
| Accuracy mismatch / wrong numerical output | [Accuracy](#accuracy-issues) |
| `InternalTorchDynamoError` | [Internal Errors](#internal-dynamo-errors) |
| Segfault or CUDA IMA | [Runtime Crashes](#runtime-crashes) |
| Triton assertion / index out of bounds | [Triton Failures](#triton-kernel-failures) |
## Graph Breaks
Graph breaks split the compiled graph into smaller subgraphs, causing performance regressions.
**Diagnose:**
```bash
TORCH_LOGS="graph_breaks" python script.py
```
**Common causes:**
- Data-dependent control flow
- Unsupported Python builtins
- In-place ops on inputs, unsupported dtypes
- Calls to non-traceable functions
**Fix approaches:**
1. Read the graph break message to identify the unsupported operation
2. Check for a decomposition or supported alternative
3. Consider `torch._dynamo.allow_in_graph` or restructure user code
## Backend Compiler Failures
`BackendCompilerFailed` means Inductor crashed during compilation.
**Diagnose with the minifier:**
```bash
# Generate minifier launcher
TORCHDYNAMO_REPRO_AFTER=aot TORCHDYNAMO_REPRO_LEVEL=2 python script.py
# Run the minifier to get minimal failing graph
python minifier_launcher.py minify
# Run the minimized reproduction
python minifier_launcher.py run
```
**Then inspect:**
```bash
TORCH_COMPILE_DEBUG=1 python script.py # FX graphs in torch_compile_debug/
```
## Recompilation Issues
Excessive recompilation from guards that are too specific, causing cache misses.
**Diagnose:**
```bash
TORCH_LOGS="recompiles,recompiles_verbose,guards" python script.py
```
**Key config:**
```python
torch._dynamo.config.recompile_limit # default: 8
torch._dynamo.config.fail_on_recompile_limit_hit = True # hard error on limit
```
**Common causes:**
- Changing tensor shapes without marking them dynamic
- Python scalar values that change between calls
- Global state mutations between calls
**Fix:** Read the recompilation reason from logs, identify the failing guard, then either:
- Mark dimensions as dynamic: `torch._dynamo.mark_dynamic(tensor, dim)`
- Fix the source of guard instability
## Accuracy Issues
Compiled model produces different numerical results than eager mode.
**Diagnose:**
```bash
# Compares compiled vs eager with fp64 reference, dumps repro on failure
TORCHDYNAMO_REPRO_AFTER=aot TORCHDYNAMO_REPRO_LEVEL=4 python script.py
```
**Fix approach:**
1. Get minimal failing graph from the minifier
2. Compare eager vs compiled output at fp64 precision
3. Binary search through ops to find the diverging operation
4. Check for known issues: reduction order, fused kernels, dtype promotions
## Internal Dynamo Errors
`InternalTorchDynamoError` indicates a bug in Dynamo.
**Diagnose:**
```bash
TORCHDYNAMO_VERBOSE=1 python script.py
# or equivalently:
TORCH_LOGS="+dynamo" python script.py
```
**Debug interactively:**
```bash
TORCHINDUCTOR_COMPILE_THREADS=1 python script.py # then attach pdb
```
## Runtime Crashes
Segfaults and CUDA illegal memory access during execution of compiled code.
**Make crash deterministic:**
```bash
PYTORCH_NO_CUDA_MEMORY_CACHING=1 CUDA_LAUNCH_BLOCKING=1 python script.py
```
**Add NaN checks to find the first bad kernel:**
```bash
TORCHINDUCTOR_NAN_ASSERTS=1 python script.py
```
**Inductor sync debugging:**
```python
torch._inductor.config.triton.debug_sync_kernel = True # sync after every kernel
torch._inductor.config.triton.debug_sync_graph = True # sync before/after graph
```
**Fix approach:**
1. Make deterministic with `PYTORCH_NO_CUDA_MEMORY_CACHING=1 CUDA_LAUNCH_BLOCKING=1`
2. Check input shapes, devices, dtypes
3. Inspect generated kernel code with `TORCH_LOGS="output_code"`
4. Use `TORCHINDUCTOR_NAN_ASSERTS=1` to find the first kernel producing bad values
5. Dynamic shapes are historically a common source of IMA
## Triton Kernel Failures
Triton assertion failures or index-out-of-bounds in generated kernels.
**Diagnose:**
```bash
TORCH_LOGS="output_code,schedule" python script.py
```
**Fix approach:**
1. Get the generated Triton kernel from `output_code` logs
2. Check index computations for off-by-one or wrong stride calculations
3. Check IR with `TORCH_COMPILE_DEBUG=1` to trace back to the FX op
4. Check if fusion decisions created invalid index combinations
## Distinguish Trace-Time vs Runtime
Many bugs come from confusing these:
- **Trace-time**: Inside Dynamo's symbolic interpreter. Function calls may be constant-folded.
- **Runtime**: Real tensors, real Python calls.
When debugging, add `print()` directly in source files rather than monkey-patching -- dispatch chains make monkey-patching unreliable.
## Using the Minifier
The minifier reduces a failing graph to the smallest reproduction:
```bash
# For compilation failures (level 2)
TORCHDYNAMO_REPRO_AFTER=aot TORCHDYNAMO_REPRO_LEVEL=2 python script.py
python minifier_launcher.py minify
python minifier_launcher.py run
# For accuracy failures (level 4)
TORCHDYNAMO_REPRO_AFTER=aot TORCHDYNAMO_REPRO_LEVEL=4 python script.py
```

View File

@@ -0,0 +1,134 @@
---
name: ruff
description:
Guide for using ruff, the extremely fast Python linter and formatter. Use this
when linting, formatting, or fixing Python code.
---
# ruff
Ruff is an extremely fast Python linter and code formatter. It replaces Flake8,
isort, Black, pyupgrade, autoflake, and dozens of other tools.
## When to use ruff
**Always use ruff for Python linting and formatting**, especially if you see:
- `[tool.ruff]` section in `pyproject.toml`
- A `ruff.toml` or `.ruff.toml` configuration file
However, avoid making unnecessary changes:
- **Don't format unformatted code** - If `ruff format --diff` shows changes
throughout an entire file, the project likely isn't using ruff for formatting.
Skip formatting to avoid obscuring actual changes.
- **Scope fixes to code being edited** - Use `ruff check --diff` to see fixes
relevant to the code you're changing. Only apply fixes to files you're
modifying unless the user explicitly asks for broader fixes.
## How to invoke ruff
- `uv run ruff ...` - Use when ruff is in the project's dependencies to ensure
you use the pinned version
- `uvx ruff ...` - Use when ruff is not a project dependency, or for quick
one-off checks
- `ruff ...` - Use if ruff is installed globally
## Commands
### Linting
```bash
ruff check . # Check all files in current directory
ruff check path/to/file.py # Check specific file
ruff check --fix . # Auto-fix fixable violations
ruff check --fix --unsafe-fixes . # Include unsafe fixes (review changes!)
ruff check --watch . # Watch for changes and re-lint
ruff check --select E,F . # Only check specific rules
ruff check --ignore E501 . # Ignore specific rules
ruff rule E501 # Explain a specific rule
ruff linter # List available linters
```
### Formatting
```bash
ruff format . # Format all files
ruff format path/to/file.py # Format specific file
ruff format --check . # Check if files are formatted (no changes)
ruff format --diff . # Show formatting diff without applying
```
## Configuration
Ruff is configured in `pyproject.toml` or `ruff.toml`:
```toml
# pyproject.toml
[tool.ruff.lint]
select = ["E", "F", "I", "UP"] # Enable specific rule sets
ignore = ["E501"] # Ignore specific rules
[tool.ruff.lint.isort]
known-first-party = ["myproject"]
```
## Migrating from other tools
### Black → ruff format
```bash
black . → ruff format .
black --check . → ruff format --check .
black --diff . → ruff format --diff .
```
### Flake8 → ruff check
```bash
flake8 . → ruff check .
flake8 --select E,F . → ruff check --select E,F .
flake8 --ignore E501 . → ruff check --ignore E501 .
```
### isort → ruff check
```bash
isort . → ruff check --select I --fix .
isort --check . → ruff check --select I .
isort --diff . → ruff check --select I --diff .
```
## Common patterns
### Apply lint fixes before formatting
Run `ruff check --fix` before `ruff format`. Lint fixes can change code
structure (e.g., reordering imports), which formatting then cleans up.
```bash
ruff check --fix .
ruff format .
```
### Applying and reviewing unsafe fixes
Ruff categorizes some auto-fixes as "unsafe" because they may change code
behavior, not just style. For example, removing unused imports could break code
that relies on side effects.
```bash
ruff check --fix --unsafe-fixes --diff . # Preview changes first
ruff check --fix --unsafe-fixes . # Apply changes
```
**Always review changes before applying `--unsafe-fixes`:**
- Use `ruff rule <CODE>` to understand why the fix is considered unsafe
- Verify the fix doesn't violate those assumptions in your code
## Documentation
For detailed information, read the official documentation:
- https://docs.astral.sh/ruff/

135
.agents/skills/ty/SKILL.md Normal file
View File

@@ -0,0 +1,135 @@
---
name: ty
description:
Guide for using ty, the extremely fast Python type checker and language
server. Use this when type checking Python code or setting up type checking in
Python projects.
---
# ty
ty is an extremely fast Python type checker and language server. It replaces
mypy, Pyright, and other type checkers.
## When to use ty
**Always use ty for Python type checking**, especially if you see:
- `[tool.ty]` section in `pyproject.toml`
- A `ty.toml` configuration file
## How to invoke ty
- `uv run ty ...` - Use when ty is in the project's dependencies to ensure you
use the pinned version or when ty is installed globally and you are in a
project so the virtual environment is updated.
- `uvx ty ...` - Use when ty is not a project dependency, or for quick one-off
checks
## Commands
### Type checking
```bash
ty check # Check all files in current directory
ty check path/to/file.py # Check specific file
ty check src/ # Check specific directory
```
### Rule configuration
```bash
ty check --error possibly-unresolved-reference # Treat as error
ty check --warn division-by-zero # Treat as warning
ty check --ignore unresolved-import # Disable rule
```
### Python version targeting
```bash
ty check --python-version 3.12 # Check against Python 3.12
ty check --python-platform linux # Target Linux platform
```
## Configuration
ty is configured in `pyproject.toml` or `ty.toml`:
```toml
# pyproject.toml
[tool.ty.environment]
python-version = "3.12"
[tool.ty.rules]
possibly-unresolved-reference = "warn"
division-by-zero = "error"
[tool.ty.src]
include = ["src/**/*.py"]
exclude = ["**/migrations/**"]
[tool.ty.terminal]
output-format = "full"
error-on-warning = false
```
### Per-file overrides
Use overrides to apply different rules to specific files, such as relaxing rules
for tests or scripts that have different typing requirements than production
code:
```toml
[[tool.ty.overrides]]
include = ["tests/**", "**/test_*.py"]
[tool.ty.overrides.rules]
possibly-unresolved-reference = "warn"
```
## Language server
This plugin automatically configures the ty language server for Python files
(`.py` and `.pyi`).
## Migrating from other tools
### mypy → ty
```bash
mypy . → ty check
mypy --strict . → ty check --error-on-warning
mypy path/to/file.py → ty check path/to/file.py
```
### Pyright → ty
```bash
pyright . → ty check
pyright path/to/file.py → ty check path/to/file.py
```
## Common patterns
### Don't add ignore comments
Fix type errors instead of suppressing them. Only add ignore comments when
explicitly requested by the user. Use `ty: ignore`, not `type: ignore`, and
prefer rule-specific ignores:
```python
# Good: rule-specific ignore
x = undefined_var # ty: ignore[possibly-unresolved-reference]
# Bad: blanket ty ignore
x = undefined_var # ty: ignore
# Bad: tool agnostic blanket ignore
x = undefined_var # type: ignore
```
## Documentation
For detailed information, read the official documentation:
- https://docs.astral.sh/ty/

182
.agents/skills/uv/SKILL.md Normal file
View File

@@ -0,0 +1,182 @@
---
name: uv
description:
Guide for using uv, the Python package and project manager. Use this when
working with Python projects, scripts, packages, or tools.
---
# uv
uv is an extremely fast Python package and project manager. It replaces pip,
pip-tools, pipx, pyenv, virtualenv, poetry, etc.
## When to use uv
**Always use uv for Python work**, especially if you see:
- The `uv.lock` file
- uv headers in `requirements*` files, e.g., "This file was autogenerated by uv"
Don't use uv in projects managed by other tools:
- Poetry projects (identifiable by `poetry.lock` file)
- PDM projects (identifiable by `pdm.lock` file)
## Choosing the right workflow
### Scripts
**Use when:** Running single Python files and standalone scripts.
**Key commands:**
```bash
uv run script.py # Run a script
uv run --with requests script.py # Run with additional packages
uv add --script script.py requests # Add dependencies inline to the script
```
### Projects
**Use when:** There is a `pyproject.toml` or `uv.lock`
**Key commands:**
```bash
uv init # Create new project
uv add requests # Add dependency
uv remove requests # Remove dependency
uv sync # Install from lockfile
uv run <command> # Run commands in environment
uv run python -c "" # Run Python in project environment
uv run -p 3.12 <command> # Run with specific Python version
```
### Tools
**Use when:** Running command-line tools (e.g., ruff, ty, pytest) without
installation.
**Key commands:**
```bash
uvx <tool> <args> # Run a tool without installation
uvx <tool>@<version> <args> # Run a specific version of a tool
```
**Important:**
- `uvx` runs tools from PyPI by package name. This can be unsafe - only run
well-known tools.
- Only use `uv tool install` only when specifically requested by the user.
### Pip interface
**Use when:** Legacy workflows with `requirements.txt` or manual environment
management, no `uv.lock` present.
**Key commands:**
```bash
uv venv
uv pip install -r requirements.txt
uv pip compile requirements.in -o requirements.txt
uv pip sync requirements.txt
# Platform independent resolution
uv pip compile --universal requirements.in -o requirements.txt
```
**Important:**
- Don't use the pip interface unless clearly needed.
- Don't introduce new `requirements.txt` files.
- Prefer `uv init` for new projects.
## Migrating from other tools
### pyenv → uv python
```bash
pyenv install 3.12 → uv python install 3.12
pyenv versions → uv python list --only-installed
pyenv local 3.12 → uv python pin 3.12
pyenv global 3.12 → uv python install 3.12 --default
```
### pipx → uvx
```bash
pipx run ruff → uvx ruff
pipx install ruff → uv tool install ruff
pipx upgrade ruff → uv tool upgrade ruff
pipx list → uv tool list
```
### pip and pip-tools → uv pip
```bash
pip install package → uv pip install package
pip install -r req.txt → uv pip install -r req.txt
pip freeze → uv pip freeze
pip-compile req.in → uv pip compile req.in
pip-sync req.txt → uv pip sync req.txt
virtualenv .venv → uv venv
```
## Common patterns
### Don't use pip in uv projects
```bash
# Bad
pip install requests
# Good
uv add requests
```
### Don't run python directly
```bash
# Bad
python script.py
# Good
uv run script.py
```
```bash
# Bad
python -c "..."
# Good
uv run python -c "..."
```
```bash
# Bad
python3.12 -c "..."
# Good
uvx python@3.12 -c "..."
```
### Don't manually manage environments in uv projects
```bash
# Bad
python -m venv .venv
source .venv/bin/activate
# Good
uv run <command>
```
## Documentation
For detailed information, read the official documentation:
- https://docs.astral.sh/uv/llms.txt
The documentation links to specific pages for each of these workflows.

View File

@@ -17,11 +17,15 @@
"userUid": "1000",
"userGid": "1000",
"configureZshAsDefaultShell": false
},
"ghcr.io/devcontainers/features/node:1": {
"version": "lts"
}
},
"remoteUser": "ubuntu",
"remoteEnv": {
"CARGO_HOME": "/home/ubuntu/.cache/luminal/cargo"
"CARGO_HOME": "/home/ubuntu/.cache/luminal/cargo",
"CODEX_HOME": "${containerWorkspaceFolder}/.claude/codex"
},
"postStartCommand": "mkdir -p /home/ubuntu/.cache/luminal/cargo && git config --global --add safe.directory ${containerWorkspaceFolder} && gh auth setup-git",
"customizations": {

View File

@@ -21,11 +21,15 @@
"userUid": "1000",
"userGid": "1000",
"configureZshAsDefaultShell": false
},
"ghcr.io/devcontainers/features/node:1": {
"version": "lts"
}
},
"remoteUser": "ubuntu",
"remoteEnv": {
"CARGO_HOME": "/home/ubuntu/.cache/luminal/cargo"
"CARGO_HOME": "/home/ubuntu/.cache/luminal/cargo",
"CODEX_HOME": "${containerWorkspaceFolder}/.claude/codex"
},
"postStartCommand": "mkdir -p /home/ubuntu/.cache/luminal/cargo && git config --global --add safe.directory ${containerWorkspaceFolder} && gh auth setup-git",
"customizations": {
@@ -52,4 +56,4 @@
]
}
}
}
}

30
.github/workflows/cuda-clippy.yml vendored Normal file
View File

@@ -0,0 +1,30 @@
name: CUDA Clippy
on:
push:
branches: ["main"]
pull_request:
branches: ["main"]
workflow_dispatch:
jobs:
cuda_clippy:
name: CUDA Clippy
runs-on: cuda_t4_runner
container:
image: ghcr.io/luminal-ai/luminal-docker:cuda
options: --gpus all
timeout-minutes: 20
steps:
- uses: actions/checkout@v6
- name: Mark workspace as safe for git
run: git config --global --add safe.directory "$GITHUB_WORKSPACE"
- uses: actions/setup-python@v5
with:
python-version: "3.11"
- name: Update Rust toolchain
run: rustup update
- uses: pre-commit/action@v3.0.1
with:
extra_args: cargo-clippy --all-files

23
.github/workflows/fmt.yml vendored Normal file
View File

@@ -0,0 +1,23 @@
name: Fmt
on:
push:
branches: ["main"]
pull_request:
branches: ["main"]
workflow_dispatch:
jobs:
fmt:
name: Fmt
runs-on: ubuntu-latest
timeout-minutes: 20
steps:
- uses: actions/checkout@v6
- uses: actions/setup-python@v5
with:
python-version: "3.11"
- uses: pre-commit/action@v3.0.1
with:
extra_args: cargo-fmt --all-files

View File

@@ -1,86 +0,0 @@
name: Lint
on:
push:
branches: ["main"]
pull_request:
branches: ["main"]
workflow_dispatch:
env:
CARGO_TERM_COLOR: always
jobs:
ruff:
name: Ruff
runs-on: ubuntu-latest
timeout-minutes: 20
steps:
- uses: actions/checkout@v6
- uses: actions/setup-python@v5
with:
python-version: "3.11"
- uses: pre-commit/action@v3.0.1
with:
extra_args: ruff-check --all-files
ruff_format:
name: Ruff Format
runs-on: ubuntu-latest
timeout-minutes: 20
steps:
- uses: actions/checkout@v6
- uses: actions/setup-python@v5
with:
python-version: "3.11"
- uses: pre-commit/action@v3.0.1
with:
extra_args: ruff-format --all-files
clippy:
name: Clippy
runs-on: ubuntu-latest
timeout-minutes: 20
steps:
- uses: actions/checkout@v6
- uses: actions/setup-python@v5
with:
python-version: "3.11"
- name: Update Rust toolchain
run: rustup update
- uses: pre-commit/action@v3.0.1
with:
extra_args: cargo-clippy --all-files
metal_clippy:
name: Metal Clippy
runs-on: macos-14
timeout-minutes: 30
steps:
- uses: actions/checkout@v6
- uses: actions/setup-python@v5
with:
python-version: "3.11"
- name: Update Rust toolchain
run: rustup update
- uses: pre-commit/action@v3.0.1
with:
extra_args: --hook-stage manual cargo-clippy-metal --all-files
fmt:
name: Fmt
runs-on: ubuntu-latest
timeout-minutes: 20
steps:
- uses: actions/checkout@v6
- uses: actions/setup-python@v5
with:
python-version: "3.11"
- uses: pre-commit/action@v3.0.1
with:
extra_args: cargo-fmt --all-files

25
.github/workflows/metal-clippy.yml vendored Normal file
View File

@@ -0,0 +1,25 @@
name: Metal Clippy
on:
push:
branches: ["main"]
pull_request:
branches: ["main"]
workflow_dispatch:
jobs:
metal_clippy:
name: Metal Clippy
runs-on: macos-14
timeout-minutes: 30
steps:
- uses: actions/checkout@v6
- uses: actions/setup-python@v5
with:
python-version: "3.11"
- name: Update Rust toolchain
run: rustup update
- uses: pre-commit/action@v3.0.1
with:
extra_args: --hook-stage manual cargo-clippy-metal --all-files

View File

@@ -5,13 +5,16 @@ on:
branches: ["main"]
pull_request:
branches: ["main"]
types: [opened, synchronize, reopened, ready_for_review]
types: [labeled, synchronize]
workflow_dispatch:
jobs:
modal_example:
# Keep the draft check PR-specific so push/manual runs still execute.
if: ${{ github.event_name != 'pull_request' || !github.event.pull_request.draft }}
if: >-
github.event_name == 'push'
|| github.event_name == 'workflow_dispatch'
|| (github.event_name == 'pull_request'
&& contains(github.event.pull_request.labels.*.name, 'modal-ready'))
name: "${{ matrix.example }} (Modal ${{ matrix.gpu.type }})"
runs-on: ubuntu-latest
environment: Modal

23
.github/workflows/ruff-format.yml vendored Normal file
View File

@@ -0,0 +1,23 @@
name: Ruff Format
on:
push:
branches: ["main"]
pull_request:
branches: ["main"]
workflow_dispatch:
jobs:
ruff_format:
name: Ruff Format
runs-on: ubuntu-latest
timeout-minutes: 20
steps:
- uses: actions/checkout@v6
- uses: actions/setup-python@v5
with:
python-version: "3.11"
- uses: pre-commit/action@v3.0.1
with:
extra_args: ruff-format --all-files

23
.github/workflows/ruff.yml vendored Normal file
View File

@@ -0,0 +1,23 @@
name: Ruff
on:
push:
branches: ["main"]
pull_request:
branches: ["main"]
workflow_dispatch:
jobs:
ruff:
name: Ruff
runs-on: ubuntu-latest
timeout-minutes: 20
steps:
- uses: actions/checkout@v6
- uses: actions/setup-python@v5
with:
python-version: "3.11"
- uses: pre-commit/action@v3.0.1
with:
extra_args: ruff-check --all-files

24
.github/workflows/test-core.yml vendored Normal file
View File

@@ -0,0 +1,24 @@
name: Test Core
on:
push:
branches: ["main"]
pull_request:
branches: ["main"]
workflow_dispatch:
env:
CARGO_TERM_COLOR: always
jobs:
core_unit_test:
name: Core Unit Tests
runs-on: ubuntu-latest
container:
image: ghcr.io/luminal-ai/luminal-docker:cpu
timeout-minutes: 20
steps:
- uses: actions/checkout@v6
- name: Run tests
run: cargo test --workspace --exclude luminal_cuda_lite --exclude luminal_metal --exclude luminal_bench --verbose

View File

@@ -5,44 +5,31 @@ on:
branches: ["main"]
pull_request:
branches: ["main"]
types: [labeled, synchronize]
workflow_dispatch:
env:
CARGO_TERM_COLOR: always
jobs:
cuda_clippy:
name: Cuda Clippy
runs-on: cuda_t4_runner
container:
image: ghcr.io/luminal-ai/luminal-docker:cuda
options: --gpus all
cuda_unit_test:
if: >-
github.event_name == 'push'
|| github.event_name == 'workflow_dispatch'
|| (github.event_name == 'pull_request'
&& contains(github.event.pull_request.labels.*.name, 'modal-ready'))
name: Cuda Unit Tests
runs-on: ubuntu-latest
environment: Modal
timeout-minutes: 30
steps:
- uses: actions/checkout@v6
- name: Mark workspace as a safe git directory
run: git config --global --add safe.directory "$GITHUB_WORKSPACE"
- uses: actions/setup-python@v5
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.11"
- uses: pre-commit/action@v3.0.1
with:
extra_args: --hook-stage manual cargo-clippy-cuda-lite --all-files
cuda_unit_test:
name: Cuda Unit Tests
runs-on: cuda_t4_runner
container:
image: ghcr.io/luminal-ai/luminal-docker:cuda
options: --gpus all
timeout-minutes: 30
steps:
- uses: actions/checkout@v6
- name: Detect GPU compute capability
run: |
CAP=$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader | head -1 | tr -d '.')
echo "CUDA_COMPUTE_CAP=${CAP}" >> "$GITHUB_ENV"
- name: Run CUDA crate tests
run: cargo test -p luminal_cuda_lite --verbose -- --test-threads=1
- name: Install Modal
run: pip install modal
- name: Run CUDA tests on Modal
env:
MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }}
MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }}
run: modal run ci/modal_cargo_test.py

19
.github/workflows/test-metal.yml vendored Normal file
View File

@@ -0,0 +1,19 @@
name: Test Metal
on:
push:
branches: ["main"]
pull_request:
branches: ["main"]
workflow_dispatch:
jobs:
metal_unit_test:
name: Metal Unit Tests
runs-on: macos-14
timeout-minutes: 30
steps:
- uses: actions/checkout@v6
- name: Run Metal crate tests
run: rustup update; cargo test -p luminal_metal --verbose -- --test-threads=1

View File

@@ -1,56 +1,20 @@
name: Test
name: Test Python CUDA
on:
push:
branches: ["main"]
pull_request:
branches: ["main"]
types: [labeled, synchronize]
workflow_dispatch:
env:
CARGO_TERM_COLOR: always
jobs:
core_unit_test:
name: Core Unit Tests
runs-on: ubuntu-latest
container:
image: ghcr.io/luminal-ai/luminal-docker:cpu
timeout-minutes: 20
steps:
- uses: actions/checkout@v6
- name: Run tests
run: cargo test --workspace --exclude luminal_cuda_lite --exclude luminal_metal --exclude luminal_bench --verbose
metal_unit_test:
name: Metal Unit Tests
runs-on: macos-14
timeout-minutes: 30
steps:
- uses: actions/checkout@v6
- name: Run Metal crate tests
run: rustup update; cargo test -p luminal_metal --verbose -- --test-threads=1
python_native_tests:
name: Python Native Tests
runs-on: ubuntu-latest
container:
image: ghcr.io/luminal-ai/luminal-docker:cpu
timeout-minutes: 45
defaults:
run:
working-directory: crates/luminal_python
steps:
- uses: actions/checkout@v6
- name: Update Rust toolchain
run: rustup update
- name: Build maturin extension
run: uv run maturin develop --manifest-path rust/Cargo.toml
- name: Run pytest
run: uv run pytest tests/test_hlir_ops.py tests/test_unary.py -v -m "not slow"
python_cuda_tests:
if: >-
github.event_name == 'push'
|| github.event_name == 'workflow_dispatch'
|| (github.event_name == 'pull_request'
&& contains(github.event.pull_request.labels.*.name, 'modal-ready'))
name: Python CUDA Tests
runs-on: ubuntu-latest
environment: Modal

View File

@@ -0,0 +1,28 @@
name: Test Python Native
on:
push:
branches: ["main"]
pull_request:
branches: ["main"]
workflow_dispatch:
jobs:
python_native_tests:
name: Python Native Tests
runs-on: ubuntu-latest
container:
image: ghcr.io/luminal-ai/luminal-docker:cpu
timeout-minutes: 45
defaults:
run:
working-directory: crates/luminal_python
steps:
- uses: actions/checkout@v6
- name: Update Rust toolchain
run: rustup update
- name: Build maturin extension
run: uv run maturin develop --manifest-path rust/Cargo.toml
- name: Run pytest
run: uv run pytest tests/test_hlir_ops.py tests/test_unary.py -v -m "not slow"

3
.gitignore vendored
View File

@@ -1,6 +1,9 @@
/target
/crates/**/target
/examples/**/target
.claude-project
.claude-memory
.codex
*.env
.claude/

34
CLAUDE.md Normal file
View File

@@ -0,0 +1,34 @@
# Luminal
## Package Management
- Use `uv add`, `uv add --dev`, `uv remove` for Python dependencies (pyproject.toml is in `crates/luminal_python/`)
- Use `uv sync` to sync the Python environment
- Never use pip, pip-tools, poetry, or conda
- Never manually create or activate virtual environments — uv manages `.venv/` automatically
- Never generate requirements.txt
## Code Execution
- Always use `uv run` to execute Python tools: `uv run pytest`, `uv run pre-commit`, `uv run python`
- Use `cargo` directly for Rust: `cargo build`, `cargo test`, `cargo check`, `cargo clippy`
- Python project root is `crates/luminal_python/` — run `uv run` commands from there
## Building the Python Package (Maturin)
- After modifying `.rs` files that affect the Python bridge, rebuild with: `maturin develop --release`
- Maturin config is in `crates/luminal_python/pyproject.toml` under `[tool.maturin]`
## Pre-commit
- Run with: `uv run pre-commit run --all-files`
- Hooks configured: ruff-check, ruff-format (Python), cargo-fmt, cargo-clippy (Rust)
- Manual-stage hooks (cargo-clippy-metal, cargo-clippy-cuda-lite) run with `--hook-stage manual`
## Testing
- **Rust tests**: `cargo test -p <crate_name>`
- **Python tests**: `cd crates/luminal_python && uv run pytest`
- `./run_test.sh` — native backend
- `./run_tests_cuda.sh` — CUDA backend
- See `crates/luminal_python/CLAUDE.md` for Python test patterns and conventions

67
ci/modal_cargo_test.py Normal file
View File

@@ -0,0 +1,67 @@
import modal
import subprocess
import os
import sys
gpu_type = os.environ.get("GPU_TYPE", "T4")
CUDARC_CUDA_VERSION = "12080"
app = modal.App("luminal-ci-cargo-test")
WORKDIR = "/workspace/luminal"
cuda_image = (
modal.Image.from_registry("nvcr.io/nvidia/pytorch:25.03-py3")
.apt_install("protobuf-compiler")
.run_commands(
"curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y",
)
.env(
{
"PATH": "/root/.cargo/bin:$PATH",
"CUDARC_CUDA_VERSION": CUDARC_CUDA_VERSION,
}
)
.add_local_dir(".", remote_path=WORKDIR, copy=True)
)
@app.function(
image=cuda_image,
gpu=gpu_type,
timeout=1800, # 30 minutes
)
def run_cargo_test():
"""Run cargo test for luminal_cuda_lite on a Modal GPU."""
subprocess.run(["nvidia-smi"], check=True)
# Detect GPU compute capability
result = subprocess.run(
["nvidia-smi", "--query-gpu=compute_cap", "--format=csv,noheader"],
capture_output=True,
text=True,
check=True,
)
compute_cap = result.stdout.strip().replace(".", "")
subprocess.run(
[
"cargo", "test",
"-p", "luminal_cuda_lite",
"--verbose",
"--",
"--test-threads=1",
],
cwd=WORKDIR,
env={
**os.environ,
"CUDARC_CUDA_VERSION": CUDARC_CUDA_VERSION,
"CUDA_COMPUTE_CAP": compute_cap,
},
check=True,
)
@app.local_entrypoint()
def main():
run_cargo_test.remote()

View File

@@ -69,7 +69,7 @@ pub type Ops = (
/// Build a rewrite that matches an HLIR op, reads dtype(s) from the given source fields,
/// and unions with a kernel op that has the same fields plus the dtype(s) appended.
fn kernel_rewrite<H: Default + EgglogOp, L: Default + EgglogOp>() -> Rule {
pub fn kernel_rewrite<H: Default + EgglogOp, L: Default + EgglogOp>() -> Rule {
let hlir = H::default().sort();
let llir = L::default().sort();
let (mut args, hlir_kind_term) = hlir.new_call();
@@ -415,8 +415,12 @@ extern \"C\" {{
long long iters = {iters};
{dtype} partial = 0;
{dtype} comp = 0; // Kahan compensation
for (long long i = tid; i < iters; i += THREADS_PER_BLOCK) {{
partial += in_data[in_start + {iter_stride_of_i}];
{dtype} y = in_data[in_start + {iter_stride_of_i}] - comp;
{dtype} t = partial + y;
comp = (t - partial) - y;
partial = t;
}}
#pragma unroll

View File

@@ -3,7 +3,7 @@ use std::sync::Arc;
use crate::{
compile_module_image_for_current_device, cuda_dtype,
kernel::KernelOp,
kernel::hlir::{dtype_includes, generate_dyn_dims_defines},
kernel::hlir::{dtype_includes, generate_dyn_dims_defines, kernel_rewrite},
};
use cudarc::driver::{CudaFunction, CudaModule, CudaSlice, CudaStream};
use itertools::Itertools;
@@ -22,6 +22,9 @@ pub type Ops = (
KernelBatchMatVec,
KernelBatchMatMul,
KernelScatterNoCopy,
KernelSoftmax,
KernelExp,
KernelSigmoid,
);
#[derive(Default, Debug, Clone)]
@@ -1151,6 +1154,7 @@ impl EgglogOp for KernelSoftmax {
("out_strides", ELIST),
("reduce_dim", EXPRESSION),
("reduce_stride", EXPRESSION),
("dtype", DTYPE),
],
)
}
@@ -1160,8 +1164,24 @@ impl EgglogOp for KernelSoftmax {
}
fn rewrites(&self) -> Vec<Rule> {
// No rewrite rules yet - this op is not in the Ops tuple.
vec![]
vec![
kernel_rewrite::<luminal::hlir::Softmax, Self>(),
// Also add a direct rewrite that assumes F32 dtype, in case dtype
// propagation hasn't reached the Softmax node yet.
Rule::raw(
"(rule
(
(= ?sm (Op (Softmax ?shape ?in_strides ?out_strides ?reduce_dim ?reduce_stride) ?inputs))
)
(
(let ?ksm (Op (KernelSoftmax ?shape ?in_strides ?out_strides ?reduce_dim ?reduce_stride (F32)) ?inputs))
(union ?sm ?ksm)
(set (dtype ?ksm) (F32))
)
:name \"softmax-to-kernel-f32\"
)",
),
]
}
fn cleanup(&self) -> bool {
@@ -1176,16 +1196,21 @@ impl EgglogOp for KernelSoftmax {
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
) -> (LLIROp, Vec<&'a ENodeId>) {
let out_shape =
extract_expr_list(egraph, kind_children[0], list_cache, expr_cache).unwrap();
let in_stride =
extract_expr_list(egraph, kind_children[1], list_cache, expr_cache).unwrap();
let out_stride =
extract_expr_list(egraph, kind_children[2], list_cache, expr_cache).unwrap();
let reduce_dim = extract_expr(egraph, kind_children[3], expr_cache).unwrap();
let reduce_stride = extract_expr(egraph, kind_children[4], expr_cache).unwrap();
(
LLIROp::new::<dyn KernelOp>(Box::new(Self {
out_shape: extract_expr_list(egraph, kind_children[0], list_cache, expr_cache)
.unwrap(),
in_stride: extract_expr_list(egraph, kind_children[1], list_cache, expr_cache)
.unwrap(),
out_stride: extract_expr_list(egraph, kind_children[2], list_cache, expr_cache)
.unwrap(),
reduce_dim: extract_expr(egraph, kind_children[3], expr_cache).unwrap(),
reduce_stride: extract_expr(egraph, kind_children[4], expr_cache).unwrap(),
out_shape,
in_stride,
out_stride,
reduce_dim,
reduce_stride,
})),
input_enodes,
)
@@ -1374,3 +1399,370 @@ extern \"C\" {{
"Softmax"
}
}
// KernelExp: native exp (uses expf instead of exp2f * constant)
// Single-kernel alternative to the 3-kernel Constant+Mul+Exp2 path.
// Improves numerical precision by avoiding the truncated log2(e) constant.
#[derive(Default, Debug, Clone)]
pub struct KernelExp {
shape: Vec<Expression>,
in_strides: Vec<Expression>,
out_strides: Vec<Expression>,
dtype: DType,
}
impl EgglogOp for KernelExp {
fn sort(&self) -> SortDef {
sort(
OP_KIND,
"KernelExp",
&[
("shape", ELIST),
("strides", ELIST),
("out_strides", ELIST),
("dtype", DTYPE),
],
)
}
fn n_inputs(&self) -> usize {
1
}
fn rewrites(&self) -> Vec<Rule> {
vec![
// Match Exp2(Mul(x, log2e_constant)) directly.
// This matches the pattern created by frontend exp() = (self * (1/ln(2))).exp2()
Rule::raw(
"(rule
(
(= ?mul (Op (Mul ?shape ?x_stride ?const_stride ?inter_stride) (ICons ?x (ICons ?exp_const (INil)))))
(= ?exp2 (Op (Exp2 ?shape ?inter_stride ?out_stride) (ICons ?mul (INil))))
(= ?dt (dtype ?x))
(= ?cv (Op (Constant ?val) (INil)))
(= ?exp_const ?cv)
(> ?val 1.44)
(< ?val 1.45)
)
(
(let ?kexp (Op (KernelExp ?shape ?x_stride ?out_stride ?dt) (ICons ?x (INil))))
(union ?exp2 ?kexp)
(set (dtype ?kexp) ?dt)
)
:name \"direct-exp-fusion\"
)",
),
]
}
fn cleanup(&self) -> bool {
false
}
fn extract<'a>(
&'a self,
egraph: &'a SerializedEGraph,
kind_children: &[&'a ENodeId],
input_enodes: Vec<&'a ENodeId>,
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
) -> (LLIROp, Vec<&'a ENodeId>) {
(
LLIROp::new::<dyn KernelOp>(Box::new(Self {
shape: extract_expr_list(egraph, kind_children[0], list_cache, expr_cache).unwrap(),
in_strides: extract_expr_list(egraph, kind_children[1], list_cache, expr_cache)
.unwrap(),
out_strides: extract_expr_list(egraph, kind_children[2], list_cache, expr_cache)
.unwrap(),
dtype: extract_dtype(egraph, kind_children[3]),
})),
input_enodes,
)
}
}
impl KernelOp for KernelExp {
fn compile(
&self,
stream: &Arc<CudaStream>,
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
) -> (
CudaFunction,
Arc<CudaModule>,
String,
(Expression, Expression, Expression),
(Expression, Expression, Expression),
Expression,
FxHashMap<char, CudaSlice<u8>>,
) {
let vars = self
.shape
.iter()
.flat_map(|e| e.dyn_vars())
.chain(self.in_strides.iter().flat_map(|e| e.dyn_vars()))
.chain(self.out_strides.iter().flat_map(|e| e.dyn_vars()))
.collect::<FxHashSet<_>>();
let dtype = cuda_dtype(self.dtype);
let includes = dtype_includes(&[self.dtype]);
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
let dyn_dims_param = if vars.is_empty() {
""
} else {
", const int* dyn_dims"
};
let n_elements = self
.shape
.iter()
.copied()
.product::<Expression>()
.to_kernel();
let out_idx = flatten_strides(&self.shape, &self.out_strides).to_kernel();
let in_idx = flatten_strides(&self.shape, &self.in_strides).to_kernel();
let kernel = format!(
"{includes}
{dyn_defines}
extern \"C\" {{
__global__ void exp_k({dtype} *out, const {dtype} *in{dyn_dims_param}) {{
long long const_z = (long long)blockIdx.x * blockDim.x + threadIdx.x;
if (const_z >= {n_elements}) return;
out[{out_idx}] = expf(in[{in_idx}]);
}}
}}"
);
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
(module.clone(), func.clone())
} else {
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
let module = stream.context().load_module(ptx).unwrap();
let func = module.load_function("exp_k").unwrap();
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
(module, func)
};
let out_size = self.shape.iter().copied().product::<Expression>();
(
func,
module,
kernel,
(out_size.ceil_div(128), 1.into(), 1.into()),
(out_size.min(128), 1.into(), 1.into()),
0.into(),
FxHashMap::default(),
)
}
fn output_size(&self) -> Expression {
self.shape.iter().copied().product()
}
fn output_bytes(&self) -> Expression {
(self.output_size() * self.dtype.bits()).ceil_div(8)
}
fn bytes_loaded(&self) -> Expression {
self.output_bytes()
}
fn bytes_stored(&self) -> Expression {
self.output_bytes()
}
fn flops(&self) -> Expression {
self.shape.iter().copied().product()
}
fn output_dtype(&self) -> DType {
self.dtype
}
fn kernel_name(&self) -> &'static str {
"Exp"
}
}
// KernelSigmoid: fused sigmoid = 1/(1+exp(-x))
// Single-kernel alternative to the 5-kernel Neg+Exp+Const+Add+Recip path.
#[derive(Default, Debug, Clone)]
pub struct KernelSigmoid {
shape: Vec<Expression>,
in_strides: Vec<Expression>,
out_strides: Vec<Expression>,
dtype: DType,
}
impl EgglogOp for KernelSigmoid {
fn sort(&self) -> SortDef {
sort(
OP_KIND,
"KernelSigmoid",
&[
("shape", ELIST),
("strides", ELIST),
("out_strides", ELIST),
("dtype", DTYPE),
],
)
}
fn n_inputs(&self) -> usize {
1
}
fn rewrites(&self) -> Vec<Rule> {
vec![
// Match the HLIR pattern directly: Recip(Add(Exp2(Mul(Mul(x, -1), log2e)), 1))
Rule::raw(
"(rule
(
(= ?neg1 (Op (Constant ?nv) (INil)))
(< ?nv -0.99)
(> ?nv -1.01)
(= ?neg_x (Op (Mul ?shape ?x_stride ?neg_stride ?neg_out_stride) (ICons ?x (ICons ?neg1 (INil)))))
(= ?log2e (Op (Constant ?lv) (INil)))
(> ?lv 1.44)
(< ?lv 1.45)
(= ?scaled (Op (Mul ?shape ?neg_out_stride ?log2e_stride ?scaled_stride) (ICons ?neg_x (ICons ?log2e (INil)))))
(= ?exp2 (Op (Exp2 ?shape ?scaled_stride ?exp_stride) (ICons ?scaled (INil))))
(= ?one (Op (Constant ?ov) (INil)))
(> ?ov 0.99)
(< ?ov 1.01)
(= ?plus_one (Op (Add ?shape ?exp_stride ?one_stride ?add_stride) (ICons ?exp2 (ICons ?one (INil)))))
(= ?sig_out (Op (Recip ?shape ?add_stride ?out_stride) (ICons ?plus_one (INil))))
(= ?dt (dtype ?x))
)
(
(let ?ksig (Op (KernelSigmoid ?shape ?x_stride ?out_stride ?dt) (ICons ?x (INil))))
(union ?sig_out ?ksig)
(set (dtype ?ksig) ?dt)
)
:name \"direct-sigmoid-fusion\"
)",
),
]
}
fn cleanup(&self) -> bool {
false
}
fn extract<'a>(
&'a self,
egraph: &'a SerializedEGraph,
kind_children: &[&'a ENodeId],
input_enodes: Vec<&'a ENodeId>,
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
) -> (LLIROp, Vec<&'a ENodeId>) {
(
LLIROp::new::<dyn KernelOp>(Box::new(Self {
shape: extract_expr_list(egraph, kind_children[0], list_cache, expr_cache).unwrap(),
in_strides: extract_expr_list(egraph, kind_children[1], list_cache, expr_cache)
.unwrap(),
out_strides: extract_expr_list(egraph, kind_children[2], list_cache, expr_cache)
.unwrap(),
dtype: extract_dtype(egraph, kind_children[3]),
})),
input_enodes,
)
}
}
impl KernelOp for KernelSigmoid {
fn compile(
&self,
stream: &Arc<CudaStream>,
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
) -> (
CudaFunction,
Arc<CudaModule>,
String,
(Expression, Expression, Expression),
(Expression, Expression, Expression),
Expression,
FxHashMap<char, CudaSlice<u8>>,
) {
let vars = self
.shape
.iter()
.flat_map(|e| e.dyn_vars())
.chain(self.in_strides.iter().flat_map(|e| e.dyn_vars()))
.chain(self.out_strides.iter().flat_map(|e| e.dyn_vars()))
.collect::<FxHashSet<_>>();
let dtype = cuda_dtype(self.dtype);
let includes = dtype_includes(&[self.dtype]);
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
let dyn_dims_param = if vars.is_empty() {
""
} else {
", const int* dyn_dims"
};
let n_elements = self
.shape
.iter()
.copied()
.product::<Expression>()
.to_kernel();
let out_idx = flatten_strides(&self.shape, &self.out_strides).to_kernel();
let in_idx = flatten_strides(&self.shape, &self.in_strides).to_kernel();
let kernel = format!(
"{includes}
{dyn_defines}
extern \"C\" {{
__global__ void sigmoid_k({dtype} *out, const {dtype} *in{dyn_dims_param}) {{
long long const_z = (long long)blockIdx.x * blockDim.x + threadIdx.x;
if (const_z >= {n_elements}) return;
out[{out_idx}] = 1.0f / (1.0f + expf(-in[{in_idx}]));
}}
}}"
);
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
(module.clone(), func.clone())
} else {
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
let module = stream.context().load_module(ptx).unwrap();
let func = module.load_function("sigmoid_k").unwrap();
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
(module, func)
};
let out_size = self.shape.iter().copied().product::<Expression>();
(
func,
module,
kernel,
(out_size.ceil_div(128), 1.into(), 1.into()),
(out_size.min(128), 1.into(), 1.into()),
0.into(),
FxHashMap::default(),
)
}
fn output_size(&self) -> Expression {
self.shape.iter().copied().product()
}
fn output_bytes(&self) -> Expression {
(self.output_size() * self.dtype.bits()).ceil_div(8)
}
fn bytes_loaded(&self) -> Expression {
self.output_bytes()
}
fn bytes_stored(&self) -> Expression {
self.output_bytes()
}
fn flops(&self) -> Expression {
// neg + exp + add + recip = ~4 ops per element
self.shape.iter().copied().product::<Expression>() * 4
}
fn output_dtype(&self) -> DType {
self.dtype
}
fn kernel_name(&self) -> &'static str {
"Sigmoid"
}
}

View File

@@ -1,6 +1,5 @@
pub mod host;
pub mod kernel;
pub mod logical;
pub mod runtime;
use std::{
ffi::{CStr, CString},

View File

@@ -1,71 +0,0 @@
use std::fmt::Debug;
use luminal::{
egglog_utils::api::{Rule, SortDef},
hlir::unary_sort,
op::EgglogOp,
};
pub type Ops = (Exp, Sigmoid);
#[derive(Debug, Default)]
pub struct Exp;
impl EgglogOp for Exp {
fn sort(&self) -> SortDef {
unary_sort("Exp")
}
fn cleanup(&self) -> bool {
true
}
fn rewrites(&self) -> Vec<Rule> {
vec![Rule::raw(
"(rule
(
(= ?exp_const (Op (Constant 1.442695) (INil)))
(= ?mul (Op (Mul ?shape ?x_stride ?const_stride ?intermediate_stride) (ICons ?x (ICons ?exp_const (INil)))))
(= ?exp2 (Op (Exp2 ?shape ?intermediate_stride ?out_stride) (ICons ?mul (INil))))
(= ?dt (dtype ?x))
)
(
(let ?exp (Op (Exp ?shape ?x_stride ?out_stride) (ICons ?x (INil))))
(union ?exp2 ?exp)
(set (dtype ?exp) ?dt)
)
)",
)]
}
}
#[derive(Default, Debug, Clone)]
pub struct Sigmoid;
impl EgglogOp for Sigmoid {
fn sort(&self) -> SortDef {
unary_sort("Sigmoid")
}
fn cleanup(&self) -> bool {
true
}
fn rewrites(&self) -> Vec<Rule> {
vec![Rule::raw("(rule
(
(= ?neg1 (Op (Constant -1.0) (INil)))
(= ?neg_input (Op (Mul ?input_range ?input_stride ?const_stride ?intermediate_stride) (ICons ?input (ICons ?neg1 (INil)))))
(= ?exp (Op (Exp ?input_range ?intermediate_stride ?exp_stride) (ICons ?neg_input (INil))))
(= ?one (Op (Constant 1.0) (INil)))
(= ?plus_one (Op (Add ?input_range ?exp_stride ?const_stride ?plus_one_stride) (ICons ?exp (ICons ?one (INil)))))
(= ?sig_out (Op (Recip ?input_range ?plus_one_stride ?out_stride) (ICons ?plus_one (INil))))
(= ?dt (dtype ?input))
)
(
(let ?sig (Op (Sigmoid ?input_range ?input_stride ?out_stride) (ICons ?input (INil))))
(union ?sig_out ?sig)
(set (dtype ?sig) ?dt)
)
:name \"sigmoid\"
)")]
}
}

View File

@@ -119,6 +119,14 @@ pub struct CudaRuntime {
active_bucket: usize,
/// Bucket definitions per dimension (empty = single-bucket mode)
dim_buckets: FxHashMap<char, Vec<DimBucket>>,
/// HLIR nodes that should never be consumed after execute().
/// Used for weight tensors shared via external device pointers.
persistent_hlir_nodes: FxHashSet<NodeIndex>,
/// Non-owning CudaSlice wrappers for external device pointers.
/// ManuallyDrop prevents cuMemFree — the external allocator (e.g. PyTorch) owns the memory.
external_buffers: FxHashMap<NodeIndex, std::mem::ManuallyDrop<CudaSlice<u8>>>,
}
impl CudaRuntime {
@@ -199,6 +207,32 @@ impl CudaRuntime {
self.changed_hlir.insert(id);
}
/// Set an external CUDA device pointer as input data. Zero-copy.
/// The caller must ensure the pointer remains valid for the runtime's lifetime.
///
/// # Safety
/// The device pointer must point to a valid CUDA allocation on the same device
/// as this runtime's stream, with at least `n_bytes` bytes available.
pub unsafe fn set_device_ptr(&mut self, id: impl ToId, device_ptr: u64, n_bytes: usize) {
debug_assert!(device_ptr != 0, "set_device_ptr called with null pointer");
let id = id.to_id();
// Create CudaSlice view via cudarc's upgrade_device_ptr.
// ManuallyDrop prevents cuMemFree on drop (external allocator owns this memory).
let slice = unsafe {
self.cuda_stream
.upgrade_device_ptr::<u8>(device_ptr, n_bytes)
};
self.external_buffers
.insert(id, std::mem::ManuallyDrop::new(slice));
self.hlir_buffers.insert(id, CudaInput::Ptr(device_ptr));
self.changed_hlir.insert(id);
}
/// Mark an HLIR node as persistent — its buffer won't be consumed after execute().
pub fn persist_hlir_node(&mut self, id: impl ToId) {
self.persistent_hlir_nodes.insert(id.to_id());
}
/// Find the LLIR producing node for an output tensor.
fn find_producer_node(&self, id: impl ToId) -> NodeIndex {
let id = id.to_id();
@@ -281,12 +315,15 @@ impl CudaRuntime {
.expect("Cannot find input tensor in runtime!")
{
CudaInput::Buffer(buf) => self.cuda_stream.clone_dtoh(buf).unwrap(),
CudaInput::Ptr(p) => {
// Raw pointer — need size from cached_buffer_ptrs or error
panic!(
"Cannot read raw pointer input (ptr=0x{:x}) — use Buffer variant",
p
);
CudaInput::Ptr(_) => {
// External device pointer — use the CudaSlice view from external_buffers
if let Some(ext) = self.external_buffers.get(hlir_node) {
self.cuda_stream.clone_dtoh(&**ext).unwrap()
} else {
panic!(
"Cannot read raw pointer input — no external_buffers entry for node"
);
}
}
}
} else {
@@ -302,6 +339,57 @@ impl CudaRuntime {
}
}
/// Resolve the device-side CudaSlice for an output tensor without copying to host.
/// Used by copy_output_to_device_ptr for DtoD transfers.
fn resolve_output_slice(&self, id: impl ToId) -> &CudaSlice<u8> {
let data_id = self.resolve_data_node(id);
let bucket = self.active();
if let Some(hlir_node) = bucket.llir_to_hlir.get(&data_id) {
match self
.hlir_buffers
.get(hlir_node)
.expect("Cannot find input tensor in runtime!")
{
CudaInput::Buffer(buf) => buf,
CudaInput::Ptr(_) => self
.external_buffers
.get(hlir_node)
.map(|ext| &**ext)
.expect("Cannot read raw pointer input — no external_buffers entry for node"),
}
} else {
bucket
.buffers
.get(&data_id)
.expect("Cannot find tensor in runtime!")
}
}
/// Copy output tensor data to an external CUDA device pointer (DtoD).
/// Much faster than get_f32 + HtoD for CUDA-to-CUDA workflows.
///
/// # Safety
/// The dest_ptr must be a valid CUDA device allocation with at least n_bytes available.
pub unsafe fn copy_output_to_device_ptr(&self, id: impl ToId, dest_ptr: u64, n_bytes: usize) {
debug_assert!(
dest_ptr != 0,
"copy_output_to_device_ptr called with null pointer"
);
let src_slice = self.resolve_output_slice(id);
let src_ptr = src_slice.device_ptr(&self.cuda_stream).0;
let copy_bytes = n_bytes.min(src_slice.len());
unsafe {
cudarc::driver::result::memcpy_dtod_async(
dest_ptr,
src_ptr,
copy_bytes,
self.cuda_stream.cu_stream(),
)
.expect("cuMemcpyDtoDAsync failed");
}
self.cuda_stream.synchronize().unwrap();
}
pub fn get_f32(&self, id: impl ToId) -> Vec<f32> {
let bytes = self.get_output_data(id);
let bytes = bytes.leak();
@@ -684,7 +772,7 @@ fn format_duration_precise(d: &std::time::Duration) -> String {
}
impl Runtime for CudaRuntime {
type Ops = (crate::logical::Ops, crate::kernel::Ops, crate::host::Ops);
type Ops = (crate::kernel::Ops, crate::host::Ops);
type CompileArg = Arc<CudaStream>;
type ExecReturn = ();
type ProfileMetric = Duration;
@@ -702,6 +790,8 @@ impl Runtime for CudaRuntime {
compiled_buckets: vec![CompiledBucket::new()],
active_bucket: 0,
dim_buckets: FxHashMap::default(),
persistent_hlir_nodes: FxHashSet::default(),
external_buffers: FxHashMap::default(),
}
}
@@ -938,10 +1028,23 @@ impl Runtime for CudaRuntime {
}
// Add input buffers (prefer HLIR weight buffers over intermediate placeholders)
for inp in exec_op.inputs.iter() {
if let Some(hlir_node) = bucket.llir_to_hlir.get(inp)
&& let Some(CudaInput::Buffer(buf)) = self.hlir_buffers.get(hlir_node)
{
buffer_map.insert(*inp, buf);
if let Some(hlir_node) = bucket.llir_to_hlir.get(inp) {
match self.hlir_buffers.get(hlir_node) {
Some(CudaInput::Buffer(buf)) => {
buffer_map.insert(*inp, buf);
}
Some(CudaInput::Ptr(_)) => {
if let Some(ext) = self.external_buffers.get(hlir_node) {
buffer_map.insert(*inp, &**ext);
}
}
None => {}
}
if !buffer_map.contains_key(inp)
&& let Some(buf) = bucket.buffers.get(inp)
{
buffer_map.insert(*inp, buf);
}
} else if let Some(buf) = bucket.buffers.get(inp) {
buffer_map.insert(*inp, buf);
}
@@ -952,25 +1055,43 @@ impl Runtime for CudaRuntime {
if let Entry::Vacant(e) = buffer_map.entry(extra_node) {
if let Some(buf) = bucket.buffers.get(&extra_node) {
e.insert(buf);
} else if let Some(hlir_node) = bucket.llir_to_hlir.get(&extra_node)
&& let Some(CudaInput::Buffer(buf)) = self.hlir_buffers.get(hlir_node)
{
e.insert(buf);
} else if let Some(hlir_node) = bucket.llir_to_hlir.get(&extra_node) {
match self.hlir_buffers.get(hlir_node) {
Some(CudaInput::Buffer(buf)) => {
e.insert(buf);
}
Some(CudaInput::Ptr(_)) => {
if let Some(ext) = self.external_buffers.get(hlir_node) {
e.insert(&**ext);
}
}
None => {}
}
}
}
}
// Resolve output aliases
for (&alias_node, &alias_target) in &bucket.output_alias_map {
if let std::collections::hash_map::Entry::Occupied(mut e) =
buffer_map.entry(alias_node)
{
if let Some(hlir_node) = bucket.llir_to_hlir.get(&alias_target)
&& let Some(CudaInput::Buffer(buf)) = self.hlir_buffers.get(hlir_node)
{
e.insert(buf);
} else if let Some(buf) = bucket.buffers.get(&alias_target) {
e.insert(buf);
}
if !buffer_map.contains_key(&alias_node) {
continue;
}
// Try HLIR buffer first (includes external device pointers)
let resolved: Option<&CudaSlice<u8>> =
if let Some(hlir_node) = bucket.llir_to_hlir.get(&alias_target) {
match self.hlir_buffers.get(hlir_node) {
Some(CudaInput::Buffer(buf)) => Some(buf),
Some(CudaInput::Ptr(_)) => {
self.external_buffers.get(hlir_node).map(|ext| &**ext)
}
None => None,
}
} else {
None
};
if let Some(buf) = resolved {
buffer_map.insert(alias_node, buf);
} else if let Some(buf) = bucket.buffers.get(&alias_target) {
buffer_map.insert(alias_node, buf);
}
}
let _span = span!(
@@ -1069,11 +1190,13 @@ impl Runtime for CudaRuntime {
.hlir_buffers
.keys()
.filter(|hlir_node| !inputs_with_outputs.contains(hlir_node))
.filter(|hlir_node| !self.persistent_hlir_nodes.contains(hlir_node))
.copied()
.collect();
for hlir_node in to_consume {
self.hlir_buffers.remove(&hlir_node);
self.external_buffers.remove(&hlir_node);
let bucket = &mut self.compiled_buckets[self.active_bucket];
if let Some(llir_node) = bucket.hlir_to_llir.get(&hlir_node) {
bucket.cached_buffer_ptrs.remove(llir_node);

View File

@@ -1,5 +1,4 @@
*.onnx
tests/llama38b_ref_logits.pt
__pycache__/
*.pyc
uv.lock

View File

@@ -1,4 +1,8 @@
A couple of short things to keep in mind
## Python Environment
- Always use `uv run` to execute Python tools (pytest, pre-commit, python) — never bare `pytest` or `python`
- Use `uv add` / `uv add --dev` / `uv remove` for dependencies — never hand-edit pyproject.toml deps
- After modifying Rust source files, rebuild before running Python tests: `maturin develop --release`
## Lessons Learned

View File

@@ -340,7 +340,7 @@ with matching shape tracker dimensions.
---
## Bug: TopK values wrong on CUDA (gather_elements with sliced non-contiguous indices)
## 2026-03-05 — TopK Values Wrong on CUDA (gather_elements with sliced non-contiguous indices)
1. **Symptom**: `test_topk_values` failed on CUDA — rows 0-1 were correct but rows 2+ returned
the value at column 0 of each row (all three top-k positions got the same value).
@@ -748,3 +748,11 @@ method rather than string-matching on Debug output. Additionally, when diagnosin
candidates rejected" during search, check whether the rejection is from actual float NaN
or from dtype misinterpretation — the key diagnostic is whether the NaN pattern is
identical across all attempts (dtype issue) vs varying (actual numerical issue).
## 2026-03-25 — KernelExp/KernelSigmoid: Fused CUDA Kernels for Precision
1. **Symptom**: `test_hf_llama3_full` (16-layer Llama-3.2-1B) had ~1e-4 max diff vs PyTorch.
2. **Root cause**: `exp(x)` was computed as `exp2(x * 1.442695)` — the constant truncated by `{:.6}` format + extra multiply adds rounding. Sigmoid was 5 separate kernels. SumReduce had naive accumulation.
3. **Why hard**: Per-operation error was ~1e-7 but compounded over 16 layers Ă— ~25 extra materializations. The egglog `Exp` rewrite depends on exact constant format matching.
4. **Fix**: Added `KernelExp` (uses `expf()`), `KernelSigmoid` (uses `1/(1+expf(-x))`), and Kahan summation in SumReduce. Each uses both `kernel_rewrite` and a direct egglog pattern match with range checks (e.g., `(> ?val 1.44) (< ?val 1.45)`) to bypass constant format dependency.
5. **Principle**: When decomposed CUDA kernel chains cause precision loss, add fused kernels via `kernel_rewrite`. For robustness, add BOTH the logical-op rewrite path AND a direct HLIR pattern match — the constant format in egglog can be fragile.

View File

@@ -3,15 +3,19 @@ name = "luminal_python"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.10"
requires-python = ">=3.12"
dependencies = [
"numpy>=2.0.2",
"torch>=2.10.0",
"onnx",
"onnxscript",
"safetensors",
"flash-attn-3>=3.0.0",
]
[tool.uv]
no-build-isolation-package = ["flash-attn"]
[[tool.uv.index]]
name = "pytorch-cu128"
url = "https://download.pytorch.org/whl/cu128"
@@ -21,6 +25,7 @@ explicit = true
torch = [
{ index = "pytorch-cu128", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
]
flash-attn-3 = { index = "pytorch-cu128" }
[build-system]
@@ -40,13 +45,21 @@ markers = [
[dependency-groups]
dev = [
"maturin>=1.0,<2.0",
"maturin-import-hook>=0.3.0",
"pytest>=9.0.2",
"pytest-profiling",
"snakeviz",
"maturin-import-hook>=0.3.0",
"pytest-randomly>=4.0.1",
"transformers>=4.40.0",
"transformers>=5.5.0,<6",
"diffusers>=0.35.0",
"onnxsim",
"tiktoken>=0.12.0",
"pydantic>=2.12.5",
"psutil>=7.2.2",
"modal>=1.3.5",
"pillow",
"flash-attn>=2.8.3",
]
flash-attention-4 = [
"nvidia-cutlass-dsl==4.1.0",
]

View File

@@ -14,7 +14,7 @@ uv run maturin develop --manifest-path rust/Cargo.toml --features cuda -r
# Run pytest with CUDA backend
echo "Step 3: Running pytest with CUDA backend..."
RUST_BACKTRACE=1 LUMINAL_BACKEND=cuda uv run pytest tests/test_hlir_ops.py tests/test_unary.py tests/test_llama3.py -m "not slow" -v
RUST_BACKTRACE=1 LUMINAL_BACKEND=cuda uv run pytest tests/test_llama3.py tests/test_hlir_ops.py tests/test_unary.py -v
echo ""
echo "=== Tests Complete ==="

View File

@@ -14,6 +14,6 @@ uv run maturin develop --manifest-path rust/Cargo.toml --features cuda -r
# Run pytest with CUDA backend and PT2 export mode
echo "Step 3: Running pytest with CUDA backend + PT2 export mode..."
RUST_BACKTRACE=1 LUMINAL_BACKEND=cuda LUMINAL_EXPORT_MODE=pt2 uv run pytest tests/test_hlir_ops.py tests/test_unary.py tests/test_llama3.py -m "not slow" -v
RUST_BACKTRACE=1 LUMINAL_BACKEND=cuda LUMINAL_EXPORT_MODE=pt2 uv run pytest tests/test_llama3.py tests/test_hlir_ops.py tests/test_unary.py -v
echo ""
echo "=== Tests Complete ==="

View File

@@ -1,34 +1,45 @@
use luminal::{
prelude::{
tracing::{Level, span, trace},
*,
},
shape::Expression,
visualization::ToDot,
};
use onnx_protobuf::{GraphProto, ModelProto};
use pyo3::prelude::*;
use std::{
collections::{HashMap, HashSet},
path::Path,
};
#[cfg(feature = "cuda")]
use crate::util::transpose_weight_data;
use crate::{
dispatch::process_onnx_nodes,
runtime::*,
util::{
DimParamMap, get_shape_for_onnx_value, get_shape_for_onnx_value_expr,
load_all_tensor_floats, load_initializer_as_f32,
},
};
use luminal::prelude::tracing::{trace, warn};
use luminal::{prelude::*, shape::Expression, visualization::ToDot};
use pyo3::prelude::*;
use std::collections::HashMap;
#[cfg(feature = "cuda")]
use std::collections::HashSet;
use crate::{runtime::RuntimeBackend, util::DimParamMap};
/// Common intermediate result from translating a model graph (ONNX or FX).
pub struct GraphTranslation {
pub graph: Graph,
pub tensor_ids: HashMap<String, NodeIndex>,
pub input_names: Vec<String>,
pub output_names: Vec<String>,
pub output_shape_exprs: Vec<Vec<Expression>>,
pub input_shape_exprs: Vec<Vec<Expression>>,
pub dim_param_map: DimParamMap,
}
/// Pre-loaded weight data from any model format.
///
/// NOTE: Currently assumes all data is F32. When the type system branch lands
/// with proper multi-dtype support, this struct (and all callers) will need
/// updating to carry dtype metadata alongside the raw data.
pub struct WeightData {
/// (Input node label, f32 data) for weights and constants.
pub weights: Vec<(String, Vec<f32>)>,
/// label → element count for ALL Input nodes (for CUDA dummy data sizing).
pub tensor_sizes: HashMap<String, usize>,
/// label → (device_ptr, n_bytes) for zero-copy CUDA weight sharing.
pub device_ptrs: HashMap<String, (u64, usize)>,
}
#[pyclass(unsendable)]
pub struct CompiledGraph {
pub graph: Graph,
pub runtime: RuntimeBackend,
pub tensor_ids: HashMap<String, NodeIndex>,
/// Cached label → NodeIndex map for O(1) lookups in set_weight_* methods.
label_map: HashMap<String, NodeIndex>,
pub input_names: Vec<String>,
pub output_names: Vec<String>,
pub output_shapes: Vec<Vec<usize>>,
@@ -38,218 +49,35 @@ pub struct CompiledGraph {
}
impl CompiledGraph {
/// Shared compilation pipeline for both ONNX and FX/PT2 graphs.
///
/// Takes a format-neutral `GraphTranslation` (produced by `translate_onnx` or
/// `translate_pt2`) and `WeightData`, builds the backend, loads weights, and
/// returns a ready-to-execute `CompiledGraph`.
pub fn parse_graph(
model: ModelProto,
model_directory: &Path,
translation: GraphTranslation,
weight_data: WeightData,
backend: &str,
search_iters: usize,
) -> Result<CompiledGraph, String> {
let _span = span!(Level::TRACE, "Onnx Graphing Parsing").entered();
let onnx_graph = &model.graph;
let mut cx = Graph::new();
// We will need to track the tensors we allocate so we can match up inputs and outputs in the graph
let mut tensors: HashMap<String, GraphTensor> = HashMap::new();
// Dynamic dimension tracking
let mut dim_param_map: DimParamMap = HashMap::new();
let mut next_char = 'a';
// This is the name of all of the tensors we will need to fill in parameters for
let initializer_names: HashSet<&str> = onnx_graph
.initializer
.iter()
.map(|t| t.name.as_str())
.collect();
// Input is an overloaded term in Onnx, it both means the inputs into the model, like the next token
// and the parameters of the layers, for this we don't want any of the parameters
// Input here is in the straightforward meaning, those tensors you feed into the network for a
// forward passd
let input_names: Vec<String> = onnx_graph
.input
.iter()
.filter(|inp| !initializer_names.contains(inp.name.as_str()))
.map(|inp| inp.name.clone())
.collect();
// Create "holding" tensors for the input
// this way they can be considered in the graph computation, and later as we do mutiple runs we can target them and swap out the values
// in them and not need to recompile the network
for input in &onnx_graph.input {
// Use expression-aware shape parsing to detect DimParam (dynamic dims)
let shape_exprs =
get_shape_for_onnx_value_expr(input, &mut dim_param_map, &mut next_char);
if shape_exprs.is_empty() {
// Fall back to concrete parsing (initializer shapes don't have DimParam)
let shape = get_shape_for_onnx_value(input);
if shape.is_empty() {
trace!("Input {} skipped because it is empty", input.name.clone());
continue;
}
let tensor = cx.named_tensor(input.name.clone(), shape);
trace!("Input {} added to tensors", input.name.clone());
tensors.insert(input.name.clone(), tensor);
continue;
}
// Always F32: Python runtime always sends float32 data via .float().numpy()
let tensor = cx.named_tensor(input.name.clone(), shape_exprs);
trace!("Input {} added to tensors", input.name.clone());
tensors.insert(input.name.clone(), tensor);
}
for init in &onnx_graph.initializer {
if !tensors.contains_key(&init.name) {
let mut shape: Vec<usize> = init.dims.iter().map(|&d| d as usize).collect();
// Scalar (0-dim) tensors have empty dims; represent as [1] in luminal
if shape.is_empty() {
shape = vec![1];
}
let tensor = cx.named_tensor(init.name.clone(), shape);
tensors.insert(init.name.clone(), tensor);
}
}
let mut weight_data = Vec::new();
let mut known_values: HashMap<String, Vec<f32>> = HashMap::new();
for init in &onnx_graph.initializer {
let n_elements: usize = init
.dims
.iter()
.map(|&d| d as usize)
.product::<usize>()
.max(1);
// MAGIC_NUMBER:
if n_elements <= 32 {
if let Some(floats) = load_initializer_as_f32(init) {
known_values.insert(init.name.clone(), floats);
} else {
// Questions
// Should this be fatal
// Should this be a print or a log
panic!("Unable to initializer values for {:?}", init.name);
}
}
}
// Shape expressions map for propagating symbolic shape values through
// Shape→Gather→Unsqueeze→Concat chains in dynamic ONNX graphs
let mut shape_exprs: HashMap<String, Vec<Expression>> = HashMap::new();
// Process computation nodes (Constant nodes add to weight_data)
process_onnx_nodes(
&onnx_graph.node,
&mut tensors,
&mut cx,
&mut weight_data,
&mut known_values,
&mut shape_exprs,
)
.map_err(|e| format!("process_onnx_nodes failed: {}", e))?;
// Mark weight/constant tensors as persistent so their buffers survive
// execute()'s input consumption. User inputs (like input_ids) are NOT persisted
// since they are re-set via set_input() before each execution.
for (name, gt) in &tensors {
if !input_names.contains(name) {
gt.persist();
}
}
let has_dynamic = !dim_param_map.is_empty();
// Mark graph outputs (must happen before build_search_space)
let mut output_names = Vec::new();
let mut output_shapes = Vec::new();
let mut output_shape_exprs = Vec::new();
for output_vi in &onnx_graph.output {
if let Some(&gt) = tensors.get(&output_vi.name) {
// Force contiguous if the shape tracker is a non-contiguous view
// (e.g. a view-only slice that changed dims without a gather).
// Without this, get_f32 returns the full underlying buffer.
let gt = if gt.shape != gt.shape.contiguous() {
let contiguous = gt * 1.0;
tensors.insert(output_vi.name.clone(), contiguous);
contiguous
} else {
gt
};
gt.output();
let dims = gt.dims();
// Store Expression-based shapes for dynamic resolution
output_shape_exprs.push(dims.clone());
// For concrete output shapes, resolve now; for dynamic, use placeholder
let shape: Vec<usize> = dims.iter().map(|d| d.to_usize().unwrap_or(1)).collect();
if shape.is_empty() {
return Err(format!(
"Output tensor '{}' has no shape information in the ONNX model",
output_vi.name
));
}
output_names.push(output_vi.name.clone());
output_shapes.push(shape);
}
}
// If we have dynamic dims, set initial values in the graph's dyn_map
// based on the concrete shapes from the example input used during export
if has_dynamic {
for input in &onnx_graph.input {
if initializer_names.contains(input.name.as_str()) {
continue;
}
let concrete_shape = get_shape_for_onnx_value(input);
let expr_shape =
get_shape_for_onnx_value_expr(input, &mut dim_param_map, &mut next_char);
for (expr, concrete) in expr_shape.iter().zip(concrete_shape.iter()) {
if expr.to_usize().is_none() {
// This is a symbolic dim — set initial value in dyn_map
// Extract the char variable from the expression
if let Some(ch) = dim_param_map
.values()
.find(|&&ch| Expression::from(ch) == *expr)
{
cx.set_dim(*ch, *concrete);
}
}
}
}
}
// Extract weight data from initializers (handles inline + external storage)
// Batch load reads each external file only once instead of per-tensor
for (name, floats) in load_all_tensor_floats(&onnx_graph.initializer, model_directory) {
if let Some(f) = floats {
weight_data.push((name, f));
}
}
// Collect tensor name -> NodeIndex mapping
let tensor_ids: HashMap<String, NodeIndex> = tensors
.iter()
.map(|(name, gt)| (name.clone(), gt.id))
.collect();
// Track which tensor names are Input nodes (includes those created during process_onnx_nodes)
let input_tensor_names: HashSet<String> = tensors.keys().cloned().collect();
let GraphTranslation {
mut graph,
tensor_ids,
input_names,
output_names,
output_shape_exprs,
input_shape_exprs,
dim_param_map,
} = translation;
let rt = match backend {
#[cfg(feature = "cuda")]
"cuda" => CompiledGraph::build_cuda_backend(
onnx_graph,
model_directory,
&mut tensors,
&mut weight_data,
&mut cx,
&input_tensor_names,
)?,
"native" => CompiledGraph::build_native_backend(
onnx_graph,
model_directory,
&mut tensors,
&mut weight_data,
&mut cx,
&input_tensor_names,
)?,
"cuda" | "gpu" => {
CompiledGraph::build_cuda_backend(&mut graph, &weight_data, search_iters)?
}
"native" | "cpu" => {
CompiledGraph::build_native_backend(&mut graph, &weight_data, search_iters)?
}
_ => {
#[cfg(feature = "cuda")]
{
@@ -274,22 +102,19 @@ impl CompiledGraph {
}
};
// Build input_shape_exprs for user inputs (needed for auto-dim detection)
let input_shape_exprs: Vec<Vec<Expression>> = input_names
// Resolve concrete output shapes from expressions
let output_shapes: Vec<Vec<usize>> = output_shape_exprs
.iter()
.map(|name| {
if let Some(&gt) = tensors.get(name) {
gt.dims()
} else {
vec![]
}
})
.map(|exprs| exprs.iter().map(|e| e.to_usize().unwrap_or(1)).collect())
.collect();
let label_map = CompiledGraph::build_label_map(&graph);
Ok(CompiledGraph {
graph: cx,
graph,
runtime: rt,
tensor_ids,
label_map,
input_names,
output_names,
output_shapes,
@@ -299,124 +124,154 @@ impl CompiledGraph {
})
}
/// Build a label → NodeIndex map for all Input nodes in the graph.
/// Used for efficient weight loading by label matching.
fn build_label_map(graph: &Graph) -> HashMap<String, NodeIndex> {
graph
.graph
.node_indices()
.filter_map(|node_id| {
(*graph.graph[node_id])
.as_any()
.downcast_ref::<luminal::hlir::Input>()
.map(|input| (input.label.clone(), node_id))
})
.collect()
}
#[cfg(feature = "cuda")]
fn build_cuda_backend(
onnx_graph: &protobuf::MessageField<GraphProto>,
model_directory: &Path,
tensors: &mut HashMap<String, GraphTensor>,
weight_data: &mut Vec<(String, Vec<f32>)>,
context: &mut Graph,
input_tensor_names: &HashSet<String>,
graph: &mut Graph,
weight_data: &WeightData,
search_iters: usize,
) -> Result<RuntimeBackend, String> {
let compute_n_elements = |name: &str| -> usize {
if let Some(vi) = onnx_graph.input.iter().find(|i| i.name == name) {
let shape = get_shape_for_onnx_value(vi);
shape.iter().product::<usize>()
} else if let Some(init) = onnx_graph.initializer.iter().find(|i| i.name == name) {
init.dims.iter().map(|&d| d as usize).product::<usize>()
} else if let Some((_, data)) = weight_data.iter().find(|(n, _)| n == name) {
data.len()
let device_ptrs = &weight_data.device_ptrs;
use luminal_cuda_lite::cudarc::driver::CudaContext;
use luminal_cuda_lite::runtime::CudaRuntime;
let cuda_ctx = CudaContext::new(0).map_err(|e| format!("CUDA context init failed: {e}"))?;
let stream = cuda_ctx.default_stream();
graph.build_search_space::<CudaRuntime>();
let mut rt = CudaRuntime::initialize(stream);
// Build label → NodeIndex map for device pointer matching.
let label_map = CompiledGraph::build_label_map(graph);
// For weights with device pointers: use them directly (zero-copy).
// This avoids allocating ~N GB of dummy data during search.
// The pointers survive search because profiling mode skips buffer consumption,
// and persist_hlir_node ensures they survive post-search execution too.
let mut device_ptr_nodes: HashSet<NodeIndex> = HashSet::new();
let mut matched_count = 0usize;
let mut missed_labels: Vec<String> = Vec::new();
for (label, &(ptr, n_bytes)) in device_ptrs {
if let Some(&node_id) = label_map.get(label) {
unsafe { rt.set_device_ptr(node_id, ptr, n_bytes) };
rt.persist_hlir_node(node_id);
device_ptr_nodes.insert(node_id);
matched_count += 1;
} else {
0
missed_labels.push(label.clone());
}
};
}
let total_device_bytes: usize = device_ptrs.values().map(|(_, n)| *n).sum();
trace!(
"[CUDA BUILD] Device pointers: {} matched, {} missed out of {} total ({:.3} GiB)",
matched_count,
missed_labels.len(),
device_ptrs.len(),
total_device_bytes as f64 / (1024.0 * 1024.0 * 1024.0),
);
if !missed_labels.is_empty() {
warn!(
"[CUDA BUILD] {} device-ptr labels did not match any Input node (first 10): {:?}",
missed_labels.len(),
&missed_labels[..missed_labels.len().min(10)]
);
let available: Vec<&String> = label_map.keys().take(10).collect();
warn!(
"[CUDA BUILD] Available label_map keys (first 10): {:?}",
available
);
}
// CUDA: Two-phase - set data BEFORE search for profiling
let (mut cuda_rt, _stream) = prepare_cuda(context)?;
// Set dummy data for ALL input tensors using small non-zero values (ones).
// Set dummy 1.0 data for remaining Input nodes (user inputs, constants without
// device pointers) for safe search profiling.
// IMPORTANT: Must use 1.0, NOT 0.0. Zero inputs cause NaN in many ops:
// - fmod(0, 0) = NaN (Mod)
// - recip(0) = inf → weight * inf = NaN (Div)
// - log(0) = -inf (Pow)
// - chain ops with zero produce NaN (Erf)
// The search's has_nan_outputs check then rejects ALL candidates, causing
// "Failed to find viable genome" errors. See LessonsLearned.md entry #1.
// Note: torch.compile passes model weights as additional ONNX inputs (not
// initializers), so these dummy values also cover weight tensors.
for (name, gt) in &mut *tensors {
if !input_tensor_names.contains(name) {
let mut dummy_total_elements = 0usize;
let mut dummy_count = 0usize;
for node_id in graph.graph.node_indices() {
if device_ptr_nodes.contains(&node_id) {
continue;
}
let n_elements = compute_n_elements(name);
if n_elements > 0 {
cuda_rt.set_data(gt.id, vec![1.0f32; n_elements]);
}
}
// Overwrite with real initializer data (for accurate profiling)
// Batch load reads each external file only once
let init_data = load_all_tensor_floats(&onnx_graph.initializer, model_directory);
for (i, (name, floats_opt)) in init_data.iter().enumerate() {
let floats = match floats_opt {
Some(f) => f,
None => continue,
};
if let Some(gt) = tensors.get(name) {
cuda_rt.set_data(gt.id, floats.clone());
}
let kn_name = format!("{}_kn", name);
if let Some(gt_kn) = tensors.get(&kn_name) {
let dims: Vec<usize> = onnx_graph.initializer[i]
.dims
.iter()
.map(|&d| d as usize)
.collect();
if dims.len() == 2 {
let transposed = transpose_weight_data(floats, dims[0], dims[1]);
cuda_rt.set_data(gt_kn.id, transposed);
if let Some(input) = (*graph.graph[node_id])
.as_any()
.downcast_ref::<luminal::hlir::Input>()
{
if let Some(&n) = weight_data.tensor_sizes.get(&input.label) {
if n > 0 {
dummy_total_elements += n;
dummy_count += 1;
rt.set_data(node_id, vec![1.0f32; n]);
}
}
}
}
trace!(
"[CUDA BUILD] Dummy data: {} nodes, {} elements ({:.3} GiB as f32)",
dummy_count,
dummy_total_elements,
(dummy_total_elements * 4) as f64 / (1024.0 * 1024.0 * 1024.0),
);
// Load constant node data
for (name, floats) in weight_data {
if let Some(gt) = tensors.get(name) {
cuda_rt.set_data(gt.id, floats.clone());
// Search (device-pointer weights are used directly; dummy data for the rest)
let mut rt = graph.search(rt, search_iters);
// Load real weight data for non-device-ptr weights (constants from PT2 archive, etc.)
let mut loaded_weight_elements = 0usize;
let mut loaded_weight_count = 0usize;
for (label, data) in &weight_data.weights {
if !device_ptrs.contains_key(label) {
if let Some(&node_id) = label_map.get(label) {
loaded_weight_elements += data.len();
loaded_weight_count += 1;
rt.set_data(node_id, data.clone());
}
}
}
trace!(
"[CUDA BUILD] Post-search weight load: {} weights, {} elements ({:.3} GiB as f32)",
loaded_weight_count,
loaded_weight_elements,
(loaded_weight_elements * 4) as f64 / (1024.0 * 1024.0 * 1024.0),
);
// Now finalize (search with profiling, data is available)
let cuda_rt = finalize_cuda(context, cuda_rt);
Ok(cuda_rt)
Ok(RuntimeBackend::Cuda(Box::new(rt)))
}
fn build_native_backend(
onnx_graph: &protobuf::MessageField<GraphProto>,
model_directory: &Path,
tensors: &mut HashMap<String, GraphTensor>,
weight_data: &mut Vec<(String, Vec<f32>)>,
context: &mut Graph,
_input_tensor_names: &HashSet<String>,
graph: &mut Graph,
weight_data: &WeightData,
search_iters: usize,
) -> Result<RuntimeBackend, String> {
let mut rt = initialize_native(context)?;
context.search(NativeRuntime::default(), 1);
graph.build_search_space::<NativeRuntime>();
let mut rt = graph.search(NativeRuntime::default(), search_iters);
// Set initializer data - these MUST exist after optimization (they're weights)
// Skip _kn variants - they might be optimized away
// Batch load reads each external file only once
for (name, floats_opt) in load_all_tensor_floats(&onnx_graph.initializer, model_directory) {
let floats = match floats_opt {
Some(f) => f,
None => continue,
};
if let Some(gt) = tensors.get(&name) {
rt.set_data(gt.id, floats);
// Load weight data after search
let label_map = CompiledGraph::build_label_map(graph);
for (label, data) in &weight_data.weights {
if let Some(&node_id) = label_map.get(label) {
rt.set_data(node_id, data.clone());
}
}
// Load constant node data, but skip _kn transposed variants
for (name, floats) in weight_data {
// Skip _kn transposed variants - might be optimized away
if name.ends_with("_kn") {
continue;
}
if let Some(gt) = tensors.get(name) {
rt.set_data(gt.id, floats.clone());
}
}
Ok(rt)
Ok(RuntimeBackend::Native(rt))
}
}
@@ -525,6 +380,94 @@ impl CompiledGraph {
Ok(())
}
/// Set input tensor data from a CPU host memory pointer (avoids Python list conversion).
/// The pointer must point to contiguous f32 data (from tensor.data_ptr() on a CPU float32 tensor).
fn set_input_from_ptr(&mut self, name: &str, ptr: u64, n_elements: usize) -> PyResult<()> {
debug_assert!(ptr != 0, "set_input_from_ptr called with null pointer");
let node_id = self.tensor_ids.get(name).ok_or_else(|| {
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!("Unknown input tensor: {}", name))
})?;
let data: Vec<f32> =
unsafe { std::slice::from_raw_parts(ptr as *const f32, n_elements).to_vec() };
self.runtime.set_data(*node_id, data);
Ok(())
}
/// Set input from a CUDA device pointer. Zero-copy on device.
/// The pointer must be a valid CUDA device allocation with at least n_bytes bytes.
#[cfg(feature = "cuda")]
fn set_input_device_ptr(
&mut self,
name: &str,
device_ptr: u64,
n_bytes: usize,
) -> PyResult<()> {
let node_id = self.tensor_ids.get(name).ok_or_else(|| {
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!("Unknown input tensor: {}", name))
})?;
match &mut self.runtime {
RuntimeBackend::Cuda(rt) => unsafe { rt.set_device_ptr(*node_id, device_ptr, n_bytes) },
_ => {
return Err(pyo3::exceptions::PyValueError::new_err(
"set_input_device_ptr requires CUDA backend",
));
}
}
Ok(())
}
/// Mark an input tensor as persistent (survives execute() calls).
/// Call this for weight tensors that should not be consumed after each execution.
fn persist_input(&mut self, name: &str) -> PyResult<()> {
let _node_id = *self.tensor_ids.get(name).ok_or_else(|| {
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!("Unknown input tensor: {}", name))
})?;
match &mut self.runtime {
#[cfg(feature = "cuda")]
RuntimeBackend::Cuda(rt) => rt.persist_hlir_node(_node_id),
RuntimeBackend::Native(_) => {} // Native: persist is handled at graph level
}
Ok(())
}
/// Set a weight tensor from a CUDA device pointer, matching by Input node label.
/// Also marks the weight as persistent. For PT2 weights (e.g. "fc1.weight").
#[cfg(feature = "cuda")]
fn set_weight_device_ptr(
&mut self,
label: &str,
device_ptr: u64,
n_bytes: usize,
) -> PyResult<()> {
let &node_id = self.label_map.get(label).ok_or_else(|| {
pyo3::exceptions::PyKeyError::new_err(format!("No Input node with label: {}", label))
})?;
match &mut self.runtime {
RuntimeBackend::Cuda(rt) => {
unsafe { rt.set_device_ptr(node_id, device_ptr, n_bytes) };
rt.persist_hlir_node(node_id);
}
_ => {
return Err(pyo3::exceptions::PyValueError::new_err(
"set_weight_device_ptr requires CUDA backend",
));
}
}
Ok(())
}
/// Set a weight tensor from a CPU host pointer, matching by Input node label.
fn set_weight_from_ptr(&mut self, label: &str, ptr: u64, n_elements: usize) -> PyResult<()> {
debug_assert!(ptr != 0, "set_weight_from_ptr called with null pointer");
let &node_id = self.label_map.get(label).ok_or_else(|| {
pyo3::exceptions::PyKeyError::new_err(format!("No Input node with label: {}", label))
})?;
let data: Vec<f32> =
unsafe { std::slice::from_raw_parts(ptr as *const f32, n_elements).to_vec() };
self.runtime.set_data(node_id, data);
Ok(())
}
/// Execute the graph.
fn run(&mut self) {
self.runtime.execute(&self.graph.dyn_map);
@@ -537,7 +480,7 @@ impl CompiledGraph {
})
}
/// Get output tensor data by name.
/// Get output tensor data by name (copies to host).
fn get_output(&self, name: &str) -> PyResult<Vec<f32>> {
let node_id = self.tensor_ids.get(name).ok_or_else(|| {
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!(
@@ -547,4 +490,25 @@ impl CompiledGraph {
})?;
Ok(self.runtime.get_f32(*node_id))
}
/// Copy output tensor data directly to a CUDA device pointer (DtoD).
/// Avoids the DtoH + HtoD round-trip of get_output() + .to(device).
#[cfg(feature = "cuda")]
fn copy_output_to_device_ptr(&self, name: &str, dest_ptr: u64, n_bytes: usize) -> PyResult<()> {
let node_id = self.tensor_ids.get(name).ok_or_else(|| {
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!(
"Unknown output tensor: {}",
name
))
})?;
match &self.runtime {
RuntimeBackend::Cuda(rt) => {
unsafe { rt.copy_output_to_device_ptr(*node_id, dest_ptr, n_bytes) };
Ok(())
}
_ => Err(pyo3::exceptions::PyValueError::new_err(
"copy_output_to_device_ptr requires CUDA backend",
)),
}
}
}

View File

@@ -1,5 +1,6 @@
mod compiled_graph;
mod dispatch;
mod onnx_translator;
mod ops_parse;
mod runtime;
mod util;
@@ -12,12 +13,9 @@ mod pt2_util;
mod translator;
use compiled_graph::CompiledGraph;
use onnx_protobuf::ModelProto;
use protobuf::Message;
use pt2_compiled_model::compile_pt2;
use pt2_compiled_model::process_pt2;
use pyo3::prelude::*;
use std::fs;
use std::path::Path;
use std::collections::HashMap;
fn validate_backend(backend: &str) -> PyResult<()> {
match backend {
@@ -48,46 +46,28 @@ fn validate_backend(backend: &str) -> PyResult<()> {
}
#[pyfunction]
#[pyo3(signature = (path, backend="native"))]
fn process_onnx(path: &str, backend: &str) -> PyResult<CompiledGraph> {
#[pyo3(signature = (path, backend="native", search_iters=10, weight_device_ptrs=None))]
fn process_onnx(
path: &str,
backend: &str,
search_iters: usize,
weight_device_ptrs: Option<HashMap<String, (u64, usize)>>,
) -> PyResult<CompiledGraph> {
validate_backend(backend)?;
parse_onnx(path, backend).map_err(pyo3::exceptions::PyRuntimeError::new_err)
}
fn parse_onnx(path: &str, backend: &str) -> Result<CompiledGraph, String> {
let data = fs::read(path).map_err(|e| format!("Failed to read file: {}", e))?;
let model_directory = Path::new(path).parent().unwrap_or(Path::new("."));
let model = ModelProto::parse_from_bytes(&data)
.map_err(|e| format!("Failed to parse Onnx Model: {}", e))?;
let opset_version = model
.opset_import
.iter()
.find(|entry| entry.domain.is_empty())
.map(|entry| entry.version);
match opset_version {
Some(20) => {}
Some(v) => {
return Err(format!(
"Unsupported ONNX opset version {v}. Only opset 20 is supported."
));
}
None => {
return Err(
"No ONNX opset version found in model. Only opset 20 is supported.".to_string(),
);
}
}
CompiledGraph::parse_graph(model, model_directory, backend)
onnx_translator::compile_onnx(
path,
backend,
weight_device_ptrs.unwrap_or_default(),
search_iters,
)
.map_err(pyo3::exceptions::PyRuntimeError::new_err)
}
#[pymodule]
fn luminal(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(process_onnx, m)?)?;
m.add_function(wrap_pyfunction!(compile_pt2, m)?)?;
m.add_function(wrap_pyfunction!(process_pt2, m)?)?;
m.add_class::<CompiledGraph>()?;
Ok(())
}

View File

@@ -0,0 +1,283 @@
use luminal::{
prelude::{
tracing::{Level, span, trace},
*,
},
shape::Expression,
};
use onnx_protobuf::ModelProto;
use protobuf::Message;
use std::{
collections::{HashMap, HashSet},
fs,
path::Path,
};
use crate::{
compiled_graph::{CompiledGraph, GraphTranslation, WeightData},
dispatch::process_onnx_nodes,
util::{
DimParamMap, get_shape_for_onnx_value, get_shape_for_onnx_value_expr,
load_all_tensor_floats, load_initializer_as_f32,
},
};
/// Load, validate, translate, and compile an ONNX model.
///
/// This is the ONNX counterpart of `pt2_compiled_model::compile_pt2()`.
pub fn compile_onnx(
path: &str,
backend: &str,
weight_device_ptrs: HashMap<String, (u64, usize)>,
search_iters: usize,
) -> Result<CompiledGraph, String> {
let data = fs::read(path).map_err(|e| format!("Failed to read file: {}", e))?;
let model_directory = Path::new(path).parent().unwrap_or(Path::new("."));
let model = ModelProto::parse_from_bytes(&data)
.map_err(|e| format!("Failed to parse ONNX model: {}", e))?;
let opset_version = model
.opset_import
.iter()
.find(|entry| entry.domain.is_empty())
.map(|entry| entry.version);
match opset_version {
Some(20) => {}
Some(v) => {
return Err(format!(
"Unsupported ONNX opset version {v}. Only opset 20 is supported."
));
}
None => {
return Err(
"No ONNX opset version found in model. Only opset 20 is supported.".to_string(),
);
}
}
let (translation, mut weights) = translate_onnx(model, model_directory)?;
weights.device_ptrs = weight_device_ptrs;
CompiledGraph::parse_graph(translation, weights, backend, search_iters)
}
/// Translate an ONNX model into a format-neutral GraphTranslation + WeightData.
pub fn translate_onnx(
model: ModelProto,
model_directory: &Path,
) -> Result<(GraphTranslation, WeightData), String> {
let _span = span!(Level::TRACE, "ONNX Graph Translation").entered();
let onnx_graph = &model.graph;
let mut cx = Graph::new();
let mut tensors: HashMap<String, GraphTensor> = HashMap::new();
// Dynamic dimension tracking
let mut dim_param_map: DimParamMap = HashMap::new();
let mut next_char = 'a';
// Separate initializers (weights) from true user inputs
let initializer_names: HashSet<&str> = onnx_graph
.initializer
.iter()
.map(|t| t.name.as_str())
.collect();
let input_names: Vec<String> = onnx_graph
.input
.iter()
.filter(|inp| !initializer_names.contains(inp.name.as_str()))
.map(|inp| inp.name.clone())
.collect();
// Create input tensors with dynamic dimension support
for input in &onnx_graph.input {
let shape_exprs = get_shape_for_onnx_value_expr(input, &mut dim_param_map, &mut next_char);
if shape_exprs.is_empty() {
let shape = get_shape_for_onnx_value(input);
if shape.is_empty() {
trace!("Input {} skipped because it is empty", input.name.clone());
continue;
}
let tensor = cx.named_tensor(input.name.clone(), shape);
trace!("Input {} added to tensors", input.name.clone());
tensors.insert(input.name.clone(), tensor);
continue;
}
let tensor = cx.named_tensor(input.name.clone(), shape_exprs);
trace!("Input {} added to tensors", input.name.clone());
tensors.insert(input.name.clone(), tensor);
}
// Create initializer (weight) tensors
for init in &onnx_graph.initializer {
if !tensors.contains_key(&init.name) {
let mut shape: Vec<usize> = init.dims.iter().map(|&d| d as usize).collect();
if shape.is_empty() {
shape = vec![1];
}
let tensor = cx.named_tensor(init.name.clone(), shape);
tensors.insert(init.name.clone(), tensor);
}
}
// Load small constants for constant folding
let mut known_values: HashMap<String, Vec<f32>> = HashMap::new();
for init in &onnx_graph.initializer {
let n_elements: usize = init
.dims
.iter()
.map(|&d| d as usize)
.product::<usize>()
.max(1);
if n_elements <= 32 {
if let Some(floats) = load_initializer_as_f32(init) {
known_values.insert(init.name.clone(), floats);
} else {
panic!("Unable to load initializer values for {:?}", init.name);
}
}
}
// Shape expressions for propagating symbolic shapes through ONNX graphs
let mut shape_exprs: HashMap<String, Vec<Expression>> = HashMap::new();
// Accumulates constant node data from process_onnx_nodes
let mut constant_data: Vec<(String, Vec<f32>)> = Vec::new();
// Process computation nodes
process_onnx_nodes(
&onnx_graph.node,
&mut tensors,
&mut cx,
&mut constant_data,
&mut known_values,
&mut shape_exprs,
)
.map_err(|e| format!("process_onnx_nodes failed: {}", e))?;
// Mark weight/constant tensors as persistent so their buffers survive execute()
for (name, gt) in &tensors {
if !input_names.contains(name) {
gt.persist();
}
}
// Mark graph outputs (must happen before build_search_space)
let mut output_names = Vec::new();
let mut output_shape_exprs = Vec::new();
for output_vi in &onnx_graph.output {
if let Some(&gt) = tensors.get(&output_vi.name) {
// Force contiguous if the shape tracker is a non-contiguous view
let gt = if gt.shape != gt.shape.contiguous() {
let contiguous = gt * 1.0;
tensors.insert(output_vi.name.clone(), contiguous);
contiguous
} else {
gt
};
gt.output();
let dims = gt.dims();
output_shape_exprs.push(dims.clone());
let shape: Vec<usize> = dims.iter().map(|d| d.to_usize().unwrap_or(1)).collect();
if shape.is_empty() {
return Err(format!(
"Output tensor '{}' has no shape information in the ONNX model",
output_vi.name
));
}
output_names.push(output_vi.name.clone());
}
}
// Set initial dynamic dimension values from example input shapes
let has_dynamic = !dim_param_map.is_empty();
if has_dynamic {
for input in &onnx_graph.input {
if initializer_names.contains(input.name.as_str()) {
continue;
}
let concrete_shape = get_shape_for_onnx_value(input);
let expr_shape =
get_shape_for_onnx_value_expr(input, &mut dim_param_map, &mut next_char);
for (expr, concrete) in expr_shape.iter().zip(concrete_shape.iter()) {
if expr.to_usize().is_none()
&& let Some(ch) = dim_param_map
.values()
.find(|&&ch| Expression::from(ch) == *expr)
{
cx.set_dim(*ch, *concrete);
}
}
}
}
// Build weight data: initializers + constants from process_onnx_nodes
let mut weights: Vec<(String, Vec<f32>)> = Vec::new();
for (name, floats) in load_all_tensor_floats(&onnx_graph.initializer, model_directory) {
if let Some(f) = floats {
weights.push((name, f));
}
}
weights.extend(constant_data);
// Build tensor sizes for CUDA dummy data allocation
let mut tensor_sizes: HashMap<String, usize> = HashMap::new();
for input in &onnx_graph.input {
if !initializer_names.contains(input.name.as_str()) {
let shape = get_shape_for_onnx_value(input);
let n: usize = shape.iter().product::<usize>().max(1);
tensor_sizes.insert(input.name.clone(), n);
}
}
for init in &onnx_graph.initializer {
let n: usize = init
.dims
.iter()
.map(|&d| d as usize)
.product::<usize>()
.max(1);
tensor_sizes.insert(init.name.clone(), n);
}
for (name, data) in &weights {
if !tensor_sizes.contains_key(name) {
tensor_sizes.insert(name.clone(), data.len());
}
}
// Collect tensor name → NodeIndex mapping
let tensor_ids: HashMap<String, NodeIndex> = tensors
.iter()
.map(|(name, gt)| (name.clone(), gt.id))
.collect();
// Build input_shape_exprs for user inputs (needed for auto-dim detection)
let input_shape_exprs: Vec<Vec<Expression>> = input_names
.iter()
.map(|name| {
if let Some(&gt) = tensors.get(name) {
gt.dims()
} else {
vec![]
}
})
.collect();
let translation = GraphTranslation {
graph: cx,
tensor_ids,
input_names,
output_names,
output_shape_exprs,
input_shape_exprs,
dim_param_map,
};
let weight_data = WeightData {
weights,
tensor_sizes,
device_ptrs: HashMap::new(),
};
Ok((translation, weight_data))
}

View File

@@ -1,20 +1,16 @@
use luminal::graph::Graph as LuminalGraph;
use luminal::prelude::*;
use pyo3::prelude::*;
use std::collections::HashMap;
#[cfg(feature = "cuda")]
use luminal_cuda_lite::cudarc::driver::CudaContext;
#[cfg(feature = "cuda")]
use luminal_cuda_lite::runtime::CudaRuntime;
use crate::compiled_graph::CompiledGraph;
use crate::compiled_graph::{CompiledGraph, GraphTranslation, WeightData};
use crate::pt2_parser;
use crate::pt2_schema;
use crate::runtime::RuntimeBackend;
use crate::translator;
use crate::util::DimParamMap;
/// Pre-loaded weight/constant data paired with tensor sizes.
type PreloadResult = (Vec<(String, Vec<f32>)>, HashMap<String, usize>);
fn resolve_dim_sizes(
sizes: &[pt2_schema::DimSize],
sym_to_char: &HashMap<String, char>,
@@ -39,32 +35,55 @@ fn resolve_dim_sizes(
}
#[pyfunction]
pub fn compile_pt2(
#[pyo3(signature = (pt2_path, weights_path, backend, search_iters, weight_device_ptrs=None))]
pub fn process_pt2(
pt2_path: &str,
weights_path: &str,
backend: &str,
search_iters: usize,
weight_device_ptrs: Option<HashMap<String, (u64, usize)>>,
) -> PyResult<CompiledGraph> {
compile_pt2_inner(pt2_path, weights_path, backend, search_iters)
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e:#}")))
compile_pt2(
pt2_path,
weights_path,
backend,
search_iters,
weight_device_ptrs.unwrap_or_default(),
)
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e:#}")))
}
fn compile_pt2_inner(
fn compile_pt2(
pt2_path: &str,
weights_path: &str,
backend: &str,
search_iters: usize,
weight_device_ptrs: HashMap<String, (u64, usize)>,
) -> anyhow::Result<CompiledGraph> {
let (translation, mut weights) = translate_pt2(pt2_path, weights_path)?;
weights.device_ptrs = weight_device_ptrs;
CompiledGraph::parse_graph(translation, weights, backend, search_iters)
.map_err(|e| anyhow::anyhow!(e))
}
/// Translate a PT2 exported model into a format-neutral GraphTranslation + WeightData.
pub fn translate_pt2(
pt2_path: &str,
weights_path: &str,
) -> anyhow::Result<(GraphTranslation, WeightData)> {
let parsed = pt2_parser::parse_pt2(pt2_path)?;
let translated = translator::translate(&parsed)?;
let mut graph = translated.graph;
// Set initial dynamic dim values from symbol ranges
for (sym_name, c) in &translated.sym_map.sym_to_char {
if let Some(rc) = translated.sym_map.ranges.get(sym_name) {
graph.set_dim(*c, rc.min_val as usize);
}
}
// Compute shape expressions from PT2 tensor metadata
let output_shape_exprs: Vec<Vec<Expression>> = translated
.output_ids
.iter()
@@ -98,45 +117,6 @@ fn compile_pt2_inner(
})
.collect();
let user_input_sizes: Vec<(NodeIndex, usize)> = translated
.user_input_ids
.iter()
.map(|(name, id)| {
let meta = parsed.tensor_meta(name);
let n_elements = meta
.map(|m| {
m.sizes
.iter()
.map(|s| s.hint().unwrap_or(1) as usize)
.product()
})
.unwrap_or(1);
(*id, n_elements)
})
.collect();
let runtime = match backend {
"cpu" | "native" => {
graph.build_search_space::<NativeRuntime>();
let mut rt = graph.search(NativeRuntime::default(), search_iters);
if !weights_path.is_empty() {
load_safetensors_native(&mut rt, &graph, weights_path)?;
}
load_constants_native(&mut rt, &graph, &parsed)?;
RuntimeBackend::Native(rt)
}
"cuda" | "gpu" => init_cuda_runtime(
&mut graph,
weights_path,
&parsed,
&user_input_sizes,
search_iters,
)?,
other => {
anyhow::bail!("Unknown backend: {other}. Use 'cpu' or 'cuda'.");
}
};
// Build tensor_ids from user inputs and outputs
let mut tensor_ids: HashMap<String, NodeIndex> = HashMap::new();
for (name, id) in &translated.user_input_ids {
@@ -146,80 +126,90 @@ fn compile_pt2_inner(
tensor_ids.insert(name.clone(), *id);
}
// Resolve concrete output shapes
let output_shapes: Vec<Vec<usize>> = output_shape_exprs
.iter()
.map(|exprs| exprs.iter().map(|e| e.to_usize().unwrap_or(1)).collect())
.collect();
// Pre-load weights and compute tensor sizes for CUDA dummy data
let mut weights: Vec<(String, Vec<f32>)> = Vec::new();
let mut tensor_sizes: HashMap<String, usize> = HashMap::new();
// Load safetensors weights
if !weights_path.is_empty() {
let (st_weights, st_sizes) = preload_safetensors(&graph, weights_path)?;
weights.extend(st_weights);
tensor_sizes.extend(st_sizes);
}
// Load PT2 constants from ZIP archive
let (const_weights, const_sizes) = preload_constants(&graph, &parsed)?;
weights.extend(const_weights);
tensor_sizes.extend(const_sizes);
// Add tensor sizes from PT2 metadata for parameters/buffers not in safetensors
// (covers case when weights are loaded via device pointers after compilation)
for input_kind in parsed.classify_inputs() {
let (graph_name, original_name) = match &input_kind {
pt2_parser::InputKind::Parameter {
graph_name,
original_name,
} => (graph_name.as_str(), original_name.as_str()),
pt2_parser::InputKind::Buffer {
graph_name,
original_name,
} => (graph_name.as_str(), original_name.as_str()),
pt2_parser::InputKind::UserInput { .. } => continue,
};
// Always use authoritative sizes from model.json tensor_meta,
// even if preload_constants inserted a different (possibly stripped) size.
if let Some(meta) = parsed.tensor_meta(graph_name) {
let n: usize = meta
.sizes
.iter()
.map(|s| s.hint().unwrap_or(1) as usize)
.product();
tensor_sizes.insert(original_name.to_string(), n);
}
}
// Add user input sizes
for (name, _id) in &translated.user_input_ids {
if !tensor_sizes.contains_key(name)
&& let Some(meta) = parsed.tensor_meta(name)
{
let n: usize = meta
.sizes
.iter()
.map(|s| s.hint().unwrap_or(1) as usize)
.product();
tensor_sizes.insert(name.clone(), n);
}
}
// Build dim_param_map from sym_map
let dim_param_map: DimParamMap = translated.sym_map.sym_to_char;
Ok(CompiledGraph {
let translation = GraphTranslation {
graph,
runtime,
tensor_ids,
input_names,
output_names,
output_shapes,
output_shape_exprs,
input_shape_exprs,
dim_param_map,
})
}
};
#[cfg(feature = "cuda")]
fn init_cuda_runtime(
graph: &mut LuminalGraph,
weights_path: &str,
parsed: &pt2_parser::ParsedPT2,
user_input_sizes: &[(NodeIndex, usize)],
search_iters: usize,
) -> anyhow::Result<RuntimeBackend> {
let cuda_ctx =
CudaContext::new(0).map_err(|e| anyhow::anyhow!("CUDA context init failed: {e}"))?;
let stream = cuda_ctx.default_stream();
let weight_data = WeightData {
weights,
tensor_sizes,
device_ptrs: HashMap::new(),
};
graph.build_search_space::<CudaRuntime>();
let mut rt = CudaRuntime::initialize(stream);
// Phase 1: Set ALL input nodes to safe dummy data (1.0) for search profiling.
// Real weights/constants may contain -inf (e.g. causal attention mask) which
// produce NaN in intermediate computations (e.g. -inf - (-inf) = NaN in softmax
// decomposition), causing the search's has_nan_outputs check to reject ALL
// candidates. We load real data only AFTER the search completes.
set_all_inputs_dummy_cuda(&mut rt, graph, weights_path, parsed, user_input_sizes)?;
let mut rt = graph.search(rt, search_iters);
if !weights_path.is_empty() {
load_safetensors_cuda(&mut rt, graph, weights_path)?;
}
load_constants_cuda(&mut rt, graph, parsed)?;
Ok(RuntimeBackend::Cuda(Box::new(rt)))
}
#[cfg(not(feature = "cuda"))]
fn init_cuda_runtime(
_graph: &mut LuminalGraph,
_weights_path: &str,
_parsed: &pt2_parser::ParsedPT2,
_user_input_sizes: &[(NodeIndex, usize)],
_search_iters: usize,
) -> anyhow::Result<RuntimeBackend> {
anyhow::bail!("CUDA support not compiled. Rebuild with --features cuda")
Ok((translation, weight_data))
}
// ---------------------------------------------------------------------------
// Weight loading
// Weight pre-loading helpers
// ---------------------------------------------------------------------------
fn load_safetensors_impl(
cx: &LuminalGraph,
file_path: &str,
mut set_data: impl FnMut(NodeIndex, Vec<f32>),
) -> anyhow::Result<()> {
/// Pre-load all safetensors weights that match Input nodes in the graph.
/// Returns (weight data, tensor sizes for all tensors in the file).
fn preload_safetensors(graph: &Graph, file_path: &str) -> anyhow::Result<PreloadResult> {
use memmap2::MmapOptions;
use safetensors::SafeTensors;
use std::fs::File;
@@ -229,95 +219,78 @@ fn load_safetensors_impl(
let st = SafeTensors::deserialize(&mmap)
.map_err(|e| anyhow::anyhow!("SafeTensors deserialize error: {e}"))?;
for node in cx.graph.node_indices() {
if let Some(input) = (*cx.graph[node])
let mut weights = Vec::new();
let mut sizes = HashMap::new();
// Get sizes for ALL tensors in the file (for dummy data allocation)
for (name, info) in st.tensors() {
let n: usize = info.shape().iter().product();
sizes.insert(name.to_string(), n);
}
// Load weight data for Input nodes that match safetensors tensor names
for node_id in graph.graph.node_indices() {
if let Some(input) = (*graph.graph[node_id])
.as_any()
.downcast_ref::<luminal::hlir::Input>()
&& let Ok(tensor) = st.tensor(&input.label)
{
let f32s = bytes_to_f32(tensor.data(), safetensors_dtype_to_pt2(tensor.dtype()));
set_data(node, f32s);
weights.push((input.label.clone(), f32s));
}
}
Ok(())
Ok((weights, sizes))
}
fn load_safetensors_native(
rt: &mut NativeRuntime,
cx: &LuminalGraph,
file_path: &str,
) -> anyhow::Result<()> {
load_safetensors_impl(cx, file_path, |node, data| rt.set_data(node, data))
}
#[cfg(feature = "cuda")]
fn load_safetensors_cuda(
rt: &mut CudaRuntime,
cx: &LuminalGraph,
file_path: &str,
) -> anyhow::Result<()> {
load_safetensors_impl(cx, file_path, |node, data| rt.set_data(node, data))
}
/// Set ALL input nodes to dummy 1.0 data for safe CUDA search profiling.
#[cfg(feature = "cuda")]
fn set_all_inputs_dummy_cuda(
rt: &mut CudaRuntime,
cx: &LuminalGraph,
weights_path: &str,
/// Pre-load all PT2 constants from the ZIP archive.
/// Returns (constant data, tensor sizes for all constants).
fn preload_constants(
_graph: &Graph,
parsed: &pt2_parser::ParsedPT2,
user_input_sizes: &[(NodeIndex, usize)],
) -> anyhow::Result<()> {
use memmap2::MmapOptions;
use safetensors::SafeTensors;
use std::fs::File;
) -> anyhow::Result<PreloadResult> {
let constants_config = match &parsed.constants_config {
Some(c) => c,
None => return Ok((Vec::new(), HashMap::new())),
};
let mut label_sizes: HashMap<String, usize> = HashMap::new();
let mut weights = Vec::new();
let mut sizes = HashMap::new();
if !weights_path.is_empty() {
let f = File::open(weights_path)?;
let mmap = unsafe { MmapOptions::new().map(&f)? };
let st = SafeTensors::deserialize(&mmap)
.map_err(|e| anyhow::anyhow!("SafeTensors deserialize error: {e}"))?;
for (name, info) in st.tensors() {
let n: usize = info.shape().iter().product();
label_sizes.insert(name.to_string(), n);
}
}
for (name, entry) in &constants_config.config {
let n: usize = entry
.tensor_meta
.sizes
.iter()
.map(|s| s.hint().unwrap_or(1) as usize)
.product();
sizes.insert(name.clone(), n);
if let Some(cc) = &parsed.constants_config {
for (name, entry) in &cc.config {
let n: usize = entry
.tensor_meta
.sizes
.iter()
.map(|s| s.hint().unwrap_or(1) as usize)
.product();
label_sizes.insert(name.clone(), n);
}
}
for node_id in cx.graph.node_indices() {
if let Some(input) = (*cx.graph[node_id])
.as_any()
.downcast_ref::<luminal::hlir::Input>()
{
if let Some(&n) = label_sizes.get(&input.label) {
if n > 0 {
rt.set_data(node_id, vec![1.0f32; n]);
}
let raw_bytes = match pt2_parser::read_constant_bytes(
&parsed.pt2_path,
&parsed.archive_prefix,
entry,
) {
Ok(b) => b,
Err(e) => {
eprintln!(
"[luminal] Warning: failed to load constant '{}': {:#}",
name, e
);
continue;
}
}
};
let f32_data = bytes_to_f32(&raw_bytes, entry.tensor_meta.dtype);
weights.push((name.clone(), f32_data));
}
for &(id, n_elements) in user_input_sizes {
rt.set_data(id, vec![1.0f32; n_elements]);
}
Ok(())
Ok((weights, sizes))
}
// ---------------------------------------------------------------------------
// Byte conversion helpers
// ---------------------------------------------------------------------------
/// Convert safetensors Dtype to PT2 dtype number.
fn safetensors_dtype_to_pt2(dtype: safetensors::Dtype) -> u32 {
match dtype {
@@ -381,60 +354,3 @@ fn bytes_to_f32(bytes: &[u8], dtype: u32) -> Vec<f32> {
}
}
}
fn load_constants_impl(
cx: &LuminalGraph,
parsed: &pt2_parser::ParsedPT2,
mut set_data: impl FnMut(NodeIndex, Vec<f32>),
) -> anyhow::Result<()> {
let constants_config = match &parsed.constants_config {
Some(c) => c,
None => return Ok(()),
};
for (name, entry) in &constants_config.config {
let raw_bytes = match pt2_parser::read_constant_bytes(
&parsed.pt2_path,
&parsed.archive_prefix,
entry,
) {
Ok(b) => b,
Err(e) => {
eprintln!(
"[luminal] Warning: failed to load constant '{}': {:#}",
name, e
);
continue;
}
};
let f32_data = bytes_to_f32(&raw_bytes, entry.tensor_meta.dtype);
for node_id in cx.graph.node_indices() {
if let Some(input) = (*cx.graph[node_id])
.as_any()
.downcast_ref::<luminal::hlir::Input>()
&& input.label == *name
{
set_data(node_id, f32_data.clone());
}
}
}
Ok(())
}
fn load_constants_native(
rt: &mut NativeRuntime,
cx: &LuminalGraph,
parsed: &pt2_parser::ParsedPT2,
) -> anyhow::Result<()> {
load_constants_impl(cx, parsed, |node, data| rt.set_data(node, data))
}
#[cfg(feature = "cuda")]
fn load_constants_cuda(
rt: &mut CudaRuntime,
cx: &LuminalGraph,
parsed: &pt2_parser::ParsedPT2,
) -> anyhow::Result<()> {
load_constants_impl(cx, parsed, |node, data| rt.set_data(node, data))
}

View File

@@ -79,11 +79,3 @@ pub fn finalize_cuda(context: &mut Graph, rt: CudaRuntime) -> RuntimeBackend {
let optimized_rt = context.search(rt, 10);
RuntimeBackend::Cuda(Box::new(optimized_rt))
}
/// Initialize a native (CPU) runtime using single-phase approach.
/// NativeRuntime validates Input nodes, so we must search first, then set data.
pub fn initialize_native(context: &mut Graph) -> Result<RuntimeBackend, String> {
context.build_search_space::<NativeRuntime>();
let rt = context.search(NativeRuntime::default(), 10);
Ok(RuntimeBackend::Native(rt))
}

View File

@@ -434,18 +434,6 @@ pub fn load_initializer_as_f32(init: &onnx_protobuf::TensorProto) -> Option<Vec<
}
}
/// Transpose weight data from [rows, cols] to [cols, rows] row-major layout
#[cfg(feature = "cuda")]
pub fn transpose_weight_data(data: &[f32], rows: usize, cols: usize) -> Vec<f32> {
let mut transposed = vec![0.0f32; rows * cols];
for r in 0..rows {
for c in 0..cols {
transposed[c * rows + r] = data[r * cols + c];
}
}
transposed
}
/// Get an integer attribute from a node, with a default value
pub fn get_int_attr(node: &NodeProto, name: &str, default: i64) -> i64 {
for attr in &node.attribute {

View File

@@ -2,11 +2,16 @@
# Import Python components
from .compiled_model import CompiledModel
from .main import luminal_backend
# Import Rust extension components (built by maturin)
# These are available directly in the package namespace
from .luminal import process_onnx, CompiledGraph, compile_pt2
from .luminal import CompiledGraph, process_onnx, process_pt2
from .main import luminal_backend
# Register DynamicCache pytree serialization once at import time
from .cache_utils import _register_cache_serialization
_register_cache_serialization()
# Re-export everything for clean package interface
__all__ = [
@@ -14,5 +19,5 @@ __all__ = [
"luminal_backend",
"process_onnx",
"CompiledGraph",
"compile_pt2",
"process_pt2",
]

View File

@@ -8,17 +8,27 @@ import torch
class CompiledModel:
"""Wrapper around CompiledGraph that handles PyTorch tensor conversion."""
def __init__(self, graph_result):
def __init__(
self, graph_result, weight_refs=None, input_names=None, user_indices=None
):
"""Initialize with a compiled CompiledGraph from Rust.
Args:
graph_result: The CompiledGraph from luminal_python.process_onnx() or compile_pt2()
graph_result: The CompiledGraph from luminal_python.process_onnx() or process_pt2()
weight_refs: List of PyTorch tensors to keep alive (prevents GC of shared weights)
input_names: Override for user input names. If None, uses graph_result.input_names.
user_indices: When torch.compile lifts model parameters into extra args,
this tells __call__ which arg positions are actual user inputs.
None means all args are user inputs (PT2 path).
"""
self._graph = graph_result
self._input_names = graph_result.input_names
self._input_names = input_names or graph_result.input_names
self._output_names = graph_result.output_names
self._output_shapes = graph_result.output_shapes
self._has_dynamic_dims = getattr(graph_result, "has_dynamic_dims", False)
self._weight_refs = weight_refs or []
self._user_indices = user_indices
self._is_cuda = graph_result.backend == "cuda"
def set_dim(self, param_name: str, value: int) -> None:
"""Set a dynamic dimension value by its param name."""
@@ -36,29 +46,42 @@ class CompiledModel:
"""Execute the compiled model with PyTorch tensor inputs.
Args:
*inputs: PyTorch tensors matching the model's input signature
*inputs: PyTorch tensors. When torch.compile lifts model parameters,
this includes both weights and user inputs. user_indices filters
to just the user inputs.
Returns:
Tuple of PyTorch tensors containing the model outputs
"""
if len(inputs) != len(self._input_names):
raise ValueError(
f"Expected {len(self._input_names)} inputs, got {len(inputs)}"
)
# Extract user inputs (torch.compile may pass lifted weights as extra args)
if self._user_indices is not None:
user_inputs = [inputs[i] for i in self._user_indices]
else:
if len(inputs) != len(self._input_names):
raise ValueError(
f"Expected {len(self._input_names)} inputs, got {len(inputs)}"
)
user_inputs = inputs
input_device = inputs[0].device if inputs else torch.device("cpu")
# Auto-detect dynamic dims from input shapes
if self._has_dynamic_dims:
input_shapes = [list(t.shape) for t in inputs]
input_shapes = [list(t.shape) for t in user_inputs]
self._graph.auto_set_dims_from_input_shapes(input_shapes)
# Set input data
for name, tensor in zip(self._input_names, inputs):
# Convert to contiguous float32 numpy array (move to CPU first for CUDA tensors)
arr = tensor.detach().cpu().contiguous().float().numpy()
data = arr.flatten().tolist()
self._graph.set_input(name, data)
# Set user input data via pointer (avoids Python list conversion).
# For CUDA inputs, keep references alive so the caching allocator doesn't
# recycle GPU memory before run() reads the pointers.
_input_refs = []
for name, tensor in zip(self._input_names, user_inputs):
if self._is_cuda and tensor.is_cuda:
t = tensor.detach().contiguous().float()
self._graph.set_input_device_ptr(name, t.data_ptr(), t.numel() * 4)
_input_refs.append(t)
else:
t = tensor.detach().cpu().contiguous().float()
self._graph.set_input_from_ptr(name, t.data_ptr(), t.numel())
# Run the graph
self._graph.run()
@@ -69,16 +92,22 @@ class CompiledModel:
else:
output_shapes = self._output_shapes
# Get outputs and convert back to PyTorch tensors on the same device as inputs
# Get outputs and convert back to PyTorch tensors on the same device as inputs.
# For CUDA: DtoD copy avoids the DtoH + HtoD round-trip.
outputs = []
for name, shape in zip(self._output_names, output_shapes):
data = self._graph.get_output(name)
tensor = (
torch.tensor(data, dtype=torch.float32)
.reshape(tuple(shape))
.to(input_device)
)
outputs.append(tensor)
if self._is_cuda and hasattr(self._graph, "copy_output_to_device_ptr"):
out = torch.empty(shape, dtype=torch.float32, device=input_device)
self._graph.copy_output_to_device_ptr(
name, out.data_ptr(), out.numel() * 4
)
else:
data = self._graph.get_output(name)
out = (
torch.tensor(data, dtype=torch.float32)
.reshape(tuple(shape))
.to(input_device)
)
outputs.append(out)
# Return as a tuple (TorchDynamo expects tuple return from backend callables)
return tuple(outputs)

View File

@@ -6,10 +6,61 @@ import torch._dynamo
import luminal
from .cache_utils import _register_cache_serialization
from .compiled_model import CompiledModel
# ---------------------------------------------------------------------------
# Shared helpers (used by both ONNX and PT2 paths)
# ---------------------------------------------------------------------------
def _detect_backend(example_inputs):
"""Detect backend from input device. Returns 'cuda' or 'native'."""
device = example_inputs[0].device if example_inputs else torch.device("cpu")
return "cuda" if device.type == "cuda" else "native"
def _collect_weight_pointers(weights, backend):
"""Partition weight tensors into CUDA device pointers and CPU host pointers.
Args:
weights: dict of name -> torch.Tensor
backend: "cuda", "gpu", "cpu", or "native"
Returns:
(keep_alive, device_ptrs, cpu_ptrs) where:
- keep_alive: list[Tensor] to prevent GC of shared weight memory
- device_ptrs: {name: (device_ptr, n_bytes)}
- cpu_ptrs: {name: (host_ptr, n_elements)}
"""
keep_alive = []
device_ptrs = {}
cpu_ptrs = {}
for name, tensor in weights.items():
t = tensor.detach().contiguous()
if t.dtype != torch.float32:
t = t.float()
if backend in ("cuda", "gpu") and t.is_cuda:
keep_alive.append(t)
device_ptrs[name] = (t.data_ptr(), t.numel() * 4)
else:
t = t.cpu() if t.is_cuda else t
keep_alive.append(t)
cpu_ptrs[name] = (t.data_ptr(), t.numel())
return keep_alive, device_ptrs, cpu_ptrs
def _load_cpu_weights(compiled_graph, cpu_weights):
"""Load CPU weight data into a compiled graph after Rust compilation."""
for name, (ptr, n_elements) in cpu_weights.items():
compiled_graph.set_weight_from_ptr(name, ptr, n_elements)
# ---------------------------------------------------------------------------
# torch.compile backend entry point
# ---------------------------------------------------------------------------
def luminal_backend(gm, example_inputs, options=None):
"""Luminal torch.compile backend.
@@ -30,17 +81,39 @@ def luminal_backend(gm, example_inputs, options=None):
)
opset = options.get("opset", 20)
_register_cache_serialization()
device = example_inputs[0].device if example_inputs else torch.device("cpu")
backend = "cuda" if device.type == "cuda" else "native"
backend = _detect_backend(example_inputs)
if export_mode == "pt2":
return _compile_pt2(gm, example_inputs, backend)
return _compile_onnx(gm, example_inputs, backend, opset=opset)
# ---------------------------------------------------------------------------
# ONNX compilation path
# ---------------------------------------------------------------------------
def _compile_onnx(gm, example_inputs, backend, opset=20):
"""ONNX compilation path."""
# Identify weight vs user inputs from FX graph placeholders.
# torch.compile lifts model parameters into graph inputs — we detect them by name prefix.
weight_tensors = {} # onnx_name -> tensor
user_indices = []
ph_idx = 0
for node in gm.graph.nodes:
if node.op == "placeholder":
onnx_name = f"input_{ph_idx}"
if node.name.startswith(("l_self_", "l_model_", "l__self_")):
weight_tensors[onnx_name] = example_inputs[ph_idx]
else:
user_indices.append(ph_idx)
ph_idx += 1
# Collect weight pointers for Rust (avoids duplicate GPU buffer allocation)
weight_refs, weight_device_ptrs, cpu_weights = _collect_weight_pointers(
weight_tensors, backend
)
tmp = tempfile.NamedTemporaryFile(suffix=".onnx", delete=False)
tmp_path = tmp.name
tmp.close()
@@ -54,11 +127,29 @@ def _compile_onnx(gm, example_inputs, backend, opset=20):
input_names=[f"input_{i}" for i in range(len(example_inputs))],
)
result = luminal.process_onnx(tmp_path, backend)
result = luminal.process_onnx(
tmp_path, backend, weight_device_ptrs=weight_device_ptrs
)
finally:
os.unlink(tmp_path)
compiled = CompiledModel(result)
return compiled
# Load CPU weights after compilation
_load_cpu_weights(result, cpu_weights)
# Only expose user input names to CompiledModel (weights are pre-loaded).
# user_indices tells __call__ which args from torch.compile are real user inputs.
user_input_names = [f"input_{i}" for i in user_indices]
return CompiledModel(
result,
weight_refs=weight_refs,
input_names=user_input_names,
user_indices=user_indices,
)
# ---------------------------------------------------------------------------
# PT2 compilation path (delegates to pt2 module)
# ---------------------------------------------------------------------------
def _compile_pt2(gm, example_inputs, backend):

View File

@@ -11,12 +11,10 @@ import shutil
import tempfile
import torch
from safetensors.torch import save_file
from .cache_utils import _register_cache_serialization
from .compiled_model import CompiledModel
from .luminal import compile_pt2 as _compile_pt2_rust
from .luminal import process_pt2
from .main import _collect_weight_pointers, _detect_backend, _load_cpu_weights
# ---------------------------------------------------------------------------
# Helpers
@@ -34,37 +32,61 @@ def _export_kwargs():
return kwargs
def _save_and_compile(ep, backend, search_iterations):
"""Save ExportedProgram + weights to temp files, compile via Rust, return CompiledModel."""
tmpdir = tempfile.mkdtemp(prefix="luminal_")
def _save_and_compile(ep_or_path, backend, search_iterations, original_weights=None):
"""Compile a PT2 model via Rust, return CompiledModel.
Args:
ep_or_path: Either an ExportedProgram (will be saved to a temp file) or
a path to an already-saved .pt2 file.
original_weights: Optional dict mapping state_dict key -> original PyTorch tensor.
When provided, device pointers are taken from these tensors instead of
ep.state_dict (which torch.export may have cloned), enabling true zero-copy
sharing with the original model's GPU memory.
"""
owns_tmpdir = not isinstance(ep_or_path, str)
tmpdir = tempfile.mkdtemp(prefix="luminal_") if owns_tmpdir else None
try:
pt2_path = os.path.join(tmpdir, "model.pt2")
weights_path = os.path.join(tmpdir, "weights.safetensors")
torch.export.save(ep, pt2_path)
state_dict = {k: v.float().clone() for k, v in ep.state_dict.items()}
if state_dict:
save_file(state_dict, weights_path)
if owns_tmpdir:
pt2_path = os.path.join(tmpdir, "model.pt2")
torch.export.save(ep_or_path, pt2_path)
weight_source = (
original_weights if original_weights else ep_or_path.state_dict
)
else:
weights_path = ""
pt2_path = ep_or_path
weight_source = original_weights or {}
compiled = _compile_pt2_rust(pt2_path, weights_path, backend, search_iterations)
return CompiledModel(compiled)
# Collect weight pointers for Rust (avoids duplicate GPU buffer allocation)
keep_alive, weight_device_ptrs, cpu_weights = _collect_weight_pointers(
weight_source, backend
)
# Compile with device pointers — search uses actual weight memory (zero-copy)
compiled = process_pt2(
pt2_path, "", backend, search_iterations, weight_device_ptrs
)
# Load CPU weights after compilation
_load_cpu_weights(compiled, cpu_weights)
return CompiledModel(compiled, weight_refs=keep_alive)
finally:
shutil.rmtree(tmpdir, ignore_errors=True)
if owns_tmpdir and tmpdir:
shutil.rmtree(tmpdir, ignore_errors=True)
def _reinternalize_lifted_params(gm, example_inputs):
"""Re-internalize lifted params as buffers so torch.export sees them as model state.
torch.compile lifts model parameters out of the module and passes them as
extra elements in example_inputs. The Rust PT2 compiler expects weights in
extra elements in example_inputs. The Rust PT2 compiler may expect weights in
the .pt2 state dict, not as runtime inputs. This function reverses the
lifting by registering them as buffers and replacing the placeholder nodes
with get_attr nodes.
Returns (gm, user_inputs) where user_inputs contains only the real inputs.
Returns (gm, user_inputs, original_weights) where:
- user_inputs contains only the real inputs
- original_weights maps buffer name -> original tensor (for zero-copy device pointers)
"""
buffer_indices = []
user_indices = []
@@ -80,12 +102,15 @@ def _reinternalize_lifted_params(gm, example_inputs):
user_indices.append(placeholder_idx)
placeholder_idx += 1
original_weights = {}
if buffer_nodes:
for i, node in enumerate(buffer_nodes):
attr_name = f"_luminal_param_{i}"
gm.register_buffer(
attr_name, example_inputs[buffer_indices[i]].detach().clone()
)
# Keep a reference to the original tensor for zero-copy device pointers.
# torch.export.export may clone the registered buffer, so we bypass
# the EP's state_dict and use the originals directly.
original_weights[attr_name] = example_inputs[buffer_indices[i]]
gm.register_buffer(attr_name, example_inputs[buffer_indices[i]].detach())
with gm.graph.inserting_before(node):
new_node = gm.graph.create_node("get_attr", attr_name)
new_node.meta = node.meta.copy()
@@ -99,7 +124,7 @@ def _reinternalize_lifted_params(gm, example_inputs):
if user_indices
else list(example_inputs)
)
return gm, user_inputs
return gm, user_inputs, original_weights
# ---------------------------------------------------------------------------
@@ -121,22 +146,20 @@ def compile(
model: A PyTorch nn.Module.
example_input: Example input tensor(s) for tracing.
search_iterations: Number of optimization search iterations.
backend: "cpu" or "cuda". Auto-detected if None.
backend: "native" or "cuda". Auto-detected if None.
export_kwargs: Extra kwargs passed to torch.export.export.
dynamic_dim: Which input dimension to make dynamic.
Returns:
A CompiledModel callable.
"""
_register_cache_serialization()
if dynamic_dim is None:
dynamic_dim = "auto"
if backend is None:
backend = os.environ.get("LUMINAL_BACKEND", None)
if backend is None:
backend = "cuda" if torch.cuda.is_available() else "cpu"
backend = "cuda" if torch.cuda.is_available() else "native"
kwargs = export_kwargs or {}
extra = _export_kwargs()
@@ -191,11 +214,43 @@ def pt2_backend(gm, example_inputs, backend=None):
Usage: torch.compile(model, backend=luminal.pt2.pt2_backend)
"""
_register_cache_serialization()
import gc
if backend is None:
device = example_inputs[0].device if example_inputs else torch.device("cpu")
backend = "cuda" if device.type == "cuda" else "cpu"
backend = _detect_backend(example_inputs)
gm = gm.eval()
gm, user_inputs = _reinternalize_lifted_params(gm, example_inputs)
gm, user_inputs, original_weights = _reinternalize_lifted_params(gm, example_inputs)
ep = torch.export.export(gm, tuple(user_inputs), **_export_kwargs())
return _save_and_compile(ep, backend, 10)
# When using shared memory (original_weights), strip large weight buffers from
# the EP before saving. The Rust side uses device pointers for these weights,
# not the .pt2 file data, so serializing them is pure IO waste (~32 GB for 8B
# models). Replacing with tiny CPU scalars shrinks the .pt2 to < 1 MB.
if original_weights:
for key in list(ep._state_dict.keys()):
if key in original_weights:
orig = ep._state_dict[key]
ep._state_dict[key] = torch.zeros(1, dtype=orig.dtype, device="cpu")
del orig
# Save the exported program to disk, then free it and the traced graph module
# BEFORE Rust compilation. torch.export clones the state_dict internally, so
# holding ep alive during compilation would double the weight memory on GPU.
tmpdir = tempfile.mkdtemp(prefix="luminal_")
pt2_path = os.path.join(tmpdir, "model.pt2")
torch.export.save(ep, pt2_path)
del ep, gm
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
try:
result = _save_and_compile(
pt2_path, backend, 10, original_weights=original_weights
)
return result
finally:
shutil.rmtree(tmpdir, ignore_errors=True)

View File

@@ -0,0 +1,176 @@
"""Helpers for caching Llama 3.1-8B test artifacts under pytest cache."""
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
import shutil
import torch
import transformers
from safetensors.torch import save_file
from transformers import AutoConfig, LlamaForCausalLM
# This code is designed to be deleted.
# We should not need to cache pt2 or onnx files to get reasonable compile performance.
MODEL_ID = "NousResearch/Meta-Llama-3.1-8B-Instruct"
INPUT_IDS_LIST = [1, 2, 3, 4]
INPUT_IDS = torch.tensor([INPUT_IDS_LIST], dtype=torch.long)
ARTIFACT_SCHEMA_VERSION = 1
ONNX_OPSET_VERSION = 20
PT2_STRICT = False
_REF_LOGITS_META_KEY = "luminal_python/llama38b_artifacts/ref_logits_v1"
_ONNX_META_KEY = "luminal_python/llama38b_artifacts/onnx_v1"
_PT2_META_KEY = "luminal_python/llama38b_artifacts/pt2_v1"
@dataclass(frozen=True)
class Llama38BArtifactBundle:
ref_logits_path: Path
onnx_path: Path | None = None
pt2_path: Path | None = None
weights_path: Path | None = None
def ensure_onnx_bundle(cache, cache_dir: Path) -> Llama38BArtifactBundle:
"""Ensure ONNX artifacts and shared reference logits exist in pytest cache."""
ref_logits_path = cache_dir / "ref_logits.pt"
onnx_dir = cache_dir / "onnx"
onnx_path = onnx_dir / "llama38b.onnx"
ref_metadata = _ref_logits_metadata()
onnx_metadata = _onnx_metadata()
needs_ref_logits = cache.get(_REF_LOGITS_META_KEY, None) != ref_metadata or not (
ref_logits_path.is_file()
)
needs_onnx = cache.get(_ONNX_META_KEY, None) != onnx_metadata or not (
onnx_path.is_file()
)
if needs_ref_logits or needs_onnx:
print(f"Generating cached ONNX artifacts for {MODEL_ID} in {cache_dir}")
if needs_ref_logits:
ref_logits_path.unlink(missing_ok=True)
if needs_onnx:
shutil.rmtree(onnx_dir, ignore_errors=True)
onnx_dir.mkdir(parents=True, exist_ok=True)
model = _load_model()
if needs_ref_logits:
ref_logits = _compute_ref_logits(model)
torch.save(ref_logits, ref_logits_path)
cache.set(_REF_LOGITS_META_KEY, ref_metadata)
if needs_onnx:
torch.onnx.export(
model,
(INPUT_IDS,),
str(onnx_path),
opset_version=ONNX_OPSET_VERSION,
input_names=["input_ids"],
output_names=["logits"],
)
cache.set(_ONNX_META_KEY, onnx_metadata)
return Llama38BArtifactBundle(ref_logits_path=ref_logits_path, onnx_path=onnx_path)
def ensure_pt2_bundle(cache, cache_dir: Path) -> Llama38BArtifactBundle:
"""Ensure PT2 artifacts and shared reference logits exist in pytest cache."""
ref_logits_path = cache_dir / "ref_logits.pt"
pt2_dir = cache_dir / "pt2"
pt2_path = pt2_dir / "llama38b.pt2"
weights_path = pt2_dir / "llama38b_weights.safetensors"
ref_metadata = _ref_logits_metadata()
pt2_metadata = _pt2_metadata()
needs_ref_logits = cache.get(_REF_LOGITS_META_KEY, None) != ref_metadata or not (
ref_logits_path.is_file()
)
needs_pt2 = cache.get(_PT2_META_KEY, None) != pt2_metadata or not (
pt2_path.is_file() and weights_path.is_file()
)
if needs_ref_logits or needs_pt2:
print(f"Generating cached PT2 artifacts for {MODEL_ID} in {cache_dir}")
if needs_ref_logits:
ref_logits_path.unlink(missing_ok=True)
if needs_pt2:
shutil.rmtree(pt2_dir, ignore_errors=True)
pt2_dir.mkdir(parents=True, exist_ok=True)
model = _load_model()
if needs_ref_logits:
ref_logits = _compute_ref_logits(model)
torch.save(ref_logits, ref_logits_path)
cache.set(_REF_LOGITS_META_KEY, ref_metadata)
if needs_pt2:
exported_program = torch.export.export(
model, (INPUT_IDS,), strict=PT2_STRICT
)
torch.export.save(exported_program, str(pt2_path))
state_dict = {
key: value.float().clone()
for key, value in exported_program.state_dict.items()
}
save_file(state_dict, str(weights_path))
cache.set(_PT2_META_KEY, pt2_metadata)
return Llama38BArtifactBundle(
ref_logits_path=ref_logits_path,
pt2_path=pt2_path,
weights_path=weights_path,
)
def _load_model() -> LlamaForCausalLM:
config = AutoConfig.from_pretrained(MODEL_ID)
config.use_cache = False
config._attn_implementation = "eager"
return LlamaForCausalLM.from_pretrained(
MODEL_ID,
config=config,
torch_dtype=torch.float32,
).eval()
def _compute_ref_logits(model: LlamaForCausalLM) -> torch.Tensor:
with torch.no_grad():
return model(INPUT_IDS).logits.clone()
def _ref_logits_metadata() -> dict[str, object]:
return {
"schema_version": ARTIFACT_SCHEMA_VERSION,
"model_id": MODEL_ID,
"input_ids": INPUT_IDS_LIST,
"device": "cpu",
"torch_dtype": "float32",
"use_cache": False,
"attn_implementation": "eager",
"torch_version": torch.__version__,
"transformers_version": transformers.__version__,
}
def _onnx_metadata() -> dict[str, object]:
return {
**_ref_logits_metadata(),
"artifact_type": "onnx",
"opset_version": ONNX_OPSET_VERSION,
}
def _pt2_metadata() -> dict[str, object]:
return {
**_ref_logits_metadata(),
"artifact_type": "pt2",
"strict": PT2_STRICT,
}

View File

@@ -10,7 +10,6 @@ import torch._dynamo
from luminal import luminal_backend
# ========== HuggingFace Qwen3ForCausalLM Tests ==========
@@ -56,12 +55,12 @@ def _run_hf_qwen3_test(config, device: torch.device, atol: float):
def test_hf_qwen3_tiny(device: torch.device):
"""HuggingFace Qwen3ForCausalLM -- tiny (64 hidden, 1 layer, ~70K params)."""
config = _make_qwen3_config(
hidden_size=64,
num_attention_heads=4,
num_key_value_heads=2,
hidden_size=32,
num_attention_heads=2,
num_key_value_heads=1,
num_hidden_layers=1,
intermediate_size=128,
vocab_size=256,
intermediate_size=64,
vocab_size=128,
)
_run_hf_qwen3_test(config, device, atol=1e-5)
@@ -161,167 +160,6 @@ def test_hf_qwen3_decode_loop_static(device: torch.device):
tokens.append(next_token)
def test_hf_qwen3_decode_loop_dynamic():
"""Decode loop with dynamic shapes -- compile once, run with varying seq_len.
Bypasses torch.compile to use luminal's dynamic dim support directly.
Exports ONNX once with dynamic_axes, then calls set_dim/set_input/run/get_output.
"""
import os
import tempfile
from transformers import Qwen3Config, Qwen3ForCausalLM
import luminal
config = Qwen3Config(
hidden_size=64,
num_attention_heads=4,
num_key_value_heads=2,
num_hidden_layers=1,
intermediate_size=128,
vocab_size=256,
max_position_embeddings=128,
use_cache=False,
attn_implementation="eager",
)
model = Qwen3ForCausalLM(config).eval()
# Export ONNX once with dynamic seq_len
dummy = torch.tensor([[1, 2, 3, 4]])
tmp = tempfile.NamedTemporaryFile(suffix=".onnx", delete=False)
tmp_path = tmp.name
tmp.close()
try:
torch.onnx.export(
model,
(dummy,),
tmp_path,
opset_version=20,
input_names=["input_ids"],
output_names=["logits"],
dynamic_axes={"input_ids": {1: "seq_len"}, "logits": {1: "seq_len"}},
)
graph = luminal.process_onnx(tmp_path, "native")
finally:
os.unlink(tmp_path)
assert graph.has_dynamic_dims, "Graph should have dynamic dims"
assert "seq_len" in graph.dim_params, f"Expected 'seq_len' in {graph.dim_params}"
tokens = [1, 2, 3, 4]
for step in range(3):
seq_len = len(tokens)
graph.set_dim("seq_len", seq_len)
# Set input as float (luminal works with f32 internally)
graph.set_input("input_ids", [float(t) for t in tokens])
graph.run()
# Get output and reshape using resolved shapes
output_shapes = graph.resolve_output_shapes()
logits_data = graph.get_output("logits")
logits = torch.tensor(logits_data, dtype=torch.float32).reshape(
output_shapes[0]
)
# Compare against PyTorch reference
input_ids = torch.tensor([tokens])
with torch.no_grad():
ref = model(input_ids)
assert torch.allclose(logits, ref.logits, atol=1e-4), (
f"step {step}: max_diff={torch.max(torch.abs(logits - ref.logits)).item():.2e}"
)
next_token = ref.logits[0, -1, :].argmax().item()
tokens.append(next_token)
def test_hf_qwen3_8b_decode_loop_dynamic():
"""Decode loop with dynamic shapes on real Qwen3-8B -- compile once, run with varying seq_len.
Full 8B model with pretrained weights, ONNX exported once with dynamic_axes
for seq_len, then decoded autoregressively without recompilation.
"""
import os
import tempfile
from transformers import AutoConfig, AutoTokenizer, Qwen3ForCausalLM
import luminal
backend = os.environ.get("LUMINAL_BACKEND", "cuda")
config = AutoConfig.from_pretrained("Qwen/Qwen3-8B")
config.use_cache = False
config._attn_implementation = "eager"
print("Loaded config")
model = Qwen3ForCausalLM.from_pretrained(
"Qwen/Qwen3-8B",
config=config,
torch_dtype=torch.float32,
).eval()
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B")
print("Loaded Model")
# Export ONNX once with dynamic seq_len
dummy = torch.tensor([[1, 2, 3, 4]])
tmp = tempfile.NamedTemporaryFile(suffix=".onnx", delete=False)
tmp_path = tmp.name
tmp.close()
try:
torch.onnx.export(
model,
(dummy,),
tmp_path,
opset_version=20,
input_names=["input_ids"],
output_names=["logits"],
dynamic_axes={"input_ids": {1: "seq_len"}, "logits": {1: "seq_len"}},
)
print("Exported onnx")
graph = luminal.process_onnx(tmp_path, backend)
finally:
os.unlink(tmp_path)
print("Exported Model")
assert graph.has_dynamic_dims, "Graph should have dynamic dims"
assert "seq_len" in graph.dim_params, f"Expected 'seq_len' in {graph.dim_params}"
prompt = "The capital of france is"
tokens = tokenizer.encode(prompt)
print(f"Prompt: '{prompt}' -> {len(tokens)} tokens: {tokens}")
num_generate = 3
for step in range(num_generate):
seq_len = len(tokens)
graph.set_dim("seq_len", seq_len)
graph.set_input("input_ids", [float(t) for t in tokens])
graph.run()
output_shapes = graph.resolve_output_shapes()
logits_data = graph.get_output("logits")
logits = torch.tensor(logits_data, dtype=torch.float32).reshape(
output_shapes[0]
)
# Compare against PyTorch reference
input_ids = torch.tensor([tokens])
with torch.no_grad():
ref = model(input_ids)
assert torch.allclose(logits, ref.logits, atol=1e-3), (
f"step {step}: max_diff={torch.max(torch.abs(logits - ref.logits)).item():.2e}"
)
next_token = ref.logits[0, -1, :].argmax().item()
tokens.append(next_token)
print(f"Step {step}: '{tokenizer.decode(tokens)}'")
def test_hf_qwen3_8b_full(device: torch.device):
"""HuggingFace Qwen3ForCausalLM -- full Qwen3-8B with real pretrained weights.

View File

@@ -1,21 +1,64 @@
"""Test configuration."""
# ruff: noqa: E402
import logging
import os
from pathlib import Path
import tempfile
from urllib.request import urlopen
import warnings
try:
import huggingface_hub
from transformers import logging as transformers_logging
except ImportError: # pragma: no cover - optional for non-HF test environments
huggingface_hub = None
transformers_logging = None
# Enable automatic Rust rebuilds during test development
try:
import maturin_import_hook
from maturin_import_hook.settings import MaturinSettings
import maturin_import_hook
from maturin_import_hook.settings import MaturinSettings
from maturin_import_hook.project_importer import DefaultProjectFileSearcher
backend = os.getenv("LUMINAL_BACKEND", "native").lower()
settings = MaturinSettings(features=["cuda"]) if backend == "cuda" else None
maturin_import_hook.install(settings=settings)
except ImportError:
pass # Hook not available, rebuilds will be manual
backend = os.getenv("LUMINAL_BACKEND", "native").lower()
settings = MaturinSettings(
release=(backend == "cuda"),
features=["cuda"] if backend == "cuda" else None,
skip_install=True,
)
searcher = DefaultProjectFileSearcher(
source_excluded_dir_names=(
DefaultProjectFileSearcher.DEFAULT_SOURCE_EXCLUDED_DIR_NAMES
| {".claude", "docs", ".github", "examples"}
),
)
maturin_import_hook.install(
settings=settings,
enable_automatic_installation=True,
file_searcher=searcher,
)
logging.getLogger("maturin_import_hook").disabled = True
logging.getLogger("maturin_import_hook.project_importer").disabled = True
# Silence noisy ONNX / onnxscript / httpx logging
for _logger_name in (
"onnxscript",
"onnx_ir",
"torch.onnx",
"httpx",
):
logging.getLogger(_logger_name).setLevel(logging.WARNING)
# Suppress torch.onnx diagnostics/progress output and torchvision warnings
os.environ.setdefault("TORCH_ONNX_VERBOSE", "0")
os.environ.setdefault("TORCH_ONNX_LOG_LEVEL", "ERROR")
warnings.filterwarnings("ignore", message=".*torchvision.*")
warnings.filterwarnings("ignore", module="torch.onnx")
import pytest
import torch
import torch._dynamo
from _llama38b_artifacts import ensure_onnx_bundle, ensure_pt2_bundle
torch.set_float32_matmul_precision("highest")
@@ -26,6 +69,139 @@ def device() -> torch.device:
return torch.device("cuda") if backend == "cuda" else torch.device("cpu")
@pytest.fixture(scope="session", autouse=True)
def configure_hf_test_output() -> None:
if transformers_logging is not None:
transformers_logging.disable_progress_bar()
if huggingface_hub is not None:
huggingface_hub.utils.disable_progress_bars()
@pytest.fixture
def configure_dynamo():
original_cache_size_limit = torch._dynamo.config.cache_size_limit
original_suppress_errors = torch._dynamo.config.suppress_errors
def _configure(
*, cache_size_limit: int | None = None, suppress_errors: bool | None = None
) -> None:
if cache_size_limit is not None:
torch._dynamo.config.cache_size_limit = cache_size_limit
if suppress_errors is not None:
torch._dynamo.config.suppress_errors = suppress_errors
yield _configure
torch._dynamo.config.cache_size_limit = original_cache_size_limit
torch._dynamo.config.suppress_errors = original_suppress_errors
@pytest.fixture(scope="session")
def _llama38b_cache_dir(pytestconfig: pytest.Config) -> Path:
return pytestconfig.cache.mkdir("luminal_llama38b_artifacts_v1")
@pytest.fixture(scope="session")
def _hf_multimodal_cache_dir(pytestconfig: pytest.Config) -> Path:
return pytestconfig.cache.mkdir("luminal_hf_multimodal_v1")
@pytest.fixture(scope="session")
def hf_multimodal_image_path(
pytestconfig: pytest.Config, _hf_multimodal_cache_dir: Path
) -> Path:
image_url = (
"https://huggingface.co/datasets/huggingface/documentation-images/"
"resolve/main/bee.jpg"
)
image_path = _hf_multimodal_cache_dir / "bee.jpg"
metadata_key = "luminal_python/hf_multimodal_image_v1"
metadata = {
"schema_version": 1,
"url": image_url,
"filename": image_path.name,
}
needs_download = pytestconfig.cache.get(metadata_key, None) != metadata or not (
image_path.is_file()
)
if not needs_download:
return image_path
image_path.parent.mkdir(parents=True, exist_ok=True)
tmp_path: Path | None = None
try:
with urlopen(image_url, timeout=60) as response:
with tempfile.NamedTemporaryFile(
dir=image_path.parent, delete=False
) as tmp_file:
tmp_path = Path(tmp_file.name)
while chunk := response.read(1024 * 1024):
tmp_file.write(chunk)
assert tmp_path is not None
tmp_path.replace(image_path)
except Exception:
if tmp_path is not None:
tmp_path.unlink(missing_ok=True)
raise
pytestconfig.cache.set(metadata_key, metadata)
return image_path
@pytest.fixture(scope="session")
def _llama38b_onnx_bundle(pytestconfig: pytest.Config, _llama38b_cache_dir: Path):
return ensure_onnx_bundle(pytestconfig.cache, _llama38b_cache_dir)
@pytest.fixture(scope="session")
def _llama38b_pt2_bundle(pytestconfig: pytest.Config, _llama38b_cache_dir: Path):
return ensure_pt2_bundle(pytestconfig.cache, _llama38b_cache_dir)
@pytest.fixture(scope="session")
def llama38b_ref_logits(request: pytest.FixtureRequest) -> torch.Tensor:
fixturenames = set(request.fixturenames)
uses_onnx = "llama38b_onnx_path" in fixturenames
uses_pt2 = bool({"llama38b_pt2_path", "llama38b_weights_path"} & fixturenames)
if uses_onnx and uses_pt2:
raise pytest.UsageError(
"llama38b_ref_logits cannot be requested with both ONNX and PT2 "
"artifact fixtures in the same test"
)
if uses_onnx:
bundle = request.getfixturevalue("_llama38b_onnx_bundle")
elif uses_pt2:
bundle = request.getfixturevalue("_llama38b_pt2_bundle")
else:
raise pytest.UsageError(
"llama38b_ref_logits must be requested alongside llama38b_onnx_path "
"or llama38b_pt2_path/llama38b_weights_path"
)
return torch.load(bundle.ref_logits_path, weights_only=True)
@pytest.fixture(scope="session")
def llama38b_onnx_path(_llama38b_onnx_bundle) -> Path:
assert _llama38b_onnx_bundle.onnx_path is not None
return _llama38b_onnx_bundle.onnx_path
@pytest.fixture(scope="session")
def llama38b_pt2_path(_llama38b_pt2_bundle) -> Path:
assert _llama38b_pt2_bundle.pt2_path is not None
return _llama38b_pt2_bundle.pt2_path
@pytest.fixture(scope="session")
def llama38b_weights_path(_llama38b_pt2_bundle) -> Path:
assert _llama38b_pt2_bundle.weights_path is not None
return _llama38b_pt2_bundle.weights_path
@pytest.fixture(autouse=True, scope="function")
def reset_torch_dynamo():
# We need this for two reasons

View File

@@ -1,57 +0,0 @@
"""Generate pre-computed artifacts for test_hf_llama38b_cached_onnx.
Run once:
uv run python tests/generate_llama38b_artifacts.py
Produces:
tests/llama38b.onnx — ONNX export of Llama 3.1-8B
tests/llama38b_ref_logits.pt — reference logits for input_ids=[1,2,3,4]
"""
from pathlib import Path
import torch
from transformers import AutoConfig, LlamaForCausalLM
SCRIPT_DIR = Path(__file__).resolve().parent
ONNX_PATH = SCRIPT_DIR / "llama38b.onnx"
LOGITS_PATH = SCRIPT_DIR / "llama38b_ref_logits.pt"
INPUT_IDS = torch.tensor([[1, 2, 3, 4]])
def main():
config = AutoConfig.from_pretrained("NousResearch/Meta-Llama-3.1-8B-Instruct")
config.use_cache = False
config._attn_implementation = "eager"
print("Loading model on CPU...")
model = LlamaForCausalLM.from_pretrained(
"NousResearch/Meta-Llama-3.1-8B-Instruct",
config=config,
torch_dtype=torch.float32,
).eval()
print("Computing reference logits...")
with torch.no_grad():
ref_logits = model(INPUT_IDS).logits.clone()
print(f"Reference logits shape: {ref_logits.shape}")
print(f"Saving reference logits to {LOGITS_PATH}")
torch.save(ref_logits, LOGITS_PATH)
print(f"Exporting ONNX to {ONNX_PATH}")
torch.onnx.export(
model,
(INPUT_IDS,),
str(ONNX_PATH),
opset_version=20,
input_names=["input_ids"],
output_names=["logits"],
)
print("Done.")
if __name__ == "__main__":
main()

View File

@@ -1,62 +0,0 @@
"""Generate pre-computed PT2 artifacts for test_hf_llama38b_cached.
Run once:
uv run python tests/generate_llama38b_pt2_artifacts.py
Produces:
tests/llama38b.pt2 — torch.export of Llama 3.1-8B
tests/llama38b_weights.safetensors — model weights
tests/llama38b_ref_logits.pt — reference logits for input_ids=[1,2,3,4]
(shared with ONNX artifact script)
"""
from pathlib import Path
import torch
from safetensors.torch import save_file
from transformers import AutoConfig, LlamaForCausalLM
SCRIPT_DIR = Path(__file__).resolve().parent
PT2_PATH = SCRIPT_DIR / "llama38b.pt2"
WEIGHTS_PATH = SCRIPT_DIR / "llama38b_weights.safetensors"
LOGITS_PATH = SCRIPT_DIR / "llama38b_ref_logits.pt"
INPUT_IDS = torch.tensor([[1, 2, 3, 4]])
def main():
config = AutoConfig.from_pretrained("NousResearch/Meta-Llama-3.1-8B-Instruct")
config.use_cache = False
config._attn_implementation = "eager"
print("Loading model on CPU...")
model = LlamaForCausalLM.from_pretrained(
"NousResearch/Meta-Llama-3.1-8B-Instruct",
config=config,
torch_dtype=torch.float32,
).eval()
# Generate reference logits (shared with ONNX artifact script)
if not LOGITS_PATH.exists():
print("Computing reference logits...")
with torch.no_grad():
ref_logits = model(INPUT_IDS).logits.clone()
print(f"Reference logits shape: {ref_logits.shape}")
print(f"Saving reference logits to {LOGITS_PATH}")
torch.save(ref_logits, LOGITS_PATH)
else:
print(f"Reference logits already exist at {LOGITS_PATH}, skipping")
print(f"Exporting PT2 to {PT2_PATH}")
ep = torch.export.export(model, (INPUT_IDS,), strict=False)
torch.export.save(ep, str(PT2_PATH))
print(f"Saving weights to {WEIGHTS_PATH}")
state_dict = {k: v.float().clone() for k, v in ep.state_dict.items()}
save_file(state_dict, str(WEIGHTS_PATH))
print("Done.")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,282 @@
"""Hugging Face causal-LM config option support tests.
These tests verify that luminal matches eager Hugging Face execution across
supported causal-LM config options using tiny public model definitions loaded
through AutoConfig.
"""
from __future__ import annotations
import importlib
import os
from dataclasses import dataclass
import pytest
import torch
import torch._dynamo
from transformers import AutoConfig, AutoModelForCausalLM, GenerationConfig
from transformers.generation.configuration_utils import ContinuousBatchingConfig
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from luminal import luminal_backend
# Attention implementations that require optional packages.
_ATTN_REQUIRES_PACKAGE: dict[str, str] = {
"flash_attention_2": "flash_attn",
"flash_attention_3": "flash_attn_3",
"flash_attention_4": "cutlass",
}
# Attention implementations known to be incompatible with tiny random models
# (e.g. head_dim < 16 or missing scaffolding).
_ATTN_SKIP_TINY_MODEL: set[str] = {"flex_attention", "paged_attention"}
_PAGED_ATTN_IMPLEMENTATIONS = tuple(
k for k in ALL_ATTENTION_FUNCTIONS.valid_keys() if k.startswith("paged|")
)
@dataclass(frozen=True)
class _CausalLMConfigCase:
case_id: str
model_id: str
input_ids: tuple[int, ...]
atol: float
rtol: float
_MODEL_CASES = [
_CausalLMConfigCase(
case_id="llama_3.2_1B",
model_id="meta-llama/Llama-3.2-1B",
input_ids=(1, 2, 3, 4),
atol=1e-5,
rtol=1e-5,
)
]
_ATTN_IMPLEMENTATIONS = tuple(
dict.fromkeys([None, "eager", *ALL_ATTENTION_FUNCTIONS.valid_keys()])
)
_CUDA_BACKEND_AVAILABLE = (
os.getenv("LUMINAL_BACKEND", "native").lower() == "cuda"
and torch.cuda.is_available()
)
def _attn_id(attn_impl: str | None) -> str:
return "default" if attn_impl is None else attn_impl
def _base_attn_impl(attn_impl: str | None) -> str | None:
if attn_impl is None:
return None
if attn_impl.startswith("paged|"):
return attn_impl.split("|", maxsplit=1)[1]
return attn_impl
def _attn_param(attn_impl: str | None, *, allow_paged: bool) -> pytest.ParameterSet:
marks = []
base_attn_impl = _base_attn_impl(attn_impl)
if base_attn_impl == "flash_attention_2":
marks.append(pytest.mark.skip(reason="flash_attention_2 is very slow"))
if attn_impl is not None and attn_impl.startswith("paged|") and not allow_paged:
marks.append(
pytest.mark.skip(reason=f"{attn_impl} requires continuous batching API")
)
if base_attn_impl in _ATTN_REQUIRES_PACKAGE:
pkg = _ATTN_REQUIRES_PACKAGE[base_attn_impl]
if importlib.util.find_spec(pkg) is None:
marks.append(
pytest.mark.skip(
reason=f"{attn_impl} requires package '{pkg}' which is not installed"
)
)
if base_attn_impl in _ATTN_SKIP_TINY_MODEL:
marks.append(
pytest.mark.skip(
reason=f"{attn_impl} is incompatible with tiny random test models"
)
)
kwargs = {"id": _attn_id(attn_impl)}
if marks:
kwargs["marks"] = marks
return pytest.param(attn_impl, **kwargs)
_ATTN_PREFILL_PARAMS = tuple(
_attn_param(attn_impl, allow_paged=False) for attn_impl in _ATTN_IMPLEMENTATIONS
)
_ATTN_GENERATE_BATCH_PARAMS = tuple(
_attn_param(attn_impl, allow_paged=True) for attn_impl in _ATTN_IMPLEMENTATIONS
)
def _compare_past_key_values(lhs, rhs, *, atol: float, rtol: float) -> None:
assert lhs is not None
assert rhs is not None
assert hasattr(lhs, "layers")
assert hasattr(rhs, "layers")
assert len(lhs.layers) == len(rhs.layers)
for lhs_layer, rhs_layer in zip(lhs.layers, rhs.layers):
torch.testing.assert_close(lhs_layer.keys, rhs_layer.keys, atol=atol, rtol=rtol)
torch.testing.assert_close(
lhs_layer.values, rhs_layer.values, atol=atol, rtol=rtol
)
def _instantiate_model(
model_id: str, *, use_cache: bool, attn_impl: str | None, device: torch.device
) -> AutoModelForCausalLM:
config = AutoConfig.from_pretrained(model_id)
config.use_cache = use_cache
if attn_impl is not None:
config._attn_implementation = attn_impl
return AutoModelForCausalLM.from_config(config).eval().to(device)
@pytest.mark.parametrize("model_case", _MODEL_CASES, ids=lambda case: case.case_id)
@pytest.mark.parametrize("use_cache", [False, True], ids=["no_cache", "cache"])
@pytest.mark.parametrize("attn_impl", _ATTN_PREFILL_PARAMS)
def test_hf_causal_lm_config_options_match_eager(
model_case: _CausalLMConfigCase,
use_cache: bool,
attn_impl: str | None,
device: torch.device,
configure_dynamo,
):
"""Compare luminal against eager HF across causal-LM config options."""
if use_cache:
configure_dynamo(cache_size_limit=2)
model = _instantiate_model(
model_case.model_id,
use_cache=use_cache,
attn_impl=attn_impl,
device=device,
)
input_ids = torch.tensor([model_case.input_ids], device=device)
with torch.no_grad():
eager_prefill = model(input_ids)
compiled_model = torch.compile(model, backend=luminal_backend)
with torch.no_grad():
compiled_prefill = compiled_model(input_ids)
torch.testing.assert_close(
compiled_prefill.logits,
eager_prefill.logits,
atol=model_case.atol,
rtol=model_case.rtol,
)
if not use_cache:
assert eager_prefill.past_key_values is None
assert compiled_prefill.past_key_values is None
return
assert eager_prefill.past_key_values is not None
assert compiled_prefill.past_key_values is not None
_compare_past_key_values(
compiled_prefill.past_key_values,
eager_prefill.past_key_values,
atol=model_case.atol,
rtol=model_case.rtol,
)
next_token = eager_prefill.logits[:, -1, :].argmax(dim=-1, keepdim=True)
with torch.no_grad():
eager_decode = model(next_token, past_key_values=eager_prefill.past_key_values)
compiled_decode = compiled_model(
next_token,
past_key_values=compiled_prefill.past_key_values,
)
torch.testing.assert_close(
compiled_decode.logits,
eager_decode.logits,
atol=model_case.atol,
rtol=model_case.rtol,
)
assert eager_decode.past_key_values is not None
assert compiled_decode.past_key_values is not None
_compare_past_key_values(
compiled_decode.past_key_values,
eager_decode.past_key_values,
atol=model_case.atol,
rtol=model_case.rtol,
)
@pytest.mark.parametrize("model_case", _MODEL_CASES, ids=lambda case: case.case_id)
@pytest.mark.parametrize("attn_impl", _ATTN_GENERATE_BATCH_PARAMS)
@pytest.mark.skipif(not _CUDA_BACKEND_AVAILABLE, reason="generate_batch requires CUDA")
def test_hf_generate_batch(
model_case: _CausalLMConfigCase,
attn_impl: str | None,
device: torch.device,
):
"""Compare generate_batch output for each attention variant against eager baseline."""
config = AutoConfig.from_pretrained(model_case.model_id)
config.use_cache = True
model = (
AutoModelForCausalLM.from_config(config)
.to(dtype=torch.bfloat16)
.eval()
.to(device)
)
gen_config = GenerationConfig(
do_sample=False,
max_new_tokens=5,
temperature=None,
top_p=None,
top_k=None,
)
cb_config = ContinuousBatchingConfig(
block_size=256,
use_cuda_graph=False,
)
inputs = [list(model_case.input_ids)]
# Baseline: eager generate_batch.
model.set_attn_implementation("eager")
eager_outputs = model.generate_batch(
inputs,
generation_config=gen_config,
continuous_batching_config=cb_config,
progress_bar=False,
warmup=False,
)
# Variant under test.
if attn_impl is not None:
model.set_attn_implementation(attn_impl)
variant_outputs = model.generate_batch(
inputs,
generation_config=gen_config,
continuous_batching_config=cb_config,
progress_bar=False,
warmup=False,
)
assert len(eager_outputs) == len(variant_outputs)
eager_out = next(iter(eager_outputs.values()))
variant_out = next(iter(variant_outputs.values()))
assert eager_out.error is None, f"Eager baseline failed: {eager_out.error}"
assert variant_out.error is None, f"Variant {attn_impl} failed: {variant_out.error}"
assert eager_out.generated_tokens == variant_out.generated_tokens, (
f"Token mismatch for {attn_impl}: eager={eager_out.generated_tokens} "
f"vs variant={variant_out.generated_tokens}"
)

View File

@@ -0,0 +1,168 @@
"""Hugging Face causal-LM experts backend smoke tests.
These tests load a real pretrained text-only MoE model and compare eager
PyTorch against torch.compile with the luminal backend across the standardized
`experts_implementation` backends.
"""
from __future__ import annotations
from dataclasses import dataclass
import os
import pytest
import torch
from transformers import AutoConfig, AutoModelForCausalLM
from luminal import luminal_backend
@dataclass(frozen=True)
class _HFMoeCase:
case_id: str
model_id: str
input_ids: tuple[tuple[int, ...], ...]
@dataclass(frozen=True)
class _HFMoeBundle:
case: _HFMoeCase
model: AutoModelForCausalLM
device: torch.device
dtype: torch.dtype
_MODEL_CASES = [
_HFMoeCase(
case_id="qwen15_moe_a27b",
model_id="Qwen/Qwen1.5-MoE-A2.7B",
input_ids=(
(1, 2, 3, 4, 5, 6, 7, 8),
(8, 7, 6, 5, 4, 3, 2, 1),
),
),
]
_EXPERTS_IMPLEMENTATIONS = ("eager", "batched_mm", "grouped_mm")
_CUDA_BACKEND_AVAILABLE = (
os.getenv("LUMINAL_BACKEND", "native").lower() == "cuda"
and torch.cuda.is_available()
)
pytestmark = [
pytest.mark.slow,
pytest.mark.skipif(
not _CUDA_BACKEND_AVAILABLE,
reason="HF MoE experts backend tests require the CUDA backend",
),
]
def _model_dtype(device: torch.device) -> torch.dtype:
if device.type != "cuda":
return torch.float32
return torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
def _output_tolerance(dtype: torch.dtype) -> float:
if dtype == torch.bfloat16:
return 5e-2
if dtype == torch.float16:
return 1e-2
return 1e-3
def _compare_router_logits(lhs, rhs, *, atol: float, rtol: float) -> None:
assert lhs is not None
assert rhs is not None
if isinstance(lhs, torch.Tensor):
torch.testing.assert_close(lhs, rhs, atol=atol, rtol=rtol)
return
assert len(lhs) == len(rhs)
for lhs_layer, rhs_layer in zip(lhs, rhs):
torch.testing.assert_close(lhs_layer, rhs_layer, atol=atol, rtol=rtol)
@pytest.fixture(scope="module", params=_MODEL_CASES, ids=lambda case: case.case_id)
def hf_moe_case(request: pytest.FixtureRequest) -> _HFMoeCase:
return request.param
@pytest.fixture
def hf_moe_bundle(hf_moe_case: _HFMoeCase) -> _HFMoeBundle:
case = hf_moe_case
device = torch.device("cuda")
dtype = _model_dtype(device)
config = AutoConfig.from_pretrained(case.model_id)
config.use_cache = False
config.output_router_logits = True
# Keep attention fixed so this test isolates experts backends.
config._attn_implementation = "eager"
model = (
AutoModelForCausalLM.from_pretrained(
case.model_id,
config=config,
torch_dtype=dtype,
)
.eval()
.to(device)
)
return _HFMoeBundle(case=case, model=model, device=device, dtype=dtype)
@pytest.mark.parametrize(
"experts_implementation",
_EXPERTS_IMPLEMENTATIONS,
ids=list(_EXPERTS_IMPLEMENTATIONS),
)
def test_hf_causal_lm_experts_implementation_matches_eager(
hf_moe_bundle: _HFMoeBundle, experts_implementation: str
):
model = hf_moe_bundle.model
model.set_experts_implementation(experts_implementation)
assert model.config._experts_implementation == experts_implementation
input_ids = torch.tensor(hf_moe_bundle.case.input_ids, device=hf_moe_bundle.device)
kwargs = {
"input_ids": input_ids,
"use_cache": False,
"output_router_logits": True,
"logits_to_keep": 1,
}
with torch.no_grad():
eager_output = model(**kwargs)
compiled_model = torch.compile(model, backend=luminal_backend)
with torch.no_grad():
compiled_output = compiled_model(**kwargs)
atol = _output_tolerance(hf_moe_bundle.dtype)
rtol = 1e-3
torch.testing.assert_close(
compiled_output.logits,
eager_output.logits,
atol=atol,
rtol=rtol,
)
_compare_router_logits(
compiled_output.router_logits,
eager_output.router_logits,
atol=atol,
rtol=rtol,
)
if eager_output.aux_loss is not None or compiled_output.aux_loss is not None:
assert eager_output.aux_loss is not None
assert compiled_output.aux_loss is not None
torch.testing.assert_close(
compiled_output.aux_loss,
eager_output.aux_loss,
atol=atol,
rtol=rtol,
)

View File

@@ -0,0 +1,242 @@
"""Hugging Face multimodal image-text-to-text smoke tests."""
from __future__ import annotations
from collections.abc import Callable
from dataclasses import dataclass
import os
from pathlib import Path
import pytest
import torch
from transformers import AutoConfig, AutoModelForImageTextToText, AutoProcessor
from luminal import luminal_backend
MODEL_ID = "google/gemma-3-4b-it"
_CUDA_BACKEND_AVAILABLE = (
os.getenv("LUMINAL_BACKEND", "native").lower() == "cuda"
and torch.cuda.is_available()
)
pytestmark = [
pytest.mark.slow,
pytest.mark.skipif(
not _CUDA_BACKEND_AVAILABLE,
reason="Gemma 3 multimodal tests require the CUDA backend",
),
]
@dataclass(frozen=True)
class HFMultimodalCase:
case_id: str
messages_builder: Callable[[Path], list[dict]]
max_new_tokens: int
expects_pixel_values: bool
@dataclass(frozen=True)
class Gemma3MultimodalBundle:
model: AutoModelForImageTextToText
processor: AutoProcessor
device: torch.device
dtype: torch.dtype
def _build_text_only_messages(_: Path) -> list[dict]:
return [
{
"role": "system",
"content": [{"type": "text", "text": "You are a concise assistant."}],
},
{
"role": "user",
"content": [
{
"type": "text",
"text": "In one short sentence, explain what a compiler does.",
}
],
},
]
def _build_image_to_text_messages(image_path: Path) -> list[dict]:
return [
{
"role": "system",
"content": [{"type": "text", "text": "You are a helpful assistant."}],
},
{
"role": "user",
"content": [
{"type": "image", "path": str(image_path)},
{"type": "text", "text": "Describe this image in one short sentence."},
],
},
]
MULTIMODAL_CASES = [
HFMultimodalCase(
case_id="chat_text_only",
messages_builder=_build_text_only_messages,
max_new_tokens=12,
expects_pixel_values=False,
),
HFMultimodalCase(
case_id="image_to_text",
messages_builder=_build_image_to_text_messages,
max_new_tokens=16,
expects_pixel_values=True,
),
]
def _model_dtype(device: torch.device) -> torch.dtype:
if device.type != "cuda":
return torch.float32
return torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
def _set_greedy_generation(model) -> None:
model.generation_config.temperature = None
model.generation_config.top_p = None
model.generation_config.top_k = None
def _move_to_device(
encoded: dict[str, torch.Tensor], device: torch.device, dtype: torch.dtype
) -> dict[str, torch.Tensor]:
result = {}
for key, value in encoded.items():
if not isinstance(value, torch.Tensor):
result[key] = value
continue
moved = value.to(device)
if moved.is_floating_point():
moved = moved.to(dtype=dtype)
result[key] = moved
return result
def _encode_case(
bundle: Gemma3MultimodalBundle,
case: HFMultimodalCase,
image_path: Path,
) -> dict[str, torch.Tensor]:
encoded = bundle.processor.apply_chat_template(
case.messages_builder(image_path),
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
)
encoded = _move_to_device(dict(encoded), bundle.device, bundle.dtype)
if case.expects_pixel_values:
assert "pixel_values" in encoded
assert "input_ids" in encoded
return encoded
def _generate_kwargs(
bundle: Gemma3MultimodalBundle,
encoded: dict[str, torch.Tensor],
max_new_tokens: int,
) -> dict:
tokenizer = bundle.processor.tokenizer
return dict(
**encoded,
max_new_tokens=max_new_tokens,
do_sample=False,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
)
def _logits_tolerance(dtype: torch.dtype) -> float:
if dtype == torch.bfloat16:
return 5e-2
if dtype == torch.float16:
return 1e-2
return 1e-3
@pytest.fixture(scope="module")
def gemma3_multimodal_bundle() -> Gemma3MultimodalBundle:
device = torch.device("cuda")
dtype = _model_dtype(device)
config = AutoConfig.from_pretrained(MODEL_ID)
processor = AutoProcessor.from_pretrained(MODEL_ID)
tokenizer = processor.tokenizer
if tokenizer.pad_token_id is None and tokenizer.eos_token is not None:
tokenizer.pad_token = tokenizer.eos_token
model = (
AutoModelForImageTextToText.from_pretrained(
MODEL_ID,
config=config,
torch_dtype=dtype,
)
.eval()
.to(device)
)
_set_greedy_generation(model)
return Gemma3MultimodalBundle(
model=model, processor=processor, device=device, dtype=dtype
)
class TestHFMultimodalGeneration:
@pytest.mark.parametrize("case", MULTIMODAL_CASES, ids=lambda case: case.case_id)
def test_generate_matches_eager(
self,
case: HFMultimodalCase,
gemma3_multimodal_bundle: Gemma3MultimodalBundle,
hf_multimodal_image_path: Path,
):
encoded = _encode_case(gemma3_multimodal_bundle, case, hf_multimodal_image_path)
kwargs = _generate_kwargs(
gemma3_multimodal_bundle, encoded, case.max_new_tokens
)
with torch.no_grad():
eager_output = gemma3_multimodal_bundle.model.generate(**kwargs)
compiled_model = torch.compile(
gemma3_multimodal_bundle.model, backend=luminal_backend
)
with torch.no_grad():
compiled_output = compiled_model.generate(**kwargs)
torch.testing.assert_close(compiled_output, eager_output)
@pytest.mark.parametrize("case", MULTIMODAL_CASES, ids=lambda case: case.case_id)
def test_forward_logits_match_eager(
self,
case: HFMultimodalCase,
gemma3_multimodal_bundle: Gemma3MultimodalBundle,
hf_multimodal_image_path: Path,
):
encoded = _encode_case(gemma3_multimodal_bundle, case, hf_multimodal_image_path)
with torch.no_grad():
eager_out = gemma3_multimodal_bundle.model(**encoded)
compiled_model = torch.compile(
gemma3_multimodal_bundle.model, backend=luminal_backend
)
with torch.no_grad():
compiled_out = compiled_model(**encoded)
atol = _logits_tolerance(gemma3_multimodal_bundle.dtype)
torch.testing.assert_close(
compiled_out.logits,
eager_out.logits,
atol=atol,
rtol=1e-3,
)

View File

@@ -0,0 +1,288 @@
"""Hugging Face text-generation smoke tests.
These tests intentionally download real Hugging Face checkpoints, configs,
and tokenizers. They compare eager PyTorch output against torch.compile
with the luminal backend to verify numerical equivalence.
"""
from __future__ import annotations
from dataclasses import dataclass
import pytest
import torch
import torch._dynamo
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from luminal import luminal_backend
SIMPLE_DENSE_MODELS = [
"meta-llama/Llama-3.2-1B",
"meta-llama/Llama-3.1-1B",
"Qwen/Qwen3-8B",
]
@dataclass(frozen=True)
class HFTextGenerationBundle:
model: AutoModelForCausalLM
tokenizer: AutoTokenizer
device: torch.device
def _load_model_and_tokenizer(
model_id: str, device: torch.device
) -> tuple[AutoModelForCausalLM, AutoTokenizer]:
"""Load a pretrained HF causal LM and its tokenizer, ready for generation."""
config = AutoConfig.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
if tokenizer.pad_token_id is None and tokenizer.eos_token is not None:
tokenizer.pad_token = tokenizer.eos_token
dtype = torch.float16 if device.type == "cuda" else torch.float32
model = (
AutoModelForCausalLM.from_pretrained(
model_id,
config=config,
torch_dtype=dtype,
)
.eval()
.to(device)
)
model.generation_config.temperature = None
model.generation_config.top_p = None
model.generation_config.top_k = None
return model, tokenizer
def _encode(
tokenizer: AutoTokenizer, prompt: str, device: torch.device
) -> dict[str, torch.Tensor]:
"""Tokenize a prompt and move tensors to device."""
encoded = tokenizer(prompt, return_tensors="pt")
result = {"input_ids": encoded["input_ids"].to(device)}
if encoded.get("attention_mask") is not None:
result["attention_mask"] = encoded["attention_mask"].to(device)
return result
def _generate_kwargs(
tokenizer: AutoTokenizer,
encoded: dict[str, torch.Tensor],
max_new_tokens: int = 6,
) -> dict:
"""Build kwargs dict for model.generate()."""
return dict(
**encoded,
max_new_tokens=max_new_tokens,
do_sample=False,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
)
@pytest.fixture
def hf_text_bundle(model_id: str, device: torch.device) -> HFTextGenerationBundle:
model, tokenizer = _load_model_and_tokenizer(model_id, device)
return HFTextGenerationBundle(model=model, tokenizer=tokenizer, device=device)
@pytest.mark.slow
class TestHFGeneration:
"""End-to-end tests comparing eager PyTorch against torch.compile with luminal."""
@pytest.mark.parametrize("model_id", SIMPLE_DENSE_MODELS)
def test_capital_of_france(self, hf_text_bundle: HFTextGenerationBundle):
"""Basic greedy generation -- the original smoke test."""
encoded = _encode(
hf_text_bundle.tokenizer,
"What is the capital of France ",
hf_text_bundle.device,
)
kwargs = _generate_kwargs(hf_text_bundle.tokenizer, encoded)
with torch.no_grad():
eager_output = hf_text_bundle.model.generate(**kwargs)
compiled_model = torch.compile(hf_text_bundle.model, backend=luminal_backend)
with torch.no_grad():
compiled_output = compiled_model.generate(**kwargs)
torch.testing.assert_close(compiled_output, eager_output)
@pytest.mark.parametrize("model_id", SIMPLE_DENSE_MODELS)
def test_forward_logits(self, hf_text_bundle: HFTextGenerationBundle):
"""Forward pass only -- compare raw logits, not generated tokens."""
encoded = _encode(
hf_text_bundle.tokenizer, "The quick brown fox", hf_text_bundle.device
)
with torch.no_grad():
eager_out = hf_text_bundle.model(**encoded)
compiled_model = torch.compile(hf_text_bundle.model, backend=luminal_backend)
with torch.no_grad():
compiled_out = compiled_model(**encoded)
dtype = next(hf_text_bundle.model.parameters()).dtype
atol = 1e-2 if dtype == torch.float16 else 1e-3
torch.testing.assert_close(
compiled_out.logits, eager_out.logits, atol=atol, rtol=1e-3
)
@pytest.mark.parametrize("model_id", SIMPLE_DENSE_MODELS[:1])
@pytest.mark.parametrize(
"prompt",
[
"Hi",
"What is the capital of France",
"Explain the theory of general relativity in simple terms that a high school student could understand",
],
ids=["short", "medium", "long"],
)
def test_variable_length_prompts(
self,
prompt: str,
hf_text_bundle: HFTextGenerationBundle,
):
"""Generate with prompts of different lengths -- tests dynamic shape handling."""
encoded = _encode(hf_text_bundle.tokenizer, prompt, hf_text_bundle.device)
kwargs = _generate_kwargs(hf_text_bundle.tokenizer, encoded, max_new_tokens=4)
with torch.no_grad():
eager_output = hf_text_bundle.model.generate(**kwargs)
compiled_model = torch.compile(hf_text_bundle.model, backend=luminal_backend)
with torch.no_grad():
compiled_output = compiled_model.generate(**kwargs)
torch.testing.assert_close(compiled_output, eager_output)
@pytest.mark.parametrize("model_id", SIMPLE_DENSE_MODELS[:1])
def test_chat_template_generation(
self,
model_id: str,
hf_text_bundle: HFTextGenerationBundle,
):
"""Generate using chat-templated input with special tokens."""
if hf_text_bundle.tokenizer.chat_template is None:
pytest.skip(f"{model_id} has no chat template")
messages = [{"role": "user", "content": "What is 2+2?"}]
encoded = hf_text_bundle.tokenizer.apply_chat_template(
messages,
return_tensors="pt",
add_generation_prompt=True,
return_dict=True,
)
encoded = {k: v.to(hf_text_bundle.device) for k, v in encoded.items()}
kwargs = _generate_kwargs(hf_text_bundle.tokenizer, encoded)
with torch.no_grad():
eager_output = hf_text_bundle.model.generate(**kwargs)
compiled_model = torch.compile(hf_text_bundle.model, backend=luminal_backend)
with torch.no_grad():
compiled_output = compiled_model.generate(**kwargs)
torch.testing.assert_close(compiled_output, eager_output)
@pytest.mark.parametrize("model_id", SIMPLE_DENSE_MODELS[:1])
@pytest.mark.parametrize("max_new_tokens", [20, 50])
def test_longer_generation(
self,
max_new_tokens: int,
hf_text_bundle: HFTextGenerationBundle,
):
"""Generate many tokens to stress KV cache over extended decode loop."""
encoded = _encode(
hf_text_bundle.tokenizer, "Once upon a time", hf_text_bundle.device
)
kwargs = _generate_kwargs(
hf_text_bundle.tokenizer,
encoded,
max_new_tokens=max_new_tokens,
)
with torch.no_grad():
eager_output = hf_text_bundle.model.generate(**kwargs)
compiled_model = torch.compile(hf_text_bundle.model, backend=luminal_backend)
with torch.no_grad():
compiled_output = compiled_model.generate(**kwargs)
torch.testing.assert_close(compiled_output, eager_output)
@pytest.mark.parametrize("model_id", SIMPLE_DENSE_MODELS[:1])
def test_greedy_determinism(
self,
hf_text_bundle: HFTextGenerationBundle,
configure_dynamo,
):
"""Greedy generation produces identical results on repeated calls."""
configure_dynamo(cache_size_limit=4)
encoded = _encode(
hf_text_bundle.tokenizer,
"The meaning of life is",
hf_text_bundle.device,
)
kwargs = _generate_kwargs(hf_text_bundle.tokenizer, encoded, max_new_tokens=10)
compiled_model = torch.compile(hf_text_bundle.model, backend=luminal_backend)
with torch.no_grad():
output_1 = compiled_model.generate(**kwargs)
output_2 = compiled_model.generate(**kwargs)
torch.testing.assert_close(output_1, output_2)
@pytest.mark.parametrize("model_id", SIMPLE_DENSE_MODELS[:1])
def test_reuse_compiled_model(
self,
hf_text_bundle: HFTextGenerationBundle,
configure_dynamo,
):
"""Call the same compiled model multiple times with different prompts."""
configure_dynamo(cache_size_limit=8)
compiled_model = torch.compile(hf_text_bundle.model, backend=luminal_backend)
prompts = [
"The capital of France is",
"Water boils at",
"The largest planet in our solar system is",
]
for prompt in prompts:
encoded = _encode(hf_text_bundle.tokenizer, prompt, hf_text_bundle.device)
kwargs = _generate_kwargs(
hf_text_bundle.tokenizer, encoded, max_new_tokens=4
)
with torch.no_grad():
eager_output = hf_text_bundle.model.generate(**kwargs)
compiled_output = compiled_model.generate(**kwargs)
torch.testing.assert_close(compiled_output, eager_output)
@pytest.mark.parametrize("model_id", SIMPLE_DENSE_MODELS[:1])
def test_batched_inference(self, hf_text_bundle: HFTextGenerationBundle):
"""Batched generation with multiple prompts and left-padding."""
hf_text_bundle.tokenizer.padding_side = "left"
prompts = ["Hello", "What is the capital of France"]
encoded = hf_text_bundle.tokenizer(prompts, return_tensors="pt", padding=True)
encoded = {k: v.to(hf_text_bundle.device) for k, v in encoded.items()}
kwargs = _generate_kwargs(hf_text_bundle.tokenizer, encoded, max_new_tokens=4)
with torch.no_grad():
eager_output = hf_text_bundle.model.generate(**kwargs)
compiled_model = torch.compile(hf_text_bundle.model, backend=luminal_backend)
with torch.no_grad():
compiled_output = compiled_model.generate(**kwargs)
torch.testing.assert_close(compiled_output, eager_output)

File diff suppressed because it is too large Load Diff

View File

@@ -224,127 +224,15 @@ def test_hf_llama_decode_loop_static(device: torch.device):
tokens.append(next_token)
@pytest.mark.slow
@pytest.mark.skip(reason="This is currently failing and in development")
def test_hf_llama3_1b_decode_loop_dynamic():
"""Decode loop with dynamic shapes on real Llama3.2-1B — compile once, run with varying seq_len.
@pytest.mark.xfail(reason="numerical precision — max_diff exceeds atol")
def test_hf_llama3_1b_decode_loop_dynamic(device: torch.device):
"""Decode loop on real Llama3.2-1B with pretrained weights.
This is the end-goal test: full 1B model with pretrained weights, CUDA backend,
ONNX exported once with dynamic_axes for seq_len, then decoded autoregressively
without recompilation.
Supports both ONNX and PT2 export modes via LUMINAL_EXPORT_MODE env var.
Recompiles each step as sequence length grows, using the standard
torch.compile(model, backend=luminal_backend) pattern.
"""
import os
import tempfile
from transformers import AutoConfig, AutoTokenizer, LlamaForCausalLM
import luminal
backend = os.environ.get("LUMINAL_BACKEND", "cuda")
export_mode = os.getenv("LUMINAL_EXPORT_MODE", "onnx").lower()
config = AutoConfig.from_pretrained("NousResearch/Llama-3.2-1B")
config.use_cache = False
config._attn_implementation = "eager"
print("Loaded config")
model = LlamaForCausalLM.from_pretrained(
"NousResearch/Llama-3.2-1B",
config=config,
torch_dtype=torch.float32,
).eval()
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-3.2-1B")
print("Loaded Model")
prompt = "The capital of france is"
tokens = tokenizer.encode(prompt)
print(f"Prompt: '{prompt}' -> {len(tokens)} tokens: {tokens}")
num_generate = 3
if export_mode == "pt2":
from luminal.pt2 import compile as luminal_compile
dummy = torch.tensor([[1, 2, 3, 4]])
compiled = luminal_compile(model, dummy, search_iterations=0, dynamic_dim=1)
for step in range(num_generate):
input_ids = torch.tensor([tokens])
logits = compiled(input_ids)[0]
with torch.no_grad():
ref = model(input_ids)
assert torch.allclose(logits, ref.logits, atol=1e-3), (
f"step {step}: max_diff={torch.max(torch.abs(logits - ref.logits)).item():.2e}"
)
next_token = ref.logits[0, -1, :].argmax().item()
tokens.append(next_token)
print(f"Step {step}: '{tokenizer.decode(tokens)}'")
else:
# ONNX path — manual export with dynamic_axes
dummy = torch.tensor([[1, 2, 3, 4]])
tmp = tempfile.NamedTemporaryFile(suffix=".onnx", delete=False)
tmp_path = tmp.name
tmp.close()
try:
torch.onnx.export(
model,
(dummy,),
tmp_path,
opset_version=20,
input_names=["input_ids"],
output_names=["logits"],
dynamic_axes={"input_ids": {1: "seq_len"}, "logits": {1: "seq_len"}},
)
print("Exported onnx")
graph = luminal.process_onnx(tmp_path, backend)
finally:
os.unlink(tmp_path)
print("Exported Model")
assert graph.has_dynamic_dims, "Graph should have dynamic dims"
assert "seq_len" in graph.dim_params, (
f"Expected 'seq_len' in {graph.dim_params}"
)
for step in range(num_generate):
seq_len = len(tokens)
graph.set_dim("seq_len", seq_len)
graph.set_input("input_ids", [float(t) for t in tokens])
graph.run()
output_shapes = graph.resolve_output_shapes()
logits_data = graph.get_output("logits")
logits = torch.tensor(logits_data, dtype=torch.float32).reshape(
output_shapes[0]
)
# Compare against PyTorch reference
input_ids = torch.tensor([tokens])
with torch.no_grad():
ref = model(input_ids)
assert torch.allclose(logits, ref.logits, atol=1e-3), (
f"step {step}: max_diff={torch.max(torch.abs(logits - ref.logits)).item():.2e}"
)
next_token = ref.logits[0, -1, :].argmax().item()
tokens.append(next_token)
print(f"Step {step}: '{tokenizer.decode(tokens)}'")
@pytest.mark.slow
def test_hf_llama3_full(device: torch.device):
"""HuggingFace LlamaForCausalLM — full Llama3.2-1B with real pretrained weights.
No config alterations except use_cache=False and eager attention.
Loads actual weights from NousResearch/Llama-3.2-1B.
"""
from transformers import AutoConfig, LlamaForCausalLM
config = AutoConfig.from_pretrained("NousResearch/Llama-3.2-1B")
config.use_cache = False
config._attn_implementation = "eager"
@@ -358,18 +246,41 @@ def test_hf_llama3_full(device: torch.device):
.eval()
.to(device)
)
compiled = torch.compile(model, backend=luminal_backend)
input_ids = torch.tensor([[1, 2, 3, 4]], device=device)
with torch.no_grad():
ref = model(input_ids)
out = compiled(input_ids)
assert torch.allclose(out.logits, ref.logits, atol=1e-3), (
f"max_diff={torch.max(torch.abs(out.logits - ref.logits)).item():.2e}"
)
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-3.2-1B")
prompt = "The capital of france is"
tokens = tokenizer.encode(prompt)
print(f"Prompt: '{prompt}' -> {len(tokens)} tokens: {tokens}")
num_generate = 3
for step in range(num_generate):
input_ids = torch.tensor([tokens], device=device)
torch._dynamo.reset()
compiled = torch.compile(model, backend=luminal_backend)
with torch.no_grad():
ref = model(input_ids)
out = compiled(input_ids)
assert torch.allclose(out.logits, ref.logits, atol=1e-5), (
f"step {step}: max_diff={torch.max(torch.abs(out.logits - ref.logits)).item():.2e}"
)
next_token = ref.logits[0, -1, :].argmax().item()
tokens.append(next_token)
print(f"Step {step}: '{tokenizer.decode(tokens)}'")
@pytest.mark.slow
def test_hf_llama38b_full(device: torch.device):
def _gpu_mem(label):
"""Print GPU memory stats at a given checkpoint."""
if torch.cuda.is_available():
alloc = torch.cuda.memory_allocated() / (1024**3)
reserved = torch.cuda.memory_reserved() / (1024**3)
peak = torch.cuda.max_memory_allocated() / (1024**3)
print(
f"[GPU MEM] {label}: allocated={alloc:.3f} GiB, reserved={reserved:.3f} GiB, peak={peak:.3f} GiB"
)
@pytest.mark.xfail(reason="numerical precision — max_diff exceeds atol")
def test_hf_llama3_full(device: torch.device):
"""HuggingFace LlamaForCausalLM — full Llama3.2-1B with real pretrained weights.
No config alterations except use_cache=False and eager attention.
@@ -377,6 +288,57 @@ def test_hf_llama38b_full(device: torch.device):
"""
from transformers import AutoConfig, LlamaForCausalLM
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
_gpu_mem("before model load")
config = AutoConfig.from_pretrained("NousResearch/Llama-3.2-1B")
config.use_cache = False
config._attn_implementation = "eager"
model = (
LlamaForCausalLM.from_pretrained(
"NousResearch/Llama-3.2-1B",
config=config,
torch_dtype=torch.float32,
)
.eval()
.to(device)
)
n_params = sum(p.numel() for p in model.parameters())
print(
f"[MODEL] Total parameters: {n_params:,} ({n_params * 4 / 1024**3:.3f} GiB in f32)"
)
_gpu_mem("after model load")
compiled = torch.compile(model, backend=luminal_backend)
_gpu_mem("after torch.compile (lazy, no compilation yet)")
input_ids = torch.tensor([[1, 2, 3, 4]], device=device)
with torch.no_grad():
ref = model(input_ids)
_gpu_mem("after PyTorch reference forward")
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
_gpu_mem("before compiled forward (peak reset)")
out = compiled(input_ids)
_gpu_mem("after compiled forward (includes compilation)")
assert torch.allclose(out.logits, ref.logits, atol=1e-5), (
f"max_diff={torch.max(torch.abs(out.logits - ref.logits)).item():.2e}"
)
@pytest.mark.xfail(reason="numerical precision — max_diff exceeds atol")
def test_hf_llama3_large_full(device: torch.device):
"""HuggingFace LlamaForCausalLM — full Llama-3.1-8B-Instruct with real pretrained weights.
No config alterations except use_cache=False and eager attention.
Loads actual weights from NousResearch/Meta-Llama-3.1-8B-Instruct.
"""
from transformers import AutoConfig, LlamaForCausalLM
config = AutoConfig.from_pretrained("NousResearch/Meta-Llama-3.1-8B-Instruct")
config.use_cache = False
config._attn_implementation = "eager"
@@ -395,79 +357,95 @@ def test_hf_llama38b_full(device: torch.device):
with torch.no_grad():
ref = model(input_ids)
out = compiled(input_ids)
assert torch.allclose(out.logits, ref.logits, atol=1e-3), (
assert torch.allclose(out.logits, ref.logits, atol=1e-5), (
f"max_diff={torch.max(torch.abs(out.logits - ref.logits)).item():.2e}"
)
@pytest.mark.xfail(reason="numerical precision — max_diff exceeds atol")
def test_hf_llama38b_full(device: torch.device):
"""HuggingFace LlamaForCausalLM — full Llama-3.1-8B-Instruct with real pretrained weights.
No config alterations except use_cache=False and eager attention.
Loads actual weights from NousResearch/Meta-Llama-3.1-8B-Instruct.
"""
from transformers import AutoConfig, LlamaForCausalLM
config = AutoConfig.from_pretrained("NousResearch/Meta-Llama-3.1-8B-Instruct")
config.use_cache = False
config._attn_implementation = "eager"
model = (
LlamaForCausalLM.from_pretrained(
"NousResearch/Meta-Llama-3.1-8B-Instruct",
config=config,
torch_dtype=torch.float32,
)
.eval()
.to(device)
)
compiled = torch.compile(model, backend=luminal_backend)
input_ids = torch.tensor([[1, 2, 3, 4]], device=device)
with torch.no_grad():
ref = model(input_ids)
out = compiled(input_ids)
assert torch.allclose(out.logits, ref.logits, atol=1e-5), (
f"max_diff={torch.max(torch.abs(out.logits - ref.logits)).item():.2e}"
)
@pytest.mark.slow
def test_hf_llama38b_cached():
"""Llama 3.1-8B via pre-generated artifacts + reference logits.
Supports both ONNX and PT2 export modes via LUMINAL_EXPORT_MODE env var.
Requires artifacts generated by:
ONNX: uv run python tests/generate_llama38b_artifacts.py
PT2: uv run python tests/generate_llama38b_pt2_artifacts.py
"""
def test_hf_llama38b_cached_onnx(
llama38b_onnx_path, llama38b_ref_logits: torch.Tensor
):
import os
from pathlib import Path
import luminal
backend = os.environ.get("LUMINAL_BACKEND", "cuda")
export_mode = os.getenv("LUMINAL_EXPORT_MODE", "onnx").lower()
tests_dir = Path(__file__).resolve().parent
logits_path = tests_dir / "llama38b_ref_logits.pt"
graph = luminal.process_onnx(str(llama38b_onnx_path), backend)
print("Compiled luminal ONNX graph")
assert logits_path.exists(), (
f"{logits_path} not found. Run: uv run python tests/generate_llama38b_artifacts.py"
)
ref_logits = torch.load(logits_path, weights_only=True)
print(f"Loaded reference logits: {ref_logits.shape}")
graph.set_input("input_ids", [float(t) for t in [1, 2, 3, 4]])
graph.run()
if export_mode == "pt2":
from luminal import CompiledModel
pt2_path = tests_dir / "llama38b.pt2"
weights_path = tests_dir / "llama38b_weights.safetensors"
assert pt2_path.exists(), (
f"{pt2_path} not found. Run: uv run python tests/generate_llama38b_pt2_artifacts.py"
)
assert weights_path.exists(), (
f"{weights_path} not found. Run: uv run python tests/generate_llama38b_pt2_artifacts.py"
)
backend_name = "cuda" if backend == "cuda" else "cpu"
compiled_inner = luminal.compile_pt2(
str(pt2_path), str(weights_path), backend_name, 0
)
compiled = CompiledModel(compiled_inner)
print("Compiled luminal PT2 graph")
input_ids = torch.tensor([[1, 2, 3, 4]])
logits = compiled(input_ids)[0]
else:
onnx_path = tests_dir / "llama38b.onnx"
assert onnx_path.exists(), (
f"{onnx_path} not found. Run: uv run python tests/generate_llama38b_artifacts.py"
)
graph = luminal.process_onnx(str(onnx_path), backend)
print("Compiled luminal ONNX graph")
graph.set_input("input_ids", [float(t) for t in [1, 2, 3, 4]])
graph.run()
logits_data = graph.get_output("logits")
logits_shape = graph.output_shapes[0]
logits = torch.tensor(logits_data, dtype=torch.float32).reshape(logits_shape)
logits_data = graph.get_output("logits")
logits_shape = graph.output_shapes[0]
logits = torch.tensor(logits_data, dtype=torch.float32).reshape(logits_shape)
print(f"Loaded reference logits: {llama38b_ref_logits.shape}")
print(f"Output logits shape: {logits.shape}")
assert torch.allclose(logits, ref_logits, atol=1e-3), (
f"max_diff={torch.max(torch.abs(logits - ref_logits)).item():.2e}"
assert torch.allclose(logits, llama38b_ref_logits, atol=1e-3), (
f"max_diff={torch.max(torch.abs(logits - llama38b_ref_logits)).item():.2e}"
)
@pytest.mark.slow
def test_hf_llama38b_cached_pt2(
llama38b_pt2_path, llama38b_weights_path, llama38b_ref_logits: torch.Tensor
):
import os
import luminal
from luminal import CompiledModel
backend = os.environ.get("LUMINAL_BACKEND", "cuda")
backend_name = "cuda" if backend == "cuda" else "cpu"
compiled_inner = luminal.compile_pt2(
str(llama38b_pt2_path), str(llama38b_weights_path), backend_name, 0
)
compiled = CompiledModel(compiled_inner)
print("Compiled luminal PT2 graph")
input_ids = torch.tensor([[1, 2, 3, 4]])
logits = compiled(input_ids)[0]
print(f"Loaded reference logits: {llama38b_ref_logits.shape}")
print(f"Output logits shape: {logits.shape}")
assert torch.allclose(logits, llama38b_ref_logits, atol=1e-3), (
f"max_diff={torch.max(torch.abs(logits - llama38b_ref_logits)).item():.2e}"
)

View File

@@ -41,9 +41,6 @@ class AddAddTestModel(torch.nn.Module):
class AddConstantTestModel(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x: torch.Tensor):
return x + 10
@@ -59,25 +56,16 @@ class LinearLayerModel(torch.nn.Module):
class SqrtTestModel(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x.sqrt()
class SinTestModel(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.sin(x)
class CosTestModel(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.cos(x)
@@ -94,9 +82,6 @@ class SubTestModel(torch.nn.Module):
class TransposeTestModel(torch.nn.Module):
"""Test basic 2D transpose (matrix transpose)."""
def __init__(self) -> None:
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x.t() # 2D transpose
@@ -104,9 +89,6 @@ class TransposeTestModel(torch.nn.Module):
class Transpose3DTestModel(torch.nn.Module):
"""Test 3D transpose with explicit permutation."""
def __init__(self) -> None:
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x.permute(2, 0, 1) # Rotate dimensions
@@ -114,9 +96,6 @@ class Transpose3DTestModel(torch.nn.Module):
class Transpose4DTestModel(torch.nn.Module):
"""Test 4D transpose (NCHW -> NHWC)."""
def __init__(self) -> None:
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x.permute(0, 2, 3, 1) # Common in CNNs
@@ -124,9 +103,6 @@ class Transpose4DTestModel(torch.nn.Module):
class TransposeReverseTestModel(torch.nn.Module):
"""Test reverse permutation (default transpose behavior)."""
def __init__(self) -> None:
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
dims = list(range(x.ndim))
return x.permute(*reversed(dims))
@@ -151,9 +127,6 @@ class TransposeInExpressionModel(torch.nn.Module):
class ConstantScalarFloatModel(torch.nn.Module):
"""Test scalar constant (broadcasts to input shape)."""
def __init__(self) -> None:
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
constant = torch.tensor(10.5).to(x.device)
return x + constant
@@ -162,9 +135,6 @@ class ConstantScalarFloatModel(torch.nn.Module):
class Constant1DArrayFloatModel(torch.nn.Module):
"""Test 1D array constant."""
def __init__(self) -> None:
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
constant = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0]).to(x.device)
return x * constant
@@ -173,9 +143,6 @@ class Constant1DArrayFloatModel(torch.nn.Module):
class Constant2DMatrixFloatModel(torch.nn.Module):
"""Test 2D matrix constant."""
def __init__(self) -> None:
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
constant = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]).to(x.device)
return x + constant
@@ -184,9 +151,6 @@ class Constant2DMatrixFloatModel(torch.nn.Module):
class ConstantRawDataFloatModel(torch.nn.Module):
"""Test constant with specific values (tests raw data format)."""
def __init__(self) -> None:
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
constant = torch.tensor([7.5, 8.5, 9.5]).to(x.device)
return x + constant
@@ -195,9 +159,6 @@ class ConstantRawDataFloatModel(torch.nn.Module):
class ConstantInt32ConversionModel(torch.nn.Module):
"""Test INT32 constant values (PyTorch exports as integers)."""
def __init__(self) -> None:
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
constant = torch.tensor([1, 2, 3, 4, 5], dtype=torch.int32).to(x.device)
return x + constant.float()
@@ -206,9 +167,6 @@ class ConstantInt32ConversionModel(torch.nn.Module):
class ConstantInt64ConversionModel(torch.nn.Module):
"""Test INT64 constant values."""
def __init__(self) -> None:
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
constant = torch.tensor([100, 200, 300], dtype=torch.int64).to(x.device)
return x * constant.float()
@@ -217,9 +175,6 @@ class ConstantInt64ConversionModel(torch.nn.Module):
class ConstantFloat64ConversionModel(torch.nn.Module):
"""Test FLOAT64 (double) constant values."""
def __init__(self) -> None:
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
constant = torch.tensor([1.5, 2.5, 3.5], dtype=torch.float64).to(x.device)
return x * constant.float()
@@ -228,9 +183,6 @@ class ConstantFloat64ConversionModel(torch.nn.Module):
class ConstantBoolConversionModel(torch.nn.Module):
"""Test boolean constant values (converted to 0.0/1.0)."""
def __init__(self) -> None:
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
constant = torch.tensor([True, False, True, False, True], dtype=torch.bool).to(
x.device
@@ -241,9 +193,6 @@ class ConstantBoolConversionModel(torch.nn.Module):
class ConstantInt64RawDataModel(torch.nn.Module):
"""Test INT64 constant with large values (tests raw data path)."""
def __init__(self) -> None:
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
constant = torch.tensor([1000, 2000, 3000], dtype=torch.int64).to(x.device)
return x + constant.float()
@@ -252,9 +201,6 @@ class ConstantInt64RawDataModel(torch.nn.Module):
class ConstantNegativeValuesModel(torch.nn.Module):
"""Test negative constant values."""
def __init__(self) -> None:
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
constant = torch.tensor([-5.0, -10.0, -15.0]).to(x.device)
return x + constant
@@ -263,9 +209,6 @@ class ConstantNegativeValuesModel(torch.nn.Module):
class ConstantZeroValueModel(torch.nn.Module):
"""Test all-zero constant."""
def __init__(self) -> None:
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
constant = torch.tensor([0.0, 0.0, 0.0, 0.0]).to(x.device)
return x * constant
@@ -274,9 +217,6 @@ class ConstantZeroValueModel(torch.nn.Module):
class ConstantMultipleInGraphModel(torch.nn.Module):
"""Test multiple constants in one graph."""
def __init__(self) -> None:
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
const1 = torch.tensor([10.0, 20.0, 30.0]).to(x.device)
const2 = torch.tensor([1.0, 2.0, 3.0]).to(x.device)
@@ -290,9 +230,6 @@ class ConstantMultipleInGraphModel(torch.nn.Module):
class CastDoubleToFloatModel(torch.nn.Module):
"""Test downcast: Double (FLOAT64) -> Float."""
def __init__(self) -> None:
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Input will be float64, cast to float32
return x.to(torch.float32)
@@ -301,9 +238,6 @@ class CastDoubleToFloatModel(torch.nn.Module):
class CastInt32ToFloatModel(torch.nn.Module):
"""Test INT32 -> Float conversion."""
def __init__(self) -> None:
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x.to(torch.float32)
@@ -311,9 +245,6 @@ class CastInt32ToFloatModel(torch.nn.Module):
class CastInt64ToFloatModel(torch.nn.Module):
"""Test INT64 -> Float conversion."""
def __init__(self) -> None:
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x.to(torch.float32)
@@ -321,9 +252,6 @@ class CastInt64ToFloatModel(torch.nn.Module):
class CastBoolToFloatModel(torch.nn.Module):
"""Test BOOL -> Float conversion (non-zero -> 1.0, zero -> 0.0)."""
def __init__(self) -> None:
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x.to(torch.float32)
@@ -331,9 +259,6 @@ class CastBoolToFloatModel(torch.nn.Module):
class CastInComputationGraphModel(torch.nn.Module):
"""Test Cast node followed by an operation (Cast + Add)."""
def __init__(self) -> None:
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
casted = x.to(torch.float32)
constant = torch.tensor([2.0, 2.0, 2.0]).to(x.device)
@@ -343,9 +268,6 @@ class CastInComputationGraphModel(torch.nn.Module):
class CastWith2DTensorModel(torch.nn.Module):
"""Test Cast with 2D tensor (matrix)."""
def __init__(self) -> None:
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x.to(torch.float32)
@@ -353,9 +275,6 @@ class CastWith2DTensorModel(torch.nn.Module):
class CastNegativeValuesModel(torch.nn.Module):
"""Test Cast with negative integer values."""
def __init__(self) -> None:
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x.to(torch.float32)
@@ -363,9 +282,6 @@ class CastNegativeValuesModel(torch.nn.Module):
class CastScalarValueModel(torch.nn.Module):
"""Test Cast with scalar (single element)."""
def __init__(self) -> None:
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x.to(torch.float32)
@@ -389,9 +305,6 @@ class ModTestModel(torch.nn.Module):
class ModByConstantModel(torch.nn.Module):
"""Tests modulo with an inline constant tensor (ONNX Constant node)."""
def __init__(self) -> None:
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
constant = torch.tensor([3.0, 4.0, 5.0]).to(x.device)
return x.fmod(constant)

View File

@@ -1,154 +1,118 @@
from dataclasses import dataclass
from typing import Callable
import pytest
import torch
import torch._dynamo
from test_models import (
SigmoidTestModel,
SigmoidInExpressionModel,
TanhTestModel,
TanhInExpressionModel,
ReluTestModel,
ReluAllNegativeModel,
ReluInExpressionModel,
AbsTestModel,
AbsAllNegativeModel,
AbsInExpressionModel,
NegTestModel,
NegAllPositiveModel,
NegInExpressionModel,
ClipTestModel,
ClipMinOnlyTestModel,
ClipMaxOnlyTestModel,
)
import test_models as tm
from luminal import luminal_backend
# ── Sigmoid ──────────────────────────────────────────────────────────────────
Args = tuple[torch.Tensor, ...]
Kwargs = dict[str, torch.Tensor]
InputFactory = Callable[[torch.device], tuple[Args, Kwargs]]
def test_sigmoid(device):
model = SigmoidTestModel().to(device)
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
x = torch.rand((5, 5), device=device) * 2 - 1 # mixed positive/negative
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
@dataclass(frozen=True)
class UnaryCase:
id: str
model_factory: Callable[[], torch.nn.Module]
input_factory: InputFactory
def test_sigmoid_in_expression(device):
model = SigmoidInExpressionModel().to(device)
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
x = torch.rand((5, 5), device=device)
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
UNARY_CASES: list[UnaryCase] = [
UnaryCase(
id="sigmoid",
model_factory=tm.SigmoidTestModel,
input_factory=lambda device: ((torch.rand((5, 5), device=device) * 2 - 1,), {}),
),
UnaryCase(
id="sigmoid_in_expression",
model_factory=tm.SigmoidInExpressionModel,
input_factory=lambda device: ((torch.rand((5, 5), device=device),), {}),
),
UnaryCase(
id="tanh",
model_factory=tm.TanhTestModel,
input_factory=lambda device: ((torch.rand((5, 5), device=device) * 2 - 1,), {}),
),
UnaryCase(
id="tanh_in_expression",
model_factory=tm.TanhInExpressionModel,
input_factory=lambda device: ((torch.rand((5, 5), device=device),), {}),
),
UnaryCase(
id="relu",
model_factory=tm.ReluTestModel,
input_factory=lambda device: ((torch.rand((5, 5), device=device) * 2 - 1,), {}),
),
UnaryCase(
id="relu_all_negative",
model_factory=tm.ReluAllNegativeModel,
input_factory=lambda device: ((-torch.rand((5, 5), device=device),), {}),
),
UnaryCase(
id="relu_in_expression",
model_factory=tm.ReluInExpressionModel,
input_factory=lambda device: ((torch.rand((5, 5), device=device) * 2 - 1,), {}),
),
UnaryCase(
id="abs",
model_factory=tm.AbsTestModel,
input_factory=lambda device: ((torch.rand((5, 5), device=device) * 2 - 1,), {}),
),
UnaryCase(
id="abs_all_negative",
model_factory=tm.AbsAllNegativeModel,
input_factory=lambda device: ((-torch.rand((5, 5), device=device),), {}),
),
UnaryCase(
id="abs_in_expression",
model_factory=tm.AbsInExpressionModel,
input_factory=lambda device: ((torch.rand((5, 5), device=device) * 2 - 1,), {}),
),
UnaryCase(
id="neg",
model_factory=tm.NegTestModel,
input_factory=lambda device: ((torch.rand((5, 5), device=device) * 2 - 1,), {}),
),
UnaryCase(
id="neg_all_positive",
model_factory=tm.NegAllPositiveModel,
input_factory=lambda device: ((torch.rand((5, 5), device=device),), {}),
),
UnaryCase(
id="neg_in_expression",
model_factory=tm.NegInExpressionModel,
input_factory=lambda device: ((torch.rand((5, 5), device=device) * 2 - 1,), {}),
),
UnaryCase(
id="clip",
model_factory=tm.ClipTestModel,
input_factory=lambda device: ((torch.rand((5, 5), device=device) * 4 - 2,), {}),
),
UnaryCase(
id="clip_min_only",
model_factory=tm.ClipMinOnlyTestModel,
input_factory=lambda device: ((torch.rand((5, 5), device=device) * 4 - 2,), {}),
),
UnaryCase(
id="clip_max_only",
model_factory=tm.ClipMaxOnlyTestModel,
input_factory=lambda device: ((torch.rand((5, 5), device=device) * 4 - 2,), {}),
),
]
# ── Tanh ─────────────────────────────────────────────────────────────────────
def test_tanh(device):
model = TanhTestModel().to(device)
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
x = torch.rand((5, 5), device=device) * 2 - 1
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
def test_tanh_in_expression(device):
model = TanhInExpressionModel().to(device)
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
x = torch.rand((5, 5), device=device)
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
# ── Relu ─────────────────────────────────────────────────────────────────────
def test_relu(device):
model = ReluTestModel().to(device)
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
x = torch.rand((5, 5), device=device) * 2 - 1 # mixed positive/negative
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
def test_relu_all_negative(device):
model = ReluAllNegativeModel().to(device)
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
x = -torch.rand((5, 5), device=device) # all negative -> output all zeros
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
def test_relu_in_expression(device):
model = ReluInExpressionModel().to(device)
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
x = torch.rand((5, 5), device=device) * 2 - 1
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
# ── Abs ──────────────────────────────────────────────────────────────────────
def test_abs(device):
model = AbsTestModel().to(device)
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
x = torch.rand((5, 5), device=device) * 2 - 1 # mixed positive/negative
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
def test_abs_all_negative(device):
model = AbsAllNegativeModel().to(device)
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
x = -torch.rand((5, 5), device=device) # all negative
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
def test_abs_in_expression(device):
model = AbsInExpressionModel().to(device)
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
x = torch.rand((5, 5), device=device) * 2 - 1
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
# ── Neg ──────────────────────────────────────────────────────────────────────
def test_neg(device):
model = NegTestModel().to(device)
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
x = torch.rand((5, 5), device=device) * 2 - 1 # mixed positive/negative
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
def test_neg_all_positive(device):
model = NegAllPositiveModel().to(device)
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
x = torch.rand((5, 5), device=device) # all positive -> output all negative
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
def test_neg_in_expression(device):
model = NegInExpressionModel().to(device)
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
x = torch.rand((5, 5), device=device) * 2 - 1
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
# ── Clip ──────────────────────────────────────────────────────────────────────
def test_clip(device):
"""Clip tensor values to [-0.5, 0.5]."""
model = ClipTestModel().to(device)
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
x = torch.rand((5, 5), device=device) * 4 - 2 # range [-2, 2]
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
def test_clip_min_only(device):
"""Clip tensor values to [0.0, +inf]."""
model = ClipMinOnlyTestModel().to(device)
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
x = torch.rand((5, 5), device=device) * 4 - 2
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
def test_clip_max_only(device):
"""Clip tensor values to [-inf, 0.5]."""
model = ClipMaxOnlyTestModel().to(device)
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
x = torch.rand((5, 5), device=device) * 4 - 2
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
class TestUnaryOps:
@pytest.mark.parametrize("case", UNARY_CASES, ids=lambda case: case.id)
def test_matches_eager(self, case: UnaryCase, device: torch.device) -> None:
model = case.model_factory().to(device)
compiled_model = torch.compile(model, backend=luminal_backend)
args, kwargs = case.input_factory(device)
torch.testing.assert_close(
compiled_model(*args, **kwargs),
model(*args, **kwargs),
atol=1e-5,
rtol=1e-5,
)

20
skills-lock.json Normal file
View File

@@ -0,0 +1,20 @@
{
"version": 1,
"skills": {
"ruff": {
"source": "astral-sh/claude-code-plugins",
"sourceType": "github",
"computedHash": "905f1c2b66e722ab534d7cbeceafb6ee37d32e3bfe51f3e68c661c51eb19a776"
},
"ty": {
"source": "astral-sh/claude-code-plugins",
"sourceType": "github",
"computedHash": "5d82b319a84d296fae47f2664228bfac1a36dd832cb487b485a65ec36b3df21f"
},
"uv": {
"source": "astral-sh/claude-code-plugins",
"sourceType": "github",
"computedHash": "b1ca9e9826d906f6ee6ea172da50018ec13bb622ac9ba7204bda8b10e7fc8d0a"
}
}
}

View File

@@ -254,6 +254,7 @@ impl Graph {
if subgraphs.len() <= 1 {
let (program, root) = hlir_to_egglog(self);
self.egraphs = vec![run_egglog(&program, &root, &ops, cleanup_hlir).unwrap()];
self.chunk_groups = vec![ChunkGroup {
representative: 0,
members: vec![0],
@@ -579,7 +580,6 @@ impl Graph {
for (group_idx, group) in self.chunk_groups.iter().enumerate() {
let egraph = &self.egraphs[group_idx];
let mut prev_selected: FxHashSet<u64> = FxHashSet::default();
let mut list_cache = FxHashMap::default();
let mut expr_cache = FxHashMap::default();

View File

@@ -149,6 +149,7 @@ pub type HLIROps = (
Scatter,
SumReduce,
MaxReduce,
Softmax,
);
#[derive(Default, Debug, Clone)]
@@ -1836,6 +1837,160 @@ impl NativeOp for MaxReduce {
}
}
// Fused Softmax: softmax(x, axis) = exp(x - max(x)) / sum(exp(x - max(x)))
// A single HLIR op that replaces the 6-op decomposed chain.
// On CUDA, KernelSoftmax provides a fused 3-pass kernel.
// On native, NativeOp implements softmax directly.
#[derive(Debug, Clone, Default, PartialEq)]
pub struct Softmax {
pub axis: usize,
pub input_shape: ShapeTracker,
// Extracted fields (populated during egglog extraction, used by NativeOp)
pub shape: Vec<Expression>,
pub in_strides: Vec<Expression>,
pub reduce_dim: Expression,
pub reduce_stride: Expression,
}
impl Display for Softmax {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Softmax(axis={})", self.axis)
}
}
/// Sort for Softmax: (shape, in_strides, out_strides, reduce_dim, reduce_stride)
pub fn softmax_sort(name: &str) -> SortDef {
sort(
OP_KIND,
name,
&[
("shape", ELIST),
("in_strides", ELIST),
("out_strides", ELIST),
("reduce_dim", EXPRESSION),
("reduce_stride", EXPRESSION),
],
)
}
impl HLIROp for Softmax {
fn to_egglog(&self, inputs: &[(NodeIndex, String)]) -> String {
let reduce_dim = self.input_shape.dims[self.axis];
let reduce_stride = self.input_shape.strides[self.axis];
format!(
"(Op (Softmax {} {} {} {} {}) {})",
elist_to_egglog(&self.input_shape.dims),
elist_to_egglog(&self.input_shape.strides),
elist_to_egglog(&self.input_shape.contiguous().strides),
reduce_dim.to_egglog(),
reduce_stride.to_egglog(),
ilist_egglog(&[&inputs[0].1]),
)
}
}
impl EgglogOp for Softmax {
fn sort(&self) -> SortDef {
softmax_sort("Softmax")
}
fn cleanup(&self) -> bool {
true
}
fn n_inputs(&self) -> usize {
1
}
fn rewrites(&self) -> Vec<Rule> {
vec![dtype_propagation_op(&self.sort())]
}
fn extract<'a>(
&'a self,
egraph: &'a SerializedEGraph,
kind_children: &[&'a ENodeId],
input_enodes: Vec<&'a ENodeId>,
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
) -> (LLIROp, Vec<&'a ENodeId>) {
let shape = extract_expr_list(egraph, kind_children[0], list_cache, expr_cache).unwrap();
let in_strides =
extract_expr_list(egraph, kind_children[1], list_cache, expr_cache).unwrap();
let reduce_dim = extract_expr(egraph, kind_children[3], expr_cache).unwrap();
let reduce_stride = extract_expr(egraph, kind_children[4], expr_cache).unwrap();
(
LLIROp::new::<dyn NativeOp>(Box::new(Self {
axis: 0,
input_shape: ShapeTracker::default(),
shape,
in_strides,
reduce_dim,
reduce_stride,
})),
input_enodes,
)
}
}
impl NativeOp for Softmax {
fn execute(&self, inputs: Vec<&NativeData>, dyn_map: &FxHashMap<char, usize>) -> NativeData {
match inputs[0] {
NativeData::F32(a) => {
// Use extracted fields (populated during egglog extraction)
let dims: Vec<usize> = self
.shape
.iter()
.map(|d| d.exec(dyn_map).unwrap())
.collect();
let n = self.reduce_dim.exec(dyn_map).unwrap();
let mut reduce_stride_expr = self.reduce_stride;
for (&var, &val) in dyn_map {
reduce_stride_expr =
reduce_stride_expr.substitute(var, Expression::from(val as i32));
}
// Compute row index strides (all dims except last, since softmax is always last-dim)
let ndim = dims.len();
let out_size: usize = dims.iter().product();
let mut out = vec![0.0f32; out_size];
// Use StridedIterator for the row dimensions
let row_ind = StridedIterator::new(
&self.shape[..ndim - 1],
&self.in_strides[..ndim - 1],
dyn_map,
);
for (row_idx, in_base) in row_ind.enumerate() {
// Pass 1: find max
let mut max_val = f32::NEG_INFINITY;
for i in 0..n {
let val = a[in_base + reduce_stride_expr.exec_single_var(i)];
if val > max_val {
max_val = val;
}
}
// Pass 2: exp(x - max) and sum
let mut sum = 0.0f32;
let out_base = row_idx * n;
for i in 0..n {
let val =
(a[in_base + reduce_stride_expr.exec_single_var(i)] - max_val).exp();
out[out_base + i] = val;
sum += val;
}
// Pass 3: normalize
let inv_sum = 1.0 / sum;
for i in 0..n {
out[out_base + i] *= inv_sum;
}
}
NativeData::F32(out)
}
_ => panic!("Softmax only supports F32"),
}
}
}
pub trait NativeOp: Debug + AsAny + Send + Sync {
fn execute(&self, inputs: Vec<&NativeData>, dyn_map: &FxHashMap<char, usize>) -> NativeData;
}