mirror of
https://github.com/jafioti/luminal.git
synced 2026-06-01 21:49:47 +09:00
Compare commits
16 Commits
worktree-f
...
pytest-cla
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7ee5b54438 | ||
|
|
389c05abeb | ||
|
|
dcc2c9cbb4 | ||
|
|
a9af4c3923 | ||
|
|
3092d0d68b | ||
|
|
8a2bd714ac | ||
|
|
54a26a044c | ||
|
|
5a0d3f87cc | ||
|
|
112d064700 | ||
|
|
c51c36fbcb | ||
|
|
ee372d464e | ||
|
|
1bef1344d1 | ||
|
|
8d41c491fd | ||
|
|
64f390a833 | ||
|
|
8d20581f38 | ||
|
|
bfd4ae9b27 |
130
.agents/skills/aoti-debug/SKILL.md
Normal file
130
.agents/skills/aoti-debug/SKILL.md
Normal file
@@ -0,0 +1,130 @@
|
||||
---
|
||||
name: aoti-debug
|
||||
description: Debug AOTInductor (AOTI) errors including device mismatches, CUDA illegal memory access, segfaults, and wrong outputs when deploying compiled PyTorch models. Use when encountering errors with aoti_compile_and_package, aoti_load_package, or the deprecated aot_compile/aot_load APIs.
|
||||
---
|
||||
|
||||
# AOTInductor Debugging
|
||||
|
||||
Debug errors when compiling and deploying PyTorch models with AOTInductor.
|
||||
|
||||
## First Step: Always Check Device and Shape Matching
|
||||
|
||||
**For ANY AOTI error (segfault, exception, crash, wrong output), check these first:**
|
||||
|
||||
1. **Compile device == Load device**: The model must be loaded on the same device type it was compiled on
|
||||
2. **Input devices match**: Runtime inputs must be on the same device as the compiled model
|
||||
3. **Input shapes match**: Runtime input shapes must match compilation shapes (or satisfy dynamic shape constraints)
|
||||
|
||||
```python
|
||||
# Compilation -- note the device and shapes
|
||||
model = MyModel().eval().cuda()
|
||||
inp = torch.randn(2, 10, device="cuda")
|
||||
pkg = torch._inductor.aoti_compile_and_package(model, (inp,))
|
||||
|
||||
# Loading -- device type MUST match compilation
|
||||
loaded = torch._inductor.aoti_load_package(pkg) # auto-detects device from package
|
||||
|
||||
# Inference -- device and shapes MUST match
|
||||
out = loaded(torch.randn(2, 10, device="cuda")) # same device, same shape
|
||||
```
|
||||
|
||||
**AOTI requires compile and load to use the same device type.** Cross-device loading (compile on GPU, load on CPU) is NOT supported. Device index can differ (cuda:0 vs cuda:1).
|
||||
|
||||
## Current vs Deprecated API
|
||||
|
||||
### Current API (use this)
|
||||
```python
|
||||
torch._inductor.aoti_compile_and_package() # compile
|
||||
torch._inductor.aoti_load_package() # load (auto-detects device)
|
||||
```
|
||||
|
||||
### Deprecated API (migrate away)
|
||||
```python
|
||||
torch._export.aot_compile() # deprecated
|
||||
torch._export.aot_load() # deprecated
|
||||
```
|
||||
|
||||
The new API stores device metadata in the package, so `aoti_load_package()` automatically uses the correct device type.
|
||||
|
||||
## Common Error Patterns
|
||||
|
||||
### Device Mismatch Segfault
|
||||
|
||||
**Symptom**: Segfault, exception, or crash during load or execution.
|
||||
|
||||
**Example errors**:
|
||||
- `The specified pointer resides on host memory and is not registered with any CUDA device`
|
||||
- Crash during constant loading
|
||||
- `Expected out tensor to have device cuda:0, but got cpu instead`
|
||||
|
||||
**Solution**: Ensure compile and load use the same device type.
|
||||
|
||||
### Input Device Mismatch at Runtime
|
||||
|
||||
**Symptom**: RuntimeError during model execution.
|
||||
|
||||
**Better debugging**: Run with `AOTI_RUNTIME_CHECK_INPUTS=1` for clear errors:
|
||||
```bash
|
||||
AOTI_RUNTIME_CHECK_INPUTS=1 python script.py
|
||||
```
|
||||
|
||||
Produces actionable messages like:
|
||||
```
|
||||
Error: input_handles[0]: unmatched device type, expected: 0(cpu), but got: 1(cuda)
|
||||
```
|
||||
|
||||
## Debugging CUDA Illegal Memory Access (IMA)
|
||||
|
||||
### Step 1: Sanity Checks
|
||||
|
||||
```bash
|
||||
AOTI_RUNTIME_CHECK_INPUTS=1 python script.py # validate inputs match compilation guards
|
||||
TORCHINDUCTOR_NAN_ASSERTS=1 python script.py # check for NaN before/after each kernel
|
||||
```
|
||||
|
||||
Both flags take effect at **compile time** (codegen time).
|
||||
|
||||
### Step 2: Make IMA Deterministic
|
||||
|
||||
```bash
|
||||
PYTORCH_NO_CUDA_MEMORY_CACHING=1 CUDA_LAUNCH_BLOCKING=1 python script.py
|
||||
```
|
||||
|
||||
- `PYTORCH_NO_CUDA_MEMORY_CACHING=1` -- disables caching allocator (which allocates bigger buffers, masking IMA)
|
||||
- `CUDA_LAUNCH_BLOCKING=1` -- forces synchronous kernel launches (pinpoints which kernel crashed)
|
||||
|
||||
Both take effect at **runtime**.
|
||||
|
||||
### Step 3: Identify the Problematic Kernel
|
||||
|
||||
```bash
|
||||
AOT_INDUCTOR_DEBUG_INTERMEDIATE_VALUE_PRINTER=3 python script.py
|
||||
```
|
||||
|
||||
Prints kernels one by one at runtime. Combined with Step 2 flags, shows which kernel launched right before the error.
|
||||
|
||||
To inspect inputs to specific kernels:
|
||||
```bash
|
||||
AOT_INDUCTOR_FILTERED_KERNELS_TO_PRINT="kernel_name_1,kernel_name_2" \
|
||||
AOT_INDUCTOR_DEBUG_INTERMEDIATE_VALUE_PRINTER=2 python script.py
|
||||
```
|
||||
|
||||
If inputs to a kernel are unexpected, trace back to the kernel that produced the bad input.
|
||||
|
||||
## Environment Variables Reference
|
||||
|
||||
| Variable | When | Purpose |
|
||||
|---|---|---|
|
||||
| `AOTI_RUNTIME_CHECK_INPUTS=1` | Compile time | Validate inputs match compilation guards |
|
||||
| `TORCHINDUCTOR_NAN_ASSERTS=1` | Compile time | Check for NaN before/after kernels |
|
||||
| `PYTORCH_NO_CUDA_MEMORY_CACHING=1` | Runtime | Make IMA errors deterministic |
|
||||
| `CUDA_LAUNCH_BLOCKING=1` | Runtime | Force synchronous kernel launches |
|
||||
| `AOT_INDUCTOR_DEBUG_INTERMEDIATE_VALUE_PRINTER=3` | Compile time | Print kernels at runtime |
|
||||
| `AOT_INDUCTOR_FILTERED_KERNELS_TO_PRINT="..."` | Compile time | Filter which kernels to print |
|
||||
| `TORCH_LOGS="+inductor,output_code"` | Runtime | See PT2 internal logs |
|
||||
| `TORCH_SHOW_CPP_STACKTRACES=1` | Runtime | Show C++ stack traces |
|
||||
|
||||
## Common Sources of Issues
|
||||
|
||||
- **Dynamic shapes**: Historically a common source of IMA errors. Pay special attention when using dynamic shape constraints.
|
||||
- **Custom ops**: Especially C++ custom ops with dynamic shapes. The meta function may need to handle SymInt properly.
|
||||
195
.agents/skills/pt2-debug/SKILL.md
Normal file
195
.agents/skills/pt2-debug/SKILL.md
Normal file
@@ -0,0 +1,195 @@
|
||||
---
|
||||
name: pt2-debug
|
||||
description: Debug torch.compile failures, graph breaks, recompilation issues, accuracy mismatches, and Triton kernel errors. Use when encountering BackendCompilerFailed exceptions, torch.compile errors, recompilation warnings, or numerical accuracy issues with compiled PyTorch models.
|
||||
---
|
||||
|
||||
# PyTorch 2 Compile Debugging
|
||||
|
||||
Debug `torch.compile`, Dynamo, Inductor, and AOTAutograd failures when using PyTorch as a library.
|
||||
|
||||
## Diagnostic Environment Variables
|
||||
|
||||
Pick the right diagnostic based on the error:
|
||||
|
||||
| Command | When to use |
|
||||
|---|---|
|
||||
| `TORCH_LOGS="+dynamo,graph_breaks,recompiles" python script.py` | Quick overview of what's going wrong |
|
||||
| `TORCH_COMPILE_DEBUG=1 python script.py` | Full debug artifacts (FX graphs, Inductor IR, generated code) in `torch_compile_debug/` |
|
||||
| `TORCH_LOGS="output_code" python script.py` | See the generated Triton/C++ kernel code |
|
||||
| `TORCH_TRACE=/path/to/trace python script.py` | Structured trace (parse with `tlparse`) |
|
||||
| `TORCHINDUCTOR_COMPILE_THREADS=1 python script.py` | Single-threaded compilation for pdb debugging |
|
||||
|
||||
## Error Triage
|
||||
|
||||
Classify the failure and jump to the right section:
|
||||
|
||||
| Error Pattern | Category |
|
||||
|---|---|
|
||||
| `Unsupported: ...` or `graph break` in logs | [Graph Breaks](#graph-breaks) |
|
||||
| `BackendCompilerFailed` | [Backend Failures](#backend-compiler-failures) |
|
||||
| `RecompileError` or `cache_size_limit` | [Recompilation](#recompilation-issues) |
|
||||
| Accuracy mismatch / wrong numerical output | [Accuracy](#accuracy-issues) |
|
||||
| `InternalTorchDynamoError` | [Internal Errors](#internal-dynamo-errors) |
|
||||
| Segfault or CUDA IMA | [Runtime Crashes](#runtime-crashes) |
|
||||
| Triton assertion / index out of bounds | [Triton Failures](#triton-kernel-failures) |
|
||||
|
||||
## Graph Breaks
|
||||
|
||||
Graph breaks split the compiled graph into smaller subgraphs, causing performance regressions.
|
||||
|
||||
**Diagnose:**
|
||||
```bash
|
||||
TORCH_LOGS="graph_breaks" python script.py
|
||||
```
|
||||
|
||||
**Common causes:**
|
||||
- Data-dependent control flow
|
||||
- Unsupported Python builtins
|
||||
- In-place ops on inputs, unsupported dtypes
|
||||
- Calls to non-traceable functions
|
||||
|
||||
**Fix approaches:**
|
||||
1. Read the graph break message to identify the unsupported operation
|
||||
2. Check for a decomposition or supported alternative
|
||||
3. Consider `torch._dynamo.allow_in_graph` or restructure user code
|
||||
|
||||
## Backend Compiler Failures
|
||||
|
||||
`BackendCompilerFailed` means Inductor crashed during compilation.
|
||||
|
||||
**Diagnose with the minifier:**
|
||||
```bash
|
||||
# Generate minifier launcher
|
||||
TORCHDYNAMO_REPRO_AFTER=aot TORCHDYNAMO_REPRO_LEVEL=2 python script.py
|
||||
|
||||
# Run the minifier to get minimal failing graph
|
||||
python minifier_launcher.py minify
|
||||
|
||||
# Run the minimized reproduction
|
||||
python minifier_launcher.py run
|
||||
```
|
||||
|
||||
**Then inspect:**
|
||||
```bash
|
||||
TORCH_COMPILE_DEBUG=1 python script.py # FX graphs in torch_compile_debug/
|
||||
```
|
||||
|
||||
## Recompilation Issues
|
||||
|
||||
Excessive recompilation from guards that are too specific, causing cache misses.
|
||||
|
||||
**Diagnose:**
|
||||
```bash
|
||||
TORCH_LOGS="recompiles,recompiles_verbose,guards" python script.py
|
||||
```
|
||||
|
||||
**Key config:**
|
||||
```python
|
||||
torch._dynamo.config.recompile_limit # default: 8
|
||||
torch._dynamo.config.fail_on_recompile_limit_hit = True # hard error on limit
|
||||
```
|
||||
|
||||
**Common causes:**
|
||||
- Changing tensor shapes without marking them dynamic
|
||||
- Python scalar values that change between calls
|
||||
- Global state mutations between calls
|
||||
|
||||
**Fix:** Read the recompilation reason from logs, identify the failing guard, then either:
|
||||
- Mark dimensions as dynamic: `torch._dynamo.mark_dynamic(tensor, dim)`
|
||||
- Fix the source of guard instability
|
||||
|
||||
## Accuracy Issues
|
||||
|
||||
Compiled model produces different numerical results than eager mode.
|
||||
|
||||
**Diagnose:**
|
||||
```bash
|
||||
# Compares compiled vs eager with fp64 reference, dumps repro on failure
|
||||
TORCHDYNAMO_REPRO_AFTER=aot TORCHDYNAMO_REPRO_LEVEL=4 python script.py
|
||||
```
|
||||
|
||||
**Fix approach:**
|
||||
1. Get minimal failing graph from the minifier
|
||||
2. Compare eager vs compiled output at fp64 precision
|
||||
3. Binary search through ops to find the diverging operation
|
||||
4. Check for known issues: reduction order, fused kernels, dtype promotions
|
||||
|
||||
## Internal Dynamo Errors
|
||||
|
||||
`InternalTorchDynamoError` indicates a bug in Dynamo.
|
||||
|
||||
**Diagnose:**
|
||||
```bash
|
||||
TORCHDYNAMO_VERBOSE=1 python script.py
|
||||
# or equivalently:
|
||||
TORCH_LOGS="+dynamo" python script.py
|
||||
```
|
||||
|
||||
**Debug interactively:**
|
||||
```bash
|
||||
TORCHINDUCTOR_COMPILE_THREADS=1 python script.py # then attach pdb
|
||||
```
|
||||
|
||||
## Runtime Crashes
|
||||
|
||||
Segfaults and CUDA illegal memory access during execution of compiled code.
|
||||
|
||||
**Make crash deterministic:**
|
||||
```bash
|
||||
PYTORCH_NO_CUDA_MEMORY_CACHING=1 CUDA_LAUNCH_BLOCKING=1 python script.py
|
||||
```
|
||||
|
||||
**Add NaN checks to find the first bad kernel:**
|
||||
```bash
|
||||
TORCHINDUCTOR_NAN_ASSERTS=1 python script.py
|
||||
```
|
||||
|
||||
**Inductor sync debugging:**
|
||||
```python
|
||||
torch._inductor.config.triton.debug_sync_kernel = True # sync after every kernel
|
||||
torch._inductor.config.triton.debug_sync_graph = True # sync before/after graph
|
||||
```
|
||||
|
||||
**Fix approach:**
|
||||
1. Make deterministic with `PYTORCH_NO_CUDA_MEMORY_CACHING=1 CUDA_LAUNCH_BLOCKING=1`
|
||||
2. Check input shapes, devices, dtypes
|
||||
3. Inspect generated kernel code with `TORCH_LOGS="output_code"`
|
||||
4. Use `TORCHINDUCTOR_NAN_ASSERTS=1` to find the first kernel producing bad values
|
||||
5. Dynamic shapes are historically a common source of IMA
|
||||
|
||||
## Triton Kernel Failures
|
||||
|
||||
Triton assertion failures or index-out-of-bounds in generated kernels.
|
||||
|
||||
**Diagnose:**
|
||||
```bash
|
||||
TORCH_LOGS="output_code,schedule" python script.py
|
||||
```
|
||||
|
||||
**Fix approach:**
|
||||
1. Get the generated Triton kernel from `output_code` logs
|
||||
2. Check index computations for off-by-one or wrong stride calculations
|
||||
3. Check IR with `TORCH_COMPILE_DEBUG=1` to trace back to the FX op
|
||||
4. Check if fusion decisions created invalid index combinations
|
||||
|
||||
## Distinguish Trace-Time vs Runtime
|
||||
|
||||
Many bugs come from confusing these:
|
||||
- **Trace-time**: Inside Dynamo's symbolic interpreter. Function calls may be constant-folded.
|
||||
- **Runtime**: Real tensors, real Python calls.
|
||||
|
||||
When debugging, add `print()` directly in source files rather than monkey-patching -- dispatch chains make monkey-patching unreliable.
|
||||
|
||||
## Using the Minifier
|
||||
|
||||
The minifier reduces a failing graph to the smallest reproduction:
|
||||
|
||||
```bash
|
||||
# For compilation failures (level 2)
|
||||
TORCHDYNAMO_REPRO_AFTER=aot TORCHDYNAMO_REPRO_LEVEL=2 python script.py
|
||||
python minifier_launcher.py minify
|
||||
python minifier_launcher.py run
|
||||
|
||||
# For accuracy failures (level 4)
|
||||
TORCHDYNAMO_REPRO_AFTER=aot TORCHDYNAMO_REPRO_LEVEL=4 python script.py
|
||||
```
|
||||
134
.agents/skills/ruff/SKILL.md
Normal file
134
.agents/skills/ruff/SKILL.md
Normal file
@@ -0,0 +1,134 @@
|
||||
---
|
||||
name: ruff
|
||||
description:
|
||||
Guide for using ruff, the extremely fast Python linter and formatter. Use this
|
||||
when linting, formatting, or fixing Python code.
|
||||
---
|
||||
|
||||
# ruff
|
||||
|
||||
Ruff is an extremely fast Python linter and code formatter. It replaces Flake8,
|
||||
isort, Black, pyupgrade, autoflake, and dozens of other tools.
|
||||
|
||||
## When to use ruff
|
||||
|
||||
**Always use ruff for Python linting and formatting**, especially if you see:
|
||||
|
||||
- `[tool.ruff]` section in `pyproject.toml`
|
||||
- A `ruff.toml` or `.ruff.toml` configuration file
|
||||
|
||||
However, avoid making unnecessary changes:
|
||||
|
||||
- **Don't format unformatted code** - If `ruff format --diff` shows changes
|
||||
throughout an entire file, the project likely isn't using ruff for formatting.
|
||||
Skip formatting to avoid obscuring actual changes.
|
||||
- **Scope fixes to code being edited** - Use `ruff check --diff` to see fixes
|
||||
relevant to the code you're changing. Only apply fixes to files you're
|
||||
modifying unless the user explicitly asks for broader fixes.
|
||||
|
||||
## How to invoke ruff
|
||||
|
||||
- `uv run ruff ...` - Use when ruff is in the project's dependencies to ensure
|
||||
you use the pinned version
|
||||
- `uvx ruff ...` - Use when ruff is not a project dependency, or for quick
|
||||
one-off checks
|
||||
- `ruff ...` - Use if ruff is installed globally
|
||||
|
||||
## Commands
|
||||
|
||||
### Linting
|
||||
|
||||
```bash
|
||||
ruff check . # Check all files in current directory
|
||||
ruff check path/to/file.py # Check specific file
|
||||
ruff check --fix . # Auto-fix fixable violations
|
||||
ruff check --fix --unsafe-fixes . # Include unsafe fixes (review changes!)
|
||||
ruff check --watch . # Watch for changes and re-lint
|
||||
ruff check --select E,F . # Only check specific rules
|
||||
ruff check --ignore E501 . # Ignore specific rules
|
||||
ruff rule E501 # Explain a specific rule
|
||||
ruff linter # List available linters
|
||||
```
|
||||
|
||||
### Formatting
|
||||
|
||||
```bash
|
||||
ruff format . # Format all files
|
||||
ruff format path/to/file.py # Format specific file
|
||||
ruff format --check . # Check if files are formatted (no changes)
|
||||
ruff format --diff . # Show formatting diff without applying
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
Ruff is configured in `pyproject.toml` or `ruff.toml`:
|
||||
|
||||
```toml
|
||||
# pyproject.toml
|
||||
[tool.ruff.lint]
|
||||
select = ["E", "F", "I", "UP"] # Enable specific rule sets
|
||||
ignore = ["E501"] # Ignore specific rules
|
||||
|
||||
[tool.ruff.lint.isort]
|
||||
known-first-party = ["myproject"]
|
||||
```
|
||||
|
||||
## Migrating from other tools
|
||||
|
||||
### Black → ruff format
|
||||
|
||||
```bash
|
||||
black . → ruff format .
|
||||
black --check . → ruff format --check .
|
||||
black --diff . → ruff format --diff .
|
||||
```
|
||||
|
||||
### Flake8 → ruff check
|
||||
|
||||
```bash
|
||||
flake8 . → ruff check .
|
||||
flake8 --select E,F . → ruff check --select E,F .
|
||||
flake8 --ignore E501 . → ruff check --ignore E501 .
|
||||
```
|
||||
|
||||
### isort → ruff check
|
||||
|
||||
```bash
|
||||
isort . → ruff check --select I --fix .
|
||||
isort --check . → ruff check --select I .
|
||||
isort --diff . → ruff check --select I --diff .
|
||||
```
|
||||
|
||||
## Common patterns
|
||||
|
||||
### Apply lint fixes before formatting
|
||||
|
||||
Run `ruff check --fix` before `ruff format`. Lint fixes can change code
|
||||
structure (e.g., reordering imports), which formatting then cleans up.
|
||||
|
||||
```bash
|
||||
ruff check --fix .
|
||||
ruff format .
|
||||
```
|
||||
|
||||
### Applying and reviewing unsafe fixes
|
||||
|
||||
Ruff categorizes some auto-fixes as "unsafe" because they may change code
|
||||
behavior, not just style. For example, removing unused imports could break code
|
||||
that relies on side effects.
|
||||
|
||||
```bash
|
||||
ruff check --fix --unsafe-fixes --diff . # Preview changes first
|
||||
ruff check --fix --unsafe-fixes . # Apply changes
|
||||
```
|
||||
|
||||
**Always review changes before applying `--unsafe-fixes`:**
|
||||
|
||||
- Use `ruff rule <CODE>` to understand why the fix is considered unsafe
|
||||
- Verify the fix doesn't violate those assumptions in your code
|
||||
|
||||
## Documentation
|
||||
|
||||
For detailed information, read the official documentation:
|
||||
|
||||
- https://docs.astral.sh/ruff/
|
||||
135
.agents/skills/ty/SKILL.md
Normal file
135
.agents/skills/ty/SKILL.md
Normal file
@@ -0,0 +1,135 @@
|
||||
---
|
||||
name: ty
|
||||
description:
|
||||
Guide for using ty, the extremely fast Python type checker and language
|
||||
server. Use this when type checking Python code or setting up type checking in
|
||||
Python projects.
|
||||
---
|
||||
|
||||
# ty
|
||||
|
||||
ty is an extremely fast Python type checker and language server. It replaces
|
||||
mypy, Pyright, and other type checkers.
|
||||
|
||||
## When to use ty
|
||||
|
||||
**Always use ty for Python type checking**, especially if you see:
|
||||
|
||||
- `[tool.ty]` section in `pyproject.toml`
|
||||
- A `ty.toml` configuration file
|
||||
|
||||
## How to invoke ty
|
||||
|
||||
- `uv run ty ...` - Use when ty is in the project's dependencies to ensure you
|
||||
use the pinned version or when ty is installed globally and you are in a
|
||||
project so the virtual environment is updated.
|
||||
- `uvx ty ...` - Use when ty is not a project dependency, or for quick one-off
|
||||
checks
|
||||
|
||||
## Commands
|
||||
|
||||
### Type checking
|
||||
|
||||
```bash
|
||||
ty check # Check all files in current directory
|
||||
ty check path/to/file.py # Check specific file
|
||||
ty check src/ # Check specific directory
|
||||
```
|
||||
|
||||
### Rule configuration
|
||||
|
||||
```bash
|
||||
ty check --error possibly-unresolved-reference # Treat as error
|
||||
ty check --warn division-by-zero # Treat as warning
|
||||
ty check --ignore unresolved-import # Disable rule
|
||||
```
|
||||
|
||||
### Python version targeting
|
||||
|
||||
```bash
|
||||
ty check --python-version 3.12 # Check against Python 3.12
|
||||
ty check --python-platform linux # Target Linux platform
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
ty is configured in `pyproject.toml` or `ty.toml`:
|
||||
|
||||
```toml
|
||||
# pyproject.toml
|
||||
[tool.ty.environment]
|
||||
python-version = "3.12"
|
||||
|
||||
[tool.ty.rules]
|
||||
possibly-unresolved-reference = "warn"
|
||||
division-by-zero = "error"
|
||||
|
||||
[tool.ty.src]
|
||||
include = ["src/**/*.py"]
|
||||
exclude = ["**/migrations/**"]
|
||||
|
||||
[tool.ty.terminal]
|
||||
output-format = "full"
|
||||
error-on-warning = false
|
||||
```
|
||||
|
||||
### Per-file overrides
|
||||
|
||||
Use overrides to apply different rules to specific files, such as relaxing rules
|
||||
for tests or scripts that have different typing requirements than production
|
||||
code:
|
||||
|
||||
```toml
|
||||
[[tool.ty.overrides]]
|
||||
include = ["tests/**", "**/test_*.py"]
|
||||
|
||||
[tool.ty.overrides.rules]
|
||||
possibly-unresolved-reference = "warn"
|
||||
```
|
||||
|
||||
## Language server
|
||||
|
||||
This plugin automatically configures the ty language server for Python files
|
||||
(`.py` and `.pyi`).
|
||||
|
||||
## Migrating from other tools
|
||||
|
||||
### mypy → ty
|
||||
|
||||
```bash
|
||||
mypy . → ty check
|
||||
mypy --strict . → ty check --error-on-warning
|
||||
mypy path/to/file.py → ty check path/to/file.py
|
||||
```
|
||||
|
||||
### Pyright → ty
|
||||
|
||||
```bash
|
||||
pyright . → ty check
|
||||
pyright path/to/file.py → ty check path/to/file.py
|
||||
```
|
||||
|
||||
## Common patterns
|
||||
|
||||
### Don't add ignore comments
|
||||
|
||||
Fix type errors instead of suppressing them. Only add ignore comments when
|
||||
explicitly requested by the user. Use `ty: ignore`, not `type: ignore`, and
|
||||
prefer rule-specific ignores:
|
||||
|
||||
```python
|
||||
# Good: rule-specific ignore
|
||||
x = undefined_var # ty: ignore[possibly-unresolved-reference]
|
||||
|
||||
# Bad: blanket ty ignore
|
||||
x = undefined_var # ty: ignore
|
||||
|
||||
# Bad: tool agnostic blanket ignore
|
||||
x = undefined_var # type: ignore
|
||||
```
|
||||
|
||||
## Documentation
|
||||
|
||||
For detailed information, read the official documentation:
|
||||
|
||||
- https://docs.astral.sh/ty/
|
||||
182
.agents/skills/uv/SKILL.md
Normal file
182
.agents/skills/uv/SKILL.md
Normal file
@@ -0,0 +1,182 @@
|
||||
---
|
||||
name: uv
|
||||
description:
|
||||
Guide for using uv, the Python package and project manager. Use this when
|
||||
working with Python projects, scripts, packages, or tools.
|
||||
---
|
||||
|
||||
# uv
|
||||
|
||||
uv is an extremely fast Python package and project manager. It replaces pip,
|
||||
pip-tools, pipx, pyenv, virtualenv, poetry, etc.
|
||||
|
||||
## When to use uv
|
||||
|
||||
**Always use uv for Python work**, especially if you see:
|
||||
|
||||
- The `uv.lock` file
|
||||
- uv headers in `requirements*` files, e.g., "This file was autogenerated by uv"
|
||||
|
||||
Don't use uv in projects managed by other tools:
|
||||
|
||||
- Poetry projects (identifiable by `poetry.lock` file)
|
||||
- PDM projects (identifiable by `pdm.lock` file)
|
||||
|
||||
## Choosing the right workflow
|
||||
|
||||
### Scripts
|
||||
|
||||
**Use when:** Running single Python files and standalone scripts.
|
||||
|
||||
**Key commands:**
|
||||
|
||||
```bash
|
||||
uv run script.py # Run a script
|
||||
uv run --with requests script.py # Run with additional packages
|
||||
uv add --script script.py requests # Add dependencies inline to the script
|
||||
```
|
||||
|
||||
### Projects
|
||||
|
||||
**Use when:** There is a `pyproject.toml` or `uv.lock`
|
||||
|
||||
**Key commands:**
|
||||
|
||||
```bash
|
||||
uv init # Create new project
|
||||
uv add requests # Add dependency
|
||||
uv remove requests # Remove dependency
|
||||
uv sync # Install from lockfile
|
||||
uv run <command> # Run commands in environment
|
||||
uv run python -c "" # Run Python in project environment
|
||||
uv run -p 3.12 <command> # Run with specific Python version
|
||||
```
|
||||
|
||||
### Tools
|
||||
|
||||
**Use when:** Running command-line tools (e.g., ruff, ty, pytest) without
|
||||
installation.
|
||||
|
||||
**Key commands:**
|
||||
|
||||
```bash
|
||||
uvx <tool> <args> # Run a tool without installation
|
||||
uvx <tool>@<version> <args> # Run a specific version of a tool
|
||||
```
|
||||
|
||||
**Important:**
|
||||
|
||||
- `uvx` runs tools from PyPI by package name. This can be unsafe - only run
|
||||
well-known tools.
|
||||
- Only use `uv tool install` only when specifically requested by the user.
|
||||
|
||||
### Pip interface
|
||||
|
||||
**Use when:** Legacy workflows with `requirements.txt` or manual environment
|
||||
management, no `uv.lock` present.
|
||||
|
||||
**Key commands:**
|
||||
|
||||
```bash
|
||||
uv venv
|
||||
uv pip install -r requirements.txt
|
||||
uv pip compile requirements.in -o requirements.txt
|
||||
uv pip sync requirements.txt
|
||||
|
||||
# Platform independent resolution
|
||||
uv pip compile --universal requirements.in -o requirements.txt
|
||||
```
|
||||
|
||||
**Important:**
|
||||
|
||||
- Don't use the pip interface unless clearly needed.
|
||||
- Don't introduce new `requirements.txt` files.
|
||||
- Prefer `uv init` for new projects.
|
||||
|
||||
## Migrating from other tools
|
||||
|
||||
### pyenv → uv python
|
||||
|
||||
```bash
|
||||
pyenv install 3.12 → uv python install 3.12
|
||||
pyenv versions → uv python list --only-installed
|
||||
pyenv local 3.12 → uv python pin 3.12
|
||||
pyenv global 3.12 → uv python install 3.12 --default
|
||||
```
|
||||
|
||||
### pipx → uvx
|
||||
|
||||
```bash
|
||||
pipx run ruff → uvx ruff
|
||||
pipx install ruff → uv tool install ruff
|
||||
pipx upgrade ruff → uv tool upgrade ruff
|
||||
pipx list → uv tool list
|
||||
```
|
||||
|
||||
### pip and pip-tools → uv pip
|
||||
|
||||
```bash
|
||||
pip install package → uv pip install package
|
||||
pip install -r req.txt → uv pip install -r req.txt
|
||||
pip freeze → uv pip freeze
|
||||
pip-compile req.in → uv pip compile req.in
|
||||
pip-sync req.txt → uv pip sync req.txt
|
||||
virtualenv .venv → uv venv
|
||||
```
|
||||
|
||||
## Common patterns
|
||||
|
||||
### Don't use pip in uv projects
|
||||
|
||||
```bash
|
||||
# Bad
|
||||
pip install requests
|
||||
|
||||
# Good
|
||||
uv add requests
|
||||
```
|
||||
|
||||
### Don't run python directly
|
||||
|
||||
```bash
|
||||
# Bad
|
||||
python script.py
|
||||
|
||||
# Good
|
||||
uv run script.py
|
||||
```
|
||||
|
||||
```bash
|
||||
# Bad
|
||||
python -c "..."
|
||||
|
||||
# Good
|
||||
uv run python -c "..."
|
||||
```
|
||||
|
||||
```bash
|
||||
# Bad
|
||||
python3.12 -c "..."
|
||||
|
||||
# Good
|
||||
uvx python@3.12 -c "..."
|
||||
```
|
||||
|
||||
### Don't manually manage environments in uv projects
|
||||
|
||||
```bash
|
||||
# Bad
|
||||
python -m venv .venv
|
||||
source .venv/bin/activate
|
||||
|
||||
# Good
|
||||
uv run <command>
|
||||
```
|
||||
|
||||
## Documentation
|
||||
|
||||
For detailed information, read the official documentation:
|
||||
|
||||
- https://docs.astral.sh/uv/llms.txt
|
||||
|
||||
The documentation links to specific pages for each of these workflows.
|
||||
@@ -17,11 +17,15 @@
|
||||
"userUid": "1000",
|
||||
"userGid": "1000",
|
||||
"configureZshAsDefaultShell": false
|
||||
},
|
||||
"ghcr.io/devcontainers/features/node:1": {
|
||||
"version": "lts"
|
||||
}
|
||||
},
|
||||
"remoteUser": "ubuntu",
|
||||
"remoteEnv": {
|
||||
"CARGO_HOME": "/home/ubuntu/.cache/luminal/cargo"
|
||||
"CARGO_HOME": "/home/ubuntu/.cache/luminal/cargo",
|
||||
"CODEX_HOME": "${containerWorkspaceFolder}/.claude/codex"
|
||||
},
|
||||
"postStartCommand": "mkdir -p /home/ubuntu/.cache/luminal/cargo && git config --global --add safe.directory ${containerWorkspaceFolder} && gh auth setup-git",
|
||||
"customizations": {
|
||||
|
||||
@@ -21,11 +21,15 @@
|
||||
"userUid": "1000",
|
||||
"userGid": "1000",
|
||||
"configureZshAsDefaultShell": false
|
||||
},
|
||||
"ghcr.io/devcontainers/features/node:1": {
|
||||
"version": "lts"
|
||||
}
|
||||
},
|
||||
"remoteUser": "ubuntu",
|
||||
"remoteEnv": {
|
||||
"CARGO_HOME": "/home/ubuntu/.cache/luminal/cargo"
|
||||
"CARGO_HOME": "/home/ubuntu/.cache/luminal/cargo",
|
||||
"CODEX_HOME": "${containerWorkspaceFolder}/.claude/codex"
|
||||
},
|
||||
"postStartCommand": "mkdir -p /home/ubuntu/.cache/luminal/cargo && git config --global --add safe.directory ${containerWorkspaceFolder} && gh auth setup-git",
|
||||
"customizations": {
|
||||
@@ -52,4 +56,4 @@
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
6
.github/workflows/modal-examples.yml
vendored
6
.github/workflows/modal-examples.yml
vendored
@@ -3,7 +3,7 @@ name: Modal Examples
|
||||
on:
|
||||
push:
|
||||
branches: ["main"]
|
||||
pull_request_target:
|
||||
pull_request:
|
||||
branches: ["main"]
|
||||
types: [labeled, synchronize]
|
||||
workflow_dispatch:
|
||||
@@ -13,7 +13,7 @@ jobs:
|
||||
if: >-
|
||||
github.event_name == 'push'
|
||||
|| github.event_name == 'workflow_dispatch'
|
||||
|| (github.event_name == 'pull_request_target'
|
||||
|| (github.event_name == 'pull_request'
|
||||
&& contains(github.event.pull_request.labels.*.name, 'modal-ready'))
|
||||
name: "${{ matrix.example }} (Modal ${{ matrix.gpu.type }})"
|
||||
runs-on: ubuntu-latest
|
||||
@@ -30,8 +30,6 @@ jobs:
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
with:
|
||||
ref: ${{ github.event.pull_request.head.sha || github.sha }}
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
|
||||
6
.github/workflows/test-cuda.yml
vendored
6
.github/workflows/test-cuda.yml
vendored
@@ -3,7 +3,7 @@ name: Test CUDA
|
||||
on:
|
||||
push:
|
||||
branches: ["main"]
|
||||
pull_request_target:
|
||||
pull_request:
|
||||
branches: ["main"]
|
||||
types: [labeled, synchronize]
|
||||
workflow_dispatch:
|
||||
@@ -13,7 +13,7 @@ jobs:
|
||||
if: >-
|
||||
github.event_name == 'push'
|
||||
|| github.event_name == 'workflow_dispatch'
|
||||
|| (github.event_name == 'pull_request_target'
|
||||
|| (github.event_name == 'pull_request'
|
||||
&& contains(github.event.pull_request.labels.*.name, 'modal-ready'))
|
||||
name: Cuda Unit Tests
|
||||
runs-on: ubuntu-latest
|
||||
@@ -22,8 +22,6 @@ jobs:
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
with:
|
||||
ref: ${{ github.event.pull_request.head.sha || github.sha }}
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
|
||||
6
.github/workflows/test-python-cuda.yml
vendored
6
.github/workflows/test-python-cuda.yml
vendored
@@ -3,7 +3,7 @@ name: Test Python CUDA
|
||||
on:
|
||||
push:
|
||||
branches: ["main"]
|
||||
pull_request_target:
|
||||
pull_request:
|
||||
branches: ["main"]
|
||||
types: [labeled, synchronize]
|
||||
workflow_dispatch:
|
||||
@@ -13,7 +13,7 @@ jobs:
|
||||
if: >-
|
||||
github.event_name == 'push'
|
||||
|| github.event_name == 'workflow_dispatch'
|
||||
|| (github.event_name == 'pull_request_target'
|
||||
|| (github.event_name == 'pull_request'
|
||||
&& contains(github.event.pull_request.labels.*.name, 'modal-ready'))
|
||||
name: Python CUDA Tests
|
||||
runs-on: ubuntu-latest
|
||||
@@ -25,8 +25,6 @@ jobs:
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
with:
|
||||
ref: ${{ github.event.pull_request.head.sha || github.sha }}
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -1,6 +1,9 @@
|
||||
/target
|
||||
/crates/**/target
|
||||
/examples/**/target
|
||||
.claude-project
|
||||
.claude-memory
|
||||
.codex
|
||||
|
||||
*.env
|
||||
.claude/
|
||||
|
||||
34
CLAUDE.md
Normal file
34
CLAUDE.md
Normal file
@@ -0,0 +1,34 @@
|
||||
# Luminal
|
||||
|
||||
## Package Management
|
||||
|
||||
- Use `uv add`, `uv add --dev`, `uv remove` for Python dependencies (pyproject.toml is in `crates/luminal_python/`)
|
||||
- Use `uv sync` to sync the Python environment
|
||||
- Never use pip, pip-tools, poetry, or conda
|
||||
- Never manually create or activate virtual environments — uv manages `.venv/` automatically
|
||||
- Never generate requirements.txt
|
||||
|
||||
## Code Execution
|
||||
|
||||
- Always use `uv run` to execute Python tools: `uv run pytest`, `uv run pre-commit`, `uv run python`
|
||||
- Use `cargo` directly for Rust: `cargo build`, `cargo test`, `cargo check`, `cargo clippy`
|
||||
- Python project root is `crates/luminal_python/` — run `uv run` commands from there
|
||||
|
||||
## Building the Python Package (Maturin)
|
||||
|
||||
- After modifying `.rs` files that affect the Python bridge, rebuild with: `maturin develop --release`
|
||||
- Maturin config is in `crates/luminal_python/pyproject.toml` under `[tool.maturin]`
|
||||
|
||||
## Pre-commit
|
||||
|
||||
- Run with: `uv run pre-commit run --all-files`
|
||||
- Hooks configured: ruff-check, ruff-format (Python), cargo-fmt, cargo-clippy (Rust)
|
||||
- Manual-stage hooks (cargo-clippy-metal, cargo-clippy-cuda-lite) run with `--hook-stage manual`
|
||||
|
||||
## Testing
|
||||
|
||||
- **Rust tests**: `cargo test -p <crate_name>`
|
||||
- **Python tests**: `cd crates/luminal_python && uv run pytest`
|
||||
- `./run_test.sh` — native backend
|
||||
- `./run_tests_cuda.sh` — CUDA backend
|
||||
- See `crates/luminal_python/CLAUDE.md` for Python test patterns and conventions
|
||||
@@ -4,7 +4,7 @@
|
||||
Luminal is a high-performance general-purpose inference compiler.
|
||||
</h3>
|
||||
|
||||
[](https://github.com/luminal-ai/luminal/actions)
|
||||
[](https://github.com/jafioti/luminal/actions)
|
||||
[](https://docs.luminalai.com)
|
||||
[](https://crates.io/crates/luminal)
|
||||
[](https://discord.gg/APjuwHAbGy)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import modal
|
||||
import subprocess
|
||||
import os
|
||||
import sys
|
||||
|
||||
gpu_type = os.environ.get("GPU_TYPE", "T4")
|
||||
CUDARC_CUDA_VERSION = "12080"
|
||||
@@ -45,10 +46,8 @@ def run_cargo_test():
|
||||
|
||||
subprocess.run(
|
||||
[
|
||||
"cargo",
|
||||
"test",
|
||||
"-p",
|
||||
"luminal_cuda_lite",
|
||||
"cargo", "test",
|
||||
"-p", "luminal_cuda_lite",
|
||||
"--verbose",
|
||||
"--",
|
||||
"--test-threads=1",
|
||||
|
||||
@@ -461,8 +461,7 @@ impl HostOp for CuBlasLt {
|
||||
cublasLtMatmulDescDestroy(matmul_desc);
|
||||
}
|
||||
|
||||
// No stream.synchronize() here — CUDA stream ordering guarantees
|
||||
// sequential execution. The runtime syncs once at the end of execute().
|
||||
stream.synchronize()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
@@ -653,53 +653,4 @@ mod tests {
|
||||
}
|
||||
assert_close(&rt.get_f32(output), &expected, 1e-2, 1e-2);
|
||||
}
|
||||
|
||||
/// Test that CUDA graphs produce correct results when dynamic dimensions
|
||||
/// change incrementally across many executions (simulating a decode loop
|
||||
/// where position offset increments each step).
|
||||
#[test]
|
||||
fn test_cuda_graph_incremental_dim_changes() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
let mut cx = Graph::default();
|
||||
let a = cx.tensor('s');
|
||||
let b = cx.tensor('s');
|
||||
let c = ((a + b) * a).output();
|
||||
|
||||
let initial_size = 128;
|
||||
cx.set_dim('s', initial_size);
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
let data_a = random_f32_vec(initial_size, 42, -0.5, 0.5);
|
||||
let data_b = random_f32_vec(initial_size, 43, -0.5, 0.5);
|
||||
rt.set_data(a, data_a.clone());
|
||||
rt.set_data(b, data_b.clone());
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
rt = cx.search(rt, 5);
|
||||
|
||||
// Initial execution
|
||||
rt.execute(&cx.dyn_map);
|
||||
let eps = dtype_epsilon(luminal::dtype::DType::F32);
|
||||
let tol = eps * TOLERANCE_SAFETY_FACTOR;
|
||||
let expected: Vec<f32> = data_a
|
||||
.iter()
|
||||
.zip(&data_b)
|
||||
.map(|(a, b)| (a + b) * a)
|
||||
.collect();
|
||||
assert_close(&rt.get_f32(c), &expected, tol, tol);
|
||||
|
||||
// Incrementally change the dynamic dimension 10 times,
|
||||
// simulating decode steps where position offset grows.
|
||||
for step in 1..=10usize {
|
||||
let size = initial_size + step;
|
||||
cx.set_dim('s', size);
|
||||
let da = random_f32_vec(size, 100 + step as u64, -0.5, 0.5);
|
||||
let db = random_f32_vec(size, 200 + step as u64, -0.5, 0.5);
|
||||
rt.set_data(a, da.clone());
|
||||
rt.set_data(b, db.clone());
|
||||
rt.execute(&cx.dyn_map);
|
||||
let expected: Vec<f32> = da.iter().zip(&db).map(|(a, b)| (a + b) * a).collect();
|
||||
assert_close(&rt.get_f32(c), &expected, tol, tol);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,774 +0,0 @@
|
||||
//! # Elementwise Kernel Fusion
|
||||
//!
|
||||
//! Fuses chains of adjacent pointwise KernelOps into a single CUDA kernel,
|
||||
//! eliminating intermediate global-memory round-trips and kernel-launch overhead.
|
||||
//!
|
||||
//! ## Where this sits in the pipeline
|
||||
//!
|
||||
//! ```text
|
||||
//! HLIR Graph
|
||||
//! | (egglog rewrite rules)
|
||||
//! v
|
||||
//! E-Graph ────── genetic search
|
||||
//! |
|
||||
//! v
|
||||
//! LLIR Graph
|
||||
//! |
|
||||
//! v
|
||||
//! kernel_to_host()
|
||||
//! ├── identify_chains() <── THIS MODULE
|
||||
//! ├── build_chain_fused() <── THIS MODULE
|
||||
//! └── compile kernels → CudaGraphOps
|
||||
//! |
|
||||
//! v
|
||||
//! CUDA execution
|
||||
//! ```
|
||||
//!
|
||||
//! ## What it does
|
||||
//!
|
||||
//! Before fusion, each elementwise op is a separate kernel launch that reads
|
||||
//! from and writes to global memory:
|
||||
//!
|
||||
//! ```text
|
||||
//! x * 2.0 + 1.0 (4 kernel launches, 2 intermediate buffers)
|
||||
//!
|
||||
//! ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐
|
||||
//! │ Constant(2.0)│ │ Mul(x, 2.0) │ │ Constant(1.0)│ │ Add(mul, 1.0)│
|
||||
//! │ launch #1 │────>│ launch #2 │ │ launch #3 │────>│ launch #4 │
|
||||
//! │ write tmp0 │ │ read x, tmp0 │ │ write tmp1 │ │read tmp0,tmp1│
|
||||
//! └──────────────┘ │ write tmp0 │ └──────────────┘ │ write output │
|
||||
//! └──────────────┘ └──────────────┘
|
||||
//! ```
|
||||
//!
|
||||
//! After fusion, one kernel does everything in registers:
|
||||
//!
|
||||
//! ```text
|
||||
//! x * 2.0 + 1.0 (1 kernel launch, 0 intermediate buffers)
|
||||
//!
|
||||
//! ┌─────────────────────────────────────────┐
|
||||
//! │ FusedElementwise │
|
||||
//! │ t0 = x[idx] // load from GMEM │
|
||||
//! │ t1 = 2.0 // register const │
|
||||
//! │ t2 = t0 * t1 // register mul │
|
||||
//! │ t3 = 1.0 // register const │
|
||||
//! │ t4 = t2 + t3 // register add │
|
||||
//! │ out[idx] = t4 // store to GMEM │
|
||||
//! └─────────────────────────────────────────┘
|
||||
//! ```
|
||||
//!
|
||||
//! ## Design: egglog rewrite rules
|
||||
//!
|
||||
//! Fusion happens inside egglog via rewrite rules, so fused kernels compete with
|
||||
//! unfused alternatives during the genetic search. A tree-structured `FusedInstr`
|
||||
//! sort encodes the computation DAG directly in the e-graph.
|
||||
//!
|
||||
//! **Seed rules** (one per fusible kernel op) convert individual ops into trivial
|
||||
//! `KernelFused` nodes. **Extension rules** absorb a `KernelFused` producer into
|
||||
//! a `KernelFused` consumer by splicing the producer's program tree in place of
|
||||
//! the consumer's `FIInput` node. The schedule runs 10 iterations, so chains
|
||||
//! naturally grow up to ~8 ops.
|
||||
//!
|
||||
//! ## Fusible ops
|
||||
//!
|
||||
//! - **Unary**: Exp2, Log2, Sin, Recip, Sqrt
|
||||
//! - **Binary**: Add, Mul, Mod, LessThan
|
||||
//! - **Nullary**: Constant (inlined as a register literal)
|
||||
//!
|
||||
//! Non-fusible ops (reductions, gather, scatter, matmul, etc.) break the chain.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use cudarc::driver::{CudaFunction, CudaModule, CudaSlice, CudaStream};
|
||||
use luminal::{
|
||||
egglog_utils::{
|
||||
api::{sort, Rule, SortDef},
|
||||
base::{DTYPE, ELIST, FUSED_INSTR, OP_KIND},
|
||||
extract_dtype, extract_expr_list, SerializedEGraph,
|
||||
},
|
||||
op::{EgglogOp, LLIROp},
|
||||
prelude::*,
|
||||
shape::flatten_strides,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
compile_module_image_for_current_device, cuda_dtype,
|
||||
kernel::{
|
||||
hlir::{dtype_includes, generate_dyn_dims_defines},
|
||||
KernelOp,
|
||||
},
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
// Data structures
|
||||
// ============================================================================
|
||||
|
||||
/// Maximum number of fused nodes before we stop fusing (to limit register pressure).
|
||||
const MAX_FUSED_NODES: usize = 32;
|
||||
|
||||
/// A unary operation in the fused kernel.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum UnaryFusedOp {
|
||||
Exp2,
|
||||
Log2,
|
||||
Sin,
|
||||
Recip,
|
||||
Sqrt,
|
||||
}
|
||||
|
||||
/// A binary operation in the fused kernel.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum BinaryFusedOp {
|
||||
Add,
|
||||
Mul,
|
||||
Mod,
|
||||
LessThan,
|
||||
}
|
||||
|
||||
/// A node in the fused computation DAG.
|
||||
///
|
||||
/// Nodes are stored in topological order in a `Vec<FusedNode>`. Each Unary/Binary
|
||||
/// node references earlier nodes by index, forming a DAG. The last node is the
|
||||
/// output that gets written to global memory.
|
||||
///
|
||||
/// ## Example: `rsqrt(y + x * recip(z))`
|
||||
///
|
||||
/// ```text
|
||||
/// nodes vec graph edges (LLIR)
|
||||
/// ────────── ──────────────────
|
||||
/// [0] Input { strides_z } <───── edge 0: z_buffer
|
||||
/// [1] Unary { Recip, 0 } (no edge — computed in registers)
|
||||
/// [2] Input { strides_x } <───── edge 1: x_buffer
|
||||
/// [3] Binary { Mul, 2, 1 } (no edge — register * register)
|
||||
/// [4] Input { strides_y } <───── edge 2: y_buffer
|
||||
/// [5] Binary { Add, 4, 3 } (no edge)
|
||||
/// [6] Unary { Sqrt, 5 } (no edge)
|
||||
/// [7] Unary { Recip, 6 } (no edge)
|
||||
/// ^
|
||||
/// └── last node = output written to out[idx]
|
||||
/// ```
|
||||
///
|
||||
/// The i-th `Input` node (counting only Input variants) corresponds to the i-th
|
||||
/// external input in the `CompiledKernel`'s `inputs` vec.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum FusedNode {
|
||||
/// Read from an external input buffer. The i-th Input (counting only Inputs)
|
||||
/// maps to the i-th entry in the external inputs list and thus the i-th kernel
|
||||
/// parameter after `out`.
|
||||
Input {
|
||||
/// Strides for converting the linear thread index to a memory offset.
|
||||
/// Uses `flatten_strides(out_shape, strides)` with variable `z` = `const_z`.
|
||||
strides: Vec<Expression>,
|
||||
},
|
||||
/// An inline constant value — becomes a float literal in the kernel, no buffer.
|
||||
Constant { value: f32 },
|
||||
/// A unary operation on a previous node (referenced by index into the vec).
|
||||
Unary { op: UnaryFusedOp, input: usize },
|
||||
/// A binary operation on two previous nodes (referenced by index into the vec).
|
||||
Binary {
|
||||
op: BinaryFusedOp,
|
||||
lhs: usize,
|
||||
rhs: usize,
|
||||
},
|
||||
}
|
||||
|
||||
/// A fused elementwise kernel that evaluates a DAG of pointwise operations
|
||||
/// in a single CUDA kernel launch.
|
||||
///
|
||||
/// ## Relationship to the generated CUDA kernel
|
||||
///
|
||||
/// ```text
|
||||
/// struct fields generated kernel
|
||||
/// ───────────── ────────────────
|
||||
/// out_shape ──────────────> n_elements = product(out_shape)
|
||||
/// out_strides ─────────────> out[flatten(out_shape, out_strides)] = t_last
|
||||
/// nodes[i] = Input {strides} > float t_i = in_k[flatten(out_shape, strides)]
|
||||
/// nodes[j] = Unary {Exp2, i} > float t_j = exp2f(t_i)
|
||||
/// nodes[k] = Binary{Add,i,j} > float t_k = t_i + t_j
|
||||
/// dtype ───────────────────> all pointers typed as `float*` / `half*` / etc.
|
||||
/// ```
|
||||
///
|
||||
/// One thread per output element. Each Input node reads one element via
|
||||
/// `flatten_strides`, which converts the linear thread index (`const_z`) to a
|
||||
/// strided memory offset — handling broadcasting (zero strides), slicing, etc.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct KernelFusedElementwise {
|
||||
/// The output iteration shape — all fused ops must share this shape.
|
||||
/// One CUDA thread is launched per element in this shape.
|
||||
pub out_shape: Vec<Expression>,
|
||||
/// Strides for writing the final result to global memory.
|
||||
pub out_strides: Vec<Expression>,
|
||||
/// The computation DAG in topological order. The last node is the output.
|
||||
pub nodes: Vec<FusedNode>,
|
||||
/// Output data type (determines pointer types and element size).
|
||||
pub dtype: DType,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// EgglogOp implementation (stub — this op is never created by egglog)
|
||||
// ============================================================================
|
||||
|
||||
impl Default for KernelFusedElementwise {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
out_shape: vec![],
|
||||
out_strides: vec![],
|
||||
nodes: vec![],
|
||||
dtype: DType::F32,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl EgglogOp for KernelFusedElementwise {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
"KernelFused",
|
||||
&[
|
||||
("shape", ELIST),
|
||||
("out_strides", ELIST),
|
||||
("program", FUSED_INSTR),
|
||||
("dtype", DTYPE),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
fn n_inputs(&self) -> usize {
|
||||
self.nodes
|
||||
.iter()
|
||||
.filter(|n| matches!(n, FusedNode::Input { .. }))
|
||||
.count()
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
let mut rules = Vec::new();
|
||||
rules.extend(seed_rules());
|
||||
rules.extend(extension_rules());
|
||||
rules
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
let out_shape =
|
||||
extract_expr_list(egraph, kind_children[0], list_cache, expr_cache).unwrap();
|
||||
let out_strides =
|
||||
extract_expr_list(egraph, kind_children[1], list_cache, expr_cache).unwrap();
|
||||
let nodes = extract_fused_instr(egraph, kind_children[2], list_cache, expr_cache);
|
||||
let dtype = extract_dtype(egraph, kind_children[3]);
|
||||
(
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(Self {
|
||||
out_shape,
|
||||
out_strides,
|
||||
nodes,
|
||||
dtype,
|
||||
}) as Box<dyn KernelOp>),
|
||||
input_enodes,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// KernelOp implementation — CUDA code generation
|
||||
// ============================================================================
|
||||
|
||||
impl KernelOp for KernelFusedElementwise {
|
||||
fn compile(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> (
|
||||
CudaFunction,
|
||||
Arc<CudaModule>,
|
||||
String,
|
||||
(Expression, Expression, Expression),
|
||||
(Expression, Expression, Expression),
|
||||
Expression,
|
||||
FxHashMap<char, CudaSlice<u8>>,
|
||||
) {
|
||||
let mut vars = FxHashSet::default();
|
||||
for e in &self.out_shape {
|
||||
vars.extend(e.dyn_vars());
|
||||
}
|
||||
for e in &self.out_strides {
|
||||
vars.extend(e.dyn_vars());
|
||||
}
|
||||
for node in &self.nodes {
|
||||
if let FusedNode::Input { strides } = node {
|
||||
for e in strides {
|
||||
vars.extend(e.dyn_vars());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let dtype = cuda_dtype(self.dtype);
|
||||
let includes = dtype_includes(&[self.dtype]);
|
||||
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
|
||||
let dyn_dims_param = if vars.is_empty() {
|
||||
""
|
||||
} else {
|
||||
", const int* dyn_dims"
|
||||
};
|
||||
|
||||
let n_elements = self
|
||||
.out_shape
|
||||
.iter()
|
||||
.copied()
|
||||
.product::<Expression>()
|
||||
.to_kernel();
|
||||
let out_idx = flatten_strides(&self.out_shape, &self.out_strides).to_kernel();
|
||||
|
||||
// Walk the FusedNode DAG and emit one CUDA statement per node.
|
||||
//
|
||||
// Example: nodes = [Input(s0), Input(s1), Binary(Mul,0,1), Constant(1.0), Binary(Add,2,3)]
|
||||
// generates:
|
||||
// float t0 = in0[<flatten(shape, s0)>]; // Input -> load from buffer
|
||||
// float t1 = in1[<flatten(shape, s1)>]; // Input -> load from buffer
|
||||
// float t2 = t0 * t1; // Binary -> register arithmetic
|
||||
// float t3 = (float)1.0; // Const -> literal
|
||||
// float t4 = t2 + t3; // Binary -> register arithmetic
|
||||
// out[<flatten(shape, out_strides)>] = t4; // last node -> store
|
||||
let mut input_params = String::new();
|
||||
let mut body = String::new();
|
||||
let mut input_count = 0usize;
|
||||
|
||||
for (i, node) in self.nodes.iter().enumerate() {
|
||||
match node {
|
||||
FusedNode::Input { strides } => {
|
||||
let idx = input_count;
|
||||
input_count += 1;
|
||||
input_params.push_str(&format!(", const {dtype} *in{idx}"));
|
||||
let in_idx = flatten_strides(&self.out_shape, strides).to_kernel();
|
||||
body.push_str(&format!(" {dtype} t{i} = in{idx}[{in_idx}];\n"));
|
||||
}
|
||||
FusedNode::Constant { value } => {
|
||||
let value_str = if value.is_nan() {
|
||||
"__int_as_float(0x7fc00000)".to_string()
|
||||
} else if value.is_infinite() {
|
||||
if *value > 0.0 {
|
||||
"__int_as_float(0x7f800000)".to_string()
|
||||
} else {
|
||||
"__int_as_float(0xff800000)".to_string()
|
||||
}
|
||||
} else {
|
||||
format!("{:.10}f", value)
|
||||
};
|
||||
body.push_str(&format!(" {dtype} t{i} = ({dtype}){value_str};\n"));
|
||||
}
|
||||
FusedNode::Unary { op, input } => {
|
||||
let expr = match op {
|
||||
UnaryFusedOp::Exp2 => format!("exp2f(t{input})"),
|
||||
UnaryFusedOp::Log2 => format!("log2f(t{input})"),
|
||||
UnaryFusedOp::Sin => format!("sinf(t{input})"),
|
||||
UnaryFusedOp::Recip => format!("1.0f / t{input}"),
|
||||
UnaryFusedOp::Sqrt => format!("sqrtf(t{input})"),
|
||||
};
|
||||
body.push_str(&format!(" {dtype} t{i} = {expr};\n"));
|
||||
}
|
||||
FusedNode::Binary { op, lhs, rhs } => {
|
||||
let expr = match op {
|
||||
BinaryFusedOp::Add => format!("t{lhs} + t{rhs}"),
|
||||
BinaryFusedOp::Mul => format!("t{lhs} * t{rhs}"),
|
||||
BinaryFusedOp::Mod => format!("fmodf(t{lhs}, t{rhs})"),
|
||||
BinaryFusedOp::LessThan => format!("(t{lhs} < t{rhs}) ? ({dtype})1.0f : ({dtype})0.0f"),
|
||||
};
|
||||
body.push_str(&format!(" {dtype} t{i} = {expr};\n"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let last_idx = self.nodes.len() - 1;
|
||||
let kernel = format!(
|
||||
"{includes}
|
||||
{dyn_defines}
|
||||
extern \"C\" {{
|
||||
__global__ void fused_ew_k({dtype} *out{input_params}{dyn_dims_param}) {{
|
||||
long long const_z = (long long)blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (const_z >= {n_elements}) return;
|
||||
{body} out[{out_idx}] = t{last_idx};
|
||||
}}
|
||||
}}"
|
||||
);
|
||||
|
||||
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
|
||||
(module.clone(), func.clone())
|
||||
} else {
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let func = module.load_function("fused_ew_k").unwrap();
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
(module, func)
|
||||
};
|
||||
|
||||
let out_size = self.out_shape.iter().copied().product::<Expression>();
|
||||
(
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(out_size.ceil_div(256), 1.into(), 1.into()),
|
||||
(out_size.min(256), 1.into(), 1.into()),
|
||||
0.into(),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
self.out_shape.iter().copied().product()
|
||||
}
|
||||
|
||||
fn output_bytes(&self) -> Expression {
|
||||
(self.output_size() * self.dtype.bits()).ceil_div(8)
|
||||
}
|
||||
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
let n_inputs = self
|
||||
.nodes
|
||||
.iter()
|
||||
.filter(|n| matches!(n, FusedNode::Input { .. }))
|
||||
.count();
|
||||
(self.output_size() * self.dtype.bits()).ceil_div(8) * n_inputs as i64
|
||||
}
|
||||
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
self.output_bytes()
|
||||
}
|
||||
|
||||
fn flops(&self) -> Expression {
|
||||
let ops_per_element: i64 = self
|
||||
.nodes
|
||||
.iter()
|
||||
.filter(|n| matches!(n, FusedNode::Unary { .. } | FusedNode::Binary { .. }))
|
||||
.count() as i64;
|
||||
self.output_size() * ops_per_element
|
||||
}
|
||||
|
||||
fn output_dtype(&self) -> DType {
|
||||
self.dtype
|
||||
}
|
||||
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
"FusedElementwise"
|
||||
}
|
||||
}
|
||||
|
||||
// NOTE: Fusibility is now handled entirely by egglog rewrite rules (seed_rules
|
||||
// and extension_rules above). The old extract_fusible_info() function that
|
||||
// downcasted KernelOp structs is no longer needed.
|
||||
|
||||
// ============================================================================
|
||||
// Egglog rule generation
|
||||
// ============================================================================
|
||||
|
||||
/// Seed rules: convert each fusible KernelOp into a trivial KernelFused node.
|
||||
/// This places fused variants into the e-graph alongside unfused ones.
|
||||
fn seed_rules() -> Vec<Rule> {
|
||||
let mut rules = Vec::new();
|
||||
|
||||
// Unary ops: KernelFoo(shape, in_strides, out_strides, dtype)
|
||||
for (kernel_name, fi_name) in [
|
||||
("KernelExp2", "FIExp2"),
|
||||
("KernelLog2", "FILog2"),
|
||||
("KernelSin", "FISin"),
|
||||
("KernelRecip", "FIRecip"),
|
||||
("KernelSqrt", "FISqrt"),
|
||||
] {
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule
|
||||
((= ?out (Op ({kernel_name} ?shape ?in_s ?out_s ?dt) (ICons ?x (INil)))))
|
||||
((let ?prog ({fi_name} (FIInput 0 ?in_s)))
|
||||
(let ?f (Op (KernelFused ?shape ?out_s ?prog ?dt) (ICons ?x (INil))))
|
||||
(union ?out ?f)
|
||||
(set (dtype ?f) ?dt))
|
||||
:name \"seed-{kernel_name}\"
|
||||
)"
|
||||
)));
|
||||
}
|
||||
|
||||
// Binary ops: KernelFoo(shape, a_strides, b_strides, out_strides, dtype)
|
||||
for (kernel_name, fi_name) in [
|
||||
("KernelAdd", "FIAdd"),
|
||||
("KernelMul", "FIMul"),
|
||||
("KernelMod", "FIMod"),
|
||||
] {
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule
|
||||
((= ?out (Op ({kernel_name} ?shape ?a_s ?b_s ?out_s ?dt) (ICons ?a (ICons ?b (INil))))))
|
||||
((let ?prog ({fi_name} (FIInput 0 ?a_s) (FIInput 1 ?b_s)))
|
||||
(let ?f (Op (KernelFused ?shape ?out_s ?prog ?dt) (ICons ?a (ICons ?b (INil)))))
|
||||
(union ?out ?f)
|
||||
(set (dtype ?f) ?dt))
|
||||
:name \"seed-{kernel_name}\"
|
||||
)"
|
||||
)));
|
||||
}
|
||||
|
||||
// Constant: KernelConstant(value) — no inputs, single element at out[0]
|
||||
rules.push(Rule::raw(
|
||||
"(rule
|
||||
((= ?out (Op (KernelConstant ?val) (INil))))
|
||||
((let ?prog (FIConstant ?val))
|
||||
(let ?f (Op (KernelFused (ECons (MNum 1) (ENil)) (ECons (MNum 0) (ENil)) ?prog (F32)) (INil)))
|
||||
(union ?out ?f)
|
||||
(set (dtype ?f) (F32)))
|
||||
:name \"seed-KernelConstant\"
|
||||
)"
|
||||
.to_string(),
|
||||
));
|
||||
|
||||
rules
|
||||
}
|
||||
|
||||
/// Extension rules: absorb a KernelFused producer into a KernelFused consumer.
|
||||
///
|
||||
/// For unary consumers, the consumer's `FIInput(0, ?)` is replaced by the
|
||||
/// producer's entire program tree. The IList passes through directly.
|
||||
///
|
||||
/// For binary consumers, one input is the producer (replaced by its program),
|
||||
/// the other stays as an `FIInput`. The ILists are merged and the external
|
||||
/// input's index is shifted by the producer's input count.
|
||||
fn extension_rules() -> Vec<Rule> {
|
||||
let mut rules = Vec::new();
|
||||
|
||||
// Unary extension: consumer is FI_Op(FIInput(0, ?s)), producer is any KernelFused.
|
||||
// The producer's IList becomes the fused IList directly.
|
||||
for fi_name in ["FIExp2", "FILog2", "FISin", "FIRecip", "FISqrt"] {
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule
|
||||
((= ?out (Op (KernelFused ?shape ?out_s ({fi_name} (FIInput 0 ?s)) ?dt) (ICons ?prod (INil))))
|
||||
(= ?prod (Op (KernelFused ?shape ?_os ?prog ?_dt) ?inputs)))
|
||||
((let ?f (Op (KernelFused ?shape ?out_s ({fi_name} ?prog) ?dt) ?inputs))
|
||||
(union ?out ?f)
|
||||
(set (dtype ?f) ?dt))
|
||||
:name \"extend-{fi_name}\"
|
||||
)"
|
||||
)));
|
||||
}
|
||||
|
||||
for fi_name in ["FIAdd", "FIMul", "FIMod"] {
|
||||
// LHS absorb: producer feeds into input 0, input 1 stays external.
|
||||
// Producer has N inputs → consumer's input 1 shifts to index N.
|
||||
for n_prod_inputs in 1..=4usize {
|
||||
let prod_ilist_pattern = build_ilist_pattern("?px", n_prod_inputs);
|
||||
let prod_ilist_build = build_ilist_cons("?px", n_prod_inputs, "ICons ?rhs (INil)");
|
||||
rules.push(Rule::raw(format!(
|
||||
"(rule
|
||||
((= ?out (Op (KernelFused ?sh ?os ({fi_name} (FIInput 0 ?s0) (FIInput 1 ?s1)) ?dt)
|
||||
(ICons ?prod (ICons ?rhs (INil)))))
|
||||
(= ?prod (Op (KernelFused ?sh ?_os ?prog ?_dt) {prod_ilist_pattern})))
|
||||
((let ?f (Op (KernelFused ?sh ?os ({fi_name} ?prog (FIInput {n_prod_inputs} ?s1)) ?dt)
|
||||
{prod_ilist_build}))
|
||||
(union ?out ?f)
|
||||
(set (dtype ?f) ?dt))
|
||||
:name \"extend-{fi_name}-lhs-{n_prod_inputs}\"
|
||||
)"
|
||||
)));
|
||||
}
|
||||
|
||||
// RHS absorb: producer feeds into input 1, input 0 stays external.
|
||||
// NOTE: RHS absorption is omitted. LHS absorption handles the common case
|
||||
// where the producer feeds into input 0 (the 'a' operand) of a binary op.
|
||||
// Full RHS support requires fixing the FIInput index → Input occurrence
|
||||
// mapping during extraction, which is tracked as future work.
|
||||
}
|
||||
|
||||
rules
|
||||
}
|
||||
|
||||
/// Build an IList pattern like `(ICons ?px0 (ICons ?px1 (INil)))` for matching.
|
||||
fn build_ilist_pattern(prefix: &str, n: usize) -> String {
|
||||
let mut s = "(INil)".to_string();
|
||||
for i in (0..n).rev() {
|
||||
s = format!("(ICons {prefix}{i} {s})");
|
||||
}
|
||||
s
|
||||
}
|
||||
|
||||
/// Build an IList construction like `(ICons ?px0 (ICons ?px1 <tail>))`.
|
||||
fn build_ilist_cons(prefix: &str, n: usize, tail: &str) -> String {
|
||||
let mut s = format!("({tail})");
|
||||
for i in (0..n).rev() {
|
||||
s = format!("(ICons {prefix}{i} {s})");
|
||||
}
|
||||
s
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Extraction: FusedInstr tree → Vec<FusedNode>
|
||||
// ============================================================================
|
||||
|
||||
/// Recursively walk a `FusedInstr` tree in the serialized e-graph and convert
|
||||
/// it to a flat `Vec<FusedNode>` via post-order traversal.
|
||||
fn extract_fused_instr<'a>(
|
||||
egraph: &'a SerializedEGraph,
|
||||
node_id: &'a ENodeId,
|
||||
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> Vec<FusedNode> {
|
||||
// Phase 1: Walk the tree, collecting nodes. FIInput nodes record their
|
||||
// index (IList position) as a placeholder. Non-Input nodes use vec indices.
|
||||
let mut raw_nodes = Vec::new();
|
||||
let mut input_records: Vec<(usize, Vec<Expression>)> = Vec::new(); // (fi_index, strides)
|
||||
extract_fused_instr_rec(egraph, node_id, &mut raw_nodes, &mut input_records, list_cache, expr_cache);
|
||||
|
||||
// Phase 2: Reorder so FusedNode::Input entries appear sorted by FIInput index.
|
||||
// This ensures the i-th Input maps to input_enodes[i] (IList position i).
|
||||
input_records.sort_by_key(|(fi_idx, _)| *fi_idx);
|
||||
|
||||
// Build the final node vec: sorted Inputs first, then non-Input nodes.
|
||||
let mut final_nodes = Vec::with_capacity(raw_nodes.len());
|
||||
let mut old_to_new: FxHashMap<usize, usize> = FxHashMap::default();
|
||||
|
||||
// Map each FIInput's original raw_nodes position to its new sorted position.
|
||||
// Find which raw_nodes indices were Inputs.
|
||||
let input_raw_indices: Vec<usize> = raw_nodes
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|(_, n)| matches!(n, FusedNode::Input { .. }))
|
||||
.map(|(i, _)| i)
|
||||
.collect();
|
||||
|
||||
// Create sorted Input nodes.
|
||||
for (new_pos, (_fi_idx, strides)) in input_records.iter().enumerate() {
|
||||
final_nodes.push(FusedNode::Input { strides: strides.clone() });
|
||||
// Find the raw index that had this FIInput.
|
||||
// Since input_records were collected in tree-walk order and we now
|
||||
// sort by fi_idx, we need the mapping from sorted position to raw index.
|
||||
// But we don't have a direct link. Instead, match by fi_idx:
|
||||
// The input_raw_indices[k] corresponds to input_records[k] (before sort).
|
||||
// After sorting, input_records[new_pos] was originally at some position k.
|
||||
// We need to find k.
|
||||
}
|
||||
|
||||
// Actually, let's use a simpler approach: collect all Inputs with their raw
|
||||
// position, sort by fi_idx, then build the mapping.
|
||||
drop(final_nodes);
|
||||
drop(old_to_new);
|
||||
|
||||
let mut input_entries: Vec<(usize, usize, Vec<Expression>)> = Vec::new(); // (raw_idx, fi_idx, strides)
|
||||
let mut input_counter = 0;
|
||||
for (raw_idx, node) in raw_nodes.iter().enumerate() {
|
||||
if matches!(node, FusedNode::Input { .. }) {
|
||||
let (fi_idx, strides) = &input_records[input_counter];
|
||||
input_entries.push((raw_idx, *fi_idx, strides.clone()));
|
||||
input_counter += 1;
|
||||
}
|
||||
}
|
||||
// Sort by fi_idx.
|
||||
input_entries.sort_by_key(|(_, fi_idx, _)| *fi_idx);
|
||||
|
||||
// Build mapping: old raw_idx → new position.
|
||||
let mut old_to_new = FxHashMap::default();
|
||||
let mut final_nodes = Vec::with_capacity(raw_nodes.len());
|
||||
|
||||
// Sorted Inputs first.
|
||||
for (raw_idx, _fi_idx, strides) in &input_entries {
|
||||
old_to_new.insert(*raw_idx, final_nodes.len());
|
||||
final_nodes.push(FusedNode::Input { strides: strides.clone() });
|
||||
}
|
||||
// Non-Input nodes in original order.
|
||||
for (raw_idx, node) in raw_nodes.iter().enumerate() {
|
||||
if !matches!(node, FusedNode::Input { .. }) {
|
||||
old_to_new.insert(raw_idx, final_nodes.len());
|
||||
final_nodes.push(node.clone());
|
||||
}
|
||||
}
|
||||
// Remap references.
|
||||
for node in &mut final_nodes {
|
||||
match node {
|
||||
FusedNode::Unary { input, .. } => *input = old_to_new[input],
|
||||
FusedNode::Binary { lhs, rhs, .. } => {
|
||||
*lhs = old_to_new[lhs];
|
||||
*rhs = old_to_new[rhs];
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
final_nodes
|
||||
}
|
||||
|
||||
/// Resolve a child (ClassId) of an enode to a concrete NodeId.
|
||||
/// Picks the first node in the e-class.
|
||||
fn resolve_child<'a>(egraph: &'a SerializedEGraph, enode: &'a ENodeId, child_idx: usize) -> &'a ENodeId {
|
||||
let class_id = &egraph.enodes[enode].1[child_idx];
|
||||
&egraph.eclasses[class_id].1[0]
|
||||
}
|
||||
|
||||
/// Recursive helper. Returns the index of the node it added (the last one).
|
||||
fn extract_fused_instr_rec<'a>(
|
||||
egraph: &'a SerializedEGraph,
|
||||
node_id: &'a ENodeId,
|
||||
nodes: &mut Vec<FusedNode>,
|
||||
input_records: &mut Vec<(usize, Vec<Expression>)>,
|
||||
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> usize {
|
||||
let op_name = &egraph.enodes[node_id].0;
|
||||
|
||||
match op_name.as_str() {
|
||||
"FIInput" => {
|
||||
let fi_index: usize = egraph.enodes[resolve_child(egraph, node_id, 0)]
|
||||
.0
|
||||
.replace("\"", "")
|
||||
.parse()
|
||||
.unwrap();
|
||||
let strides_node = resolve_child(egraph, node_id, 1);
|
||||
let strides =
|
||||
extract_expr_list(egraph, strides_node, list_cache, expr_cache).unwrap();
|
||||
let idx = nodes.len();
|
||||
input_records.push((fi_index, strides.clone()));
|
||||
nodes.push(FusedNode::Input { strides });
|
||||
idx
|
||||
}
|
||||
"FIConstant" => {
|
||||
let value: f32 = egraph.enodes[resolve_child(egraph, node_id, 0)]
|
||||
.0
|
||||
.replace("\"", "")
|
||||
.parse()
|
||||
.unwrap();
|
||||
let idx = nodes.len();
|
||||
nodes.push(FusedNode::Constant { value });
|
||||
idx
|
||||
}
|
||||
// Unary ops — child[0] is a FusedInstr class
|
||||
"FIExp2" | "FILog2" | "FISin" | "FIRecip" | "FISqrt" => {
|
||||
let child_node = resolve_child(egraph, node_id, 0);
|
||||
let input = extract_fused_instr_rec(egraph, child_node, nodes, input_records, list_cache, expr_cache);
|
||||
let op = match op_name.as_str() {
|
||||
"FIExp2" => UnaryFusedOp::Exp2,
|
||||
"FILog2" => UnaryFusedOp::Log2,
|
||||
"FISin" => UnaryFusedOp::Sin,
|
||||
"FIRecip" => UnaryFusedOp::Recip,
|
||||
"FISqrt" => UnaryFusedOp::Sqrt,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
let idx = nodes.len();
|
||||
nodes.push(FusedNode::Unary { op, input });
|
||||
idx
|
||||
}
|
||||
// Binary ops — child[0] and child[1] are FusedInstr classes
|
||||
"FIAdd" | "FIMul" | "FIMod" => {
|
||||
let lhs_node = resolve_child(egraph, node_id, 0);
|
||||
let rhs_node = resolve_child(egraph, node_id, 1);
|
||||
let lhs = extract_fused_instr_rec(egraph, lhs_node, nodes, input_records, list_cache, expr_cache);
|
||||
let rhs = extract_fused_instr_rec(egraph, rhs_node, nodes, input_records, list_cache, expr_cache);
|
||||
let op = match op_name.as_str() {
|
||||
"FIAdd" => BinaryFusedOp::Add,
|
||||
"FIMul" => BinaryFusedOp::Mul,
|
||||
"FIMod" => BinaryFusedOp::Mod,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
let idx = nodes.len();
|
||||
nodes.push(FusedNode::Binary { op, lhs, rhs });
|
||||
idx
|
||||
}
|
||||
other => panic!("Unknown FusedInstr variant: {other}"),
|
||||
}
|
||||
}
|
||||
@@ -508,11 +508,11 @@ extern \"C\" {{
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct KernelAdd {
|
||||
pub out_shape: Vec<Expression>,
|
||||
pub a_stride: Vec<Expression>,
|
||||
pub b_stride: Vec<Expression>,
|
||||
pub out_stride: Vec<Expression>,
|
||||
pub dtype: DType,
|
||||
out_shape: Vec<Expression>,
|
||||
a_stride: Vec<Expression>,
|
||||
b_stride: Vec<Expression>,
|
||||
out_stride: Vec<Expression>,
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
impl EgglogOp for KernelAdd {
|
||||
@@ -634,8 +634,8 @@ extern \"C\" {{
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(out_size.ceil_div(256), 1.into(), 1.into()),
|
||||
(out_size.min(256), 1.into(), 1.into()),
|
||||
(out_size.ceil_div(128), 1.into(), 1.into()),
|
||||
(out_size.min(128), 1.into(), 1.into()),
|
||||
0.into(),
|
||||
FxHashMap::default(), // No per-module constants needed
|
||||
)
|
||||
@@ -673,11 +673,11 @@ extern \"C\" {{
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct KernelMul {
|
||||
pub out_shape: Vec<Expression>,
|
||||
pub a_stride: Vec<Expression>,
|
||||
pub b_stride: Vec<Expression>,
|
||||
pub out_stride: Vec<Expression>,
|
||||
pub dtype: DType,
|
||||
out_shape: Vec<Expression>,
|
||||
a_stride: Vec<Expression>,
|
||||
b_stride: Vec<Expression>,
|
||||
out_stride: Vec<Expression>,
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
impl EgglogOp for KernelMul {
|
||||
@@ -797,8 +797,8 @@ extern \"C\" {{
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(out_size.ceil_div(256), 1.into(), 1.into()),
|
||||
(out_size.min(256), 1.into(), 1.into()),
|
||||
(out_size.ceil_div(128), 1.into(), 1.into()),
|
||||
(out_size.min(128), 1.into(), 1.into()),
|
||||
0.into(),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
@@ -990,13 +990,12 @@ extern \"C\" {{
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
(module, func)
|
||||
};
|
||||
let out_size = self.out_shape.iter().copied().product::<Expression>();
|
||||
(
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(out_size.ceil_div(256), 1.into(), 1.into()),
|
||||
(out_size.min(256), 1.into(), 1.into()),
|
||||
(self.out_shape.iter().copied().product(), 1.into(), 1.into()),
|
||||
(1.into(), 1.into(), 1.into()),
|
||||
0.into(),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
@@ -1500,10 +1499,10 @@ extern \"C\" {{
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct KernelExp2 {
|
||||
pub shape: Vec<Expression>,
|
||||
pub in_strides: Vec<Expression>,
|
||||
pub out_strides: Vec<Expression>,
|
||||
pub dtype: DType,
|
||||
shape: Vec<Expression>,
|
||||
in_strides: Vec<Expression>,
|
||||
out_strides: Vec<Expression>,
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
impl EgglogOp for KernelExp2 {
|
||||
@@ -1616,8 +1615,8 @@ extern \"C\" {{
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(out_size.ceil_div(256), 1.into(), 1.into()),
|
||||
(out_size.min(256), 1.into(), 1.into()),
|
||||
(out_size.ceil_div(128), 1.into(), 1.into()),
|
||||
(out_size.min(128), 1.into(), 1.into()),
|
||||
0.into(),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
@@ -1654,10 +1653,10 @@ extern \"C\" {{
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct KernelLog2 {
|
||||
pub shape: Vec<Expression>,
|
||||
pub in_strides: Vec<Expression>,
|
||||
pub out_strides: Vec<Expression>,
|
||||
pub dtype: DType,
|
||||
shape: Vec<Expression>,
|
||||
in_strides: Vec<Expression>,
|
||||
out_strides: Vec<Expression>,
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
impl EgglogOp for KernelLog2 {
|
||||
@@ -1770,8 +1769,8 @@ extern \"C\" {{
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(out_size.ceil_div(256), 1.into(), 1.into()),
|
||||
(out_size.min(256), 1.into(), 1.into()),
|
||||
(out_size.ceil_div(128), 1.into(), 1.into()),
|
||||
(out_size.min(128), 1.into(), 1.into()),
|
||||
0.into(),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
@@ -1808,10 +1807,10 @@ extern \"C\" {{
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct KernelSin {
|
||||
pub shape: Vec<Expression>,
|
||||
pub in_strides: Vec<Expression>,
|
||||
pub out_strides: Vec<Expression>,
|
||||
pub dtype: DType,
|
||||
shape: Vec<Expression>,
|
||||
in_strides: Vec<Expression>,
|
||||
out_strides: Vec<Expression>,
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
impl EgglogOp for KernelSin {
|
||||
@@ -1924,8 +1923,8 @@ extern \"C\" {{
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(out_size.ceil_div(256), 1.into(), 1.into()),
|
||||
(out_size.min(256), 1.into(), 1.into()),
|
||||
(out_size.ceil_div(128), 1.into(), 1.into()),
|
||||
(out_size.min(128), 1.into(), 1.into()),
|
||||
0.into(),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
@@ -1962,10 +1961,10 @@ extern \"C\" {{
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct KernelRecip {
|
||||
pub shape: Vec<Expression>,
|
||||
pub in_strides: Vec<Expression>,
|
||||
pub out_strides: Vec<Expression>,
|
||||
pub dtype: DType,
|
||||
shape: Vec<Expression>,
|
||||
in_strides: Vec<Expression>,
|
||||
out_strides: Vec<Expression>,
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
impl EgglogOp for KernelRecip {
|
||||
@@ -2078,8 +2077,8 @@ extern \"C\" {{
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(out_size.ceil_div(256), 1.into(), 1.into()),
|
||||
(out_size.min(256), 1.into(), 1.into()),
|
||||
(out_size.ceil_div(128), 1.into(), 1.into()),
|
||||
(out_size.min(128), 1.into(), 1.into()),
|
||||
0.into(),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
@@ -2116,10 +2115,10 @@ extern \"C\" {{
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct KernelSqrt {
|
||||
pub shape: Vec<Expression>,
|
||||
pub in_strides: Vec<Expression>,
|
||||
pub out_strides: Vec<Expression>,
|
||||
pub dtype: DType,
|
||||
shape: Vec<Expression>,
|
||||
in_strides: Vec<Expression>,
|
||||
out_strides: Vec<Expression>,
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
impl EgglogOp for KernelSqrt {
|
||||
@@ -2232,8 +2231,8 @@ extern \"C\" {{
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(out_size.ceil_div(256), 1.into(), 1.into()),
|
||||
(out_size.min(256), 1.into(), 1.into()),
|
||||
(out_size.ceil_div(128), 1.into(), 1.into()),
|
||||
(out_size.min(128), 1.into(), 1.into()),
|
||||
0.into(),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
@@ -2270,11 +2269,11 @@ extern \"C\" {{
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct KernelMod {
|
||||
pub out_shape: Vec<Expression>,
|
||||
pub a_stride: Vec<Expression>,
|
||||
pub b_stride: Vec<Expression>,
|
||||
pub out_stride: Vec<Expression>,
|
||||
pub dtype: DType,
|
||||
out_shape: Vec<Expression>,
|
||||
a_stride: Vec<Expression>,
|
||||
b_stride: Vec<Expression>,
|
||||
out_stride: Vec<Expression>,
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
impl EgglogOp for KernelMod {
|
||||
@@ -2393,8 +2392,8 @@ extern \"C\" {{
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(out_size.ceil_div(256), 1.into(), 1.into()),
|
||||
(out_size.min(256), 1.into(), 1.into()),
|
||||
(out_size.ceil_div(128), 1.into(), 1.into()),
|
||||
(out_size.min(128), 1.into(), 1.into()),
|
||||
0.into(),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
@@ -2432,11 +2431,11 @@ extern \"C\" {{
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct KernelLessThan {
|
||||
pub out_shape: Vec<Expression>,
|
||||
pub a_stride: Vec<Expression>,
|
||||
pub b_stride: Vec<Expression>,
|
||||
pub out_stride: Vec<Expression>,
|
||||
pub dtype: DType,
|
||||
out_shape: Vec<Expression>,
|
||||
a_stride: Vec<Expression>,
|
||||
b_stride: Vec<Expression>,
|
||||
out_stride: Vec<Expression>,
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
impl EgglogOp for KernelLessThan {
|
||||
@@ -2568,8 +2567,8 @@ extern \"C\" {{
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(out_size.ceil_div(256), 1.into(), 1.into()),
|
||||
(out_size.min(256), 1.into(), 1.into()),
|
||||
(out_size.ceil_div(128), 1.into(), 1.into()),
|
||||
(out_size.min(128), 1.into(), 1.into()),
|
||||
0.into(),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
@@ -2608,7 +2607,7 @@ extern \"C\" {{
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct KernelConstant {
|
||||
pub value: f32,
|
||||
value: f32,
|
||||
}
|
||||
|
||||
impl EgglogOp for KernelConstant {
|
||||
|
||||
@@ -10,13 +10,12 @@ use luminal_tracing::schema::{
|
||||
use uuid::Uuid;
|
||||
|
||||
pub mod cuda_graph;
|
||||
pub mod fused_elementwise;
|
||||
pub mod hlir;
|
||||
pub mod other_ops;
|
||||
|
||||
pub use cuda_graph::*;
|
||||
|
||||
pub type Ops = (hlir::Ops, other_ops::Ops, fused_elementwise::KernelFusedElementwise);
|
||||
pub type Ops = (hlir::Ops, other_ops::Ops);
|
||||
|
||||
/// Build a mapping from interned string IDs to their string values for a given sequence.
|
||||
fn build_interned_strings(trace: &schema::Trace) -> std::collections::HashMap<(u32, u64), String> {
|
||||
|
||||
@@ -1544,8 +1544,8 @@ extern \"C\" {{
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(out_size.ceil_div(256), 1.into(), 1.into()),
|
||||
(out_size.min(256), 1.into(), 1.into()),
|
||||
(out_size.ceil_div(128), 1.into(), 1.into()),
|
||||
(out_size.min(128), 1.into(), 1.into()),
|
||||
0.into(),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
@@ -1730,8 +1730,8 @@ extern \"C\" {{
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(out_size.ceil_div(256), 1.into(), 1.into()),
|
||||
(out_size.min(256), 1.into(), 1.into()),
|
||||
(out_size.ceil_div(128), 1.into(), 1.into()),
|
||||
(out_size.min(128), 1.into(), 1.into()),
|
||||
0.into(),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
|
||||
@@ -302,10 +302,8 @@ impl CudaGraphOp {
|
||||
kernel.internal_bufs = kernel.kernel_op.allocate_internal_buffers(stream, dyn_map);
|
||||
}
|
||||
}
|
||||
// Only force full rebuild when internal buffer sizes change.
|
||||
// Dim-only changes (e.g. position offset `p` incrementing each decode step) are
|
||||
// handled by updating the dyn_dims device buffer + kernel node params in-place.
|
||||
if needs_internal_realloc {
|
||||
// Force full rebuild when dims change (debug: testing if update_kernel_node is the issue)
|
||||
if dyn_map_changed || needs_internal_realloc {
|
||||
state.cuda_graph = None;
|
||||
state.cuda_graph_exec = None;
|
||||
state.node_to_graph_node.clear();
|
||||
@@ -700,7 +698,7 @@ pub fn kernel_to_host(
|
||||
kernel_op_ref.compile(cuda_stream, kernel_cache);
|
||||
|
||||
// Collect inputs from graph edges
|
||||
let inputs: Vec<NodeIndex> = llir_graph
|
||||
let mut inputs: Vec<NodeIndex> = llir_graph
|
||||
.edges_directed(*kernel_node_idx, Direction::Incoming)
|
||||
.sorted_by_key(|e| e.id())
|
||||
.map(|e| e.source())
|
||||
|
||||
@@ -120,17 +120,13 @@ pub struct CudaRuntime {
|
||||
/// Bucket definitions per dimension (empty = single-bucket mode)
|
||||
dim_buckets: FxHashMap<char, Vec<DimBucket>>,
|
||||
|
||||
/// HLIR nodes that should never be consumed after execute().
|
||||
/// Used for weight tensors shared via external device pointers.
|
||||
persistent_hlir_nodes: FxHashSet<NodeIndex>,
|
||||
|
||||
/// Non-owning CudaSlice wrappers for external device pointers.
|
||||
/// ManuallyDrop prevents cuMemFree — the external allocator (e.g. PyTorch) owns the memory.
|
||||
external_buffers: FxHashMap<NodeIndex, std::mem::ManuallyDrop<CudaSlice<u8>>>,
|
||||
|
||||
/// Pending output pointer registrations: HLIR output id -> (device_ptr, n_bytes)
|
||||
/// Set by python before execute(), consumed at start of execute()
|
||||
output_ptr_registrations: FxHashMap<NodeIndex, (u64, usize)>,
|
||||
|
||||
/// Non-owning CudaSlice views of external output pointers, keyed by LLIR data node
|
||||
/// ManuallyDrop prevents cuMemFree -- Pytorch owns the memory
|
||||
external_output_buffers: FxHashMap<NodeIndex, std::mem::ManuallyDrop<CudaSlice<u8>>>,
|
||||
}
|
||||
|
||||
impl CudaRuntime {
|
||||
@@ -232,25 +228,9 @@ impl CudaRuntime {
|
||||
self.changed_hlir.insert(id);
|
||||
}
|
||||
|
||||
/// Register an external device pointer for an output tensor (zero-copy output).
|
||||
/// The pointer is stored lazily — resolution to LLIR nodes happens in execute().
|
||||
///
|
||||
/// # Safety
|
||||
/// The device pointer must point to a valid CUDA allocation with at least `n_bytes` bytes,
|
||||
/// and must remain valid through the next execute() call.
|
||||
pub unsafe fn set_output_device_ptr(&mut self, id: impl ToId, device_ptr: u64, n_bytes: usize) {
|
||||
debug_assert!(
|
||||
device_ptr != 0,
|
||||
"set_output_device_ptr called with null pointer"
|
||||
);
|
||||
self.output_ptr_registrations
|
||||
.insert(id.to_id(), (device_ptr, n_bytes));
|
||||
}
|
||||
|
||||
pub fn output_is_zero_copy(&self, id: impl ToId) -> bool {
|
||||
let producer = self.find_producer_node(id);
|
||||
let data_node = self.follow_aliases(producer);
|
||||
self.external_output_buffers.contains_key(&data_node)
|
||||
/// Mark an HLIR node as persistent — its buffer won't be consumed after execute().
|
||||
pub fn persist_hlir_node(&mut self, id: impl ToId) {
|
||||
self.persistent_hlir_nodes.insert(id.to_id());
|
||||
}
|
||||
|
||||
/// Find the LLIR producing node for an output tensor.
|
||||
@@ -410,50 +390,6 @@ impl CudaRuntime {
|
||||
self.cuda_stream.synchronize().unwrap();
|
||||
}
|
||||
|
||||
/// Resolve pending output pointer registrations into external_output_buffers.
|
||||
/// Called at the start of execute(), after buffer allocation and HLIR sync.
|
||||
fn apply_output_ptr_registrations(&mut self) {
|
||||
// clear stale external output buffers from previous execution
|
||||
self.external_output_buffers.clear();
|
||||
|
||||
if self.output_ptr_registrations.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
// Collect registrations to avoid borrow conflict (drain borrows self mutably,
|
||||
// but find_producer_node/follow_aliases need &self).
|
||||
|
||||
let registrations: Vec<_> = self.output_ptr_registrations.drain().collect();
|
||||
|
||||
for (hlir_id, (device_ptr, n_bytes)) in registrations {
|
||||
// Resolve HLIR output id -> LLIR producer -> follow aliases -> data node
|
||||
let producer = self.find_producer_node(hlir_id);
|
||||
let data_node = self.follow_aliases(producer);
|
||||
|
||||
// If data_node is an HLIR input (aliased output), skip — can't substitute
|
||||
if self.compiled_buckets[self.active_bucket]
|
||||
.llir_to_hlir
|
||||
.contains_key(&data_node)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
// Create non-owning CudaSlice view of PyTorch's buffer
|
||||
let slice = unsafe {
|
||||
self.cuda_stream
|
||||
.upgrade_device_ptr::<u8>(device_ptr, n_bytes)
|
||||
};
|
||||
|
||||
self.external_output_buffers
|
||||
.insert(data_node, std::mem::ManuallyDrop::new(slice));
|
||||
|
||||
// Update cached_buffer_ptrs so CudaGraphOp picks up the new pointer
|
||||
self.compiled_buckets[self.active_bucket]
|
||||
.cached_buffer_ptrs
|
||||
.insert(data_node, device_ptr);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_f32(&self, id: impl ToId) -> Vec<f32> {
|
||||
let bytes = self.get_output_data(id);
|
||||
let bytes = bytes.leak();
|
||||
@@ -854,8 +790,7 @@ impl Runtime for CudaRuntime {
|
||||
compiled_buckets: vec![CompiledBucket::new()],
|
||||
active_bucket: 0,
|
||||
dim_buckets: FxHashMap::default(),
|
||||
output_ptr_registrations: FxHashMap::default(),
|
||||
external_output_buffers: FxHashMap::default(),
|
||||
persistent_hlir_nodes: FxHashSet::default(),
|
||||
external_buffers: FxHashMap::default(),
|
||||
}
|
||||
}
|
||||
@@ -1078,9 +1013,6 @@ impl Runtime for CudaRuntime {
|
||||
// Ensure all CUDA graphs are built (handles first execute and any missing graphs)
|
||||
self.prebuild_graphs(dyn_map);
|
||||
|
||||
// Resolve external output pointer registrations (zero-copy output path)
|
||||
self.apply_output_ptr_registrations();
|
||||
|
||||
let total_start = std::time::Instant::now();
|
||||
let bucket = &self.compiled_buckets[self.active_bucket];
|
||||
|
||||
@@ -1090,11 +1022,8 @@ impl Runtime for CudaRuntime {
|
||||
|
||||
// Build buffer map for the HostOp interface
|
||||
let mut buffer_map: FxHashMap<NodeIndex, &CudaSlice<u8>> = FxHashMap::default();
|
||||
|
||||
// Add output buffer -- prefer external output pointer if registered (zero copy)
|
||||
if let Some(ext) = self.external_output_buffers.get(&exec_op.output) {
|
||||
buffer_map.insert(exec_op.output, &**ext);
|
||||
} else if let Some(buf) = bucket.buffers.get(&exec_op.output) {
|
||||
// Add output buffer
|
||||
if let Some(buf) = bucket.buffers.get(&exec_op.output) {
|
||||
buffer_map.insert(exec_op.output, buf);
|
||||
}
|
||||
// Add input buffers (prefer HLIR weight buffers over intermediate placeholders)
|
||||
@@ -1124,9 +1053,7 @@ impl Runtime for CudaRuntime {
|
||||
let extra_nodes = exec_op.internal.extra_buffer_nodes();
|
||||
for extra_node in extra_nodes {
|
||||
if let Entry::Vacant(e) = buffer_map.entry(extra_node) {
|
||||
if let Some(ext) = self.external_output_buffers.get(&extra_node) {
|
||||
e.insert(&**ext);
|
||||
} else if let Some(buf) = bucket.buffers.get(&extra_node) {
|
||||
if let Some(buf) = bucket.buffers.get(&extra_node) {
|
||||
e.insert(buf);
|
||||
} else if let Some(hlir_node) = bucket.llir_to_hlir.get(&extra_node) {
|
||||
match self.hlir_buffers.get(hlir_node) {
|
||||
@@ -1211,6 +1138,11 @@ impl Runtime for CudaRuntime {
|
||||
}
|
||||
}
|
||||
|
||||
// Final sync to ensure all operations completed successfully
|
||||
self.cuda_stream
|
||||
.synchronize()
|
||||
.expect("Final sync failed in execute");
|
||||
|
||||
// Consume input buffers
|
||||
if self.profiling {
|
||||
return;
|
||||
@@ -1258,6 +1190,7 @@ impl Runtime for CudaRuntime {
|
||||
.hlir_buffers
|
||||
.keys()
|
||||
.filter(|hlir_node| !inputs_with_outputs.contains(hlir_node))
|
||||
.filter(|hlir_node| !self.persistent_hlir_nodes.contains(hlir_node))
|
||||
.copied()
|
||||
.collect();
|
||||
|
||||
@@ -1312,8 +1245,7 @@ impl CudaRuntime {
|
||||
// Clone llir_graph so we can modify it
|
||||
let mut llir_graph = llir_graph.clone();
|
||||
|
||||
// Compile kernel subgraphs into CudaGraphOps (which implement HostOp).
|
||||
// Elementwise fusion happens inside kernel_to_host via identify_chains/build_chain_fused.
|
||||
// Compile kernel subgraphs into CudaGraphOps (which implement HostOp)
|
||||
crate::kernel::kernel_to_host(&mut llir_graph, &self.cuda_stream, &mut self.kernel_cache);
|
||||
|
||||
// Build output alias map
|
||||
|
||||
@@ -41,7 +41,7 @@ fn test_bucket_dispatch_simple() {
|
||||
rt.set_data(a, vec![1.0f32; 4]);
|
||||
|
||||
let mut rng = SmallRng::seed_from_u64(42);
|
||||
rt = cx.search_options(rt, SearchOptions::new(5), &mut rng);
|
||||
rt = cx.search_rng(rt, 5, &mut rng);
|
||||
|
||||
// Test bucket 1: s=1
|
||||
cx.set_dim('s', 1);
|
||||
@@ -85,7 +85,7 @@ fn test_bucket_matmul_dynamic() {
|
||||
rt.set_data(b_tensor, b_data.clone());
|
||||
|
||||
let mut rng = SmallRng::seed_from_u64(42);
|
||||
rt = cx.search_options(rt, SearchOptions::new(5), &mut rng);
|
||||
rt = cx.search_rng(rt, 5, &mut rng);
|
||||
|
||||
// Execute at s=1
|
||||
cx.set_dim('s', 1);
|
||||
@@ -140,7 +140,7 @@ fn test_bucket_results_match_unbucketed() {
|
||||
let input_data = random_f32_vec(12, seed, -1.0, 1.0);
|
||||
rt1.set_data(a1, input_data.clone());
|
||||
let mut rng1 = SmallRng::seed_from_u64(seed);
|
||||
rt1 = cx1.search_options(rt1, SearchOptions::new(5), &mut rng1);
|
||||
rt1 = cx1.search_rng(rt1, 5, &mut rng1);
|
||||
rt1.set_data(a1, input_data.clone());
|
||||
rt1.execute(&cx1.dyn_map);
|
||||
let result_unbucketed = rt1.get_f32(b1);
|
||||
@@ -153,7 +153,7 @@ fn test_bucket_results_match_unbucketed() {
|
||||
let mut rt2 = CudaRuntime::initialize(stream.clone());
|
||||
rt2.set_data(a2, input_data.clone());
|
||||
let mut rng2 = SmallRng::seed_from_u64(seed);
|
||||
rt2 = cx2.search_options(rt2, SearchOptions::new(5), &mut rng2);
|
||||
rt2 = cx2.search_rng(rt2, 5, &mut rng2);
|
||||
rt2.set_data(a2, input_data.clone());
|
||||
rt2.execute(&cx2.dyn_map);
|
||||
let result_bucketed = rt2.get_f32(b2);
|
||||
@@ -179,7 +179,7 @@ fn test_bucket_out_of_range_panics() {
|
||||
cx.set_dim('s', 1);
|
||||
rt.set_data(a, vec![1.0f32; 4]);
|
||||
let mut rng = SmallRng::seed_from_u64(42);
|
||||
rt = cx.search_options(rt, SearchOptions::new(3), &mut rng);
|
||||
rt = cx.search_rng(rt, 3, &mut rng);
|
||||
|
||||
// s=10 is outside all buckets — should panic
|
||||
cx.set_dim('s', 10);
|
||||
@@ -204,7 +204,7 @@ fn test_bucket_no_buckets_backward_compat() {
|
||||
let input_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
|
||||
rt.set_data(a, input_data.clone());
|
||||
let mut rng = SmallRng::seed_from_u64(42);
|
||||
rt = cx.search_options(rt, SearchOptions::new(3), &mut rng);
|
||||
rt = cx.search_rng(rt, 3, &mut rng);
|
||||
|
||||
rt.set_data(a, input_data.clone());
|
||||
rt.execute(&cx.dyn_map);
|
||||
@@ -249,7 +249,7 @@ fn test_bucket_switch_preserves_weights() {
|
||||
rt.set_data(b_tensor, b_data.clone());
|
||||
|
||||
let mut rng = SmallRng::seed_from_u64(42);
|
||||
rt = cx.search_options(rt, SearchOptions::new(5), &mut rng);
|
||||
rt = cx.search_rng(rt, 5, &mut rng);
|
||||
|
||||
// Execute with bucket 1 (s=1)
|
||||
cx.set_dim('s', 1);
|
||||
@@ -305,7 +305,7 @@ fn test_bucket_multiple_executions_same_bucket() {
|
||||
cx.set_dim('s', 1);
|
||||
rt.set_data(a, vec![1.0f32; 4]);
|
||||
let mut rng = SmallRng::seed_from_u64(42);
|
||||
rt = cx.search_options(rt, SearchOptions::new(3), &mut rng);
|
||||
rt = cx.search_rng(rt, 3, &mut rng);
|
||||
|
||||
// Execute at different sizes within the same bucket
|
||||
for s in [1, 2, 4, 8] {
|
||||
|
||||
@@ -348,7 +348,7 @@ fn test_scatter_dual_cache_with_graph_break() {
|
||||
// Use seeded search for deterministic scatter variant selection.
|
||||
// Seed 0 reliably selects Scatter (not ScatterNoCopy) for both caches.
|
||||
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
|
||||
rt = cx.search_options(rt, SearchOptions::new(5), &mut rng);
|
||||
rt = cx.search_rng(rt, 5, &mut rng);
|
||||
|
||||
// Print selected variants
|
||||
for node in rt.llir_graph().node_weights() {
|
||||
|
||||
1
crates/luminal_python/.gitignore
vendored
1
crates/luminal_python/.gitignore
vendored
@@ -1,5 +1,4 @@
|
||||
*.onnx
|
||||
tests/llama38b_ref_logits.pt
|
||||
__pycache__/
|
||||
*.pyc
|
||||
uv.lock
|
||||
|
||||
@@ -1,4 +1,8 @@
|
||||
A couple of short things to keep in mind
|
||||
## Python Environment
|
||||
|
||||
- Always use `uv run` to execute Python tools (pytest, pre-commit, python) — never bare `pytest` or `python`
|
||||
- Use `uv add` / `uv add --dev` / `uv remove` for dependencies — never hand-edit pyproject.toml deps
|
||||
- After modifying Rust source files, rebuild before running Python tests: `maturin develop --release`
|
||||
|
||||
## Lessons Learned
|
||||
|
||||
@@ -24,7 +28,7 @@ consult before writing new egglog rules, CUDA kernels, or optimizer passes.
|
||||
## Testing Best Practices
|
||||
|
||||
### Overview
|
||||
The luminal_python crate provides a bridge between PyTorch models and the luminal library via the PT2 Export pipeline. Tests should verify this integration end-to-end by testing the actual user workflow: PyTorch model → torch.compile → luminal backend.
|
||||
The luminal_python crate provides a bridge between PyTorch models and the luminal library via ONNX. Tests should verify this integration end-to-end by testing the actual user workflow: PyTorch model → torch.compile → luminal backend.
|
||||
|
||||
### Test Pattern (CORRECT)
|
||||
|
||||
@@ -67,11 +71,11 @@ class AddTestModel(torch.nn.Module):
|
||||
|
||||
### What NOT to Do
|
||||
|
||||
**❌ DO NOT create pt2 files directly in tests:**
|
||||
**❌ DO NOT create ONNX files directly in tests:**
|
||||
```python
|
||||
# WRONG - bypasses the PyTorch integration
|
||||
model_path = create_pt2_model(...)
|
||||
graph_result = luminal.process_pt(model_path, backend='native')
|
||||
model_path = create_onnx_model(...)
|
||||
graph_result = luminal.process_onnx(model_path, backend='native')
|
||||
```
|
||||
|
||||
**✓ DO create PyTorch models and use torch.compile:**
|
||||
@@ -83,16 +87,16 @@ model_compiled = torch.compile(model, backend=luminal_backend)
|
||||
|
||||
### Rationale
|
||||
|
||||
- **End-to-end testing**: Tests verify the complete PyTorch → Pt2 → luminal pipeline
|
||||
- **End-to-end testing**: Tests verify the complete PyTorch → ONNX → luminal pipeline
|
||||
- **User-facing API**: Tests use the same API that users will use (torch.compile)
|
||||
- **Correctness**: Comparing compiled vs original PyTorch output ensures correctness
|
||||
- **Maintainability**: Consistent pattern across all tests makes the codebase easier to understand
|
||||
- **Simplicity**: No manual Pt2 file creation, no tempfile cleanup, no numpy comparisons
|
||||
- **Simplicity**: No manual ONNX file creation, no tempfile cleanup, no numpy comparisons
|
||||
|
||||
### Special Cases
|
||||
|
||||
**Testing constants:**
|
||||
Use inline tensor literals in the forward method - these are exported as constant tensors:
|
||||
Use inline tensor literals in the forward method - PyTorch exports these as ONNX Constant nodes:
|
||||
```python
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
constant = torch.tensor([1.0, 2.0, 3.0])
|
||||
@@ -100,14 +104,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
```
|
||||
|
||||
**Testing type casts:**
|
||||
Use `.to(dtype)` method - these are exported as type cast operations:
|
||||
Use `.to(dtype)` method - PyTorch exports these as ONNX Cast nodes:
|
||||
```python
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x.to(torch.float32)
|
||||
```
|
||||
|
||||
**Testing complex operations:**
|
||||
Chain operations naturally in PyTorch - the export pipeline handles the conversion:
|
||||
Chain operations naturally in PyTorch - ONNX export handles the conversion:
|
||||
```python
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
transposed = x.transpose(0, 1)
|
||||
|
||||
@@ -3,13 +3,19 @@ name = "luminal_python"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"numpy>=2.0.2",
|
||||
"torch>=2.10.0",
|
||||
"onnx",
|
||||
"onnxscript",
|
||||
"safetensors",
|
||||
"flash-attn-3>=3.0.0",
|
||||
]
|
||||
|
||||
[tool.uv]
|
||||
no-build-isolation-package = ["flash-attn"]
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch-cu128"
|
||||
url = "https://download.pytorch.org/whl/cu128"
|
||||
@@ -19,6 +25,7 @@ explicit = true
|
||||
torch = [
|
||||
{ index = "pytorch-cu128", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||
]
|
||||
flash-attn-3 = { index = "pytorch-cu128" }
|
||||
|
||||
|
||||
[build-system]
|
||||
@@ -38,12 +45,21 @@ markers = [
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"maturin>=1.0,<2.0",
|
||||
"maturin-import-hook>=0.3.0",
|
||||
"pytest>=9.0.2",
|
||||
"pytest-profiling",
|
||||
"snakeviz",
|
||||
"maturin-import-hook>=0.3.0",
|
||||
"pytest-randomly>=4.0.1",
|
||||
"transformers>=4.40.0",
|
||||
"transformers>=5.5.0,<6",
|
||||
"diffusers>=0.35.0",
|
||||
"onnxsim",
|
||||
"tiktoken>=0.12.0",
|
||||
"pydantic>=2.12.5",
|
||||
"psutil>=7.2.2",
|
||||
"modal>=1.3.5",
|
||||
"pillow",
|
||||
"flash-attn>=2.8.3",
|
||||
]
|
||||
flash-attention-4 = [
|
||||
"nvidia-cutlass-dsl==4.1.0",
|
||||
]
|
||||
|
||||
@@ -16,9 +16,13 @@ rm -rf rust/target/wheels rust/target/debug rust/target/release
|
||||
uv run maturin develop --manifest-path rust/Cargo.toml
|
||||
|
||||
echo ""
|
||||
echo "--- 1a: Native backend tests ---"
|
||||
echo "--- 1a: Native + ONNX ---"
|
||||
uv run pytest $NATIVE_TESTS -v
|
||||
|
||||
echo ""
|
||||
echo "--- 1b: Native + PT2 ---"
|
||||
LUMINAL_EXPORT_MODE=pt2 uv run pytest $NATIVE_TESTS -v
|
||||
|
||||
# ── Phase 2: CUDA Backend ───────────────────────────────────
|
||||
|
||||
echo ""
|
||||
@@ -27,9 +31,13 @@ rm -rf rust/target/wheels rust/target/debug rust/target/release
|
||||
uv run maturin develop --manifest-path rust/Cargo.toml --features cuda -r
|
||||
|
||||
echo ""
|
||||
echo "--- 2a: CUDA ---"
|
||||
echo "--- 2a: CUDA + ONNX ---"
|
||||
RUST_BACKTRACE=1 LUMINAL_BACKEND=cuda uv run pytest $CUDA_TESTS -m "not slow" -v
|
||||
|
||||
echo ""
|
||||
echo "--- 2b: CUDA + PT2 ---"
|
||||
RUST_BACKTRACE=1 LUMINAL_BACKEND=cuda LUMINAL_EXPORT_MODE=pt2 uv run pytest $CUDA_TESTS -m "not slow" -v
|
||||
|
||||
echo ""
|
||||
echo "=========================================="
|
||||
echo " All tests passed!"
|
||||
|
||||
20
crates/luminal_python/run_test_fx.sh
Executable file
20
crates/luminal_python/run_test_fx.sh
Executable file
@@ -0,0 +1,20 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
echo "=== Luminal Python Test Runner (PT2 Export Mode) ==="
|
||||
echo ""
|
||||
|
||||
# Force clean rebuild of Rust extension
|
||||
echo "Step 1: Cleaning previous builds..."
|
||||
rm -rf rust/target/wheels rust/target/debug rust/target/release
|
||||
|
||||
# Rebuild in development mode (faster compilation)
|
||||
echo "Step 2: Building Rust extension..."
|
||||
uv run maturin develop --manifest-path rust/Cargo.toml
|
||||
|
||||
# Run pytest with PT2 export mode
|
||||
echo "Step 3: Running pytest with PT2 export mode..."
|
||||
LUMINAL_EXPORT_MODE=pt2 uv run pytest tests/test_hlir_ops.py tests/test_unary.py -v
|
||||
|
||||
echo ""
|
||||
echo "=== Tests Complete ==="
|
||||
19
crates/luminal_python/run_tests_cuda_fx.sh
Executable file
19
crates/luminal_python/run_tests_cuda_fx.sh
Executable file
@@ -0,0 +1,19 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
echo "=== Luminal Python Test Runner (CUDA + PT2 Export Mode) ==="
|
||||
echo ""
|
||||
|
||||
# Force clean rebuild of Rust extension
|
||||
echo "Step 1: Cleaning previous builds..."
|
||||
rm -rf rust/target/wheels rust/target/debug rust/target/release
|
||||
|
||||
# Rebuild in development mode (faster compilation)
|
||||
echo "Step 2: Building Rust extension..."
|
||||
uv run maturin develop --manifest-path rust/Cargo.toml --features cuda -r
|
||||
|
||||
# Run pytest with CUDA backend and PT2 export mode
|
||||
echo "Step 3: Running pytest with CUDA backend + PT2 export mode..."
|
||||
RUST_BACKTRACE=1 LUMINAL_BACKEND=cuda LUMINAL_EXPORT_MODE=pt2 uv run pytest tests/test_llama3.py tests/test_hlir_ops.py tests/test_unary.py -v
|
||||
echo ""
|
||||
echo "=== Tests Complete ==="
|
||||
@@ -12,6 +12,8 @@ path = "src/lib.rs"
|
||||
cuda = ["dep:luminal_cuda_lite"]
|
||||
|
||||
[dependencies]
|
||||
onnx-protobuf = "0.2"
|
||||
protobuf = "~3.4"
|
||||
rustc-hash = "2.1.1"
|
||||
luminal = {path= "../../.."}
|
||||
luminal_cuda_lite = {path="../../luminal_cuda_lite", optional = true}
|
||||
|
||||
@@ -1,55 +1,32 @@
|
||||
#[cfg(feature = "cuda")]
|
||||
use luminal::prelude::tracing::{trace, warn};
|
||||
use luminal::{
|
||||
hlir::{NativeData, Output},
|
||||
prelude::*,
|
||||
shape::Expression,
|
||||
visualization::ToDot,
|
||||
};
|
||||
use luminal::{prelude::*, shape::Expression, visualization::ToDot};
|
||||
use pyo3::prelude::*;
|
||||
use std::collections::HashMap;
|
||||
#[cfg(feature = "cuda")]
|
||||
use std::collections::HashSet;
|
||||
|
||||
use crate::{runtime::RuntimeBackend, typed_data::TypedData};
|
||||
use crate::{runtime::RuntimeBackend, util::DimParamMap};
|
||||
|
||||
/// Maps symbolic dimension parameter names (e.g. "seq_len") to luminal Expression variable chars.
|
||||
pub type DimParamMap = HashMap<String, char>;
|
||||
|
||||
/// Convert luminal DType to PT2 dtype integer code (for python interop)
|
||||
/// Types without a direct Pytorch equivalent map to the closest safe representation
|
||||
fn luminal_dtype_to_pt2_code(dtype: DType) -> u32 {
|
||||
match dtype {
|
||||
DType::U8 => 1,
|
||||
DType::I8 => 2,
|
||||
DType::I16 => 3,
|
||||
DType::Int => 4, // i32
|
||||
DType::U16 => 4, // u16 -> i32 (Pytorch has no u16 in older versions)
|
||||
DType::F16 => 6,
|
||||
DType::F32 | DType::TF32 => 7,
|
||||
DType::F64 => 8,
|
||||
DType::Bool => 12,
|
||||
DType::Bf16 => 13,
|
||||
_ => panic!("luminal_dtype_to_pt2_code: unsupported dtype {:?}", dtype),
|
||||
}
|
||||
}
|
||||
|
||||
/// Common intermediate result from translating a model graph.
|
||||
/// Common intermediate result from translating a model graph (ONNX or FX).
|
||||
pub struct GraphTranslation {
|
||||
pub graph: Graph,
|
||||
pub tensor_ids: HashMap<String, NodeIndex>,
|
||||
pub input_names: Vec<String>,
|
||||
pub output_names: Vec<String>,
|
||||
pub output_shape_exprs: Vec<Vec<Expression>>,
|
||||
pub output_dtypes: Vec<DType>,
|
||||
pub input_shape_exprs: Vec<Vec<Expression>>,
|
||||
pub dim_param_map: DimParamMap,
|
||||
}
|
||||
|
||||
/// Pre-loaded weight data from any model format (dtype-aware).
|
||||
/// Pre-loaded weight data from any model format.
|
||||
///
|
||||
/// NOTE: Currently assumes all data is F32. When the type system branch lands
|
||||
/// with proper multi-dtype support, this struct (and all callers) will need
|
||||
/// updating to carry dtype metadata alongside the raw data.
|
||||
pub struct WeightData {
|
||||
/// (Input node label, typed data) for weights and constants.
|
||||
pub weights: Vec<(String, TypedData)>,
|
||||
/// (Input node label, f32 data) for weights and constants.
|
||||
pub weights: Vec<(String, Vec<f32>)>,
|
||||
/// label → element count for ALL Input nodes (for CUDA dummy data sizing).
|
||||
pub tensor_sizes: HashMap<String, usize>,
|
||||
/// label → (device_ptr, n_bytes) for zero-copy CUDA weight sharing.
|
||||
@@ -67,16 +44,15 @@ pub struct CompiledGraph {
|
||||
pub output_names: Vec<String>,
|
||||
pub output_shapes: Vec<Vec<usize>>,
|
||||
pub output_shape_exprs: Vec<Vec<Expression>>,
|
||||
pub output_dtypes: Vec<DType>,
|
||||
pub input_shape_exprs: Vec<Vec<Expression>>,
|
||||
pub dim_param_map: DimParamMap,
|
||||
}
|
||||
|
||||
impl CompiledGraph {
|
||||
/// Compilation pipeline for PT2/FX graphs.
|
||||
/// Shared compilation pipeline for both ONNX and FX/PT2 graphs.
|
||||
///
|
||||
/// Takes a `GraphTranslation` (produced by `translate_pt2`) and `WeightData`,
|
||||
/// builds the backend, loads weights, and
|
||||
/// Takes a format-neutral `GraphTranslation` (produced by `translate_onnx` or
|
||||
/// `translate_pt2`) and `WeightData`, builds the backend, loads weights, and
|
||||
/// returns a ready-to-execute `CompiledGraph`.
|
||||
pub fn parse_graph(
|
||||
translation: GraphTranslation,
|
||||
@@ -90,7 +66,6 @@ impl CompiledGraph {
|
||||
input_names,
|
||||
output_names,
|
||||
output_shape_exprs,
|
||||
output_dtypes,
|
||||
input_shape_exprs,
|
||||
dim_param_map,
|
||||
} = translation;
|
||||
@@ -144,7 +119,6 @@ impl CompiledGraph {
|
||||
output_names,
|
||||
output_shapes,
|
||||
output_shape_exprs,
|
||||
output_dtypes,
|
||||
input_shape_exprs,
|
||||
dim_param_map,
|
||||
})
|
||||
@@ -188,13 +162,14 @@ impl CompiledGraph {
|
||||
// For weights with device pointers: use them directly (zero-copy).
|
||||
// This avoids allocating ~N GB of dummy data during search.
|
||||
// The pointers survive search because profiling mode skips buffer consumption,
|
||||
// and graph-level .persist() ensures they survive post-search execution too.
|
||||
// and persist_hlir_node ensures they survive post-search execution too.
|
||||
let mut device_ptr_nodes: HashSet<NodeIndex> = HashSet::new();
|
||||
let mut matched_count = 0usize;
|
||||
let mut missed_labels: Vec<String> = Vec::new();
|
||||
for (label, &(ptr, n_bytes)) in device_ptrs {
|
||||
if let Some(&node_id) = label_map.get(label) {
|
||||
unsafe { rt.set_device_ptr(node_id, ptr, n_bytes) };
|
||||
rt.persist_hlir_node(node_id);
|
||||
device_ptr_nodes.insert(node_id);
|
||||
matched_count += 1;
|
||||
} else {
|
||||
@@ -243,10 +218,7 @@ impl CompiledGraph {
|
||||
if n > 0 {
|
||||
dummy_total_elements += n;
|
||||
dummy_count += 1;
|
||||
// Use dtype-aware dummy data: TypedData::ones produces correct
|
||||
// byte patterns for every dtype (f32, f16, bf16, i32, bool, f8, etc.).
|
||||
// Must use 1, not 0 — zero inputs cause NaN in many ops.
|
||||
rt.set_data(node_id, TypedData::ones(n, input.dtype).bytes);
|
||||
rt.set_data(node_id, vec![1.0f32; n]);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -262,21 +234,22 @@ impl CompiledGraph {
|
||||
let mut rt = graph.search(rt, search_iters);
|
||||
|
||||
// Load real weight data for non-device-ptr weights (constants from PT2 archive, etc.)
|
||||
let mut loaded_weight_bytes = 0usize;
|
||||
let mut loaded_weight_elements = 0usize;
|
||||
let mut loaded_weight_count = 0usize;
|
||||
for (label, data) in &weight_data.weights {
|
||||
if !device_ptrs.contains_key(label) {
|
||||
if let Some(&node_id) = label_map.get(label) {
|
||||
loaded_weight_bytes += data.n_bytes();
|
||||
loaded_weight_elements += data.len();
|
||||
loaded_weight_count += 1;
|
||||
rt.set_data(node_id, data.bytes.clone());
|
||||
rt.set_data(node_id, data.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
trace!(
|
||||
"[CUDA BUILD] Post-search weight load: {} weights, {:.3} GiB",
|
||||
"[CUDA BUILD] Post-search weight load: {} weights, {} elements ({:.3} GiB as f32)",
|
||||
loaded_weight_count,
|
||||
loaded_weight_bytes as f64 / (1024.0 * 1024.0 * 1024.0),
|
||||
loaded_weight_elements,
|
||||
(loaded_weight_elements * 4) as f64 / (1024.0 * 1024.0 * 1024.0),
|
||||
);
|
||||
|
||||
Ok(RuntimeBackend::Cuda(Box::new(rt)))
|
||||
@@ -290,14 +263,11 @@ impl CompiledGraph {
|
||||
graph.build_search_space::<NativeRuntime>();
|
||||
let mut rt = graph.search(NativeRuntime::default(), search_iters);
|
||||
|
||||
// Load weight data after search, preserving native dtype.
|
||||
// TypedData -> NativeData conversion (From<TypedData>) handles mapping to the
|
||||
// correct NativeData variant (F32, F16, Bf16, Int, Bool).
|
||||
// Load weight data after search
|
||||
let label_map = CompiledGraph::build_label_map(graph);
|
||||
for (label, data) in &weight_data.weights {
|
||||
if let Some(&node_id) = label_map.get(label) {
|
||||
let native: NativeData = data.into();
|
||||
rt.set_data(node_id, native);
|
||||
rt.set_data(node_id, data.clone());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -313,24 +283,6 @@ impl CompiledGraph {
|
||||
self.input_names.clone()
|
||||
}
|
||||
|
||||
/// Get the PT2 dtype codes for all inputs (in order of input_names).
|
||||
#[getter]
|
||||
fn input_dtypes(&self) -> Vec<u32> {
|
||||
self.input_names
|
||||
.iter()
|
||||
.map(|name| {
|
||||
if let Some(&node_id) = self.tensor_ids.get(name)
|
||||
&& let Some(input) = (*self.graph.graph[node_id])
|
||||
.as_any()
|
||||
.downcast_ref::<luminal::hlir::Input>()
|
||||
{
|
||||
return luminal_dtype_to_pt2_code(input.dtype);
|
||||
}
|
||||
7 // default to f32
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get the list of output tensor names.
|
||||
#[getter]
|
||||
fn output_names(&self) -> Vec<String> {
|
||||
@@ -419,33 +371,25 @@ impl CompiledGraph {
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Set input tensor data by name (f32, for backward compatibility).
|
||||
/// Set input tensor data by name.
|
||||
fn set_input(&mut self, name: &str, data: Vec<f32>) -> PyResult<()> {
|
||||
let node_id = self.tensor_ids.get(name).ok_or_else(|| {
|
||||
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!("Unknown input tensor: {}", name))
|
||||
})?;
|
||||
self.runtime.set_data_f32(*node_id, data);
|
||||
self.runtime.set_data(*node_id, data);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Set input tensor data from a CPU host memory pointer (dtype-aware).
|
||||
/// The pointer must point to contiguous data. `n_bytes` is the total byte count.
|
||||
/// `dtype_code` uses PT2 numbering (7=f32, 6=f16, 13=bf16, etc.).
|
||||
/// Converts source format to luminal's native format (e.g., i64→i32, f64→f32).
|
||||
fn set_input_from_ptr(
|
||||
&mut self,
|
||||
name: &str,
|
||||
ptr: u64,
|
||||
n_bytes: usize,
|
||||
dtype_code: u32,
|
||||
) -> PyResult<()> {
|
||||
/// Set input tensor data from a CPU host memory pointer (avoids Python list conversion).
|
||||
/// The pointer must point to contiguous f32 data (from tensor.data_ptr() on a CPU float32 tensor).
|
||||
fn set_input_from_ptr(&mut self, name: &str, ptr: u64, n_elements: usize) -> PyResult<()> {
|
||||
debug_assert!(ptr != 0, "set_input_from_ptr called with null pointer");
|
||||
let node_id = self.tensor_ids.get(name).ok_or_else(|| {
|
||||
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!("Unknown input tensor: {}", name))
|
||||
})?;
|
||||
let raw_bytes = unsafe { std::slice::from_raw_parts(ptr as *const u8, n_bytes).to_vec() };
|
||||
let typed = TypedData::from_pytorch_bytes(raw_bytes, dtype_code);
|
||||
self.runtime.set_data(*node_id, typed);
|
||||
let data: Vec<f32> =
|
||||
unsafe { std::slice::from_raw_parts(ptr as *const f32, n_elements).to_vec() };
|
||||
self.runtime.set_data(*node_id, data);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -472,7 +416,22 @@ impl CompiledGraph {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// For PT2 weights (e.g. "fc1.weight"). Persistence is handled at graph level via .persist().
|
||||
/// Mark an input tensor as persistent (survives execute() calls).
|
||||
/// Call this for weight tensors that should not be consumed after each execution.
|
||||
fn persist_input(&mut self, name: &str) -> PyResult<()> {
|
||||
let _node_id = *self.tensor_ids.get(name).ok_or_else(|| {
|
||||
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!("Unknown input tensor: {}", name))
|
||||
})?;
|
||||
match &mut self.runtime {
|
||||
#[cfg(feature = "cuda")]
|
||||
RuntimeBackend::Cuda(rt) => rt.persist_hlir_node(_node_id),
|
||||
RuntimeBackend::Native(_) => {} // Native: persist is handled at graph level
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Set a weight tensor from a CUDA device pointer, matching by Input node label.
|
||||
/// Also marks the weight as persistent. For PT2 weights (e.g. "fc1.weight").
|
||||
#[cfg(feature = "cuda")]
|
||||
fn set_weight_device_ptr(
|
||||
&mut self,
|
||||
@@ -486,6 +445,7 @@ impl CompiledGraph {
|
||||
match &mut self.runtime {
|
||||
RuntimeBackend::Cuda(rt) => {
|
||||
unsafe { rt.set_device_ptr(node_id, device_ptr, n_bytes) };
|
||||
rt.persist_hlir_node(node_id);
|
||||
}
|
||||
_ => {
|
||||
return Err(pyo3::exceptions::PyValueError::new_err(
|
||||
@@ -496,70 +456,15 @@ impl CompiledGraph {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Register an external device pointer for an output tensor (zero-copy output).
|
||||
/// Call before run() — the runtime will write kernel results directly into this buffer.
|
||||
/// For aliased outputs (in-place ops), falls back to DtoD copy; check output_is_zero_copy() after run().
|
||||
#[cfg(feature = "cuda")]
|
||||
fn set_output_device_ptr(
|
||||
&mut self,
|
||||
name: &str,
|
||||
device_ptr: u64,
|
||||
n_bytes: usize,
|
||||
) -> PyResult<()> {
|
||||
let node_id = self.tensor_ids.get(name).ok_or_else(|| {
|
||||
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!(
|
||||
"Unknown output tensor: {}",
|
||||
name
|
||||
))
|
||||
})?;
|
||||
|
||||
match &mut self.runtime {
|
||||
RuntimeBackend::Cuda(rt) => {
|
||||
unsafe { rt.set_output_device_ptr(*node_id, device_ptr, n_bytes) };
|
||||
}
|
||||
_ => {
|
||||
return Err(pyo3::exceptions::PyValueError::new_err(
|
||||
"set_output_device_ptr requires CUDA backend",
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check whether an output tensor was zero-copied (written directly to the registered pointer).
|
||||
/// Returns false for aliased outputs that need a fallback DtoD copy. Must be called after run().
|
||||
#[cfg(feature = "cuda")]
|
||||
fn output_is_zero_copy(&self, name: &str) -> PyResult<bool> {
|
||||
let node_id = self.tensor_ids.get(name).ok_or_else(|| {
|
||||
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!(
|
||||
"Unknown output tensor: {}",
|
||||
name
|
||||
))
|
||||
})?;
|
||||
|
||||
match &self.runtime {
|
||||
RuntimeBackend::Cuda(rt) => Ok(rt.output_is_zero_copy(*node_id)),
|
||||
_ => Ok(false),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set a weight tensor from a CPU host pointer, matching by Input node label (dtype-aware).
|
||||
/// `n_bytes` is the total byte count. `dtype_code` uses PT2 numbering (7=f32, 6=f16, 13=bf16, etc.).
|
||||
fn set_weight_from_ptr(
|
||||
&mut self,
|
||||
label: &str,
|
||||
ptr: u64,
|
||||
n_bytes: usize,
|
||||
dtype_code: u32,
|
||||
) -> PyResult<()> {
|
||||
/// Set a weight tensor from a CPU host pointer, matching by Input node label.
|
||||
fn set_weight_from_ptr(&mut self, label: &str, ptr: u64, n_elements: usize) -> PyResult<()> {
|
||||
debug_assert!(ptr != 0, "set_weight_from_ptr called with null pointer");
|
||||
let &node_id = self.label_map.get(label).ok_or_else(|| {
|
||||
pyo3::exceptions::PyKeyError::new_err(format!("No Input node with label: {}", label))
|
||||
})?;
|
||||
let bytes = unsafe { std::slice::from_raw_parts(ptr as *const u8, n_bytes).to_vec() };
|
||||
let typed = TypedData::from_pytorch_bytes(bytes, dtype_code);
|
||||
self.runtime.set_data(node_id, typed);
|
||||
let data: Vec<f32> =
|
||||
unsafe { std::slice::from_raw_parts(ptr as *const f32, n_elements).to_vec() };
|
||||
self.runtime.set_data(node_id, data);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -575,19 +480,7 @@ impl CompiledGraph {
|
||||
})
|
||||
}
|
||||
|
||||
/// Get the PT2 dtype codes for all outputs (in order).
|
||||
#[getter]
|
||||
fn output_dtypes(&self) -> Vec<u32> {
|
||||
self.output_dtypes
|
||||
.iter()
|
||||
.map(|d| luminal_dtype_to_pt2_code(*d))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get output tensor data by name as f32 (copies to host).
|
||||
/// For native backend: handles any NativeData variant by converting to f32.
|
||||
/// The native runtime may produce NativeData::Int or NativeData::Bool for some ops
|
||||
/// (e.g., Cast chains), so we can't assume NativeData::F32.
|
||||
/// Get output tensor data by name (copies to host).
|
||||
fn get_output(&self, name: &str) -> PyResult<Vec<f32>> {
|
||||
let node_id = self.tensor_ids.get(name).ok_or_else(|| {
|
||||
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!(
|
||||
@@ -595,37 +488,7 @@ impl CompiledGraph {
|
||||
name
|
||||
))
|
||||
})?;
|
||||
match &self.runtime {
|
||||
RuntimeBackend::Native(rt) => {
|
||||
let id = *node_id;
|
||||
let output_id = rt
|
||||
.graph
|
||||
.node_indices()
|
||||
.find(|n| {
|
||||
if let Some(out) = (**rt.graph[*n]).as_any().downcast_ref::<Output>() {
|
||||
out.node == id.index()
|
||||
} else {
|
||||
false
|
||||
}
|
||||
})
|
||||
.ok_or_else(|| {
|
||||
pyo3::exceptions::PyRuntimeError::new_err(format!(
|
||||
"No output node found for tensor: {}",
|
||||
name
|
||||
))
|
||||
})?;
|
||||
let data = rt.buffers.get(&output_id).ok_or_else(|| {
|
||||
pyo3::exceptions::PyRuntimeError::new_err(format!(
|
||||
"No buffer data for output tensor: {}",
|
||||
name
|
||||
))
|
||||
})?;
|
||||
// Convert any NativeData variant to f32
|
||||
Ok((0..data.len()).map(|i| data.f32(i)).collect())
|
||||
}
|
||||
#[cfg(feature = "cuda")]
|
||||
RuntimeBackend::Cuda(rt) => Ok(rt.get_f32(*node_id)),
|
||||
}
|
||||
Ok(self.runtime.get_f32(*node_id))
|
||||
}
|
||||
|
||||
/// Copy output tensor data directly to a CUDA device pointer (DtoD).
|
||||
|
||||
248
crates/luminal_python/rust/src/dispatch.rs
Normal file
248
crates/luminal_python/rust/src/dispatch.rs
Normal file
@@ -0,0 +1,248 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use luminal::{prelude::*, shape::Expression};
|
||||
use onnx_protobuf::NodeProto;
|
||||
|
||||
use crate::ops_parse::*;
|
||||
|
||||
pub fn process_onnx_nodes(
|
||||
nodes: &[NodeProto],
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
cx: &mut Graph,
|
||||
weight_data: &mut Vec<(String, Vec<f32>)>,
|
||||
known_values: &mut HashMap<String, Vec<f32>>,
|
||||
shape_exprs: &mut HashMap<String, Vec<Expression>>,
|
||||
) -> Result<(), String> {
|
||||
for node in nodes {
|
||||
match node.op_type.as_str() {
|
||||
"Add" => parse_binary_broadcast_op(
|
||||
node,
|
||||
tensors,
|
||||
"Add",
|
||||
|a, b| a + b,
|
||||
shape_exprs,
|
||||
known_values,
|
||||
)?,
|
||||
"Mod" => parse_binary_broadcast_op(
|
||||
node,
|
||||
tensors,
|
||||
"Mod",
|
||||
|a, b| a % b,
|
||||
shape_exprs,
|
||||
known_values,
|
||||
)?,
|
||||
"Sub" => parse_binary_broadcast_op(
|
||||
node,
|
||||
tensors,
|
||||
"Sub",
|
||||
|a, b| a - b,
|
||||
shape_exprs,
|
||||
known_values,
|
||||
)?,
|
||||
"Mul" => parse_binary_broadcast_op(
|
||||
node,
|
||||
tensors,
|
||||
"Mul",
|
||||
|a, b| a * b,
|
||||
shape_exprs,
|
||||
known_values,
|
||||
)?,
|
||||
"Div" => parse_binary_broadcast_op(
|
||||
node,
|
||||
tensors,
|
||||
"Div",
|
||||
|a, b| a / b,
|
||||
shape_exprs,
|
||||
known_values,
|
||||
)?,
|
||||
"Sqrt" => parse_unary_op(node, tensors, "Sqrt", |a| a.sqrt())?,
|
||||
"Transpose" => parse_transpose_node(node, tensors)?,
|
||||
"Concat" => parse_concat_node(node, tensors, shape_exprs, known_values)?,
|
||||
"Floor" => parse_floor_node(node, tensors)?,
|
||||
"Ceil" => parse_ceil_node(node, tensors)?,
|
||||
"Sin" => parse_unary_op(node, tensors, "Sin", |a| a.sin())?,
|
||||
"Neg" => parse_unary_op(node, tensors, "Neg", |a| -a)?,
|
||||
"Cos" => parse_unary_op(node, tensors, "Cos", |a| a.cos())?,
|
||||
"Pow" => parse_binary_broadcast_op(
|
||||
node,
|
||||
tensors,
|
||||
"Pow",
|
||||
|a, b| a.pow(b),
|
||||
shape_exprs,
|
||||
known_values,
|
||||
)?,
|
||||
"Sigmoid" => parse_unary_op(node, tensors, "Sigmoid", |a| a.sigmoid())?,
|
||||
"Tanh" => parse_unary_op(node, tensors, "Tanh", |a| a.tanh())?,
|
||||
"Relu" => parse_unary_op(node, tensors, "Relu", |a| a.relu())?,
|
||||
"Softmax" => parse_softmax_node(node, tensors)?,
|
||||
"Abs" => parse_unary_op(node, tensors, "Abs", |a| a.abs())?,
|
||||
"Reciprocal" => parse_unary_op(node, tensors, "Reciprocal", |a| a.reciprocal())?,
|
||||
"Clip" => parse_clip_node(node, tensors, known_values)?,
|
||||
"Equal" => parse_binary_broadcast_op(
|
||||
node,
|
||||
tensors,
|
||||
"Equal",
|
||||
|a, b| a.eq(b),
|
||||
shape_exprs,
|
||||
known_values,
|
||||
)?,
|
||||
"Where" => parse_where_node(node, tensors)?,
|
||||
"Constant" => {
|
||||
parse_constant_node(node, tensors, cx, weight_data, known_values, shape_exprs)?
|
||||
}
|
||||
"ConstantOfShape" => {
|
||||
parse_constant_of_shape(node, tensors, cx, weight_data, known_values, shape_exprs)?
|
||||
}
|
||||
"Cast" => parse_cast_node(node, tensors, weight_data, known_values, shape_exprs)?,
|
||||
"MatMul" => parse_matmul_node(node, tensors)?,
|
||||
"Reshape" => parse_reshape_node(node, tensors, known_values, shape_exprs)?,
|
||||
"Shape" => parse_shape_node(node, tensors, cx, weight_data, known_values, shape_exprs)?,
|
||||
"Gather" => {
|
||||
parse_gather_node(node, tensors, cx, weight_data, known_values, shape_exprs)?
|
||||
}
|
||||
"GatherND" => parse_gathernd_node(node, tensors, cx, weight_data, known_values)?,
|
||||
"Less" => parse_binary_broadcast_op(
|
||||
node,
|
||||
tensors,
|
||||
"Less",
|
||||
|a, b| a.lt(b),
|
||||
shape_exprs,
|
||||
known_values,
|
||||
)?,
|
||||
"Greater" => parse_binary_broadcast_op(
|
||||
node,
|
||||
tensors,
|
||||
"Greater",
|
||||
|a, b| b.lt(a),
|
||||
shape_exprs,
|
||||
known_values,
|
||||
)?,
|
||||
"LessOrEqual" => parse_binary_broadcast_op(
|
||||
node,
|
||||
tensors,
|
||||
"LessOrEqual",
|
||||
|a, b| a.le(b),
|
||||
shape_exprs,
|
||||
known_values,
|
||||
)?,
|
||||
"GreaterOrEqual" => parse_binary_broadcast_op(
|
||||
node,
|
||||
tensors,
|
||||
"GreaterOrEqual",
|
||||
|a, b| a.ge(b),
|
||||
shape_exprs,
|
||||
known_values,
|
||||
)?,
|
||||
"Not" => parse_not_node(node, tensors)?,
|
||||
"And" => parse_binary_broadcast_op(
|
||||
node,
|
||||
tensors,
|
||||
"And",
|
||||
|a, b| a.cast(DType::F32) * b.cast(DType::F32),
|
||||
shape_exprs,
|
||||
known_values,
|
||||
)?,
|
||||
"Or" => parse_binary_broadcast_op(
|
||||
node,
|
||||
tensors,
|
||||
"Or",
|
||||
|a, b| (a.cast(DType::F32) + b.cast(DType::F32)).minimum_f32(1.0),
|
||||
shape_exprs,
|
||||
known_values,
|
||||
)?,
|
||||
"Xor" => parse_binary_broadcast_op(
|
||||
node,
|
||||
tensors,
|
||||
"Xor",
|
||||
|a, b| a.ne(b),
|
||||
shape_exprs,
|
||||
known_values,
|
||||
)?,
|
||||
"Min" => parse_variadic_broadcast_op(
|
||||
node,
|
||||
tensors,
|
||||
"Min",
|
||||
|a, b| a.minimum(b),
|
||||
shape_exprs,
|
||||
known_values,
|
||||
)?,
|
||||
"Max" => parse_variadic_broadcast_op(
|
||||
node,
|
||||
tensors,
|
||||
"Max",
|
||||
|a, b| a.maximum(b),
|
||||
shape_exprs,
|
||||
known_values,
|
||||
)?,
|
||||
"Identity" => parse_identity(node, tensors, known_values, shape_exprs)?,
|
||||
"Unsqueeze" => parse_unsqueeze_node(node, tensors, known_values, shape_exprs)?,
|
||||
"Squeeze" => parse_squeeze_node(node, tensors, known_values, shape_exprs)?,
|
||||
"ReduceSum" => parse_reduce_op(
|
||||
node,
|
||||
tensors,
|
||||
known_values,
|
||||
"ReduceSum",
|
||||
|t, axes| t.sum(axes),
|
||||
|flat, _n| flat.sum(1),
|
||||
)?,
|
||||
"ReduceMax" => parse_reduce_op(
|
||||
node,
|
||||
tensors,
|
||||
known_values,
|
||||
"ReduceMax",
|
||||
|t, axes| t.max(axes),
|
||||
|flat, _n| flat.max(1),
|
||||
)?,
|
||||
"ReduceMin" => parse_reduce_op(
|
||||
node,
|
||||
tensors,
|
||||
known_values,
|
||||
"ReduceMin",
|
||||
|t, axes| t.min(axes),
|
||||
|flat, _n| flat.min(1),
|
||||
)?,
|
||||
"ReduceMean" => parse_reduce_op(
|
||||
node,
|
||||
tensors,
|
||||
known_values,
|
||||
"ReduceMean",
|
||||
|t, axes| t.mean(axes),
|
||||
|flat, n| flat.sum(1) / n as f32,
|
||||
)?,
|
||||
"Trilu" => parse_trilu_node(node, tensors, cx, known_values)?,
|
||||
"GatherElements" => parse_gather_elements_node(node, tensors)?,
|
||||
"ScatterElements" => parse_scatter_elements_node(node, tensors)?,
|
||||
"ScatterND" => parse_scatter_nd_node(node, tensors)?,
|
||||
"Expand" => parse_expand_node(node, tensors, known_values, shape_exprs)?,
|
||||
"IsNaN" => parse_unary_op(node, tensors, "IsNaN", |a| a.ne(a))?,
|
||||
"LayerNormalization" => parse_layernorm_node(node, tensors)?,
|
||||
"Gemm" => parse_gemm_node(node, tensors)?,
|
||||
"Erf" => parse_erf_node(node, tensors)?,
|
||||
"Slice" => parse_slice_node(node, tensors, known_values, shape_exprs)?,
|
||||
"Split" => parse_split_node(node, tensors, known_values)?,
|
||||
"TopK" => parse_topk_node(node, tensors, known_values)?,
|
||||
"OneHot" => parse_onehot_node(node, tensors, known_values)?,
|
||||
"Range" => parse_range_node(node, tensors, cx, weight_data, known_values, shape_exprs)?,
|
||||
"CumSum" => parse_cumsum_node(node, tensors, known_values)?,
|
||||
"Gelu" => parse_unary_op(node, tensors, "Gelu", |a| a.gelu())?,
|
||||
"Conv" => parse_conv_node(node, tensors)?,
|
||||
"Pad" => parse_pad_node(node, tensors, known_values)?,
|
||||
"Resize" => parse_resize_node(node, tensors, known_values)?,
|
||||
"Tile" => parse_tile_node(node, tensors, known_values)?,
|
||||
"ReduceL2" => parse_reduce_op(
|
||||
node,
|
||||
tensors,
|
||||
known_values,
|
||||
"ReduceL2",
|
||||
|t, axes| (t * t).sum(axes).sqrt(),
|
||||
|flat, _n| (flat * flat).sum(1).sqrt(),
|
||||
)?,
|
||||
"GroupNormalization" => parse_group_norm_node(node, tensors)?,
|
||||
_ => {
|
||||
panic!("Missing Node {}", node.op_type)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,6 +1,9 @@
|
||||
mod compiled_graph;
|
||||
mod dispatch;
|
||||
mod onnx_translator;
|
||||
mod ops_parse;
|
||||
mod runtime;
|
||||
pub mod typed_data;
|
||||
mod util;
|
||||
|
||||
// PT2 modules
|
||||
mod pt2_compiled_model;
|
||||
@@ -12,9 +15,58 @@ mod translator;
|
||||
use compiled_graph::CompiledGraph;
|
||||
use pt2_compiled_model::process_pt2;
|
||||
use pyo3::prelude::*;
|
||||
use std::collections::HashMap;
|
||||
|
||||
fn validate_backend(backend: &str) -> PyResult<()> {
|
||||
match backend {
|
||||
"native" => Ok(()),
|
||||
#[cfg(feature = "cuda")]
|
||||
"cuda" => Ok(()),
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
"cuda" => Err(pyo3::exceptions::PyValueError::new_err(
|
||||
"CUDA backend requested, but this luminal extension was built without the `cuda` feature. Rebuild with `maturin develop --features cuda -r` or use backend='native'.",
|
||||
)),
|
||||
_ => {
|
||||
#[cfg(feature = "cuda")]
|
||||
{
|
||||
Err(pyo3::exceptions::PyValueError::new_err(format!(
|
||||
"Invalid backend '{}'. Must be 'native' or 'cuda'",
|
||||
backend
|
||||
)))
|
||||
}
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
{
|
||||
Err(pyo3::exceptions::PyValueError::new_err(format!(
|
||||
"Invalid backend '{}'. This build only supports 'native'. Rebuild with the `cuda` feature to enable 'cuda'.",
|
||||
backend
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
#[pyo3(signature = (path, backend="native", search_iters=10, weight_device_ptrs=None))]
|
||||
fn process_onnx(
|
||||
path: &str,
|
||||
backend: &str,
|
||||
search_iters: usize,
|
||||
weight_device_ptrs: Option<HashMap<String, (u64, usize)>>,
|
||||
) -> PyResult<CompiledGraph> {
|
||||
validate_backend(backend)?;
|
||||
|
||||
onnx_translator::compile_onnx(
|
||||
path,
|
||||
backend,
|
||||
weight_device_ptrs.unwrap_or_default(),
|
||||
search_iters,
|
||||
)
|
||||
.map_err(pyo3::exceptions::PyRuntimeError::new_err)
|
||||
}
|
||||
|
||||
#[pymodule]
|
||||
fn luminal(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_function(wrap_pyfunction!(process_onnx, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(process_pt2, m)?)?;
|
||||
m.add_class::<CompiledGraph>()?;
|
||||
Ok(())
|
||||
|
||||
283
crates/luminal_python/rust/src/onnx_translator.rs
Normal file
283
crates/luminal_python/rust/src/onnx_translator.rs
Normal file
@@ -0,0 +1,283 @@
|
||||
use luminal::{
|
||||
prelude::{
|
||||
tracing::{Level, span, trace},
|
||||
*,
|
||||
},
|
||||
shape::Expression,
|
||||
};
|
||||
use onnx_protobuf::ModelProto;
|
||||
use protobuf::Message;
|
||||
use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
fs,
|
||||
path::Path,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
compiled_graph::{CompiledGraph, GraphTranslation, WeightData},
|
||||
dispatch::process_onnx_nodes,
|
||||
util::{
|
||||
DimParamMap, get_shape_for_onnx_value, get_shape_for_onnx_value_expr,
|
||||
load_all_tensor_floats, load_initializer_as_f32,
|
||||
},
|
||||
};
|
||||
|
||||
/// Load, validate, translate, and compile an ONNX model.
|
||||
///
|
||||
/// This is the ONNX counterpart of `pt2_compiled_model::compile_pt2()`.
|
||||
pub fn compile_onnx(
|
||||
path: &str,
|
||||
backend: &str,
|
||||
weight_device_ptrs: HashMap<String, (u64, usize)>,
|
||||
search_iters: usize,
|
||||
) -> Result<CompiledGraph, String> {
|
||||
let data = fs::read(path).map_err(|e| format!("Failed to read file: {}", e))?;
|
||||
let model_directory = Path::new(path).parent().unwrap_or(Path::new("."));
|
||||
let model = ModelProto::parse_from_bytes(&data)
|
||||
.map_err(|e| format!("Failed to parse ONNX model: {}", e))?;
|
||||
|
||||
let opset_version = model
|
||||
.opset_import
|
||||
.iter()
|
||||
.find(|entry| entry.domain.is_empty())
|
||||
.map(|entry| entry.version);
|
||||
|
||||
match opset_version {
|
||||
Some(20) => {}
|
||||
Some(v) => {
|
||||
return Err(format!(
|
||||
"Unsupported ONNX opset version {v}. Only opset 20 is supported."
|
||||
));
|
||||
}
|
||||
None => {
|
||||
return Err(
|
||||
"No ONNX opset version found in model. Only opset 20 is supported.".to_string(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
let (translation, mut weights) = translate_onnx(model, model_directory)?;
|
||||
weights.device_ptrs = weight_device_ptrs;
|
||||
CompiledGraph::parse_graph(translation, weights, backend, search_iters)
|
||||
}
|
||||
|
||||
/// Translate an ONNX model into a format-neutral GraphTranslation + WeightData.
|
||||
pub fn translate_onnx(
|
||||
model: ModelProto,
|
||||
model_directory: &Path,
|
||||
) -> Result<(GraphTranslation, WeightData), String> {
|
||||
let _span = span!(Level::TRACE, "ONNX Graph Translation").entered();
|
||||
let onnx_graph = &model.graph;
|
||||
let mut cx = Graph::new();
|
||||
let mut tensors: HashMap<String, GraphTensor> = HashMap::new();
|
||||
|
||||
// Dynamic dimension tracking
|
||||
let mut dim_param_map: DimParamMap = HashMap::new();
|
||||
let mut next_char = 'a';
|
||||
|
||||
// Separate initializers (weights) from true user inputs
|
||||
let initializer_names: HashSet<&str> = onnx_graph
|
||||
.initializer
|
||||
.iter()
|
||||
.map(|t| t.name.as_str())
|
||||
.collect();
|
||||
|
||||
let input_names: Vec<String> = onnx_graph
|
||||
.input
|
||||
.iter()
|
||||
.filter(|inp| !initializer_names.contains(inp.name.as_str()))
|
||||
.map(|inp| inp.name.clone())
|
||||
.collect();
|
||||
|
||||
// Create input tensors with dynamic dimension support
|
||||
for input in &onnx_graph.input {
|
||||
let shape_exprs = get_shape_for_onnx_value_expr(input, &mut dim_param_map, &mut next_char);
|
||||
if shape_exprs.is_empty() {
|
||||
let shape = get_shape_for_onnx_value(input);
|
||||
if shape.is_empty() {
|
||||
trace!("Input {} skipped because it is empty", input.name.clone());
|
||||
continue;
|
||||
}
|
||||
let tensor = cx.named_tensor(input.name.clone(), shape);
|
||||
trace!("Input {} added to tensors", input.name.clone());
|
||||
tensors.insert(input.name.clone(), tensor);
|
||||
continue;
|
||||
}
|
||||
let tensor = cx.named_tensor(input.name.clone(), shape_exprs);
|
||||
trace!("Input {} added to tensors", input.name.clone());
|
||||
tensors.insert(input.name.clone(), tensor);
|
||||
}
|
||||
|
||||
// Create initializer (weight) tensors
|
||||
for init in &onnx_graph.initializer {
|
||||
if !tensors.contains_key(&init.name) {
|
||||
let mut shape: Vec<usize> = init.dims.iter().map(|&d| d as usize).collect();
|
||||
if shape.is_empty() {
|
||||
shape = vec![1];
|
||||
}
|
||||
let tensor = cx.named_tensor(init.name.clone(), shape);
|
||||
tensors.insert(init.name.clone(), tensor);
|
||||
}
|
||||
}
|
||||
|
||||
// Load small constants for constant folding
|
||||
let mut known_values: HashMap<String, Vec<f32>> = HashMap::new();
|
||||
for init in &onnx_graph.initializer {
|
||||
let n_elements: usize = init
|
||||
.dims
|
||||
.iter()
|
||||
.map(|&d| d as usize)
|
||||
.product::<usize>()
|
||||
.max(1);
|
||||
if n_elements <= 32 {
|
||||
if let Some(floats) = load_initializer_as_f32(init) {
|
||||
known_values.insert(init.name.clone(), floats);
|
||||
} else {
|
||||
panic!("Unable to load initializer values for {:?}", init.name);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Shape expressions for propagating symbolic shapes through ONNX graphs
|
||||
let mut shape_exprs: HashMap<String, Vec<Expression>> = HashMap::new();
|
||||
|
||||
// Accumulates constant node data from process_onnx_nodes
|
||||
let mut constant_data: Vec<(String, Vec<f32>)> = Vec::new();
|
||||
|
||||
// Process computation nodes
|
||||
process_onnx_nodes(
|
||||
&onnx_graph.node,
|
||||
&mut tensors,
|
||||
&mut cx,
|
||||
&mut constant_data,
|
||||
&mut known_values,
|
||||
&mut shape_exprs,
|
||||
)
|
||||
.map_err(|e| format!("process_onnx_nodes failed: {}", e))?;
|
||||
|
||||
// Mark weight/constant tensors as persistent so their buffers survive execute()
|
||||
for (name, gt) in &tensors {
|
||||
if !input_names.contains(name) {
|
||||
gt.persist();
|
||||
}
|
||||
}
|
||||
|
||||
// Mark graph outputs (must happen before build_search_space)
|
||||
let mut output_names = Vec::new();
|
||||
let mut output_shape_exprs = Vec::new();
|
||||
for output_vi in &onnx_graph.output {
|
||||
if let Some(>) = tensors.get(&output_vi.name) {
|
||||
// Force contiguous if the shape tracker is a non-contiguous view
|
||||
let gt = if gt.shape != gt.shape.contiguous() {
|
||||
let contiguous = gt * 1.0;
|
||||
tensors.insert(output_vi.name.clone(), contiguous);
|
||||
contiguous
|
||||
} else {
|
||||
gt
|
||||
};
|
||||
gt.output();
|
||||
let dims = gt.dims();
|
||||
output_shape_exprs.push(dims.clone());
|
||||
|
||||
let shape: Vec<usize> = dims.iter().map(|d| d.to_usize().unwrap_or(1)).collect();
|
||||
if shape.is_empty() {
|
||||
return Err(format!(
|
||||
"Output tensor '{}' has no shape information in the ONNX model",
|
||||
output_vi.name
|
||||
));
|
||||
}
|
||||
output_names.push(output_vi.name.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Set initial dynamic dimension values from example input shapes
|
||||
let has_dynamic = !dim_param_map.is_empty();
|
||||
if has_dynamic {
|
||||
for input in &onnx_graph.input {
|
||||
if initializer_names.contains(input.name.as_str()) {
|
||||
continue;
|
||||
}
|
||||
let concrete_shape = get_shape_for_onnx_value(input);
|
||||
let expr_shape =
|
||||
get_shape_for_onnx_value_expr(input, &mut dim_param_map, &mut next_char);
|
||||
for (expr, concrete) in expr_shape.iter().zip(concrete_shape.iter()) {
|
||||
if expr.to_usize().is_none()
|
||||
&& let Some(ch) = dim_param_map
|
||||
.values()
|
||||
.find(|&&ch| Expression::from(ch) == *expr)
|
||||
{
|
||||
cx.set_dim(*ch, *concrete);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Build weight data: initializers + constants from process_onnx_nodes
|
||||
let mut weights: Vec<(String, Vec<f32>)> = Vec::new();
|
||||
for (name, floats) in load_all_tensor_floats(&onnx_graph.initializer, model_directory) {
|
||||
if let Some(f) = floats {
|
||||
weights.push((name, f));
|
||||
}
|
||||
}
|
||||
weights.extend(constant_data);
|
||||
|
||||
// Build tensor sizes for CUDA dummy data allocation
|
||||
let mut tensor_sizes: HashMap<String, usize> = HashMap::new();
|
||||
for input in &onnx_graph.input {
|
||||
if !initializer_names.contains(input.name.as_str()) {
|
||||
let shape = get_shape_for_onnx_value(input);
|
||||
let n: usize = shape.iter().product::<usize>().max(1);
|
||||
tensor_sizes.insert(input.name.clone(), n);
|
||||
}
|
||||
}
|
||||
for init in &onnx_graph.initializer {
|
||||
let n: usize = init
|
||||
.dims
|
||||
.iter()
|
||||
.map(|&d| d as usize)
|
||||
.product::<usize>()
|
||||
.max(1);
|
||||
tensor_sizes.insert(init.name.clone(), n);
|
||||
}
|
||||
for (name, data) in &weights {
|
||||
if !tensor_sizes.contains_key(name) {
|
||||
tensor_sizes.insert(name.clone(), data.len());
|
||||
}
|
||||
}
|
||||
|
||||
// Collect tensor name → NodeIndex mapping
|
||||
let tensor_ids: HashMap<String, NodeIndex> = tensors
|
||||
.iter()
|
||||
.map(|(name, gt)| (name.clone(), gt.id))
|
||||
.collect();
|
||||
|
||||
// Build input_shape_exprs for user inputs (needed for auto-dim detection)
|
||||
let input_shape_exprs: Vec<Vec<Expression>> = input_names
|
||||
.iter()
|
||||
.map(|name| {
|
||||
if let Some(>) = tensors.get(name) {
|
||||
gt.dims()
|
||||
} else {
|
||||
vec![]
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
let translation = GraphTranslation {
|
||||
graph: cx,
|
||||
tensor_ids,
|
||||
input_names,
|
||||
output_names,
|
||||
output_shape_exprs,
|
||||
input_shape_exprs,
|
||||
dim_param_map,
|
||||
};
|
||||
|
||||
let weight_data = WeightData {
|
||||
weights,
|
||||
tensor_sizes,
|
||||
device_ptrs: HashMap::new(),
|
||||
};
|
||||
|
||||
Ok((translation, weight_data))
|
||||
}
|
||||
187
crates/luminal_python/rust/src/ops_parse/binary.rs
Normal file
187
crates/luminal_python/rust/src/ops_parse/binary.rs
Normal file
@@ -0,0 +1,187 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use luminal::{
|
||||
prelude::{tracing::trace, *},
|
||||
shape::Expression,
|
||||
};
|
||||
use onnx_protobuf::NodeProto;
|
||||
|
||||
use crate::util::{broadcast_to_expr, compute_broadcast_shape_expr};
|
||||
|
||||
/// Handle Where node: conditional select — output[i] = condition[i] ? x[i] : y[i]
|
||||
///
|
||||
/// ONNX Where uses numpy-style broadcasting across all three inputs.
|
||||
pub fn parse_where_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
) -> Result<(), String> {
|
||||
assert!(node.input.len() == 3, "Where should have 3 inputs");
|
||||
let condition = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("Where: missing condition tensor '{}'", node.input[0]))?;
|
||||
let x = *tensors
|
||||
.get(&node.input[1])
|
||||
.ok_or_else(|| format!("Where: missing X tensor '{}'", node.input[1]))?;
|
||||
let y = *tensors
|
||||
.get(&node.input[2])
|
||||
.ok_or_else(|| format!("Where: missing Y tensor '{}'", node.input[2]))?;
|
||||
|
||||
let output_name = &node.output[0];
|
||||
|
||||
// ONNX Where broadcasts all 3 inputs to a common shape
|
||||
let bc_shape = compute_broadcast_shape_expr(
|
||||
&condition.dims(),
|
||||
&compute_broadcast_shape_expr(&x.dims(), &y.dims()),
|
||||
);
|
||||
let condition = broadcast_to_expr(condition, &bc_shape);
|
||||
let x = broadcast_to_expr(x, &bc_shape);
|
||||
let y = broadcast_to_expr(y, &bc_shape);
|
||||
|
||||
let result = x.cond(condition, y);
|
||||
tensors.insert(output_name.clone(), result);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn parse_binary_broadcast_op(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
op_name: &str,
|
||||
op: impl Fn(GraphTensor, GraphTensor) -> GraphTensor,
|
||||
shape_exprs: &mut HashMap<String, Vec<Expression>>,
|
||||
known_values: &HashMap<String, Vec<f32>>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: {} Node", op_name);
|
||||
assert!(
|
||||
node.input.len() == 2,
|
||||
"{} should have 2 inputs, got {}",
|
||||
op_name,
|
||||
node.input.len()
|
||||
);
|
||||
assert!(
|
||||
node.output.len() == 1,
|
||||
"{} should have 1 output, got {}",
|
||||
op_name,
|
||||
node.output.len()
|
||||
);
|
||||
// Shape-only path: if any input is shape-only (not in tensors), do Expression arithmetic
|
||||
let a_missing = !tensors.contains_key(&node.input[0]);
|
||||
let b_missing = !tensors.contains_key(&node.input[1]);
|
||||
if a_missing || b_missing {
|
||||
// At least one input is shape-only. Do shape_exprs arithmetic and return.
|
||||
let se_a = shape_exprs.get(&node.input[0]).cloned().or_else(|| {
|
||||
known_values
|
||||
.get(&node.input[0])
|
||||
.map(|kv| kv.iter().map(|&v| Expression::from(v as usize)).collect())
|
||||
});
|
||||
let se_b = shape_exprs.get(&node.input[1]).cloned().or_else(|| {
|
||||
known_values
|
||||
.get(&node.input[1])
|
||||
.map(|kv| kv.iter().map(|&v| Expression::from(v as usize)).collect())
|
||||
});
|
||||
if let (Some(se_a), Some(se_b)) = (se_a, se_b)
|
||||
&& se_a.len() == 1
|
||||
&& se_b.len() == 1
|
||||
{
|
||||
let result_expr = match op_name {
|
||||
"Add" => Some(se_a[0] + se_b[0]),
|
||||
"Sub" => Some(se_a[0] - se_b[0]),
|
||||
"Mul" => Some(se_a[0] * se_b[0]),
|
||||
"Div" => Some(se_a[0] / se_b[0]),
|
||||
_ => None,
|
||||
};
|
||||
if let Some(expr) = result_expr {
|
||||
shape_exprs.insert(node.output[0].clone(), vec![expr]);
|
||||
}
|
||||
}
|
||||
trace!("Finished parse: {} Node (shape-only)", op_name);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let a = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("{}: missing input '{}'", op_name, node.input[0]))?;
|
||||
let b = *tensors
|
||||
.get(&node.input[1])
|
||||
.ok_or_else(|| format!("{}: missing input '{}'", op_name, node.input[1]))?;
|
||||
let broadcast_shape = compute_broadcast_shape_expr(&a.dims(), &b.dims());
|
||||
let a_bc = broadcast_to_expr(a, &broadcast_shape);
|
||||
let b_bc = broadcast_to_expr(b, &broadcast_shape);
|
||||
let result = op(a_bc, b_bc);
|
||||
tensors.insert(node.output[0].clone(), result);
|
||||
|
||||
// Propagate shape_exprs for scalar shape arithmetic (e.g., Add(1, seq_len))
|
||||
// At least one input must be in shape_exprs; the other can come from known_values.
|
||||
let has_shape_expr =
|
||||
shape_exprs.contains_key(&node.input[0]) || shape_exprs.contains_key(&node.input[1]);
|
||||
if has_shape_expr {
|
||||
let se_a = shape_exprs.get(&node.input[0]).cloned().or_else(|| {
|
||||
known_values
|
||||
.get(&node.input[0])
|
||||
.map(|kv| kv.iter().map(|&v| Expression::from(v as usize)).collect())
|
||||
});
|
||||
let se_b = shape_exprs.get(&node.input[1]).cloned().or_else(|| {
|
||||
known_values
|
||||
.get(&node.input[1])
|
||||
.map(|kv| kv.iter().map(|&v| Expression::from(v as usize)).collect())
|
||||
});
|
||||
if let (Some(se_a), Some(se_b)) = (se_a, se_b)
|
||||
&& se_a.len() == 1
|
||||
&& se_b.len() == 1
|
||||
{
|
||||
let result_expr = match op_name {
|
||||
"Add" => Some(se_a[0] + se_b[0]),
|
||||
"Sub" => Some(se_a[0] - se_b[0]),
|
||||
"Mul" => Some(se_a[0] * se_b[0]),
|
||||
"Div" => Some(se_a[0] / se_b[0]),
|
||||
_ => None,
|
||||
};
|
||||
if let Some(expr) = result_expr {
|
||||
shape_exprs.insert(node.output[0].clone(), vec![expr]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
trace!("Finished parse: {} Node", op_name);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn parse_variadic_broadcast_op(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
op_name: &str,
|
||||
op: impl Fn(GraphTensor, GraphTensor) -> GraphTensor,
|
||||
_shape_exprs: &mut HashMap<String, Vec<Expression>>,
|
||||
_known_values: &HashMap<String, Vec<f32>>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: {} Node", op_name);
|
||||
assert!(
|
||||
node.input.len() >= 2,
|
||||
"{} needs at least two inputs, got {}",
|
||||
op_name,
|
||||
node.input.len()
|
||||
);
|
||||
assert!(
|
||||
node.output.len() == 1,
|
||||
"{} nodes only have one output, got {}",
|
||||
op_name,
|
||||
node.output.len()
|
||||
);
|
||||
|
||||
let mut result = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("{}: missing input tensor '{}'", op_name, node.input[0]))?;
|
||||
|
||||
for input_name in &node.input[1..] {
|
||||
let rhs = *tensors
|
||||
.get(input_name)
|
||||
.ok_or_else(|| format!("{}: missing input tensor '{}'", op_name, input_name))?;
|
||||
let broadcast_shape = compute_broadcast_shape_expr(&result.dims(), &rhs.dims());
|
||||
let lhs_bc = broadcast_to_expr(result, &broadcast_shape);
|
||||
let rhs_bc = broadcast_to_expr(rhs, &broadcast_shape);
|
||||
result = op(lhs_bc, rhs_bc);
|
||||
}
|
||||
|
||||
tensors.insert(node.output[0].clone(), result);
|
||||
trace!("Finished parse: {} Node", op_name);
|
||||
Ok(())
|
||||
}
|
||||
194
crates/luminal_python/rust/src/ops_parse/convolution.rs
Normal file
194
crates/luminal_python/rust/src/ops_parse/convolution.rs
Normal file
@@ -0,0 +1,194 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use luminal::{
|
||||
prelude::{tracing::trace, *},
|
||||
shape::Expression,
|
||||
};
|
||||
use onnx_protobuf::NodeProto;
|
||||
|
||||
use crate::util::get_int_attr;
|
||||
|
||||
/// Get an integer-list attribute from a node, with a default value applied per element.
|
||||
fn get_ints_attr(node: &NodeProto, name: &str, default_elem: i64, spatial: usize) -> Vec<usize> {
|
||||
for attr in &node.attribute {
|
||||
if attr.name == name {
|
||||
return attr.ints.iter().map(|&v| v as usize).collect();
|
||||
}
|
||||
}
|
||||
vec![default_elem as usize; spatial]
|
||||
}
|
||||
|
||||
/// Parse an ONNX Conv node.
|
||||
///
|
||||
/// Supports N-dimensional convolution (1D, 2D, 3D) with group=1.
|
||||
/// Uses the unfold-based approach from `luminal_nn::ConvND`.
|
||||
///
|
||||
/// Input layout: [batch, C_in, spatial...]
|
||||
/// Weight layout: [C_out, C_in/group, kernel...]
|
||||
/// Optional bias: [C_out]
|
||||
pub fn parse_conv_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: Conv Node");
|
||||
|
||||
assert!(
|
||||
node.input.len() >= 2,
|
||||
"Conv needs at least 2 inputs (X, W), got {}",
|
||||
node.input.len()
|
||||
);
|
||||
|
||||
let x = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("Conv: missing input X '{}'", node.input[0]))?;
|
||||
let w = *tensors
|
||||
.get(&node.input[1])
|
||||
.ok_or_else(|| format!("Conv: missing weight W '{}'", node.input[1]))?;
|
||||
let bias = if node.input.len() > 2 && !node.input[2].is_empty() {
|
||||
Some(
|
||||
*tensors
|
||||
.get(&node.input[2])
|
||||
.ok_or_else(|| format!("Conv: missing bias B '{}'", node.input[2]))?,
|
||||
)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let x_dims = x.dims();
|
||||
let w_dims = w.dims();
|
||||
let rank = x_dims.len();
|
||||
assert!(
|
||||
rank >= 3,
|
||||
"Conv: input must be at least 3D (batch, channels, spatial...), got {rank}D"
|
||||
);
|
||||
|
||||
let spatial = rank - 2; // number of spatial dimensions
|
||||
|
||||
// Parse attributes
|
||||
let kernel_shape = get_ints_attr(node, "kernel_shape", 1, spatial);
|
||||
let strides = get_ints_attr(node, "strides", 1, spatial);
|
||||
let dilations = get_ints_attr(node, "dilations", 1, spatial);
|
||||
let group = get_int_attr(node, "group", 1) as usize;
|
||||
|
||||
// Parse pads: ONNX format is [begin_0, begin_1, ..., end_0, end_1, ...]
|
||||
let pads_flat = get_ints_attr(node, "pads", 0, 2 * spatial);
|
||||
let mut pads_begin = vec![0usize; spatial];
|
||||
let mut pads_end = vec![0usize; spatial];
|
||||
if pads_flat.len() == 2 * spatial {
|
||||
pads_begin[..spatial].copy_from_slice(&pads_flat[..spatial]);
|
||||
pads_end[..spatial].copy_from_slice(&pads_flat[spatial..(spatial + spatial)]);
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
group, 1,
|
||||
"Conv: only group=1 is currently supported, got {group}"
|
||||
);
|
||||
|
||||
// Get channel dimensions
|
||||
let ch_out = w_dims[0]
|
||||
.to_usize()
|
||||
.ok_or("Conv: weight C_out must be concrete")?;
|
||||
let ch_in = x_dims[1]
|
||||
.to_usize()
|
||||
.ok_or("Conv: input C_in must be concrete")?;
|
||||
|
||||
let kernel_product: usize = kernel_shape.iter().product();
|
||||
|
||||
// Reshape weight from ONNX [C_out, C_in, *kernel] to [C_out, C_in * kernel_product]
|
||||
let w_reshaped = {
|
||||
let mut wt = w;
|
||||
wt.shape = ShapeTracker::new(vec![ch_out, ch_in * kernel_product]);
|
||||
wt
|
||||
};
|
||||
|
||||
// Pad spatial dimensions
|
||||
let mut padding: Vec<(Expression, Expression)> = vec![(0.into(), 0.into()); rank];
|
||||
for i in 0..spatial {
|
||||
let axis = 2 + i; // batch=0, channel=1, spatial starts at 2
|
||||
padding[axis] = (
|
||||
Expression::from(pads_begin[i]),
|
||||
Expression::from(pads_end[i]),
|
||||
);
|
||||
}
|
||||
let padded = x.pad(padding, 0.0);
|
||||
|
||||
// Build unfold parameters (ones for batch/channel, actual for spatial)
|
||||
let mut kernel_full = vec![1usize; rank];
|
||||
let mut stride_full = vec![1usize; rank];
|
||||
let mut dilation_full = vec![1usize; rank];
|
||||
for i in 0..spatial {
|
||||
let axis = 2 + i;
|
||||
kernel_full[axis] = kernel_shape[i];
|
||||
stride_full[axis] = strides[i];
|
||||
dilation_full[axis] = dilations[i];
|
||||
}
|
||||
|
||||
let unfolded = padded.unfold(kernel_full, stride_full, dilation_full);
|
||||
// unfolded shape: [win_N, win_C, win_spatial..., k_batch=1, k_chan=1, k_spatial...]
|
||||
// (2*rank dimensions total)
|
||||
|
||||
// Step 1: Permute to [N, win_spatial..., C_in, k_batch, k_chan, k_spatial...]
|
||||
// This groups: batch | output spatial | channel+kernel (for merging)
|
||||
let mut perm: Vec<usize> = Vec::with_capacity(2 * rank);
|
||||
perm.push(0); // win_N (batch)
|
||||
perm.extend(2..2 + spatial); // win_spatial dims
|
||||
perm.push(1); // win_C (= C_in)
|
||||
perm.extend(rank..2 * rank); // all kernel dims: k_batch=1, k_chan=1, k_spatial...
|
||||
let permuted = unfolded.permute(perm);
|
||||
|
||||
// Step 2: Capture output spatial dimensions (win_spatial sizes)
|
||||
let output_spatial_dims: Vec<Expression> = permuted.dims()[1..1 + spatial].to_vec();
|
||||
|
||||
// Step 3: Merge all channel+kernel dims into one (C_in * kernel_product)
|
||||
// From index (1+spatial) to end there are (1 + 2 + spatial) dims to merge
|
||||
let mut patches = permuted;
|
||||
let target_before_spatial_merge = 2 + spatial; // [N, spatial..., merged_patch]
|
||||
while patches.dims().len() > target_before_spatial_merge {
|
||||
let last = patches.dims().len();
|
||||
patches = patches.merge_dims(last - 2, last - 1);
|
||||
}
|
||||
// patches: [N, spatial_0, ..., spatial_{s-1}, C_in * kernel_product]
|
||||
|
||||
// Step 4: Merge spatial dims into one
|
||||
for _ in 1..spatial {
|
||||
patches = patches.merge_dims(1, 2);
|
||||
}
|
||||
// patches: [N, spatial_product, C_in * kernel_product]
|
||||
|
||||
// Step 5: Matmul with weight
|
||||
let mut out = patches.matmul(w_reshaped.permute((1, 0)));
|
||||
// out: [N, spatial_product, C_out]
|
||||
|
||||
// Step 6: Restore spatial dimensions via split_dims
|
||||
// Split from innermost spatial dim first (reverse order, skip outermost)
|
||||
for i in (1..spatial).rev() {
|
||||
out = out.split_dims(1, output_spatial_dims[i]);
|
||||
}
|
||||
// out: [N, spatial_0, spatial_1, ..., spatial_{s-1}, C_out]
|
||||
|
||||
// Step 7: Move C_out from last position to position 1 (after batch)
|
||||
let mut final_order: Vec<usize> = Vec::with_capacity(2 + spatial);
|
||||
final_order.push(0); // batch
|
||||
final_order.push(1 + spatial); // C_out
|
||||
final_order.extend(1..1 + spatial); // spatial dims
|
||||
out = out.permute(final_order);
|
||||
// out: [N, C_out, spatial_0, ..., spatial_{s-1}]
|
||||
|
||||
// Add bias if present: bias shape [C_out], broadcast to [1, C_out, 1, 1, ...]
|
||||
if let Some(b) = bias {
|
||||
let mut bias_expanded = b;
|
||||
// Expand to [1, C_out, 1, 1, ...]
|
||||
bias_expanded = bias_expanded.expand_dim(0, 1); // batch dim
|
||||
for i in 0..spatial {
|
||||
let out_dims = out.dims();
|
||||
let spatial_size = out_dims[2 + i];
|
||||
bias_expanded = bias_expanded.expand_dim(2 + i, spatial_size);
|
||||
}
|
||||
out += bias_expanded;
|
||||
}
|
||||
|
||||
tensors.insert(node.output[0].clone(), out);
|
||||
|
||||
trace!("Finished parse: Conv Node");
|
||||
Ok(())
|
||||
}
|
||||
70
crates/luminal_python/rust/src/ops_parse/matmul.rs
Normal file
70
crates/luminal_python/rust/src/ops_parse/matmul.rs
Normal file
@@ -0,0 +1,70 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use luminal::prelude::{tracing::trace, *};
|
||||
use onnx_protobuf::NodeProto;
|
||||
|
||||
use crate::util::{broadcast_to_expr, get_float_attr, get_int_attr};
|
||||
|
||||
pub fn parse_matmul_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Started parse: MatMul Node");
|
||||
assert!(node.input.len() == 2, "MatMul should have exactly 2 inputs");
|
||||
let a = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("MatMul: missing input tensor '{}'", node.input[0]))?;
|
||||
let b = *tensors
|
||||
.get(&node.input[1])
|
||||
.ok_or_else(|| format!("MatMul: missing input tensor '{}'", node.input[1]))?;
|
||||
|
||||
//TODO: enforce some kind of check here that they are broadcastable
|
||||
let result = a.matmul(b);
|
||||
let output_name = &node.output[0];
|
||||
tensors.insert(output_name.clone(), result);
|
||||
trace!("Finished parse: MatMul Node");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle Gemm node: Y = alpha * (transA ? A.T : A) @ (transB ? B.T : B) + beta * C
|
||||
///
|
||||
/// Attributes: transA (default 0), transB (default 0), alpha (default 1.0), beta (default 1.0)
|
||||
/// Input C (bias) is optional.
|
||||
pub fn parse_gemm_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Started parse: Gemm Node");
|
||||
let a = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("Gemm: missing input A '{}'", node.input[0]))?;
|
||||
let b = *tensors
|
||||
.get(&node.input[1])
|
||||
.ok_or_else(|| format!("Gemm: missing input B '{}'", node.input[1]))?;
|
||||
|
||||
let trans_a = get_int_attr(node, "transA", 0) != 0;
|
||||
let trans_b = get_int_attr(node, "transB", 0) != 0;
|
||||
let alpha = get_float_attr(node, "alpha", 1.0);
|
||||
let beta = get_float_attr(node, "beta", 1.0);
|
||||
|
||||
let a_mat = if trans_a { a.permute(vec![1, 0]) } else { a };
|
||||
let b_mat = if trans_b { b.permute(vec![1, 0]) } else { b };
|
||||
|
||||
let mut result = a_mat.matmul(b_mat);
|
||||
if alpha != 1.0 {
|
||||
result *= alpha;
|
||||
}
|
||||
|
||||
if node.input.len() > 2 && !node.input[2].is_empty() {
|
||||
let c = *tensors
|
||||
.get(&node.input[2])
|
||||
.ok_or_else(|| format!("Gemm: missing bias C '{}'", node.input[2]))?;
|
||||
let c_scaled = if beta != 1.0 { c * beta } else { c };
|
||||
let result_shape = result.dims();
|
||||
result += broadcast_to_expr(c_scaled, &result_shape);
|
||||
}
|
||||
|
||||
tensors.insert(node.output[0].clone(), result);
|
||||
trace!("Finished parse: Gemm Node");
|
||||
Ok(())
|
||||
}
|
||||
15
crates/luminal_python/rust/src/ops_parse/mod.rs
Normal file
15
crates/luminal_python/rust/src/ops_parse/mod.rs
Normal file
@@ -0,0 +1,15 @@
|
||||
pub mod binary;
|
||||
pub mod convolution;
|
||||
pub mod matmul;
|
||||
pub mod movement;
|
||||
pub mod reduction;
|
||||
pub mod tensor;
|
||||
pub mod unary;
|
||||
|
||||
pub use binary::*;
|
||||
pub use convolution::*;
|
||||
pub use matmul::*;
|
||||
pub use movement::*;
|
||||
pub use reduction::*;
|
||||
pub use tensor::*;
|
||||
pub use unary::*;
|
||||
1787
crates/luminal_python/rust/src/ops_parse/movement.rs
Normal file
1787
crates/luminal_python/rust/src/ops_parse/movement.rs
Normal file
File diff suppressed because it is too large
Load Diff
172
crates/luminal_python/rust/src/ops_parse/reduction.rs
Normal file
172
crates/luminal_python/rust/src/ops_parse/reduction.rs
Normal file
@@ -0,0 +1,172 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use luminal::prelude::{tracing::trace, *};
|
||||
use onnx_protobuf::NodeProto;
|
||||
|
||||
use crate::util::get_int_attr;
|
||||
|
||||
/// Handle TopK node: return the top-k values and indices along an axis.
|
||||
///
|
||||
/// output[0] = values (F32), output[1] = indices (Int, can be empty/unused).
|
||||
/// For largest=true (default): uses topk_indexes + gather_elements.
|
||||
/// For largest=false: uses argsort(ascending).slice_along(..k) + gather_elements.
|
||||
/// Indices output is stored as-is (Int dtype); downstream Cast handles F32 conversion.
|
||||
/// The "sorted" attribute is ignored — output is always sorted.
|
||||
pub fn parse_topk_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
known_values: &mut HashMap<String, Vec<f32>>,
|
||||
) -> Result<(), String> {
|
||||
let x = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("TopK: missing input '{}'", node.input[0]))?;
|
||||
let k = known_values
|
||||
.get(&node.input[1])
|
||||
.ok_or("TopK: k must be constant")?[0] as usize;
|
||||
|
||||
let rank = x.dims().len() as i64;
|
||||
let raw_axis = get_int_attr(node, "axis", -1);
|
||||
let axis = if raw_axis < 0 {
|
||||
(raw_axis + rank) as usize
|
||||
} else {
|
||||
raw_axis as usize
|
||||
};
|
||||
|
||||
let largest = get_int_attr(node, "largest", 1) != 0;
|
||||
|
||||
// Compute full argsort, then gather all sorted values, then slice both to top-k.
|
||||
// This avoids passing a non-contiguous sliced index tensor into gather_elements,
|
||||
// which triggers a CUDA kernel bug when data and index sizes differ along the axis.
|
||||
let full_argsort = x.argsort(axis, largest);
|
||||
let indices = full_argsort.slice_along(..k, axis);
|
||||
let values = x.gather_elements(full_argsort, axis).slice_along(..k, axis);
|
||||
|
||||
// ONNX output[0] = values, output[1] = indices
|
||||
if !node.output[0].is_empty() {
|
||||
tensors.insert(node.output[0].clone(), values);
|
||||
}
|
||||
if node.output.len() > 1 && !node.output[1].is_empty() {
|
||||
// Force materialization of Int indices; downstream Cast(INT64→FLOAT) handles the
|
||||
// F32 conversion via the *1.0 workaround in parse_cast_node.
|
||||
tensors.insert(node.output[1].clone(), indices * 1.0);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn parse_reduce_op(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
known_values: &mut HashMap<String, Vec<f32>>,
|
||||
op_name: &str,
|
||||
reduce_op: impl Fn(GraphTensor, Vec<usize>) -> GraphTensor,
|
||||
all_axes_op: impl Fn(GraphTensor, usize) -> GraphTensor,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: {} Node", op_name);
|
||||
assert!(
|
||||
!node.input.is_empty(),
|
||||
"{} should have at least 1 input",
|
||||
op_name
|
||||
);
|
||||
assert!(
|
||||
node.output.len() == 1,
|
||||
"{} should have exactly 1 output",
|
||||
op_name
|
||||
);
|
||||
|
||||
let input = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("{}: missing input tensor '{}'", op_name, node.input[0]))?;
|
||||
|
||||
let keepdims = get_int_attr(node, "keepdims", 1) != 0;
|
||||
let noop_with_empty_axes = get_int_attr(node, "noop_with_empty_axes", 0) != 0;
|
||||
|
||||
let ndim = input.dims().len();
|
||||
|
||||
// Resolve axes from second input (opset 13+) or from attribute (opset 11)
|
||||
let raw_axes: Vec<i64> = if node.input.len() > 1 && !node.input[1].is_empty() {
|
||||
let axes_vals = known_values.get(&node.input[1]).ok_or_else(|| {
|
||||
format!(
|
||||
"{}: axes input '{}' must be a known constant",
|
||||
op_name, node.input[1]
|
||||
)
|
||||
})?;
|
||||
axes_vals.iter().map(|&v| v as i64).collect()
|
||||
} else if let Some(attr) = node.attribute.iter().find(|a| a.name == "axes") {
|
||||
attr.ints.clone()
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
|
||||
let output_name = &node.output[0];
|
||||
|
||||
// Handle empty axes: noop or reduce all
|
||||
let raw_axes: Vec<i64> = if raw_axes.is_empty() {
|
||||
if noop_with_empty_axes {
|
||||
tensors.insert(output_name.clone(), input);
|
||||
trace!("Finished parse: {} Node (noop)", op_name);
|
||||
return Ok(());
|
||||
} else {
|
||||
(0..ndim as i64).collect()
|
||||
}
|
||||
} else {
|
||||
raw_axes
|
||||
};
|
||||
|
||||
// Normalize negative axes and convert to usize
|
||||
let mut normalized_axes: Vec<usize> = raw_axes
|
||||
.iter()
|
||||
.map(|&a| {
|
||||
if a < 0 {
|
||||
(ndim as i64 + a) as usize
|
||||
} else {
|
||||
a as usize
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
normalized_axes.sort();
|
||||
normalized_axes.dedup();
|
||||
|
||||
// Save original sorted axes for keepdims unsqueeze bookkeeping
|
||||
let sorted_axes = normalized_axes.clone();
|
||||
|
||||
let input_dims = input.dims();
|
||||
|
||||
if normalized_axes.len() == ndim {
|
||||
// All-axes reduction: flatten to [1, N] and reduce axis 1 → [1].
|
||||
// luminal's Expression::product() returns 0 for empty iterators, so a reduce
|
||||
// producing a 0-dim tensor causes CUDA to launch with grid (0,1,1), which is
|
||||
// invalid. Using [1, N] → reduce(1) → [1] avoids this entirely.
|
||||
let total: usize = input_dims
|
||||
.iter()
|
||||
.map(|d| d.to_usize().expect("reduce: dim must be concrete"))
|
||||
.product();
|
||||
let mut flat = input;
|
||||
flat.shape = ShapeTracker::new(vec![1, total]);
|
||||
let mut result = all_axes_op(flat, total);
|
||||
|
||||
if keepdims {
|
||||
// Insert (ndim-1) additional size-1 dims to produce [1]*ndim
|
||||
for i in 1..ndim {
|
||||
result = result.unsqueeze(i);
|
||||
}
|
||||
}
|
||||
|
||||
tensors.insert(output_name.clone(), result);
|
||||
trace!("Finished parse: {} Node (all-axes)", op_name);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Partial reduction: luminal's ToAxes API handles axis shifting internally
|
||||
let mut result = reduce_op(input, normalized_axes);
|
||||
|
||||
// Re-insert size-1 dims at original positions (ascending order keeps positions correct)
|
||||
if keepdims {
|
||||
for &axis in &sorted_axes {
|
||||
result = result.unsqueeze(axis);
|
||||
}
|
||||
}
|
||||
|
||||
tensors.insert(output_name.clone(), result);
|
||||
trace!("Finished parse: {} Node", op_name);
|
||||
Ok(())
|
||||
}
|
||||
453
crates/luminal_python/rust/src/ops_parse/tensor.rs
Normal file
453
crates/luminal_python/rust/src/ops_parse/tensor.rs
Normal file
@@ -0,0 +1,453 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use luminal::{
|
||||
prelude::{tracing::trace, *},
|
||||
shape::Expression,
|
||||
};
|
||||
use onnx_protobuf::NodeProto;
|
||||
|
||||
use crate::util::{broadcast_to_expr, get_int_attr};
|
||||
|
||||
/// Handle Constant node: creates a tensor from embedded data in the node attributes.
|
||||
///
|
||||
/// Supports FLOAT, INT64, INT32, and FLOAT64 data types (all converted to f32).
|
||||
/// The resulting tensor is registered as a known constant for downstream folding.
|
||||
pub fn parse_constant_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
cx: &mut Graph,
|
||||
weight_data: &mut Vec<(String, Vec<f32>)>,
|
||||
known_values: &mut HashMap<String, Vec<f32>>,
|
||||
shape_exprs: &mut HashMap<String, Vec<Expression>>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: Constant Node");
|
||||
assert!(
|
||||
node.output.len() == 1,
|
||||
"Constant should have exactly one output"
|
||||
);
|
||||
|
||||
// Find the "value" attribute (type TENSOR)
|
||||
let value_attr = node
|
||||
.attribute
|
||||
.iter()
|
||||
.find(|a| a.name == "value")
|
||||
.ok_or_else(|| "Constant node missing 'value' attribute".to_string())?;
|
||||
|
||||
let tensor_proto = value_attr
|
||||
.t
|
||||
.as_ref()
|
||||
.ok_or_else(|| "Constant 'value' attribute has no TensorProto".to_string())?;
|
||||
|
||||
// Determine shape: empty dims = scalar = [1] for luminal
|
||||
let shape: Vec<usize> = if tensor_proto.dims.is_empty() {
|
||||
vec![1]
|
||||
} else {
|
||||
tensor_proto.dims.iter().map(|&d| d as usize).collect()
|
||||
};
|
||||
|
||||
// Extract float data based on data_type
|
||||
let floats: Vec<f32> = match tensor_proto.data_type {
|
||||
1 => {
|
||||
// FLOAT (f32)
|
||||
if !tensor_proto.float_data.is_empty() {
|
||||
tensor_proto.float_data.clone()
|
||||
} else {
|
||||
tensor_proto
|
||||
.raw_data
|
||||
.chunks_exact(4)
|
||||
.map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
6 => {
|
||||
// INT32
|
||||
if !tensor_proto.int32_data.is_empty() {
|
||||
tensor_proto.int32_data.iter().map(|&v| v as f32).collect()
|
||||
} else {
|
||||
tensor_proto
|
||||
.raw_data
|
||||
.chunks_exact(4)
|
||||
.map(|c| i32::from_le_bytes([c[0], c[1], c[2], c[3]]) as f32)
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
7 => {
|
||||
// INT64
|
||||
if !tensor_proto.int64_data.is_empty() {
|
||||
tensor_proto.int64_data.iter().map(|&v| v as f32).collect()
|
||||
} else {
|
||||
tensor_proto
|
||||
.raw_data
|
||||
.chunks_exact(8)
|
||||
.map(|c| {
|
||||
i64::from_le_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]]) as f32
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
dt => return Err(format!("Constant node: unsupported data_type {}", dt)),
|
||||
};
|
||||
|
||||
let output_name = &node.output[0];
|
||||
let tensor = cx.named_tensor(output_name.clone(), shape);
|
||||
tensors.insert(output_name.clone(), tensor);
|
||||
known_values.insert(output_name.clone(), floats.clone());
|
||||
// Also propagate as concrete shape_exprs for downstream shape computation chains
|
||||
shape_exprs.insert(
|
||||
output_name.clone(),
|
||||
floats
|
||||
.iter()
|
||||
.map(|&v| Expression::from(v as usize))
|
||||
.collect(),
|
||||
);
|
||||
weight_data.push((output_name.clone(), floats));
|
||||
|
||||
trace!("Finished parse: Constant Node");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle Shape node: extract the shape of the input tensor as a 1D constant.
|
||||
///
|
||||
/// For static shapes, stores as known_values. For dynamic shapes (containing
|
||||
/// Expression variables), stores in shape_exprs for downstream shape computation chains.
|
||||
pub fn parse_shape_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
cx: &mut Graph,
|
||||
weight_data: &mut Vec<(String, Vec<f32>)>,
|
||||
known_values: &mut HashMap<String, Vec<f32>>,
|
||||
shape_exprs: &mut HashMap<String, Vec<Expression>>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Started parse: Shape");
|
||||
assert!(node.input.len() == 1, "Shape should have exactly 1 input");
|
||||
let input = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("Shape: missing input tensor '{}'", node.input[0]))?;
|
||||
|
||||
let all_dims = input.dims();
|
||||
|
||||
// Handle start/end attributes (ONNX Shape opset 15+: extract a slice of dims)
|
||||
let start = get_int_attr(node, "start", 0) as usize;
|
||||
let end_attr = get_int_attr(node, "end", all_dims.len() as i64);
|
||||
let end = if end_attr < 0 {
|
||||
(all_dims.len() as i64 + end_attr) as usize
|
||||
} else {
|
||||
(end_attr as usize).min(all_dims.len())
|
||||
};
|
||||
let dims: Vec<Expression> = all_dims[start..end].to_vec();
|
||||
|
||||
let output_name = &node.output[0];
|
||||
|
||||
// Always store in shape_exprs (supports both concrete and symbolic dims)
|
||||
shape_exprs.insert(output_name.clone(), dims.clone());
|
||||
|
||||
// For concrete dims, also store in known_values for backward compat
|
||||
let all_concrete = dims.iter().all(|d| d.to_usize().is_some());
|
||||
let shape_values: Vec<f32> = dims
|
||||
.iter()
|
||||
.map(|d| d.to_usize().unwrap_or(1) as f32)
|
||||
.collect();
|
||||
|
||||
if all_concrete {
|
||||
// Concrete shape: create tensor + known_values + weight_data
|
||||
let tensor = cx.named_tensor(output_name.clone(), vec![shape_values.len()]);
|
||||
tensors.insert(output_name.clone(), tensor);
|
||||
known_values.insert(output_name.clone(), shape_values.clone());
|
||||
weight_data.push((output_name.clone(), shape_values));
|
||||
}
|
||||
// For symbolic shapes, don't create a tensor — it's shape-only
|
||||
|
||||
trace!("Finished parse: Shape");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle ConstantOfShape node: creates a tensor of a given shape filled with a constant value.
|
||||
///
|
||||
/// The shape is taken from the input tensor (which must be a known constant).
|
||||
/// The fill value comes from the "value" attribute (default 0.0).
|
||||
pub fn parse_constant_of_shape(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
cx: &mut Graph,
|
||||
weight_data: &mut Vec<(String, Vec<f32>)>,
|
||||
known_values: &mut HashMap<String, Vec<f32>>,
|
||||
shape_exprs: &mut HashMap<String, Vec<Expression>>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: ConstantOfShape Node");
|
||||
assert!(
|
||||
node.input.len() == 1,
|
||||
"ConstantOfShape should have exactly one input (shape)"
|
||||
);
|
||||
assert!(
|
||||
node.output.len() == 1,
|
||||
"ConstantOfShape should have exactly one output"
|
||||
);
|
||||
|
||||
// Extract fill value from "value" attribute (TensorProto scalar), default 0.0
|
||||
let fill_value: f32 = node
|
||||
.attribute
|
||||
.iter()
|
||||
.find(|a| a.name == "value")
|
||||
.and_then(|attr| attr.t.as_ref())
|
||||
.map(|tp| {
|
||||
if !tp.float_data.is_empty() {
|
||||
tp.float_data[0]
|
||||
} else if !tp.int32_data.is_empty() {
|
||||
tp.int32_data[0] as f32
|
||||
} else if !tp.raw_data.is_empty() {
|
||||
match tp.data_type {
|
||||
1 => f32::from_le_bytes([
|
||||
tp.raw_data[0],
|
||||
tp.raw_data[1],
|
||||
tp.raw_data[2],
|
||||
tp.raw_data[3],
|
||||
]),
|
||||
6 => i32::from_le_bytes([
|
||||
tp.raw_data[0],
|
||||
tp.raw_data[1],
|
||||
tp.raw_data[2],
|
||||
tp.raw_data[3],
|
||||
]) as f32,
|
||||
7 => i64::from_le_bytes([
|
||||
tp.raw_data[0],
|
||||
tp.raw_data[1],
|
||||
tp.raw_data[2],
|
||||
tp.raw_data[3],
|
||||
tp.raw_data[4],
|
||||
tp.raw_data[5],
|
||||
tp.raw_data[6],
|
||||
tp.raw_data[7],
|
||||
]) as f32,
|
||||
_ => 0.0,
|
||||
}
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
})
|
||||
.unwrap_or(0.0);
|
||||
|
||||
let output_name = &node.output[0];
|
||||
|
||||
// Try shape_exprs first (for dynamic shapes), then known_values
|
||||
if let Some(se) = shape_exprs.get(&node.input[0]) {
|
||||
let shape: Vec<Expression> = se.clone();
|
||||
|
||||
// Check if all dims are concrete
|
||||
if let Some(concrete) = shape
|
||||
.iter()
|
||||
.map(|e| e.to_usize())
|
||||
.collect::<Option<Vec<usize>>>()
|
||||
{
|
||||
// Fully concrete: create named tensor with weight data
|
||||
let numel: usize = concrete.iter().product();
|
||||
let floats: Vec<f32> = vec![fill_value; numel];
|
||||
let tensor = cx.named_tensor(output_name.clone(), concrete);
|
||||
tensors.insert(output_name.clone(), tensor);
|
||||
known_values.insert(output_name.clone(), floats.clone());
|
||||
weight_data.push((output_name.clone(), floats));
|
||||
} else {
|
||||
// Dynamic shape: create scalar constant and broadcast to symbolic shape.
|
||||
// The scalar always has concrete data (1 element), and the shape is
|
||||
// resolved at runtime via ShapeTracker/dyn_map. Broadcast uses stride-0
|
||||
// expansion, so only 1 float is needed in the backing buffer.
|
||||
let scalar = cx.constant_float(fill_value);
|
||||
let result = broadcast_to_expr(scalar, se);
|
||||
tensors.insert(output_name.clone(), result);
|
||||
}
|
||||
} else {
|
||||
let shape_values = known_values.get(&node.input[0]).ok_or_else(|| {
|
||||
format!(
|
||||
"ConstantOfShape: shape input '{}' must be a known constant or shape_expr",
|
||||
node.input[0]
|
||||
)
|
||||
})?;
|
||||
let shape: Vec<usize> = shape_values.iter().map(|&v| v as usize).collect();
|
||||
let numel: usize = shape.iter().product();
|
||||
let floats: Vec<f32> = vec![fill_value; numel];
|
||||
|
||||
let tensor = cx.named_tensor(output_name.clone(), shape);
|
||||
tensors.insert(output_name.clone(), tensor);
|
||||
known_values.insert(output_name.clone(), floats.clone());
|
||||
weight_data.push((output_name.clone(), floats));
|
||||
}
|
||||
|
||||
trace!("Finished parse: ConstantOfShape Node");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle Identity node: output is a direct alias of the input tensor.
|
||||
///
|
||||
/// Propagates known constant values for downstream constant folding.
|
||||
pub fn parse_identity(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
known_values: &mut HashMap<String, Vec<f32>>,
|
||||
shape_exprs: &mut HashMap<String, Vec<Expression>>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: Identity Node");
|
||||
assert!(node.input.len() == 1, "Identity should only have one input");
|
||||
let a = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("Identity: missing input tensor '{}'", node.input[0]))?;
|
||||
|
||||
assert!(
|
||||
node.output.len() == 1,
|
||||
"Identity should only have a single output"
|
||||
);
|
||||
|
||||
let output_name = &node.output[0];
|
||||
|
||||
// Force materialization using Expression-aware broadcast
|
||||
let dims = a.dims();
|
||||
let one = a.graph().constant_float(1.0);
|
||||
let one_expanded = broadcast_to_expr(one, &dims);
|
||||
let result = a * one_expanded;
|
||||
tensors.insert(output_name.clone(), result);
|
||||
|
||||
// Propagate known values
|
||||
if let Some(vals) = known_values.get(&node.input[0]).cloned() {
|
||||
known_values.insert(output_name.clone(), vals);
|
||||
}
|
||||
// Propagate shape_exprs
|
||||
if let Some(se) = shape_exprs.get(&node.input[0]).cloned() {
|
||||
shape_exprs.insert(output_name.clone(), se);
|
||||
}
|
||||
|
||||
trace!("Finished parse: Identity Node");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle Range node: creates a 1D tensor [start, start+delta, start+2*delta, ...] up to limit.
|
||||
///
|
||||
/// Used by dynamo ONNX export for generating position indices (arange).
|
||||
/// Supports Expression-based limits for dynamic sequence lengths.
|
||||
pub fn parse_range_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
cx: &mut Graph,
|
||||
weight_data: &mut Vec<(String, Vec<f32>)>,
|
||||
known_values: &mut HashMap<String, Vec<f32>>,
|
||||
shape_exprs: &mut HashMap<String, Vec<Expression>>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: Range Node");
|
||||
assert!(
|
||||
node.input.len() == 3,
|
||||
"Range needs 3 inputs: start, limit, delta"
|
||||
);
|
||||
|
||||
let output_name = &node.output[0];
|
||||
|
||||
// Try to get concrete values from known_values first
|
||||
let start_val = known_values
|
||||
.get(&node.input[0])
|
||||
.and_then(|v| v.first().copied());
|
||||
let limit_val = known_values
|
||||
.get(&node.input[1])
|
||||
.and_then(|v| v.first().copied());
|
||||
let delta_val = known_values
|
||||
.get(&node.input[2])
|
||||
.and_then(|v| v.first().copied());
|
||||
|
||||
// Also check shape_exprs for symbolic limit
|
||||
let limit_expr = shape_exprs
|
||||
.get(&node.input[1])
|
||||
.and_then(|v| v.first().cloned());
|
||||
|
||||
let start = start_val.unwrap_or(0.0);
|
||||
let delta = delta_val.unwrap_or(1.0);
|
||||
|
||||
if start == 0.0 && delta == 1.0 {
|
||||
// Simple arange case — most common for position indices
|
||||
if let Some(expr) = limit_expr {
|
||||
// Dynamic limit: create arange with symbolic length
|
||||
let tensor = cx.arange(expr);
|
||||
// Cast to F32 (luminal arange returns Int dtype)
|
||||
let result = tensor.cast(DType::F32);
|
||||
tensors.insert(output_name.clone(), result);
|
||||
shape_exprs.insert(output_name.clone(), vec![expr]);
|
||||
} else if let Some(limit) = limit_val {
|
||||
let n = limit as usize;
|
||||
let floats: Vec<f32> = (0..n).map(|i| i as f32).collect();
|
||||
let tensor = cx.named_tensor(output_name.clone(), vec![n]);
|
||||
tensors.insert(output_name.clone(), tensor);
|
||||
known_values.insert(output_name.clone(), floats.clone());
|
||||
weight_data.push((output_name.clone(), floats));
|
||||
} else {
|
||||
return Err("Range: limit must be known or symbolic".to_string());
|
||||
}
|
||||
} else if let (Some(s), Some(l), Some(d)) = (start_val, limit_val, delta_val) {
|
||||
// Fully concrete range
|
||||
let mut floats = Vec::new();
|
||||
let mut v = s;
|
||||
while (d > 0.0 && v < l) || (d < 0.0 && v > l) {
|
||||
floats.push(v);
|
||||
v += d;
|
||||
}
|
||||
let tensor = cx.named_tensor(output_name.clone(), vec![floats.len()]);
|
||||
tensors.insert(output_name.clone(), tensor);
|
||||
known_values.insert(output_name.clone(), floats.clone());
|
||||
weight_data.push((output_name.clone(), floats));
|
||||
} else {
|
||||
return Err("Range: cannot handle non-trivial dynamic ranges yet".to_string());
|
||||
}
|
||||
|
||||
trace!("Finished parse: Range Node");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle CumSum node: cumulative sum along an axis.
|
||||
///
|
||||
/// For the simple case of axis=0 on a 1D tensor [0, 1, 2, ...] (position indices),
|
||||
/// the cumsum is equivalent to [0, 1, 3, 6, ...]. For dynamic ONNX graphs,
|
||||
/// this is typically used for position_ids computation.
|
||||
pub fn parse_cumsum_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
known_values: &mut HashMap<String, Vec<f32>>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: CumSum Node");
|
||||
assert!(node.input.len() >= 2, "CumSum needs at least 2 inputs");
|
||||
|
||||
let input = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("CumSum: missing input '{}'", node.input[0]))?;
|
||||
|
||||
let axis_val = known_values
|
||||
.get(&node.input[1])
|
||||
.and_then(|v| v.first().copied())
|
||||
.unwrap_or(0.0) as i64;
|
||||
|
||||
let dims = input.dims();
|
||||
let ndim = dims.len();
|
||||
let _axis = if axis_val < 0 {
|
||||
(ndim as i64 + axis_val) as usize
|
||||
} else {
|
||||
axis_val as usize
|
||||
};
|
||||
|
||||
// For constant folding
|
||||
if let Some(vals) = known_values.get(&node.input[0]).cloned() {
|
||||
let output_name = &node.output[0];
|
||||
let mut cumsum = vals.clone();
|
||||
// Simple 1D cumsum
|
||||
if ndim == 1 {
|
||||
for i in 1..cumsum.len() {
|
||||
cumsum[i] += cumsum[i - 1];
|
||||
}
|
||||
}
|
||||
known_values.insert(output_name.clone(), cumsum);
|
||||
// Just alias the tensor (same shape)
|
||||
tensors.insert(output_name.clone(), input);
|
||||
trace!("Finished parse: CumSum Node (constant folded)");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// For dynamic: cumsum is hard to express in luminal primitives.
|
||||
// For the specific pattern used in Llama position_ids (cumsum of ones = arange),
|
||||
// we just pass through since arange is already handled by Range node.
|
||||
let output_name = &node.output[0];
|
||||
tensors.insert(output_name.clone(), input);
|
||||
|
||||
trace!("Finished parse: CumSum Node");
|
||||
Ok(())
|
||||
}
|
||||
440
crates/luminal_python/rust/src/ops_parse/unary.rs
Normal file
440
crates/luminal_python/rust/src/ops_parse/unary.rs
Normal file
@@ -0,0 +1,440 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use luminal::{
|
||||
prelude::{tracing::trace, *},
|
||||
shape::Expression,
|
||||
};
|
||||
use onnx_protobuf::NodeProto;
|
||||
|
||||
use crate::util::{broadcast_to_expr, get_float_attr, get_int_attr};
|
||||
|
||||
/// Handle Softmax node: output = softmax(input[0], axis)
|
||||
///
|
||||
/// ONNX axis attribute defaults to -1 (last dimension, opset 13+).
|
||||
/// Negative axis is normalized against the input rank.
|
||||
pub fn parse_softmax_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: Softmax Node");
|
||||
assert!(
|
||||
node.input.len() == 1,
|
||||
"Softmax nodes need to have one input, {} where present",
|
||||
node.input.len()
|
||||
);
|
||||
assert!(
|
||||
node.output.len() == 1,
|
||||
"Softmax nodes only have one output, {} where present",
|
||||
node.output.len(),
|
||||
);
|
||||
let output_name = &node.output[0];
|
||||
let a = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("Softmax: missing input tensor '{}'", node.input[0]))?;
|
||||
|
||||
let ndim = a.dims().len();
|
||||
let raw_axis = get_int_attr(node, "axis", -1);
|
||||
let axis = if raw_axis < 0 {
|
||||
(ndim as i64 + raw_axis) as usize
|
||||
} else {
|
||||
raw_axis as usize
|
||||
};
|
||||
|
||||
let result = a.softmax(axis);
|
||||
tensors.insert(output_name.clone(), result);
|
||||
trace!("Finished parse: Softmax Node");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle Not node: logical NOT — output = 1.0 - input[0]
|
||||
pub fn parse_not_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: Not Node");
|
||||
assert!(
|
||||
node.input.len() == 1,
|
||||
"Not nodes need to have one input {} where present",
|
||||
node.input.len()
|
||||
);
|
||||
assert!(
|
||||
node.output.len() == 1,
|
||||
"Not nodes only have one output, {} where present",
|
||||
node.output.len()
|
||||
);
|
||||
let a = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("Not: missing input tensor '{}'", node.input[0]))?;
|
||||
let a_f32 = a.cast(DType::F32);
|
||||
let result = 1.0_f32 - a_f32;
|
||||
tensors.insert(node.output[0].clone(), result);
|
||||
trace!("Finished parse: Not Node");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle Clip node: output = clip(input[0], min, max)
|
||||
///
|
||||
/// Equivalent to torch.clamp. min and max are optional tensor inputs
|
||||
/// (typically constants) residing in known_values.
|
||||
pub fn parse_clip_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
known_values: &HashMap<String, Vec<f32>>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: Clip Node");
|
||||
let output_name = &node.output[0];
|
||||
let a = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("Clip: missing input tensor '{}'", node.input[0]))?;
|
||||
|
||||
// input[1] = min (optional), input[2] = max (optional)
|
||||
let min_name = node.input.get(1).map(String::as_str).unwrap_or("");
|
||||
let max_name = node.input.get(2).map(String::as_str).unwrap_or("");
|
||||
|
||||
let min_val = if min_name.is_empty() {
|
||||
None
|
||||
} else {
|
||||
known_values.get(min_name).map(|v| v[0])
|
||||
};
|
||||
let max_val = if max_name.is_empty() {
|
||||
None
|
||||
} else {
|
||||
known_values.get(max_name).map(|v| v[0])
|
||||
};
|
||||
|
||||
let result = match (min_val, max_val) {
|
||||
(Some(lo), Some(hi)) => a.clip(lo, hi),
|
||||
(Some(lo), None) => a.maximum_f32(lo),
|
||||
(None, Some(hi)) => a.minimum_f32(hi),
|
||||
(None, None) => a,
|
||||
};
|
||||
|
||||
tensors.insert(output_name.clone(), result);
|
||||
trace!("Finished parse: Clip Node");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle Floor node: output = floor(input[0])
|
||||
///
|
||||
/// Implemented as: trunc(x) - (x < trunc(x) ? 1 : 0)
|
||||
/// where trunc is truncation toward zero via cast to Int then back to F32.
|
||||
/// This correctly handles negative non-integer values (e.g. floor(-1.5) = -2).
|
||||
pub fn parse_floor_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: Floor Node");
|
||||
assert!(
|
||||
node.input.len() == 1,
|
||||
"Floor nodes need to have one input {} where present",
|
||||
node.input.len()
|
||||
);
|
||||
assert!(
|
||||
node.output.len() == 1,
|
||||
"Floor nodes only have one output, {} where present",
|
||||
node.output.len(),
|
||||
);
|
||||
let output_name = &node.output[0];
|
||||
let a = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("Floor: missing input tensor '{}'", node.input[0]))?;
|
||||
|
||||
// trunc(x): truncation toward zero
|
||||
let trunc = a.cast(DType::Int).cast(DType::F32);
|
||||
// For negative non-integers, x < trunc(x), so subtract 1
|
||||
// Cast lt result (Bool) to F32 before arithmetic
|
||||
let adjustment = a.lt(trunc).cast(DType::F32);
|
||||
let result = trunc - adjustment;
|
||||
tensors.insert(output_name.clone(), result);
|
||||
trace!("Finished parse: Floor Node");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle Ceil node: output = ceil(input[0])
|
||||
///
|
||||
/// Implemented as: trunc(x) + (x > trunc(x) ? 1 : 0)
|
||||
/// where trunc is truncation toward zero via cast to Int then back to F32.
|
||||
/// This correctly handles positive non-integer values (e.g. ceil(1.5) = 2).
|
||||
pub fn parse_ceil_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: Ceil Node");
|
||||
assert!(
|
||||
node.input.len() == 1,
|
||||
"Ceil nodes need to have one input {} where present",
|
||||
node.input.len()
|
||||
);
|
||||
assert!(
|
||||
node.output.len() == 1,
|
||||
"Ceil nodes only have one output, {} where present",
|
||||
node.output.len(),
|
||||
);
|
||||
let output_name = &node.output[0];
|
||||
let a = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("Ceil: missing input tensor '{}'", node.input[0]))?;
|
||||
|
||||
// trunc(x): truncation toward zero
|
||||
let trunc = a.cast(DType::Int).cast(DType::F32);
|
||||
// For positive non-integers, x > trunc(x), so add 1
|
||||
let adjustment = a.gt(trunc).cast(DType::F32);
|
||||
let result = trunc + adjustment;
|
||||
tensors.insert(output_name.clone(), result);
|
||||
trace!("Finished parse: Ceil Node");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn parse_cast_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
weight_data: &mut Vec<(String, Vec<f32>)>,
|
||||
known_values: &mut HashMap<String, Vec<f32>>,
|
||||
shape_exprs: &mut HashMap<String, Vec<Expression>>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: Cast Node");
|
||||
assert!(node.input.len() == 1, "Cast should have exactly 1 input");
|
||||
let input = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("Cast: missing input tensor '{}'", node.input[0]))?;
|
||||
|
||||
// ONNX data type enum → luminal DType
|
||||
let to = get_int_attr(node, "to", 1);
|
||||
let dtype = match to {
|
||||
1 => DType::F32, // FLOAT
|
||||
10 => DType::F16, // FLOAT16
|
||||
16 => DType::Bf16, // BFLOAT16
|
||||
6 | 7 => DType::Int, // INT32, INT64
|
||||
9 => DType::F32, // BOOL → treat as F32 (0.0/1.0)
|
||||
11 => DType::F32, // DOUBLE → F32 (downcast)
|
||||
_ => DType::F32, // fallback
|
||||
};
|
||||
|
||||
let cast_result = input.cast(dtype);
|
||||
let output_name = &node.output[0];
|
||||
|
||||
let result = if cast_result.id == input.id {
|
||||
input
|
||||
} else {
|
||||
cast_result
|
||||
};
|
||||
|
||||
tensors.insert(output_name.clone(), result);
|
||||
|
||||
// Propagate known values (cast is a no-op for our f32 storage)
|
||||
if let Some(vals) = known_values.get(&node.input[0]).cloned() {
|
||||
let folded = if to == 9 {
|
||||
vals.iter()
|
||||
.map(|&v| if v != 0.0 { 1.0 } else { 0.0 })
|
||||
.collect()
|
||||
} else if to == 6 || to == 7 {
|
||||
vals.iter().map(|&v| (v as i64) as f32).collect()
|
||||
} else {
|
||||
vals
|
||||
};
|
||||
known_values.insert(output_name.clone(), folded.clone());
|
||||
weight_data.push((output_name.clone(), folded));
|
||||
}
|
||||
// Propagate shape_exprs
|
||||
if let Some(se) = shape_exprs.get(&node.input[0]).cloned() {
|
||||
shape_exprs.insert(output_name.clone(), se);
|
||||
}
|
||||
|
||||
trace!("Finished parse: Cast Node");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn parse_unary_op(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
op_name: &str,
|
||||
op: impl Fn(GraphTensor) -> GraphTensor,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: {} Node", op_name);
|
||||
assert!(
|
||||
node.input.len() == 1,
|
||||
"{} should have 1 input, got {}",
|
||||
op_name,
|
||||
node.input.len()
|
||||
);
|
||||
assert!(
|
||||
node.output.len() == 1,
|
||||
"{} should have 1 output, got {}",
|
||||
op_name,
|
||||
node.output.len()
|
||||
);
|
||||
let a = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("{}: missing input tensor '{}'", op_name, node.input[0]))?;
|
||||
let result = op(a);
|
||||
tensors.insert(node.output[0].clone(), result);
|
||||
trace!("Finished parse: {} Node", op_name);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle Erf node: output = erf(input[0])
|
||||
///
|
||||
/// Uses the Abramowitz & Stegun 7.1.26 polynomial approximation (max error < 1.5e-7):
|
||||
/// For x ≥ 0: erf(x) ≈ 1 - (a1·t + a2·t² + a3·t³ + a4·t⁴ + a5·t⁵) · exp(-x²)
|
||||
/// where t = 1 / (1 + 0.3275911·x)
|
||||
/// a1 = 0.254829592
|
||||
/// a2 = -0.284496736
|
||||
/// a3 = 1.421413741
|
||||
/// a4 = -1.453152027
|
||||
/// a5 = 1.061405429
|
||||
/// Extended to all x via odd symmetry: erf(-x) = -erf(x).
|
||||
pub fn parse_erf_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
) -> Result<(), String> {
|
||||
parse_unary_op(node, tensors, "Erf", |x| {
|
||||
let a = x.abs();
|
||||
let t = (1.0_f32 + 0.3275911_f32 * a).reciprocal();
|
||||
// Horner evaluation of a1*t + a2*t² + a3*t³ + a4*t⁴ + a5*t⁵
|
||||
// poly = t*(a1 + t*(a2 + t*(a3 + t*(a4 + a5*t))))
|
||||
let h = t * 1.061_405_4_f32 - 1.453_152_1_f32; // a4 + a5*t
|
||||
let h = t * h + 1.421_413_8_f32;
|
||||
let h = t * h - 0.284_496_72_f32;
|
||||
let h = t * h + 0.254_829_6_f32;
|
||||
let poly = t * h;
|
||||
let erf_abs = 1.0_f32 - poly * (-a * a).exp();
|
||||
x.sign() * erf_abs
|
||||
})
|
||||
}
|
||||
|
||||
/// Handle LayerNormalization node (opset 17).
|
||||
///
|
||||
/// Inputs: X (required), scale (required), bias (optional)
|
||||
/// Attributes: axis (default -1), epsilon (default 1e-5)
|
||||
/// Normalizes over axes [axis, axis+1, ..., rank-1], then applies scale and bias.
|
||||
/// Only output 0 (the normalized result) is wired; outputs 1/2 (mean, inv_std_var)
|
||||
/// are training-only and not supported for inference.
|
||||
pub fn parse_layernorm_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: LayerNormalization Node");
|
||||
let input = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("LayerNorm: missing input '{}'", node.input[0]))?;
|
||||
let scale = *tensors
|
||||
.get(&node.input[1])
|
||||
.ok_or_else(|| format!("LayerNorm: missing scale '{}'", node.input[1]))?;
|
||||
|
||||
let ndim = input.dims().len();
|
||||
let axis_raw = get_int_attr(node, "axis", -1);
|
||||
let axis = if axis_raw < 0 {
|
||||
(ndim as i64 + axis_raw) as usize
|
||||
} else {
|
||||
axis_raw as usize
|
||||
};
|
||||
let epsilon = get_float_attr(node, "epsilon", 1e-5);
|
||||
let axes: Vec<usize> = (axis..ndim).collect();
|
||||
|
||||
let mut result = input.layer_norm(axes, epsilon);
|
||||
|
||||
// Apply scale (broadcast to input shape using Expression-aware broadcast)
|
||||
let input_shape = input.dims();
|
||||
result *= broadcast_to_expr(scale, &input_shape);
|
||||
|
||||
// Apply optional bias
|
||||
if node.input.len() > 2 && !node.input[2].is_empty() {
|
||||
let bias = *tensors
|
||||
.get(&node.input[2])
|
||||
.ok_or_else(|| format!("LayerNorm: missing bias '{}'", node.input[2]))?;
|
||||
result += broadcast_to_expr(bias, &input_shape);
|
||||
}
|
||||
|
||||
tensors.insert(node.output[0].clone(), result);
|
||||
trace!("Finished parse: LayerNormalization Node");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle GroupNormalization node (opset 18).
|
||||
///
|
||||
/// Inputs: X [N, C, spatial...], scale [num_groups], bias [num_groups]
|
||||
/// Attributes: num_groups (required), epsilon (default 1e-5)
|
||||
///
|
||||
/// Normalizes over channels-per-group and spatial dims, then applies per-group scale/bias.
|
||||
/// Decomposed into: reshape [N, G, C/G, spatial...] -> layer_norm over [C/G, spatial...] ->
|
||||
/// reshape back to [N, C, spatial...] -> scale + bias (broadcast).
|
||||
pub fn parse_group_norm_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: GroupNormalization Node");
|
||||
|
||||
assert!(
|
||||
node.input.len() >= 3,
|
||||
"GroupNormalization needs 3 inputs (X, scale, bias), got {}",
|
||||
node.input.len()
|
||||
);
|
||||
|
||||
let x = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("GroupNorm: missing input X '{}'", node.input[0]))?;
|
||||
let scale = *tensors
|
||||
.get(&node.input[1])
|
||||
.ok_or_else(|| format!("GroupNorm: missing scale '{}'", node.input[1]))?;
|
||||
let bias = *tensors
|
||||
.get(&node.input[2])
|
||||
.ok_or_else(|| format!("GroupNorm: missing bias '{}'", node.input[2]))?;
|
||||
|
||||
let x_dims = x.dims();
|
||||
let ndim = x_dims.len();
|
||||
assert!(
|
||||
ndim >= 3,
|
||||
"GroupNorm: input must be at least 3D [N, C, spatial...], got {ndim}D"
|
||||
);
|
||||
|
||||
let num_groups = get_int_attr(node, "num_groups", 1) as usize;
|
||||
let epsilon = get_float_attr(node, "epsilon", 1e-5);
|
||||
|
||||
let n = x_dims[0]
|
||||
.to_usize()
|
||||
.expect("GroupNorm: batch must be concrete");
|
||||
let c = x_dims[1]
|
||||
.to_usize()
|
||||
.expect("GroupNorm: channels must be concrete");
|
||||
assert_eq!(
|
||||
c % num_groups,
|
||||
0,
|
||||
"GroupNorm: channels {c} must be divisible by num_groups {num_groups}"
|
||||
);
|
||||
let cpg = c / num_groups; // channels per group
|
||||
|
||||
// Reshape X from [N, C, spatial...] to [N, G, C/G, spatial...]
|
||||
let spatial_dims: Vec<Expression> = x_dims[2..].to_vec();
|
||||
let mut reshaped = x;
|
||||
let mut new_shape = vec![n, num_groups, cpg];
|
||||
for d in &spatial_dims {
|
||||
new_shape.push(
|
||||
d.to_usize()
|
||||
.expect("GroupNorm: spatial dims must be concrete"),
|
||||
);
|
||||
}
|
||||
reshaped.shape = ShapeTracker::new(new_shape.clone());
|
||||
|
||||
// Normalize over axes [2, 3, ..., ndim] (C/G + spatial dims)
|
||||
let norm_axes: Vec<usize> = (2..new_shape.len()).collect();
|
||||
let mut normed = reshaped.layer_norm(norm_axes, epsilon);
|
||||
|
||||
// Reshape back to [N, C, spatial...]
|
||||
let mut orig_shape = vec![n, c];
|
||||
for d in &spatial_dims {
|
||||
orig_shape.push(d.to_usize().unwrap());
|
||||
}
|
||||
normed *= 1.0;
|
||||
normed.shape = ShapeTracker::new(orig_shape.clone());
|
||||
|
||||
// Apply scale and bias (both shape [C], broadcast to [N, C, spatial...])
|
||||
let target_shape: Vec<Expression> = orig_shape.iter().map(|&d| Expression::from(d)).collect();
|
||||
let result =
|
||||
normed * broadcast_to_expr(scale, &target_shape) + broadcast_to_expr(bias, &target_shape);
|
||||
|
||||
tensors.insert(node.output[0].clone(), result);
|
||||
trace!("Finished parse: GroupNormalization Node");
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,16 +1,15 @@
|
||||
use luminal::prelude::tracing::warn;
|
||||
use luminal::prelude::*;
|
||||
use pyo3::prelude::*;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::compiled_graph::{CompiledGraph, DimParamMap, GraphTranslation, WeightData};
|
||||
use crate::compiled_graph::{CompiledGraph, GraphTranslation, WeightData};
|
||||
use crate::pt2_parser;
|
||||
use crate::pt2_schema;
|
||||
use crate::translator;
|
||||
use crate::typed_data::TypedData;
|
||||
use crate::{pt2_parser, pt2_util};
|
||||
use crate::util::DimParamMap;
|
||||
|
||||
/// Pre-loaded weight/constant data paired with tensor sizes.
|
||||
type PreloadResult = (Vec<(String, TypedData)>, HashMap<String, usize>);
|
||||
type PreloadResult = (Vec<(String, Vec<f32>)>, HashMap<String, usize>);
|
||||
|
||||
fn resolve_dim_sizes(
|
||||
sizes: &[pt2_schema::DimSize],
|
||||
@@ -84,7 +83,7 @@ pub fn translate_pt2(
|
||||
}
|
||||
}
|
||||
|
||||
// Compute shape expressions and dtypes from PT2 tensor metadata
|
||||
// Compute shape expressions from PT2 tensor metadata
|
||||
let output_shape_exprs: Vec<Vec<Expression>> = translated
|
||||
.output_ids
|
||||
.iter()
|
||||
@@ -96,17 +95,6 @@ pub fn translate_pt2(
|
||||
})
|
||||
.collect();
|
||||
|
||||
let output_dtypes: Vec<DType> = translated
|
||||
.output_ids
|
||||
.iter()
|
||||
.map(|(name, _id)| {
|
||||
parsed
|
||||
.tensor_meta(name)
|
||||
.map(|meta| pt2_util::torch_dtype_int_to_luminal(meta.dtype))
|
||||
.unwrap_or(DType::F32)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let input_names: Vec<String> = translated
|
||||
.user_input_ids
|
||||
.iter()
|
||||
@@ -139,7 +127,7 @@ pub fn translate_pt2(
|
||||
}
|
||||
|
||||
// Pre-load weights and compute tensor sizes for CUDA dummy data
|
||||
let mut weights: Vec<(String, TypedData)> = Vec::new();
|
||||
let mut weights: Vec<(String, Vec<f32>)> = Vec::new();
|
||||
let mut tensor_sizes: HashMap<String, usize> = HashMap::new();
|
||||
|
||||
// Load safetensors weights
|
||||
@@ -201,7 +189,6 @@ pub fn translate_pt2(
|
||||
tensor_ids,
|
||||
input_names,
|
||||
output_names,
|
||||
output_dtypes,
|
||||
output_shape_exprs,
|
||||
input_shape_exprs,
|
||||
dim_param_map,
|
||||
@@ -248,8 +235,8 @@ fn preload_safetensors(graph: &Graph, file_path: &str) -> anyhow::Result<Preload
|
||||
.downcast_ref::<luminal::hlir::Input>()
|
||||
&& let Ok(tensor) = st.tensor(&input.label)
|
||||
{
|
||||
let types = bytes_to_typed(tensor.data(), safetensors_dtype_to_pt2(tensor.dtype()));
|
||||
weights.push((input.label.clone(), types));
|
||||
let f32s = bytes_to_f32(tensor.data(), safetensors_dtype_to_pt2(tensor.dtype()));
|
||||
weights.push((input.label.clone(), f32s));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -286,12 +273,15 @@ fn preload_constants(
|
||||
) {
|
||||
Ok(b) => b,
|
||||
Err(e) => {
|
||||
warn!("failed to load constant '{}': {:#}", name, e);
|
||||
eprintln!(
|
||||
"[luminal] Warning: failed to load constant '{}': {:#}",
|
||||
name, e
|
||||
);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
let typed_data = bytes_to_typed(&raw_bytes, entry.tensor_meta.dtype);
|
||||
weights.push((name.clone(), typed_data));
|
||||
let f32_data = bytes_to_f32(&raw_bytes, entry.tensor_meta.dtype);
|
||||
weights.push((name.clone(), f32_data));
|
||||
}
|
||||
|
||||
Ok((weights, sizes))
|
||||
@@ -318,52 +308,49 @@ fn safetensors_dtype_to_pt2(dtype: safetensors::Dtype) -> u32 {
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert raw bytes to TypedData using PT2 dtype numbering.
|
||||
/// Preserves native byte format for types luminal supports directly (f32, f16, bf16, i32, bool, u8, i8).
|
||||
/// Converts i64/f64/i16 to the closest luminal-native representation.
|
||||
fn bytes_to_typed(bytes: &[u8], dtype: u32) -> TypedData {
|
||||
/// Convert raw bytes to f32 using PT2 dtype numbering.
|
||||
fn bytes_to_f32(bytes: &[u8], dtype: u32) -> Vec<f32> {
|
||||
match dtype {
|
||||
// Types that map directly — preserve raw bytes
|
||||
7 => TypedData::from_raw(bytes.to_vec(), DType::F32),
|
||||
6 => TypedData::from_raw(bytes.to_vec(), DType::F16),
|
||||
13 => TypedData::from_raw(bytes.to_vec(), DType::Bf16),
|
||||
4 => TypedData::from_raw(bytes.to_vec(), DType::Int), // i32
|
||||
1 => TypedData::from_raw(bytes.to_vec(), DType::U8),
|
||||
2 => TypedData::from_raw(bytes.to_vec(), DType::I8),
|
||||
12 => TypedData::from_raw(bytes.to_vec(), DType::Bool),
|
||||
|
||||
// i64 → i32 (truncate, matching luminal's Int type)
|
||||
5 => {
|
||||
let i32s: Vec<i32> = bytes
|
||||
.chunks_exact(8)
|
||||
.map(|b| {
|
||||
i64::from_le_bytes([b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]]) as i32
|
||||
})
|
||||
.collect();
|
||||
TypedData::from_i32_vec(i32s)
|
||||
}
|
||||
// f64 → f32 (downcast, luminal has no F64 in practice for most ops)
|
||||
8 => {
|
||||
let f32s: Vec<f32> = bytes
|
||||
.chunks_exact(8)
|
||||
.map(|b| {
|
||||
f64::from_le_bytes([b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]]) as f32
|
||||
})
|
||||
.collect();
|
||||
TypedData::from_f32_vec(f32s)
|
||||
}
|
||||
// i16 → i32 (widen to luminal's Int)
|
||||
3 => {
|
||||
let i32s: Vec<i32> = bytes
|
||||
.chunks_exact(2)
|
||||
.map(|b| i16::from_le_bytes([b[0], b[1]]) as i32)
|
||||
.collect();
|
||||
TypedData::from_i32_vec(i32s)
|
||||
}
|
||||
7 => bytes
|
||||
.chunks_exact(4)
|
||||
.map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
|
||||
.collect(),
|
||||
6 => bytes
|
||||
.chunks_exact(2)
|
||||
.map(|b| half::f16::from_le_bytes([b[0], b[1]]).to_f32())
|
||||
.collect(),
|
||||
13 => bytes
|
||||
.chunks_exact(2)
|
||||
.map(|b| half::bf16::from_le_bytes([b[0], b[1]]).to_f32())
|
||||
.collect(),
|
||||
8 => bytes
|
||||
.chunks_exact(8)
|
||||
.map(|b| f64::from_le_bytes([b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]]) as f32)
|
||||
.collect(),
|
||||
5 => bytes
|
||||
.chunks_exact(8)
|
||||
.map(|b| i64::from_le_bytes([b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]]) as f32)
|
||||
.collect(),
|
||||
4 => bytes
|
||||
.chunks_exact(4)
|
||||
.map(|b| i32::from_le_bytes([b[0], b[1], b[2], b[3]]) as f32)
|
||||
.collect(),
|
||||
3 => bytes
|
||||
.chunks_exact(2)
|
||||
.map(|b| i16::from_le_bytes([b[0], b[1]]) as f32)
|
||||
.collect(),
|
||||
2 => bytes.iter().map(|&b| (b as i8) as f32).collect(),
|
||||
1 => bytes.iter().map(|&b| b as f32).collect(),
|
||||
12 => bytes
|
||||
.iter()
|
||||
.map(|&b| if b != 0 { 1.0 } else { 0.0 })
|
||||
.collect(),
|
||||
_ => {
|
||||
let luminal_dtype = pt2_util::torch_dtype_int_to_luminal(dtype);
|
||||
warn!("Unrecognized dtype {dtype}, interpreting as {luminal_dtype:?}");
|
||||
TypedData::from_raw(bytes.to_vec(), luminal_dtype)
|
||||
eprintln!("[luminal] Warning: unrecognized dtype {dtype}, interpreting as f32");
|
||||
bytes
|
||||
.chunks_exact(4)
|
||||
.map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -77,7 +77,6 @@ pub enum Argument {
|
||||
SymInts(SymIntsArg),
|
||||
SymInt(SymIntArg),
|
||||
Expr(ExprArg),
|
||||
#[allow(dead_code)]
|
||||
ScalarType(ScalarTypeArg),
|
||||
Tensors(TensorsArg),
|
||||
OptionalTensors(OptionalTensorsArg),
|
||||
@@ -169,7 +168,6 @@ pub struct NoneArg {
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
#[allow(dead_code)]
|
||||
pub struct ScalarTypeArg {
|
||||
pub as_scalar_type: u32,
|
||||
}
|
||||
@@ -226,7 +224,6 @@ impl Argument {
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn as_scalar_type(&self) -> Option<u32> {
|
||||
match self {
|
||||
Argument::ScalarType(s) => Some(s.as_scalar_type),
|
||||
|
||||
@@ -16,7 +16,6 @@ pub enum ReductionOp {
|
||||
Mean,
|
||||
Max,
|
||||
Min,
|
||||
Prod,
|
||||
}
|
||||
|
||||
/// Normalize a potentially negative dimension index.
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
use luminal::hlir::NativeData;
|
||||
use luminal::prelude::*;
|
||||
#[cfg(feature = "cuda")]
|
||||
use luminal_cuda_lite::cudarc::driver::{CudaContext, CudaStream};
|
||||
@@ -8,8 +7,6 @@ use rustc_hash::FxHashMap;
|
||||
#[cfg(feature = "cuda")]
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::typed_data::TypedData;
|
||||
|
||||
/// Enum wrapper for runtime backends allowing runtime selection.
|
||||
pub enum RuntimeBackend {
|
||||
Native(NativeRuntime),
|
||||
@@ -18,23 +15,8 @@ pub enum RuntimeBackend {
|
||||
}
|
||||
|
||||
impl RuntimeBackend {
|
||||
/// Set input data for a tensor node (dtype-aware).
|
||||
pub fn set_data(&mut self, node: NodeIndex, data: TypedData) {
|
||||
match self {
|
||||
RuntimeBackend::Native(rt) => {
|
||||
let native: NativeData = data.into();
|
||||
rt.set_data(node, native);
|
||||
}
|
||||
#[cfg(feature = "cuda")]
|
||||
RuntimeBackend::Cuda(rt) => {
|
||||
// CUDA runtime stores raw bytes — just upload directly
|
||||
rt.set_data(node, data.bytes);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Set input data from a Vec<f32> (convenience for backward compatibility).
|
||||
pub fn set_data_f32(&mut self, node: NodeIndex, data: Vec<f32>) {
|
||||
/// Set input data for a tensor node.
|
||||
pub fn set_data(&mut self, node: NodeIndex, data: Vec<f32>) {
|
||||
match self {
|
||||
RuntimeBackend::Native(rt) => rt.set_data(node, data),
|
||||
#[cfg(feature = "cuda")]
|
||||
@@ -51,7 +33,7 @@ impl RuntimeBackend {
|
||||
}
|
||||
}
|
||||
|
||||
/// Get output data as f32 from a tensor node.
|
||||
/// Get output data from a tensor node.
|
||||
pub fn get_f32(&self, node: NodeIndex) -> Vec<f32> {
|
||||
match self {
|
||||
RuntimeBackend::Native(rt) => rt.get_f32(node).to_vec(),
|
||||
|
||||
@@ -12,7 +12,6 @@ impl<'a> Translator<'a> {
|
||||
let arg1 = &node.inputs[1].arg;
|
||||
if let Some(name) = arg1.as_tensor_name() {
|
||||
let b = self.get_tensor(name)?;
|
||||
let (a, b) = ensure_same_dtype(a, b);
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
Ok(match op {
|
||||
BinaryOp::Add => a + b,
|
||||
|
||||
@@ -1,407 +0,0 @@
|
||||
use anyhow::Result;
|
||||
use luminal::prelude::*;
|
||||
|
||||
use crate::pt2_schema::*;
|
||||
|
||||
use super::Translator;
|
||||
|
||||
const CONV_INPUT_ARG: usize = 0;
|
||||
const CONV_WEIGHT_ARG: usize = 1;
|
||||
const CONV_BIAS_ARG: usize = 2;
|
||||
const CONV_STRIDE_ARG: usize = 3;
|
||||
const CONV_PADDING_ARG: usize = 4;
|
||||
const CONV_DILATION_ARG: usize = 5;
|
||||
const CONV_GROUPS_ARG: usize = 6;
|
||||
|
||||
const CONVOLUTION_TRANSPOSED_ARG: usize = 6;
|
||||
const CONVOLUTION_OUTPUT_PADDING_ARG: usize = 7;
|
||||
const CONVOLUTION_GROUPS_ARG: usize = 8;
|
||||
|
||||
impl<'a> Translator<'a> {
|
||||
/// Translate aten.conv{1,2,3}d.default and aten.convolution.default.
|
||||
///
|
||||
/// The PT2 export may omit defaulted trailing arguments entirely. In practice this means
|
||||
/// conv{N}d.default can show up as just `(input, weight)` for the no-bias, stride=1,
|
||||
/// padding=0, dilation=1, groups=1 case.
|
||||
pub(crate) fn translate_conv(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let input = self.get_input_tensor(node, CONV_INPUT_ARG)?;
|
||||
let weight = self.get_input_tensor(node, CONV_WEIGHT_ARG)?;
|
||||
let bias = self.get_input_tensor(node, CONV_BIAS_ARG).ok();
|
||||
|
||||
let x_dims = input.dims();
|
||||
let w_dims = weight.dims();
|
||||
let rank = x_dims.len();
|
||||
let spatial = rank - 2;
|
||||
|
||||
let stride = self
|
||||
.get_ints_arg(node, CONV_STRIDE_ARG)
|
||||
.unwrap_or_else(|_| vec![1; spatial]);
|
||||
let padding = self
|
||||
.get_ints_arg(node, CONV_PADDING_ARG)
|
||||
.unwrap_or_else(|_| vec![0; spatial]);
|
||||
let mut dilation = self
|
||||
.get_ints_arg(node, CONV_DILATION_ARG)
|
||||
.unwrap_or_else(|_| vec![1; spatial]);
|
||||
let groups = if node.target == "torch.ops.aten.convolution.default" {
|
||||
let transposed = self
|
||||
.get_bool_arg(node, CONVOLUTION_TRANSPOSED_ARG)
|
||||
.unwrap_or(false);
|
||||
anyhow::ensure!(
|
||||
!transposed,
|
||||
"conv: ConvTranspose / transposed=true is not supported yet"
|
||||
);
|
||||
let output_padding = self
|
||||
.get_ints_arg(node, CONVOLUTION_OUTPUT_PADDING_ARG)
|
||||
.unwrap_or_else(|_| vec![0; spatial]);
|
||||
anyhow::ensure!(
|
||||
output_padding.iter().all(|&v| v == 0),
|
||||
"conv: output_padding is not supported for non-transposed convolution"
|
||||
);
|
||||
self.get_int_arg(node, CONVOLUTION_GROUPS_ARG).unwrap_or(1) as usize
|
||||
} else {
|
||||
self.get_int_arg(node, CONV_GROUPS_ARG).unwrap_or(1) as usize
|
||||
};
|
||||
if dilation.len() != spatial {
|
||||
dilation = vec![1; spatial];
|
||||
}
|
||||
|
||||
let ch_out = w_dims[0]
|
||||
.to_usize()
|
||||
.ok_or_else(|| anyhow::anyhow!("conv: weight C_out must be concrete"))?;
|
||||
let ch_in = x_dims[1]
|
||||
.to_usize()
|
||||
.ok_or_else(|| anyhow::anyhow!("conv: input C_in must be concrete"))?;
|
||||
anyhow::ensure!(
|
||||
stride.len() == spatial && padding.len() == spatial && dilation.len() == spatial,
|
||||
"conv: stride/padding/dilation rank must match spatial rank {spatial}"
|
||||
);
|
||||
anyhow::ensure!(
|
||||
groups > 0 && ch_in % groups == 0 && ch_out % groups == 0,
|
||||
"conv: invalid group configuration (C_in={ch_in}, C_out={ch_out}, groups={groups})"
|
||||
);
|
||||
let ch_per_group = ch_in / groups;
|
||||
|
||||
let kernel_shape: Vec<usize> = w_dims[2..]
|
||||
.iter()
|
||||
.map(|d| {
|
||||
d.to_usize()
|
||||
.ok_or_else(|| anyhow::anyhow!("conv: kernel dims must be concrete"))
|
||||
})
|
||||
.collect::<Result<_>>()?;
|
||||
let kernel_product: usize = kernel_shape.iter().product();
|
||||
|
||||
// ATen uses symmetric padding (same begin/end)
|
||||
let stride_u: Vec<usize> = stride.iter().map(|&v| v as usize).collect();
|
||||
let padding_u: Vec<usize> = padding.iter().map(|&v| v as usize).collect();
|
||||
let dilation_u: Vec<usize> = dilation.iter().map(|&v| v as usize).collect();
|
||||
|
||||
let mut out = if groups > 1 {
|
||||
let group_out = ch_out / groups;
|
||||
|
||||
if ch_per_group == 1 {
|
||||
// Depthwise (including channel multiplier > 1): avoid per-channel slicing.
|
||||
depthwise_conv(
|
||||
input,
|
||||
weight,
|
||||
&kernel_shape,
|
||||
&stride_u,
|
||||
&dilation_u,
|
||||
&padding_u,
|
||||
&padding_u,
|
||||
ch_in,
|
||||
group_out,
|
||||
kernel_product,
|
||||
spatial,
|
||||
)
|
||||
} else {
|
||||
// General grouped: pre-pad full input then slice per group
|
||||
let padded_input = {
|
||||
let mut pad_spec: Vec<(Expression, Expression)> =
|
||||
vec![(0.into(), 0.into()); 2 + spatial];
|
||||
for i in 0..spatial {
|
||||
pad_spec[2 + i] = (padding_u[i].into(), padding_u[i].into());
|
||||
}
|
||||
input.pad(pad_spec, 0.0)
|
||||
};
|
||||
|
||||
let no_pad = vec![0usize; spatial];
|
||||
let mut group_outputs = Vec::with_capacity(groups);
|
||||
for g in 0..groups {
|
||||
let x_g = slice_channel_group(padded_input, g, ch_per_group, spatial);
|
||||
let w_g =
|
||||
slice_weight_group(weight, g, group_out, ch_per_group * kernel_product);
|
||||
group_outputs.push(conv_unfold(
|
||||
x_g,
|
||||
w_g,
|
||||
&kernel_shape,
|
||||
&stride_u,
|
||||
&dilation_u,
|
||||
&no_pad,
|
||||
&no_pad,
|
||||
ch_per_group,
|
||||
group_out,
|
||||
spatial,
|
||||
));
|
||||
}
|
||||
|
||||
let mut result = group_outputs[0];
|
||||
for g_out in &group_outputs[1..] {
|
||||
result = result.concat_along(*g_out, 1);
|
||||
}
|
||||
result
|
||||
}
|
||||
} else {
|
||||
let mut w_flat = weight;
|
||||
w_flat.shape = ShapeTracker::new_with_element_bits(
|
||||
vec![ch_out, ch_in * kernel_product],
|
||||
weight.dtype.bits(),
|
||||
);
|
||||
|
||||
conv_unfold(
|
||||
input,
|
||||
w_flat,
|
||||
&kernel_shape,
|
||||
&stride_u,
|
||||
&dilation_u,
|
||||
&padding_u,
|
||||
&padding_u,
|
||||
ch_in,
|
||||
ch_out,
|
||||
spatial,
|
||||
)
|
||||
};
|
||||
|
||||
if let Some(b) = bias {
|
||||
let out_dims = out.dims();
|
||||
let mut b_expanded = b.expand_dim(0, 1);
|
||||
for i in 0..spatial {
|
||||
b_expanded = b_expanded.expand_dim(2 + i, out_dims[2 + i]);
|
||||
}
|
||||
out += b_expanded;
|
||||
}
|
||||
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
/// Slice input channels for one group.
|
||||
/// Caller must pre-pad `x` so no additional padding is applied to the slice.
|
||||
fn slice_channel_group(
|
||||
x: GraphTensor,
|
||||
g: usize,
|
||||
ch_per_group: usize,
|
||||
spatial: usize,
|
||||
) -> GraphTensor {
|
||||
let start = g * ch_per_group;
|
||||
let end = start + ch_per_group;
|
||||
let dims = x.dims();
|
||||
let rank = 2 + spatial;
|
||||
let mut slices: Vec<(Expression, Expression)> = Vec::with_capacity(rank);
|
||||
slices.push((0.into(), dims[0]));
|
||||
slices.push((start.into(), end.into()));
|
||||
for dim in dims.iter().take(rank).skip(2) {
|
||||
slices.push((0.into(), *dim));
|
||||
}
|
||||
x.slice(slices)
|
||||
}
|
||||
|
||||
/// Slice and flatten weight for one group.
|
||||
fn slice_weight_group(
|
||||
w: GraphTensor,
|
||||
g: usize,
|
||||
group_out: usize,
|
||||
flat_inner: usize,
|
||||
) -> GraphTensor {
|
||||
let start = g * group_out;
|
||||
let end = start + group_out;
|
||||
let w_dims = w.dims();
|
||||
let mut slices: Vec<(Expression, Expression)> = Vec::with_capacity(w_dims.len());
|
||||
slices.push((start.into(), end.into()));
|
||||
for dim in w_dims.iter().skip(1) {
|
||||
slices.push((0.into(), *dim));
|
||||
}
|
||||
// Materialize through Add: binary op outputs are contiguous in Luminal, which makes the
|
||||
// following flatten safe for the sliced weight buffer.
|
||||
let w_sliced = w.slice(slices) + 0.0;
|
||||
let mut w_flat = w_sliced;
|
||||
w_flat.shape =
|
||||
ShapeTracker::new_with_element_bits(vec![group_out, flat_inner], w_sliced.dtype.bits());
|
||||
w_flat
|
||||
}
|
||||
|
||||
/// Core unfold-based convolution for a single group.
|
||||
///
|
||||
/// `x`: [batch, ch_in, spatial...]
|
||||
/// `w_flat`: [ch_out, ch_in * kernel_product] (already reshaped)
|
||||
/// Returns: [batch, ch_out, out_spatial...]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn conv_unfold(
|
||||
x: GraphTensor,
|
||||
w_flat: GraphTensor,
|
||||
kernel_shape: &[usize],
|
||||
strides: &[usize],
|
||||
dilations: &[usize],
|
||||
pads_begin: &[usize],
|
||||
pads_end: &[usize],
|
||||
_ch_in: usize,
|
||||
_ch_out: usize,
|
||||
spatial: usize,
|
||||
) -> GraphTensor {
|
||||
let rank = 2 + spatial;
|
||||
|
||||
// Pad spatial dimensions (skip if all padding is zero)
|
||||
let needs_pad = pads_begin.iter().any(|&p| p > 0) || pads_end.iter().any(|&p| p > 0);
|
||||
let padded = if needs_pad {
|
||||
let mut padding: Vec<(Expression, Expression)> = vec![(0.into(), 0.into()); rank];
|
||||
for i in 0..spatial {
|
||||
padding[2 + i] = (pads_begin[i].into(), pads_end[i].into());
|
||||
}
|
||||
x.pad(padding, 0.0)
|
||||
} else {
|
||||
x
|
||||
};
|
||||
|
||||
// Build full-rank unfold parameters (1 for batch/channel, actual for spatial)
|
||||
let mut kernel_full = vec![1usize; rank];
|
||||
let mut stride_full = vec![1usize; rank];
|
||||
let mut dilation_full = vec![1usize; rank];
|
||||
kernel_full[2..(spatial + 2)].copy_from_slice(&kernel_shape[..spatial]);
|
||||
stride_full[2..(spatial + 2)].copy_from_slice(&strides[..spatial]);
|
||||
dilation_full[2..(spatial + 2)].copy_from_slice(&dilations[..spatial]);
|
||||
|
||||
let unfolded = padded.unfold(kernel_full, stride_full, dilation_full);
|
||||
// Shape: [win_N, win_C, win_spatial..., k_N=1, k_C=1, k_spatial...]
|
||||
|
||||
// Permute to [N, win_spatial..., C_in, k_N, k_C, k_spatial...]
|
||||
let mut perm: Vec<usize> = Vec::with_capacity(2 * rank);
|
||||
perm.push(0);
|
||||
perm.extend(2..2 + spatial);
|
||||
perm.push(1);
|
||||
perm.extend(rank..2 * rank);
|
||||
let permuted = unfolded.permute(perm);
|
||||
|
||||
let output_spatial_dims: Vec<Expression> = permuted.dims()[1..1 + spatial].to_vec();
|
||||
|
||||
// Merge all channel+kernel dims into [N, spatial..., ch_in * kernel_product]
|
||||
let mut patches = permuted;
|
||||
let target = 2 + spatial;
|
||||
while patches.dims().len() > target {
|
||||
let last = patches.dims().len();
|
||||
patches = patches.merge_dims(last - 2, last - 1);
|
||||
}
|
||||
|
||||
// Merge spatial dims into one
|
||||
for _ in 1..spatial {
|
||||
patches = patches.merge_dims(1, 2);
|
||||
}
|
||||
// patches: [N, spatial_product, ch_in * kernel_product]
|
||||
|
||||
let mut out = patches.matmul(w_flat.permute((1, 0)));
|
||||
// out: [N, spatial_product, ch_out]
|
||||
|
||||
// Restore spatial dimensions
|
||||
for i in (1..spatial).rev() {
|
||||
out = out.split_dims(1, output_spatial_dims[i]);
|
||||
}
|
||||
|
||||
// Move ch_out from last to position 1: [N, ch_out, spatial...]
|
||||
let mut final_order: Vec<usize> = Vec::with_capacity(2 + spatial);
|
||||
final_order.push(0);
|
||||
final_order.push(1 + spatial);
|
||||
final_order.extend(1..1 + spatial);
|
||||
out.permute(final_order)
|
||||
}
|
||||
|
||||
/// Depthwise convolution: groups == in_channels, ch_per_group == 1.
|
||||
///
|
||||
/// Processes all channels simultaneously using element-wise multiply + reduce,
|
||||
/// avoiding per-channel input slicing which can cause index-expression bugs in luminal.
|
||||
///
|
||||
/// out[n, c, oh, ow] = sum_k patches[n, c, oh, ow, k] * weight[c, k]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn depthwise_conv(
|
||||
x: GraphTensor,
|
||||
w: GraphTensor, // [C, 1, *kernel]
|
||||
kernel_shape: &[usize],
|
||||
strides: &[usize],
|
||||
dilations: &[usize],
|
||||
pads_begin: &[usize],
|
||||
pads_end: &[usize],
|
||||
ch: usize,
|
||||
group_out: usize,
|
||||
kernel_product: usize,
|
||||
spatial: usize,
|
||||
) -> GraphTensor {
|
||||
let rank = 2 + spatial;
|
||||
|
||||
let needs_pad = pads_begin.iter().any(|&p| p > 0) || pads_end.iter().any(|&p| p > 0);
|
||||
let padded = if needs_pad {
|
||||
let mut padding: Vec<(Expression, Expression)> = vec![(0.into(), 0.into()); rank];
|
||||
for i in 0..spatial {
|
||||
padding[2 + i] = (pads_begin[i].into(), pads_end[i].into());
|
||||
}
|
||||
x.pad(padding, 0.0)
|
||||
} else {
|
||||
x
|
||||
};
|
||||
|
||||
// Unfold the full [N, C, H+2p, W+2p] with kernel [1, 1, kH, kW]
|
||||
let mut kernel_full = vec![1usize; rank];
|
||||
let mut stride_full = vec![1usize; rank];
|
||||
let mut dilation_full = vec![1usize; rank];
|
||||
kernel_full[2..(spatial + 2)].copy_from_slice(&kernel_shape[..spatial]);
|
||||
stride_full[2..(spatial + 2)].copy_from_slice(&strides[..spatial]);
|
||||
dilation_full[2..(spatial + 2)].copy_from_slice(&dilations[..spatial]);
|
||||
|
||||
let unfolded = padded.unfold(kernel_full, stride_full, dilation_full);
|
||||
// Shape: [N, C, out_H, out_W, 1, 1, kH, kW]
|
||||
|
||||
// Permute to [N, C, out_spatial..., k_all...]
|
||||
let mut perm: Vec<usize> = Vec::with_capacity(2 * rank);
|
||||
perm.push(0); // N
|
||||
perm.push(1); // C
|
||||
perm.extend(2..2 + spatial); // win_spatial
|
||||
perm.extend(rank..2 * rank); // all kernel dims
|
||||
let permuted = unfolded.permute(perm);
|
||||
|
||||
let out_spatial_dims: Vec<Expression> = permuted.dims()[2..2 + spatial].to_vec();
|
||||
|
||||
// Merge all kernel dims (including 1-size k_N, k_C) into kernel_product
|
||||
let target = 3 + spatial; // [N, C, spatial..., K]
|
||||
let mut patches = permuted;
|
||||
while patches.dims().len() > target {
|
||||
let last = patches.dims().len();
|
||||
patches = patches.merge_dims(last - 2, last - 1);
|
||||
}
|
||||
// patches: [N, C, out_H, ..., out_W, kernel_product]
|
||||
|
||||
// Merge spatial into one: [N, C, out_spatial_product, kernel_product]
|
||||
for _ in 1..spatial {
|
||||
patches = patches.merge_dims(2, 3);
|
||||
}
|
||||
|
||||
// Weight [C * group_out, 1, *kernel] -> [C, group_out, kernel_product]
|
||||
let mut w_flat = w;
|
||||
w_flat.shape =
|
||||
ShapeTracker::new_with_element_bits(vec![ch, group_out, kernel_product], w.dtype.bits());
|
||||
|
||||
// patches: [N, C, out_spatial_product, kernel_product]
|
||||
// Expand to [N, C, group_out, out_spatial_product, kernel_product]
|
||||
let patches = patches.expand_dim(2, group_out);
|
||||
|
||||
// Expand weight for broadcast: [1, C, group_out, out_spatial_product, kernel_product]
|
||||
let w_expanded = w_flat.expand_dim(0, 1).expand_dim(3, patches.dims()[3]);
|
||||
|
||||
// Element-wise multiply and sum over kernel dim
|
||||
let product = patches * w_expanded;
|
||||
let mut out = product.sum(vec![4]).merge_dims(1, 2);
|
||||
// out: [N, C * group_out, out_spatial_product]
|
||||
|
||||
// Restore spatial dimensions
|
||||
for i in (1..spatial).rev() {
|
||||
out = out.split_dims(2, out_spatial_dims[i]);
|
||||
}
|
||||
// out: [N, C, out_spatial_0, ..., out_spatial_{s-1}]
|
||||
|
||||
out
|
||||
}
|
||||
@@ -66,72 +66,74 @@ impl<'a> Translator<'a> {
|
||||
}
|
||||
"torch.ops.aten.sigmoid.default" => self.translate_unary_op(node, |a| a.sigmoid())?,
|
||||
"torch.ops.aten.relu.default" => self.translate_unary_op(node, |a| a.relu())?,
|
||||
"torch.ops.aten.silu.default" => self.translate_unary_op(node, |a| a.swish())?,
|
||||
"torch.ops.aten.tanh.default" => self.translate_unary_op(node, |a| a.tanh())?,
|
||||
"torch.ops.aten.abs.default" => self.translate_unary_op(node, |a| a.abs())?,
|
||||
"torch.ops.aten.log.default" => self.translate_unary_op(node, |a| a.log())?,
|
||||
"torch.ops.aten.log2.default" => self.translate_unary_op(node, |a| a.log2())?,
|
||||
"torch.ops.aten.exp2.default" => self.translate_unary_op(node, |a| a.exp2())?,
|
||||
|
||||
// Cast
|
||||
"torch.ops.aten._to_copy.default" => self.translate_to_copy(node)?,
|
||||
"torch.ops.aten.to.dtype" => self.translate_to_dtype(node)?,
|
||||
"torch.ops.aten.to.dtype_layout" => self.translate_to_dtype_layout(node)?,
|
||||
|
||||
// No-op
|
||||
"torch.ops.aten.alias.default" => self.get_input_tensor(node, 0)?,
|
||||
// No-op pass-throughs
|
||||
"torch.ops.aten.alias.default"
|
||||
| "torch.ops.aten.detach_.default"
|
||||
| "torch.ops.aten.lift_fresh_copy.default" => self.get_input_tensor(node, 0)?,
|
||||
"torch.ops.aten.dropout.default" => self.get_input_tensor(node, 0)?,
|
||||
|
||||
// Shape ops
|
||||
"torch.ops.aten.view.default" => self.translate_reshape(node)?,
|
||||
"torch.ops.aten.view.default"
|
||||
| "torch.ops.aten.reshape.default"
|
||||
| "torch.ops.aten._unsafe_view.default" => self.translate_reshape(node)?,
|
||||
"torch.ops.aten.permute.default" => self.translate_permute(node)?,
|
||||
"torch.ops.aten.transpose.int" => self.translate_transpose(node)?,
|
||||
"torch.ops.aten.t.default" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
a.t()
|
||||
}
|
||||
"torch.ops.aten.unsqueeze.default" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let dim = self.get_int_arg(node, 1)?;
|
||||
let dim = normalize_dim(dim, a.shape.len() + 1);
|
||||
a.unsqueeze(dim)
|
||||
}
|
||||
"torch.ops.aten.squeeze.dims" => {
|
||||
"torch.ops.aten.squeeze.dim" | "torch.ops.aten.squeeze.default" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let dims = self.get_ints_arg(node, 1)?;
|
||||
let ndim = a.shape.len();
|
||||
let mut sorted_dims: Vec<usize> =
|
||||
dims.iter().map(|&d| normalize_dim(d, ndim)).collect();
|
||||
sorted_dims.sort();
|
||||
let mut result = a;
|
||||
let mut offset = 0;
|
||||
for d in sorted_dims {
|
||||
if result.shape.dims[d - offset].to_usize() == Some(1) {
|
||||
result = result.squeeze(d - offset);
|
||||
offset += 1;
|
||||
if node.inputs.len() > 1 {
|
||||
let dim = self.get_int_arg(node, 1)?;
|
||||
let dim = normalize_dim(dim, a.shape.len());
|
||||
a.squeeze(dim)
|
||||
} else {
|
||||
let mut result = a;
|
||||
let dims = a.shape.dims;
|
||||
let mut offset = 0;
|
||||
for (i, d) in dims.iter().enumerate() {
|
||||
if d.to_usize() == Some(1) {
|
||||
result = result.squeeze(i - offset);
|
||||
offset += 1;
|
||||
}
|
||||
}
|
||||
result
|
||||
}
|
||||
result
|
||||
}
|
||||
"torch.ops.aten.expand.default" => self.translate_expand(node)?,
|
||||
"torch.ops.aten.clone.default" => {
|
||||
"torch.ops.aten.contiguous.default" | "torch.ops.aten.clone.default" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
if !a.shape.is_contiguous() { a + 0.0 } else { a }
|
||||
}
|
||||
|
||||
// Matmul
|
||||
"torch.ops.aten.mm.default" | "torch.ops.aten.bmm.default" => {
|
||||
"torch.ops.aten.mm.default"
|
||||
| "torch.ops.aten.bmm.default"
|
||||
| "torch.ops.aten.matmul.default" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let b = self.get_input_tensor(node, 1)?;
|
||||
let (a, b) = ensure_same_dtype(a, b);
|
||||
a.matmul(b)
|
||||
}
|
||||
|
||||
// addmm: beta*input + alpha*(mat1 @ mat2)
|
||||
"torch.ops.aten.addmm.default" => {
|
||||
let input = self.get_input_tensor(node, 0)?;
|
||||
let mat1 = self.get_input_tensor(node, 1)?;
|
||||
let mat2 = self.get_input_tensor(node, 2)?;
|
||||
let beta = self.get_float_arg(node, 3).unwrap_or(1.0) as f32;
|
||||
let alpha = self.get_float_arg(node, 4).unwrap_or(1.0) as f32;
|
||||
let mm = mat1.matmul(mat2);
|
||||
let (input, mm) = broadcast_binary(input, mm);
|
||||
input * beta + mm * alpha
|
||||
}
|
||||
|
||||
// Convolution
|
||||
"torch.ops.aten.convolution.default" => self.translate_conv(node)?,
|
||||
// Linear
|
||||
"torch.ops.aten.linear.default" => self.translate_linear(node)?,
|
||||
|
||||
// Reduction ops
|
||||
"torch.ops.aten.sum.dim_IntList" => self.translate_reduction(node, ReductionOp::Sum)?,
|
||||
@@ -140,14 +142,16 @@ impl<'a> Translator<'a> {
|
||||
|
||||
// Slice/index ops
|
||||
"torch.ops.aten.slice.Tensor" => self.translate_slice(node)?,
|
||||
"torch.ops.aten.select.int" => self.translate_select(node)?,
|
||||
"torch.ops.aten.cat.default" => self.translate_cat(node)?,
|
||||
"torch.ops.aten.index_select.default" => self.translate_index_select(node)?,
|
||||
"torch.ops.aten.index.Tensor" => self.translate_index_tensor(node)?,
|
||||
|
||||
// Embedding
|
||||
"torch.ops.aten.embedding.default" => self.translate_embedding(node)?,
|
||||
|
||||
// Softmax
|
||||
"torch.ops.aten._softmax.default" => {
|
||||
"torch.ops.aten._softmax.default" | "torch.ops.aten.softmax.int" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let dim = self.get_int_arg(node, 1)?;
|
||||
let dim = normalize_dim(dim, a.shape.len());
|
||||
@@ -155,10 +159,11 @@ impl<'a> Translator<'a> {
|
||||
}
|
||||
|
||||
// LayerNorm
|
||||
"torch.ops.aten.native_layer_norm.default" => self.translate_layer_norm(node)?,
|
||||
"torch.ops.aten.layer_norm.default" => self.translate_layer_norm(node)?,
|
||||
|
||||
// Where
|
||||
"torch.ops.aten.where.self" => self.translate_where(node)?,
|
||||
"torch.ops.aten.where.ScalarOther" => self.translate_where_scalar_other(node)?,
|
||||
|
||||
// Pow
|
||||
"torch.ops.aten.pow.Tensor_Scalar" => {
|
||||
@@ -174,12 +179,18 @@ impl<'a> Translator<'a> {
|
||||
}
|
||||
|
||||
// Creation ops
|
||||
"torch.ops.aten.arange.start_step" => self.translate_arange(node)?,
|
||||
"torch.ops.aten.full.default" => self.translate_full(node)?,
|
||||
"torch.ops.aten.scalar_tensor.default" => {
|
||||
let val = self.get_float_arg(node, 0)? as f32;
|
||||
self.graph.constant_float(val)
|
||||
"torch.ops.aten.arange.default" | "torch.ops.aten.arange.start" => {
|
||||
self.translate_arange(node)?
|
||||
}
|
||||
"torch.ops.aten.full.default" => self.translate_full(node)?,
|
||||
"torch.ops.aten.zeros.default" | "torch.ops.aten.zeros_like.default" => {
|
||||
self.translate_zeros(node)?
|
||||
}
|
||||
"torch.ops.aten.ones.default" | "torch.ops.aten.ones_like.default" => {
|
||||
self.translate_ones(node)?
|
||||
}
|
||||
"torch.ops.aten.new_ones.default" => self.translate_new_ones(node)?,
|
||||
|
||||
// Scalar comparisons
|
||||
"torch.ops.aten.gt.Scalar" => self.translate_scalar_comparison(node, |a, s| a.gt(s))?,
|
||||
"torch.ops.aten.lt.Scalar" => self.translate_scalar_comparison(node, |a, s| a.lt(s))?,
|
||||
@@ -211,7 +222,7 @@ impl<'a> Translator<'a> {
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
a.le(b)
|
||||
}
|
||||
"torch.ops.aten.bitwise_and.Tensor" | "torch.ops.aten.logical_and.default" => {
|
||||
"torch.ops.aten.__and__.Tensor" | "torch.ops.aten.logical_and.default" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let b = self.get_input_tensor(node, 1)?;
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
@@ -237,7 +248,9 @@ impl<'a> Translator<'a> {
|
||||
}
|
||||
|
||||
// Clamp
|
||||
"torch.ops.aten.clamp.default" => self.translate_clamp(node)?,
|
||||
"torch.ops.aten.clamp.default" | "torch.ops.aten.clamp_min.default" => {
|
||||
self.translate_clamp(node)?
|
||||
}
|
||||
|
||||
// Cumsum
|
||||
"torch.ops.aten.cumsum.default" => {
|
||||
@@ -252,6 +265,9 @@ impl<'a> Translator<'a> {
|
||||
a.cumsum(dim)
|
||||
}
|
||||
|
||||
// Diff
|
||||
"torch.ops.aten.diff.default" => self.translate_diff(node)?,
|
||||
|
||||
// Floor / Ceil / Erf (approximations)
|
||||
"torch.ops.aten.floor.default" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
@@ -336,12 +352,45 @@ impl<'a> Translator<'a> {
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
a.gt(b)
|
||||
}
|
||||
"torch.ops.aten.ne.Tensor" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let b = self.get_input_tensor(node, 1)?;
|
||||
let (a, b) = ensure_same_dtype(a, b);
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
a.ne(b)
|
||||
}
|
||||
|
||||
// Full-reduce variants (no dim arg) — handled by translate_reduction fallback
|
||||
"torch.ops.aten.sum.default" => self.translate_reduction(node, ReductionOp::Sum)?,
|
||||
"torch.ops.aten.mean.default" => self.translate_reduction(node, ReductionOp::Mean)?,
|
||||
"torch.ops.aten.max.default" => self.translate_reduction(node, ReductionOp::Max)?,
|
||||
"torch.ops.aten.min.default" => self.translate_reduction(node, ReductionOp::Min)?,
|
||||
// Reductions without dim arg (full reduce)
|
||||
// Flatten to [1, N] and reduce axis 1 to avoid multi-step HLIR
|
||||
// that CUDA can't schedule (grid (0,1,1) invalid launch).
|
||||
"torch.ops.aten.sum.default" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let total = concrete_numel(&a)?;
|
||||
let mut flat = a;
|
||||
flat.shape = ShapeTracker::new(vec![1, total]);
|
||||
flat.sum(vec![1])
|
||||
}
|
||||
"torch.ops.aten.mean.default" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let total = concrete_numel(&a)?;
|
||||
let mut flat = a;
|
||||
flat.shape = ShapeTracker::new(vec![1, total]);
|
||||
flat.sum(vec![1]) / total as f32
|
||||
}
|
||||
"torch.ops.aten.max.default" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let total = concrete_numel(&a)?;
|
||||
let mut flat = a;
|
||||
flat.shape = ShapeTracker::new(vec![1, total]);
|
||||
flat.max(vec![1])
|
||||
}
|
||||
"torch.ops.aten.min.default" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let total = concrete_numel(&a)?;
|
||||
let mut flat = a;
|
||||
flat.shape = ShapeTracker::new(vec![1, total]);
|
||||
flat.min(vec![1])
|
||||
}
|
||||
"torch.ops.aten.amin.default" => self.translate_reduction(node, ReductionOp::Min)?,
|
||||
|
||||
// Gather (axis-aware)
|
||||
@@ -349,7 +398,11 @@ impl<'a> Translator<'a> {
|
||||
|
||||
// Scatter ops
|
||||
"torch.ops.aten.scatter.src" => self.translate_scatter_src(node)?,
|
||||
"torch.ops.aten.index_put.default" => self.translate_index_put(node)?,
|
||||
"torch.ops.aten.index_put_.default" => self.translate_index_put(node)?,
|
||||
|
||||
// Triangular
|
||||
"torch.ops.aten.tril.default" => self.translate_tril(node)?,
|
||||
"torch.ops.aten.triu.default" => self.translate_triu(node)?,
|
||||
|
||||
// TopK — handles its own output storage, returns early
|
||||
"torch.ops.aten.topk.default" => {
|
||||
@@ -358,7 +411,12 @@ impl<'a> Translator<'a> {
|
||||
}
|
||||
|
||||
// Split
|
||||
"torch.ops.aten.split_with_sizes.default" => self.translate_split_with_sizes(node)?,
|
||||
"torch.ops.aten.split.Tensor" | "torch.ops.aten.split_with_sizes.default" => {
|
||||
self.translate_split(node)?
|
||||
}
|
||||
|
||||
// One-hot
|
||||
"torch.ops.aten.one_hot.default" => self.translate_one_hot(node)?,
|
||||
|
||||
// Fmod
|
||||
"torch.ops.aten.fmod.Tensor" => {
|
||||
@@ -367,8 +425,12 @@ impl<'a> Translator<'a> {
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
a % b
|
||||
}
|
||||
// Prod reduction
|
||||
"torch.ops.aten.prod.dim_int" => self.translate_reduction(node, ReductionOp::Prod)?,
|
||||
"torch.ops.aten.fmod.Scalar" | "torch.ops.aten.remainder.Scalar" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let val = self.get_float_arg(node, 1)? as f32;
|
||||
let b = self.graph.constant_float(val).expand_rhs(a.shape);
|
||||
a % b
|
||||
}
|
||||
|
||||
other => {
|
||||
bail!("Unsupported ATen op: {other}");
|
||||
@@ -382,6 +444,15 @@ impl<'a> Translator<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute total element count, returning an error if any dimension is symbolic.
|
||||
fn concrete_numel(a: &GraphTensor) -> Result<usize> {
|
||||
a.dims().iter().try_fold(1usize, |acc, d| {
|
||||
d.to_usize().map(|v| acc * v).ok_or_else(|| {
|
||||
anyhow::anyhow!("Full reduction requires concrete dimensions, got symbolic dim")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
impl<'a> Translator<'a> {
|
||||
fn translate_scalar_comparison(
|
||||
&mut self,
|
||||
|
||||
23
crates/luminal_python/rust/src/translator/matmul.rs
Normal file
23
crates/luminal_python/rust/src/translator/matmul.rs
Normal file
@@ -0,0 +1,23 @@
|
||||
use anyhow::Result;
|
||||
use luminal::prelude::*;
|
||||
|
||||
use crate::pt2_schema::*;
|
||||
use crate::pt2_util::broadcast_binary;
|
||||
|
||||
use super::Translator;
|
||||
|
||||
impl<'a> Translator<'a> {
|
||||
pub(crate) fn translate_linear(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let input = self.get_input_tensor(node, 0)?;
|
||||
let weight = self.get_input_tensor(node, 1)?;
|
||||
let result = input.matmul(weight.t());
|
||||
|
||||
if node.inputs.len() > 2
|
||||
&& let Ok(bias) = self.get_input_tensor(node, 2)
|
||||
{
|
||||
let (result, bias) = broadcast_binary(result, bias);
|
||||
return Ok(result + bias);
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
@@ -3,8 +3,8 @@
|
||||
//! Walks the parsed PT2 graph and constructs an equivalent Luminal computation graph.
|
||||
|
||||
mod binary;
|
||||
mod conv;
|
||||
mod dispatch;
|
||||
mod matmul;
|
||||
mod movement;
|
||||
mod reduction;
|
||||
mod tensor;
|
||||
@@ -18,7 +18,6 @@ use luminal::prelude::*;
|
||||
|
||||
use crate::pt2_parser::{InputKind, ParsedPT2, SymDimMap};
|
||||
use crate::pt2_schema::*;
|
||||
use crate::pt2_util;
|
||||
|
||||
/// Result of translating a PT2 graph to a Luminal graph.
|
||||
pub struct TranslatedGraph {
|
||||
@@ -77,12 +76,7 @@ impl<'a> Translator<'a> {
|
||||
let output_names = self.parsed.output_names();
|
||||
for name in &output_names {
|
||||
let tensor = self.get_tensor(name)?;
|
||||
// Cast non-float outputs (Bool, Int) to F32 for the runtime.
|
||||
// Preserve F16/BF16/F32 as-is to avoid corrupting half-precision models.
|
||||
let tensor = match tensor.dtype {
|
||||
DType::Bool | DType::Int => tensor.cast(DType::F32) + 0.0,
|
||||
_ => tensor + 0.0,
|
||||
};
|
||||
let tensor = tensor + 0.0;
|
||||
tensor.output();
|
||||
self.output_ids.push((name.clone(), tensor.id));
|
||||
}
|
||||
@@ -103,12 +97,7 @@ impl<'a> Translator<'a> {
|
||||
.tensor_meta(graph_name)
|
||||
.with_context(|| format!("Missing tensor meta for param {graph_name}"))?;
|
||||
let shape = self.tensor_meta_to_shape(meta)?;
|
||||
let dtype = pt2_util::torch_dtype_int_to_luminal(meta.dtype);
|
||||
let tensor = self
|
||||
.graph
|
||||
.named_tensor(original_name, shape)
|
||||
.as_dtype(dtype);
|
||||
tensor.persist();
|
||||
let tensor = self.graph.named_tensor(original_name, shape);
|
||||
self.tensors.insert(graph_name.clone(), tensor);
|
||||
}
|
||||
InputKind::Buffer {
|
||||
@@ -120,12 +109,7 @@ impl<'a> Translator<'a> {
|
||||
.tensor_meta(graph_name)
|
||||
.with_context(|| format!("Missing tensor meta for buffer {graph_name}"))?;
|
||||
let shape = self.tensor_meta_to_shape(meta)?;
|
||||
let dtype = pt2_util::torch_dtype_int_to_luminal(meta.dtype);
|
||||
let tensor = self
|
||||
.graph
|
||||
.named_tensor(original_name, shape)
|
||||
.as_dtype(dtype);
|
||||
tensor.persist();
|
||||
let tensor = self.graph.named_tensor(original_name, shape);
|
||||
self.tensors.insert(graph_name.clone(), tensor);
|
||||
}
|
||||
InputKind::UserInput { graph_name } => {
|
||||
@@ -134,8 +118,7 @@ impl<'a> Translator<'a> {
|
||||
.tensor_meta(graph_name)
|
||||
.with_context(|| format!("Missing tensor meta for input {graph_name}"))?;
|
||||
let shape = self.tensor_meta_to_shape(meta)?;
|
||||
let dtype = pt2_util::torch_dtype_int_to_luminal(meta.dtype);
|
||||
let tensor = self.graph.named_tensor(graph_name, shape).as_dtype(dtype);
|
||||
let tensor = self.graph.named_tensor(graph_name, shape);
|
||||
self.user_input_ids.push((graph_name.clone(), tensor.id));
|
||||
self.tensors.insert(graph_name.clone(), tensor);
|
||||
}
|
||||
@@ -155,6 +138,13 @@ impl<'a> Translator<'a> {
|
||||
|
||||
// --- Helper methods ---
|
||||
|
||||
/// Look up tensor metadata by name, checking subgraph extras first.
|
||||
pub(crate) fn tensor_meta(&self, name: &str) -> Option<&TensorMeta> {
|
||||
self.extra_tensor_values
|
||||
.get(name)
|
||||
.or_else(|| self.parsed.tensor_meta(name))
|
||||
}
|
||||
|
||||
pub(crate) fn get_tensor(&self, name: &str) -> Result<GraphTensor> {
|
||||
self.tensors
|
||||
.get(name)
|
||||
|
||||
@@ -49,6 +49,15 @@ impl<'a> Translator<'a> {
|
||||
Ok(a.permute(axes))
|
||||
}
|
||||
|
||||
pub(crate) fn translate_transpose(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let dim0 = self.get_int_arg(node, 1)?;
|
||||
let dim1 = self.get_int_arg(node, 2)?;
|
||||
let dim0 = normalize_dim(dim0, a.shape.len());
|
||||
let dim1 = normalize_dim(dim1, a.shape.len());
|
||||
Ok(a.transpose(dim0, dim1))
|
||||
}
|
||||
|
||||
pub(crate) fn translate_expand(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let mut a = self.get_input_tensor(node, 0)?;
|
||||
let neg1_expr = Expression::from(-1i32);
|
||||
@@ -115,6 +124,20 @@ impl<'a> Translator<'a> {
|
||||
Ok(a.slice_along(start..end, dim))
|
||||
}
|
||||
|
||||
pub(crate) fn translate_select(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let dim = self.get_int_arg(node, 1)?;
|
||||
let dim = normalize_dim(dim, a.shape.len());
|
||||
let index = self.get_int_arg(node, 2)?;
|
||||
let index = if index < 0 {
|
||||
bail!("Negative select index not yet supported");
|
||||
} else {
|
||||
index as usize
|
||||
};
|
||||
|
||||
Ok(a.slice_along(index..index + 1, dim).squeeze(dim))
|
||||
}
|
||||
|
||||
pub(crate) fn translate_cat(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let tensors: Vec<GraphTensor> = if let Some(names) = node.inputs[0].arg.as_tensors() {
|
||||
names
|
||||
@@ -161,6 +184,31 @@ impl<'a> Translator<'a> {
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
pub(crate) fn translate_index_select(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let dim = self.get_int_arg(node, 1)?;
|
||||
let dim = normalize_dim(dim, a.shape.len());
|
||||
let indices = self.get_input_tensor(node, 2)?.cast(DType::Int);
|
||||
let src_dims = a.shape.dims;
|
||||
let idx_len = indices.shape.dims[0];
|
||||
|
||||
// Reshape 1D indices [K] → [1,..,K,..,1] with K at position `dim`
|
||||
let mut idx = indices;
|
||||
for _ in 0..dim {
|
||||
idx = idx.unsqueeze(0);
|
||||
}
|
||||
for _ in (dim + 1)..src_dims.len() {
|
||||
idx = idx.expand_dim(idx.shape.len(), Expression::from(1usize));
|
||||
}
|
||||
|
||||
// Expand to output shape: src_dims with dim replaced by idx_len
|
||||
let mut target: Vec<Expression> = src_dims.to_vec();
|
||||
target[dim] = idx_len;
|
||||
idx.shape.expand(target);
|
||||
|
||||
Ok(a.gather_elements(idx, dim))
|
||||
}
|
||||
|
||||
pub(crate) fn translate_embedding(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let weight = self.get_input_tensor(node, 0)?;
|
||||
let indices = self.get_input_tensor(node, 1)?;
|
||||
@@ -382,9 +430,9 @@ impl<'a> Translator<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn translate_split_with_sizes(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
pub(crate) fn translate_split(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let sizes = self.get_ints_arg(node, 1)?;
|
||||
let split_size = self.get_int_arg(node, 1)? as usize;
|
||||
let dim = if node.inputs.len() > 2 {
|
||||
self.get_int_arg(node, 2).unwrap_or(0)
|
||||
} else {
|
||||
@@ -392,32 +440,35 @@ impl<'a> Translator<'a> {
|
||||
};
|
||||
let dim = normalize_dim(dim, a.shape.len());
|
||||
|
||||
let output_names: Vec<String> = node
|
||||
.outputs
|
||||
.first()
|
||||
.and_then(|o| o.as_tensors.as_ref())
|
||||
.map(|ts| ts.iter().map(|t| t.name.clone()).collect())
|
||||
.unwrap_or_else(|| {
|
||||
node.outputs
|
||||
.iter()
|
||||
.filter_map(|o| o.as_tensor.as_ref().map(|t| t.name.clone()))
|
||||
.collect()
|
||||
});
|
||||
let dim_size = a.shape.dims[dim];
|
||||
if let Some(total) = dim_size.to_usize() {
|
||||
// Collect output names from as_tensors (multi-output) or as_tensor (single)
|
||||
let output_names: Vec<String> = node
|
||||
.outputs
|
||||
.first()
|
||||
.and_then(|o| o.as_tensors.as_ref())
|
||||
.map(|ts| ts.iter().map(|t| t.name.clone()).collect())
|
||||
.unwrap_or_else(|| {
|
||||
node.outputs
|
||||
.iter()
|
||||
.filter_map(|o| o.as_tensor.as_ref().map(|t| t.name.clone()))
|
||||
.collect()
|
||||
});
|
||||
|
||||
let mut offset = 0usize;
|
||||
let mut first_chunk = None;
|
||||
for (i, &size) in sizes.iter().enumerate() {
|
||||
let size = size as usize;
|
||||
let chunk = a.slice_along(offset..offset + size, dim);
|
||||
if let Some(name) = output_names.get(i) {
|
||||
self.tensors.insert(name.clone(), chunk);
|
||||
// Store each chunk under its output name
|
||||
for (i, out_name) in output_names.iter().enumerate() {
|
||||
let start = i * split_size;
|
||||
let end = ((i + 1) * split_size).min(total);
|
||||
if start < total {
|
||||
let chunk = a.slice_along(start..end, dim);
|
||||
self.tensors.insert(out_name.clone(), chunk);
|
||||
}
|
||||
}
|
||||
if i == 0 {
|
||||
first_chunk = Some(chunk);
|
||||
}
|
||||
offset += size;
|
||||
|
||||
// Return the first chunk
|
||||
Ok(a.slice_along(0..split_size.min(total), dim))
|
||||
} else {
|
||||
Ok(a.slice_along(0..split_size, dim))
|
||||
}
|
||||
|
||||
first_chunk.ok_or_else(|| anyhow::anyhow!("split_with_sizes: empty sizes list"))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,15 +6,6 @@ use crate::pt2_util::*;
|
||||
|
||||
use super::Translator;
|
||||
|
||||
/// Compute total element count, returning an error if any dimension is symbolic.
|
||||
fn concrete_numel(a: &GraphTensor) -> Result<usize> {
|
||||
a.dims().iter().try_fold(1usize, |acc, d| {
|
||||
d.to_usize().map(|v| acc * v).ok_or_else(|| {
|
||||
anyhow::anyhow!("Full reduction requires concrete dimensions, got symbolic dim")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
impl<'a> Translator<'a> {
|
||||
pub(crate) fn translate_reduction(
|
||||
&mut self,
|
||||
@@ -22,42 +13,21 @@ impl<'a> Translator<'a> {
|
||||
op: ReductionOp,
|
||||
) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
|
||||
// Try to get dims arg; if missing or empty, fall back to full reduce
|
||||
let dims_result = self.get_ints_arg(node, 1);
|
||||
let (axes, keepdim) = match dims_result {
|
||||
Ok(ref dims) if !dims.is_empty() => {
|
||||
let ndim = a.shape.len();
|
||||
let axes: Vec<usize> = dims.iter().map(|&d| normalize_dim(d, ndim)).collect();
|
||||
let keepdim = if node.inputs.len() > 2 {
|
||||
self.get_bool_arg(node, 2).unwrap_or(false)
|
||||
} else {
|
||||
false
|
||||
};
|
||||
(axes, keepdim)
|
||||
}
|
||||
_ => {
|
||||
// Full reduce: flatten to [1, N] and reduce axis 1
|
||||
let total = concrete_numel(&a)?;
|
||||
let mut flat = a;
|
||||
flat.shape = ShapeTracker::new(vec![1, total]);
|
||||
let result = match op {
|
||||
ReductionOp::Sum => flat.sum(vec![1]),
|
||||
ReductionOp::Mean => flat.sum(vec![1]) / total as f32,
|
||||
ReductionOp::Max => flat.max(vec![1]),
|
||||
ReductionOp::Min => flat.min(vec![1]),
|
||||
ReductionOp::Prod => flat.prod(vec![1]),
|
||||
};
|
||||
return Ok(result);
|
||||
}
|
||||
let dims = self.get_ints_arg(node, 1)?;
|
||||
let keepdim = if node.inputs.len() > 2 {
|
||||
self.get_bool_arg(node, 2).unwrap_or(false)
|
||||
} else {
|
||||
false
|
||||
};
|
||||
|
||||
let ndim = a.shape.len();
|
||||
let axes: Vec<usize> = dims.iter().map(|&d| normalize_dim(d, ndim)).collect();
|
||||
|
||||
let mut result = match op {
|
||||
ReductionOp::Sum => a.sum(axes.clone()),
|
||||
ReductionOp::Mean => a.mean(axes.clone()),
|
||||
ReductionOp::Max => a.max(axes.clone()),
|
||||
ReductionOp::Min => a.min(axes.clone()),
|
||||
ReductionOp::Prod => a.prod(axes.clone()),
|
||||
};
|
||||
|
||||
if keepdim {
|
||||
|
||||
@@ -18,48 +18,139 @@ impl<'a> Translator<'a> {
|
||||
match positional_args.len() {
|
||||
0 => anyhow::bail!("arange: no positional args found"),
|
||||
1 => Ok(self.graph.arange(positional_args[0])),
|
||||
2 => Ok(self
|
||||
_ => Ok(self
|
||||
.graph
|
||||
.arange_options(positional_args[0], positional_args[1], 1)),
|
||||
_ => Ok(self.graph.arange_options(
|
||||
positional_args[0],
|
||||
positional_args[1],
|
||||
positional_args[2],
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn translate_full(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let shape = self.get_exprs_arg(node, 0)?;
|
||||
// fill_value can be float, int, or bool after decomposition
|
||||
let val = if let Ok(f) = self.get_float_arg(node, 1) {
|
||||
f as f32
|
||||
} else if let Ok(b) = self.get_bool_arg(node, 1) {
|
||||
if b { 1.0 } else { 0.0 }
|
||||
} else {
|
||||
anyhow::bail!(
|
||||
"full: unsupported fill value type: {:?}",
|
||||
node.inputs.get(1)
|
||||
);
|
||||
};
|
||||
let val = self.get_float_arg(node, 1)? as f32;
|
||||
Ok(self.graph.constant_float(val).expand_rhs(shape))
|
||||
}
|
||||
|
||||
pub(crate) fn translate_zeros(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
self.translate_constant_fill(node, 0.0)
|
||||
}
|
||||
|
||||
pub(crate) fn translate_ones(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
self.translate_constant_fill(node, 1.0)
|
||||
}
|
||||
|
||||
pub(crate) fn translate_new_ones(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
self.translate_constant_fill(node, 1.0)
|
||||
}
|
||||
|
||||
fn translate_constant_fill(&mut self, node: &Node, val: f32) -> Result<GraphTensor> {
|
||||
let output_name = node
|
||||
.outputs
|
||||
.first()
|
||||
.and_then(|o| o.as_tensor.as_ref())
|
||||
.map(|t| t.name.clone())
|
||||
.unwrap_or_default();
|
||||
let meta = self
|
||||
.tensor_meta(&output_name)
|
||||
.context("Missing tensor meta for constant fill output")?;
|
||||
let shape = self.tensor_meta_to_shape(meta)?;
|
||||
if shape.is_empty() {
|
||||
Ok(self.graph.constant_float(val))
|
||||
} else {
|
||||
Ok(self.graph.constant_float(val).expand_rhs(shape))
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn translate_where(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let cond = self.get_input_tensor(node, 0)?;
|
||||
let x = self.get_input_tensor(node, 1)?;
|
||||
let y = self.get_input_tensor(node, 2)?;
|
||||
// Ensure x and y have the same dtype
|
||||
let (x, y) = ensure_same_dtype(x, y);
|
||||
// Broadcast all three tensors to a common shape first
|
||||
let (cond_b, x_b) = broadcast_binary(cond, x);
|
||||
let (cond_bc, y_b) = broadcast_binary(cond_b, y);
|
||||
let (x_bc, y_bc) = broadcast_binary(x_b, y_b);
|
||||
let c = cond_bc.cast(DType::F32);
|
||||
let x_f = x_bc.cast(DType::F32);
|
||||
let y_f = y_bc.cast(DType::F32);
|
||||
let one = self.graph.constant_float(1.0).expand_rhs(c.shape);
|
||||
Ok(c * x_f + (one - c) * y_f)
|
||||
Ok(c * x_bc + (one - c) * y_bc)
|
||||
}
|
||||
|
||||
pub(crate) fn translate_where_scalar_other(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let cond = self.get_input_tensor(node, 0)?;
|
||||
let x = self.get_input_tensor(node, 1)?;
|
||||
let other_val = self.get_float_arg(node, 2)? as f32;
|
||||
// Broadcast cond and x to a common shape
|
||||
let (cond_b, x_b) = broadcast_binary(cond, x);
|
||||
let c = cond_b.cast(DType::F32);
|
||||
let one = self.graph.constant_float(1.0).expand_rhs(c.shape);
|
||||
let other = self.graph.constant_float(other_val).expand_rhs(c.shape);
|
||||
Ok(c * x_b + (one - c) * other)
|
||||
}
|
||||
|
||||
pub(crate) fn translate_diff(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let input = self.get_input_tensor(node, 0)?;
|
||||
let dim = if node.inputs.len() > 2 {
|
||||
self.get_int_arg(node, 2).unwrap_or(-1)
|
||||
} else {
|
||||
-1
|
||||
};
|
||||
let dim = normalize_dim(dim, input.shape.len());
|
||||
|
||||
let prepend = if node.inputs.len() > 3 {
|
||||
self.get_input_tensor(node, 3).ok()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let x = if let Some(prep) = prepend {
|
||||
prep.concat_along(input, dim)
|
||||
} else {
|
||||
input
|
||||
};
|
||||
|
||||
let dim_size = x.shape.dims[dim];
|
||||
let front = x.slice_along(Expression::from(1)..dim_size, dim);
|
||||
let back = x.slice_along(Expression::from(0)..dim_size - 1, dim);
|
||||
Ok(front - back)
|
||||
}
|
||||
|
||||
pub(crate) fn translate_tril(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
self.translate_triangular(node, false)
|
||||
}
|
||||
|
||||
pub(crate) fn translate_triu(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
self.translate_triangular(node, true)
|
||||
}
|
||||
|
||||
fn translate_triangular(&mut self, node: &Node, upper: bool) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let diagonal = if node.inputs.len() > 1 {
|
||||
self.get_int_arg(node, 1).unwrap_or(0) as i32
|
||||
} else {
|
||||
0
|
||||
};
|
||||
let dims = a.shape.dims;
|
||||
let rows = dims[dims.len() - 2];
|
||||
let cols = dims[dims.len() - 1];
|
||||
let (r_val, c_val) = match (rows.to_usize(), cols.to_usize()) {
|
||||
(Some(r), Some(c)) => (r, c),
|
||||
_ => anyhow::bail!("tril/triu requires concrete matrix dimensions"),
|
||||
};
|
||||
let size = r_val.max(c_val);
|
||||
let mask = if upper {
|
||||
self.graph.triu(size, diagonal)
|
||||
} else {
|
||||
self.graph.tril(size, diagonal)
|
||||
}
|
||||
.cast(DType::F32);
|
||||
let mask = if rows != cols {
|
||||
mask.slice_along(0..r_val, 0).slice_along(0..c_val, 1)
|
||||
} else {
|
||||
mask
|
||||
};
|
||||
let mut mask_expanded = mask;
|
||||
for i in (0..dims.len() - 2).rev() {
|
||||
mask_expanded = mask_expanded.expand_dim(0, dims[i]);
|
||||
}
|
||||
Ok(a * mask_expanded)
|
||||
}
|
||||
|
||||
pub(crate) fn translate_topk(&mut self, node: &Node) -> Result<()> {
|
||||
@@ -109,6 +200,21 @@ impl<'a> Translator<'a> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn translate_one_hot(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let num_classes = self.get_int_arg(node, 1)? as usize;
|
||||
// one_hot: output[..., i] = 1 if input[...] == i else 0
|
||||
let a_int = a.cast(DType::Int);
|
||||
let classes = self.graph.arange(num_classes);
|
||||
// Expand a to [..., 1] and classes to [..., num_classes]
|
||||
let a_expanded = a_int.expand_dim(a.shape.len(), num_classes);
|
||||
let mut classes_expanded = classes;
|
||||
for d in a.shape.dims.iter().rev() {
|
||||
classes_expanded = classes_expanded.expand_dim(0, *d);
|
||||
}
|
||||
Ok(a_expanded.eq(classes_expanded).cast(DType::Int))
|
||||
}
|
||||
|
||||
pub(crate) fn translate_wrap_set_grad(&mut self, node: &Node) -> Result<()> {
|
||||
let subgraph = node.inputs[1]
|
||||
.arg
|
||||
|
||||
@@ -29,6 +29,36 @@ impl<'a> Translator<'a> {
|
||||
Ok(a)
|
||||
}
|
||||
|
||||
pub(crate) fn translate_to_dtype(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
if let Some(dtype_int) = node.inputs.get(1).and_then(|i| i.arg.as_scalar_type()) {
|
||||
let dtype = torch_dtype_int_to_luminal(dtype_int);
|
||||
Ok(a.cast(dtype))
|
||||
} else if let Some(dtype_int) = node.inputs.get(1).and_then(|i| i.arg.as_int()) {
|
||||
let dtype = torch_dtype_int_to_luminal(dtype_int as u32);
|
||||
Ok(a.cast(dtype))
|
||||
} else {
|
||||
Ok(a)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn translate_to_dtype_layout(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
for input in &node.inputs {
|
||||
if input.name == "dtype" {
|
||||
if let Some(dtype_int) = input.arg.as_scalar_type() {
|
||||
let dtype = torch_dtype_int_to_luminal(dtype_int);
|
||||
return Ok(a.cast(dtype));
|
||||
}
|
||||
if let Some(dtype_int) = input.arg.as_int() {
|
||||
let dtype = torch_dtype_int_to_luminal(dtype_int as u32);
|
||||
return Ok(a.cast(dtype));
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(a)
|
||||
}
|
||||
|
||||
pub(crate) fn translate_layer_norm(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let input = self.get_input_tensor(node, 0)?;
|
||||
let normalized_shape = self.get_ints_arg(node, 1)?;
|
||||
|
||||
@@ -1,352 +0,0 @@
|
||||
//! Dtype-aware buffer type for the luminal_python bridge.
|
||||
//!
|
||||
//! `TypedData` wraps raw bytes with a `DType` tag, enabling multi-dtype data flow
|
||||
//! through the PT2 path without forcing everything to f32.
|
||||
|
||||
use luminal::hlir::NativeData;
|
||||
use luminal::prelude::tracing::warn;
|
||||
use luminal::prelude::*;
|
||||
|
||||
/// A dtype-tagged byte buffer. All weight, constant, and input data flows through this type.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct TypedData {
|
||||
pub bytes: Vec<u8>,
|
||||
pub dtype: DType,
|
||||
}
|
||||
|
||||
impl TypedData {
|
||||
/// Wrap raw bytes with a dtype tag. Caller must ensure bytes are correctly formatted.
|
||||
pub fn from_raw(bytes: Vec<u8>, dtype: DType) -> Self {
|
||||
Self { bytes, dtype }
|
||||
}
|
||||
|
||||
/// Number of bytes in the buffer
|
||||
pub fn n_bytes(&self) -> usize {
|
||||
self.bytes.len()
|
||||
}
|
||||
|
||||
/// Number of logical elements (for byte-aligned dtypes)
|
||||
pub fn n_elements(&self) -> usize {
|
||||
let bits = self.dtype.bits();
|
||||
if bits >= 8 {
|
||||
self.bytes.len() / (bits / 8)
|
||||
} else {
|
||||
// sub-byte types: multiple elements per byte
|
||||
self.bytes.len() * (8 / bits)
|
||||
}
|
||||
}
|
||||
|
||||
/// Read element at `idx` as f64 (used by From<TypedData> for NativeData fallback).
|
||||
fn as_f64(&self, idx: usize) -> f64 {
|
||||
match self.dtype {
|
||||
DType::F32 => {
|
||||
let start = idx * 4;
|
||||
f32::from_le_bytes([
|
||||
self.bytes[start],
|
||||
self.bytes[start + 1],
|
||||
self.bytes[start + 2],
|
||||
self.bytes[start + 3],
|
||||
]) as f64
|
||||
}
|
||||
DType::F64 => {
|
||||
let start = idx * 8;
|
||||
f64::from_le_bytes([
|
||||
self.bytes[start],
|
||||
self.bytes[start + 1],
|
||||
self.bytes[start + 2],
|
||||
self.bytes[start + 3],
|
||||
self.bytes[start + 4],
|
||||
self.bytes[start + 5],
|
||||
self.bytes[start + 6],
|
||||
self.bytes[start + 7],
|
||||
])
|
||||
}
|
||||
DType::F16 => {
|
||||
let start = idx * 2;
|
||||
half::f16::from_le_bytes([self.bytes[start], self.bytes[start + 1]]).to_f64()
|
||||
}
|
||||
DType::Bf16 => {
|
||||
let start = idx * 2;
|
||||
half::bf16::from_le_bytes([self.bytes[start], self.bytes[start + 1]]).to_f64()
|
||||
}
|
||||
DType::Int => {
|
||||
let start = idx * 4;
|
||||
i32::from_le_bytes([
|
||||
self.bytes[start],
|
||||
self.bytes[start + 1],
|
||||
self.bytes[start + 2],
|
||||
self.bytes[start + 3],
|
||||
]) as f64
|
||||
}
|
||||
DType::I8 => self.bytes[idx] as i8 as f64,
|
||||
DType::U8 => self.bytes[idx] as f64,
|
||||
DType::I16 | DType::U16 => {
|
||||
let start = idx * 2;
|
||||
let val = i16::from_le_bytes([self.bytes[start], self.bytes[start + 1]]);
|
||||
if self.dtype == DType::U16 {
|
||||
val as u16 as f64
|
||||
} else {
|
||||
val as f64
|
||||
}
|
||||
}
|
||||
DType::Bool => {
|
||||
if self.bytes[idx] != 0 {
|
||||
1.0
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
}
|
||||
_ => panic!("as_f64 not supported for {:?}", self.dtype),
|
||||
}
|
||||
}
|
||||
// -- Constructors from typed Vecs --
|
||||
|
||||
pub fn from_f32_vec(data: Vec<f32>) -> Self {
|
||||
let bytes = unsafe {
|
||||
std::slice::from_raw_parts(data.as_ptr() as *const u8, data.len() * 4).to_vec()
|
||||
};
|
||||
Self {
|
||||
bytes,
|
||||
dtype: DType::F32,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_f16_vec(data: Vec<half::f16>) -> Self {
|
||||
let bytes = unsafe {
|
||||
std::slice::from_raw_parts(data.as_ptr() as *const u8, data.len() * 2).to_vec()
|
||||
};
|
||||
Self {
|
||||
bytes,
|
||||
dtype: DType::F16,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_bf16_vec(data: Vec<half::bf16>) -> Self {
|
||||
let bytes = unsafe {
|
||||
std::slice::from_raw_parts(data.as_ptr() as *const u8, data.len() * 2).to_vec()
|
||||
};
|
||||
Self {
|
||||
bytes,
|
||||
dtype: DType::Bf16,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_i32_vec(data: Vec<i32>) -> Self {
|
||||
let bytes = unsafe {
|
||||
std::slice::from_raw_parts(data.as_ptr() as *const u8, data.len() * 4).to_vec()
|
||||
};
|
||||
Self {
|
||||
bytes,
|
||||
dtype: DType::Int,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_bool_vec(data: Vec<bool>) -> Self {
|
||||
let bytes: Vec<u8> = data.iter().map(|&b| b as u8).collect();
|
||||
Self {
|
||||
bytes,
|
||||
dtype: DType::Bool,
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert raw bytes from a PyTorch tensor (identified by PT2 dtype code) to TypedData
|
||||
/// in luminal's native format. Handles widening/narrowing conversions for types where
|
||||
/// PyTorch's byte layout differs from luminal's:
|
||||
/// - i64 → i32, f64 → f32 (luminal has no 64-bit types)
|
||||
/// - i16 → i32, u8 → i32, i8 → i32 (luminal maps all integer types to i32 for PT2)
|
||||
pub fn from_pytorch_bytes(bytes: Vec<u8>, dtype_code: u32) -> Self {
|
||||
match dtype_code {
|
||||
// Types that map directly — preserve raw bytes
|
||||
7 => Self::from_raw(bytes, DType::F32),
|
||||
6 => Self::from_raw(bytes, DType::F16),
|
||||
13 => Self::from_raw(bytes, DType::Bf16),
|
||||
4 => Self::from_raw(bytes, DType::Int), // i32
|
||||
12 => Self::from_raw(bytes, DType::Bool),
|
||||
// i64 → i32 (truncate)
|
||||
5 => {
|
||||
let i32s: Vec<i32> = bytes
|
||||
.chunks_exact(8)
|
||||
.map(|b| {
|
||||
i64::from_le_bytes([b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]]) as i32
|
||||
})
|
||||
.collect();
|
||||
Self::from_i32_vec(i32s)
|
||||
}
|
||||
// f64 → f32 (downcast)
|
||||
8 => {
|
||||
let f32s: Vec<f32> = bytes
|
||||
.chunks_exact(8)
|
||||
.map(|b| {
|
||||
f64::from_le_bytes([b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]]) as f32
|
||||
})
|
||||
.collect();
|
||||
Self::from_f32_vec(f32s)
|
||||
}
|
||||
// i16 → i32 (widen)
|
||||
3 => {
|
||||
let i32s: Vec<i32> = bytes
|
||||
.chunks_exact(2)
|
||||
.map(|b| i16::from_le_bytes([b[0], b[1]]) as i32)
|
||||
.collect();
|
||||
Self::from_i32_vec(i32s)
|
||||
}
|
||||
// u8 → i32 (widen)
|
||||
1 => {
|
||||
let i32s: Vec<i32> = bytes.iter().map(|&b| b as i32).collect();
|
||||
Self::from_i32_vec(i32s)
|
||||
}
|
||||
// i8 → i32 (widen, signed)
|
||||
2 => {
|
||||
let i32s: Vec<i32> = bytes.iter().map(|&b| (b as i8) as i32).collect();
|
||||
Self::from_i32_vec(i32s)
|
||||
}
|
||||
// Unknown: best-effort pass-through as f32
|
||||
_ => {
|
||||
warn!("Unrecognized pytorch dtype code {dtype_code}, interpreting as f32");
|
||||
Self::from_raw(bytes, DType::F32)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Create an n-element buffer of "safe" dummy values (1.0 for floats, 1 for ints, true for bool).
|
||||
/// IMPORTANT: Must use 1, NOT 0. Zero inputs cause NaN in many ops (fmod, recip, log, etc.).
|
||||
pub fn ones(n_elements: usize, dtype: DType) -> Self {
|
||||
match dtype {
|
||||
DType::F32 | DType::TF32 => Self::from_f32_vec(vec![1.0f32; n_elements]),
|
||||
DType::F64 => {
|
||||
let data = vec![1.0f64; n_elements];
|
||||
let bytes = unsafe {
|
||||
std::slice::from_raw_parts(data.as_ptr() as *const u8, data.len() * 8).to_vec()
|
||||
};
|
||||
Self {
|
||||
bytes,
|
||||
dtype: DType::F64,
|
||||
}
|
||||
}
|
||||
DType::F16 => Self::from_f16_vec(vec![half::f16::from_f32(1.0); n_elements]),
|
||||
DType::Bf16 => Self::from_bf16_vec(vec![half::bf16::from_f32(1.0); n_elements]),
|
||||
DType::Int => Self::from_i32_vec(vec![1i32; n_elements]),
|
||||
DType::I8 => Self::from_raw(vec![1u8; n_elements], DType::I8),
|
||||
DType::U8 => Self::from_raw(vec![1u8; n_elements], DType::U8),
|
||||
DType::I16 => {
|
||||
let data = vec![1i16; n_elements];
|
||||
let bytes = unsafe {
|
||||
std::slice::from_raw_parts(data.as_ptr() as *const u8, data.len() * 2).to_vec()
|
||||
};
|
||||
Self {
|
||||
bytes,
|
||||
dtype: DType::I16,
|
||||
}
|
||||
}
|
||||
DType::U16 => {
|
||||
let data = vec![1u16; n_elements];
|
||||
let bytes = unsafe {
|
||||
std::slice::from_raw_parts(data.as_ptr() as *const u8, data.len() * 2).to_vec()
|
||||
};
|
||||
Self {
|
||||
bytes,
|
||||
dtype: DType::U16,
|
||||
}
|
||||
}
|
||||
DType::Bool => Self::from_bool_vec(vec![true; n_elements]),
|
||||
_ => panic!("TypedData::ones not supported for {:?}", dtype),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert TypedData to NativeData for the native runtime.
|
||||
impl From<TypedData> for NativeData {
|
||||
fn from(td: TypedData) -> Self {
|
||||
match td.dtype {
|
||||
DType::F32 | DType::TF32 => {
|
||||
let data: Vec<f32> = td
|
||||
.bytes
|
||||
.chunks_exact(4)
|
||||
.map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
|
||||
.collect();
|
||||
NativeData::F32(data)
|
||||
}
|
||||
DType::F64 => {
|
||||
// Downcast f64 -> f32 for native runtime (which only has F32 variant for floats > 32-bit)
|
||||
let data: Vec<f32> = td
|
||||
.bytes
|
||||
.chunks_exact(8)
|
||||
.map(|b| {
|
||||
f64::from_le_bytes([b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]]) as f32
|
||||
})
|
||||
.collect();
|
||||
NativeData::F32(data)
|
||||
}
|
||||
DType::F16 => {
|
||||
let data: Vec<half::f16> = td
|
||||
.bytes
|
||||
.chunks_exact(2)
|
||||
.map(|b| half::f16::from_le_bytes([b[0], b[1]]))
|
||||
.collect();
|
||||
NativeData::F16(data)
|
||||
}
|
||||
DType::Bf16 => {
|
||||
let data: Vec<half::bf16> = td
|
||||
.bytes
|
||||
.chunks_exact(2)
|
||||
.map(|b| half::bf16::from_le_bytes([b[0], b[1]]))
|
||||
.collect();
|
||||
NativeData::Bf16(data)
|
||||
}
|
||||
DType::Int => {
|
||||
let data: Vec<i32> = td
|
||||
.bytes
|
||||
.chunks_exact(4)
|
||||
.map(|b| i32::from_le_bytes([b[0], b[1], b[2], b[3]]))
|
||||
.collect();
|
||||
NativeData::Int(data)
|
||||
}
|
||||
DType::Bool => {
|
||||
let data: Vec<bool> = td.bytes.iter().map(|&b| b != 0).collect();
|
||||
NativeData::Bool(data)
|
||||
}
|
||||
// Integer types that map to NativeData::Int
|
||||
DType::I8 => {
|
||||
let data: Vec<i32> = td.bytes.iter().map(|&b| b as i8 as i32).collect();
|
||||
NativeData::Int(data)
|
||||
}
|
||||
DType::U8 => {
|
||||
let data: Vec<i32> = td.bytes.iter().map(|&b| b as i32).collect();
|
||||
NativeData::Int(data)
|
||||
}
|
||||
DType::I16 => {
|
||||
let data: Vec<i32> = td
|
||||
.bytes
|
||||
.chunks_exact(2)
|
||||
.map(|b| i16::from_le_bytes([b[0], b[1]]) as i32)
|
||||
.collect();
|
||||
NativeData::Int(data)
|
||||
}
|
||||
DType::U16 => {
|
||||
let data: Vec<i32> = td
|
||||
.bytes
|
||||
.chunks_exact(2)
|
||||
.map(|b| u16::from_le_bytes([b[0], b[1]]) as i32)
|
||||
.collect();
|
||||
NativeData::Int(data)
|
||||
}
|
||||
// Sub-byte and F8 types: store as raw f32 for native runtime (best effort)
|
||||
_ => {
|
||||
// For exotic types, the native runtime can't handle them natively.
|
||||
// Store as f32 with element-wise conversion.
|
||||
let data: Vec<f32> = (0..td.n_elements()).map(|i| td.as_f64(i) as f32).collect();
|
||||
NativeData::F32(data)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert &TypedData to NativeData (clone the bytes).
|
||||
impl From<&TypedData> for NativeData {
|
||||
fn from(td: &TypedData) -> Self {
|
||||
td.clone().into()
|
||||
}
|
||||
}
|
||||
|
||||
// CUDA runtime conversion is implemented via ToCudaInput in runtime.rs
|
||||
// (behind the `cuda` feature gate) since it depends on cudarc types.
|
||||
465
crates/luminal_python/rust/src/util.rs
Normal file
465
crates/luminal_python/rust/src/util.rs
Normal file
@@ -0,0 +1,465 @@
|
||||
use std::{collections::HashMap, fs, path::Path};
|
||||
|
||||
use luminal::{prelude::GraphTensor, shape::Expression};
|
||||
use onnx_protobuf::NodeProto;
|
||||
|
||||
/// Maps ONNX dim_param names (e.g. "seq_len") to luminal Expression variable chars ('a'..'w').
|
||||
pub type DimParamMap = HashMap<String, char>;
|
||||
|
||||
// Given a Value from the Onnx proto return its tensor Shape, if it exists
|
||||
// Note: some times pytorch will create tensors with a 0 shape
|
||||
// we might want to handle, 0 shape and No shape as seperate ideas
|
||||
pub fn get_shape_for_onnx_value(value: &onnx_protobuf::ValueInfoProto) -> Vec<usize> {
|
||||
if let Some(type_proto) = value.type_.as_ref()
|
||||
&& let Some(onnx_protobuf::type_proto::Value::TensorType(tensor)) = &type_proto.value
|
||||
&& let Some(shape) = tensor.shape.as_ref()
|
||||
{
|
||||
// Scalar (0-dim) tensors have an empty dim list; represent as [1] in luminal
|
||||
if shape.dim.is_empty() {
|
||||
return vec![1];
|
||||
}
|
||||
return shape
|
||||
.dim
|
||||
.iter()
|
||||
.map(|dimension| {
|
||||
if let Some(onnx_protobuf::tensor_shape_proto::dimension::Value::DimValue(v)) =
|
||||
&dimension.value
|
||||
{
|
||||
*v as usize
|
||||
} else {
|
||||
1
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
}
|
||||
|
||||
vec![]
|
||||
}
|
||||
|
||||
/// Like `get_shape_for_onnx_value`, but returns `Vec<Expression>` with symbolic vars for DimParam dims.
|
||||
/// Allocates new variable chars in `dim_param_map` for unseen dim_param names.
|
||||
/// `next_char` is updated to the next available char after allocation.
|
||||
pub fn get_shape_for_onnx_value_expr(
|
||||
value: &onnx_protobuf::ValueInfoProto,
|
||||
dim_param_map: &mut DimParamMap,
|
||||
next_char: &mut char,
|
||||
) -> Vec<Expression> {
|
||||
if let Some(type_proto) = value.type_.as_ref()
|
||||
&& let Some(onnx_protobuf::type_proto::Value::TensorType(tensor)) = &type_proto.value
|
||||
&& let Some(shape) = tensor.shape.as_ref()
|
||||
{
|
||||
if shape.dim.is_empty() {
|
||||
return vec![Expression::from(1usize)];
|
||||
}
|
||||
return shape
|
||||
.dim
|
||||
.iter()
|
||||
.map(|dimension| match &dimension.value {
|
||||
Some(onnx_protobuf::tensor_shape_proto::dimension::Value::DimValue(v)) => {
|
||||
Expression::from(*v as usize)
|
||||
}
|
||||
Some(onnx_protobuf::tensor_shape_proto::dimension::Value::DimParam(name)) => {
|
||||
let ch = *dim_param_map.entry(name.clone()).or_insert_with(|| {
|
||||
let c = *next_char;
|
||||
*next_char = (c as u8 + 1) as char;
|
||||
c
|
||||
});
|
||||
Expression::from(ch)
|
||||
}
|
||||
_ => Expression::from(1usize),
|
||||
})
|
||||
.collect();
|
||||
}
|
||||
|
||||
vec![]
|
||||
}
|
||||
|
||||
/// Compute the broadcast output shape for two tensors using Expressions (numpy rules).
|
||||
pub fn compute_broadcast_shape_expr(a: &[Expression], b: &[Expression]) -> Vec<Expression> {
|
||||
let max_rank = a.len().max(b.len());
|
||||
let mut result = Vec::with_capacity(max_rank);
|
||||
|
||||
for i in 0..max_rank {
|
||||
let a_dim = if i < max_rank - a.len() {
|
||||
Expression::from(1usize)
|
||||
} else {
|
||||
a[i - (max_rank - a.len())]
|
||||
};
|
||||
let b_dim = if i < max_rank - b.len() {
|
||||
Expression::from(1usize)
|
||||
} else {
|
||||
b[i - (max_rank - b.len())]
|
||||
};
|
||||
|
||||
// If both are concrete, use max. If one is 1, use the other.
|
||||
// Otherwise, assume they match (same symbolic dim).
|
||||
let dim = match (a_dim.to_usize(), b_dim.to_usize()) {
|
||||
(Some(a_val), Some(b_val)) => Expression::from(a_val.max(b_val)),
|
||||
(Some(1), _) => b_dim,
|
||||
(_, Some(1)) => a_dim,
|
||||
_ => a_dim, // Both symbolic — assume compatible
|
||||
};
|
||||
result.push(dim);
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
/// Broadcast a tensor's shape to match a target Expression shape (numpy-style broadcasting).
|
||||
/// Left-pads with size-1 dims, then expands dims that are 1 to match target.
|
||||
pub fn broadcast_to_expr(mut tensor: GraphTensor, target_shape: &[Expression]) -> GraphTensor {
|
||||
let src_dims = tensor.dims();
|
||||
let src_len = src_dims.len();
|
||||
let tgt_len = target_shape.len();
|
||||
|
||||
if src_len == tgt_len {
|
||||
tensor.shape.expand(target_shape.to_vec());
|
||||
return tensor;
|
||||
}
|
||||
|
||||
// Left-pad with size-1 dims
|
||||
for _ in 0..(tgt_len - src_len) {
|
||||
tensor = tensor.expand_dim(0, 1);
|
||||
}
|
||||
|
||||
tensor.shape.expand(target_shape.to_vec());
|
||||
tensor
|
||||
}
|
||||
|
||||
/// Convert inline data from a TensorProto to f32, based on data_type.
|
||||
/// Returns None if the tensor has no inline data (e.g. external storage).
|
||||
fn convert_inline_data(init: &onnx_protobuf::TensorProto) -> Option<Vec<f32>> {
|
||||
match init.data_type {
|
||||
1 => {
|
||||
// FLOAT
|
||||
if !init.float_data.is_empty() {
|
||||
return Some(init.float_data.clone());
|
||||
}
|
||||
if !init.raw_data.is_empty() {
|
||||
return Some(parse_raw_bytes_as_f32(&init.raw_data, 1));
|
||||
}
|
||||
}
|
||||
7 => {
|
||||
// INT64
|
||||
if !init.int64_data.is_empty() {
|
||||
return Some(init.int64_data.iter().map(|&v| v as f32).collect());
|
||||
}
|
||||
if !init.raw_data.is_empty() {
|
||||
return Some(parse_raw_bytes_as_f32(&init.raw_data, 7));
|
||||
}
|
||||
}
|
||||
6 => {
|
||||
// INT32
|
||||
if !init.int32_data.is_empty() {
|
||||
return Some(init.int32_data.iter().map(|&v| v as f32).collect());
|
||||
}
|
||||
if !init.raw_data.is_empty() {
|
||||
return Some(parse_raw_bytes_as_f32(&init.raw_data, 6));
|
||||
}
|
||||
}
|
||||
9 => {
|
||||
// BOOL
|
||||
if !init.raw_data.is_empty() {
|
||||
return Some(parse_raw_bytes_as_f32(&init.raw_data, 9));
|
||||
}
|
||||
if !init.int32_data.is_empty() {
|
||||
return Some(
|
||||
init.int32_data
|
||||
.iter()
|
||||
.map(|&v| if v != 0 { 1.0 } else { 0.0 })
|
||||
.collect(),
|
||||
);
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
// Fallback: try float_data or interpret raw_data as F32
|
||||
if !init.float_data.is_empty() {
|
||||
return Some(init.float_data.clone());
|
||||
}
|
||||
if !init.raw_data.is_empty() {
|
||||
return Some(parse_raw_bytes_as_f32(&init.raw_data, 1));
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Parse a raw byte slice as f32 values, respecting the ONNX data_type.
|
||||
fn parse_raw_bytes_as_f32(bytes: &[u8], data_type: i32) -> Vec<f32> {
|
||||
match data_type {
|
||||
1 => bytes
|
||||
.chunks_exact(4)
|
||||
.map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
|
||||
.collect(),
|
||||
7 => bytes
|
||||
.chunks_exact(8)
|
||||
.map(|c| i64::from_le_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]]) as f32)
|
||||
.collect(),
|
||||
6 => bytes
|
||||
.chunks_exact(4)
|
||||
.map(|c| i32::from_le_bytes([c[0], c[1], c[2], c[3]]) as f32)
|
||||
.collect(),
|
||||
9 => bytes
|
||||
.iter()
|
||||
.map(|&b| if b != 0 { 1.0 } else { 0.0 })
|
||||
.collect(),
|
||||
_ => bytes
|
||||
.chunks_exact(4)
|
||||
.map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Load float data from a TensorProto, handling inline (float_data/raw_data) and external storage.
|
||||
/// Prefer `load_all_tensor_floats` for batch loading (avoids redundant file reads).
|
||||
#[allow(dead_code)]
|
||||
pub fn load_tensor_floats(init: &onnx_protobuf::TensorProto, model_dir: &Path) -> Option<Vec<f32>> {
|
||||
// Try inline data first
|
||||
if let Some(floats) = convert_inline_data(init) {
|
||||
return Some(floats);
|
||||
}
|
||||
// Try external data (data_location == EXTERNAL = 1)
|
||||
if !init.external_data.is_empty() {
|
||||
let mut location: Option<&str> = None;
|
||||
let mut offset: u64 = 0;
|
||||
let mut length: Option<u64> = None;
|
||||
for entry in &init.external_data {
|
||||
match entry.key.as_str() {
|
||||
"location" => location = Some(&entry.value),
|
||||
"offset" => offset = entry.value.parse().unwrap_or(0),
|
||||
"length" => length = entry.value.parse().ok(),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
if let Some(loc) = location {
|
||||
let ext_path = model_dir.join(loc);
|
||||
match fs::read(&ext_path) {
|
||||
Ok(file_data) => {
|
||||
let start = offset as usize;
|
||||
let end = match length {
|
||||
Some(len) => start + len as usize,
|
||||
None => file_data.len(),
|
||||
};
|
||||
if end > file_data.len() {
|
||||
return None;
|
||||
}
|
||||
return Some(parse_raw_bytes_as_f32(
|
||||
&file_data[start..end],
|
||||
init.data_type,
|
||||
));
|
||||
}
|
||||
Err(_) => {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Batch-load float data from multiple TensorProtos, reading each external file only once.
|
||||
/// Returns results in the same order as `inits`, with `None` for tensors that couldn't be loaded.
|
||||
pub fn load_all_tensor_floats(
|
||||
inits: &[onnx_protobuf::TensorProto],
|
||||
model_dir: &Path,
|
||||
) -> Vec<(String, Option<Vec<f32>>)> {
|
||||
let mut results: Vec<(String, Option<Vec<f32>>)> = Vec::with_capacity(inits.len());
|
||||
|
||||
// Pending external data entries: (result_index, offset, length, data_type)
|
||||
// grouped by file location
|
||||
type ExternalEntry = (usize, u64, Option<u64>, i32);
|
||||
let mut external_pending: HashMap<String, Vec<ExternalEntry>> = HashMap::new();
|
||||
|
||||
for (i, init) in inits.iter().enumerate() {
|
||||
// Try inline data first
|
||||
if let Some(floats) = convert_inline_data(init) {
|
||||
results.push((init.name.clone(), Some(floats)));
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check for external data
|
||||
if !init.external_data.is_empty() {
|
||||
let mut location: Option<String> = None;
|
||||
let mut offset: u64 = 0;
|
||||
let mut length: Option<u64> = None;
|
||||
for entry in &init.external_data {
|
||||
match entry.key.as_str() {
|
||||
"location" => location = Some(entry.value.clone()),
|
||||
"offset" => offset = entry.value.parse().unwrap_or(0),
|
||||
"length" => length = entry.value.parse().ok(),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
if let Some(loc) = location {
|
||||
// Push placeholder, will fill in later
|
||||
results.push((init.name.clone(), None));
|
||||
external_pending
|
||||
.entry(loc)
|
||||
.or_default()
|
||||
.push((i, offset, length, init.data_type));
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
results.push((init.name.clone(), None));
|
||||
}
|
||||
|
||||
// Read each external file once and extract all tensor slices
|
||||
for (loc, entries) in &external_pending {
|
||||
let ext_path = model_dir.join(loc);
|
||||
let file_data = match fs::read(&ext_path) {
|
||||
Ok(data) => data,
|
||||
Err(_) => continue, // results already have None
|
||||
};
|
||||
for &(idx, offset, length, data_type) in entries {
|
||||
let start = offset as usize;
|
||||
let end = match length {
|
||||
Some(len) => start + len as usize,
|
||||
None => file_data.len(),
|
||||
};
|
||||
if end > file_data.len() {
|
||||
continue;
|
||||
}
|
||||
results[idx].1 = Some(parse_raw_bytes_as_f32(&file_data[start..end], data_type));
|
||||
}
|
||||
}
|
||||
|
||||
results
|
||||
}
|
||||
|
||||
/// Load initializer data as f32 values, handling multiple ONNX data types.
|
||||
/// Used to seed known_values with small constant initializers for constant folding.
|
||||
pub fn load_initializer_as_f32(init: &onnx_protobuf::TensorProto) -> Option<Vec<f32>> {
|
||||
match init.data_type {
|
||||
1 => {
|
||||
// FLOAT
|
||||
if !init.float_data.is_empty() {
|
||||
Some(init.float_data.clone())
|
||||
} else if !init.raw_data.is_empty() {
|
||||
Some(
|
||||
init.raw_data
|
||||
.chunks_exact(4)
|
||||
.map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
|
||||
.collect(),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
7 => {
|
||||
// INT64
|
||||
if !init.int64_data.is_empty() {
|
||||
Some(init.int64_data.iter().map(|&v| v as f32).collect())
|
||||
} else if !init.raw_data.is_empty() {
|
||||
Some(
|
||||
init.raw_data
|
||||
.chunks_exact(8)
|
||||
.map(|c| {
|
||||
i64::from_le_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]])
|
||||
as f32
|
||||
})
|
||||
.collect(),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
6 => {
|
||||
// INT32
|
||||
if !init.int32_data.is_empty() {
|
||||
Some(init.int32_data.iter().map(|&v| v as f32).collect())
|
||||
} else if !init.raw_data.is_empty() {
|
||||
Some(
|
||||
init.raw_data
|
||||
.chunks_exact(4)
|
||||
.map(|c| i32::from_le_bytes([c[0], c[1], c[2], c[3]]) as f32)
|
||||
.collect(),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
16 => {
|
||||
// BFLOAT16 — 2 bytes per element, upper 16 bits of f32
|
||||
if !init.raw_data.is_empty() {
|
||||
Some(
|
||||
init.raw_data
|
||||
.chunks_exact(2)
|
||||
.map(|c| {
|
||||
let bits = u16::from_le_bytes([c[0], c[1]]);
|
||||
f32::from_bits((bits as u32) << 16)
|
||||
})
|
||||
.collect(),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
9 => {
|
||||
// BOOL — 1 byte per element, 0 → 0.0, non-zero → 1.0
|
||||
if !init.raw_data.is_empty() {
|
||||
Some(
|
||||
init.raw_data
|
||||
.iter()
|
||||
.map(|&b| if b != 0 { 1.0 } else { 0.0 })
|
||||
.collect(),
|
||||
)
|
||||
} else if !init.int32_data.is_empty() {
|
||||
Some(
|
||||
init.int32_data
|
||||
.iter()
|
||||
.map(|&v| if v != 0 { 1.0 } else { 0.0 })
|
||||
.collect(),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
11 => {
|
||||
// FLOAT64
|
||||
if !init.raw_data.is_empty() {
|
||||
Some(
|
||||
init.raw_data
|
||||
.chunks_exact(8)
|
||||
.map(|c| {
|
||||
f64::from_le_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]])
|
||||
as f32
|
||||
})
|
||||
.collect(),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get an integer attribute from a node, with a default value
|
||||
pub fn get_int_attr(node: &NodeProto, name: &str, default: i64) -> i64 {
|
||||
for attr in &node.attribute {
|
||||
if attr.name == name {
|
||||
return attr.i;
|
||||
}
|
||||
}
|
||||
default
|
||||
}
|
||||
|
||||
/// Get a string attribute from a node, with a default value
|
||||
pub fn get_str_attr(node: &NodeProto, name: &str, default: &str) -> String {
|
||||
for attr in &node.attribute {
|
||||
if attr.name == name {
|
||||
return String::from_utf8_lossy(&attr.s).into_owned();
|
||||
}
|
||||
}
|
||||
default.to_string()
|
||||
}
|
||||
|
||||
/// Get a float attribute from a node, with a default value
|
||||
pub fn get_float_attr(node: &NodeProto, name: &str, default: f32) -> f32 {
|
||||
for attr in &node.attribute {
|
||||
if attr.name == name {
|
||||
return attr.f;
|
||||
}
|
||||
}
|
||||
default
|
||||
}
|
||||
@@ -1,21 +1,23 @@
|
||||
"""Luminal Python bindings - PyTorch backend using Luminal."""
|
||||
|
||||
# Import Python components
|
||||
# Register DynamicCache pytree serialization once at import time
|
||||
from .cache_utils import _register_cache_serialization
|
||||
from .compiled_model import CompiledModel
|
||||
|
||||
# Import Rust extension components (built by maturin)
|
||||
# These are available directly in the package namespace
|
||||
from .luminal import CompiledGraph, process_pt2
|
||||
from .luminal import CompiledGraph, process_onnx, process_pt2
|
||||
from .main import luminal_backend
|
||||
|
||||
# Register DynamicCache pytree serialization once at import time
|
||||
from .cache_utils import _register_cache_serialization
|
||||
|
||||
_register_cache_serialization()
|
||||
|
||||
# Re-export everything for clean package interface
|
||||
__all__ = [
|
||||
"CompiledModel",
|
||||
"luminal_backend",
|
||||
"process_onnx",
|
||||
"CompiledGraph",
|
||||
"process_pt2",
|
||||
]
|
||||
|
||||
@@ -4,9 +4,6 @@ from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
from .dtype_util import code_to_torch_dtype
|
||||
from .dtype_util import torch_dtype_code as _torch_dtype_code
|
||||
|
||||
|
||||
class CompiledModel:
|
||||
"""Wrapper around CompiledGraph that handles PyTorch tensor conversion."""
|
||||
@@ -17,7 +14,7 @@ class CompiledModel:
|
||||
"""Initialize with a compiled CompiledGraph from Rust.
|
||||
|
||||
Args:
|
||||
graph_result: The CompiledGraph from luminal_python.process_pt2()
|
||||
graph_result: The CompiledGraph from luminal_python.process_onnx() or process_pt2()
|
||||
weight_refs: List of PyTorch tensors to keep alive (prevents GC of shared weights)
|
||||
input_names: Override for user input names. If None, uses graph_result.input_names.
|
||||
user_indices: When torch.compile lifts model parameters into extra args,
|
||||
@@ -32,14 +29,6 @@ class CompiledModel:
|
||||
self._weight_refs = weight_refs or []
|
||||
self._user_indices = user_indices
|
||||
self._is_cuda = graph_result.backend == "cuda"
|
||||
# Expected input dtypes from graph (used to convert user inputs)
|
||||
input_dtype_codes = graph_result.input_dtypes
|
||||
self._input_dtypes = [
|
||||
code_to_torch_dtype(input_dtype_codes[i])
|
||||
if i < len(input_dtype_codes)
|
||||
else torch.float32
|
||||
for i in range(len(self._input_names))
|
||||
]
|
||||
|
||||
def set_dim(self, param_name: str, value: int) -> None:
|
||||
"""Set a dynamic dimension value by its param name."""
|
||||
@@ -81,78 +70,44 @@ class CompiledModel:
|
||||
input_shapes = [list(t.shape) for t in user_inputs]
|
||||
self._graph.auto_set_dims_from_input_shapes(input_shapes)
|
||||
|
||||
# Set user input data via pointer.
|
||||
# Convert to the graph's expected dtype so bytes match the Input node's dtype tag.
|
||||
# Set user input data via pointer (avoids Python list conversion).
|
||||
# For CUDA inputs, keep references alive so the caching allocator doesn't
|
||||
# recycle GPU memory before run() reads the pointers.
|
||||
_input_refs = []
|
||||
for name, tensor, expected_dtype in zip(
|
||||
self._input_names, user_inputs, self._input_dtypes
|
||||
):
|
||||
for name, tensor in zip(self._input_names, user_inputs):
|
||||
if self._is_cuda and tensor.is_cuda:
|
||||
t = tensor.detach().contiguous().to(expected_dtype)
|
||||
n_bytes = t.numel() * t.element_size()
|
||||
self._graph.set_input_device_ptr(name, t.data_ptr(), n_bytes)
|
||||
t = tensor.detach().contiguous().float()
|
||||
self._graph.set_input_device_ptr(name, t.data_ptr(), t.numel() * 4)
|
||||
_input_refs.append(t)
|
||||
else:
|
||||
t = tensor.detach().cpu().contiguous()
|
||||
n_bytes = t.numel() * t.element_size()
|
||||
dtype_code = _torch_dtype_code(t.dtype)
|
||||
self._graph.set_input_from_ptr(name, t.data_ptr(), n_bytes, dtype_code)
|
||||
t = tensor.detach().cpu().contiguous().float()
|
||||
self._graph.set_input_from_ptr(name, t.data_ptr(), t.numel())
|
||||
|
||||
# Resolve output shapes before run() (needed for pre-allocation).
|
||||
# Run the graph
|
||||
self._graph.run()
|
||||
|
||||
# Get output shapes — resolve dynamically if needed
|
||||
if self._has_dynamic_dims:
|
||||
output_shapes = self._graph.resolve_output_shapes()
|
||||
else:
|
||||
output_shapes = self._output_shapes
|
||||
|
||||
output_dtype_codes = self._graph.output_dtypes
|
||||
|
||||
# CUDA zero-copy path: pre-allocate output tensors and register their device
|
||||
# pointers so the final kernel writes directly into PyTorch's buffer.
|
||||
_use_zero_copy = self._is_cuda and hasattr(self._graph, "set_output_device_ptr")
|
||||
output_tensors = []
|
||||
if _use_zero_copy:
|
||||
for i, (name, shape) in enumerate(zip(self._output_names, output_shapes)):
|
||||
out_dtype = (
|
||||
code_to_torch_dtype(output_dtype_codes[i])
|
||||
if i < len(output_dtype_codes)
|
||||
else torch.float32
|
||||
)
|
||||
out = torch.empty(shape, dtype=out_dtype, device=input_device)
|
||||
self._graph.set_output_device_ptr(
|
||||
name, out.data_ptr(), out.numel() * out.element_size()
|
||||
)
|
||||
output_tensors.append(out)
|
||||
|
||||
# Run the graph
|
||||
self._graph.run()
|
||||
|
||||
# Collect outputs
|
||||
if _use_zero_copy:
|
||||
# For aliased outputs that couldn't be zero-copied, fall back to DtoD copy.
|
||||
for name, out in zip(self._output_names, output_tensors):
|
||||
if not self._graph.output_is_zero_copy(name):
|
||||
self._graph.copy_output_to_device_ptr(
|
||||
name, out.data_ptr(), out.numel() * out.element_size()
|
||||
)
|
||||
outputs = output_tensors
|
||||
else:
|
||||
# Native path: retrieve as f32, then convert to target dtype if needed.
|
||||
outputs = []
|
||||
for i, (name, shape) in enumerate(zip(self._output_names, output_shapes)):
|
||||
out_dtype = (
|
||||
code_to_torch_dtype(output_dtype_codes[i])
|
||||
if i < len(output_dtype_codes)
|
||||
else torch.float32
|
||||
# Get outputs and convert back to PyTorch tensors on the same device as inputs.
|
||||
# For CUDA: DtoD copy avoids the DtoH + HtoD round-trip.
|
||||
outputs = []
|
||||
for name, shape in zip(self._output_names, output_shapes):
|
||||
if self._is_cuda and hasattr(self._graph, "copy_output_to_device_ptr"):
|
||||
out = torch.empty(shape, dtype=torch.float32, device=input_device)
|
||||
self._graph.copy_output_to_device_ptr(
|
||||
name, out.data_ptr(), out.numel() * 4
|
||||
)
|
||||
else:
|
||||
data = self._graph.get_output(name)
|
||||
out = (
|
||||
torch.tensor(data, dtype=torch.float32)
|
||||
.reshape(tuple(shape))
|
||||
.to(out_dtype)
|
||||
.to(input_device)
|
||||
)
|
||||
outputs.append(out)
|
||||
outputs.append(out)
|
||||
|
||||
return tuple(outputs)
|
||||
|
||||
@@ -1,28 +0,0 @@
|
||||
"""Shared dtype utility functions for the luminal Python Bridge"""
|
||||
|
||||
import torch
|
||||
|
||||
_TORCH_DTYPE_TO_CODE = {
|
||||
torch.uint8: 1,
|
||||
torch.int8: 2,
|
||||
torch.int16: 3,
|
||||
torch.int32: 4,
|
||||
torch.int64: 5,
|
||||
torch.float16: 6,
|
||||
torch.float32: 7,
|
||||
torch.float64: 8,
|
||||
torch.bool: 12,
|
||||
torch.bfloat16: 13,
|
||||
}
|
||||
|
||||
_CODE_TO_TORCH_DTYPE = {v: k for k, v in _TORCH_DTYPE_TO_CODE.items()}
|
||||
|
||||
|
||||
def torch_dtype_code(dtype):
|
||||
"""Map torch.dtype to PT2 dtype integer code."""
|
||||
return _TORCH_DTYPE_TO_CODE.get(dtype, 7) # default to f32
|
||||
|
||||
|
||||
def code_to_torch_dtype(code):
|
||||
"""Map PT2 dtype integer code to torch.dtype."""
|
||||
return _CODE_TO_TORCH_DTYPE.get(code, torch.float32)
|
||||
@@ -1,11 +1,16 @@
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import torch
|
||||
import torch._dynamo
|
||||
|
||||
from .dtype_util import torch_dtype_code as _torch_dtype_code
|
||||
import luminal
|
||||
|
||||
from .compiled_model import CompiledModel
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared helpers (used by PT2 path and compiled_model)
|
||||
# Shared helpers (used by both ONNX and PT2 paths)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@@ -18,8 +23,6 @@ def _detect_backend(example_inputs):
|
||||
def _collect_weight_pointers(weights, backend):
|
||||
"""Partition weight tensors into CUDA device pointers and CPU host pointers.
|
||||
|
||||
Preserves native dtype — no forced conversion to float32.
|
||||
|
||||
Args:
|
||||
weights: dict of name -> torch.Tensor
|
||||
backend: "cuda", "gpu", "cpu", or "native"
|
||||
@@ -28,28 +31,29 @@ def _collect_weight_pointers(weights, backend):
|
||||
(keep_alive, device_ptrs, cpu_ptrs) where:
|
||||
- keep_alive: list[Tensor] to prevent GC of shared weight memory
|
||||
- device_ptrs: {name: (device_ptr, n_bytes)}
|
||||
- cpu_ptrs: {name: (host_ptr, n_bytes, dtype_code)}
|
||||
- cpu_ptrs: {name: (host_ptr, n_elements)}
|
||||
"""
|
||||
keep_alive = []
|
||||
device_ptrs = {}
|
||||
cpu_ptrs = {}
|
||||
for name, tensor in weights.items():
|
||||
t = tensor.detach().contiguous()
|
||||
n_bytes = t.numel() * t.element_size()
|
||||
if t.dtype != torch.float32:
|
||||
t = t.float()
|
||||
if backend in ("cuda", "gpu") and t.is_cuda:
|
||||
keep_alive.append(t)
|
||||
device_ptrs[name] = (t.data_ptr(), n_bytes)
|
||||
device_ptrs[name] = (t.data_ptr(), t.numel() * 4)
|
||||
else:
|
||||
t = t.cpu() if t.is_cuda else t
|
||||
keep_alive.append(t)
|
||||
cpu_ptrs[name] = (t.data_ptr(), n_bytes, _torch_dtype_code(t.dtype))
|
||||
cpu_ptrs[name] = (t.data_ptr(), t.numel())
|
||||
return keep_alive, device_ptrs, cpu_ptrs
|
||||
|
||||
|
||||
def _load_cpu_weights(compiled_graph, cpu_weights):
|
||||
"""Load CPU weight data into a compiled graph after Rust compilation."""
|
||||
for name, (ptr, n_bytes, dtype_code) in cpu_weights.items():
|
||||
compiled_graph.set_weight_from_ptr(name, ptr, n_bytes, dtype_code)
|
||||
for name, (ptr, n_elements) in cpu_weights.items():
|
||||
compiled_graph.set_weight_from_ptr(name, ptr, n_elements)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -62,9 +66,85 @@ def luminal_backend(gm, example_inputs, options=None):
|
||||
|
||||
Usage:
|
||||
torch.compile(model, backend=luminal_backend)
|
||||
torch.compile(model, backend=luminal_backend, options={"export_mode": "pt2"})
|
||||
|
||||
Options:
|
||||
export_mode: "onnx" (default) or "pt2"
|
||||
opset: ONNX opset version (default 20)
|
||||
"""
|
||||
options = options or {}
|
||||
|
||||
# Env var override
|
||||
env_mode = os.getenv("LUMINAL_EXPORT_MODE", "").lower()
|
||||
export_mode = (
|
||||
env_mode if env_mode in ("pt2", "onnx") else options.get("export_mode", "onnx")
|
||||
)
|
||||
opset = options.get("opset", 20)
|
||||
|
||||
backend = _detect_backend(example_inputs)
|
||||
return _compile_pt2(gm, example_inputs, backend)
|
||||
|
||||
if export_mode == "pt2":
|
||||
return _compile_pt2(gm, example_inputs, backend)
|
||||
return _compile_onnx(gm, example_inputs, backend, opset=opset)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ONNX compilation path
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _compile_onnx(gm, example_inputs, backend, opset=20):
|
||||
"""ONNX compilation path."""
|
||||
# Identify weight vs user inputs from FX graph placeholders.
|
||||
# torch.compile lifts model parameters into graph inputs — we detect them by name prefix.
|
||||
weight_tensors = {} # onnx_name -> tensor
|
||||
user_indices = []
|
||||
ph_idx = 0
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == "placeholder":
|
||||
onnx_name = f"input_{ph_idx}"
|
||||
if node.name.startswith(("l_self_", "l_model_", "l__self_")):
|
||||
weight_tensors[onnx_name] = example_inputs[ph_idx]
|
||||
else:
|
||||
user_indices.append(ph_idx)
|
||||
ph_idx += 1
|
||||
|
||||
# Collect weight pointers for Rust (avoids duplicate GPU buffer allocation)
|
||||
weight_refs, weight_device_ptrs, cpu_weights = _collect_weight_pointers(
|
||||
weight_tensors, backend
|
||||
)
|
||||
|
||||
tmp = tempfile.NamedTemporaryFile(suffix=".onnx", delete=False)
|
||||
tmp_path = tmp.name
|
||||
tmp.close()
|
||||
_ = gm.eval()
|
||||
try:
|
||||
_ = torch.onnx.export(
|
||||
gm,
|
||||
tuple(example_inputs),
|
||||
tmp_path,
|
||||
opset_version=opset,
|
||||
input_names=[f"input_{i}" for i in range(len(example_inputs))],
|
||||
)
|
||||
|
||||
result = luminal.process_onnx(
|
||||
tmp_path, backend, weight_device_ptrs=weight_device_ptrs
|
||||
)
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
|
||||
# Load CPU weights after compilation
|
||||
_load_cpu_weights(result, cpu_weights)
|
||||
|
||||
# Only expose user input names to CompiledModel (weights are pre-loaded).
|
||||
# user_indices tells __call__ which args from torch.compile are real user inputs.
|
||||
user_input_names = [f"input_{i}" for i in user_indices]
|
||||
return CompiledModel(
|
||||
result,
|
||||
weight_refs=weight_refs,
|
||||
input_names=user_input_names,
|
||||
user_indices=user_indices,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -193,7 +193,6 @@ def compile(
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
**extra,
|
||||
)
|
||||
ep = ep.run_decompositions()
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
@@ -206,7 +205,6 @@ def compile(
|
||||
dynamic_shapes=None,
|
||||
**extra,
|
||||
)
|
||||
ep = ep.run_decompositions()
|
||||
|
||||
return _save_and_compile(ep, backend, search_iterations)
|
||||
|
||||
@@ -225,7 +223,6 @@ def pt2_backend(gm, example_inputs, backend=None):
|
||||
gm, user_inputs, original_weights = _reinternalize_lifted_params(gm, example_inputs)
|
||||
|
||||
ep = torch.export.export(gm, tuple(user_inputs), **_export_kwargs())
|
||||
ep = ep.run_decompositions()
|
||||
|
||||
# When using shared memory (original_weights), strip large weight buffers from
|
||||
# the EP before saving. The Rust side uses device pointers for these weights,
|
||||
|
||||
176
crates/luminal_python/tests/_llama38b_artifacts.py
Normal file
176
crates/luminal_python/tests/_llama38b_artifacts.py
Normal file
@@ -0,0 +1,176 @@
|
||||
"""Helpers for caching Llama 3.1-8B test artifacts under pytest cache."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from safetensors.torch import save_file
|
||||
from transformers import AutoConfig, LlamaForCausalLM
|
||||
|
||||
# This code is designed to be deleted.
|
||||
# We should not need to cache pt2 or onnx files to get reasonable compile performance.
|
||||
|
||||
MODEL_ID = "NousResearch/Meta-Llama-3.1-8B-Instruct"
|
||||
INPUT_IDS_LIST = [1, 2, 3, 4]
|
||||
INPUT_IDS = torch.tensor([INPUT_IDS_LIST], dtype=torch.long)
|
||||
ARTIFACT_SCHEMA_VERSION = 1
|
||||
ONNX_OPSET_VERSION = 20
|
||||
PT2_STRICT = False
|
||||
|
||||
_REF_LOGITS_META_KEY = "luminal_python/llama38b_artifacts/ref_logits_v1"
|
||||
_ONNX_META_KEY = "luminal_python/llama38b_artifacts/onnx_v1"
|
||||
_PT2_META_KEY = "luminal_python/llama38b_artifacts/pt2_v1"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Llama38BArtifactBundle:
|
||||
ref_logits_path: Path
|
||||
onnx_path: Path | None = None
|
||||
pt2_path: Path | None = None
|
||||
weights_path: Path | None = None
|
||||
|
||||
|
||||
def ensure_onnx_bundle(cache, cache_dir: Path) -> Llama38BArtifactBundle:
|
||||
"""Ensure ONNX artifacts and shared reference logits exist in pytest cache."""
|
||||
ref_logits_path = cache_dir / "ref_logits.pt"
|
||||
onnx_dir = cache_dir / "onnx"
|
||||
onnx_path = onnx_dir / "llama38b.onnx"
|
||||
|
||||
ref_metadata = _ref_logits_metadata()
|
||||
onnx_metadata = _onnx_metadata()
|
||||
needs_ref_logits = cache.get(_REF_LOGITS_META_KEY, None) != ref_metadata or not (
|
||||
ref_logits_path.is_file()
|
||||
)
|
||||
needs_onnx = cache.get(_ONNX_META_KEY, None) != onnx_metadata or not (
|
||||
onnx_path.is_file()
|
||||
)
|
||||
|
||||
if needs_ref_logits or needs_onnx:
|
||||
print(f"Generating cached ONNX artifacts for {MODEL_ID} in {cache_dir}")
|
||||
if needs_ref_logits:
|
||||
ref_logits_path.unlink(missing_ok=True)
|
||||
if needs_onnx:
|
||||
shutil.rmtree(onnx_dir, ignore_errors=True)
|
||||
onnx_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
model = _load_model()
|
||||
|
||||
if needs_ref_logits:
|
||||
ref_logits = _compute_ref_logits(model)
|
||||
torch.save(ref_logits, ref_logits_path)
|
||||
cache.set(_REF_LOGITS_META_KEY, ref_metadata)
|
||||
|
||||
if needs_onnx:
|
||||
torch.onnx.export(
|
||||
model,
|
||||
(INPUT_IDS,),
|
||||
str(onnx_path),
|
||||
opset_version=ONNX_OPSET_VERSION,
|
||||
input_names=["input_ids"],
|
||||
output_names=["logits"],
|
||||
)
|
||||
cache.set(_ONNX_META_KEY, onnx_metadata)
|
||||
|
||||
return Llama38BArtifactBundle(ref_logits_path=ref_logits_path, onnx_path=onnx_path)
|
||||
|
||||
|
||||
def ensure_pt2_bundle(cache, cache_dir: Path) -> Llama38BArtifactBundle:
|
||||
"""Ensure PT2 artifacts and shared reference logits exist in pytest cache."""
|
||||
ref_logits_path = cache_dir / "ref_logits.pt"
|
||||
pt2_dir = cache_dir / "pt2"
|
||||
pt2_path = pt2_dir / "llama38b.pt2"
|
||||
weights_path = pt2_dir / "llama38b_weights.safetensors"
|
||||
|
||||
ref_metadata = _ref_logits_metadata()
|
||||
pt2_metadata = _pt2_metadata()
|
||||
needs_ref_logits = cache.get(_REF_LOGITS_META_KEY, None) != ref_metadata or not (
|
||||
ref_logits_path.is_file()
|
||||
)
|
||||
needs_pt2 = cache.get(_PT2_META_KEY, None) != pt2_metadata or not (
|
||||
pt2_path.is_file() and weights_path.is_file()
|
||||
)
|
||||
|
||||
if needs_ref_logits or needs_pt2:
|
||||
print(f"Generating cached PT2 artifacts for {MODEL_ID} in {cache_dir}")
|
||||
if needs_ref_logits:
|
||||
ref_logits_path.unlink(missing_ok=True)
|
||||
if needs_pt2:
|
||||
shutil.rmtree(pt2_dir, ignore_errors=True)
|
||||
pt2_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
model = _load_model()
|
||||
|
||||
if needs_ref_logits:
|
||||
ref_logits = _compute_ref_logits(model)
|
||||
torch.save(ref_logits, ref_logits_path)
|
||||
cache.set(_REF_LOGITS_META_KEY, ref_metadata)
|
||||
|
||||
if needs_pt2:
|
||||
exported_program = torch.export.export(
|
||||
model, (INPUT_IDS,), strict=PT2_STRICT
|
||||
)
|
||||
torch.export.save(exported_program, str(pt2_path))
|
||||
|
||||
state_dict = {
|
||||
key: value.float().clone()
|
||||
for key, value in exported_program.state_dict.items()
|
||||
}
|
||||
save_file(state_dict, str(weights_path))
|
||||
cache.set(_PT2_META_KEY, pt2_metadata)
|
||||
|
||||
return Llama38BArtifactBundle(
|
||||
ref_logits_path=ref_logits_path,
|
||||
pt2_path=pt2_path,
|
||||
weights_path=weights_path,
|
||||
)
|
||||
|
||||
|
||||
def _load_model() -> LlamaForCausalLM:
|
||||
config = AutoConfig.from_pretrained(MODEL_ID)
|
||||
config.use_cache = False
|
||||
config._attn_implementation = "eager"
|
||||
|
||||
return LlamaForCausalLM.from_pretrained(
|
||||
MODEL_ID,
|
||||
config=config,
|
||||
torch_dtype=torch.float32,
|
||||
).eval()
|
||||
|
||||
|
||||
def _compute_ref_logits(model: LlamaForCausalLM) -> torch.Tensor:
|
||||
with torch.no_grad():
|
||||
return model(INPUT_IDS).logits.clone()
|
||||
|
||||
|
||||
def _ref_logits_metadata() -> dict[str, object]:
|
||||
return {
|
||||
"schema_version": ARTIFACT_SCHEMA_VERSION,
|
||||
"model_id": MODEL_ID,
|
||||
"input_ids": INPUT_IDS_LIST,
|
||||
"device": "cpu",
|
||||
"torch_dtype": "float32",
|
||||
"use_cache": False,
|
||||
"attn_implementation": "eager",
|
||||
"torch_version": torch.__version__,
|
||||
"transformers_version": transformers.__version__,
|
||||
}
|
||||
|
||||
|
||||
def _onnx_metadata() -> dict[str, object]:
|
||||
return {
|
||||
**_ref_logits_metadata(),
|
||||
"artifact_type": "onnx",
|
||||
"opset_version": ONNX_OPSET_VERSION,
|
||||
}
|
||||
|
||||
|
||||
def _pt2_metadata() -> dict[str, object]:
|
||||
return {
|
||||
**_ref_logits_metadata(),
|
||||
"artifact_type": "pt2",
|
||||
"strict": PT2_STRICT,
|
||||
}
|
||||
194
crates/luminal_python/tests/_test_kimi_k25.py
Normal file
194
crates/luminal_python/tests/_test_kimi_k25.py
Normal file
@@ -0,0 +1,194 @@
|
||||
"""Kimi-K2.5 / DeepseekV3 model integration tests.
|
||||
|
||||
Tests the DeepseekV3 text backbone (MoE + MLA attention with LoRA-compressed KV,
|
||||
SwiGLU, YaRN RoPE) through the PyTorch -> ONNX -> luminal pipeline.
|
||||
|
||||
The model code requires trust_remote_code=True and uses custom HF modules from
|
||||
moonshotai/Kimi-K2.5. Since torch.compile cannot trace the MoE routing (it uses
|
||||
.numpy() and tensor indexing incompatible with dynamo), tests use manual ONNX
|
||||
export + onnxsim simplification + luminal.process_onnx.
|
||||
"""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import warnings
|
||||
|
||||
import onnx
|
||||
import onnxsim
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
|
||||
def _get_deepseek_v3_classes():
|
||||
"""Import DeepseekV3Config and DeepseekV3ForCausalLM from the Kimi-K2.5 HF repo."""
|
||||
import importlib
|
||||
|
||||
from transformers import AutoConfig
|
||||
|
||||
config = AutoConfig.from_pretrained("moonshotai/Kimi-K2.5", trust_remote_code=True)
|
||||
tc = config.text_config
|
||||
DeepseekV3Config = type(tc)
|
||||
pkg = DeepseekV3Config.__module__.rsplit(".", 1)[0]
|
||||
modeling_mod = importlib.import_module(f"{pkg}.modeling_deepseek")
|
||||
return DeepseekV3Config, modeling_mod.DeepseekV3ForCausalLM
|
||||
|
||||
|
||||
def _make_deepseek_v3_config(
|
||||
DeepseekV3Config,
|
||||
hidden_size: int = 64,
|
||||
num_attention_heads: int = 4,
|
||||
num_key_value_heads: int = 4,
|
||||
num_hidden_layers: int = 1,
|
||||
intermediate_size: int = 128,
|
||||
vocab_size: int = 256,
|
||||
kv_lora_rank: int = 16,
|
||||
q_lora_rank: int = 32,
|
||||
qk_nope_head_dim: int = 8,
|
||||
qk_rope_head_dim: int = 8,
|
||||
v_head_dim: int = 8,
|
||||
n_routed_experts: int = 4,
|
||||
num_experts_per_tok: int = 2,
|
||||
n_shared_experts: int = 1,
|
||||
moe_intermediate_size: int = 32,
|
||||
first_k_dense_replace: int = 1,
|
||||
):
|
||||
"""Create a small DeepseekV3Config for testing."""
|
||||
config = DeepseekV3Config(
|
||||
hidden_size=hidden_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
num_hidden_layers=num_hidden_layers,
|
||||
intermediate_size=intermediate_size,
|
||||
vocab_size=vocab_size,
|
||||
max_position_embeddings=128,
|
||||
kv_lora_rank=kv_lora_rank,
|
||||
q_lora_rank=q_lora_rank,
|
||||
qk_nope_head_dim=qk_nope_head_dim,
|
||||
qk_rope_head_dim=qk_rope_head_dim,
|
||||
v_head_dim=v_head_dim,
|
||||
n_routed_experts=n_routed_experts,
|
||||
num_experts_per_tok=num_experts_per_tok,
|
||||
n_shared_experts=n_shared_experts,
|
||||
moe_intermediate_size=moe_intermediate_size,
|
||||
first_k_dense_replace=first_k_dense_replace,
|
||||
use_cache=False,
|
||||
n_group=1,
|
||||
topk_group=1,
|
||||
topk_method="noaux_tc",
|
||||
scoring_func="sigmoid",
|
||||
rope_scaling={
|
||||
"type": "yarn",
|
||||
"rope_type": "yarn",
|
||||
"factor": 4.0,
|
||||
"original_max_position_embeddings": 32,
|
||||
"beta_fast": 32.0,
|
||||
"beta_slow": 1.0,
|
||||
"mscale": 1.0,
|
||||
"mscale_all_dim": 1.0,
|
||||
"rope_theta": 10000.0,
|
||||
},
|
||||
rope_theta=10000.0,
|
||||
)
|
||||
config._attn_implementation = "eager"
|
||||
return config
|
||||
|
||||
|
||||
def _export_and_simplify(model, input_ids):
|
||||
"""Export model to ONNX and simplify with onnxsim to constant-fold shape chains."""
|
||||
tmp = tempfile.NamedTemporaryFile(suffix=".onnx", delete=False)
|
||||
tmp_path = tmp.name
|
||||
tmp.close()
|
||||
try:
|
||||
torch.onnx.export(
|
||||
model,
|
||||
(input_ids,),
|
||||
tmp_path,
|
||||
opset_version=20,
|
||||
input_names=["input_ids"],
|
||||
output_names=["logits"],
|
||||
dynamo=False,
|
||||
)
|
||||
m = onnx.load(tmp_path)
|
||||
m_sim, check = onnxsim.simplify(m)
|
||||
assert check, "onnxsim simplification failed"
|
||||
onnx.save(m_sim, tmp_path)
|
||||
return tmp_path
|
||||
except Exception:
|
||||
os.unlink(tmp_path)
|
||||
raise
|
||||
|
||||
|
||||
def _run_deepseek_v3_test(config, DeepseekV3ForCausalLM, backend: str, atol: float):
|
||||
"""Export DeepseekV3 to ONNX, simplify, run through luminal, compare."""
|
||||
import luminal
|
||||
|
||||
model = DeepseekV3ForCausalLM(config).eval()
|
||||
input_ids = torch.tensor([[1, 2, 3, 4]])
|
||||
|
||||
onnx_path = _export_and_simplify(model, input_ids)
|
||||
try:
|
||||
graph = luminal.process_onnx(onnx_path, backend)
|
||||
graph.set_input("input_ids", [1.0, 2.0, 3.0, 4.0])
|
||||
graph.run()
|
||||
logits_data = graph.get_output("logits")
|
||||
logits = torch.tensor(logits_data, dtype=torch.float32).reshape(
|
||||
1, 4, config.vocab_size
|
||||
)
|
||||
finally:
|
||||
os.unlink(onnx_path)
|
||||
|
||||
with torch.no_grad():
|
||||
ref = model(input_ids)
|
||||
|
||||
assert torch.allclose(logits, ref.logits, atol=atol), (
|
||||
f"max_diff={torch.max(torch.abs(logits - ref.logits)).item():.2e}"
|
||||
)
|
||||
|
||||
|
||||
# ========== Tests ==========
|
||||
|
||||
|
||||
def test_deepseek_v3_tiny_dense():
|
||||
"""Tiny DeepseekV3 with dense MLP (no MoE): 64 hidden, 1 layer, MLA attention."""
|
||||
DeepseekV3Config, DeepseekV3ForCausalLM = _get_deepseek_v3_classes()
|
||||
config = _make_deepseek_v3_config(
|
||||
DeepseekV3Config,
|
||||
first_k_dense_replace=1, # all layers use dense MLP
|
||||
)
|
||||
backend = os.environ.get("LUMINAL_BACKEND", "native")
|
||||
_run_deepseek_v3_test(config, DeepseekV3ForCausalLM, backend, atol=1e-5)
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="MoE routing uses Int/F32 mixed ops not yet supported")
|
||||
def test_deepseek_v3_tiny_moe():
|
||||
"""Tiny DeepseekV3 with MoE: 64 hidden, 1 layer, 4 routed experts + 1 shared."""
|
||||
DeepseekV3Config, DeepseekV3ForCausalLM = _get_deepseek_v3_classes()
|
||||
config = _make_deepseek_v3_config(
|
||||
DeepseekV3Config,
|
||||
first_k_dense_replace=0, # all layers use MoE
|
||||
)
|
||||
backend = os.environ.get("LUMINAL_BACKEND", "native")
|
||||
_run_deepseek_v3_test(config, DeepseekV3ForCausalLM, backend, atol=1e-5)
|
||||
|
||||
|
||||
def test_deepseek_v3_small_dense():
|
||||
"""Small DeepseekV3 with dense MLP: 256 hidden, 1 layer."""
|
||||
DeepseekV3Config, DeepseekV3ForCausalLM = _get_deepseek_v3_classes()
|
||||
config = _make_deepseek_v3_config(
|
||||
DeepseekV3Config,
|
||||
hidden_size=256,
|
||||
num_attention_heads=8,
|
||||
num_key_value_heads=8,
|
||||
intermediate_size=512,
|
||||
vocab_size=1024,
|
||||
kv_lora_rank=32,
|
||||
q_lora_rank=64,
|
||||
qk_nope_head_dim=16,
|
||||
qk_rope_head_dim=16,
|
||||
v_head_dim=16,
|
||||
first_k_dense_replace=1,
|
||||
)
|
||||
backend = os.environ.get("LUMINAL_BACKEND", "native")
|
||||
_run_deepseek_v3_test(config, DeepseekV3ForCausalLM, backend, atol=1e-4)
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Qwen3-8B HuggingFace model integration tests.
|
||||
|
||||
Tests progressively larger HuggingFace Qwen3ForCausalLM configs through the
|
||||
PyTorch -> PT2 -> luminal pipeline via torch.compile. Qwen3 shares the same
|
||||
PyTorch -> ONNX -> luminal pipeline via torch.compile. Qwen3 shares the same
|
||||
architecture family as Llama (GQA, RoPE, SwiGLU MLP, RMSNorm).
|
||||
"""
|
||||
|
||||
|
||||
426
crates/luminal_python/tests/_test_qwen_image.py
Normal file
426
crates/luminal_python/tests/_test_qwen_image.py
Normal file
@@ -0,0 +1,426 @@
|
||||
"""Qwen-Image diffusion model integration tests.
|
||||
|
||||
Tests the QwenImageTransformer2DModel (MMDiT denoiser) and AutoencoderKLQwenImage (VAE)
|
||||
through the PyTorch -> ONNX -> luminal pipeline.
|
||||
|
||||
The transformer uses complex-valued RoPE (torch.view_as_complex) which isn't ONNX-exportable,
|
||||
so tests use a wrapper that pre-computes RoPE as real-valued cos/sin and replaces the
|
||||
attention processor with a real-valued equivalent.
|
||||
|
||||
The VAE uses Conv3d, which is supported via the N-dimensional unfold-based conv parser.
|
||||
"""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import warnings
|
||||
|
||||
import onnx
|
||||
import onnxsim
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Transformer helpers
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def _apply_rope_real(x, cos, sin):
|
||||
"""Apply RoPE using real-valued cos/sin. x: [B, S, H, D], cos/sin: [S, D/2]."""
|
||||
d = x.shape[-1]
|
||||
x1 = x[..., : d // 2]
|
||||
x2 = x[..., d // 2 :]
|
||||
cos = cos.unsqueeze(0).unsqueeze(2) # [1, S, 1, D/2]
|
||||
sin = sin.unsqueeze(0).unsqueeze(2)
|
||||
rotated_x1 = x1 * cos - x2 * sin
|
||||
rotated_x2 = x2 * cos + x1 * sin
|
||||
return torch.cat([rotated_x1, rotated_x2], dim=-1)
|
||||
|
||||
|
||||
class RealRoPEAttnProcessor:
|
||||
"""Attention processor that uses real-valued RoPE for ONNX compatibility.
|
||||
|
||||
Replaces the default QwenDoubleStreamAttnProcessor2_0 which uses
|
||||
torch.view_as_complex (not ONNX-exportable).
|
||||
"""
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn,
|
||||
hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
encoder_hidden_states_mask=None,
|
||||
attention_mask=None,
|
||||
image_rotary_emb=None,
|
||||
):
|
||||
seq_txt = encoder_hidden_states.shape[1]
|
||||
|
||||
img_query = attn.to_q(hidden_states)
|
||||
img_key = attn.to_k(hidden_states)
|
||||
img_value = attn.to_v(hidden_states)
|
||||
|
||||
txt_query = attn.add_q_proj(encoder_hidden_states)
|
||||
txt_key = attn.add_k_proj(encoder_hidden_states)
|
||||
txt_value = attn.add_v_proj(encoder_hidden_states)
|
||||
|
||||
img_query = img_query.unflatten(-1, (attn.heads, -1))
|
||||
img_key = img_key.unflatten(-1, (attn.heads, -1))
|
||||
img_value = img_value.unflatten(-1, (attn.heads, -1))
|
||||
|
||||
txt_query = txt_query.unflatten(-1, (attn.heads, -1))
|
||||
txt_key = txt_key.unflatten(-1, (attn.heads, -1))
|
||||
txt_value = txt_value.unflatten(-1, (attn.heads, -1))
|
||||
|
||||
if attn.norm_q is not None:
|
||||
img_query = attn.norm_q(img_query)
|
||||
if attn.norm_k is not None:
|
||||
img_key = attn.norm_k(img_key)
|
||||
if attn.norm_added_q is not None:
|
||||
txt_query = attn.norm_added_q(txt_query)
|
||||
if attn.norm_added_k is not None:
|
||||
txt_key = attn.norm_added_k(txt_key)
|
||||
|
||||
if image_rotary_emb is not None:
|
||||
img_cos, img_sin, txt_cos, txt_sin = image_rotary_emb
|
||||
img_query = _apply_rope_real(img_query, img_cos, img_sin)
|
||||
img_key = _apply_rope_real(img_key, img_cos, img_sin)
|
||||
txt_query = _apply_rope_real(txt_query, txt_cos, txt_sin)
|
||||
txt_key = _apply_rope_real(txt_key, txt_cos, txt_sin)
|
||||
|
||||
joint_query = torch.cat([txt_query, img_query], dim=1)
|
||||
joint_key = torch.cat([txt_key, img_key], dim=1)
|
||||
joint_value = torch.cat([txt_value, img_value], dim=1)
|
||||
|
||||
joint_query = joint_query.transpose(1, 2)
|
||||
joint_key = joint_key.transpose(1, 2)
|
||||
joint_value = joint_value.transpose(1, 2)
|
||||
joint_hidden = torch.nn.functional.scaled_dot_product_attention(
|
||||
joint_query, joint_key, joint_value, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
joint_hidden = joint_hidden.transpose(1, 2)
|
||||
joint_hidden = joint_hidden.flatten(2, 3)
|
||||
|
||||
txt_attn = joint_hidden[:, :seq_txt, :]
|
||||
img_attn = joint_hidden[:, seq_txt:, :]
|
||||
|
||||
img_attn = attn.to_out[0](img_attn.contiguous())
|
||||
if len(attn.to_out) > 1:
|
||||
img_attn = attn.to_out[1](img_attn)
|
||||
txt_attn = attn.to_add_out(txt_attn.contiguous())
|
||||
|
||||
return img_attn, txt_attn
|
||||
|
||||
|
||||
class TransformerONNXWrapper(nn.Module):
|
||||
"""Wraps QwenImageTransformer2DModel for ONNX export.
|
||||
|
||||
Pre-computes complex RoPE frequencies as real cos/sin buffers and replaces
|
||||
the attention processors with ONNX-friendly real-valued versions.
|
||||
"""
|
||||
|
||||
def __init__(self, model, img_shapes, txt_seq_len):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
for block in self.model.transformer_blocks:
|
||||
block.attn.set_processor(RealRoPEAttnProcessor())
|
||||
|
||||
with torch.no_grad():
|
||||
img_freqs, txt_freqs = model.pos_embed(
|
||||
img_shapes, max_txt_seq_len=txt_seq_len
|
||||
)
|
||||
self.register_buffer("img_cos", img_freqs.real.float().contiguous())
|
||||
self.register_buffer("img_sin", img_freqs.imag.float().contiguous())
|
||||
self.register_buffer("txt_cos", txt_freqs.real.float().contiguous())
|
||||
self.register_buffer("txt_sin", txt_freqs.imag.float().contiguous())
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states, timestep):
|
||||
hidden_states = self.model.img_in(hidden_states)
|
||||
timestep = timestep.to(hidden_states.dtype)
|
||||
|
||||
encoder_hidden_states = self.model.txt_norm(encoder_hidden_states)
|
||||
encoder_hidden_states = self.model.txt_in(encoder_hidden_states)
|
||||
|
||||
temb = self.model.time_text_embed(timestep, hidden_states)
|
||||
|
||||
rope = (self.img_cos, self.img_sin, self.txt_cos, self.txt_sin)
|
||||
|
||||
for block in self.model.transformer_blocks:
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_hidden_states_mask=None,
|
||||
temb=temb,
|
||||
image_rotary_emb=rope,
|
||||
)
|
||||
|
||||
hidden_states = self.model.norm_out(hidden_states, temb)
|
||||
output = self.model.proj_out(hidden_states)
|
||||
return output
|
||||
|
||||
|
||||
def _make_tiny_transformer_config():
|
||||
"""Tiny transformer config: ~100K params, 1 layer."""
|
||||
return dict(
|
||||
patch_size=2,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
num_layers=1,
|
||||
attention_head_dim=16,
|
||||
num_attention_heads=4,
|
||||
joint_attention_dim=64,
|
||||
axes_dims_rope=(4, 6, 6),
|
||||
)
|
||||
|
||||
|
||||
def _make_small_transformer_config():
|
||||
"""Small transformer config: ~1M params, 2 layers."""
|
||||
return dict(
|
||||
patch_size=2,
|
||||
in_channels=16,
|
||||
out_channels=16,
|
||||
num_layers=2,
|
||||
attention_head_dim=32,
|
||||
num_attention_heads=8,
|
||||
joint_attention_dim=256,
|
||||
axes_dims_rope=(8, 12, 12),
|
||||
)
|
||||
|
||||
|
||||
def _make_medium_transformer_config():
|
||||
"""Medium transformer config: ~39M params, 4 layers."""
|
||||
return dict(
|
||||
patch_size=2,
|
||||
in_channels=32,
|
||||
out_channels=32,
|
||||
num_layers=4,
|
||||
attention_head_dim=64,
|
||||
num_attention_heads=8,
|
||||
joint_attention_dim=512,
|
||||
axes_dims_rope=(8, 28, 28),
|
||||
)
|
||||
|
||||
|
||||
def _run_transformer_test(config, atol):
|
||||
"""Compile transformer with luminal backend, compare to PyTorch reference."""
|
||||
from diffusers.models import QwenImageTransformer2DModel
|
||||
|
||||
from luminal import luminal_backend
|
||||
|
||||
model = QwenImageTransformer2DModel(**config).eval()
|
||||
img_seq_len = 4
|
||||
txt_seq_len = 3
|
||||
|
||||
wrapper = TransformerONNXWrapper(model, [(1, 2, 2)], txt_seq_len).eval()
|
||||
wrapper_compiled = torch.compile(wrapper, backend=luminal_backend)
|
||||
|
||||
hidden = torch.randn(1, img_seq_len, config["in_channels"])
|
||||
encoder_hs = torch.randn(1, txt_seq_len, config["joint_attention_dim"])
|
||||
timestep = torch.tensor([1.0])
|
||||
|
||||
with torch.no_grad():
|
||||
ref = wrapper(hidden, encoder_hs, timestep)
|
||||
out = wrapper_compiled(hidden, encoder_hs, timestep)
|
||||
|
||||
assert torch.allclose(out, ref, atol=atol), (
|
||||
f"max_diff={torch.max(torch.abs(out - ref)).item():.2e}"
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# VAE helpers
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class _OnnxFriendlyUpsample(nn.Module):
|
||||
"""Replaces nn.Upsample with repeat_interleave for ONNX compatibility."""
|
||||
|
||||
def __init__(self, scale_factor):
|
||||
super().__init__()
|
||||
if isinstance(scale_factor, (tuple, list)):
|
||||
self.scale_factors = [int(s) for s in scale_factor]
|
||||
else:
|
||||
sf = int(scale_factor)
|
||||
self.scale_factors = [sf]
|
||||
|
||||
def forward(self, x):
|
||||
for dim_offset, sf in enumerate(self.scale_factors):
|
||||
if sf > 1:
|
||||
x = x.repeat_interleave(sf, dim=2 + dim_offset)
|
||||
return x
|
||||
|
||||
|
||||
def _make_tiny_vae_config():
|
||||
"""Tiny VAE config for testing."""
|
||||
return dict(
|
||||
base_dim=8,
|
||||
z_dim=4,
|
||||
dim_mult=[1, 2],
|
||||
num_res_blocks=1,
|
||||
attn_scales=[],
|
||||
temperal_downsample=[False],
|
||||
dropout=0.0,
|
||||
input_channels=3,
|
||||
)
|
||||
|
||||
|
||||
def _make_medium_vae_config():
|
||||
"""Medium VAE config: base_dim=32, z_dim=8."""
|
||||
return dict(
|
||||
base_dim=32,
|
||||
z_dim=8,
|
||||
dim_mult=[1, 2, 4],
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
temperal_downsample=[False, True],
|
||||
dropout=0.0,
|
||||
input_channels=3,
|
||||
)
|
||||
|
||||
|
||||
def _prepare_vae_for_onnx(vae):
|
||||
"""Replace non-ONNX-exportable modules in the VAE."""
|
||||
import diffusers.models.autoencoders.autoencoder_kl_qwenimage as vae_mod
|
||||
|
||||
def _replace(module):
|
||||
for name, child in module.named_children():
|
||||
if isinstance(child, vae_mod.QwenImageUpsample):
|
||||
setattr(module, name, _OnnxFriendlyUpsample(child.scale_factor))
|
||||
else:
|
||||
_replace(child)
|
||||
|
||||
_replace(vae)
|
||||
return vae
|
||||
|
||||
|
||||
class _VAEDecoderWrapper(nn.Module):
|
||||
def __init__(self, vae):
|
||||
super().__init__()
|
||||
self.vae = vae
|
||||
|
||||
def forward(self, z):
|
||||
return self.vae.decode(z).sample
|
||||
|
||||
|
||||
def _export_and_simplify(wrapper, inputs, input_names, output_names):
|
||||
"""Export model to ONNX and simplify with onnxsim."""
|
||||
tmp = tempfile.NamedTemporaryFile(suffix=".onnx", delete=False)
|
||||
tmp_path = tmp.name
|
||||
tmp.close()
|
||||
try:
|
||||
torch.onnx.export(
|
||||
wrapper,
|
||||
inputs,
|
||||
tmp_path,
|
||||
opset_version=20,
|
||||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
dynamo=False,
|
||||
)
|
||||
m = onnx.load(tmp_path)
|
||||
m_sim, check = onnxsim.simplify(m)
|
||||
assert check, "onnxsim simplification failed"
|
||||
onnx.save(m_sim, tmp_path)
|
||||
return tmp_path
|
||||
except Exception:
|
||||
os.unlink(tmp_path)
|
||||
raise
|
||||
|
||||
|
||||
def _run_vae_test(config, atol):
|
||||
"""Export VAE decoder to ONNX, run through luminal, compare."""
|
||||
from diffusers import AutoencoderKLQwenImage
|
||||
|
||||
import luminal
|
||||
|
||||
backend = os.environ.get("LUMINAL_BACKEND", "native")
|
||||
vae = AutoencoderKLQwenImage(**config).eval()
|
||||
vae = _prepare_vae_for_onnx(vae)
|
||||
|
||||
wrapper = _VAEDecoderWrapper(vae).eval()
|
||||
latents = torch.randn(1, config["z_dim"], 1, 4, 4)
|
||||
|
||||
with torch.no_grad():
|
||||
ref = wrapper(latents)
|
||||
|
||||
onnx_path = _export_and_simplify(wrapper, (latents,), ["latents"], ["output"])
|
||||
try:
|
||||
graph = luminal.process_onnx(onnx_path, backend)
|
||||
graph.set_input("latents", latents.flatten().tolist())
|
||||
graph.run()
|
||||
out_data = graph.get_output("output")
|
||||
out = torch.tensor(out_data, dtype=torch.float32).reshape(ref.shape)
|
||||
finally:
|
||||
os.unlink(onnx_path)
|
||||
|
||||
assert torch.allclose(out, ref, atol=atol), (
|
||||
f"max_diff={torch.max(torch.abs(out - ref)).item():.2e}"
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_qwen_image_transformer_tiny():
|
||||
"""Tiny QwenImage transformer: 1 layer, 4 heads, dim=64."""
|
||||
_run_transformer_test(_make_tiny_transformer_config(), atol=1e-4)
|
||||
|
||||
|
||||
def test_qwen_image_transformer_small():
|
||||
"""Small QwenImage transformer: 2 layers, 8 heads, dim=256."""
|
||||
_run_transformer_test(_make_small_transformer_config(), atol=1e-4)
|
||||
|
||||
|
||||
def test_qwen_image_transformer_medium():
|
||||
"""Medium QwenImage transformer: 4 layers, 8 heads, dim=512."""
|
||||
_run_transformer_test(_make_medium_transformer_config(), atol=1e-4)
|
||||
|
||||
|
||||
def test_qwen_image_transformer_full():
|
||||
"""Full QwenImage transformer (production defaults)."""
|
||||
from diffusers.models import QwenImageTransformer2DModel
|
||||
|
||||
from luminal import luminal_backend
|
||||
|
||||
model = QwenImageTransformer2DModel().eval()
|
||||
config = {k: v for k, v in dict(model.config).items() if not k.startswith("_")}
|
||||
|
||||
wrapper = TransformerONNXWrapper(model, [(1, 2, 2)], txt_seq_len=3).eval()
|
||||
wrapper_compiled = torch.compile(wrapper, backend=luminal_backend)
|
||||
|
||||
hidden = torch.randn(1, 4, config["in_channels"])
|
||||
encoder_hs = torch.randn(1, 3, config["joint_attention_dim"])
|
||||
timestep = torch.tensor([1.0])
|
||||
|
||||
with torch.no_grad():
|
||||
ref = wrapper(hidden, encoder_hs, timestep)
|
||||
out = wrapper_compiled(hidden, encoder_hs, timestep)
|
||||
|
||||
assert torch.allclose(out, ref, atol=1e-4), (
|
||||
f"max_diff={torch.max(torch.abs(out - ref)).item():.2e}"
|
||||
)
|
||||
|
||||
|
||||
def test_qwen_image_vae_decoder_tiny():
|
||||
"""Tiny QwenImage VAE decoder: base_dim=8, z_dim=4."""
|
||||
_run_vae_test(_make_tiny_vae_config(), atol=1e-3)
|
||||
|
||||
|
||||
def test_qwen_image_vae_decoder_medium():
|
||||
"""Medium QwenImage VAE decoder: base_dim=32, z_dim=8."""
|
||||
_run_vae_test(_make_medium_vae_config(), atol=1e-3)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Full production VAE -- expected to be slow/OOM")
|
||||
def test_qwen_image_vae_decoder_full():
|
||||
"""Full QwenImage VAE decoder (production defaults)."""
|
||||
from diffusers import AutoencoderKLQwenImage
|
||||
|
||||
config = dict(AutoencoderKLQwenImage().config)
|
||||
config = {k: v for k, v in config.items() if not k.startswith("_")}
|
||||
_run_vae_test(config, atol=1e-3)
|
||||
@@ -1,21 +1,64 @@
|
||||
"""Test configuration."""
|
||||
# ruff: noqa: E402
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
from urllib.request import urlopen
|
||||
import warnings
|
||||
|
||||
try:
|
||||
import huggingface_hub
|
||||
from transformers import logging as transformers_logging
|
||||
except ImportError: # pragma: no cover - optional for non-HF test environments
|
||||
huggingface_hub = None
|
||||
transformers_logging = None
|
||||
|
||||
# Enable automatic Rust rebuilds during test development
|
||||
try:
|
||||
import maturin_import_hook
|
||||
from maturin_import_hook.settings import MaturinSettings
|
||||
import maturin_import_hook
|
||||
from maturin_import_hook.settings import MaturinSettings
|
||||
from maturin_import_hook.project_importer import DefaultProjectFileSearcher
|
||||
|
||||
backend = os.getenv("LUMINAL_BACKEND", "native").lower()
|
||||
settings = MaturinSettings(features=["cuda"]) if backend == "cuda" else None
|
||||
maturin_import_hook.install(settings=settings)
|
||||
except ImportError:
|
||||
pass # Hook not available, rebuilds will be manual
|
||||
backend = os.getenv("LUMINAL_BACKEND", "native").lower()
|
||||
settings = MaturinSettings(
|
||||
release=(backend == "cuda"),
|
||||
features=["cuda"] if backend == "cuda" else None,
|
||||
skip_install=True,
|
||||
)
|
||||
searcher = DefaultProjectFileSearcher(
|
||||
source_excluded_dir_names=(
|
||||
DefaultProjectFileSearcher.DEFAULT_SOURCE_EXCLUDED_DIR_NAMES
|
||||
| {".claude", "docs", ".github", "examples"}
|
||||
),
|
||||
)
|
||||
maturin_import_hook.install(
|
||||
settings=settings,
|
||||
enable_automatic_installation=True,
|
||||
file_searcher=searcher,
|
||||
)
|
||||
logging.getLogger("maturin_import_hook").disabled = True
|
||||
logging.getLogger("maturin_import_hook.project_importer").disabled = True
|
||||
|
||||
# Silence noisy ONNX / onnxscript / httpx logging
|
||||
for _logger_name in (
|
||||
"onnxscript",
|
||||
"onnx_ir",
|
||||
"torch.onnx",
|
||||
"httpx",
|
||||
):
|
||||
logging.getLogger(_logger_name).setLevel(logging.WARNING)
|
||||
|
||||
# Suppress torch.onnx diagnostics/progress output and torchvision warnings
|
||||
os.environ.setdefault("TORCH_ONNX_VERBOSE", "0")
|
||||
os.environ.setdefault("TORCH_ONNX_LOG_LEVEL", "ERROR")
|
||||
warnings.filterwarnings("ignore", message=".*torchvision.*")
|
||||
warnings.filterwarnings("ignore", module="torch.onnx")
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch._dynamo
|
||||
from _llama38b_artifacts import ensure_onnx_bundle, ensure_pt2_bundle
|
||||
|
||||
torch.set_float32_matmul_precision("highest")
|
||||
|
||||
@@ -26,6 +69,139 @@ def device() -> torch.device:
|
||||
return torch.device("cuda") if backend == "cuda" else torch.device("cpu")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def configure_hf_test_output() -> None:
|
||||
if transformers_logging is not None:
|
||||
transformers_logging.disable_progress_bar()
|
||||
if huggingface_hub is not None:
|
||||
huggingface_hub.utils.disable_progress_bars()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def configure_dynamo():
|
||||
original_cache_size_limit = torch._dynamo.config.cache_size_limit
|
||||
original_suppress_errors = torch._dynamo.config.suppress_errors
|
||||
|
||||
def _configure(
|
||||
*, cache_size_limit: int | None = None, suppress_errors: bool | None = None
|
||||
) -> None:
|
||||
if cache_size_limit is not None:
|
||||
torch._dynamo.config.cache_size_limit = cache_size_limit
|
||||
if suppress_errors is not None:
|
||||
torch._dynamo.config.suppress_errors = suppress_errors
|
||||
|
||||
yield _configure
|
||||
|
||||
torch._dynamo.config.cache_size_limit = original_cache_size_limit
|
||||
torch._dynamo.config.suppress_errors = original_suppress_errors
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def _llama38b_cache_dir(pytestconfig: pytest.Config) -> Path:
|
||||
return pytestconfig.cache.mkdir("luminal_llama38b_artifacts_v1")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def _hf_multimodal_cache_dir(pytestconfig: pytest.Config) -> Path:
|
||||
return pytestconfig.cache.mkdir("luminal_hf_multimodal_v1")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def hf_multimodal_image_path(
|
||||
pytestconfig: pytest.Config, _hf_multimodal_cache_dir: Path
|
||||
) -> Path:
|
||||
image_url = (
|
||||
"https://huggingface.co/datasets/huggingface/documentation-images/"
|
||||
"resolve/main/bee.jpg"
|
||||
)
|
||||
image_path = _hf_multimodal_cache_dir / "bee.jpg"
|
||||
metadata_key = "luminal_python/hf_multimodal_image_v1"
|
||||
metadata = {
|
||||
"schema_version": 1,
|
||||
"url": image_url,
|
||||
"filename": image_path.name,
|
||||
}
|
||||
|
||||
needs_download = pytestconfig.cache.get(metadata_key, None) != metadata or not (
|
||||
image_path.is_file()
|
||||
)
|
||||
if not needs_download:
|
||||
return image_path
|
||||
|
||||
image_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
tmp_path: Path | None = None
|
||||
try:
|
||||
with urlopen(image_url, timeout=60) as response:
|
||||
with tempfile.NamedTemporaryFile(
|
||||
dir=image_path.parent, delete=False
|
||||
) as tmp_file:
|
||||
tmp_path = Path(tmp_file.name)
|
||||
while chunk := response.read(1024 * 1024):
|
||||
tmp_file.write(chunk)
|
||||
|
||||
assert tmp_path is not None
|
||||
tmp_path.replace(image_path)
|
||||
except Exception:
|
||||
if tmp_path is not None:
|
||||
tmp_path.unlink(missing_ok=True)
|
||||
raise
|
||||
|
||||
pytestconfig.cache.set(metadata_key, metadata)
|
||||
return image_path
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def _llama38b_onnx_bundle(pytestconfig: pytest.Config, _llama38b_cache_dir: Path):
|
||||
return ensure_onnx_bundle(pytestconfig.cache, _llama38b_cache_dir)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def _llama38b_pt2_bundle(pytestconfig: pytest.Config, _llama38b_cache_dir: Path):
|
||||
return ensure_pt2_bundle(pytestconfig.cache, _llama38b_cache_dir)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def llama38b_ref_logits(request: pytest.FixtureRequest) -> torch.Tensor:
|
||||
fixturenames = set(request.fixturenames)
|
||||
uses_onnx = "llama38b_onnx_path" in fixturenames
|
||||
uses_pt2 = bool({"llama38b_pt2_path", "llama38b_weights_path"} & fixturenames)
|
||||
|
||||
if uses_onnx and uses_pt2:
|
||||
raise pytest.UsageError(
|
||||
"llama38b_ref_logits cannot be requested with both ONNX and PT2 "
|
||||
"artifact fixtures in the same test"
|
||||
)
|
||||
if uses_onnx:
|
||||
bundle = request.getfixturevalue("_llama38b_onnx_bundle")
|
||||
elif uses_pt2:
|
||||
bundle = request.getfixturevalue("_llama38b_pt2_bundle")
|
||||
else:
|
||||
raise pytest.UsageError(
|
||||
"llama38b_ref_logits must be requested alongside llama38b_onnx_path "
|
||||
"or llama38b_pt2_path/llama38b_weights_path"
|
||||
)
|
||||
|
||||
return torch.load(bundle.ref_logits_path, weights_only=True)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def llama38b_onnx_path(_llama38b_onnx_bundle) -> Path:
|
||||
assert _llama38b_onnx_bundle.onnx_path is not None
|
||||
return _llama38b_onnx_bundle.onnx_path
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def llama38b_pt2_path(_llama38b_pt2_bundle) -> Path:
|
||||
assert _llama38b_pt2_bundle.pt2_path is not None
|
||||
return _llama38b_pt2_bundle.pt2_path
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def llama38b_weights_path(_llama38b_pt2_bundle) -> Path:
|
||||
assert _llama38b_pt2_bundle.weights_path is not None
|
||||
return _llama38b_pt2_bundle.weights_path
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True, scope="function")
|
||||
def reset_torch_dynamo():
|
||||
# We need this for two reasons
|
||||
|
||||
@@ -1,62 +0,0 @@
|
||||
"""Generate pre-computed PT2 artifacts for test_hf_llama38b_cached.
|
||||
|
||||
Run once:
|
||||
uv run python tests/generate_llama38b_pt2_artifacts.py
|
||||
|
||||
Produces:
|
||||
tests/llama38b.pt2 — torch.export of Llama 3.1-8B
|
||||
tests/llama38b_weights.safetensors — model weights
|
||||
tests/llama38b_ref_logits.pt — reference logits for input_ids=[1,2,3,4]
|
||||
(shared with PT2 artifact script)
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from safetensors.torch import save_file
|
||||
from transformers import AutoConfig, LlamaForCausalLM
|
||||
|
||||
SCRIPT_DIR = Path(__file__).resolve().parent
|
||||
PT2_PATH = SCRIPT_DIR / "llama38b.pt2"
|
||||
WEIGHTS_PATH = SCRIPT_DIR / "llama38b_weights.safetensors"
|
||||
LOGITS_PATH = SCRIPT_DIR / "llama38b_ref_logits.pt"
|
||||
|
||||
INPUT_IDS = torch.tensor([[1, 2, 3, 4]])
|
||||
|
||||
|
||||
def main():
|
||||
config = AutoConfig.from_pretrained("NousResearch/Meta-Llama-3.1-8B-Instruct")
|
||||
config.use_cache = False
|
||||
config._attn_implementation = "eager"
|
||||
|
||||
print("Loading model on CPU...")
|
||||
model = LlamaForCausalLM.from_pretrained(
|
||||
"NousResearch/Meta-Llama-3.1-8B-Instruct",
|
||||
config=config,
|
||||
torch_dtype=torch.float32,
|
||||
).eval()
|
||||
|
||||
# Generate reference logits (shared with PT2 artifact script)
|
||||
if not LOGITS_PATH.exists():
|
||||
print("Computing reference logits...")
|
||||
with torch.no_grad():
|
||||
ref_logits = model(INPUT_IDS).logits.clone()
|
||||
print(f"Reference logits shape: {ref_logits.shape}")
|
||||
print(f"Saving reference logits to {LOGITS_PATH}")
|
||||
torch.save(ref_logits, LOGITS_PATH)
|
||||
else:
|
||||
print(f"Reference logits already exist at {LOGITS_PATH}, skipping")
|
||||
|
||||
print(f"Exporting PT2 to {PT2_PATH}")
|
||||
ep = torch.export.export(model, (INPUT_IDS,), strict=False)
|
||||
torch.export.save(ep, str(PT2_PATH))
|
||||
|
||||
print(f"Saving weights to {WEIGHTS_PATH}")
|
||||
state_dict = {k: v.float().clone() for k, v in ep.state_dict.items()}
|
||||
save_file(state_dict, str(WEIGHTS_PATH))
|
||||
|
||||
print("Done.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
282
crates/luminal_python/tests/test_hf_causal_lm_config_options.py
Normal file
282
crates/luminal_python/tests/test_hf_causal_lm_config_options.py
Normal file
@@ -0,0 +1,282 @@
|
||||
"""Hugging Face causal-LM config option support tests.
|
||||
|
||||
These tests verify that luminal matches eager Hugging Face execution across
|
||||
supported causal-LM config options using tiny public model definitions loaded
|
||||
through AutoConfig.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch._dynamo
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, GenerationConfig
|
||||
from transformers.generation.configuration_utils import ContinuousBatchingConfig
|
||||
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
|
||||
from luminal import luminal_backend
|
||||
|
||||
# Attention implementations that require optional packages.
|
||||
_ATTN_REQUIRES_PACKAGE: dict[str, str] = {
|
||||
"flash_attention_2": "flash_attn",
|
||||
"flash_attention_3": "flash_attn_3",
|
||||
"flash_attention_4": "cutlass",
|
||||
}
|
||||
|
||||
# Attention implementations known to be incompatible with tiny random models
|
||||
# (e.g. head_dim < 16 or missing scaffolding).
|
||||
_ATTN_SKIP_TINY_MODEL: set[str] = {"flex_attention", "paged_attention"}
|
||||
|
||||
_PAGED_ATTN_IMPLEMENTATIONS = tuple(
|
||||
k for k in ALL_ATTENTION_FUNCTIONS.valid_keys() if k.startswith("paged|")
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _CausalLMConfigCase:
|
||||
case_id: str
|
||||
model_id: str
|
||||
input_ids: tuple[int, ...]
|
||||
atol: float
|
||||
rtol: float
|
||||
|
||||
|
||||
_MODEL_CASES = [
|
||||
_CausalLMConfigCase(
|
||||
case_id="llama_3.2_1B",
|
||||
model_id="meta-llama/Llama-3.2-1B",
|
||||
input_ids=(1, 2, 3, 4),
|
||||
atol=1e-5,
|
||||
rtol=1e-5,
|
||||
)
|
||||
]
|
||||
|
||||
_ATTN_IMPLEMENTATIONS = tuple(
|
||||
dict.fromkeys([None, "eager", *ALL_ATTENTION_FUNCTIONS.valid_keys()])
|
||||
)
|
||||
_CUDA_BACKEND_AVAILABLE = (
|
||||
os.getenv("LUMINAL_BACKEND", "native").lower() == "cuda"
|
||||
and torch.cuda.is_available()
|
||||
)
|
||||
|
||||
|
||||
def _attn_id(attn_impl: str | None) -> str:
|
||||
return "default" if attn_impl is None else attn_impl
|
||||
|
||||
|
||||
def _base_attn_impl(attn_impl: str | None) -> str | None:
|
||||
if attn_impl is None:
|
||||
return None
|
||||
if attn_impl.startswith("paged|"):
|
||||
return attn_impl.split("|", maxsplit=1)[1]
|
||||
return attn_impl
|
||||
|
||||
|
||||
def _attn_param(attn_impl: str | None, *, allow_paged: bool) -> pytest.ParameterSet:
|
||||
marks = []
|
||||
base_attn_impl = _base_attn_impl(attn_impl)
|
||||
|
||||
if base_attn_impl == "flash_attention_2":
|
||||
marks.append(pytest.mark.skip(reason="flash_attention_2 is very slow"))
|
||||
|
||||
if attn_impl is not None and attn_impl.startswith("paged|") and not allow_paged:
|
||||
marks.append(
|
||||
pytest.mark.skip(reason=f"{attn_impl} requires continuous batching API")
|
||||
)
|
||||
|
||||
if base_attn_impl in _ATTN_REQUIRES_PACKAGE:
|
||||
pkg = _ATTN_REQUIRES_PACKAGE[base_attn_impl]
|
||||
if importlib.util.find_spec(pkg) is None:
|
||||
marks.append(
|
||||
pytest.mark.skip(
|
||||
reason=f"{attn_impl} requires package '{pkg}' which is not installed"
|
||||
)
|
||||
)
|
||||
|
||||
if base_attn_impl in _ATTN_SKIP_TINY_MODEL:
|
||||
marks.append(
|
||||
pytest.mark.skip(
|
||||
reason=f"{attn_impl} is incompatible with tiny random test models"
|
||||
)
|
||||
)
|
||||
|
||||
kwargs = {"id": _attn_id(attn_impl)}
|
||||
if marks:
|
||||
kwargs["marks"] = marks
|
||||
return pytest.param(attn_impl, **kwargs)
|
||||
|
||||
|
||||
_ATTN_PREFILL_PARAMS = tuple(
|
||||
_attn_param(attn_impl, allow_paged=False) for attn_impl in _ATTN_IMPLEMENTATIONS
|
||||
)
|
||||
_ATTN_GENERATE_BATCH_PARAMS = tuple(
|
||||
_attn_param(attn_impl, allow_paged=True) for attn_impl in _ATTN_IMPLEMENTATIONS
|
||||
)
|
||||
|
||||
|
||||
def _compare_past_key_values(lhs, rhs, *, atol: float, rtol: float) -> None:
|
||||
assert lhs is not None
|
||||
assert rhs is not None
|
||||
assert hasattr(lhs, "layers")
|
||||
assert hasattr(rhs, "layers")
|
||||
assert len(lhs.layers) == len(rhs.layers)
|
||||
|
||||
for lhs_layer, rhs_layer in zip(lhs.layers, rhs.layers):
|
||||
torch.testing.assert_close(lhs_layer.keys, rhs_layer.keys, atol=atol, rtol=rtol)
|
||||
torch.testing.assert_close(
|
||||
lhs_layer.values, rhs_layer.values, atol=atol, rtol=rtol
|
||||
)
|
||||
|
||||
|
||||
def _instantiate_model(
|
||||
model_id: str, *, use_cache: bool, attn_impl: str | None, device: torch.device
|
||||
) -> AutoModelForCausalLM:
|
||||
config = AutoConfig.from_pretrained(model_id)
|
||||
config.use_cache = use_cache
|
||||
if attn_impl is not None:
|
||||
config._attn_implementation = attn_impl
|
||||
return AutoModelForCausalLM.from_config(config).eval().to(device)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_case", _MODEL_CASES, ids=lambda case: case.case_id)
|
||||
@pytest.mark.parametrize("use_cache", [False, True], ids=["no_cache", "cache"])
|
||||
@pytest.mark.parametrize("attn_impl", _ATTN_PREFILL_PARAMS)
|
||||
def test_hf_causal_lm_config_options_match_eager(
|
||||
model_case: _CausalLMConfigCase,
|
||||
use_cache: bool,
|
||||
attn_impl: str | None,
|
||||
device: torch.device,
|
||||
configure_dynamo,
|
||||
):
|
||||
"""Compare luminal against eager HF across causal-LM config options."""
|
||||
if use_cache:
|
||||
configure_dynamo(cache_size_limit=2)
|
||||
|
||||
model = _instantiate_model(
|
||||
model_case.model_id,
|
||||
use_cache=use_cache,
|
||||
attn_impl=attn_impl,
|
||||
device=device,
|
||||
)
|
||||
input_ids = torch.tensor([model_case.input_ids], device=device)
|
||||
|
||||
with torch.no_grad():
|
||||
eager_prefill = model(input_ids)
|
||||
|
||||
compiled_model = torch.compile(model, backend=luminal_backend)
|
||||
|
||||
with torch.no_grad():
|
||||
compiled_prefill = compiled_model(input_ids)
|
||||
|
||||
torch.testing.assert_close(
|
||||
compiled_prefill.logits,
|
||||
eager_prefill.logits,
|
||||
atol=model_case.atol,
|
||||
rtol=model_case.rtol,
|
||||
)
|
||||
|
||||
if not use_cache:
|
||||
assert eager_prefill.past_key_values is None
|
||||
assert compiled_prefill.past_key_values is None
|
||||
return
|
||||
|
||||
assert eager_prefill.past_key_values is not None
|
||||
assert compiled_prefill.past_key_values is not None
|
||||
_compare_past_key_values(
|
||||
compiled_prefill.past_key_values,
|
||||
eager_prefill.past_key_values,
|
||||
atol=model_case.atol,
|
||||
rtol=model_case.rtol,
|
||||
)
|
||||
|
||||
next_token = eager_prefill.logits[:, -1, :].argmax(dim=-1, keepdim=True)
|
||||
|
||||
with torch.no_grad():
|
||||
eager_decode = model(next_token, past_key_values=eager_prefill.past_key_values)
|
||||
compiled_decode = compiled_model(
|
||||
next_token,
|
||||
past_key_values=compiled_prefill.past_key_values,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(
|
||||
compiled_decode.logits,
|
||||
eager_decode.logits,
|
||||
atol=model_case.atol,
|
||||
rtol=model_case.rtol,
|
||||
)
|
||||
assert eager_decode.past_key_values is not None
|
||||
assert compiled_decode.past_key_values is not None
|
||||
_compare_past_key_values(
|
||||
compiled_decode.past_key_values,
|
||||
eager_decode.past_key_values,
|
||||
atol=model_case.atol,
|
||||
rtol=model_case.rtol,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_case", _MODEL_CASES, ids=lambda case: case.case_id)
|
||||
@pytest.mark.parametrize("attn_impl", _ATTN_GENERATE_BATCH_PARAMS)
|
||||
@pytest.mark.skipif(not _CUDA_BACKEND_AVAILABLE, reason="generate_batch requires CUDA")
|
||||
def test_hf_generate_batch(
|
||||
model_case: _CausalLMConfigCase,
|
||||
attn_impl: str | None,
|
||||
device: torch.device,
|
||||
):
|
||||
"""Compare generate_batch output for each attention variant against eager baseline."""
|
||||
config = AutoConfig.from_pretrained(model_case.model_id)
|
||||
config.use_cache = True
|
||||
model = (
|
||||
AutoModelForCausalLM.from_config(config)
|
||||
.to(dtype=torch.bfloat16)
|
||||
.eval()
|
||||
.to(device)
|
||||
)
|
||||
|
||||
gen_config = GenerationConfig(
|
||||
do_sample=False,
|
||||
max_new_tokens=5,
|
||||
temperature=None,
|
||||
top_p=None,
|
||||
top_k=None,
|
||||
)
|
||||
cb_config = ContinuousBatchingConfig(
|
||||
block_size=256,
|
||||
use_cuda_graph=False,
|
||||
)
|
||||
|
||||
inputs = [list(model_case.input_ids)]
|
||||
|
||||
# Baseline: eager generate_batch.
|
||||
model.set_attn_implementation("eager")
|
||||
eager_outputs = model.generate_batch(
|
||||
inputs,
|
||||
generation_config=gen_config,
|
||||
continuous_batching_config=cb_config,
|
||||
progress_bar=False,
|
||||
warmup=False,
|
||||
)
|
||||
|
||||
# Variant under test.
|
||||
if attn_impl is not None:
|
||||
model.set_attn_implementation(attn_impl)
|
||||
variant_outputs = model.generate_batch(
|
||||
inputs,
|
||||
generation_config=gen_config,
|
||||
continuous_batching_config=cb_config,
|
||||
progress_bar=False,
|
||||
warmup=False,
|
||||
)
|
||||
|
||||
assert len(eager_outputs) == len(variant_outputs)
|
||||
eager_out = next(iter(eager_outputs.values()))
|
||||
variant_out = next(iter(variant_outputs.values()))
|
||||
assert eager_out.error is None, f"Eager baseline failed: {eager_out.error}"
|
||||
assert variant_out.error is None, f"Variant {attn_impl} failed: {variant_out.error}"
|
||||
assert eager_out.generated_tokens == variant_out.generated_tokens, (
|
||||
f"Token mismatch for {attn_impl}: eager={eager_out.generated_tokens} "
|
||||
f"vs variant={variant_out.generated_tokens}"
|
||||
)
|
||||
168
crates/luminal_python/tests/test_hf_causal_lm_experts_options.py
Normal file
168
crates/luminal_python/tests/test_hf_causal_lm_experts_options.py
Normal file
@@ -0,0 +1,168 @@
|
||||
"""Hugging Face causal-LM experts backend smoke tests.
|
||||
|
||||
These tests load a real pretrained text-only MoE model and compare eager
|
||||
PyTorch against torch.compile with the luminal backend across the standardized
|
||||
`experts_implementation` backends.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoConfig, AutoModelForCausalLM
|
||||
|
||||
from luminal import luminal_backend
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _HFMoeCase:
|
||||
case_id: str
|
||||
model_id: str
|
||||
input_ids: tuple[tuple[int, ...], ...]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _HFMoeBundle:
|
||||
case: _HFMoeCase
|
||||
model: AutoModelForCausalLM
|
||||
device: torch.device
|
||||
dtype: torch.dtype
|
||||
|
||||
|
||||
_MODEL_CASES = [
|
||||
_HFMoeCase(
|
||||
case_id="qwen15_moe_a27b",
|
||||
model_id="Qwen/Qwen1.5-MoE-A2.7B",
|
||||
input_ids=(
|
||||
(1, 2, 3, 4, 5, 6, 7, 8),
|
||||
(8, 7, 6, 5, 4, 3, 2, 1),
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
_EXPERTS_IMPLEMENTATIONS = ("eager", "batched_mm", "grouped_mm")
|
||||
_CUDA_BACKEND_AVAILABLE = (
|
||||
os.getenv("LUMINAL_BACKEND", "native").lower() == "cuda"
|
||||
and torch.cuda.is_available()
|
||||
)
|
||||
|
||||
pytestmark = [
|
||||
pytest.mark.slow,
|
||||
pytest.mark.skipif(
|
||||
not _CUDA_BACKEND_AVAILABLE,
|
||||
reason="HF MoE experts backend tests require the CUDA backend",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def _model_dtype(device: torch.device) -> torch.dtype:
|
||||
if device.type != "cuda":
|
||||
return torch.float32
|
||||
return torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
||||
|
||||
|
||||
def _output_tolerance(dtype: torch.dtype) -> float:
|
||||
if dtype == torch.bfloat16:
|
||||
return 5e-2
|
||||
if dtype == torch.float16:
|
||||
return 1e-2
|
||||
return 1e-3
|
||||
|
||||
|
||||
def _compare_router_logits(lhs, rhs, *, atol: float, rtol: float) -> None:
|
||||
assert lhs is not None
|
||||
assert rhs is not None
|
||||
|
||||
if isinstance(lhs, torch.Tensor):
|
||||
torch.testing.assert_close(lhs, rhs, atol=atol, rtol=rtol)
|
||||
return
|
||||
|
||||
assert len(lhs) == len(rhs)
|
||||
for lhs_layer, rhs_layer in zip(lhs, rhs):
|
||||
torch.testing.assert_close(lhs_layer, rhs_layer, atol=atol, rtol=rtol)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=_MODEL_CASES, ids=lambda case: case.case_id)
|
||||
def hf_moe_case(request: pytest.FixtureRequest) -> _HFMoeCase:
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def hf_moe_bundle(hf_moe_case: _HFMoeCase) -> _HFMoeBundle:
|
||||
case = hf_moe_case
|
||||
device = torch.device("cuda")
|
||||
dtype = _model_dtype(device)
|
||||
|
||||
config = AutoConfig.from_pretrained(case.model_id)
|
||||
config.use_cache = False
|
||||
config.output_router_logits = True
|
||||
# Keep attention fixed so this test isolates experts backends.
|
||||
config._attn_implementation = "eager"
|
||||
|
||||
model = (
|
||||
AutoModelForCausalLM.from_pretrained(
|
||||
case.model_id,
|
||||
config=config,
|
||||
torch_dtype=dtype,
|
||||
)
|
||||
.eval()
|
||||
.to(device)
|
||||
)
|
||||
return _HFMoeBundle(case=case, model=model, device=device, dtype=dtype)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"experts_implementation",
|
||||
_EXPERTS_IMPLEMENTATIONS,
|
||||
ids=list(_EXPERTS_IMPLEMENTATIONS),
|
||||
)
|
||||
def test_hf_causal_lm_experts_implementation_matches_eager(
|
||||
hf_moe_bundle: _HFMoeBundle, experts_implementation: str
|
||||
):
|
||||
model = hf_moe_bundle.model
|
||||
model.set_experts_implementation(experts_implementation)
|
||||
assert model.config._experts_implementation == experts_implementation
|
||||
|
||||
input_ids = torch.tensor(hf_moe_bundle.case.input_ids, device=hf_moe_bundle.device)
|
||||
kwargs = {
|
||||
"input_ids": input_ids,
|
||||
"use_cache": False,
|
||||
"output_router_logits": True,
|
||||
"logits_to_keep": 1,
|
||||
}
|
||||
|
||||
with torch.no_grad():
|
||||
eager_output = model(**kwargs)
|
||||
|
||||
compiled_model = torch.compile(model, backend=luminal_backend)
|
||||
with torch.no_grad():
|
||||
compiled_output = compiled_model(**kwargs)
|
||||
|
||||
atol = _output_tolerance(hf_moe_bundle.dtype)
|
||||
rtol = 1e-3
|
||||
|
||||
torch.testing.assert_close(
|
||||
compiled_output.logits,
|
||||
eager_output.logits,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
)
|
||||
_compare_router_logits(
|
||||
compiled_output.router_logits,
|
||||
eager_output.router_logits,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
)
|
||||
|
||||
if eager_output.aux_loss is not None or compiled_output.aux_loss is not None:
|
||||
assert eager_output.aux_loss is not None
|
||||
assert compiled_output.aux_loss is not None
|
||||
torch.testing.assert_close(
|
||||
compiled_output.aux_loss,
|
||||
eager_output.aux_loss,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
)
|
||||
242
crates/luminal_python/tests/test_hf_multimodal_generation.py
Normal file
242
crates/luminal_python/tests/test_hf_multimodal_generation.py
Normal file
@@ -0,0 +1,242 @@
|
||||
"""Hugging Face multimodal image-text-to-text smoke tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoConfig, AutoModelForImageTextToText, AutoProcessor
|
||||
|
||||
from luminal import luminal_backend
|
||||
|
||||
MODEL_ID = "google/gemma-3-4b-it"
|
||||
_CUDA_BACKEND_AVAILABLE = (
|
||||
os.getenv("LUMINAL_BACKEND", "native").lower() == "cuda"
|
||||
and torch.cuda.is_available()
|
||||
)
|
||||
|
||||
pytestmark = [
|
||||
pytest.mark.slow,
|
||||
pytest.mark.skipif(
|
||||
not _CUDA_BACKEND_AVAILABLE,
|
||||
reason="Gemma 3 multimodal tests require the CUDA backend",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class HFMultimodalCase:
|
||||
case_id: str
|
||||
messages_builder: Callable[[Path], list[dict]]
|
||||
max_new_tokens: int
|
||||
expects_pixel_values: bool
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Gemma3MultimodalBundle:
|
||||
model: AutoModelForImageTextToText
|
||||
processor: AutoProcessor
|
||||
device: torch.device
|
||||
dtype: torch.dtype
|
||||
|
||||
|
||||
def _build_text_only_messages(_: Path) -> list[dict]:
|
||||
return [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [{"type": "text", "text": "You are a concise assistant."}],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "In one short sentence, explain what a compiler does.",
|
||||
}
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def _build_image_to_text_messages(image_path: Path) -> list[dict]:
|
||||
return [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [{"type": "text", "text": "You are a helpful assistant."}],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "path": str(image_path)},
|
||||
{"type": "text", "text": "Describe this image in one short sentence."},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
MULTIMODAL_CASES = [
|
||||
HFMultimodalCase(
|
||||
case_id="chat_text_only",
|
||||
messages_builder=_build_text_only_messages,
|
||||
max_new_tokens=12,
|
||||
expects_pixel_values=False,
|
||||
),
|
||||
HFMultimodalCase(
|
||||
case_id="image_to_text",
|
||||
messages_builder=_build_image_to_text_messages,
|
||||
max_new_tokens=16,
|
||||
expects_pixel_values=True,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def _model_dtype(device: torch.device) -> torch.dtype:
|
||||
if device.type != "cuda":
|
||||
return torch.float32
|
||||
return torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
||||
|
||||
|
||||
def _set_greedy_generation(model) -> None:
|
||||
model.generation_config.temperature = None
|
||||
model.generation_config.top_p = None
|
||||
model.generation_config.top_k = None
|
||||
|
||||
|
||||
def _move_to_device(
|
||||
encoded: dict[str, torch.Tensor], device: torch.device, dtype: torch.dtype
|
||||
) -> dict[str, torch.Tensor]:
|
||||
result = {}
|
||||
for key, value in encoded.items():
|
||||
if not isinstance(value, torch.Tensor):
|
||||
result[key] = value
|
||||
continue
|
||||
|
||||
moved = value.to(device)
|
||||
if moved.is_floating_point():
|
||||
moved = moved.to(dtype=dtype)
|
||||
result[key] = moved
|
||||
return result
|
||||
|
||||
|
||||
def _encode_case(
|
||||
bundle: Gemma3MultimodalBundle,
|
||||
case: HFMultimodalCase,
|
||||
image_path: Path,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
encoded = bundle.processor.apply_chat_template(
|
||||
case.messages_builder(image_path),
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
encoded = _move_to_device(dict(encoded), bundle.device, bundle.dtype)
|
||||
|
||||
if case.expects_pixel_values:
|
||||
assert "pixel_values" in encoded
|
||||
assert "input_ids" in encoded
|
||||
return encoded
|
||||
|
||||
|
||||
def _generate_kwargs(
|
||||
bundle: Gemma3MultimodalBundle,
|
||||
encoded: dict[str, torch.Tensor],
|
||||
max_new_tokens: int,
|
||||
) -> dict:
|
||||
tokenizer = bundle.processor.tokenizer
|
||||
return dict(
|
||||
**encoded,
|
||||
max_new_tokens=max_new_tokens,
|
||||
do_sample=False,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
)
|
||||
|
||||
|
||||
def _logits_tolerance(dtype: torch.dtype) -> float:
|
||||
if dtype == torch.bfloat16:
|
||||
return 5e-2
|
||||
if dtype == torch.float16:
|
||||
return 1e-2
|
||||
return 1e-3
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def gemma3_multimodal_bundle() -> Gemma3MultimodalBundle:
|
||||
device = torch.device("cuda")
|
||||
dtype = _model_dtype(device)
|
||||
|
||||
config = AutoConfig.from_pretrained(MODEL_ID)
|
||||
processor = AutoProcessor.from_pretrained(MODEL_ID)
|
||||
tokenizer = processor.tokenizer
|
||||
if tokenizer.pad_token_id is None and tokenizer.eos_token is not None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
model = (
|
||||
AutoModelForImageTextToText.from_pretrained(
|
||||
MODEL_ID,
|
||||
config=config,
|
||||
torch_dtype=dtype,
|
||||
)
|
||||
.eval()
|
||||
.to(device)
|
||||
)
|
||||
_set_greedy_generation(model)
|
||||
return Gemma3MultimodalBundle(
|
||||
model=model, processor=processor, device=device, dtype=dtype
|
||||
)
|
||||
|
||||
|
||||
class TestHFMultimodalGeneration:
|
||||
@pytest.mark.parametrize("case", MULTIMODAL_CASES, ids=lambda case: case.case_id)
|
||||
def test_generate_matches_eager(
|
||||
self,
|
||||
case: HFMultimodalCase,
|
||||
gemma3_multimodal_bundle: Gemma3MultimodalBundle,
|
||||
hf_multimodal_image_path: Path,
|
||||
):
|
||||
encoded = _encode_case(gemma3_multimodal_bundle, case, hf_multimodal_image_path)
|
||||
kwargs = _generate_kwargs(
|
||||
gemma3_multimodal_bundle, encoded, case.max_new_tokens
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
eager_output = gemma3_multimodal_bundle.model.generate(**kwargs)
|
||||
|
||||
compiled_model = torch.compile(
|
||||
gemma3_multimodal_bundle.model, backend=luminal_backend
|
||||
)
|
||||
with torch.no_grad():
|
||||
compiled_output = compiled_model.generate(**kwargs)
|
||||
|
||||
torch.testing.assert_close(compiled_output, eager_output)
|
||||
|
||||
@pytest.mark.parametrize("case", MULTIMODAL_CASES, ids=lambda case: case.case_id)
|
||||
def test_forward_logits_match_eager(
|
||||
self,
|
||||
case: HFMultimodalCase,
|
||||
gemma3_multimodal_bundle: Gemma3MultimodalBundle,
|
||||
hf_multimodal_image_path: Path,
|
||||
):
|
||||
encoded = _encode_case(gemma3_multimodal_bundle, case, hf_multimodal_image_path)
|
||||
|
||||
with torch.no_grad():
|
||||
eager_out = gemma3_multimodal_bundle.model(**encoded)
|
||||
|
||||
compiled_model = torch.compile(
|
||||
gemma3_multimodal_bundle.model, backend=luminal_backend
|
||||
)
|
||||
with torch.no_grad():
|
||||
compiled_out = compiled_model(**encoded)
|
||||
|
||||
atol = _logits_tolerance(gemma3_multimodal_bundle.dtype)
|
||||
torch.testing.assert_close(
|
||||
compiled_out.logits,
|
||||
eager_out.logits,
|
||||
atol=atol,
|
||||
rtol=1e-3,
|
||||
)
|
||||
288
crates/luminal_python/tests/test_hf_text_generation.py
Normal file
288
crates/luminal_python/tests/test_hf_text_generation.py
Normal file
@@ -0,0 +1,288 @@
|
||||
"""Hugging Face text-generation smoke tests.
|
||||
|
||||
These tests intentionally download real Hugging Face checkpoints, configs,
|
||||
and tokenizers. They compare eager PyTorch output against torch.compile
|
||||
with the luminal backend to verify numerical equivalence.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch._dynamo
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from luminal import luminal_backend
|
||||
|
||||
SIMPLE_DENSE_MODELS = [
|
||||
"meta-llama/Llama-3.2-1B",
|
||||
"meta-llama/Llama-3.1-1B",
|
||||
"Qwen/Qwen3-8B",
|
||||
]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class HFTextGenerationBundle:
|
||||
model: AutoModelForCausalLM
|
||||
tokenizer: AutoTokenizer
|
||||
device: torch.device
|
||||
|
||||
|
||||
def _load_model_and_tokenizer(
|
||||
model_id: str, device: torch.device
|
||||
) -> tuple[AutoModelForCausalLM, AutoTokenizer]:
|
||||
"""Load a pretrained HF causal LM and its tokenizer, ready for generation."""
|
||||
config = AutoConfig.from_pretrained(model_id)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
|
||||
if tokenizer.pad_token_id is None and tokenizer.eos_token is not None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
dtype = torch.float16 if device.type == "cuda" else torch.float32
|
||||
model = (
|
||||
AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
config=config,
|
||||
torch_dtype=dtype,
|
||||
)
|
||||
.eval()
|
||||
.to(device)
|
||||
)
|
||||
|
||||
model.generation_config.temperature = None
|
||||
model.generation_config.top_p = None
|
||||
model.generation_config.top_k = None
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def _encode(
|
||||
tokenizer: AutoTokenizer, prompt: str, device: torch.device
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""Tokenize a prompt and move tensors to device."""
|
||||
encoded = tokenizer(prompt, return_tensors="pt")
|
||||
result = {"input_ids": encoded["input_ids"].to(device)}
|
||||
if encoded.get("attention_mask") is not None:
|
||||
result["attention_mask"] = encoded["attention_mask"].to(device)
|
||||
return result
|
||||
|
||||
|
||||
def _generate_kwargs(
|
||||
tokenizer: AutoTokenizer,
|
||||
encoded: dict[str, torch.Tensor],
|
||||
max_new_tokens: int = 6,
|
||||
) -> dict:
|
||||
"""Build kwargs dict for model.generate()."""
|
||||
return dict(
|
||||
**encoded,
|
||||
max_new_tokens=max_new_tokens,
|
||||
do_sample=False,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def hf_text_bundle(model_id: str, device: torch.device) -> HFTextGenerationBundle:
|
||||
model, tokenizer = _load_model_and_tokenizer(model_id, device)
|
||||
return HFTextGenerationBundle(model=model, tokenizer=tokenizer, device=device)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
class TestHFGeneration:
|
||||
"""End-to-end tests comparing eager PyTorch against torch.compile with luminal."""
|
||||
|
||||
@pytest.mark.parametrize("model_id", SIMPLE_DENSE_MODELS)
|
||||
def test_capital_of_france(self, hf_text_bundle: HFTextGenerationBundle):
|
||||
"""Basic greedy generation -- the original smoke test."""
|
||||
encoded = _encode(
|
||||
hf_text_bundle.tokenizer,
|
||||
"What is the capital of France ",
|
||||
hf_text_bundle.device,
|
||||
)
|
||||
kwargs = _generate_kwargs(hf_text_bundle.tokenizer, encoded)
|
||||
|
||||
with torch.no_grad():
|
||||
eager_output = hf_text_bundle.model.generate(**kwargs)
|
||||
|
||||
compiled_model = torch.compile(hf_text_bundle.model, backend=luminal_backend)
|
||||
with torch.no_grad():
|
||||
compiled_output = compiled_model.generate(**kwargs)
|
||||
|
||||
torch.testing.assert_close(compiled_output, eager_output)
|
||||
|
||||
@pytest.mark.parametrize("model_id", SIMPLE_DENSE_MODELS)
|
||||
def test_forward_logits(self, hf_text_bundle: HFTextGenerationBundle):
|
||||
"""Forward pass only -- compare raw logits, not generated tokens."""
|
||||
encoded = _encode(
|
||||
hf_text_bundle.tokenizer, "The quick brown fox", hf_text_bundle.device
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
eager_out = hf_text_bundle.model(**encoded)
|
||||
|
||||
compiled_model = torch.compile(hf_text_bundle.model, backend=luminal_backend)
|
||||
with torch.no_grad():
|
||||
compiled_out = compiled_model(**encoded)
|
||||
|
||||
dtype = next(hf_text_bundle.model.parameters()).dtype
|
||||
atol = 1e-2 if dtype == torch.float16 else 1e-3
|
||||
torch.testing.assert_close(
|
||||
compiled_out.logits, eager_out.logits, atol=atol, rtol=1e-3
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("model_id", SIMPLE_DENSE_MODELS[:1])
|
||||
@pytest.mark.parametrize(
|
||||
"prompt",
|
||||
[
|
||||
"Hi",
|
||||
"What is the capital of France",
|
||||
"Explain the theory of general relativity in simple terms that a high school student could understand",
|
||||
],
|
||||
ids=["short", "medium", "long"],
|
||||
)
|
||||
def test_variable_length_prompts(
|
||||
self,
|
||||
prompt: str,
|
||||
hf_text_bundle: HFTextGenerationBundle,
|
||||
):
|
||||
"""Generate with prompts of different lengths -- tests dynamic shape handling."""
|
||||
encoded = _encode(hf_text_bundle.tokenizer, prompt, hf_text_bundle.device)
|
||||
kwargs = _generate_kwargs(hf_text_bundle.tokenizer, encoded, max_new_tokens=4)
|
||||
|
||||
with torch.no_grad():
|
||||
eager_output = hf_text_bundle.model.generate(**kwargs)
|
||||
|
||||
compiled_model = torch.compile(hf_text_bundle.model, backend=luminal_backend)
|
||||
with torch.no_grad():
|
||||
compiled_output = compiled_model.generate(**kwargs)
|
||||
|
||||
torch.testing.assert_close(compiled_output, eager_output)
|
||||
|
||||
@pytest.mark.parametrize("model_id", SIMPLE_DENSE_MODELS[:1])
|
||||
def test_chat_template_generation(
|
||||
self,
|
||||
model_id: str,
|
||||
hf_text_bundle: HFTextGenerationBundle,
|
||||
):
|
||||
"""Generate using chat-templated input with special tokens."""
|
||||
if hf_text_bundle.tokenizer.chat_template is None:
|
||||
pytest.skip(f"{model_id} has no chat template")
|
||||
|
||||
messages = [{"role": "user", "content": "What is 2+2?"}]
|
||||
encoded = hf_text_bundle.tokenizer.apply_chat_template(
|
||||
messages,
|
||||
return_tensors="pt",
|
||||
add_generation_prompt=True,
|
||||
return_dict=True,
|
||||
)
|
||||
encoded = {k: v.to(hf_text_bundle.device) for k, v in encoded.items()}
|
||||
kwargs = _generate_kwargs(hf_text_bundle.tokenizer, encoded)
|
||||
|
||||
with torch.no_grad():
|
||||
eager_output = hf_text_bundle.model.generate(**kwargs)
|
||||
|
||||
compiled_model = torch.compile(hf_text_bundle.model, backend=luminal_backend)
|
||||
with torch.no_grad():
|
||||
compiled_output = compiled_model.generate(**kwargs)
|
||||
|
||||
torch.testing.assert_close(compiled_output, eager_output)
|
||||
|
||||
@pytest.mark.parametrize("model_id", SIMPLE_DENSE_MODELS[:1])
|
||||
@pytest.mark.parametrize("max_new_tokens", [20, 50])
|
||||
def test_longer_generation(
|
||||
self,
|
||||
max_new_tokens: int,
|
||||
hf_text_bundle: HFTextGenerationBundle,
|
||||
):
|
||||
"""Generate many tokens to stress KV cache over extended decode loop."""
|
||||
encoded = _encode(
|
||||
hf_text_bundle.tokenizer, "Once upon a time", hf_text_bundle.device
|
||||
)
|
||||
kwargs = _generate_kwargs(
|
||||
hf_text_bundle.tokenizer,
|
||||
encoded,
|
||||
max_new_tokens=max_new_tokens,
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
eager_output = hf_text_bundle.model.generate(**kwargs)
|
||||
|
||||
compiled_model = torch.compile(hf_text_bundle.model, backend=luminal_backend)
|
||||
with torch.no_grad():
|
||||
compiled_output = compiled_model.generate(**kwargs)
|
||||
|
||||
torch.testing.assert_close(compiled_output, eager_output)
|
||||
|
||||
@pytest.mark.parametrize("model_id", SIMPLE_DENSE_MODELS[:1])
|
||||
def test_greedy_determinism(
|
||||
self,
|
||||
hf_text_bundle: HFTextGenerationBundle,
|
||||
configure_dynamo,
|
||||
):
|
||||
"""Greedy generation produces identical results on repeated calls."""
|
||||
configure_dynamo(cache_size_limit=4)
|
||||
|
||||
encoded = _encode(
|
||||
hf_text_bundle.tokenizer,
|
||||
"The meaning of life is",
|
||||
hf_text_bundle.device,
|
||||
)
|
||||
kwargs = _generate_kwargs(hf_text_bundle.tokenizer, encoded, max_new_tokens=10)
|
||||
|
||||
compiled_model = torch.compile(hf_text_bundle.model, backend=luminal_backend)
|
||||
with torch.no_grad():
|
||||
output_1 = compiled_model.generate(**kwargs)
|
||||
output_2 = compiled_model.generate(**kwargs)
|
||||
|
||||
torch.testing.assert_close(output_1, output_2)
|
||||
|
||||
@pytest.mark.parametrize("model_id", SIMPLE_DENSE_MODELS[:1])
|
||||
def test_reuse_compiled_model(
|
||||
self,
|
||||
hf_text_bundle: HFTextGenerationBundle,
|
||||
configure_dynamo,
|
||||
):
|
||||
"""Call the same compiled model multiple times with different prompts."""
|
||||
configure_dynamo(cache_size_limit=8)
|
||||
|
||||
compiled_model = torch.compile(hf_text_bundle.model, backend=luminal_backend)
|
||||
|
||||
prompts = [
|
||||
"The capital of France is",
|
||||
"Water boils at",
|
||||
"The largest planet in our solar system is",
|
||||
]
|
||||
|
||||
for prompt in prompts:
|
||||
encoded = _encode(hf_text_bundle.tokenizer, prompt, hf_text_bundle.device)
|
||||
kwargs = _generate_kwargs(
|
||||
hf_text_bundle.tokenizer, encoded, max_new_tokens=4
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
eager_output = hf_text_bundle.model.generate(**kwargs)
|
||||
compiled_output = compiled_model.generate(**kwargs)
|
||||
|
||||
torch.testing.assert_close(compiled_output, eager_output)
|
||||
|
||||
@pytest.mark.parametrize("model_id", SIMPLE_DENSE_MODELS[:1])
|
||||
def test_batched_inference(self, hf_text_bundle: HFTextGenerationBundle):
|
||||
"""Batched generation with multiple prompts and left-padding."""
|
||||
hf_text_bundle.tokenizer.padding_side = "left"
|
||||
|
||||
prompts = ["Hello", "What is the capital of France"]
|
||||
encoded = hf_text_bundle.tokenizer(prompts, return_tensors="pt", padding=True)
|
||||
encoded = {k: v.to(hf_text_bundle.device) for k, v in encoded.items()}
|
||||
kwargs = _generate_kwargs(hf_text_bundle.tokenizer, encoded, max_new_tokens=4)
|
||||
|
||||
with torch.no_grad():
|
||||
eager_output = hf_text_bundle.model.generate(**kwargs)
|
||||
|
||||
compiled_model = torch.compile(hf_text_bundle.model, backend=luminal_backend)
|
||||
with torch.no_grad():
|
||||
compiled_output = compiled_model.generate(**kwargs)
|
||||
|
||||
torch.testing.assert_close(compiled_output, eager_output)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -2,7 +2,7 @@
|
||||
|
||||
Tests individual Llama3 building blocks (RMSNorm, RoPE, SwiGLU, causal attention,
|
||||
full transformer block) and progressively larger HuggingFace LlamaForCausalLM configs
|
||||
through the PyTorch -> Pt2 -> luminal pipeline via torch.compile.
|
||||
through the PyTorch -> ONNX -> luminal pipeline via torch.compile.
|
||||
"""
|
||||
|
||||
from typing import Callable
|
||||
@@ -362,55 +362,6 @@ def test_hf_llama3_large_full(device: torch.device):
|
||||
)
|
||||
|
||||
|
||||
# ========== Dynamic Dimension Tests ==========
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(),
|
||||
reason="CUDA graph in-place update test — requires CUDA",
|
||||
)
|
||||
def test_dynamic_dim_reuse_no_recompile(device: torch.device):
|
||||
"""Compile once with dynamic shapes, execute with varying seq lengths.
|
||||
|
||||
Validates that the luminal runtime correctly handles dynamic dimension
|
||||
changes without recompilation. This is the core scenario optimized by
|
||||
removing the unnecessary CUDA graph rebuild on dyn_map changes: a single
|
||||
compiled graph handles multiple sequence lengths via in-place parameter
|
||||
updates rather than rebuilding the entire CUDA graph each step.
|
||||
"""
|
||||
from luminal.pt2 import compile as luminal_compile
|
||||
|
||||
class DynamicSeqModel(torch.nn.Module):
|
||||
"""Embedding + linear projection with variable-length integer input."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.embed = torch.nn.Embedding(256, 64)
|
||||
self.proj = torch.nn.Linear(64, 64)
|
||||
|
||||
def forward(self, x):
|
||||
return self.proj(self.embed(x))
|
||||
|
||||
model = DynamicSeqModel().eval().to(device)
|
||||
backend = "cuda" if device.type == "cuda" else "native"
|
||||
|
||||
# Compile once with dynamic seq dim (auto-detected for integer inputs)
|
||||
example = torch.tensor([[1, 2, 3, 4]], device=device)
|
||||
compiled = luminal_compile(model, example, search_iterations=5, backend=backend)
|
||||
|
||||
# Execute with multiple different seq lengths — each call reuses the
|
||||
# same compiled graph, updating dynamic dims in-place.
|
||||
for seq_len in [4, 5, 6, 7, 8]:
|
||||
input_ids = torch.tensor([list(range(1, seq_len + 1))], device=device)
|
||||
with torch.no_grad():
|
||||
ref = model(input_ids)
|
||||
out = compiled(input_ids)
|
||||
assert torch.allclose(out[0], ref, atol=1e-5), (
|
||||
f"seq_len={seq_len}: "
|
||||
f"max_diff={torch.max(torch.abs(out[0] - ref)).item():.2e}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="numerical precision — max_diff exceeds atol")
|
||||
def test_hf_llama38b_full(device: torch.device):
|
||||
"""HuggingFace LlamaForCausalLM — full Llama-3.1-8B-Instruct with real pretrained weights.
|
||||
@@ -441,3 +392,60 @@ def test_hf_llama38b_full(device: torch.device):
|
||||
assert torch.allclose(out.logits, ref.logits, atol=1e-5), (
|
||||
f"max_diff={torch.max(torch.abs(out.logits - ref.logits)).item():.2e}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_hf_llama38b_cached_onnx(
|
||||
llama38b_onnx_path, llama38b_ref_logits: torch.Tensor
|
||||
):
|
||||
import os
|
||||
|
||||
import luminal
|
||||
|
||||
backend = os.environ.get("LUMINAL_BACKEND", "cuda")
|
||||
|
||||
graph = luminal.process_onnx(str(llama38b_onnx_path), backend)
|
||||
print("Compiled luminal ONNX graph")
|
||||
|
||||
graph.set_input("input_ids", [float(t) for t in [1, 2, 3, 4]])
|
||||
graph.run()
|
||||
|
||||
logits_data = graph.get_output("logits")
|
||||
logits_shape = graph.output_shapes[0]
|
||||
logits = torch.tensor(logits_data, dtype=torch.float32).reshape(logits_shape)
|
||||
|
||||
print(f"Loaded reference logits: {llama38b_ref_logits.shape}")
|
||||
print(f"Output logits shape: {logits.shape}")
|
||||
|
||||
assert torch.allclose(logits, llama38b_ref_logits, atol=1e-3), (
|
||||
f"max_diff={torch.max(torch.abs(logits - llama38b_ref_logits)).item():.2e}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_hf_llama38b_cached_pt2(
|
||||
llama38b_pt2_path, llama38b_weights_path, llama38b_ref_logits: torch.Tensor
|
||||
):
|
||||
import os
|
||||
|
||||
import luminal
|
||||
from luminal import CompiledModel
|
||||
|
||||
backend = os.environ.get("LUMINAL_BACKEND", "cuda")
|
||||
backend_name = "cuda" if backend == "cuda" else "cpu"
|
||||
|
||||
compiled_inner = luminal.compile_pt2(
|
||||
str(llama38b_pt2_path), str(llama38b_weights_path), backend_name, 0
|
||||
)
|
||||
compiled = CompiledModel(compiled_inner)
|
||||
print("Compiled luminal PT2 graph")
|
||||
|
||||
input_ids = torch.tensor([[1, 2, 3, 4]])
|
||||
logits = compiled(input_ids)[0]
|
||||
|
||||
print(f"Loaded reference logits: {llama38b_ref_logits.shape}")
|
||||
print(f"Output logits shape: {logits.shape}")
|
||||
|
||||
assert torch.allclose(logits, llama38b_ref_logits, atol=1e-3), (
|
||||
f"max_diff={torch.max(torch.abs(logits - llama38b_ref_logits)).item():.2e}"
|
||||
)
|
||||
|
||||
@@ -3,13 +3,6 @@
|
||||
import torch
|
||||
|
||||
|
||||
class SelfAddModel(torch.nn.Module):
|
||||
"""Adds input to itself (x + x). Preserves input dtype."""
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x + x
|
||||
|
||||
|
||||
class AddTestModel(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
@@ -48,9 +41,6 @@ class AddAddTestModel(torch.nn.Module):
|
||||
|
||||
|
||||
class AddConstantTestModel(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
return x + 10
|
||||
|
||||
@@ -66,25 +56,16 @@ class LinearLayerModel(torch.nn.Module):
|
||||
|
||||
|
||||
class SqrtTestModel(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x.sqrt()
|
||||
|
||||
|
||||
class SinTestModel(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.sin(x)
|
||||
|
||||
|
||||
class CosTestModel(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.cos(x)
|
||||
|
||||
@@ -101,9 +82,6 @@ class SubTestModel(torch.nn.Module):
|
||||
class TransposeTestModel(torch.nn.Module):
|
||||
"""Test basic 2D transpose (matrix transpose)."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x.t() # 2D transpose
|
||||
|
||||
@@ -111,9 +89,6 @@ class TransposeTestModel(torch.nn.Module):
|
||||
class Transpose3DTestModel(torch.nn.Module):
|
||||
"""Test 3D transpose with explicit permutation."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x.permute(2, 0, 1) # Rotate dimensions
|
||||
|
||||
@@ -121,9 +96,6 @@ class Transpose3DTestModel(torch.nn.Module):
|
||||
class Transpose4DTestModel(torch.nn.Module):
|
||||
"""Test 4D transpose (NCHW -> NHWC)."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x.permute(0, 2, 3, 1) # Common in CNNs
|
||||
|
||||
@@ -131,9 +103,6 @@ class Transpose4DTestModel(torch.nn.Module):
|
||||
class TransposeReverseTestModel(torch.nn.Module):
|
||||
"""Test reverse permutation (default transpose behavior)."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
dims = list(range(x.ndim))
|
||||
return x.permute(*reversed(dims))
|
||||
@@ -152,15 +121,12 @@ class TransposeInExpressionModel(torch.nn.Module):
|
||||
|
||||
|
||||
# ========== Constant Node Test Models ==========
|
||||
# These models test PT2 Constant node handling via inline tensor literals
|
||||
# These models test ONNX Constant node handling via inline tensor literals
|
||||
|
||||
|
||||
class ConstantScalarFloatModel(torch.nn.Module):
|
||||
"""Test scalar constant (broadcasts to input shape)."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
constant = torch.tensor(10.5).to(x.device)
|
||||
return x + constant
|
||||
@@ -169,9 +135,6 @@ class ConstantScalarFloatModel(torch.nn.Module):
|
||||
class Constant1DArrayFloatModel(torch.nn.Module):
|
||||
"""Test 1D array constant."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
constant = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0]).to(x.device)
|
||||
return x * constant
|
||||
@@ -180,9 +143,6 @@ class Constant1DArrayFloatModel(torch.nn.Module):
|
||||
class Constant2DMatrixFloatModel(torch.nn.Module):
|
||||
"""Test 2D matrix constant."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
constant = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]).to(x.device)
|
||||
return x + constant
|
||||
@@ -191,9 +151,6 @@ class Constant2DMatrixFloatModel(torch.nn.Module):
|
||||
class ConstantRawDataFloatModel(torch.nn.Module):
|
||||
"""Test constant with specific values (tests raw data format)."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
constant = torch.tensor([7.5, 8.5, 9.5]).to(x.device)
|
||||
return x + constant
|
||||
@@ -202,9 +159,6 @@ class ConstantRawDataFloatModel(torch.nn.Module):
|
||||
class ConstantInt32ConversionModel(torch.nn.Module):
|
||||
"""Test INT32 constant values (PyTorch exports as integers)."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
constant = torch.tensor([1, 2, 3, 4, 5], dtype=torch.int32).to(x.device)
|
||||
return x + constant.float()
|
||||
@@ -213,9 +167,6 @@ class ConstantInt32ConversionModel(torch.nn.Module):
|
||||
class ConstantInt64ConversionModel(torch.nn.Module):
|
||||
"""Test INT64 constant values."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
constant = torch.tensor([100, 200, 300], dtype=torch.int64).to(x.device)
|
||||
return x * constant.float()
|
||||
@@ -224,9 +175,6 @@ class ConstantInt64ConversionModel(torch.nn.Module):
|
||||
class ConstantFloat64ConversionModel(torch.nn.Module):
|
||||
"""Test FLOAT64 (double) constant values."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
constant = torch.tensor([1.5, 2.5, 3.5], dtype=torch.float64).to(x.device)
|
||||
return x * constant.float()
|
||||
@@ -235,9 +183,6 @@ class ConstantFloat64ConversionModel(torch.nn.Module):
|
||||
class ConstantBoolConversionModel(torch.nn.Module):
|
||||
"""Test boolean constant values (converted to 0.0/1.0)."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
constant = torch.tensor([True, False, True, False, True], dtype=torch.bool).to(
|
||||
x.device
|
||||
@@ -248,9 +193,6 @@ class ConstantBoolConversionModel(torch.nn.Module):
|
||||
class ConstantInt64RawDataModel(torch.nn.Module):
|
||||
"""Test INT64 constant with large values (tests raw data path)."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
constant = torch.tensor([1000, 2000, 3000], dtype=torch.int64).to(x.device)
|
||||
return x + constant.float()
|
||||
@@ -259,9 +201,6 @@ class ConstantInt64RawDataModel(torch.nn.Module):
|
||||
class ConstantNegativeValuesModel(torch.nn.Module):
|
||||
"""Test negative constant values."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
constant = torch.tensor([-5.0, -10.0, -15.0]).to(x.device)
|
||||
return x + constant
|
||||
@@ -270,9 +209,6 @@ class ConstantNegativeValuesModel(torch.nn.Module):
|
||||
class ConstantZeroValueModel(torch.nn.Module):
|
||||
"""Test all-zero constant."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
constant = torch.tensor([0.0, 0.0, 0.0, 0.0]).to(x.device)
|
||||
return x * constant
|
||||
@@ -281,9 +217,6 @@ class ConstantZeroValueModel(torch.nn.Module):
|
||||
class ConstantMultipleInGraphModel(torch.nn.Module):
|
||||
"""Test multiple constants in one graph."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
const1 = torch.tensor([10.0, 20.0, 30.0]).to(x.device)
|
||||
const2 = torch.tensor([1.0, 2.0, 3.0]).to(x.device)
|
||||
@@ -291,15 +224,12 @@ class ConstantMultipleInGraphModel(torch.nn.Module):
|
||||
|
||||
|
||||
# ========== Cast Node Test Models ==========
|
||||
# These models test PT2 Cast node handling via .to(dtype) method
|
||||
# These models test ONNX Cast node handling via .to(dtype) method
|
||||
|
||||
|
||||
class CastDoubleToFloatModel(torch.nn.Module):
|
||||
"""Test downcast: Double (FLOAT64) -> Float."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# Input will be float64, cast to float32
|
||||
return x.to(torch.float32)
|
||||
@@ -308,9 +238,6 @@ class CastDoubleToFloatModel(torch.nn.Module):
|
||||
class CastInt32ToFloatModel(torch.nn.Module):
|
||||
"""Test INT32 -> Float conversion."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x.to(torch.float32)
|
||||
|
||||
@@ -318,9 +245,6 @@ class CastInt32ToFloatModel(torch.nn.Module):
|
||||
class CastInt64ToFloatModel(torch.nn.Module):
|
||||
"""Test INT64 -> Float conversion."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x.to(torch.float32)
|
||||
|
||||
@@ -328,9 +252,6 @@ class CastInt64ToFloatModel(torch.nn.Module):
|
||||
class CastBoolToFloatModel(torch.nn.Module):
|
||||
"""Test BOOL -> Float conversion (non-zero -> 1.0, zero -> 0.0)."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x.to(torch.float32)
|
||||
|
||||
@@ -338,9 +259,6 @@ class CastBoolToFloatModel(torch.nn.Module):
|
||||
class CastInComputationGraphModel(torch.nn.Module):
|
||||
"""Test Cast node followed by an operation (Cast + Add)."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
casted = x.to(torch.float32)
|
||||
constant = torch.tensor([2.0, 2.0, 2.0]).to(x.device)
|
||||
@@ -350,9 +268,6 @@ class CastInComputationGraphModel(torch.nn.Module):
|
||||
class CastWith2DTensorModel(torch.nn.Module):
|
||||
"""Test Cast with 2D tensor (matrix)."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x.to(torch.float32)
|
||||
|
||||
@@ -360,9 +275,6 @@ class CastWith2DTensorModel(torch.nn.Module):
|
||||
class CastNegativeValuesModel(torch.nn.Module):
|
||||
"""Test Cast with negative integer values."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x.to(torch.float32)
|
||||
|
||||
@@ -370,9 +282,6 @@ class CastNegativeValuesModel(torch.nn.Module):
|
||||
class CastScalarValueModel(torch.nn.Module):
|
||||
"""Test Cast with scalar (single element)."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x.to(torch.float32)
|
||||
|
||||
@@ -394,10 +303,7 @@ class ModTestModel(torch.nn.Module):
|
||||
|
||||
|
||||
class ModByConstantModel(torch.nn.Module):
|
||||
"""Tests modulo with an inline constant tensor (PT2 Constant node)."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
"""Tests modulo with an inline constant tensor (ONNX Constant node)."""
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
constant = torch.tensor([3.0, 4.0, 5.0]).to(x.device)
|
||||
@@ -453,7 +359,7 @@ class CeilInExpressionModel(torch.nn.Module):
|
||||
|
||||
|
||||
# ========== Reshape Node Test Models ==========
|
||||
# These models test PT2 Reshape node handling in ops_parse.rs
|
||||
# These models test ONNX Reshape node handling in ops_parse.rs
|
||||
|
||||
|
||||
class ReshapeToFlatModel(torch.nn.Module):
|
||||
@@ -541,7 +447,7 @@ class ShapeReshapeKeepBatchModel(torch.nn.Module):
|
||||
|
||||
|
||||
# ========== Less Node Test Models ==========
|
||||
# These models test PT2 Less node handling in ops_parse.rs
|
||||
# These models test ONNX Less node handling in ops_parse.rs
|
||||
|
||||
|
||||
class LessTestModel(torch.nn.Module):
|
||||
@@ -567,7 +473,7 @@ class LessBroadcastModel(torch.nn.Module):
|
||||
|
||||
|
||||
class LessWithConstantModel(torch.nn.Module):
|
||||
"""Tests less-than against an inline constant (PT2 Constant + Less nodes)."""
|
||||
"""Tests less-than against an inline constant (ONNX Constant + Less nodes)."""
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
constant = torch.tensor([0.25, 0.5, 0.75]).to(x.device)
|
||||
@@ -575,7 +481,7 @@ class LessWithConstantModel(torch.nn.Module):
|
||||
|
||||
|
||||
# ========== Gather Node Test Models ==========
|
||||
# These models test PT2 Gather node handling in ops_parse.rs
|
||||
# These models test ONNX Gather node handling in ops_parse.rs
|
||||
|
||||
|
||||
class Gather1DModel(torch.nn.Module):
|
||||
@@ -628,7 +534,7 @@ class GatherNegativeIndicesModel(torch.nn.Module):
|
||||
|
||||
|
||||
class GatherConstantFoldModel(torch.nn.Module):
|
||||
"""Tests Gather constant folding: both data and indices are PT2 Constant nodes."""
|
||||
"""Tests Gather constant folding: both data and indices are ONNX Constant nodes."""
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
data = torch.tensor([10.0, 20.0, 30.0, 40.0, 50.0]).to(x.device)
|
||||
@@ -637,7 +543,7 @@ class GatherConstantFoldModel(torch.nn.Module):
|
||||
|
||||
|
||||
# ========== Squeeze Node Test Models ==========
|
||||
# These models test PT2 Squeeze node handling in ops_parse.rs
|
||||
# These models test ONNX Squeeze node handling in ops_parse.rs
|
||||
|
||||
|
||||
class SqueezeAxisModel(torch.nn.Module):
|
||||
@@ -1147,7 +1053,7 @@ class MaxTestModel(torch.nn.Module):
|
||||
|
||||
|
||||
class MaxWithConstantModel(torch.nn.Module):
|
||||
"""Tests element-wise maximum against an inline constant (PT2 Max + Constant nodes)."""
|
||||
"""Tests element-wise maximum against an inline constant (ONNX Max + Constant nodes)."""
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
constant = torch.tensor([0.2, 0.4, 0.6, 0.8, 1.0]).to(x.device)
|
||||
@@ -1169,7 +1075,7 @@ class MinTestModel(torch.nn.Module):
|
||||
|
||||
|
||||
class MinWithConstantModel(torch.nn.Module):
|
||||
"""Tests element-wise minimum against an inline constant (PT2 Min + Constant nodes)."""
|
||||
"""Tests element-wise minimum against an inline constant (ONNX Min + Constant nodes)."""
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
constant = torch.tensor([0.2, 0.4, 0.6, 0.8, 1.0]).to(x.device)
|
||||
@@ -1295,7 +1201,7 @@ class LessOrEqualTestModel(torch.nn.Module):
|
||||
|
||||
|
||||
class LessOrEqualWithConstantModel(torch.nn.Module):
|
||||
"""Tests less-than-or-equal against an inline constant (PT2 Constant + LessOrEqual nodes)."""
|
||||
"""Tests less-than-or-equal against an inline constant (ONNX Constant + LessOrEqual nodes)."""
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
constant = torch.tensor([0.25, 0.5, 0.75]).to(x.device)
|
||||
@@ -1317,7 +1223,7 @@ class GreaterOrEqualTestModel(torch.nn.Module):
|
||||
|
||||
|
||||
class GreaterOrEqualWithConstantModel(torch.nn.Module):
|
||||
"""Tests greater-than-or-equal against an inline constant (PT2 Constant + GreaterOrEqual nodes)."""
|
||||
"""Tests greater-than-or-equal against an inline constant (ONNX Constant + GreaterOrEqual nodes)."""
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
constant = torch.tensor([0.25, 0.5, 0.75]).to(x.device)
|
||||
@@ -1439,7 +1345,7 @@ class GreaterTestModel(torch.nn.Module):
|
||||
|
||||
|
||||
class GreaterWithConstantModel(torch.nn.Module):
|
||||
"""Tests greater-than against a scalar constant (PT2 Greater + Constant nodes)."""
|
||||
"""Tests greater-than against a scalar constant (ONNX Greater + Constant nodes)."""
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return (x > 0.5).to(torch.float32)
|
||||
@@ -1516,7 +1422,7 @@ class MLPBlockModel(torch.nn.Module):
|
||||
|
||||
|
||||
class GatherElementsTestModel(torch.nn.Module):
|
||||
"""Tests element-wise gather along axis=1 using torch.gather (→ PT2 GatherElements)."""
|
||||
"""Tests element-wise gather along axis=1 using torch.gather (→ ONNX GatherElements)."""
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
idx = torch.tensor([[0, 1, 1], [1, 0, 0]], device=x.device)
|
||||
@@ -1537,7 +1443,7 @@ class GatherElementsLargeTestModel(torch.nn.Module):
|
||||
|
||||
|
||||
class ExpandTestModel(torch.nn.Module):
|
||||
"""Tests broadcasting a (1, 4) tensor to (3, 4) via .expand() (→ PT2 Expand)."""
|
||||
"""Tests broadcasting a (1, 4) tensor to (3, 4) via .expand() (→ ONNX Expand)."""
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x.expand(3, 4)
|
||||
@@ -1557,7 +1463,7 @@ class IsNaNTestModel(torch.nn.Module):
|
||||
|
||||
|
||||
class LayerNormTestModel(torch.nn.Module):
|
||||
"""Tests nn.LayerNorm which exports as PT2 LayerNormalization."""
|
||||
"""Tests nn.LayerNorm which exports as ONNX LayerNormalization."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
@@ -1571,7 +1477,7 @@ class LayerNormTestModel(torch.nn.Module):
|
||||
|
||||
|
||||
class GemmTestModel(torch.nn.Module):
|
||||
"""Tests Gemm: nn.Linear exports as PT2 Gemm (weight transposed)."""
|
||||
"""Tests Gemm: nn.Linear exports as ONNX Gemm (weight transposed)."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
@@ -1595,14 +1501,14 @@ class ErfTestModel(torch.nn.Module):
|
||||
|
||||
|
||||
class SliceTestModel(torch.nn.Module):
|
||||
"""Tests PT2 Slice: slice axis 0 from index 1 to 3."""
|
||||
"""Tests ONNX Slice: slice axis 0 from index 1 to 3."""
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x[1:3]
|
||||
|
||||
|
||||
class SliceMultiAxisTestModel(torch.nn.Module):
|
||||
"""Tests PT2 Slice along multiple axes: x[1:3, 0:2]."""
|
||||
"""Tests ONNX Slice along multiple axes: x[1:3, 0:2]."""
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x[1:3, 0:2]
|
||||
@@ -1691,7 +1597,7 @@ class ScatterNDTestModel(torch.nn.Module):
|
||||
class RMSNormModel(torch.nn.Module):
|
||||
"""Tests RMS normalization: x * rsqrt(mean(x^2) + eps) * weight.
|
||||
|
||||
PT2 ops: Pow, ReduceMean, Add, Sqrt, Reciprocal, Mul.
|
||||
ONNX ops: Pow, ReduceMean, Add, Sqrt, Reciprocal, Mul.
|
||||
Input: (1, 4, 32) -> Output: (1, 4, 32).
|
||||
"""
|
||||
|
||||
@@ -1710,7 +1616,7 @@ class RotaryEmbeddingModel(torch.nn.Module):
|
||||
"""Tests rotary position embeddings (RoPE) using rotate-half approach.
|
||||
|
||||
Precomputes cos/sin caches as buffers; at runtime: slice, split halves, rotate.
|
||||
PT2 ops: Slice, Unsqueeze, Mul, Sub, Add, Concat.
|
||||
ONNX ops: Slice, Unsqueeze, Mul, Sub, Add, Concat.
|
||||
Input: (1, 4, 4, 8) [batch, seq, heads, head_dim] -> Output: same shape.
|
||||
"""
|
||||
|
||||
@@ -1739,7 +1645,7 @@ class RotaryEmbeddingModel(torch.nn.Module):
|
||||
class SwiGLUMLPModel(torch.nn.Module):
|
||||
"""Tests SwiGLU MLP: down_proj(silu(gate_proj(x)) * up_proj(x)).
|
||||
|
||||
silu(x) = x * sigmoid(x), decomposes to Sigmoid+Mul in PT2.
|
||||
silu(x) = x * sigmoid(x), decomposes to Sigmoid+Mul in ONNX.
|
||||
Input: (1, 4, 32) -> Output: (1, 4, 32).
|
||||
"""
|
||||
|
||||
@@ -1830,202 +1736,3 @@ class LlamaTransformerBlockModel(torch.nn.Module):
|
||||
h = x + self.attn(self.input_norm(x))
|
||||
out = h + self.mlp(self.post_attn_norm(h))
|
||||
return out
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Convolution models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class Conv1dNoPadModel(torch.nn.Module):
|
||||
"""Conv1d with no padding: output length shrinks by (kernel-1)."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv1d(8, 16, kernel_size=3, padding=0, bias=False)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class Conv1dSamePadModel(torch.nn.Module):
|
||||
"""Conv1d with same-size padding (output length == input length)."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv1d(8, 16, kernel_size=3, padding=1, bias=False)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class Conv1dBiasModel(torch.nn.Module):
|
||||
"""Conv1d with bias."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv1d(8, 16, kernel_size=3, padding=1, bias=True)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class Conv2dNoPadModel(torch.nn.Module):
|
||||
"""Conv2d with no padding: output spatial dims shrink by (kernel-1)."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(3, 16, kernel_size=3, padding=0, bias=False)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class Conv2dSamePadModel(torch.nn.Module):
|
||||
"""Conv2d with same-size padding."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(3, 16, kernel_size=3, padding=1, bias=False)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class Conv2dBiasModel(torch.nn.Module):
|
||||
"""Conv2d with bias."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(3, 16, kernel_size=3, padding=1, bias=True)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class Conv2dStrideModel(torch.nn.Module):
|
||||
"""Conv2d with stride=2 (output dims halved)."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(
|
||||
3, 16, kernel_size=3, stride=2, padding=1, bias=False
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class Conv2dDilationModel(torch.nn.Module):
|
||||
"""Conv2d with dilation=2 and padding chosen to preserve spatial size."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(
|
||||
8, 16, kernel_size=3, dilation=2, padding=2, bias=False
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class Conv3dSamePadModel(torch.nn.Module):
|
||||
"""Conv3d with padding=1 to preserve spatial dimensions."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv3d(4, 8, kernel_size=3, padding=1, bias=False)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class DepthwiseConv1dModel(torch.nn.Module):
|
||||
"""Depthwise Conv1d as used in Mamba (groups == in_channels)."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv1d(
|
||||
16, 16, kernel_size=4, groups=16, padding=3, bias=True
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# Causal truncation: keep only the first L positions
|
||||
return self.conv(x)[:, :, : x.shape[2]]
|
||||
|
||||
|
||||
class DepthwiseConv2dModel(torch.nn.Module):
|
||||
"""Depthwise Conv2d (groups == in_channels)."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(
|
||||
8, 8, kernel_size=3, groups=8, padding=1, bias=False
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class DepthwiseMultiplierConv2dModel(torch.nn.Module):
|
||||
"""Depthwise Conv2d with channel multiplier 2 (out_channels = 2 * in_channels)."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(
|
||||
8, 16, kernel_size=3, groups=8, padding=1, bias=False
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class GroupedConv2dModel(torch.nn.Module):
|
||||
"""Conv2d with groups=4 (not depthwise, but grouped)."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(
|
||||
16, 32, kernel_size=3, groups=4, padding=1, bias=False
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class GroupedConv2dGroups3Model(torch.nn.Module):
|
||||
"""Conv2d with groups=3 and ch_per_group=4."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(
|
||||
12, 12, kernel_size=3, groups=3, padding=1, bias=False
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class MambaConvBlockModel(torch.nn.Module):
|
||||
"""Minimal Mamba-style SSM block: Linear -> split -> depthwise Conv1d -> SiLU gate -> Linear.
|
||||
|
||||
This is the core conv pattern used in Mamba / Mamba-2 models.
|
||||
"""
|
||||
|
||||
def __init__(self, d_model: int = 16, d_conv: int = 4, expand: int = 2) -> None:
|
||||
super().__init__()
|
||||
d_inner = d_model * expand
|
||||
self.in_proj = torch.nn.Linear(d_model, d_inner * 2, bias=False)
|
||||
self.conv1d = torch.nn.Conv1d(
|
||||
d_inner, d_inner, d_conv, groups=d_inner, padding=d_conv - 1, bias=True
|
||||
)
|
||||
self.out_proj = torch.nn.Linear(d_inner, d_model, bias=False)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
b, seq_len, _ = x.shape
|
||||
xz = self.in_proj(x)
|
||||
x_part, z = xz.chunk(2, dim=-1)
|
||||
x_part = self.conv1d(x_part.transpose(1, 2))[:, :, :seq_len].transpose(1, 2)
|
||||
return self.out_proj(
|
||||
torch.nn.functional.silu(x_part) * torch.nn.functional.silu(z)
|
||||
)
|
||||
|
||||
@@ -1,154 +1,118 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch._dynamo
|
||||
from test_models import (
|
||||
SigmoidTestModel,
|
||||
SigmoidInExpressionModel,
|
||||
TanhTestModel,
|
||||
TanhInExpressionModel,
|
||||
ReluTestModel,
|
||||
ReluAllNegativeModel,
|
||||
ReluInExpressionModel,
|
||||
AbsTestModel,
|
||||
AbsAllNegativeModel,
|
||||
AbsInExpressionModel,
|
||||
NegTestModel,
|
||||
NegAllPositiveModel,
|
||||
NegInExpressionModel,
|
||||
ClipTestModel,
|
||||
ClipMinOnlyTestModel,
|
||||
ClipMaxOnlyTestModel,
|
||||
)
|
||||
import test_models as tm
|
||||
|
||||
from luminal import luminal_backend
|
||||
|
||||
# ── Sigmoid ──────────────────────────────────────────────────────────────────
|
||||
|
||||
Args = tuple[torch.Tensor, ...]
|
||||
Kwargs = dict[str, torch.Tensor]
|
||||
InputFactory = Callable[[torch.device], tuple[Args, Kwargs]]
|
||||
|
||||
|
||||
def test_sigmoid(device):
|
||||
model = SigmoidTestModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x = torch.rand((5, 5), device=device) * 2 - 1 # mixed positive/negative
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
@dataclass(frozen=True)
|
||||
class UnaryCase:
|
||||
id: str
|
||||
model_factory: Callable[[], torch.nn.Module]
|
||||
input_factory: InputFactory
|
||||
|
||||
|
||||
def test_sigmoid_in_expression(device):
|
||||
model = SigmoidInExpressionModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x = torch.rand((5, 5), device=device)
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
UNARY_CASES: list[UnaryCase] = [
|
||||
UnaryCase(
|
||||
id="sigmoid",
|
||||
model_factory=tm.SigmoidTestModel,
|
||||
input_factory=lambda device: ((torch.rand((5, 5), device=device) * 2 - 1,), {}),
|
||||
),
|
||||
UnaryCase(
|
||||
id="sigmoid_in_expression",
|
||||
model_factory=tm.SigmoidInExpressionModel,
|
||||
input_factory=lambda device: ((torch.rand((5, 5), device=device),), {}),
|
||||
),
|
||||
UnaryCase(
|
||||
id="tanh",
|
||||
model_factory=tm.TanhTestModel,
|
||||
input_factory=lambda device: ((torch.rand((5, 5), device=device) * 2 - 1,), {}),
|
||||
),
|
||||
UnaryCase(
|
||||
id="tanh_in_expression",
|
||||
model_factory=tm.TanhInExpressionModel,
|
||||
input_factory=lambda device: ((torch.rand((5, 5), device=device),), {}),
|
||||
),
|
||||
UnaryCase(
|
||||
id="relu",
|
||||
model_factory=tm.ReluTestModel,
|
||||
input_factory=lambda device: ((torch.rand((5, 5), device=device) * 2 - 1,), {}),
|
||||
),
|
||||
UnaryCase(
|
||||
id="relu_all_negative",
|
||||
model_factory=tm.ReluAllNegativeModel,
|
||||
input_factory=lambda device: ((-torch.rand((5, 5), device=device),), {}),
|
||||
),
|
||||
UnaryCase(
|
||||
id="relu_in_expression",
|
||||
model_factory=tm.ReluInExpressionModel,
|
||||
input_factory=lambda device: ((torch.rand((5, 5), device=device) * 2 - 1,), {}),
|
||||
),
|
||||
UnaryCase(
|
||||
id="abs",
|
||||
model_factory=tm.AbsTestModel,
|
||||
input_factory=lambda device: ((torch.rand((5, 5), device=device) * 2 - 1,), {}),
|
||||
),
|
||||
UnaryCase(
|
||||
id="abs_all_negative",
|
||||
model_factory=tm.AbsAllNegativeModel,
|
||||
input_factory=lambda device: ((-torch.rand((5, 5), device=device),), {}),
|
||||
),
|
||||
UnaryCase(
|
||||
id="abs_in_expression",
|
||||
model_factory=tm.AbsInExpressionModel,
|
||||
input_factory=lambda device: ((torch.rand((5, 5), device=device) * 2 - 1,), {}),
|
||||
),
|
||||
UnaryCase(
|
||||
id="neg",
|
||||
model_factory=tm.NegTestModel,
|
||||
input_factory=lambda device: ((torch.rand((5, 5), device=device) * 2 - 1,), {}),
|
||||
),
|
||||
UnaryCase(
|
||||
id="neg_all_positive",
|
||||
model_factory=tm.NegAllPositiveModel,
|
||||
input_factory=lambda device: ((torch.rand((5, 5), device=device),), {}),
|
||||
),
|
||||
UnaryCase(
|
||||
id="neg_in_expression",
|
||||
model_factory=tm.NegInExpressionModel,
|
||||
input_factory=lambda device: ((torch.rand((5, 5), device=device) * 2 - 1,), {}),
|
||||
),
|
||||
UnaryCase(
|
||||
id="clip",
|
||||
model_factory=tm.ClipTestModel,
|
||||
input_factory=lambda device: ((torch.rand((5, 5), device=device) * 4 - 2,), {}),
|
||||
),
|
||||
UnaryCase(
|
||||
id="clip_min_only",
|
||||
model_factory=tm.ClipMinOnlyTestModel,
|
||||
input_factory=lambda device: ((torch.rand((5, 5), device=device) * 4 - 2,), {}),
|
||||
),
|
||||
UnaryCase(
|
||||
id="clip_max_only",
|
||||
model_factory=tm.ClipMaxOnlyTestModel,
|
||||
input_factory=lambda device: ((torch.rand((5, 5), device=device) * 4 - 2,), {}),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
# ── Tanh ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_tanh(device):
|
||||
model = TanhTestModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x = torch.rand((5, 5), device=device) * 2 - 1
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
|
||||
|
||||
def test_tanh_in_expression(device):
|
||||
model = TanhInExpressionModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x = torch.rand((5, 5), device=device)
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
|
||||
|
||||
# ── Relu ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_relu(device):
|
||||
model = ReluTestModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x = torch.rand((5, 5), device=device) * 2 - 1 # mixed positive/negative
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
|
||||
|
||||
def test_relu_all_negative(device):
|
||||
model = ReluAllNegativeModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x = -torch.rand((5, 5), device=device) # all negative -> output all zeros
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
|
||||
|
||||
def test_relu_in_expression(device):
|
||||
model = ReluInExpressionModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x = torch.rand((5, 5), device=device) * 2 - 1
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
|
||||
|
||||
# ── Abs ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_abs(device):
|
||||
model = AbsTestModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x = torch.rand((5, 5), device=device) * 2 - 1 # mixed positive/negative
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
|
||||
|
||||
def test_abs_all_negative(device):
|
||||
model = AbsAllNegativeModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x = -torch.rand((5, 5), device=device) # all negative
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
|
||||
|
||||
def test_abs_in_expression(device):
|
||||
model = AbsInExpressionModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x = torch.rand((5, 5), device=device) * 2 - 1
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
|
||||
|
||||
# ── Neg ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_neg(device):
|
||||
model = NegTestModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x = torch.rand((5, 5), device=device) * 2 - 1 # mixed positive/negative
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
|
||||
|
||||
def test_neg_all_positive(device):
|
||||
model = NegAllPositiveModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x = torch.rand((5, 5), device=device) # all positive -> output all negative
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
|
||||
|
||||
def test_neg_in_expression(device):
|
||||
model = NegInExpressionModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x = torch.rand((5, 5), device=device) * 2 - 1
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
|
||||
|
||||
# ── Clip ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_clip(device):
|
||||
"""Clip tensor values to [-0.5, 0.5]."""
|
||||
model = ClipTestModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x = torch.rand((5, 5), device=device) * 4 - 2 # range [-2, 2]
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
|
||||
|
||||
def test_clip_min_only(device):
|
||||
"""Clip tensor values to [0.0, +inf]."""
|
||||
model = ClipMinOnlyTestModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x = torch.rand((5, 5), device=device) * 4 - 2
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
|
||||
|
||||
def test_clip_max_only(device):
|
||||
"""Clip tensor values to [-inf, 0.5]."""
|
||||
model = ClipMaxOnlyTestModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x = torch.rand((5, 5), device=device) * 4 - 2
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
class TestUnaryOps:
|
||||
@pytest.mark.parametrize("case", UNARY_CASES, ids=lambda case: case.id)
|
||||
def test_matches_eager(self, case: UnaryCase, device: torch.device) -> None:
|
||||
model = case.model_factory().to(device)
|
||||
compiled_model = torch.compile(model, backend=luminal_backend)
|
||||
args, kwargs = case.input_factory(device)
|
||||
torch.testing.assert_close(
|
||||
compiled_model(*args, **kwargs),
|
||||
model(*args, **kwargs),
|
||||
atol=1e-5,
|
||||
rtol=1e-5,
|
||||
)
|
||||
|
||||
20
skills-lock.json
Normal file
20
skills-lock.json
Normal file
@@ -0,0 +1,20 @@
|
||||
{
|
||||
"version": 1,
|
||||
"skills": {
|
||||
"ruff": {
|
||||
"source": "astral-sh/claude-code-plugins",
|
||||
"sourceType": "github",
|
||||
"computedHash": "905f1c2b66e722ab534d7cbeceafb6ee37d32e3bfe51f3e68c661c51eb19a776"
|
||||
},
|
||||
"ty": {
|
||||
"source": "astral-sh/claude-code-plugins",
|
||||
"sourceType": "github",
|
||||
"computedHash": "5d82b319a84d296fae47f2664228bfac1a36dd832cb487b485a65ec36b3df21f"
|
||||
},
|
||||
"uv": {
|
||||
"source": "astral-sh/claude-code-plugins",
|
||||
"sourceType": "github",
|
||||
"computedHash": "b1ca9e9826d906f6ee6ea172da50018ec13bb622ac9ba7204bda8b10e7fc8d0a"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -11,7 +11,6 @@ pub const ILIST: SortClass = SortClass::new("IList");
|
||||
pub const EXPRESSION: SortClass = SortClass::new("Expression");
|
||||
pub const ELIST: SortClass = SortClass::new("EList");
|
||||
pub const DTYPE: SortClass = SortClass::new("DType");
|
||||
pub const FUSED_INSTR: SortClass = SortClass::new("FusedInstr");
|
||||
pub const I64: SortClass = SortClass::new("i64");
|
||||
pub const F64: SortClass = SortClass::new("f64");
|
||||
pub const STRING: SortClass = SortClass::new("String");
|
||||
@@ -233,8 +232,6 @@ pub struct BaseSorts {
|
||||
pub bf16_dt: SortDef,
|
||||
pub int_dt: SortDef,
|
||||
pub bool_dt: SortDef,
|
||||
pub i4_dt: SortDef,
|
||||
pub tf32_dt: SortDef,
|
||||
// Egglog builtin primitives (for term construction only)
|
||||
pub p_add: SortDef,
|
||||
pub p_sub: SortDef,
|
||||
@@ -313,8 +310,6 @@ impl BaseSorts {
|
||||
bf16_dt: sort(DTYPE, "Bf16", &[]),
|
||||
int_dt: sort(DTYPE, "Int", &[]),
|
||||
bool_dt: sort(DTYPE, "Bool", &[]),
|
||||
i4_dt: sort(DTYPE, "I4", &[]),
|
||||
tf32_dt: sort(DTYPE, "TF32", &[]),
|
||||
p_add: func("+", &["a", "b"]),
|
||||
p_sub: func("-", &["a", "b"]),
|
||||
p_mul: func("*", &["a", "b"]),
|
||||
@@ -368,8 +363,6 @@ impl BaseSorts {
|
||||
&self.bf16_dt,
|
||||
&self.int_dt,
|
||||
&self.bool_dt,
|
||||
&self.i4_dt,
|
||||
&self.tf32_dt,
|
||||
] {
|
||||
p.add_sort(s);
|
||||
}
|
||||
@@ -443,7 +436,6 @@ pub fn base_expression_egglog() -> String {
|
||||
|
||||
// Rulesets
|
||||
p.add_ruleset("expr");
|
||||
p.add_ruleset("dtype_prop");
|
||||
p.add_ruleset("cleanup");
|
||||
p.add_ruleset("early");
|
||||
|
||||
|
||||
@@ -13,9 +13,8 @@ pub mod api;
|
||||
pub mod base;
|
||||
|
||||
pub const RUN_SCHEDULE: &str = "(run-schedule
|
||||
(repeat 10
|
||||
(repeat 100
|
||||
(saturate expr)
|
||||
(saturate dtype_prop)
|
||||
(run)
|
||||
)
|
||||
(saturate expr)
|
||||
@@ -62,18 +61,6 @@ fn op_defs_string(ops: &[Arc<Box<dyn EgglogOp>>]) -> String {
|
||||
(ICons IR IList)
|
||||
(INil)
|
||||
)
|
||||
(FusedInstr
|
||||
(FIInput i64 EList)
|
||||
(FIConstant f64)
|
||||
(FIExp2 FusedInstr)
|
||||
(FILog2 FusedInstr)
|
||||
(FISin FusedInstr)
|
||||
(FIRecip FusedInstr)
|
||||
(FISqrt FusedInstr)
|
||||
(FIAdd FusedInstr FusedInstr)
|
||||
(FIMul FusedInstr FusedInstr)
|
||||
(FIMod FusedInstr FusedInstr)
|
||||
)
|
||||
)
|
||||
(function dtype (IR) DType :merge new)
|
||||
"
|
||||
@@ -779,8 +766,6 @@ pub fn extract_dtype<'a>(egraph: &'a SerializedEGraph, node: &'a NodeId) -> DTyp
|
||||
"F4E2M1" => DType::F4E2M1,
|
||||
"F8E4M3" => DType::F8E4M3,
|
||||
"F8UE8M0" => DType::F8UE8M0,
|
||||
"I4" => DType::I4,
|
||||
"TF32" => DType::TF32,
|
||||
other => panic!("unknown dtype {other}"),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -57,35 +57,15 @@ impl GraphTensor {
|
||||
self.graph().get_op_mut::<Input>(self.id).label = name.to_string();
|
||||
}
|
||||
|
||||
/// Mark this tensor as an output.
|
||||
/// If the tensor has non-contiguous strides (e.g. from transpose + merge_dims),
|
||||
/// inserts a gather to materialize contiguous data before the output node.
|
||||
/// Mark this tensor as an output
|
||||
pub fn output(&self) -> GraphTensor {
|
||||
let source = if self.shape.is_contiguous() {
|
||||
*self
|
||||
} else {
|
||||
// Insert gather to make physically contiguous
|
||||
let dims = self.dims();
|
||||
let total = dims.iter().copied().reduce(|a, b| a * b).unwrap();
|
||||
let idx_expr = self.shape.index_expression();
|
||||
let idx = self.graph().iota(idx_expr, total);
|
||||
let mut gathered = self.gather(idx);
|
||||
gathered.shape = ShapeTracker::new(dims);
|
||||
gathered
|
||||
};
|
||||
self.output_raw(source)
|
||||
}
|
||||
|
||||
/// Mark a tensor as an output without any contiguous materialization.
|
||||
/// Used internally by graph_break and persist.
|
||||
fn output_raw(&self, source: GraphTensor) -> GraphTensor {
|
||||
self.graph().add_op(
|
||||
Output {
|
||||
node: source.id.index(),
|
||||
node: self.id.index(),
|
||||
},
|
||||
&[source.id],
|
||||
&[self.id],
|
||||
);
|
||||
source
|
||||
*self
|
||||
}
|
||||
|
||||
/// Required bytes to store this tensor's physical elements. Rounds up to nearest byte.
|
||||
@@ -97,7 +77,7 @@ impl GraphTensor {
|
||||
/// so the buffer is not consumed after execute(), but returns the original
|
||||
/// Input node's GraphTensor (not the Output node).
|
||||
pub fn persist(&self) -> GraphTensor {
|
||||
self.output_raw(*self);
|
||||
self.output();
|
||||
*self
|
||||
}
|
||||
|
||||
|
||||
96
src/graph.rs
96
src/graph.rs
@@ -82,58 +82,6 @@ impl DimBucket {
|
||||
}
|
||||
}
|
||||
|
||||
/// Options for controlling the genetic search algorithm.
|
||||
///
|
||||
/// Use the builder pattern to configure search parameters:
|
||||
/// ```
|
||||
/// use luminal::prelude::SearchOptions;
|
||||
/// let opts = SearchOptions::new(5)
|
||||
/// .generation_size(50)
|
||||
/// .mutations(40)
|
||||
/// .trials(15);
|
||||
/// ```
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SearchOptions {
|
||||
/// Maximum number of graphs to evaluate
|
||||
pub limit: usize,
|
||||
/// Number of offspring per generation (default: 30)
|
||||
pub generation_size: usize,
|
||||
/// Number of mutations applied to each offspring (default: 30)
|
||||
pub mutations: usize,
|
||||
/// Number of profiling trials per candidate (default: 10)
|
||||
pub trials: usize,
|
||||
}
|
||||
|
||||
impl SearchOptions {
|
||||
/// Create new search options with the given limit. Other fields use defaults.
|
||||
pub fn new(limit: usize) -> Self {
|
||||
Self {
|
||||
limit,
|
||||
generation_size: 30,
|
||||
mutations: 30,
|
||||
trials: 10,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the number of offspring per generation.
|
||||
pub fn generation_size(mut self, generation_size: usize) -> Self {
|
||||
self.generation_size = generation_size;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the number of mutations per offspring.
|
||||
pub fn mutations(mut self, mutations: usize) -> Self {
|
||||
self.mutations = mutations;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the number of profiling trials per candidate.
|
||||
pub fn trials(mut self, trials: usize) -> Self {
|
||||
self.trials = trials;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// A Luminal compute graph.
|
||||
///
|
||||
/// All computation is represented as a directed acyclic graph.
|
||||
@@ -385,6 +333,7 @@ impl Graph {
|
||||
subgraphs.len()
|
||||
);
|
||||
|
||||
// Build e-graphs only for representative chunks
|
||||
self.egraphs = groups
|
||||
.iter()
|
||||
.map(|g| {
|
||||
@@ -406,23 +355,27 @@ impl Graph {
|
||||
self.ops.as_ref()
|
||||
}
|
||||
|
||||
const DEFAULT_GENERATION_SIZE: usize = 30;
|
||||
const MUTATIONS_PER_OFFSPRING: usize = 30;
|
||||
const TRIALS_PER_PROFILE: usize = 10;
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub fn search<R: Runtime>(&mut self, runtime: R, limit: usize) -> R {
|
||||
let mut rng = rand::rng();
|
||||
self.search_options(runtime, SearchOptions::new(limit), &mut rng)
|
||||
self.search_rng(runtime, limit, &mut rng)
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub fn search_options<R: Runtime, G: rand::Rng>(
|
||||
pub fn search_rng<R: Runtime, G: rand::Rng>(
|
||||
&mut self,
|
||||
mut runtime: R,
|
||||
options: SearchOptions,
|
||||
limit: usize,
|
||||
rng: &mut G,
|
||||
) -> R {
|
||||
if self.dim_buckets.is_empty() {
|
||||
// No buckets: existing single-search path
|
||||
let stitched =
|
||||
self.search_single(&mut runtime, &options, rng, &self.dyn_map.clone(), None);
|
||||
self.search_single(&mut runtime, limit, rng, &self.dyn_map.clone(), None);
|
||||
|
||||
runtime.clear_intermediate_buffers();
|
||||
runtime.load_llir(&stitched);
|
||||
@@ -447,7 +400,7 @@ impl Graph {
|
||||
|
||||
let stitched = self.search_single(
|
||||
&mut runtime,
|
||||
&options,
|
||||
limit,
|
||||
rng,
|
||||
&representative_dyn_map,
|
||||
Some((combo_idx, n_combos)),
|
||||
@@ -517,12 +470,11 @@ impl Graph {
|
||||
fn search_single<R: Runtime, G: rand::Rng>(
|
||||
&self,
|
||||
runtime: &mut R,
|
||||
options: &SearchOptions,
|
||||
limit: usize,
|
||||
rng: &mut G,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
bucket_progress: Option<(usize, usize)>,
|
||||
) -> LLIRGraph {
|
||||
let limit = options.limit;
|
||||
let n_chunks = self.subgraph_descriptors.len();
|
||||
let n_groups = self.chunk_groups.len();
|
||||
let multi_chunk = n_chunks > 1;
|
||||
@@ -659,7 +611,7 @@ impl Graph {
|
||||
None,
|
||||
);
|
||||
runtime.clear_intermediate_buffers();
|
||||
let profile = runtime.profile(&graph, dyn_map, options.trials);
|
||||
let profile = runtime.profile(&graph, dyn_map, Self::TRIALS_PER_PROFILE);
|
||||
let has_nan = runtime.has_nan_outputs(&graph, dyn_map);
|
||||
(graph, profile, has_nan)
|
||||
}));
|
||||
@@ -718,8 +670,8 @@ impl Graph {
|
||||
let offspring = extract_generation(
|
||||
egraph,
|
||||
&best_genome,
|
||||
(limit - n_graphs).min(options.generation_size),
|
||||
options.mutations,
|
||||
(limit - n_graphs).min(Self::DEFAULT_GENERATION_SIZE),
|
||||
Self::MUTATIONS_PER_OFFSPRING,
|
||||
&mut prev_selected,
|
||||
rng,
|
||||
);
|
||||
@@ -745,7 +697,7 @@ impl Graph {
|
||||
);
|
||||
runtime.clear_intermediate_buffers();
|
||||
let result =
|
||||
runtime.profile(&llir_graph, dyn_map, options.trials);
|
||||
runtime.profile(&llir_graph, dyn_map, Self::TRIALS_PER_PROFILE);
|
||||
let has_nan = runtime.has_nan_outputs(&llir_graph, dyn_map);
|
||||
(result, llir_graph, has_nan)
|
||||
}));
|
||||
@@ -861,7 +813,7 @@ impl Graph {
|
||||
&mut expr_cache,
|
||||
custom_remap,
|
||||
);
|
||||
remap_llir_io_nodes(&mut llir, &node_remap, &self.graph);
|
||||
remap_llir_io_nodes(&mut llir, &node_remap);
|
||||
chunk_best_llirs[chunk_idx] = Some(llir);
|
||||
}
|
||||
|
||||
@@ -1271,27 +1223,17 @@ fn build_chunk_remaps(
|
||||
}
|
||||
|
||||
/// Apply Input/Output node index remapping to an LLIR graph (in-place modification).
|
||||
fn remap_llir_io_nodes(
|
||||
llir: &mut LLIRGraph,
|
||||
node_remap: &FxHashMap<usize, usize>,
|
||||
hlir_graph: &HLIRGraph,
|
||||
) {
|
||||
fn remap_llir_io_nodes(llir: &mut LLIRGraph, node_remap: &FxHashMap<usize, usize>) {
|
||||
// We need to replace nodes in-place. Collect node indices first.
|
||||
let node_indices: Vec<NodeIndex> = llir.node_indices().collect();
|
||||
for node_idx in node_indices {
|
||||
let op = &llir[node_idx];
|
||||
let new_op = if let Some(input_op) = op.to_op::<crate::hlir::Input>() {
|
||||
if let Some(&new_node) = node_remap.get(&input_op.node) {
|
||||
// Look up the target HLIR Input's label so chunk copies get correct names
|
||||
let new_label = hlir_graph
|
||||
.node_weight(NodeIndex::new(new_node))
|
||||
.and_then(|w| w.as_any().downcast_ref::<crate::hlir::Input>())
|
||||
.map(|inp| inp.label.clone())
|
||||
.unwrap_or_else(|| input_op.label.clone());
|
||||
Some(LLIROp::new::<crate::hlir::Input>(Box::new(
|
||||
crate::hlir::Input {
|
||||
node: new_node,
|
||||
label: new_label,
|
||||
label: input_op.label.clone(),
|
||||
dtype: input_op.dtype,
|
||||
},
|
||||
)))
|
||||
@@ -1505,7 +1447,7 @@ mod tests {
|
||||
assert!(custom_op_remap.is_empty());
|
||||
|
||||
// Apply IO remap
|
||||
remap_llir_io_nodes(&mut llir, &node_remap, &hlir_graph);
|
||||
remap_llir_io_nodes(&mut llir, &node_remap);
|
||||
|
||||
// Verify remapped nodes
|
||||
let mut input_nodes: Vec<(usize, String)> = vec![];
|
||||
|
||||
@@ -25,7 +25,6 @@ fn dtype_propagation_rule(sort: &SortDef, dtype_source: &str) -> Rule {
|
||||
.fact(eq(e.clone(), op_match))
|
||||
.fact(eq(dty.clone(), dtype(args[dtype_source].clone())))
|
||||
.action(Action::Set(dtype(e), dty))
|
||||
.ruleset("dtype_prop")
|
||||
}
|
||||
|
||||
/// Helper: build a dtype-from-field rule for a direct IR op.
|
||||
@@ -35,7 +34,6 @@ fn dtype_from_field_rule(sort: &SortDef, dtype_field: &str) -> Rule {
|
||||
Rule::new()
|
||||
.fact(eq(e.clone(), op_match))
|
||||
.action(Action::Set(dtype(e), args[dtype_field].clone()))
|
||||
.ruleset("dtype_prop")
|
||||
}
|
||||
|
||||
// --- Dtype helpers for normalized ops (Op OpKind IList) ---
|
||||
@@ -60,7 +58,6 @@ fn dtype_propagation_op(kind_sort: &SortDef) -> Rule {
|
||||
))
|
||||
.fact(eq(dty.clone(), dtype(first_inp)))
|
||||
.action(Action::Set(dtype(e), dty))
|
||||
.ruleset("dtype_prop")
|
||||
}
|
||||
|
||||
/// Dtype from a field on the OpKind (e.g., Cast's dtype field).
|
||||
@@ -71,7 +68,6 @@ fn dtype_from_kind_field(kind_sort: &SortDef, field_name: &str) -> Rule {
|
||||
Rule::new()
|
||||
.fact(eq(e.clone(), op_term(kind_term, inputs)))
|
||||
.action(Action::Set(dtype(e), args[field_name].clone()))
|
||||
.ruleset("dtype_prop")
|
||||
}
|
||||
|
||||
/// Fixed dtype for a normalized op (e.g., Iota always Int).
|
||||
@@ -82,7 +78,6 @@ fn dtype_fixed_op(kind_sort: &SortDef, dtype_sort: &SortDef) -> Rule {
|
||||
Rule::new()
|
||||
.fact(eq(e.clone(), op_term(kind_term, inputs)))
|
||||
.action(Action::Set(dtype(e), dtype_sort.call(())))
|
||||
.ruleset("dtype_prop")
|
||||
}
|
||||
|
||||
/// Build an IList egglog string from input variable names.
|
||||
|
||||
Reference in New Issue
Block a user