Compare commits

..

1 Commits

Author SHA1 Message Date
Austin Glover
3e9f742bd7 try to get more stable GPU detection 2026-03-11 20:21:20 +00:00
148 changed files with 9432 additions and 24108 deletions

View File

@@ -1,55 +0,0 @@
{
"name": "Luminal (CUDA)",
"image": "ghcr.io/luminal-ai/luminal-docker:cuda",
"initializeCommand": "touch .env",
"runArgs": [
"--env-file",
".env",
"--runtime=nvidia",
"--env=NVIDIA_VISIBLE_DEVICES=nvidia.com/gpu=all",
"--env=NVIDIA_DRIVER_CAPABILITIES=compute,utility"
],
"containerEnv": {
"CARGO_HOME": "/home/ubuntu/.cache/luminal/cargo"
},
"containerUser": "ubuntu",
"features": {
"ghcr.io/devcontainers/features/common-utils:2": {
"installZsh": false,
"installOhMyZsh": false,
"username": "ubuntu",
"userUid": "1000",
"userGid": "1000",
"configureZshAsDefaultShell": false
}
},
"remoteUser": "ubuntu",
"remoteEnv": {
"CARGO_HOME": "/home/ubuntu/.cache/luminal/cargo"
},
"postStartCommand": "mkdir -p /home/ubuntu/.cache/luminal/cargo && git config --global --add safe.directory ${containerWorkspaceFolder} && gh auth setup-git",
"customizations": {
"vscode": {
"extensions": [
"ms-python.debugpy",
"ms-python.python",
"ms-python.vscode-pylance",
"ms-python.vscode-python-envs",
"ms-vscode.cmake-tools",
"ms-vscode.cpptools",
"ms-vscode.cpptools-extension-pack",
"ms-vscode.cpptools-themes",
"ms-vscode.makefile-tools",
"streetsidesoftware.code-spell-checker",
"hatookov.egglog-language",
"rust-lang.rust-analyzer",
"openai.chatgpt",
"anthropic.claude-code",
"tamasfe.even-better-toml",
"eamodio.gitlens",
"ms-vscode.live-server",
"tintinweb.graphviz-interactive-preview"
]
}
}
}

View File

@@ -1,29 +1,17 @@
{
"name": "Luminal (CPU)",
"image": "ghcr.io/luminal-ai/luminal-docker:cpu",
"initializeCommand": "touch .env",
"runArgs": [
"--env-file", ".env"
],
"containerEnv": {
"CARGO_HOME": "/home/ubuntu/.cache/luminal/cargo"
},
"containerUser": "ubuntu",
"name": "Luminal",
"image": "ghcr.io/luminal-ai/luminal-docker:latest",
"features": {
"ghcr.io/devcontainers/features/common-utils:2": {
"installZsh": false,
"installOhMyZsh": false,
"username": "ubuntu",
"userUid": "1000",
"userGid": "1000",
"configureZshAsDefaultShell": false
}
"ghcr.io/devcontainers/features/github-cli:1": {}
},
"remoteUser": "ubuntu",
"remoteEnv": {
"CARGO_HOME": "/home/ubuntu/.cache/luminal/cargo"
"GH_TOKEN": "${localEnv:GH_TOKEN}"
},
"postStartCommand": "mkdir -p /home/ubuntu/.cache/luminal/cargo && git config --global --add safe.directory ${containerWorkspaceFolder} && gh auth setup-git",
"hostRequirements": {
"gpu": "optional"
},
"shutdownAction": "stopContainer",
"postStartCommand": "git config --global --add safe.directory ${containerWorkspaceFolder} && (nvidia-smi > /dev/null 2>&1 && echo 'GPU available' || echo 'WARNING: GPU not detected - rebuild container if GPU is expected')",
"customizations": {
"vscode": {
"extensions": [
@@ -39,7 +27,6 @@
"streetsidesoftware.code-spell-checker",
"hatookov.egglog-language",
"rust-lang.rust-analyzer",
"openai.chatgpt",
"anthropic.claude-code",
"tamasfe.even-better-toml",
"eamodio.gitlens",

View File

@@ -1,30 +0,0 @@
name: CUDA Clippy
on:
push:
branches: ["main"]
pull_request:
branches: ["main"]
workflow_dispatch:
jobs:
cuda_clippy:
name: CUDA Clippy
runs-on: cuda_t4_runner
container:
image: ghcr.io/luminal-ai/luminal-docker:cuda
options: --gpus all
timeout-minutes: 20
steps:
- uses: actions/checkout@v6
- name: Mark workspace as safe for git
run: git config --global --add safe.directory "$GITHUB_WORKSPACE"
- uses: actions/setup-python@v5
with:
python-version: "3.11"
- name: Update Rust toolchain
run: rustup update
- uses: pre-commit/action@v3.0.1
with:
extra_args: cargo-clippy --all-files

View File

@@ -1,23 +0,0 @@
name: Fmt
on:
push:
branches: ["main"]
pull_request:
branches: ["main"]
workflow_dispatch:
jobs:
fmt:
name: Fmt
runs-on: ubuntu-latest
timeout-minutes: 20
steps:
- uses: actions/checkout@v6
- uses: actions/setup-python@v5
with:
python-version: "3.11"
- uses: pre-commit/action@v3.0.1
with:
extra_args: cargo-fmt --all-files

View File

@@ -1,25 +0,0 @@
name: Metal Clippy
on:
push:
branches: ["main"]
pull_request:
branches: ["main"]
workflow_dispatch:
jobs:
metal_clippy:
name: Metal Clippy
runs-on: macos-14
timeout-minutes: 30
steps:
- uses: actions/checkout@v6
- uses: actions/setup-python@v5
with:
python-version: "3.11"
- name: Update Rust toolchain
run: rustup update
- uses: pre-commit/action@v3.0.1
with:
extra_args: --hook-stage manual cargo-clippy-metal --all-files

View File

@@ -1,47 +0,0 @@
name: Modal Examples
on:
push:
branches: ["main"]
pull_request_target:
branches: ["main"]
types: [labeled, synchronize]
workflow_dispatch:
jobs:
modal_example:
if: >-
github.event_name == 'push'
|| github.event_name == 'workflow_dispatch'
|| (github.event_name == 'pull_request_target'
&& contains(github.event.pull_request.labels.*.name, 'modal-ready'))
name: "${{ matrix.example }} (Modal ${{ matrix.gpu.type }})"
runs-on: ubuntu-latest
environment: Modal
timeout-minutes: 70
strategy:
fail-fast: false
matrix:
example: [llama, gemma, qwen, qwen3_moe]
gpu:
- { type: "A100-80GB" }
# To add more GPUs, just append another entry:
# - { type: "H100" }
steps:
- uses: actions/checkout@v6
with:
ref: ${{ github.event.pull_request.head.sha || github.sha }}
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.11"
- name: Install Modal
run: pip install modal
- name: "Run ${{ matrix.example }} on Modal ${{ matrix.gpu.type }}"
env:
MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }}
MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }}
EXAMPLE: ${{ matrix.example }}
GPU_TYPE: ${{ matrix.gpu.type }}
run: modal run ci/modal_example.py

View File

@@ -1,23 +0,0 @@
name: Ruff Format
on:
push:
branches: ["main"]
pull_request:
branches: ["main"]
workflow_dispatch:
jobs:
ruff_format:
name: Ruff Format
runs-on: ubuntu-latest
timeout-minutes: 20
steps:
- uses: actions/checkout@v6
- uses: actions/setup-python@v5
with:
python-version: "3.11"
- uses: pre-commit/action@v3.0.1
with:
extra_args: ruff-format --all-files

View File

@@ -1,23 +0,0 @@
name: Ruff
on:
push:
branches: ["main"]
pull_request:
branches: ["main"]
workflow_dispatch:
jobs:
ruff:
name: Ruff
runs-on: ubuntu-latest
timeout-minutes: 20
steps:
- uses: actions/checkout@v6
- uses: actions/setup-python@v5
with:
python-version: "3.11"
- uses: pre-commit/action@v3.0.1
with:
extra_args: ruff-check --all-files

View File

@@ -1,24 +0,0 @@
name: Test Core
on:
push:
branches: ["main"]
pull_request:
branches: ["main"]
workflow_dispatch:
env:
CARGO_TERM_COLOR: always
jobs:
core_unit_test:
name: Core Unit Tests
runs-on: ubuntu-latest
container:
image: ghcr.io/luminal-ai/luminal-docker:cpu
timeout-minutes: 20
steps:
- uses: actions/checkout@v6
- name: Run tests
run: cargo test --workspace --exclude luminal_cuda_lite --exclude luminal_metal --exclude luminal_bench --verbose

View File

@@ -1,37 +0,0 @@
name: Test CUDA
on:
push:
branches: ["main"]
pull_request_target:
branches: ["main"]
types: [labeled, synchronize]
workflow_dispatch:
jobs:
cuda_unit_test:
if: >-
github.event_name == 'push'
|| github.event_name == 'workflow_dispatch'
|| (github.event_name == 'pull_request_target'
&& contains(github.event.pull_request.labels.*.name, 'modal-ready'))
name: Cuda Unit Tests
runs-on: ubuntu-latest
environment: Modal
timeout-minutes: 30
steps:
- uses: actions/checkout@v6
with:
ref: ${{ github.event.pull_request.head.sha || github.sha }}
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.11"
- name: Install Modal
run: pip install modal
- name: Run CUDA tests on Modal
env:
MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }}
MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }}
run: modal run ci/modal_cargo_test.py

View File

@@ -1,19 +0,0 @@
name: Test Metal
on:
push:
branches: ["main"]
pull_request:
branches: ["main"]
workflow_dispatch:
jobs:
metal_unit_test:
name: Metal Unit Tests
runs-on: macos-14
timeout-minutes: 30
steps:
- uses: actions/checkout@v6
- name: Run Metal crate tests
run: rustup update; cargo test -p luminal_metal --verbose -- --test-threads=1

View File

@@ -1,49 +0,0 @@
name: Test Python CUDA
on:
push:
branches: ["main"]
pull_request_target:
branches: ["main"]
types: [labeled, synchronize]
workflow_dispatch:
jobs:
python_cuda_tests:
if: >-
github.event_name == 'push'
|| github.event_name == 'workflow_dispatch'
|| (github.event_name == 'pull_request_target'
&& contains(github.event.pull_request.labels.*.name, 'modal-ready'))
name: Python CUDA Tests
runs-on: ubuntu-latest
environment: Modal
timeout-minutes: 60
defaults:
run:
working-directory: crates/luminal_python
steps:
- uses: actions/checkout@v6
with:
ref: ${{ github.event.pull_request.head.sha || github.sha }}
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.11"
- name: Install Modal
run: pip install modal
- name: Run pytest with CUDA backend on Modal
env:
MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }}
MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }}
HF_TOKEN: ${{ secrets.HF_TOKEN }}
run: modal run modal_pytest_runner.py --gpu A100 --timeout 3300 --profile --profile-output-dir luminal_artifacts/pytest-profiling/github-${{ github.run_id }}-${{ github.run_attempt }} tests/ -v -s -m "not slow"
- name: Upload Modal pytest profiling artifacts
if: always()
uses: actions/upload-artifact@v4
with:
name: python-cuda-pytest-profiling-${{ github.run_id }}-${{ github.run_attempt }}
path: crates/luminal_python/luminal_artifacts/pytest-profiling/github-${{ github.run_id }}-${{ github.run_attempt }}
retention-days: 7
if-no-files-found: warn

View File

@@ -1,28 +0,0 @@
name: Test Python Native
on:
push:
branches: ["main"]
pull_request:
branches: ["main"]
workflow_dispatch:
jobs:
python_native_tests:
name: Python Native Tests
runs-on: ubuntu-latest
container:
image: ghcr.io/luminal-ai/luminal-docker:cpu
timeout-minutes: 45
defaults:
run:
working-directory: crates/luminal_python
steps:
- uses: actions/checkout@v6
- name: Update Rust toolchain
run: rustup update
- name: Build maturin extension
run: uv run maturin develop --manifest-path rust/Cargo.toml
- name: Run pytest
run: uv run pytest tests/test_hlir_ops.py tests/test_unary.py -v -m "not slow"

87
.github/workflows/test.yml vendored Normal file
View File

@@ -0,0 +1,87 @@
name: Test
on:
push:
branches: ["main"]
pull_request:
branches: ["main"]
env:
CARGO_TERM_COLOR: always
jobs:
core_unit_test:
name: Core Unit Tests
runs-on: ubuntu-latest
timeout-minutes: 20
steps:
- uses: actions/checkout@v6
- name: Run tests
run: rustup update; cargo test --workspace --exclude luminal_cuda --exclude luminal_metal --exclude luminal_bench --verbose
clippy:
name: Clippy
runs-on: ubuntu-latest
timeout-minutes: 20
steps:
- uses: actions/checkout@v6
- name: Run clippy
run: rustup update; cargo clippy --workspace --exclude luminal_cuda --exclude luminal_metal --exclude luminal_bench --all-targets -- -D warnings
fmt:
name: Fmt
runs-on: ubuntu-latest
timeout-minutes: 20
steps:
- uses: actions/checkout@v6
- name: Format
run: cargo fmt --all --check
cuda_unit_test:
name: Cuda Unit Tests
runs-on: cuda_t4_runner
container:
image: ghcr.io/luminal-ai/luminal-docker:latest
options: --gpus all
timeout-minutes: 30
steps:
- uses: actions/checkout@v6
- name: Detect GPU compute capability
run: |
CAP=$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader | head -1 | tr -d '.')
echo "CUDA_COMPUTE_CAP=${CAP}" >> "$GITHUB_ENV"
- name: Run CUDA crate tests
run: cargo test -p luminal_cuda --verbose -- --test-threads=1
# cuda_llama: # disabled because t4 doesn't have enough memory for full precision llama. re-enable when we can run on larger machines or use 8-bit precision
# name: Cuda Llama
# runs-on: cuda_t4_runner
# timeout-minutes: 30
# env:
# CUDA_HOME: /usr/local/cuda-12.8
# LD_LIBRARY_PATH: /usr/local/cuda-12.8/lib64
# steps:
# - uses: actions/checkout@v6
# - name: Install system deps
# run: |
# sudo apt-get update
# sudo apt-get install -y --no-install-recommends \
# protobuf-compiler \
# cuda-nvrtc-12-8
# - name: Install Rust
# run: |
# curl -sSf https://sh.rustup.rs | sh -s -- -y --profile minimal
# echo "$HOME/.cargo/bin" >> "$GITHUB_PATH"
# - name: Update Rust
# run: rustup update
# - name: Install uv
# run: curl -LsSf https://astral.sh/uv/install.sh | sh
# - name: Download Llama
# working-directory: examples/llama
# run: uv run --script setup/setup.py
# - name: Run Llama
# working-directory: examples/llama
# run: SEARCH=1 cargo run --release

15
.gitignore vendored
View File

@@ -15,10 +15,6 @@ Cargo.lock
*.gguf
.claude-project
.claude-memory
.codex
*.pftrace
*.safetensors
*.safetensors.index.json
@@ -26,14 +22,3 @@ tokenizer.json
**/.cache
**/proptest-regressions
opencode.json
# Python build artifacts
*.so
*.pyd
__pycache__/
*.pyc
*.pyo
*.egg-info/
dist/
build/
uv.lock

View File

@@ -1,38 +0,0 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.14.5
hooks:
- id: ruff-check
name: ruff check
files: ^crates/luminal_python/.*\.py$
- id: ruff-format
name: ruff format
files: ^crates/luminal_python/.*\.py$
- repo: local
hooks:
- id: cargo-fmt
name: cargo fmt
entry: cargo fmt --all --check
language: system
pass_filenames: false
files: \.(rs|toml)$
- id: cargo-clippy
name: cargo clippy
entry: cargo clippy --workspace --exclude luminal_cuda_lite --exclude luminal_metal --exclude luminal_bench --all-targets -- -D warnings
language: system
pass_filenames: false
files: \.(rs|toml)$
- id: cargo-clippy-metal
name: cargo clippy metal
entry: cargo clippy -p luminal_metal --all-targets -- -D warnings
language: system
pass_filenames: false
files: \.(rs|toml)$
stages: [manual]
- id: cargo-clippy-cuda-lite
name: cargo clippy cuda_lite
entry: cargo clippy -p luminal_cuda_lite --all-targets -- -D warnings
language: system
pass_filenames: false
files: \.(rs|toml)$
stages: [manual]

View File

@@ -3,9 +3,9 @@
## Structure
Luminal is a core-and-plugin design, where the core crate `.` contains everything core to Luminal including the graph and the GraphTensor api, the shapetracker, and the primitive ops.
All other functionality is split into crates in the `crates/` directory. For instance, the Cuda compiler is in `luminal_cuda_lite` and the autograd engine is in `luminal_training`. `luminal_nn` has common nn modules.
All other functionality is split into crates in the `crates/` directory. For instance, the Cuda compiler is in `luminal_cuda` and the autograd engine is in `luminal_training`. `luminal_nn` has common nn modules.
## Testing Instructions
- Find the CI plan in the .github/workflows folder.
- Currently running `cargo test` in luminal_metal and luminal_cuda_lite require access to an Apple and Nvidia GPU respectively.
- Currently running `cargo test` in luminal_metal and luminal_cuda require access to an Apple and Nvidia GPU respectively.
- PRs must have no clippy errors and `cargo fmt` must be ran before a PR is submitted.

View File

@@ -37,8 +37,8 @@ lru = "0.16.2"
edition = "2024"
[dev-dependencies]
candle-core = "0.9.2"
candle-nn = "0.9.2"
candle-core = "0.9.2-alpha.1"
candle-nn = "0.9.2-alpha.1"
ordered-float = "5.1.0"
proptest = "1.9.0"
@@ -46,12 +46,11 @@ proptest = "1.9.0"
members = [
"examples/*",
"crates/luminal_nn",
"crates/luminal_cuda_lite",
"crates/luminal_cuda",
"crates/luminal_metal",
"crates/luminal_tracing",
"crates/luminal_bench",
"crates/luminal_python/rust",
]
[patch.crates-io]
candle-kernels = { git = "https://github.com/huggingface/candle.git", rev = "a0dbd8b8aef6bde9adca3e8ad90791609d64974b" }
candle-kernels = { git = "https://github.com/asglover/candle.git", branch = "fix/disable-bf16-wmma-pre-ampere" }

View File

@@ -4,7 +4,7 @@
Luminal is a high-performance general-purpose inference compiler.
</h3>
[![CI Status](https://img.shields.io/github/actions/workflow/status/luminal-ai/luminal/test-core.yml?style=for-the-badge&logo=github-actions&logoColor=white&branch=main)](https://github.com/luminal-ai/luminal/actions)
[![CI Status](https://img.shields.io/github/actions/workflow/status/jafioti/luminal/test.yml?style=for-the-badge&logo=github-actions&logoColor=white&branch=main)](https://github.com/jafioti/luminal/actions)
[![Docs](https://img.shields.io/badge/Documentation-green?style=for-the-badge&color=0D9373)](https://docs.luminalai.com)
[![Current Crates.io Version](https://img.shields.io/crates/v/luminal.svg?style=for-the-badge&logo=rust)](https://crates.io/crates/luminal)
[![discord](https://dcbadge.limes.pink/api/server/APjuwHAbGy)](https://discord.gg/APjuwHAbGy)
@@ -45,18 +45,6 @@ cd ./examples/llama
cargo run --release
```
**PyTorch models via `torch.compile`**
Any PyTorch model can be run through Luminal by swapping the backend:
```python
import torch
from luminal import luminal_backend
model_compiled = torch.compile(model, backend=luminal_backend)
output = model_compiled(x)
```
See `crates/luminal_python/` for the PT2-based bridge.
## Features
### Speed
@@ -87,7 +75,7 @@ The current ML ecosystem is too fragmented, and the solution isn't another layer
### Validated against Pytorch
Correctness matters. We write as much tests as possible to cover all ops and verify they work the same as an equivalent Pytorch implementation.
Correctness matters. We write as much tests as possible to cover all ops and verify they work the same as an equivalent Pytorch implementation. ([Improvements needed!](https://github.com/jafioti/luminal/issues/20))
## Ideology
@@ -114,12 +102,12 @@ Now we can do:
## Where are we?
- Search is the default execution path — compile via `build_search_space` and `search` (see the Usage example above).
- Search is partially merged. We are between 1.0 and 2.0 (search), which will be completed within the next month or so.
- Metal and Cuda are supported for running models on Macs and Nvidia GPUs respectively, in both full and half precision.
- Llama 3, Gemma, Qwen (incl. MoE variants), and a paged-attention Llama are implemented in `examples/`. See instructions above for running.
- Full training support with graph-based autograd.
- Llama 3, Phi 3, Whisper and Yolo v8 are implemented in `examples/`. See instructions above for running.
- We have a small library of NN modules in `luminal_nn`, including transformers.
- A large surface of high-level ops lives in `src/frontend/` aiming to match the most used ~80% of the PyTorch api.
- PyTorch models can be run through luminal via `torch.compile` — see `crates/luminal_python/`.
- A significant amount of high-level ops are implemented in `hl_ops`. We are aiming to match the most used ~80% of the pytorch api.
Some things on the roadmap:

View File

@@ -1,68 +0,0 @@
import modal
import subprocess
import os
gpu_type = os.environ.get("GPU_TYPE", "T4")
CUDARC_CUDA_VERSION = "12080"
app = modal.App("luminal-ci-cargo-test")
WORKDIR = "/workspace/luminal"
cuda_image = (
modal.Image.from_registry("nvcr.io/nvidia/pytorch:25.03-py3")
.apt_install("protobuf-compiler")
.run_commands(
"curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y",
)
.env(
{
"PATH": "/root/.cargo/bin:$PATH",
"CUDARC_CUDA_VERSION": CUDARC_CUDA_VERSION,
}
)
.add_local_dir(".", remote_path=WORKDIR, copy=True)
)
@app.function(
image=cuda_image,
gpu=gpu_type,
timeout=1800, # 30 minutes
)
def run_cargo_test():
"""Run cargo test for luminal_cuda_lite on a Modal GPU."""
subprocess.run(["nvidia-smi"], check=True)
# Detect GPU compute capability
result = subprocess.run(
["nvidia-smi", "--query-gpu=compute_cap", "--format=csv,noheader"],
capture_output=True,
text=True,
check=True,
)
compute_cap = result.stdout.strip().replace(".", "")
subprocess.run(
[
"cargo",
"test",
"-p",
"luminal_cuda_lite",
"--verbose",
"--",
"--test-threads=1",
],
cwd=WORKDIR,
env={
**os.environ,
"CUDARC_CUDA_VERSION": CUDARC_CUDA_VERSION,
"CUDA_COMPUTE_CAP": compute_cap,
},
check=True,
)
@app.local_entrypoint()
def main():
run_cargo_test.remote()

View File

@@ -1,67 +0,0 @@
import modal
import subprocess
import os
example = os.environ.get("EXAMPLE", "llama")
gpu_type = os.environ.get("GPU_TYPE", "A100-80GB")
CUDARC_CUDA_VERSION = "12080"
HF_CACHE_VOLUME_NAME = "luminal-hf-cache-v2"
HF_CACHE_PATH = "/root/.cache/huggingface"
app = modal.App(f"luminal-ci-{example}")
hf_cache = modal.Volume.from_name(
HF_CACHE_VOLUME_NAME,
create_if_missing=True,
version=2,
)
WORKDIR = "/workspace/luminal"
cuda_image = (
modal.Image.from_registry(
"nvcr.io/nvidia/pytorch:25.03-py3"
)
.apt_install("protobuf-compiler")
.run_commands(
"curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y",
)
.env(
{
"PATH": "/root/.cargo/bin:$PATH",
"CUDARC_CUDA_VERSION": CUDARC_CUDA_VERSION,
}
)
.add_local_dir(".", remote_path=WORKDIR, copy=True)
)
@app.function(
image=cuda_image,
gpu=gpu_type,
timeout=3600, # 60 minutes
volumes={
HF_CACHE_PATH: hf_cache,
},
)
def run_example(example: str):
"""Build and run a luminal example on a Modal GPU."""
subprocess.run(["nvidia-smi"], check=True)
subprocess.run(
["cargo", "run", "--release"],
cwd=f"{WORKDIR}/examples/{example}",
env={
**os.environ,
"CUDARC_CUDA_VERSION": CUDARC_CUDA_VERSION,
"HF_HOME": HF_CACHE_PATH,
},
check=True,
)
hf_cache.commit()
@app.local_entrypoint()
def main():
run_example.remote(example)

View File

@@ -1,5 +1,5 @@
[package]
name = "luminal_cuda_lite"
name = "luminal_cuda"
version = "0.2.0"
edition = "2024"
description = "Cuda compiler for luminal"
@@ -26,7 +26,7 @@ libc = "0.2"
colorize = "*"
[dev-dependencies]
candle-core = { version = "0.9.2", features = ["cuda"] }
candle-core = { version = "0.9.2-alpha.1", features = ["cuda"] }
proptest = "1.9.0"
rand = "0.9.2"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }

View File

@@ -1,4 +1,4 @@
## luminal_cuda_lite
## luminal_cuda
This crate contains the CUDA backend for Luminal.
@@ -26,4 +26,4 @@ Thread ops are not yet merged. Stay tuned!
### Architecture
`luminal_cuda_lite` can model a joint search space that smoothly searches through various mixed configurations of these ops. At compile time, a waterfall process takes place to iteratively raise each op to the level above, resulting in all host-level ops in the final runtime graph. For instance, block ops get combined into megakernels, implemented as kernel ops. Kernel ops get combined into cuda graphs, implemented as host ops.
`luminal_cuda` can model a joint search space that smoothly searches through various mixed configurations of these ops. At compile time, a waterfall process takes place to iteratively raise each op to the level above, resulting in all host-level ops in the final runtime graph. For instance, block ops get combined into megakernels, implemented as kernel ops. Kernel ops get combined into cuda graphs, implemented as host ops.

View File

@@ -0,0 +1,252 @@
use itertools::Itertools;
use luminal::{prelude::FxHashMap, shape::Expression};
#[derive(Debug, PartialEq, Eq)]
enum CStructType {
Float,
FloatArr(usize),
Int,
IntArr(usize),
Long,
LongArr(usize),
Bool,
BoolArr(usize),
Ptr,
PtrArr(usize),
Bytes(usize),
}
#[derive(Debug)]
pub struct CStruct<'a> {
buf: Vec<u8>,
max_align: usize,
struct_types: Vec<(String, CStructType)>,
expressions: Option<&'a FxHashMap<Expression, i32>>,
pub(crate) recorded_expressions: Vec<Expression>,
}
impl<'a> CStruct<'a> {
pub fn new(expressions: Option<&'a FxHashMap<Expression, i32>>) -> Self {
Self {
max_align: 1,
struct_types: vec![],
buf: vec![],
expressions,
recorded_expressions: vec![],
}
}
fn align_to(&mut self, align: usize) {
self.max_align = self.max_align.max(align);
let len = self.buf.len();
let rem = len % align;
if rem != 0 {
let pad = align - rem;
self.buf.extend(std::iter::repeat_n(0u8, pad));
}
}
pub fn int(mut self, name: impl ToString, v: i32) -> Self {
self.struct_types.push((name.to_string(), CStructType::Int));
self.align_to(4);
self.buf.extend_from_slice(&v.to_ne_bytes());
self
}
pub fn int_arr(mut self, name: impl ToString, vs: &[i32]) -> Self {
self.struct_types
.push((name.to_string(), CStructType::IntArr(vs.len())));
self.align_to(4);
for &v in vs {
self.buf.extend_from_slice(&v.to_ne_bytes());
}
self
}
pub fn expr(mut self, name: impl ToString, v: impl Into<Expression>) -> Self {
if let Some(expressions) = self.expressions {
self.struct_types.push((name.to_string(), CStructType::Int));
let v = expressions[&v.into()];
self.align_to(4);
self.buf.extend_from_slice(&v.to_ne_bytes());
} else {
self.recorded_expressions.push(v.into());
}
self
}
pub fn expr_arr(mut self, name: impl ToString, vs: &[Expression]) -> Self {
if let Some(expressions) = self.expressions {
self.struct_types
.push((name.to_string(), CStructType::IntArr(vs.len())));
self.align_to(4);
for &v in vs {
let v = expressions[&v];
self.buf.extend_from_slice(&v.to_ne_bytes());
}
} else {
self.recorded_expressions.extend(vs.iter().copied());
}
self
}
pub fn long(mut self, name: impl ToString, v: i64) -> Self {
self.struct_types
.push((name.to_string(), CStructType::Long));
self.align_to(8);
self.buf.extend_from_slice(&v.to_ne_bytes());
self
}
pub fn long_arr(mut self, name: impl ToString, vs: &[i64]) -> Self {
self.struct_types
.push((name.to_string(), CStructType::LongArr(vs.len())));
self.align_to(8);
for &v in vs {
self.buf.extend_from_slice(&v.to_ne_bytes());
}
self
}
pub fn float(mut self, name: impl ToString, v: f32) -> Self {
self.struct_types
.push((name.to_string(), CStructType::Float));
self.align_to(4);
self.buf.extend_from_slice(&v.to_ne_bytes());
self
}
pub fn float_arr(mut self, name: impl ToString, vs: &[f32]) -> Self {
self.struct_types
.push((name.to_string(), CStructType::FloatArr(vs.len())));
self.align_to(4);
for &v in vs {
self.buf.extend_from_slice(&v.to_ne_bytes());
}
self
}
pub fn bool(mut self, name: impl ToString, v: bool) -> Self {
self.struct_types
.push((name.to_string(), CStructType::Bool));
self.align_to(1);
self.buf.push(if v { 1 } else { 0 });
self
}
pub fn bool_arr(mut self, name: impl ToString, vs: &[bool]) -> Self {
self.struct_types
.push((name.to_string(), CStructType::BoolArr(vs.len())));
self.align_to(1);
for &v in vs {
self.buf.push(if v { 1 } else { 0 });
}
self
}
pub fn ptr_const_f32(mut self, name: impl ToString, p: *const f32) -> Self {
self.struct_types.push((name.to_string(), CStructType::Ptr));
let ptr_size = std::mem::size_of::<usize>(); // usually 8
let ptr_align = ptr_size;
self.align_to(ptr_align);
let addr = p as usize;
let bytes = addr.to_ne_bytes();
self.buf.extend_from_slice(&bytes[..ptr_size]);
self
}
pub fn ptr_mut_f32(self, name: impl ToString, p: *mut f32) -> Self {
self.ptr_const_f32(name, p as *const f32)
}
pub fn ptr_const_f32_arr(mut self, name: impl ToString, p: &[*const f32]) -> Self {
self.struct_types
.push((name.to_string(), CStructType::PtrArr(p.len())));
let ptr_size = std::mem::size_of::<usize>(); // usually 8
let ptr_align = ptr_size;
self.align_to(ptr_align);
for &p in p {
let addr = p as usize;
let bytes = addr.to_ne_bytes();
self.buf.extend_from_slice(&bytes[..ptr_size]);
}
self
}
/// Returns the current size of the buffer after alignment for a pointer field.
/// Useful for computing field offsets.
pub fn current_size(&self) -> usize {
let ptr_align = std::mem::size_of::<usize>();
let len = self.buf.len();
let rem = len % ptr_align;
if rem != 0 {
len + (ptr_align - rem)
} else {
len
}
}
/// Pad the struct size to a multiple of max_align.
pub fn finish_struct(mut self) -> Vec<u8> {
assert!(
self.expressions.is_some(),
"Can only create cstruct bytes when expression map is provided!"
);
let align = self.max_align;
if align > 1 {
let len = self.buf.len();
let rem = len % align;
if rem != 0 {
let pad = align - rem;
self.buf.extend(std::iter::repeat_n(0u8, pad));
}
}
self.buf
}
/// Returns (size, alignment) of the struct.
pub fn size_and_align(&self) -> (usize, usize) {
let align = self.max_align;
let len = self.buf.len();
let rem = len % align;
let size = if rem != 0 { len + (align - rem) } else { len };
(size, align)
}
/// Insert a raw byte field (e.g., another struct).
/// `align` must be the alignment of the nested struct.
pub fn bytes(mut self, align: usize, name: impl ToString, data: &[u8]) -> Self {
self.struct_types
.push((name.to_string(), CStructType::Bytes(data.len())));
self.align_to(align);
self.buf.extend_from_slice(data);
self
}
}
impl std::fmt::Display for CStruct<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let s = self
.struct_types
.iter()
.map(|(name, ty)| match ty {
CStructType::Bool => format!("bool {name};"),
CStructType::BoolArr(l) => format!("bool {name}[{l}];"),
CStructType::Float => format!("float {name};"),
CStructType::FloatArr(l) => format!("float {name}[{l}];"),
CStructType::Int => format!("int {name};"),
CStructType::IntArr(l) => format!("int {name}[{l}];"),
CStructType::Long => format!("long {name};"),
CStructType::LongArr(l) => format!("long {name}[{l}];"),
CStructType::Ptr => format!("float* {name};"),
CStructType::PtrArr(l) => format!("float* {name}[{l}];"),
CStructType::Bytes(l) => format!("char payload[{l}];"),
})
.join("\n");
write!(f, "{s}")
}
}

View File

@@ -0,0 +1,327 @@
const int N_OPS = 0;
const int N_TIMING_SLOTS = 0;
const int N_TASKS = 0; // Rendered at compile time
//%n_barriers_const%
enum OpCode {
//%extra_op_codes%
};
//%extra_op_structs%
union Payload {
//%extra_op_payloads%
};
struct Task {
OpCode op;
int range;
int remaining;
int in_dep_a_stride;
int in_dep_a_base;
int in_dep_b_stride;
int in_dep_b_base;
int in_dep_c_stride;
int in_dep_c_base;
int out_dep_stride;
int out_dep_base;
int source_indices[6];
int out_index;
Payload payload;
};
struct SMEvent {
unsigned long long start;
unsigned long long stop;
int event;
};
//%constants%
__device__ __noinline__ int eval_expression(int expression, int const_z) {
switch (expression) {
//%expr_fns%
}
}
__device__ __forceinline__ unsigned long long read_globaltimer() {
unsigned long long t;
asm volatile("mov.u64 %0, %%globaltimer;" : "=l"(t));
return t;
}
//%extra_op_functions%
//%extra_prologue_functions%
__device__ __forceinline__ void nanosleep(unsigned int cycles) {
asm volatile("nanosleep.u32 %0;" ::"r"(cycles));
}
__device__ __forceinline__ int atomic_load_acquire(int *addr) {
int val;
asm volatile("ld.global.acquire.gpu.b32 %0, [%1];" : "=r"(val) : "l"(addr));
return val;
}
struct NextTask {
int current;
int task_idx;
};
// Lock-free task fetching using atomicSub for claiming (reduces CAS contention)
// remaining encoding:
// -1 = uninitialized
// > 0 = iterations remaining (atomicSub to claim, iteration = old - 1)
// <= 0 = exhausted
__device__ inline bool fetch_next_task(Task *tasks, int num_tasks, int *head,
NextTask *out) {
while (true) {
int idx = atomic_load_acquire(head);
if (idx >= num_tasks)
return false;
Task *t = &tasks[idx];
int remaining = atomicAdd(&t->remaining, 0);
// Handle uninitialized task - one CAS to initialize
if (remaining == -1) {
int range = eval_expression(t->range, 0);
atomicCAS(&t->remaining, -1, range);
continue;
}
// Task already exhausted, advance head
if (remaining <= 0) {
atomicMax(head, idx + 1);
continue;
}
// Claim via atomicSub - guaranteed to make progress, no CAS retry
int old = atomicSub(&t->remaining, 1);
if (old > 0) {
out->task_idx = idx;
out->current = old - 1;
if (old == 1) {
atomicMax(head, idx + 1);
}
// DEBUG: This path indicates successful task claim
return true;
}
// Race: exhausted between check and atomicSub, advance head
atomicMax(head, idx + 1);
}
}
__device__ inline void record_event(SMEvent *__restrict__ timings,
int *event_idx, int event_type) {
if (*event_idx < N_TIMING_SLOTS) {
unsigned long long now = read_globaltimer();
if (*event_idx > 0) { // record the end of the previous op
timings[*event_idx - 1].stop = now;
}
timings[*event_idx].start = now;
timings[*event_idx].stop = 0ull;
timings[*event_idx].event = event_type;
(*event_idx)++;
}
}
extern "C" {
// Kernel params: internal buffers in order, then dyn_dims
// tasks, head, ready, queue_lock, timings, start_times, buffers, dyn_dims
__global__ void worker_kernel(
Task* __restrict__ tasks,
int* __restrict__ head,
int* __restrict__ ready,
int* __restrict__ queue_lock,
SMEvent* __restrict__ timings,
unsigned long long* __restrict__ start_times,
float* const* buffers,
int* __restrict__ dyn_dims
) {
// Constants N_TASKS and N_BARRIERS are baked into the kernel string
// Note: Reset is now done on host side in pre_execute
// All buffers (head, queue_lock, ready, tasks) are pre-initialized
// DEBUG: Count tasks fetched (use queue_lock as counter since it's not being used)
// Note: queue_lock is in internal_bufs[3]
__shared__ NextTask nt;
__shared__ int done;
__shared__ int dep_out;
__shared__ bool run_a_prologue;
__shared__ bool run_b_prologue;
__shared__ bool run_c_prologue;
__shared__ bool stop_wait_loop;
__shared__ float scratchpad[8192]; // 32 KB scratchpad
__shared__ const float* source_ptrs[6];
__shared__ float* out_ptr;
int recorded_event = 0;
timings += blockIdx.x * N_TIMING_SLOTS;
if (threadIdx.x == 0) {
start_times[blockIdx.x] = read_globaltimer();
}
while (true) {
if (threadIdx.x == 0) {
record_event(timings, &recorded_event, 0); // Record issue start
done = !fetch_next_task(tasks, N_TASKS, head, &nt);
}
__syncthreads();
if (done)
break;
const Task *t = &tasks[nt.task_idx];
// Resolve buffer pointers from indices
if (threadIdx.x == 0) {
source_ptrs[0] = buffers[t->source_indices[0]];
source_ptrs[1] = buffers[t->source_indices[1]];
source_ptrs[2] = buffers[t->source_indices[2]];
source_ptrs[3] = buffers[t->source_indices[3]];
source_ptrs[4] = buffers[t->source_indices[4]];
source_ptrs[5] = buffers[t->source_indices[5]];
out_ptr = buffers[t->out_index];
}
__syncthreads();
int dep_a = 0;
int dep_b = 0;
int dep_c = 0;
// Thread 0 calculates dependencies and waits for inputs
if (threadIdx.x == 0) {
// Note: atomic_load_acquire provides visibility for ready array
dep_a = (t->in_dep_a_base == -1
? 0
: (eval_expression(t->in_dep_a_base, 0) +
eval_expression(t->in_dep_a_stride, nt.current)));
dep_b = (t->in_dep_b_base == -1
? 0
: (eval_expression(t->in_dep_b_base, 0) +
eval_expression(t->in_dep_b_stride, nt.current)));
dep_c = (t->in_dep_c_base == -1
? 0
: (eval_expression(t->in_dep_c_base, 0) +
eval_expression(t->in_dep_c_stride, nt.current)));
dep_out = eval_expression(t->out_dep_base, 0) +
eval_expression(t->out_dep_stride, nt.current);
// Increment the output barrier to signal an op is in-flight
atomicAdd(&ready[dep_out], 1);
record_event(timings, &recorded_event, 1); // Record wait start
// Wait on input dependencies and run prologues as inputs become ready
run_a_prologue = false;
run_b_prologue = false;
run_c_prologue = false;
stop_wait_loop = false;
}
__syncthreads();
bool a_done = false, b_done = false, c_done = false, tmp;
// Optimize: if deps are same, reuse atomic load result
const bool ab_same = (dep_a == dep_b);
const bool ac_same = (dep_a == dep_c);
const bool bc_same = (dep_b == dep_c);
while (true) {
if (threadIdx.x == 0) {
// Derive x_done and run_x_prologue with optimized atomic loads
if (!a_done) {
tmp = atomic_load_acquire(&ready[dep_a]) <= 0;
if (tmp) {
run_a_prologue = true;
a_done = true;
// Propagate to same deps
if (ab_same) {
run_b_prologue = true;
b_done = true;
}
if (ac_same) {
run_c_prologue = true;
c_done = true;
}
}
}
if (!b_done && !ab_same) {
tmp = atomic_load_acquire(&ready[dep_b]) <= 0;
if (tmp) {
run_b_prologue = true;
b_done = true;
if (bc_same) {
run_c_prologue = true;
c_done = true;
}
}
}
if (!c_done && !ac_same && !bc_same) {
tmp = atomic_load_acquire(&ready[dep_c]) <= 0;
if (tmp) {
run_c_prologue = true;
c_done = true;
}
}
if (a_done && b_done && c_done)
stop_wait_loop = true;
}
__syncthreads();
// Early exit if all dependencies satisfied (skip prologue checks)
if (stop_wait_loop)
break;
if (run_a_prologue) {
switch (t->op) {
//%prologue_a_calls%
}
if (threadIdx.x == 0) {
run_a_prologue = false;
}
}
if (run_b_prologue) {
switch (t->op) {
//%prologue_b_calls%
}
if (threadIdx.x == 0) {
run_b_prologue = false;
}
}
if (run_c_prologue) {
switch (t->op) {
//%prologue_c_calls%
}
if (threadIdx.x == 0) {
run_c_prologue = false;
}
}
__syncthreads();
}
if (threadIdx.x == 0)
record_event(timings, &recorded_event,
t->op + 2); // Record main op, ends Wait
// Execute main operation
switch (t->op) {
//%extra_op_calls%
}
__syncthreads();
// Arrive at output barrier
if (threadIdx.x == 0) {
__threadfence();
atomicSub(&ready[dep_out], 1);
}
}
if (threadIdx.x == 0 && recorded_event > 0) {
timings[recorded_event - 1].stop = read_globaltimer();
}
}
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,82 @@
//! Compiles BlockOp subgraphs into KernelOp (MegakernelOp).
use std::sync::Arc;
use cudarc::driver::{CudaFunction, CudaModule, CudaStream};
use luminal::{
graph::LLIRGraph,
op::LLIROp,
prelude::{
FxHashMap, FxHashSet, NodeIndex,
petgraph::{Direction, visit::EdgeRef},
},
};
use tracing::{Level, span};
use crate::{kernel::KernelOp, runtime::partition_marked_convex};
use super::{BlockOp, MegakernelOp};
/// Compile all BlockOp subgraphs in the LLIR graph into MegakernelOps.
///
/// This function:
/// 1. Finds all BlockOp nodes in the graph
/// 2. Partitions them into convex subgraphs
/// 3. For each subgraph, creates a MegakernelOp (which implements KernelOp)
/// 4. Adds the megakernel node to the llir_graph with appropriate edges
///
/// Returns mappings needed for the kernel compilation phase:
/// - `megakernel_to_blocks`: Maps each megakernel node to the BlockOp nodes it contains
/// (used to include block op nodes in the kernel's inputs for buffer pointer collection)
#[allow(clippy::type_complexity)]
pub fn block_to_kernel(
llir_graph: &mut LLIRGraph,
cuda_stream: &Arc<CudaStream>,
kernel_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
) -> FxHashMap<NodeIndex, Vec<NodeIndex>> {
let _span = span!(Level::TRACE, "block_to_kernel").entered();
let block_ops_in_graph = llir_graph
.node_indices()
.filter(|n| llir_graph[*n].to_dialect::<dyn BlockOp>().is_some())
.collect::<FxHashSet<_>>();
if block_ops_in_graph.is_empty() {
return FxHashMap::default();
}
let mut megakernel_to_blocks: FxHashMap<NodeIndex, Vec<NodeIndex>> = FxHashMap::default();
for subgraph in partition_marked_convex(llir_graph, &block_ops_in_graph).unwrap() {
// Create MegakernelOp which implements KernelOp
let megakernel_op = MegakernelOp::new(llir_graph, &subgraph, cuda_stream, kernel_cache);
// Add megakernel node to llir_graph as a KernelOp
let megakernel_node =
llir_graph.add_node(LLIROp::new(Box::new(megakernel_op) as Box<dyn KernelOp>));
// Find external inputs: nodes outside subgraph that have edges into subgraph
// These edges establish exec_graph dependencies (megakernel waits for inputs)
let external_inputs: FxHashSet<NodeIndex> = subgraph
.iter()
.flat_map(|&node| {
llir_graph
.edges_directed(node, Direction::Incoming)
.map(|e| e.source())
.filter(|src| !subgraph.contains(src))
})
.collect();
// Add edges from external inputs to megakernel node
// Note: We don't add edges TO external consumers because the original
// block op -> consumer edges still exist and will be used for exec_graph ordering
for input in &external_inputs {
llir_graph.add_edge(*input, megakernel_node, ());
}
// Map megakernel node to all block op nodes it contains
megakernel_to_blocks.insert(megakernel_node, subgraph.into_iter().collect());
}
megakernel_to_blocks
}

View File

@@ -3,7 +3,7 @@ use std::sync::{Arc, OnceLock};
use luminal::{
egglog_utils::{
api::{Rule, SortDef, sort},
base::{EXPRESSION, OP_KIND, STRING},
base::{EXPRESSION, IR, STRING},
extract_expr,
},
op::{EgglogOp, LLIROp},
@@ -74,9 +74,11 @@ impl Default for CuBlasSgemmV2 {
impl EgglogOp for CuBlasSgemmV2 {
fn sort(&self) -> SortDef {
sort(
OP_KIND,
IR,
"cublasSgemmV2",
&[
("a", IR),
("b", IR),
("m", EXPRESSION),
("n", EXPRESSION),
("k", EXPRESSION),
@@ -89,10 +91,6 @@ impl EgglogOp for CuBlasSgemmV2 {
)
}
fn n_inputs(&self) -> usize {
2
}
fn rewrites(&self) -> Vec<Rule> {
vec![
Rule::raw(include_str!["sgemm_v2_RmRm_rewrite.egg"]), // row row
@@ -106,26 +104,25 @@ impl EgglogOp for CuBlasSgemmV2 {
fn extract<'a>(
&'a self,
egraph: &'a luminal::egglog_utils::SerializedEGraph,
kind_children: &[&'a ENodeId],
input_enodes: Vec<&'a ENodeId>,
children: &[&'a ENodeId],
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
) -> (LLIROp, Vec<&'a ENodeId>) {
// Extract dimensions from egglog
let m = extract_expr(egraph, kind_children[0], expr_cache).unwrap();
let n = extract_expr(egraph, kind_children[1], expr_cache).unwrap();
let k = extract_expr(egraph, kind_children[2], expr_cache).unwrap();
let m = extract_expr(egraph, children[2], expr_cache).unwrap();
let n = extract_expr(egraph, children[3], expr_cache).unwrap();
let k = extract_expr(egraph, children[4], expr_cache).unwrap();
// Extract layout strings from egglog
let a_layout_str = &egraph.enodes[kind_children[3]].0;
let b_layout_str = &egraph.enodes[kind_children[4]].0;
let a_layout_str = &egraph.enodes[children[5]].0;
let b_layout_str = &egraph.enodes[children[6]].0;
let a_layout = parse_cublas_op(a_layout_str);
let b_layout = parse_cublas_op(b_layout_str);
// Extract leading dimensions from egglog
let lda = extract_expr(egraph, kind_children[5], expr_cache).unwrap();
let ldb = extract_expr(egraph, kind_children[6], expr_cache).unwrap();
let ldc = extract_expr(egraph, kind_children[7], expr_cache).unwrap();
let lda = extract_expr(egraph, children[7], expr_cache).unwrap();
let ldb = extract_expr(egraph, children[8], expr_cache).unwrap();
let ldc = extract_expr(egraph, children[9], expr_cache).unwrap();
let extracted_state = Self {
m,
@@ -142,7 +139,7 @@ impl EgglogOp for CuBlasSgemmV2 {
let extracted = LLIROp::new::<dyn HostOp>(Box::new(extracted_state) as Box<dyn HostOp>);
(extracted, input_enodes)
(extracted, vec![children[0], children[1]])
}
fn cleanup(&self) -> bool {

View File

@@ -12,13 +12,10 @@
(rule
(
; Match Mul node
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
(= ?mul (Mul ?mul_shape ?a ?a_stride ?b ?b_stride ?mul_out_stride))
; Match Sum that reduces the Mul (k dimension)
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
; Must be exactly 2D (no batch dims) — batched matmul uses CuBlasLt
(= (len ?out_shape) 2)
(= ?sum (Sum ?out_shape ?k ?mul ?sum_in_stride ?k_stride ?sum_out_stride))
; Get dimensions from output shape
(= ?m (nth_from_end ?out_shape 1))
@@ -37,17 +34,17 @@
(= ?b_k_stride (nth_from_end ?b_stride 0))
; Assert contiguous k stride on output (required for reduction)
(= ?k_stride (MIter))
(= ?k_stride (MNum 1))
; Assert A has strides [1, 0, m] (column-major A[m,k] broadcast to [m,n,k])
(= ?a_m_stride (MIter))
(= ?a_m_stride (MNum 1))
(= ?a_n_stride (MNum 0))
(= ?a_k_stride (MMul (MIter) ?m))
(= ?a_k_stride ?m)
; Assert B has strides [0, k, 1] (column-major B[k,n] broadcast to [m,n,k])
(= ?b_m_stride (MNum 0))
(= ?b_n_stride (MMul (MIter) ?k))
(= ?b_k_stride (MIter))
(= ?b_n_stride ?k)
(= ?b_k_stride (MNum 1))
(= (F32) (dtype ?a))
(= (F32) (dtype ?b))
@@ -55,7 +52,9 @@
(
; For column-major A × column-major B with cuBLAS:
; C^T = B^T × A^T → cublasSgemm(OP_T, OP_T, n, m, k, α, B, k, A, m, β, C, n)
(let ?sgemm (Op (cublasSgemmV2
(let ?sgemm (cublasSgemmV2
?b ; First matrix = B (swapped)
?a ; Second matrix = A (swapped)
?n ; cuBLAS m = our n (swapped)
?m ; cuBLAS n = our m (swapped)
?k ; k unchanged
@@ -63,8 +62,7 @@
"T" ; transb = Transpose (A is column-major [m,k], need A^T[k,m])
?k ; lda = k (column-major B[k,n])
?m ; ldb = m (column-major A[m,k])
?n) ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
(ICons ?b (ICons ?a (INil)))))
?n)) ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
(union ?sum ?sgemm)
(set (dtype ?sgemm) (F32))
)

View File

@@ -12,13 +12,10 @@
(rule
(
; Match Mul node
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
(= ?mul (Mul ?mul_shape ?a ?a_stride ?b ?b_stride ?mul_out_stride))
; Match Sum that reduces the Mul (k dimension)
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
; Must be exactly 2D (no batch dims) — batched matmul uses CuBlasLt
(= (len ?out_shape) 2)
(= ?sum (Sum ?out_shape ?k ?mul ?sum_in_stride ?k_stride ?sum_out_stride))
; Get dimensions from output shape
(= ?m (nth_from_end ?out_shape 1))
@@ -37,17 +34,17 @@
(= ?b_k_stride (nth_from_end ?b_stride 0))
; Assert contiguous k stride on output (required for reduction)
(= ?k_stride (MIter))
(= ?k_stride (MNum 1))
; Assert A has strides [1, 0, m] (column-major A[m,k] broadcast to [m,n,k])
(= ?a_m_stride (MIter))
(= ?a_m_stride (MNum 1))
(= ?a_n_stride (MNum 0))
(= ?a_k_stride (MMul (MIter) ?m))
(= ?a_k_stride ?m)
; Assert B has strides [0, 1, n] (row-major B[k,n] permuted to [n,k] then broadcast to [m,n,k])
(= ?b_m_stride (MNum 0))
(= ?b_n_stride (MIter))
(= ?b_k_stride (MMul (MIter) ?n))
(= ?b_n_stride (MNum 1))
(= ?b_k_stride ?n)
(= (F32) (dtype ?a))
(= (F32) (dtype ?b))
@@ -55,7 +52,9 @@
(
; For column-major A × row-major B with cuBLAS:
; C^T = B^T × A^T → cublasSgemm(OP_N, OP_T, n, m, k, α, B, n, A, m, β, C, n)
(let ?sgemm (Op (cublasSgemmV2
(let ?sgemm (cublasSgemmV2
?b ; First matrix = B (swapped)
?a ; Second matrix = A (swapped)
?n ; cuBLAS m = our n (swapped)
?m ; cuBLAS n = our m (swapped)
?k ; k unchanged
@@ -63,8 +62,7 @@
"T" ; transb = Transpose (A is column-major [m,k], need A^T[k,m])
?n ; lda = n (row-major B[k,n] viewed as col-major [n,k])
?m ; ldb = m (column-major A[m,k])
?n) ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
(ICons ?b (ICons ?a (INil)))))
?n)) ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
(union ?sum ?sgemm)
(set (dtype ?sgemm) (F32))
)

View File

@@ -12,13 +12,10 @@
(rule
(
; Match Mul node
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
(= ?mul (Mul ?mul_shape ?a ?a_stride ?b ?b_stride ?mul_out_stride))
; Match Sum that reduces the Mul (k dimension)
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
; Must be exactly 2D (no batch dims) — batched matmul uses CuBlasLt
(= (len ?out_shape) 2)
(= ?sum (Sum ?out_shape ?k ?mul ?sum_in_stride ?k_stride ?sum_out_stride))
; Get dimensions from output shape
(= ?m (nth_from_end ?out_shape 1))
@@ -37,17 +34,17 @@
(= ?b_k_stride (nth_from_end ?b_stride 0))
; Assert contiguous k stride on output (required for reduction)
(= ?k_stride (MIter))
(= ?k_stride (MNum 1))
; Assert A has strides [k, 0, 1] (row-major A[m,k] broadcast to [m,n,k])
(= ?a_m_stride (MMul (MIter) ?k))
(= ?a_m_stride ?k)
(= ?a_n_stride (MNum 0))
(= ?a_k_stride (MIter))
(= ?a_k_stride (MNum 1))
; Assert B has strides [0, k, 1] (column-major B[k,n] broadcast to [m,n,k])
(= ?b_m_stride (MNum 0))
(= ?b_n_stride (MMul (MIter) ?k))
(= ?b_k_stride (MIter))
(= ?b_n_stride ?k)
(= ?b_k_stride (MNum 1))
(= (F32) (dtype ?a))
(= (F32) (dtype ?b))
@@ -55,7 +52,9 @@
(
; For row-major A × column-major B with cuBLAS:
; C^T = B^T × A^T → cublasSgemm(OP_T, OP_N, n, m, k, α, B, k, A, k, β, C, n)
(let ?sgemm (Op (cublasSgemmV2
(let ?sgemm (cublasSgemmV2
?b ; First matrix = B (swapped)
?a ; Second matrix = A (swapped)
?n ; cuBLAS m = our n (swapped)
?m ; cuBLAS n = our m (swapped)
?k ; k unchanged
@@ -63,8 +62,7 @@
"N" ; transb = No transpose
?k ; lda = k (column-major B[k,n])
?k ; ldb = k (row-major A[m,k] viewed as col-major [k,m])
?n) ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
(ICons ?b (ICons ?a (INil)))))
?n)) ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
(union ?sum ?sgemm)
(set (dtype ?sgemm) (F32))
)

View File

@@ -12,13 +12,10 @@
(rule
(
; Match Mul node
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
(= ?mul (Mul ?mul_shape ?a ?a_stride ?b ?b_stride ?mul_out_stride))
; Match Sum that reduces the Mul (k dimension)
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
; Must be exactly 2D (no batch dims) — batched matmul uses CuBlasLt
(= (len ?out_shape) 2)
(= ?sum (Sum ?out_shape ?k ?mul ?sum_in_stride ?k_stride ?sum_out_stride))
; Get dimensions from output shape
(= ?m (nth_from_end ?out_shape 1))
@@ -37,17 +34,17 @@
(= ?b_k_stride (nth_from_end ?b_stride 0))
; Assert contiguous k stride on output (required for reduction)
(= ?k_stride (MIter))
(= ?k_stride (MNum 1))
; Assert A has strides [k, 0, 1] (row-major A[m,k] broadcast to [m,n,k])
(= ?a_m_stride (MMul (MIter) ?k))
(= ?a_m_stride ?k)
(= ?a_n_stride (MNum 0))
(= ?a_k_stride (MIter))
(= ?a_k_stride (MNum 1))
; Assert B has strides [0, 1, n] (row-major B[k,n] permuted to [n,k] then broadcast to [m,n,k])
(= ?b_m_stride (MNum 0))
(= ?b_n_stride (MIter))
(= ?b_k_stride (MMul (MIter) ?n))
(= ?b_n_stride (MNum 1))
(= ?b_k_stride ?n)
(= (F32) (dtype ?a))
(= (F32) (dtype ?b))
@@ -55,7 +52,9 @@
(
; For row-major C = A × B with cuBLAS (column-major):
; cublasSgemm(OP_N, OP_N, n, m, k, α, B, n, A, k, β, C, n)
(let ?sgemm (Op (cublasSgemmV2
(let ?sgemm (cublasSgemmV2
?b ; First matrix = B (swapped)
?a ; Second matrix = A (swapped)
?n ; cuBLAS m = our n (swapped)
?m ; cuBLAS n = our m (swapped)
?k ; k unchanged
@@ -63,8 +62,7 @@
"N" ; transb = No transpose
?n ; lda = n (row-major B[k,n] viewed as col-major [n,k])
?k ; ldb = k (row-major A[m,k] viewed as col-major [k,m])
?n) ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
(ICons ?b (ICons ?a (INil)))))
?n)) ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
(union ?sum ?sgemm)
(set (dtype ?sgemm) (F32))
)

View File

@@ -0,0 +1,71 @@
; Column-major × Column-major matmul: C[m,n] = A[m,k] × B[k,n]
; A[m,k] column-major → expand to [m, n, k] with strides [1, 0, m]
; B[k,n] column-major → expand to [m, n, k] with strides [0, k, 1]
;
; Row-major viewed as column-major (swap trick):
; Column-major A[m,k] is already column-major with lda=m
; Column-major B[k,n] is already column-major with ldb=k
; Row-major C[m,n] ≡ column-major C^T[n,m] with ldc=n
;
; C^T[n,m] = (A × B)^T = B^T[n,k] × A^T[k,m]
; cuBLAS: cublasSgemm(OP_T, OP_T, n, m, k, α, B, k, A, m, β, C, n)
(rule
(
; Match Mul node
(= ?mul (Mul ?mul_shape ?a ?a_stride ?b ?b_stride ?mul_out_stride))
; Match Sum that reduces the Mul (k dimension)
(= ?sum (Sum ?out_shape ?k ?mul ?sum_in_stride ?k_stride ?sum_out_stride))
; Get dimensions from output shape
(= ?m (nth_from_end ?out_shape 1))
(= ?n (nth_from_end ?out_shape 0))
(!= ?m (MNum 0))
(!= ?n (MNum 0))
; Get A strides in [m, n, k] space
(= ?a_m_stride (nth_from_end ?a_stride 2))
(= ?a_n_stride (nth_from_end ?a_stride 1))
(= ?a_k_stride (nth_from_end ?a_stride 0))
; Get B strides in [m, n, k] space
(= ?b_m_stride (nth_from_end ?b_stride 2))
(= ?b_n_stride (nth_from_end ?b_stride 1))
(= ?b_k_stride (nth_from_end ?b_stride 0))
; Assert contiguous k stride on output (required for reduction)
(= ?k_stride (MNum 1))
; Assert A has strides [1, 0, m] (column-major A[m,k] broadcast to [m,n,k])
(= ?a_m_stride (MNum 1))
(= ?a_n_stride (MNum 0))
(= ?a_k_stride ?m)
; Assert B has strides [0, k, 1] (column-major B[k,n] broadcast to [m,n,k])
(= ?b_m_stride (MNum 0))
(= ?b_n_stride ?k)
(= ?b_k_stride (MNum 1))
(= ?dt (dtype ?a))
(= ?dt (dtype ?b))
)
(
; For column-major A × column-major B with cuBLAS:
; C^T = B^T × A^T → cublasSgemm(OP_T, OP_T, n, m, k, α, B, k, A, m, β, C, n)
(let ?sgemm (cublaslt
?b ; First matrix = B (swapped)
?a ; Second matrix = A (swapped)
?n ; cuBLAS m = our n (swapped)
?m ; cuBLAS n = our m (swapped)
?k ; k unchanged
"T" ; transa = Transpose (B is column-major [k,n], need B^T[n,k])
"T" ; transb = Transpose (A is column-major [m,k], need A^T[k,m])
?k ; lda = k (column-major B[k,n])
?m ; ldb = m (column-major A[m,k])
?n ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
?dt)) ; dtype
(union ?sum ?sgemm)
(set (dtype ?sgemm) ?dt)
)
:name "cublaslt column-major × column-major"
)

View File

@@ -0,0 +1,71 @@
; Column-major × Row-major matmul: C[m,n] = A[m,k] × B[k,n]
; A[m,k] column-major → expand to [m, n, k] with strides [1, 0, m]
; B[k,n] row-major → permute to [n,k] then expand to [m, n, k] with strides [0, 1, n]
;
; Row-major viewed as column-major (swap trick):
; Column-major A[m,k] is already column-major with lda=m
; Row-major B[k,n] ≡ column-major B^T[n,k] with ldb=n
; Row-major C[m,n] ≡ column-major C^T[n,m] with ldc=n
;
; C^T[n,m] = (A × B)^T = B^T[n,k] × A^T[k,m]
; cuBLAS: cublasSgemm(OP_N, OP_T, n, m, k, α, B, n, A, m, β, C, n)
(rule
(
; Match Mul node
(= ?mul (Mul ?mul_shape ?a ?a_stride ?b ?b_stride ?mul_out_stride))
; Match Sum that reduces the Mul (k dimension)
(= ?sum (Sum ?out_shape ?k ?mul ?sum_in_stride ?k_stride ?sum_out_stride))
; Get dimensions from output shape
(= ?m (nth_from_end ?out_shape 1))
(= ?n (nth_from_end ?out_shape 0))
(!= ?m (MNum 0))
(!= ?n (MNum 0))
; Get A strides in [m, n, k] space
(= ?a_m_stride (nth_from_end ?a_stride 2))
(= ?a_n_stride (nth_from_end ?a_stride 1))
(= ?a_k_stride (nth_from_end ?a_stride 0))
; Get B strides in [m, n, k] space
(= ?b_m_stride (nth_from_end ?b_stride 2))
(= ?b_n_stride (nth_from_end ?b_stride 1))
(= ?b_k_stride (nth_from_end ?b_stride 0))
; Assert contiguous k stride on output (required for reduction)
(= ?k_stride (MNum 1))
; Assert A has strides [1, 0, m] (column-major A[m,k] broadcast to [m,n,k])
(= ?a_m_stride (MNum 1))
(= ?a_n_stride (MNum 0))
(= ?a_k_stride ?m)
; Assert B has strides [0, 1, n] (row-major B[k,n] permuted to [n,k] then broadcast to [m,n,k])
(= ?b_m_stride (MNum 0))
(= ?b_n_stride (MNum 1))
(= ?b_k_stride ?n)
(= ?dt (dtype ?a))
(= ?dt (dtype ?b))
)
(
; For column-major A × row-major B with cuBLAS:
; C^T = B^T × A^T → cublasSgemm(OP_N, OP_T, n, m, k, α, B, n, A, m, β, C, n)
(let ?sgemm (cublaslt
?b ; First matrix = B (swapped)
?a ; Second matrix = A (swapped)
?n ; cuBLAS m = our n (swapped)
?m ; cuBLAS n = our m (swapped)
?k ; k unchanged
"N" ; transa = No transpose (B is row-major, viewed as col-major [n,k])
"T" ; transb = Transpose (A is column-major [m,k], need A^T[k,m])
?n ; lda = n (row-major B[k,n] viewed as col-major [n,k])
?m ; ldb = m (column-major A[m,k])
?n ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
?dt)) ; dtype
(union ?sum ?sgemm)
(set (dtype ?sgemm) ?dt)
)
:name "cublaslt column-major × row-major"
)

View File

@@ -0,0 +1,71 @@
; Row-major × Column-major matmul: C[m,n] = A[m,k] × B[k,n]
; A[m,k] row-major → expand to [m, n, k] with strides [k, 0, 1]
; B[k,n] column-major → expand to [m, n, k] with strides [0, k, 1]
;
; Row-major viewed as column-major (swap trick):
; Row-major A[m,k] ≡ column-major A^T[k,m] with lda=k
; Column-major B[k,n] is already column-major with ldb=k
; Row-major C[m,n] ≡ column-major C^T[n,m] with ldc=n
;
; C^T[n,m] = (A × B)^T = B^T[n,k] × A^T[k,m]
; cuBLAS: cublasSgemm(OP_T, OP_N, n, m, k, α, B, k, A, k, β, C, n)
(rule
(
; Match Mul node
(= ?mul (Mul ?mul_shape ?a ?a_stride ?b ?b_stride ?mul_out_stride))
; Match Sum that reduces the Mul (k dimension)
(= ?sum (Sum ?out_shape ?k ?mul ?sum_in_stride ?k_stride ?sum_out_stride))
; Get dimensions from output shape
(= ?m (nth_from_end ?out_shape 1))
(= ?n (nth_from_end ?out_shape 0))
(!= ?m (MNum 0))
(!= ?n (MNum 0))
; Get A strides in [m, n, k] space
(= ?a_m_stride (nth_from_end ?a_stride 2))
(= ?a_n_stride (nth_from_end ?a_stride 1))
(= ?a_k_stride (nth_from_end ?a_stride 0))
; Get B strides in [m, n, k] space
(= ?b_m_stride (nth_from_end ?b_stride 2))
(= ?b_n_stride (nth_from_end ?b_stride 1))
(= ?b_k_stride (nth_from_end ?b_stride 0))
; Assert contiguous k stride on output (required for reduction)
(= ?k_stride (MNum 1))
; Assert A has strides [k, 0, 1] (row-major A[m,k] broadcast to [m,n,k])
(= ?a_m_stride ?k)
(= ?a_n_stride (MNum 0))
(= ?a_k_stride (MNum 1))
; Assert B has strides [0, k, 1] (column-major B[k,n] broadcast to [m,n,k])
(= ?b_m_stride (MNum 0))
(= ?b_n_stride ?k)
(= ?b_k_stride (MNum 1))
(= ?dt (dtype ?a))
(= ?dt (dtype ?b))
)
(
; For row-major A × column-major B with cuBLAS:
; C^T = B^T × A^T → cublasSgemm(OP_T, OP_N, n, m, k, α, B, k, A, k, β, C, n)
(let ?sgemm (cublaslt
?b ; First matrix = B (swapped)
?a ; Second matrix = A (swapped)
?n ; cuBLAS m = our n (swapped)
?m ; cuBLAS n = our m (swapped)
?k ; k unchanged
"T" ; transa = Transpose (B is column-major, need B^T)
"N" ; transb = No transpose
?k ; lda = k (column-major B[k,n])
?k ; ldb = k (row-major A[m,k] viewed as col-major [k,m])
?n ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
?dt)) ; dtype
(union ?sum ?sgemm)
(set (dtype ?sgemm) ?dt)
)
:name "cublaslt row-major × column-major"
)

View File

@@ -0,0 +1,71 @@
; Row-major matmul: C[m,n] = A[m,k] × B[k,n]
; A[m,k] row-major → expand to [m, n, k] with strides [k, 0, 1]
; B[k,n] row-major → permute to [n,k] then expand to [m, n, k] with strides [0, 1, n]
;
; Row-major viewed as column-major (swap trick):
; Row-major A[m,k] ≡ column-major [k,m] with lda=k
; Row-major B[k,n] ≡ column-major [n,k] with ldb=n
; Row-major C[m,n] ≡ column-major [n,m] with ldc=n
;
; cuBLAS computes: C_col[n,m] = B_col[n,k] × A_col[k,m]
; cublasSgemm(OP_N, OP_N, n, m, k, α, B, n, A, k, β, C, n)
(rule
(
; Match Mul node
(= ?mul (Mul ?mul_shape ?a ?a_stride ?b ?b_stride ?mul_out_stride))
; Match Sum that reduces the Mul (k dimension)
(= ?sum (Sum ?out_shape ?k ?mul ?sum_in_stride ?k_stride ?sum_out_stride))
; Get dimensions from output shape
(= ?m (nth_from_end ?out_shape 1))
(= ?n (nth_from_end ?out_shape 0))
(!= ?m (MNum 0))
(!= ?n (MNum 0))
; Get A strides in [m, n, k] space
(= ?a_m_stride (nth_from_end ?a_stride 2))
(= ?a_n_stride (nth_from_end ?a_stride 1))
(= ?a_k_stride (nth_from_end ?a_stride 0))
; Get B strides in [m, n, k] space
(= ?b_m_stride (nth_from_end ?b_stride 2))
(= ?b_n_stride (nth_from_end ?b_stride 1))
(= ?b_k_stride (nth_from_end ?b_stride 0))
; Assert contiguous k stride on output (required for reduction)
(= ?k_stride (MNum 1))
; Assert A has strides [k, 0, 1] (row-major A[m,k] broadcast to [m,n,k])
(= ?a_m_stride ?k)
(= ?a_n_stride (MNum 0))
(= ?a_k_stride (MNum 1))
; Assert B has strides [0, 1, n] (row-major B[k,n] permuted to [n,k] then broadcast to [m,n,k])
(= ?b_m_stride (MNum 0))
(= ?b_n_stride (MNum 1))
(= ?b_k_stride ?n)
(= ?dt (dtype ?a))
(= ?dt (dtype ?b))
)
(
; For row-major C = A × B with cuBLAS (column-major):
; cublasSgemm(OP_N, OP_N, n, m, k, α, B, n, A, k, β, C, n)
(let ?sgemm (cublaslt
?b ; First matrix = B (swapped)
?a ; Second matrix = A (swapped)
?n ; cuBLAS m = our n (swapped)
?m ; cuBLAS n = our m (swapped)
?k ; k unchanged
"N" ; transa = No transpose
"N" ; transb = No transpose
?n ; lda = n (row-major B[k,n] viewed as col-major [n,k])
?k ; ldb = k (row-major A[m,k] viewed as col-major [k,m])
?n ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
?dt)) ; dtype
(union ?sum ?sgemm)
(set (dtype ?sgemm) ?dt)
)
:name "cublaslt row-major x row-major"
)

View File

@@ -4,7 +4,7 @@ use luminal::{
dtype::DType,
egglog_utils::{
api::{Rule, SortDef, sort},
base::{DTYPE, EXPRESSION, OP_KIND, STRING},
base::{DTYPE, EXPRESSION, IR, STRING},
extract_dtype, extract_expr,
},
op::{EgglogOp, LLIROp},
@@ -45,10 +45,6 @@ pub struct CuBlasLt {
lda: Expression,
ldb: Expression,
ldc: Expression,
batch_count: Expression,
stride_a: Expression,
stride_b: Expression,
stride_c: Expression,
dtype: DType,
cublaslt: OnceLock<Arc<CudaBlasLT>>,
}
@@ -60,15 +56,11 @@ impl Default for CuBlasLt {
m: Expression::default(),
n: Expression::default(),
k: Expression::default(),
a_layout: cublasOperation_t::CUBLAS_OP_N,
b_layout: cublasOperation_t::CUBLAS_OP_T,
a_layout: cublasOperation_t::CUBLAS_OP_N, // IGNORE NOT REAL
b_layout: cublasOperation_t::CUBLAS_OP_T, // IGNORE NOT REAL
lda: Expression::default(),
ldb: Expression::default(),
ldc: Expression::default(),
batch_count: 1.into(),
stride_a: 0.into(),
stride_b: 0.into(),
stride_c: 0.into(),
dtype: DType::F32,
cublaslt: OnceLock::new(),
}
@@ -78,9 +70,11 @@ impl Default for CuBlasLt {
impl EgglogOp for CuBlasLt {
fn sort(&self) -> SortDef {
sort(
OP_KIND,
IR,
"cublaslt",
&[
("a", IR),
("b", IR),
("m", EXPRESSION),
("n", EXPRESSION),
("k", EXPRESSION),
@@ -89,48 +83,17 @@ impl EgglogOp for CuBlasLt {
("lda", EXPRESSION),
("ldb", EXPRESSION),
("ldc", EXPRESSION),
("batch_count", EXPRESSION),
("stride_a", EXPRESSION),
("stride_b", EXPRESSION),
("stride_c", EXPRESSION),
("dtype", DTYPE),
],
)
}
fn n_inputs(&self) -> usize {
2
}
fn rewrites(&self) -> Vec<Rule> {
vec![
Rule::raw(include_str!["cublaslt_RmRm_rewrite.egg"]), // row row
Rule::raw(include_str!["cublaslt_RmCm_rewrite.egg"]), // row col
Rule::raw(include_str!["cublaslt_CmRm_rewrite.egg"]), // col row
Rule::raw(include_str!["cublaslt_CmCm_rewrite.egg"]), // col col
// Delete KernelMul matmul broadcast intermediates when the Sum eclass
// has a cublaslt or KernelBatchMatMul alternative. This prevents OOM
// from O(m*k*n) intermediates at large seq_len. cuBLAS, TileMatmulFullSplit,
// KernelBatchMatVec, and KernelBatchMatMul all take original inputs
// (not the Mul eclass), so they survive the cascade.
Rule::raw("(rule
((= ?mul (Op (KernelMul ?shape ?as ?bs ?os ?dt) ?inputs))
(= (MNum 0) (nth_from_end ?as 1))
(= (MNum 0) (nth_from_end ?bs 2))
(= ?sum (Op (Sum ?sshape ?sk ?ssi ?sks ?sso) (ICons ?mul (INil))))
(= ?sum (Op (cublaslt ?cm ?cn ?ck ?cta ?ctb ?clda ?cldb ?cldc ?cbc ?csa ?csb ?csc ?cdt) ?ci)))
((delete (Op (KernelMul ?shape ?as ?bs ?os ?dt) ?inputs)))
:ruleset cleanup
)"),
Rule::raw("(rule
((= ?mul (Op (KernelMul ?shape ?as ?bs ?os ?dt) ?inputs))
(= (MNum 0) (nth_from_end ?as 1))
(= (MNum 0) (nth_from_end ?bs 2))
(= ?sum (Op (Sum ?sshape ?sk ?ssi ?sks ?sso) (ICons ?mul (INil))))
(= ?sum (Op (KernelBatchMatMul ?bos ?bk ?bas ?baks ?bbs ?bbks ?bouts ?bdt) ?bi)))
((delete (Op (KernelMul ?shape ?as ?bs ?os ?dt) ?inputs)))
:ruleset cleanup
)"),
]
}
@@ -138,35 +101,28 @@ impl EgglogOp for CuBlasLt {
fn extract<'a>(
&'a self,
egraph: &'a luminal::egglog_utils::SerializedEGraph,
kind_children: &[&'a ENodeId],
input_enodes: Vec<&'a ENodeId>,
children: &[&'a ENodeId],
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
) -> (LLIROp, Vec<&'a ENodeId>) {
// Extract dimensions from egglog
let m = extract_expr(egraph, kind_children[0], expr_cache).unwrap();
let n = extract_expr(egraph, kind_children[1], expr_cache).unwrap();
let k = extract_expr(egraph, kind_children[2], expr_cache).unwrap();
let m = extract_expr(egraph, children[2], expr_cache).unwrap();
let n = extract_expr(egraph, children[3], expr_cache).unwrap();
let k = extract_expr(egraph, children[4], expr_cache).unwrap();
// Extract layout strings from egglog
let a_layout_str = &egraph.enodes[kind_children[3]].0;
let b_layout_str = &egraph.enodes[kind_children[4]].0;
let a_layout_str = &egraph.enodes[children[5]].0;
let b_layout_str = &egraph.enodes[children[6]].0;
let a_layout = parse_cublas_op(a_layout_str);
let b_layout = parse_cublas_op(b_layout_str);
// Extract leading dimensions from egglog
let lda = extract_expr(egraph, kind_children[5], expr_cache).unwrap();
let ldb = extract_expr(egraph, kind_children[6], expr_cache).unwrap();
let ldc = extract_expr(egraph, kind_children[7], expr_cache).unwrap();
// Extract batch parameters
let batch_count = extract_expr(egraph, kind_children[8], expr_cache).unwrap();
let stride_a = extract_expr(egraph, kind_children[9], expr_cache).unwrap();
let stride_b = extract_expr(egraph, kind_children[10], expr_cache).unwrap();
let stride_c = extract_expr(egraph, kind_children[11], expr_cache).unwrap();
let lda = extract_expr(egraph, children[7], expr_cache).unwrap();
let ldb = extract_expr(egraph, children[8], expr_cache).unwrap();
let ldc = extract_expr(egraph, children[9], expr_cache).unwrap();
// Extract dtype from egglog
let dtype = extract_dtype(egraph, kind_children[12]);
let dtype = extract_dtype(egraph, children[10]);
let extracted_state = Self {
m,
@@ -177,10 +133,6 @@ impl EgglogOp for CuBlasLt {
lda,
ldb,
ldc,
batch_count,
stride_a,
stride_b,
stride_c,
dtype,
cublaslt: OnceLock::new(),
};
@@ -188,7 +140,7 @@ impl EgglogOp for CuBlasLt {
let extracted = LLIROp::new::<dyn HostOp>(Box::new(extracted_state) as Box<dyn HostOp>);
(extracted, input_enodes)
(extracted, vec![children[0], children[1]])
}
fn cleanup(&self) -> bool {
@@ -257,24 +209,15 @@ impl HostOp for CuBlasLt {
buffers: &FxHashMap<NodeIndex, &CudaSlice<u8>>,
dyn_map: &FxHashMap<char, usize>,
) -> anyhow::Result<()> {
use crate::cudarc::cublaslt::sys::{
cublasLtMatrixLayoutAttribute_t, cublasLtMatrixLayoutSetAttribute,
};
// GEMM parameters — resolve z→1 for element stride before exec
let resolve = |e: &Expression| -> Expression { e.substitute('z', Expression::from(1)) };
let m = resolve(&self.m).exec(dyn_map).unwrap() as u64;
let n = resolve(&self.n).exec(dyn_map).unwrap() as u64;
let k = resolve(&self.k).exec(dyn_map).unwrap() as u64;
// GEMM parameters
let m = self.m.exec(dyn_map).unwrap() as u64;
let n = self.n.exec(dyn_map).unwrap() as u64;
let k = self.k.exec(dyn_map).unwrap() as u64;
let a_layout = self.a_layout;
let b_layout = self.b_layout;
let lda = resolve(&self.lda).exec(dyn_map).unwrap() as i64;
let ldb = resolve(&self.ldb).exec(dyn_map).unwrap() as i64;
let ldc = resolve(&self.ldc).exec(dyn_map).unwrap() as i64;
let batch_count = resolve(&self.batch_count).exec(dyn_map).unwrap() as i32;
let stride_a = resolve(&self.stride_a).exec(dyn_map).unwrap() as i64;
let stride_b = resolve(&self.stride_b).exec(dyn_map).unwrap() as i64;
let stride_c = resolve(&self.stride_c).exec(dyn_map).unwrap() as i64;
let lda = self.lda.exec(dyn_map).unwrap() as i64;
let ldb = self.ldb.exec(dyn_map).unwrap() as i64;
let ldc = self.ldc.exec(dyn_map).unwrap() as i64;
// Get CUDA types based on dtype
let (cuda_dtype, compute_type, scale_dtype) = dtype_to_cuda_types(self.dtype);
@@ -299,28 +242,20 @@ impl HostOp for CuBlasLt {
let (b_ptr, _b_guard) = b_buf.device_ptr(stream);
let (c_ptr, _c_guard) = c_buf.device_ptr(stream);
// Clamp leading dimensions to minimum valid values.
// When a dimension is 1 (e.g., k=1 outer product), the stride along that
// dimension may be 0 in the egglog representation, but cuBLAS requires
// lda >= rows_of_A and ldb >= rows_of_B.
let a_ld_min = if a_layout == cublasOperation_t::CUBLAS_OP_N {
m
} else {
k
};
let b_ld_min = if b_layout == cublasOperation_t::CUBLAS_OP_N {
k
} else {
n
};
let lda = std::cmp::max(lda, a_ld_min as i64);
let ldb = std::cmp::max(ldb, b_ld_min as i64);
let ldc = std::cmp::max(ldc, m as i64);
// Debug tracing
trace!(
"buffer_validation {}=={},{}=={},{}=={}",
a_buf.len(),
m * k * element_size,
b_buf.len(),
k * n * element_size,
c_buf.len(),
m * n * element_size
);
let _span = span!(
Level::TRACE,
"cuBLASLT",
m, n, k, lda, ldb, ldc, batch_count, ?a_layout, ?b_layout, ?self.dtype,
m, n, k, lda, ldb, ldc, ?a_layout, ?b_layout, ?self.dtype,
)
.entered();
@@ -377,26 +312,6 @@ impl HostOp for CuBlasLt {
cublasLtMatrixLayoutCreate(&mut b_desc, cuda_dtype, b_rows, b_cols, ldb).result()?;
cublasLtMatrixLayoutCreate(&mut c_desc, cuda_dtype, m, n, ldc).result()?;
// Set batched GEMM attributes if batch_count > 1
if batch_count > 1 {
for (desc, stride) in [(a_desc, stride_a), (b_desc, stride_b), (c_desc, stride_c)] {
cublasLtMatrixLayoutSetAttribute(
desc,
cublasLtMatrixLayoutAttribute_t::CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT,
&batch_count as *const _ as *const std::ffi::c_void,
std::mem::size_of::<i32>(),
)
.result()?;
cublasLtMatrixLayoutSetAttribute(
desc,
cublasLtMatrixLayoutAttribute_t::CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
&stride as *const _ as *const std::ffi::c_void,
std::mem::size_of::<i64>(),
)
.result()?;
}
}
// Create preference and set workspace size
cublasLtMatmulPreferenceCreate(&mut preference).result()?;
cublasLtMatmulPreferenceSetAttribute(
@@ -423,6 +338,7 @@ impl HostOp for CuBlasLt {
.result()?;
if algo_count == 0 {
// Cleanup before returning error
cublasLtMatmulPreferenceDestroy(preference);
cublasLtMatrixLayoutDestroy(c_desc);
cublasLtMatrixLayoutDestroy(b_desc);
@@ -431,6 +347,7 @@ impl HostOp for CuBlasLt {
return Err(anyhow::anyhow!("No suitable cuBLASLT algorithm found"));
}
// All dtypes use F32 scale type for alpha/beta
let alpha_ptr = &alpha_f32 as *const _ as *const std::ffi::c_void;
let beta_ptr = &beta_f32 as *const _ as *const std::ffi::c_void;
cublasLtMatmul(
@@ -445,7 +362,7 @@ impl HostOp for CuBlasLt {
c_ptr as *const std::ffi::c_void,
c_desc,
c_ptr as *mut std::ffi::c_void,
c_desc,
c_desc, // D layout same as C
&heuristic.algo,
workspace_ptr as *mut std::ffi::c_void,
WORKSPACE_SIZE,
@@ -461,14 +378,12 @@ impl HostOp for CuBlasLt {
cublasLtMatmulDescDestroy(matmul_desc);
}
// No stream.synchronize() here — CUDA stream ordering guarantees
// sequential execution. The runtime syncs once at the end of execute().
stream.synchronize()?;
Ok(())
}
fn output_size(&self) -> Expression {
let resolve = |e: &Expression| -> Expression { e.substitute('z', Expression::from(1)) };
resolve(&self.batch_count) * resolve(&self.m) * resolve(&self.n)
self.m * self.n
}
fn output_bytes(&self) -> Expression {

View File

@@ -0,0 +1,127 @@
; GLUMoE: Match the expert computation subgraph of a Gated MoE (SwiGLU variant).
;
; This matches the pattern produced by QwenMoE::forward() starting from the
; expert gathers through to the final weighted sum, and replaces it with a
; fused GLUMoE HostOp.
;
; Inputs extracted:
; ?x - input activations [s, H] F32
; ?topk_idx - top-k expert indices [s, k] Int (from argsort+slice)
; ?topk_vals - top-k routing values [s, k] F32 (from gather on softmax)
; ?gate_up_w - stacked gate+up expert weights [E, intermediate*2, H] BF16
; ?down_w - stacked down expert weights [E, H, intermediate] BF16
;
; The pattern captures:
; 1. Gate-up expert gather (Iota, Mul, Cast, Iota, Cast, Add, Cast, Gather)
; 2. Cast BF16→F32 of gathered gate-up weights
; 3. Gate-up batched matmul (Mul + SumReduce)
; 4. Gate/Up split via Iota+Gather (slice semantics)
; 5. SwiGLU: silu(gate) * up
; 6. Down expert gather (same pattern as gate-up)
; 7. Cast BF16→F32 of gathered down weights
; 8. Down batched matmul (Mul + SumReduce)
; 9. Weighted sum: (down_out * topk_values) summed over k
;
; Variables with ? prefix are egglog pattern variables.
; We use wildcards (?_xxx) for shapes/strides we don't extract.
(rule
(
; ===== Gate-up expert gather =====
; t51: Iota for base index (expert_idx * io_gu)
(= ?gu_iota_base (Iota ?gu_io ?gu_iota_base_range))
; t52: Mul topk_indices * io → base offsets [s, k]
(= ?gu_mul_base (Mul ?gu_mul_base_shape ?topk_idx ?gu_mul_base_a_stride ?gu_iota_base ?gu_mul_base_b_stride ?gu_mul_base_out_stride))
; t53: Cast to F32
(= ?gu_cast_base (Cast ?gu_mul_base ?gu_cast_base_size (F32)))
; t54: Iota for within-expert index
(= ?gu_iota_within (Iota (MIter) ?gu_iota_within_range))
; t55: Cast within to F32
(= ?gu_cast_within (Cast ?gu_iota_within ?gu_cast_within_size (F32)))
; t56: Add base + within → flat gather indices
(= ?gu_add_idx (Add ?gu_add_shape ?gu_cast_base ?gu_add_a_stride ?gu_cast_within ?gu_add_b_stride ?gu_add_out_stride))
; t57: Cast to Int
(= ?gu_cast_idx (Cast ?gu_add_idx ?gu_cast_idx_size (Int)))
; t58: Gather gate_up weights
(= ?gu_gathered (Gather ?gu_cast_idx ?gu_gather_idx_shape ?gu_gather_idx_stride ?gate_up_w ?gu_gather_data_shape ?gu_gather_data_stride))
; ===== Cast BF16→F32 =====
; t59: Cast gathered gate_up to F32
(= ?gu_f32 (Cast ?gu_gathered ?gu_f32_size (F32)))
; ===== Gate-up batched matmul =====
; t60: Mul x * gathered_gu (broadcast multiply)
(= ?gu_matmul_mul (Mul ?gu_matmul_mul_shape ?x ?gu_matmul_a_stride ?gu_f32 ?gu_matmul_b_stride ?gu_matmul_mul_out_stride))
; t61: SumReduce over K dimension
(= ?gu_matmul (Sum ?gu_matmul_out_shape ?gu_matmul_k ?gu_matmul_mul ?gu_matmul_in_stride ?gu_matmul_k_stride ?gu_matmul_out_stride))
; ===== Up slice via Iota+Gather =====
; t62: Iota with complex expression (slicing the "up" half)
(= ?up_iota (Iota ?up_iota_expr ?up_iota_range))
; t63: Gather to select up portion from matmul result
(= ?up_slice (Gather ?up_iota ?up_gather_idx_shape ?up_gather_idx_stride ?gu_matmul ?up_gather_data_shape ?up_gather_data_stride))
; ===== SwiGLU: silu(gate) * up =====
; t64: Constant(-1)
(= ?neg1 (Constant -1.000000))
; t65: gate * -1
(= ?neg_gate (Mul ?silu_shape1 ?gu_matmul ?silu_a_stride1 ?neg1 ?silu_b_stride1 ?silu_out_stride1))
; t66: Constant(log2e)
(= ?log2e (Constant 1.442695))
; t67: neg_gate * log2e
(= ?scaled (Mul ?silu_shape2 ?neg_gate ?silu_a_stride2 ?log2e ?silu_b_stride2 ?silu_out_stride2))
; t68: exp2
(= ?exp2_val (Exp2 ?silu_shape3 ?scaled ?silu_in_stride3 ?silu_out_stride3))
; t69: Constant(1)
(= ?one (Constant 1.000000))
; t70: exp2 + 1
(= ?plus1 (Add ?silu_shape4 ?exp2_val ?silu_a_stride4 ?one ?silu_b_stride4 ?silu_out_stride4))
; t71: recip
(= ?sigmoid (Recip ?silu_shape5 ?plus1 ?silu_in_stride5 ?silu_out_stride5))
; t72: gate * sigmoid(gate) = silu(gate)
(= ?silu_out (Mul ?silu_shape6 ?gu_matmul ?silu_a_stride6 ?sigmoid ?silu_b_stride6 ?silu_out_stride6))
; t73: silu(gate) * up
(= ?swiglu_out (Mul ?swiglu_shape ?silu_out ?swiglu_a_stride ?up_slice ?swiglu_b_stride ?swiglu_out_stride))
; ===== Down expert gather =====
; t74: Iota for base index (expert_idx * io_down)
(= ?dn_iota_base (Iota ?dn_io ?dn_iota_base_range))
; t75: Mul topk_indices * io_down
(= ?dn_mul_base (Mul ?dn_mul_base_shape ?topk_idx ?dn_mul_base_a_stride ?dn_iota_base ?dn_mul_base_b_stride ?dn_mul_base_out_stride))
; t76: Cast to F32
(= ?dn_cast_base (Cast ?dn_mul_base ?dn_cast_base_size (F32)))
; t77: Iota for within-expert index
(= ?dn_iota_within (Iota (MIter) ?dn_iota_within_range))
; t78: Cast within to F32
(= ?dn_cast_within (Cast ?dn_iota_within ?dn_cast_within_size (F32)))
; t79: Add base + within
(= ?dn_add_idx (Add ?dn_add_shape ?dn_cast_base ?dn_add_a_stride ?dn_cast_within ?dn_add_b_stride ?dn_add_out_stride))
; t80: Cast to Int
(= ?dn_cast_idx (Cast ?dn_add_idx ?dn_cast_idx_size (Int)))
; t81: Gather down weights
(= ?dn_gathered (Gather ?dn_cast_idx ?dn_gather_idx_shape ?dn_gather_idx_stride ?down_w ?dn_gather_data_shape ?dn_gather_data_stride))
; ===== Cast BF16→F32 =====
; t82: Cast gathered down to F32
(= ?dn_f32 (Cast ?dn_gathered ?dn_f32_size (F32)))
; ===== Down batched matmul =====
; t83: Mul swiglu_out * gathered_down (broadcast multiply)
(= ?dn_matmul_mul (Mul ?dn_matmul_mul_shape ?swiglu_out ?dn_matmul_a_stride ?dn_f32 ?dn_matmul_b_stride ?dn_matmul_mul_out_stride))
; t84: SumReduce
(= ?dn_matmul (Sum ?dn_matmul_out_shape ?dn_matmul_k ?dn_matmul_mul ?dn_matmul_in_stride ?dn_matmul_k_stride ?dn_matmul_out_stride))
; ===== Weighted sum over k experts =====
; t85: Mul down_out * topk_values
(= ?weighted (Mul ?weighted_shape ?dn_matmul ?weighted_a_stride ?topk_vals ?weighted_b_stride ?weighted_out_stride))
; t86: SumReduce over k dimension → [s, H]
(= ?output (Sum ?output_shape ?output_k ?weighted ?output_in_stride ?output_k_stride ?output_out_stride))
)
(
(let ?glumoe (GLUMoE ?x ?topk_idx ?topk_vals ?gate_up_w ?down_w
?gu_io ?dn_io ?gu_matmul_k ?dn_matmul_k ?output_k
?gu_iota_within_range ?dn_iota_within_range))
(union ?output ?glumoe)
)
:name "GLUMoE fused expert computation"
)

View File

@@ -3,7 +3,7 @@ use std::sync::{Arc, OnceLock};
use luminal::{
egglog_utils::{
api::{Rule, SortDef, sort},
base::{EXPRESSION, OP_KIND},
base::{EXPRESSION, IR},
extract_expr,
},
op::{EgglogOp, LLIROp},
@@ -12,7 +12,6 @@ use luminal::{
};
use crate::{
compile_module_image_for_current_device,
cudarc::{
cublas::sys::cublasOperation_t,
cublaslt::{
@@ -31,6 +30,7 @@ use crate::{
driver::{
CudaFunction, CudaModule, CudaSlice, CudaStream, DevicePtr, LaunchConfig, PushKernelArg,
},
nvrtc::{CompileOptions, compile_ptx_with_opts},
},
host::HostOp,
};
@@ -146,7 +146,17 @@ extern "C" __global__ void swiglu_bf16(unsigned long long gate_up_ptr, unsigned
}
}
"#;
let ptx = compile_module_image_for_current_device(stream.context(), src).unwrap();
let ptx = compile_ptx_with_opts(
src,
CompileOptions {
include_paths: vec![
"/usr/local/cuda/include".to_string(),
"/usr/include".to_string(),
],
..Default::default()
},
)
.unwrap();
let module = stream.context().load_module(ptx).unwrap();
let f32_to_bf16 = module.load_function("f32_to_bf16").unwrap();
let swiglu = module.load_function("swiglu_bf16").unwrap();
@@ -158,9 +168,14 @@ extern "C" __global__ void swiglu_bf16(unsigned long long gate_up_ptr, unsigned
impl EgglogOp for GLUMoE {
fn sort(&self) -> SortDef {
sort(
OP_KIND,
IR,
"GLUMoE",
&[
("x", IR),
("topk_idx", IR),
("topk_vals", IR),
("gate_up_w", IR),
("down_w", IR),
("gu_io", EXPRESSION),
("dn_io", EXPRESSION),
("gu_matmul_k", EXPRESSION),
@@ -172,10 +187,6 @@ impl EgglogOp for GLUMoE {
)
}
fn n_inputs(&self) -> usize {
5
}
fn early_rewrites(&self) -> Vec<Rule> {
vec![Rule::raw(include_str!["glumoe_rewrite.egg"])]
}
@@ -183,18 +194,17 @@ impl EgglogOp for GLUMoE {
fn extract<'a>(
&'a self,
egraph: &'a luminal::egglog_utils::SerializedEGraph,
kind_children: &[&'a ENodeId],
input_enodes: Vec<&'a ENodeId>,
children: &[&'a ENodeId],
_list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
) -> (LLIROp, Vec<&'a ENodeId>) {
let gu_io = extract_expr(egraph, kind_children[0], expr_cache).unwrap();
let dn_io = extract_expr(egraph, kind_children[1], expr_cache).unwrap();
let gu_matmul_k = extract_expr(egraph, kind_children[2], expr_cache).unwrap();
let dn_matmul_k = extract_expr(egraph, kind_children[3], expr_cache).unwrap();
let output_k = extract_expr(egraph, kind_children[4], expr_cache).unwrap();
let gu_within_range = extract_expr(egraph, kind_children[5], expr_cache).unwrap();
let dn_within_range = extract_expr(egraph, kind_children[6], expr_cache).unwrap();
let gu_io = extract_expr(egraph, children[5], expr_cache).unwrap();
let dn_io = extract_expr(egraph, children[6], expr_cache).unwrap();
let gu_matmul_k = extract_expr(egraph, children[7], expr_cache).unwrap();
let dn_matmul_k = extract_expr(egraph, children[8], expr_cache).unwrap();
let output_k = extract_expr(egraph, children[9], expr_cache).unwrap();
let gu_within_range = extract_expr(egraph, children[10], expr_cache).unwrap();
let dn_within_range = extract_expr(egraph, children[11], expr_cache).unwrap();
let extracted = GLUMoE {
gu_io,
@@ -210,7 +220,16 @@ impl EgglogOp for GLUMoE {
let op = LLIROp::new::<dyn HostOp>(Box::new(extracted) as Box<dyn HostOp>);
// Return the 5 IR inputs: x, topk_idx, topk_vals, gate_up_w, down_w
(op, input_enodes)
(
op,
vec![
children[0],
children[1],
children[2],
children[3],
children[4],
],
)
}
fn cleanup(&self) -> bool {

View File

@@ -425,7 +425,7 @@ mod tests {
fn test_raw_function_extraction() {
let Ok(ctx) = CudaContext::new(0) else { return };
let kernel_src = r#"extern "C" __global__ void test_kernel(float* out) { out[0] = 1.0f; }"#;
let Ok(ptx) = crate::compile_module_image_for_current_device(&ctx, kernel_src) else {
let Ok(ptx) = cudarc::nvrtc::compile_ptx(kernel_src) else {
return;
};
let module = ctx.load_module(ptx).unwrap();
@@ -448,7 +448,7 @@ mod tests {
use cudarc::driver::{CudaSlice, DevicePtr};
let Ok(ctx) = CudaContext::new(0) else { return };
let kernel_src = r#"extern "C" __global__ void test_kernel(float* out, float* in1) { if (threadIdx.x == 0) out[0] = in1[0] + 1.0f; }"#;
let Ok(ptx) = crate::compile_module_image_for_current_device(&ctx, kernel_src) else {
let Ok(ptx) = cudarc::nvrtc::compile_ptx(kernel_src) else {
return;
};
let module = ctx.load_module(ptx).unwrap();
@@ -492,13 +492,12 @@ mod tests {
let a = cx.tensor(size).persist();
let b = cx.tensor(size).persist();
let c = ((a + b) * a + b).output();
cx.build_search_space_exclude_ops::<CudaRuntime, crate::block::Ops>();
let mut rt = CudaRuntime::initialize(stream);
let data_a = random_f32_vec(size, 42, -0.5, 0.5);
let data_b = random_f32_vec(size, 43, -0.5, 0.5);
rt.set_data(a, data_a.clone());
rt.set_data(b, data_b.clone());
cx.build_search_space::<CudaRuntime>();
rt = cx.search(rt, 5);
rt.execute(&cx.dyn_map);
let result1 = rt.get_f32(c);
@@ -524,13 +523,12 @@ mod tests {
let a = cx.tensor(size).persist();
let b = cx.tensor(size).persist();
let c = (a + b + a + b).output();
cx.build_search_space_exclude_ops::<CudaRuntime, crate::block::Ops>();
let mut rt = CudaRuntime::initialize(stream);
let data_a = random_f32_vec(size, 42, -0.5, 0.5);
let data_b = random_f32_vec(size, 43, -0.5, 0.5);
rt.set_data(a, data_a.clone());
rt.set_data(b, data_b.clone());
cx.build_search_space::<CudaRuntime>();
rt = cx.search(rt, 5);
let mut results = Vec::new();
for _ in 0..5 {
@@ -561,14 +559,13 @@ mod tests {
let b = cx.tensor('s');
let c = (a + b).output();
let d = (c * a).output();
cx.build_search_space_exclude_ops::<CudaRuntime, crate::block::Ops>();
let mut rt = CudaRuntime::initialize(stream);
let data_a = random_f32_vec(size, 42, -0.5, 0.5);
let data_b = random_f32_vec(size, 43, -0.5, 0.5);
rt.set_data(a, data_a.clone());
rt.set_data(b, data_b.clone());
cx.set_dim('s', size);
cx.build_search_space::<CudaRuntime>();
rt = cx.search(rt, 5);
rt.execute(&cx.dyn_map);
let expected: Vec<f32> = data_a
@@ -604,13 +601,12 @@ mod tests {
let a = cx.tensor(size);
let b = cx.tensor(size);
let c = (a + b).output();
cx.build_search_space_exclude_ops::<CudaRuntime, crate::block::Ops>();
let mut rt = CudaRuntime::initialize(stream);
let data_a = random_f32_vec(size, 42, -0.5, 0.5);
let data_b = random_f32_vec(size, 43, -0.5, 0.5);
rt.set_data(a, data_a.clone());
rt.set_data(b, data_b.clone());
cx.build_search_space::<CudaRuntime>();
rt = cx.search(rt, 5);
rt.execute(&cx.dyn_map);
let expected: Vec<f32> = data_a.iter().zip(&data_b).map(|(a, b)| a + b).collect();
@@ -635,13 +631,12 @@ mod tests {
result *= b;
}
let output = result.output();
cx.build_search_space_exclude_ops::<CudaRuntime, crate::block::Ops>();
let mut rt = CudaRuntime::initialize(stream);
let data_a = random_f32_vec(size, 42, -0.5, 0.5);
let data_b = random_f32_vec(size, 43, -0.5, 0.5);
rt.set_data(a, data_a.clone());
rt.set_data(b, data_b.clone());
cx.build_search_space::<CudaRuntime>();
rt = cx.search(rt, 5);
for _ in 0..10 {
rt.execute(&cx.dyn_map);
@@ -653,53 +648,4 @@ mod tests {
}
assert_close(&rt.get_f32(output), &expected, 1e-2, 1e-2);
}
/// Test that CUDA graphs produce correct results when dynamic dimensions
/// change incrementally across many executions (simulating a decode loop
/// where position offset increments each step).
#[test]
fn test_cuda_graph_incremental_dim_changes() {
let Some(stream) = get_cuda_stream() else {
return;
};
let mut cx = Graph::default();
let a = cx.tensor('s');
let b = cx.tensor('s');
let c = ((a + b) * a).output();
let initial_size = 128;
cx.set_dim('s', initial_size);
let mut rt = CudaRuntime::initialize(stream);
let data_a = random_f32_vec(initial_size, 42, -0.5, 0.5);
let data_b = random_f32_vec(initial_size, 43, -0.5, 0.5);
rt.set_data(a, data_a.clone());
rt.set_data(b, data_b.clone());
cx.build_search_space::<CudaRuntime>();
rt = cx.search(rt, 5);
// Initial execution
rt.execute(&cx.dyn_map);
let eps = dtype_epsilon(luminal::dtype::DType::F32);
let tol = eps * TOLERANCE_SAFETY_FACTOR;
let expected: Vec<f32> = data_a
.iter()
.zip(&data_b)
.map(|(a, b)| (a + b) * a)
.collect();
assert_close(&rt.get_f32(c), &expected, tol, tol);
// Incrementally change the dynamic dimension 10 times,
// simulating decode steps where position offset grows.
for step in 1..=10usize {
let size = initial_size + step;
cx.set_dim('s', size);
let da = random_f32_vec(size, 100 + step as u64, -0.5, 0.5);
let db = random_f32_vec(size, 200 + step as u64, -0.5, 0.5);
rt.set_data(a, da.clone());
rt.set_data(b, db.clone());
rt.execute(&cx.dyn_map);
let expected: Vec<f32> = da.iter().zip(&db).map(|(a, b)| (a + b) * a).collect();
assert_close(&rt.get_f32(c), &expected, tol, tol);
}
}
}

View File

@@ -173,23 +173,9 @@ pub trait KernelOp: std::fmt::Debug + as_any::AsAny {
/// Returns the output buffer size in elements.
fn output_size(&self) -> Expression;
/// Returns all dynamic variables used by this kernel (for grid dims, strides, etc).
/// Default: returns dyn vars from output_size(). Override if the kernel has dyn vars
/// in expressions not captured by output_size (e.g., KernelScatter's index_shape).
fn all_dyn_vars(&self) -> FxHashSet<char> {
self.output_size().dyn_vars().into_iter().collect()
}
/// Returns the output buffer size in bytes (accounts for dtype).
fn output_bytes(&self) -> Expression;
/// Returns the DType of this kernel's output buffer.
/// Used by has_nan_outputs to interpret buffer bytes correctly.
/// Default: F32 (most kernels output float).
fn output_dtype(&self) -> DType {
DType::F32
}
/// Returns the number of bytes this kernel will load from global memory.
fn bytes_loaded(&self) -> Expression {
0.into()
@@ -258,21 +244,18 @@ pub trait KernelOp: std::fmt::Debug + as_any::AsAny {
) {
}
/// If this kernel's output aliases one of its inputs (i.e., writes in-place),
/// return the input index. Used to propagate buffer pointers in CUDA graphs.
fn output_aliases_input(&self) -> Option<usize> {
None
}
/// If this kernel's output is derived from one of its inputs (copy-then-modify
/// or in-place write), return that input index. Used by `resolve_data_node` to
/// trace buffer ownership back to HLIR inputs for the remove_buffer/set_buffer
/// roundtrip pattern.
///
/// Defaults to `output_aliases_input()`. Override for copy-then-modify ops
/// (like Scatter which copies dest→output then scatters into it).
fn output_data_input(&self) -> Option<usize> {
self.output_aliases_input()
/// Called before each CUDA graph launch. Runs stream-level work outside the graph.
/// Used by ops like KernelScatter that need a copy kernel before the main graph kernel.
/// Default: no-op.
fn pre_launch(
&self,
_stream: &Arc<CudaStream>,
_output_ptr: u64,
_input_ptrs: &[u64],
_dyn_dims_ptr: u64,
_dyn_map: &FxHashMap<char, usize>,
) -> anyhow::Result<()> {
Ok(())
}
/// Returns indices of internal buffers containing timing data, if any.

View File

@@ -0,0 +1,209 @@
use std::sync::Arc;
use crate::{
cuda_dtype,
kernel::KernelOp,
kernel::hlir::{compile_kernel, dtype_includes, generate_dyn_dims_defines},
};
use cudarc::{
driver::{CudaContext, CudaFunction, CudaModule, CudaSlice, CudaStream},
nvrtc::CompileOptions,
};
use itertools::Itertools;
use luminal::{
egglog_utils::{
api::{Rule, SortDef, sort},
base::{DTYPE, ELIST, EXPRESSION, IR},
extract_dtype, extract_expr, extract_expr_list,
},
op::*,
prelude::*,
};
pub type Ops = (KernelMeanReduce,);
#[derive(Default, Debug, Clone)]
pub struct KernelMeanReduce {
out_shape: Vec<Expression>,
iters: Expression,
in_stride: Vec<Expression>,
iter_stride: Expression,
out_stride: Vec<Expression>,
dtype: DType,
}
impl EgglogOp for KernelMeanReduce {
fn sort(&self) -> SortDef {
sort(
IR,
"KernelMean",
&[
("shape", ELIST),
("iters", EXPRESSION),
("inp", IR),
("strides", ELIST),
("iter_stride", EXPRESSION),
("out_strides", ELIST),
("dtype", DTYPE),
],
)
}
fn rewrites(&self) -> Vec<Rule> {
// Disabled: the e-graph union introduced by this rule can cause the search
// to select genomes with accumulated FP precision issues over many layers.
// The unfused Sum + Mul(Recip(Cast(Iota))) path produces equivalent results.
vec![]
}
fn cleanup(&self) -> bool {
false
}
fn extract<'a>(
&self,
egraph: &'a SerializedEGraph,
children: &[&'a ENodeId],
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
) -> (LLIROp, Vec<&'a ENodeId>) {
(
{
let out_shape =
extract_expr_list(egraph, children[0], list_cache, expr_cache).unwrap();
let iters = extract_expr(egraph, children[1], expr_cache).unwrap();
let in_stride =
extract_expr_list(egraph, children[3], list_cache, expr_cache).unwrap();
let iter_stride = extract_expr(egraph, children[4], expr_cache).unwrap();
let out_stride =
extract_expr_list(egraph, children[5], list_cache, expr_cache).unwrap();
let dtype = extract_dtype(egraph, children[6]);
LLIROp::new::<dyn KernelOp>(Box::new(Self {
out_shape,
iters,
in_stride,
iter_stride,
out_stride,
dtype,
}) as Box<dyn KernelOp>)
},
vec![children[2]],
)
}
}
impl KernelOp for KernelMeanReduce {
fn compile(
&self,
stream: &Arc<CudaStream>,
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
) -> (
CudaFunction,
Arc<CudaModule>,
String,
(Expression, Expression, Expression),
(Expression, Expression, Expression),
Expression,
FxHashMap<char, CudaSlice<u8>>,
) {
let vars = self
.out_shape
.iter()
.flat_map(|e| e.dyn_vars())
.chain(self.in_stride.iter().flat_map(|e| e.dyn_vars()))
.chain(self.out_stride.iter().flat_map(|e| e.dyn_vars()))
.chain(self.iters.dyn_vars())
.chain(self.iter_stride.dyn_vars())
.collect::<FxHashSet<_>>();
let dtype = cuda_dtype(self.dtype);
let includes = dtype_includes(&[self.dtype]);
let n_outputs: Expression = self.out_shape.iter().copied().product();
let threads_per_block = 256; // 8 warps per block
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
let dyn_dims_param = if vars.is_empty() {
""
} else {
", const int* dyn_dims"
};
let kernel = format!(
"{includes}
{dyn_defines}
extern \"C\" {{
__global__ void reduce_mean_k({dtype} *out, const {dtype} *in{dyn_dims_param}) {{
long long const_z = blockIdx.x;
long long n_elements = {n_outputs};
if (const_z >= n_elements) return;
long long in_start = {in_index};
long long iters = {iters};
long long iter_stride = {iter_stride};
{dtype} sum = 0;
for (long long i = 0; i < iters; i++) {{
sum += in[in_start + i * iter_stride];
}}
out[{out_index}] = ({dtype})(sum / ({dtype})iters);
}}
}}",
dtype = dtype,
in_index = flatten_strides(&self.out_shape, &self.in_stride).to_kernel(),
out_index = flatten_strides(&self.out_shape, &self.out_stride).to_kernel(),
n_outputs = n_outputs.to_kernel(),
iters = self.iters.to_kernel(),
iter_stride = self
.iter_stride
.substitute('z', Expression::from(1))
.simplify()
.to_kernel(),
);
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
(module.clone(), func.clone())
} else {
let ptx = compile_kernel(&kernel, &[self.dtype]);
let module = stream.context().load_module(ptx).unwrap();
let func = module.load_function("reduce_mean_k").unwrap();
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
(module, func)
};
(
func,
module,
kernel,
(n_outputs, 1.into(), 1.into()), // grid
(1.into(), 1.into(), 1.into()), // blocks (single-threaded)
0.into(), // shmem size
FxHashMap::default(),
)
}
fn output_size(&self) -> Expression {
self.out_shape.iter().copied().product()
}
fn output_bytes(&self) -> Expression {
(self.output_size() * self.dtype.bits()).ceil_div(8)
}
fn bytes_loaded(&self) -> Expression {
(self.out_shape.iter().copied().product::<Expression>() * self.iters * self.dtype.bits())
.ceil_div(8)
}
fn bytes_stored(&self) -> Expression {
self.output_bytes()
}
fn flops(&self) -> Expression {
let n_outputs: Expression = self.out_shape.iter().copied().product();
n_outputs * self.iters + n_outputs
}
fn kernel_name(&self) -> &'static str {
"MeanReduce"
}
}

View File

@@ -11,7 +11,7 @@ use cudarc::driver::{
};
use itertools::Itertools;
use luminal::{
egglog_utils::{api::Rule, base::OP_KIND},
egglog_utils::{api::Rule, base::IR},
graph::LLIRGraph,
op::{EgglogOp, LLIROp},
prelude::{
@@ -26,7 +26,6 @@ use crate::{
kernel::{
CudaFunctionExt, CudaGraphExecHandle, CudaGraphHandle, KernelOp, create_cuda_event,
destroy_cuda_event,
hlir::{clear_global_dyn_dims, get_global_dyn_dims, set_global_dyn_dims},
},
runtime::partition_marked_convex,
};
@@ -196,7 +195,7 @@ impl std::fmt::Debug for CudaGraphOp {
impl EgglogOp for CudaGraphOp {
fn sort(&self) -> luminal::egglog_utils::api::SortDef {
luminal::egglog_utils::api::sort(OP_KIND, "CudaGraphOp", &[])
luminal::egglog_utils::api::sort(IR, "CudaGraphOp", &[])
}
fn rewrites(&self) -> Vec<Rule> {
@@ -206,8 +205,7 @@ impl EgglogOp for CudaGraphOp {
fn extract<'a>(
&'a self,
_egraph: &'a luminal::egglog_utils::SerializedEGraph,
_kind_children: &[&'a luminal::prelude::ENodeId],
_input_enodes: Vec<&'a luminal::prelude::ENodeId>,
_children: &[&'a luminal::prelude::ENodeId],
_list_cache: &mut FxHashMap<&'a luminal::prelude::ENodeId, Vec<Expression>>,
_expr_cache: &mut FxHashMap<&'a luminal::prelude::ENodeId, Expression>,
) -> (LLIROp, Vec<&'a luminal::prelude::ENodeId>) {
@@ -301,11 +299,7 @@ impl CudaGraphOp {
for kernel in state.kernels.iter_mut() {
kernel.internal_bufs = kernel.kernel_op.allocate_internal_buffers(stream, dyn_map);
}
}
// Only force full rebuild when internal buffer sizes change.
// Dim-only changes (e.g. position offset `p` incrementing each decode step) are
// handled by updating the dyn_dims device buffer + kernel node params in-place.
if needs_internal_realloc {
// Internal buffer pointers changed, need to rebuild CUDA graph
state.cuda_graph = None;
state.cuda_graph_exec = None;
state.node_to_graph_node.clear();
@@ -346,15 +340,6 @@ impl CudaGraphOp {
}
}
// Apply output-aliases-input
for kernel in state.kernels.iter() {
if let Some(input_idx) = kernel.kernel_op.output_aliases_input()
&& let Some(&input_ptr) = current_buffer_ptrs.get(&kernel.inputs[input_idx])
{
current_buffer_ptrs.insert(kernel.node, input_ptr);
}
}
// Always call pre_execute for each kernel to reset internal state
// (e.g., MegakernelOps need work queue, head, barriers, lock reset every execution)
for idx in 0..state.kernels.len() {
@@ -440,9 +425,43 @@ impl CudaGraphOp {
state.last_buffer_ptrs = current_buffer_ptrs;
}
// Call pre_launch for each kernel (e.g., KernelScatter copies dest→output before graph)
{
let dyn_dims_ptr = state
.dyn_dims_buffer
.as_ref()
.map(|buf| buf.device_ptr(stream).0)
.unwrap_or(0);
for kernel in state.kernels.iter() {
let output_ptr = state
.last_buffer_ptrs
.get(&kernel.node)
.copied()
.unwrap_or(0);
let input_ptrs: Vec<u64> = kernel
.inputs
.iter()
.map(|inp| state.last_buffer_ptrs.get(inp).copied().unwrap_or(0))
.collect();
kernel.kernel_op.pre_launch(
stream,
output_ptr,
&input_ptrs,
dyn_dims_ptr,
dyn_map,
)?;
}
}
// Sync before launch
stream.synchronize()?;
// Launch the graph
state.cuda_graph_exec.as_ref().unwrap().launch(stream)?;
// Sync after launch
stream.synchronize()?;
Ok(())
}
@@ -599,7 +618,7 @@ impl Drop for CudaGraphOp {
fn drop(&mut self) {
let mut state = self.state.borrow_mut();
// Destroy timing events first
// Destroy timing events - extract ctx first to avoid borrow issues
let ctx = state.cuda_graph_exec.as_ref().map(|exec| exec.ctx.clone());
if let Some(ctx) = ctx {
for event in state.timing_events.drain(..) {
@@ -607,22 +626,22 @@ impl Drop for CudaGraphOp {
}
}
// Destroy CUDA graph handles BEFORE freeing buffers they reference.
// The graph exec holds device pointers to dyn_dims_buffer and internal_bufs,
// so it must be destroyed first to avoid dangling pointer issues.
drop(state.cuda_graph_exec.take());
drop(state.cuda_graph.take());
// Forget dyn_dims buffer (managed by runtime)
if let Some(buf) = state.dyn_dims_buffer.take() {
std::mem::forget(buf);
}
// Now safe to free dynamically allocated GPU buffers
// (dyn_dims_buffer and internal_bufs are freed by normal Drop)
// Constants point to __constant__ memory in the CUDA module,
// not dynamically allocated — must not be freed.
// Handle kernel resources
for kernel in state.kernels.iter_mut() {
// Forget constants (they point to __constant__ memory)
let constants = std::mem::take(&mut kernel.constants);
for (_k, v) in constants {
std::mem::forget(v);
}
// Forget internal buffers (managed by runtime)
for buf in kernel.internal_bufs.drain(..) {
std::mem::forget(buf);
}
}
}
}
@@ -642,6 +661,7 @@ pub fn kernel_to_host(
llir_graph: &mut LLIRGraph,
cuda_stream: &Arc<CudaStream>,
kernel_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
megakernel_to_blocks: &FxHashMap<NodeIndex, Vec<NodeIndex>>,
) {
let _span = span!(Level::TRACE, "kernel_to_host").entered();
@@ -669,28 +689,11 @@ pub fn kernel_to_host(
.filter(|n| subgraph.contains(n))
.collect();
let mut kernels = Vec::with_capacity(topo_order.len());
let mut all_dyn_dims = FxHashSet::default();
let mut all_buffer_nodes = FxHashSet::default();
let mut all_buffer_sizes: FxHashMap<NodeIndex, Expression> = FxHashMap::default();
// Pre-scan: collect all dynamic vars from all kernel ops without compiling.
// This uses KernelOp::all_dyn_vars() which inspects struct expression fields.
for kernel_node_idx in &topo_order {
let kernel_op_ref = llir_graph[*kernel_node_idx]
.to_dialect::<dyn KernelOp>()
.unwrap();
all_dyn_dims.extend(kernel_op_ref.all_dyn_vars());
}
// Set global dyn dims ordering so compiles use consistent indices
let mut global_dyn_dims: Vec<char> = all_dyn_dims.iter().copied().collect();
global_dyn_dims.sort();
if !global_dyn_dims.is_empty() {
set_global_dyn_dims(global_dyn_dims.clone());
}
// Compile all kernels with global ordering for correct dyn_dims indices
let mut kernels = Vec::with_capacity(topo_order.len());
for kernel_node_idx in &topo_order {
let kernel_op_ref = llir_graph[*kernel_node_idx]
.to_dialect::<dyn KernelOp>()
@@ -706,6 +709,21 @@ pub fn kernel_to_host(
.map(|e| e.source())
.collect_vec();
// If this is a megakernel, include all its block op nodes for buffer access
if let Some(block_nodes) = megakernel_to_blocks.get(kernel_node_idx) {
inputs.extend(block_nodes.iter().copied());
}
// Collect dyn dims used by this kernel
all_dyn_dims.extend(grid.0.dyn_vars());
all_dyn_dims.extend(grid.1.dyn_vars());
all_dyn_dims.extend(grid.2.dyn_vars());
all_dyn_dims.extend(block.0.dyn_vars());
all_dyn_dims.extend(block.1.dyn_vars());
all_dyn_dims.extend(block.2.dyn_vars());
all_dyn_dims.extend(shared_mem.dyn_vars());
all_dyn_dims.extend(kernel_op_ref.output_size().dyn_vars());
// Collect buffer nodes and sizes
// Only add kernel nodes with non-zero output size (MegakernelOps have size 0)
let output_size = kernel_op_ref.output_size();
@@ -730,19 +748,9 @@ pub fn kernel_to_host(
));
}
// Get the possibly-extended global ordering (kernels may have discovered new dims)
let final_global = get_global_dyn_dims();
// Clear global ordering now that all kernels are compiled
clear_global_dyn_dims();
// Use the final global ordering if it was extended during compilation
let mut dyn_dims_order: Vec<char> = if let Some(final_order) = final_global {
final_order
} else {
let mut dims: Vec<char> = all_dyn_dims.into_iter().collect();
dims.sort();
dims
};
// Sort dyn dims alphabetically for consistent buffer layout
let mut dyn_dims_order: Vec<char> = all_dyn_dims.into_iter().collect();
dyn_dims_order.sort();
let buffer_nodes: Vec<NodeIndex> = all_buffer_nodes.into_iter().collect();
@@ -765,6 +773,14 @@ pub fn kernel_to_host(
for kernel_node in &subgraph {
kernel_to_cuda_graph.insert(*kernel_node, cuda_graph_node);
}
// Also track block op nodes inside megakernels
for kernel_node in &subgraph {
if let Some(block_nodes) = megakernel_to_blocks.get(kernel_node) {
for block_node in block_nodes {
kernel_to_cuda_graph.insert(*block_node, cuda_graph_node);
}
}
}
cuda_graph_subgraphs.push((cuda_graph_node, subgraph.clone()));
// Find external inputs: nodes outside subgraph that have edges into subgraph
@@ -792,15 +808,23 @@ pub fn kernel_to_host(
// Second pass: Add edges between CudaGraphOps based on kernel dependencies.
// This ensures proper execution ordering when a kernel in one CudaGraphOp
// produces output consumed by a kernel in another CudaGraphOp.
// produces output consumed by a kernel (or BlockOp inside a megakernel) in another CudaGraphOp.
let mut edges_to_add: Vec<(NodeIndex, NodeIndex)> = Vec::new();
for (cuda_graph_node, subgraph) in &cuda_graph_subgraphs {
// Find all nodes that this subgraph produces output for (including BlockOp nodes in megakernels)
let mut all_producer_nodes: FxHashSet<NodeIndex> = subgraph.clone();
for kernel_node in subgraph {
if let Some(block_nodes) = megakernel_to_blocks.get(kernel_node) {
all_producer_nodes.extend(block_nodes.iter().copied());
}
}
// Find external consumers that are kernels belonging to other CudaGraphOps
for producer_node in subgraph {
for producer_node in &all_producer_nodes {
for edge in llir_graph.edges_directed(*producer_node, Direction::Outgoing) {
let consumer = edge.target();
if subgraph.contains(&consumer) {
if all_producer_nodes.contains(&consumer) {
continue; // Same subgraph
}
// Check if consumer is a kernel in another CudaGraphOp

View File

@@ -0,0 +1,57 @@
pub mod block;
pub mod host;
pub mod kernel;
pub mod logical;
pub mod runtime;
use std::sync::Arc;
pub use cudarc;
#[cfg(test)]
mod tests;
use cudarc::driver::CudaContext;
use luminal::dtype::DType;
fn cuda_dtype(dtype: DType) -> &'static str {
match dtype {
DType::F64 => "double",
DType::F32 => "float",
DType::F16 => "half",
DType::Bf16 => "__nv_bfloat16",
DType::TF32 => "float", // TF32 uses float storage, tensor cores handle the format
DType::Int => "int",
DType::I16 => "short",
DType::U16 => "unsigned short",
DType::I8 => "signed char",
DType::U8 => "unsigned char",
DType::Bool => "unsigned char",
DType::F8E4M3 => "__nv_fp8_e4m3",
DType::F8E5M2 => "__nv_fp8_e5m2",
DType::F8UE8M0 => "__nv_fp8_e8m0",
DType::F6E2M3 => "__nv_fp6_e2m3",
DType::F6E3M2 => "__nv_fp6_e3m2",
DType::F4E2M1 => "__nv_fp4_e2m1",
DType::I4 | DType::U4 => "unsigned char", // Sub-byte, packed storage
}
}
/// Returns the bandwidth of the device in GB/s
pub fn cuda_bandwidth_gbps(ctx: &Arc<CudaContext>) -> Option<usize> {
Some(match ctx.name().unwrap().as_str() {
"NVIDIA Thor" => 273,
"NVIDIA H100 PCIe" => 2_000,
"NVIDIA H100 SXM" => 3_350,
_ => return None,
})
}
/// Returns the bandwidth of the device in TFLOPs
pub fn cuda_compute_f32_tflops(ctx: &Arc<CudaContext>) -> Option<usize> {
Some(match ctx.name().unwrap().as_str() {
"NVIDIA Thor" => 125, // forced to use tf32 flops
"NVIDIA H100 PCIe" => 756,
"NVIDIA H100 SXM" => 989,
_ => return None,
})
}

View File

@@ -0,0 +1,71 @@
use std::fmt::Debug;
use luminal::{
egglog_utils::{
api::{Rule, SortDef},
base::OP_SORTS,
},
op::EgglogOp,
};
pub type Ops = (Exp, Sigmoid);
#[derive(Debug, Default)]
pub struct Exp;
impl EgglogOp for Exp {
fn sort(&self) -> SortDef {
OP_SORTS.unary("Exp")
}
fn cleanup(&self) -> bool {
true
}
fn rewrites(&self) -> Vec<Rule> {
vec![Rule::raw(
"(rule
(
(= ?exp_const (Constant 1.442695))
(= ?mul (Mul ?shape ?x ?x_stride ?exp_const ?const_stride ?intermediate_stride))
(= ?exp2 (Exp2 ?shape ?mul ?intermediate_stride ?out_stride))
(= ?dt (dtype ?x))
)
(
(let ?exp (Exp ?shape ?x ?x_stride ?out_stride))
(union ?exp2 ?exp)
(set (dtype ?exp) ?dt)
)
)",
)]
}
}
#[derive(Default, Debug, Clone)]
pub struct Sigmoid;
impl EgglogOp for Sigmoid {
fn sort(&self) -> SortDef {
OP_SORTS.unary("Sigmoid")
}
fn cleanup(&self) -> bool {
true
}
fn rewrites(&self) -> Vec<Rule> {
vec![Rule::raw("(rule
(
(= ?neg_input (Mul ?input_range ?input ?input_stride (Constant -1.0) ?const_stride ?intermediate_stride))
(= ?exp (Exp ?input_range ?neg_input ?intermediate_stride ?exp_stride))
(= ?plus_one (Add ?input_range ?exp ?exp_stride (Constant 1.0) ?const_stride ?plus_one_stride))
(= ?sig_out (Recip ?input_range ?plus_one ?plus_one_stride ?out_stride))
(= ?dt (dtype ?input))
)
(
(let ?sig (Sigmoid ?input_range ?input ?input_stride ?out_stride))
(union ?sig_out ?sig)
(set (dtype ?sig) ?dt)
)
:name \"sigmoid\"
)")]
}
}

View File

@@ -1,9 +1,5 @@
pub mod utilities;
#[cfg(test)]
mod bucket_tests;
#[cfg(test)]
mod consumed_buffer_tests;
#[cfg(test)]
mod model_fuzz;
#[cfg(test)]

View File

@@ -299,9 +299,11 @@ fn fuzz_layer_no_attn(
);
}
/// Test a SwiGLU MLP with HLIR-only to specifically verify
/// Test a SwiGLU MLP with HLIR-only (no block ops) to specifically verify
/// the HLIR matmul decomposition (KernelMul + KernelSumReduce).
fn fuzz_mlp_hlir_only(seq: usize, hidden: usize, intermediate: usize, seed: u64) {
use crate::block::Ops as BlockOps;
let Some(stream) = get_cuda_stream() else {
return;
};
@@ -313,7 +315,8 @@ fn fuzz_mlp_hlir_only(seq: usize, hidden: usize, intermediate: usize, seed: u64)
let w_down = cx.tensor((hidden, intermediate));
let out = swiglu_mlp(input, w_gate, w_up, w_down).output();
cx.build_search_space::<CudaRuntime>();
// Exclude all block ops to force HLIR kernel fallback
cx.build_search_space_exclude_ops::<CudaRuntime, BlockOps>();
let mut rt = CudaRuntime::initialize(stream.clone());
let input_data = random_f32_vec(seq * hidden, seed, -0.5, 0.5);

View File

@@ -24,8 +24,8 @@ proptest! {
let gen_lambda = |n, s| random_f32_vec(n, s, -0.5, 0.5);
let eps = dtype_epsilon(luminal::dtype::DType::F32);
let (rtol, atol) = (eps * TOLERANCE_SAFETY_FACTOR, eps * TOLERANCE_SAFETY_FACTOR);
test_binary_cuda(x, x, |a, b| a + b, |a, b| (&a + &b).unwrap(), gen_lambda, gen_lambda, seed, rtol, atol);
test_binary_cuda((y, x), (y, x), |a, b| a + b, |a, b| (&a + &b).unwrap(), gen_lambda, gen_lambda, seed, rtol, atol);
test_binary_cuda(x, x, |a, b| a + b, |a, b| (&a + &b).unwrap(), &gen_lambda, &gen_lambda, seed, rtol, atol);
test_binary_cuda((y, x), (y, x), |a, b| a + b, |a, b| (&a + &b).unwrap(), &gen_lambda, &gen_lambda, seed, rtol, atol);
}
#[test]
@@ -33,20 +33,20 @@ proptest! {
let gen_lambda = |n, s| random_f32_vec(n, s, -0.5, 0.5);
let eps = dtype_epsilon(luminal::dtype::DType::F32);
let (rtol, atol) = (eps * TOLERANCE_SAFETY_FACTOR, eps * TOLERANCE_SAFETY_FACTOR);
test_binary_cuda(x, x, |a, b| a * b, |a, b| (&a * &b).unwrap(), gen_lambda, gen_lambda, seed, rtol, atol);
test_binary_cuda((y, x), (y, x), |a, b| a * b, |a, b| (&a * &b).unwrap(), gen_lambda, gen_lambda, seed, rtol, atol);
test_binary_cuda(x, x, |a, b| a * b, |a, b| (&a * &b).unwrap(), &gen_lambda, &gen_lambda, seed, rtol, atol);
test_binary_cuda((y, x), (y, x), |a, b| a * b, |a, b| (&a * &b).unwrap(), &gen_lambda, &gen_lambda, seed, rtol, atol);
}
#[test]
fn test_max(rows in 1usize..8, cols in 1usize..8, seed in any::<u64>()) {
let gen_lambda = |n, s| random_f32_vec(n, s, -0.5, 0.5);
test_unary_cuda((rows, cols), |a| a.max(1), |a| a.max(1).unwrap(), gen_lambda, seed);
test_unary_cuda((rows, cols), |a| a.max(1), |a| a.max(1).unwrap(), &gen_lambda, seed);
}
#[test]
fn test_mean(rows in 1usize..8, cols in 1usize..8, seed in any::<u64>()) {
let gen_lambda = |n, s| random_f32_vec(n, s, -0.5, 0.5);
test_unary_cuda((rows, cols), |a| a.mean(1), |a| a.mean(1).unwrap(), gen_lambda, seed);
test_unary_cuda((rows, cols), |a| a.mean(1), |a| a.mean(1).unwrap(), &gen_lambda, seed);
}
#[test]
@@ -115,7 +115,7 @@ proptest! {
let atol = 5.0 * eps;
let gen_lambda = |n, s| random_f32_vec(n, s, -0.5, 0.5);
test_binary_cuda(a_shape, b_shape, luminal_op, candle_op, gen_lambda, gen_lambda, seed, rtol, atol);
test_binary_cuda(a_shape, b_shape, luminal_op, candle_op, &gen_lambda, &gen_lambda, seed, rtol, atol);
}
// Unary ops tests
@@ -123,37 +123,37 @@ proptest! {
fn test_exp2(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
// exp2(x) = 2^x, verified by computing 2^x using exp(x * ln(2))
let gen_lambda = |n, s| random_f32_vec(n, s, -0.5, 0.5);
test_unary_cuda(x, |a| a.exp2(), |a| (a * 2.0f64.ln()).unwrap().exp().unwrap(), gen_lambda, seed);
test_unary_cuda((y, x), |a| a.exp2(), |a| (a * 2.0f64.ln()).unwrap().exp().unwrap(), gen_lambda, seed);
test_unary_cuda(x, |a| a.exp2(), |a| (a * 2.0f64.ln()).unwrap().exp().unwrap(), &gen_lambda, seed);
test_unary_cuda((y, x), |a| a.exp2(), |a| (a * 2.0f64.ln()).unwrap().exp().unwrap(), &gen_lambda, seed);
}
#[test]
fn test_log2(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
// log2(x) = ln(x) / ln(2)
let gen_lambda = |n, s| random_f32_vec(n, s, 0.1, 0.6);
test_unary_cuda(x, |a| a.log2(), |a| (a.log().unwrap() / 2.0f64.ln()).unwrap(), gen_lambda, seed);
test_unary_cuda((y, x), |a| a.log2(), |a| (a.log().unwrap() / 2.0f64.ln()).unwrap(), gen_lambda, seed);
test_unary_cuda(x, |a| a.log2(), |a| (a.log().unwrap() / 2.0f64.ln()).unwrap(), &gen_lambda, seed);
test_unary_cuda((y, x), |a| a.log2(), |a| (a.log().unwrap() / 2.0f64.ln()).unwrap(), &gen_lambda, seed);
}
#[test]
fn test_sin(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
let gen_lambda = |n, s| random_f32_vec(n, s, -0.5, 0.5);
test_unary_cuda(x, |a| a.sin(), |a| a.sin().unwrap(), gen_lambda, seed);
test_unary_cuda((y, x), |a| a.sin(), |a| a.sin().unwrap(), gen_lambda, seed);
test_unary_cuda(x, |a| a.sin(), |a| a.sin().unwrap(), &gen_lambda, seed);
test_unary_cuda((y, x), |a| a.sin(), |a| a.sin().unwrap(), &gen_lambda, seed);
}
#[test]
fn test_recip(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
let gen_lambda = |n, s| random_f32_vec(n, s, 0.1, 0.5);
test_unary_cuda(x, |a| a.reciprocal(), |a| a.recip().unwrap(), gen_lambda, seed);
test_unary_cuda((y, x), |a| a.reciprocal(), |a| a.recip().unwrap(), gen_lambda, seed);
test_unary_cuda(x, |a| a.reciprocal(), |a| a.recip().unwrap(), &gen_lambda, seed);
test_unary_cuda((y, x), |a| a.reciprocal(), |a| a.recip().unwrap(), &gen_lambda, seed);
}
#[test]
fn test_sqrt(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
let gen_lambda = |n, s| random_f32_vec(n, s, 0.1, 0.6);
test_unary_cuda(x, |a| a.sqrt(), |a| a.sqrt().unwrap(), gen_lambda, seed);
test_unary_cuda((y, x), |a| a.sqrt(), |a| a.sqrt().unwrap(), gen_lambda, seed);
test_unary_cuda(x, |a| a.sqrt(), |a| a.sqrt().unwrap(), &gen_lambda, seed);
test_unary_cuda((y, x), |a| a.sqrt(), |a| a.sqrt().unwrap(), &gen_lambda, seed);
}
// Binary ops tests
@@ -166,12 +166,11 @@ proptest! {
#[test]
fn test_less_than(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
let gen_lambda = |n, s| random_f32_vec(n, s, -99.0, 100.0).into_iter().map(|v| v.floor()).collect();
test_binary_cuda(x, x, |a, b| a.lt(b).cast(luminal::dtype::DType::F32), |a, b| a.lt(&b).unwrap().to_dtype(candle_core::DType::F32).unwrap(), gen_lambda, gen_lambda, seed, 0.0, 0.0);
test_binary_cuda((y, x), (y, x), |a, b| a.lt(b).cast(luminal::dtype::DType::F32), |a, b| a.lt(&b).unwrap().to_dtype(candle_core::DType::F32).unwrap(), gen_lambda, gen_lambda, seed, 0.0, 0.0);
test_binary_cuda(x, x, |a, b| a.lt(b).cast(luminal::dtype::DType::F32), |a, b| a.lt(&b).unwrap().to_dtype(candle_core::DType::F32).unwrap(), &gen_lambda, &gen_lambda, seed, 0.0, 0.0);
test_binary_cuda((y, x), (y, x), |a, b| a.lt(b).cast(luminal::dtype::DType::F32), |a, b| a.lt(&b).unwrap().to_dtype(candle_core::DType::F32).unwrap(), &gen_lambda, &gen_lambda, seed, 0.0, 0.0);
}
}
#[allow(dead_code)]
fn run_argsort_test(rows: usize, cols: usize, seed: u64) {
let total = rows * cols;
@@ -276,19 +275,17 @@ fn run_argsort_test(rows: usize, cols: usize, seed: u64) {
}
}
// NOTE: Argsort proptest disabled due to pre-existing bug where argsort output shape
// through e-graph compilation returns only `rows` elements instead of `rows * cols`.
// proptest! {
// #![proptest_config(ProptestConfig::with_cases(10))]
// #[test]
// fn test_argsort(seed in any::<u64>()) {
// run_argsort_test(5, 500, seed);
// }
// }
proptest! {
#![proptest_config(ProptestConfig::with_cases(10))]
#[test]
fn test_argsort(seed in any::<u64>()) {
run_argsort_test(5, 500, seed);
}
}
/// Test F32 -> F16 -> F32 cast roundtrip with edge-case values.
#[test]
#[allow(clippy::approx_constant, clippy::excessive_precision)]
pub fn test_cast_f16_edge_cases() {
use luminal::dtype::DType;
@@ -326,7 +323,7 @@ pub fn test_cast_f16_edge_cases() {
.to_dtype(candle_core::DType::F32)
.unwrap()
},
gen_edge_cases,
&gen_edge_cases,
0,
);
}
@@ -352,7 +349,7 @@ proptest! {
.to_dtype(candle_core::DType::F32)
.unwrap()
},
gen_lambda,
&gen_lambda,
seed,
);
}

View File

@@ -173,7 +173,6 @@ fn swiglu_mlp_ref(
}
/// CPU reference for one transformer layer
#[allow(clippy::too_many_arguments)]
fn transformer_layer_ref(
x: &candle_core::Tensor,
attn_norm_w: &candle_core::Tensor,

View File

@@ -235,7 +235,6 @@ pub fn test_unary_cuda<T: TestDType>(
/// Base binary test function with input generators
/// Generic over dtype T - comparison happens in native precision.
/// Requires explicit rtol and atol tolerances (as f32, converted to T internally).
#[allow(clippy::too_many_arguments)]
pub fn test_binary_cuda<T: TestDType>(
a_shape: impl ToShape,
b_shape: impl ToShape,
@@ -411,7 +410,6 @@ pub fn gen_slice_range(
/// produce incorrect computation.
///
/// `setup_inputs` is called for each genome's fresh runtime to load input data.
#[allow(clippy::too_many_arguments)]
pub fn fuzz_genomes<T: TestDType>(
cx: &Graph,
stream: &Arc<cudarc::driver::CudaStream>,

View File

@@ -1,133 +0,0 @@
; Column-major × Column-major matmul: C[m,n] = A[m,k] × B[k,n]
; A[m,k] column-major → expand to [m, n, k] with strides [MIter, 0, m]
; B[k,n] column-major → expand to [m, n, k] with strides [0, k, MIter]
;
; Row-major viewed as column-major (swap trick):
; Column-major A[m,k] is already column-major with lda=m
; Column-major B[k,n] is already column-major with ldb=k
; Row-major C[m,n] ≡ column-major C^T[n,m] with ldc=n
;
; C^T[n,m] = (A × B)^T = B^T[n,k] × A^T[k,m]
; cuBLAS: cublasSgemm(OP_T, OP_T, n, m, k, α, B, k, A, m, β, C, n)
(rule
(
; Match Mul node
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
; Match Sum that reduces the Mul (k dimension)
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
; Match exactly 2D output shape
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
(!= ?m (MNum 0))
(!= ?n (MNum 0))
(!= ?k (MNum 1))
; Match exactly 3D strides [m, n, k]
(= ?a_stride (ECons ?a_m_stride (ECons ?a_n_stride (ECons ?a_k_stride (ENil)))))
(= ?b_stride (ECons ?b_m_stride (ECons ?b_n_stride (ECons ?b_k_stride (ENil)))))
; Assert contiguous k stride on output (required for reduction)
(= ?k_stride (MIter))
; Assert A has strides [MIter, 0, m*MIter] (column-major A[m,k] broadcast to [m,n,k])
(= ?a_m_stride (MIter))
(= ?a_n_stride (MNum 0))
(= ?a_k_stride (MMul (MIter) ?m))
; Assert B has strides [0, k*MIter, MIter] (column-major B[k,n] broadcast to [m,n,k])
(= ?b_m_stride (MNum 0))
(= ?b_n_stride (MMul (MIter) ?k))
(= ?b_k_stride (MIter))
(= ?dt (dtype ?a))
(= ?dt (dtype ?b))
)
(
; For column-major A × column-major B with cuBLAS:
; C^T = B^T × A^T → cublasSgemm(OP_T, OP_T, n, m, k, α, B, k, A, m, β, C, n)
(let ?sgemm (Op (cublaslt
?n ; cuBLAS m = our n (swapped)
?m ; cuBLAS n = our m (swapped)
?k ; k unchanged
"T" ; transa = Transpose (B is column-major [k,n], need B^T[n,k])
"T" ; transb = Transpose (A is column-major [m,k], need A^T[k,m])
?b_n_stride ; lda = B's column stride (resolves to k after z→1)
?a_k_stride ; ldb = A's column stride (resolves to m after z→1)
?n ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
(MNum 1) ; batch_count = 1
(MNum 0) ; stride_a = 0
(MNum 0) ; stride_b = 0
(MNum 0) ; stride_c = 0
?dt) ; dtype
(ICons ?b (ICons ?a (INil)))))
(union ?sum ?sgemm)
(set (dtype ?sgemm) ?dt)
)
:name "cublaslt column-major × column-major"
)
; Batched Column-major × Column-major: C[batch,m,n] = A[batch,m,k] × B[batch,k,n]
; A column-major per batch: a_m_stride=MIter, a_n_stride=0
; B column-major per batch: b_k_stride=MIter, b_m_stride=0
(rule
(
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
(= ?batch (nth_from_end ?out_shape 2))
(= ?m (nth_from_end ?out_shape 1))
(= ?n (nth_from_end ?out_shape 0))
(!= ?m (MNum 0))
(!= ?n (MNum 0))
(!= ?k (MNum 1))
(!= ?batch (MNum 0))
(= ?a_batch_stride (nth_from_end ?a_stride 3))
(= ?a_m_stride (nth_from_end ?a_stride 2))
(= ?a_n_stride (nth_from_end ?a_stride 1))
(= ?a_k_stride (nth_from_end ?a_stride 0))
(= ?b_batch_stride (nth_from_end ?b_stride 3))
(= ?b_m_stride (nth_from_end ?b_stride 2))
(= ?b_n_stride (nth_from_end ?b_stride 1))
(= ?b_k_stride (nth_from_end ?b_stride 0))
(= ?k_stride (MIter))
; A column-major: m=MIter, n=0, k_stride=m*MIter
(= ?a_m_stride (MIter))
(= ?a_n_stride (MNum 0))
(= ?a_k_stride (MMul (MIter) ?m))
; B column-major: k=MIter, m=0, n_stride=k*MIter
(= ?b_k_stride (MIter))
(= ?b_m_stride (MNum 0))
(= ?b_n_stride (MMul (MIter) ?k))
; Uniform batch strides (contiguous per batch)
(= ?a_batch_stride (MMul ?k ?a_k_stride))
(= ?b_batch_stride (MMul ?n ?b_n_stride))
(= ?dt (dtype ?a))
(= ?dt (dtype ?b))
)
(
; cuBLAS: cublas(OP_T, OP_T, n, m, k, B, lda=b_n_stride, A, ldb=a_k_stride, C, ldc=n)
(let ?sgemm (Op (cublaslt
?n ?m ?k
"T" "T"
?b_n_stride ; lda (cuBLAS A = our B, column stride)
?a_k_stride ; ldb (cuBLAS B = our A, column stride)
?n ; ldc
?batch
?b_batch_stride ; stride_a (cuBLAS A = our B)
?a_batch_stride ; stride_b (cuBLAS B = our A)
(MMul ?m ?n) ; stride_c
?dt)
(ICons ?b (ICons ?a (INil)))))
(union ?sum ?sgemm)
(set (dtype ?sgemm) ?dt)
)
:name "cublaslt batched column-major × column-major"
)

View File

@@ -1,133 +0,0 @@
; Column-major × Row-major matmul: C[m,n] = A[m,k] × B[k,n]
; A[m,k] column-major → expand to [m, n, k] with strides [MIter, 0, m]
; B[k,n] row-major → permute to [n,k] then expand to [m, n, k] with strides [0, MIter, n]
;
; Row-major viewed as column-major (swap trick):
; Column-major A[m,k] is already column-major with lda=m
; Row-major B[k,n] ≡ column-major B^T[n,k] with ldb=n
; Row-major C[m,n] ≡ column-major C^T[n,m] with ldc=n
;
; C^T[n,m] = (A × B)^T = B^T[n,k] × A^T[k,m]
; cuBLAS: cublasSgemm(OP_N, OP_T, n, m, k, α, B, n, A, m, β, C, n)
(rule
(
; Match Mul node
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
; Match Sum that reduces the Mul (k dimension)
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
; Match exactly 2D output shape
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
(!= ?m (MNum 0))
(!= ?n (MNum 0))
(!= ?k (MNum 1))
; Match exactly 3D strides [m, n, k]
(= ?a_stride (ECons ?a_m_stride (ECons ?a_n_stride (ECons ?a_k_stride (ENil)))))
(= ?b_stride (ECons ?b_m_stride (ECons ?b_n_stride (ECons ?b_k_stride (ENil)))))
; Assert contiguous k stride on output (required for reduction)
(= ?k_stride (MIter))
; Assert A has strides [MIter, 0, m*MIter] (column-major A[m,k] broadcast to [m,n,k])
(= ?a_m_stride (MIter))
(= ?a_n_stride (MNum 0))
(= ?a_k_stride (MMul (MIter) ?m))
; Assert B has strides [0, MIter, n*MIter] (row-major B[k,n] permuted to [n,k] then broadcast to [m,n,k])
(= ?b_m_stride (MNum 0))
(= ?b_n_stride (MIter))
(= ?b_k_stride (MMul (MIter) ?n))
(= ?dt (dtype ?a))
(= ?dt (dtype ?b))
)
(
; For column-major A × row-major B with cuBLAS:
; C^T = B^T × A^T → cublasSgemm(OP_N, OP_T, n, m, k, α, B, n, A, m, β, C, n)
(let ?sgemm (Op (cublaslt
?n ; cuBLAS m = our n (swapped)
?m ; cuBLAS n = our m (swapped)
?k ; k unchanged
"N" ; transa = No transpose (B is row-major, viewed as col-major [n,k])
"T" ; transb = Transpose (A is column-major [m,k], need A^T[k,m])
?b_k_stride ; lda = B's row stride (resolves to n after z→1)
?a_k_stride ; ldb = A's column stride (resolves to m after z→1)
?n ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
(MNum 1) ; batch_count = 1
(MNum 0) ; stride_a = 0
(MNum 0) ; stride_b = 0
(MNum 0) ; stride_c = 0
?dt) ; dtype
(ICons ?b (ICons ?a (INil)))))
(union ?sum ?sgemm)
(set (dtype ?sgemm) ?dt)
)
:name "cublaslt column-major × row-major"
)
; Batched Column-major × Row-major: C[batch,m,n] = A[batch,m,k] × B[batch,k,n]
; A column-major per batch: a_m_stride=MIter, a_n_stride=0
; B row-major per batch: b_n_stride=MIter, b_m_stride=0
(rule
(
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
(= ?batch (nth_from_end ?out_shape 2))
(= ?m (nth_from_end ?out_shape 1))
(= ?n (nth_from_end ?out_shape 0))
(!= ?m (MNum 0))
(!= ?n (MNum 0))
(!= ?k (MNum 1))
(!= ?batch (MNum 0))
(= ?a_batch_stride (nth_from_end ?a_stride 3))
(= ?a_m_stride (nth_from_end ?a_stride 2))
(= ?a_n_stride (nth_from_end ?a_stride 1))
(= ?a_k_stride (nth_from_end ?a_stride 0))
(= ?b_batch_stride (nth_from_end ?b_stride 3))
(= ?b_m_stride (nth_from_end ?b_stride 2))
(= ?b_n_stride (nth_from_end ?b_stride 1))
(= ?b_k_stride (nth_from_end ?b_stride 0))
(= ?k_stride (MIter))
; A column-major: m=MIter, n=0, k_stride=m*MIter
(= ?a_m_stride (MIter))
(= ?a_n_stride (MNum 0))
(= ?a_k_stride (MMul (MIter) ?m))
; B row-major: n=MIter, m=0, k_stride=n*MIter
(= ?b_n_stride (MIter))
(= ?b_m_stride (MNum 0))
(= ?b_k_stride (MMul (MIter) ?n))
; Uniform batch strides (contiguous per batch)
(= ?a_batch_stride (MMul ?k ?a_k_stride))
(= ?b_batch_stride (MMul ?k ?b_k_stride))
(= ?dt (dtype ?a))
(= ?dt (dtype ?b))
)
(
; cuBLAS: cublas(OP_N, OP_T, n, m, k, B, lda=b_k_stride, A, ldb=a_k_stride, C, ldc=n)
(let ?sgemm (Op (cublaslt
?n ?m ?k
"N" "T"
?b_k_stride ; lda (cuBLAS A = our B, row stride)
?a_k_stride ; ldb (cuBLAS B = our A, column stride)
?n ; ldc
?batch
?b_batch_stride ; stride_a (cuBLAS A = our B)
?a_batch_stride ; stride_b (cuBLAS B = our A)
(MMul ?m ?n) ; stride_c
?dt)
(ICons ?b (ICons ?a (INil)))))
(union ?sum ?sgemm)
(set (dtype ?sgemm) ?dt)
)
:name "cublaslt batched column-major × row-major"
)

View File

@@ -1,133 +0,0 @@
; Row-major × Column-major matmul: C[m,n] = A[m,k] × B[k,n]
; A[m,k] row-major → expand to [m, n, k] with strides [k, 0, MIter]
; B[k,n] column-major → expand to [m, n, k] with strides [0, k, MIter]
;
; Row-major viewed as column-major (swap trick):
; Row-major A[m,k] ≡ column-major A^T[k,m] with lda=k
; Column-major B[k,n] is already column-major with ldb=k
; Row-major C[m,n] ≡ column-major C^T[n,m] with ldc=n
;
; C^T[n,m] = (A × B)^T = B^T[n,k] × A^T[k,m]
; cuBLAS: cublasSgemm(OP_T, OP_N, n, m, k, α, B, k, A, k, β, C, n)
(rule
(
; Match Mul node
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
; Match Sum that reduces the Mul (k dimension)
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
; Match exactly 2D output shape
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
(!= ?m (MNum 0))
(!= ?n (MNum 0))
(!= ?k (MNum 1))
; Match exactly 3D strides [m, n, k]
(= ?a_stride (ECons ?a_m_stride (ECons ?a_n_stride (ECons ?a_k_stride (ENil)))))
(= ?b_stride (ECons ?b_m_stride (ECons ?b_n_stride (ECons ?b_k_stride (ENil)))))
; Assert contiguous k stride on output (required for reduction)
(= ?k_stride (MIter))
; Assert A has strides [k*MIter, 0, MIter] (row-major A[m,k] broadcast to [m,n,k])
(= ?a_m_stride (MMul (MIter) ?k))
(= ?a_n_stride (MNum 0))
(= ?a_k_stride (MIter))
; Assert B has strides [0, k*MIter, MIter] (column-major B[k,n] broadcast to [m,n,k])
(= ?b_m_stride (MNum 0))
(= ?b_n_stride (MMul (MIter) ?k))
(= ?b_k_stride (MIter))
(= ?dt (dtype ?a))
(= ?dt (dtype ?b))
)
(
; For row-major A × column-major B with cuBLAS:
; C^T = B^T × A^T → cublasSgemm(OP_T, OP_N, n, m, k, α, B, k, A, k, β, C, n)
(let ?sgemm (Op (cublaslt
?n ; cuBLAS m = our n (swapped)
?m ; cuBLAS n = our m (swapped)
?k ; k unchanged
"T" ; transa = Transpose (B is column-major, need B^T)
"N" ; transb = No transpose
?b_n_stride ; lda = B's column stride (resolves to k after z→1)
?a_m_stride ; ldb = A's row stride (resolves to k after z→1)
?n ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
(MNum 1) ; batch_count = 1
(MNum 0) ; stride_a = 0
(MNum 0) ; stride_b = 0
(MNum 0) ; stride_c = 0
?dt) ; dtype
(ICons ?b (ICons ?a (INil)))))
(union ?sum ?sgemm)
(set (dtype ?sgemm) ?dt)
)
:name "cublaslt row-major × column-major"
)
; Batched Row-major × Column-major: C[batch,m,n] = A[batch,m,k] × B[batch,k,n]
; A row-major per batch: a_k_stride=MIter, a_n_stride=0
; B column-major per batch: b_k_stride=MIter, b_m_stride=0
(rule
(
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
(= ?batch (nth_from_end ?out_shape 2))
(= ?m (nth_from_end ?out_shape 1))
(= ?n (nth_from_end ?out_shape 0))
(!= ?m (MNum 0))
(!= ?n (MNum 0))
(!= ?k (MNum 1))
(!= ?batch (MNum 0))
(= ?a_batch_stride (nth_from_end ?a_stride 3))
(= ?a_m_stride (nth_from_end ?a_stride 2))
(= ?a_n_stride (nth_from_end ?a_stride 1))
(= ?a_k_stride (nth_from_end ?a_stride 0))
(= ?b_batch_stride (nth_from_end ?b_stride 3))
(= ?b_m_stride (nth_from_end ?b_stride 2))
(= ?b_n_stride (nth_from_end ?b_stride 1))
(= ?b_k_stride (nth_from_end ?b_stride 0))
(= ?k_stride (MIter))
; A row-major: k=MIter, n=0, m_stride=k*MIter
(= ?a_k_stride (MIter))
(= ?a_n_stride (MNum 0))
(= ?a_m_stride (MMul (MIter) ?k))
; B column-major: k=MIter, m=0, n_stride=k*MIter
(= ?b_k_stride (MIter))
(= ?b_m_stride (MNum 0))
(= ?b_n_stride (MMul (MIter) ?k))
; Uniform batch strides (contiguous per batch)
(= ?a_batch_stride (MMul ?m ?a_m_stride))
(= ?b_batch_stride (MMul ?n ?b_n_stride))
(= ?dt (dtype ?a))
(= ?dt (dtype ?b))
)
(
; cuBLAS: cublas(OP_T, OP_N, n, m, k, B, lda=b_n_stride, A, ldb=a_m_stride, C, ldc=n)
(let ?sgemm (Op (cublaslt
?n ?m ?k
"T" "N"
?b_n_stride ; lda (cuBLAS A = our B, column stride)
?a_m_stride ; ldb (cuBLAS B = our A, row stride)
?n ; ldc
?batch
?b_batch_stride ; stride_a (cuBLAS A = our B)
?a_batch_stride ; stride_b (cuBLAS B = our A)
(MMul ?m ?n) ; stride_c
?dt)
(ICons ?b (ICons ?a (INil)))))
(union ?sum ?sgemm)
(set (dtype ?sgemm) ?dt)
)
:name "cublaslt batched row-major × column-major"
)

View File

@@ -1,139 +0,0 @@
; Row-major matmul: C[m,n] = A[m,k] × B[k,n]
; A[m,k] row-major → expand to [m, n, k] with strides [k, 0, MIter]
; B[k,n] row-major → permute to [n,k] then expand to [m, n, k] with strides [0, MIter, n]
;
; Row-major viewed as column-major (swap trick):
; Row-major A[m,k] ≡ column-major [k,m] with lda=k
; Row-major B[k,n] ≡ column-major [n,k] with ldb=n
; Row-major C[m,n] ≡ column-major [n,m] with ldc=n
;
; cuBLAS computes: C_col[n,m] = B_col[n,k] × A_col[k,m]
; cublasSgemm(OP_N, OP_N, n, m, k, α, B, n, A, k, β, C, n)
(rule
(
; Match Mul node
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
; Match Sum that reduces the Mul (k dimension)
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
; Match exactly 2D output shape
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
(!= ?m (MNum 0))
(!= ?n (MNum 0))
(!= ?k (MNum 1))
; Match exactly 3D strides [m, n, k]
(= ?a_stride (ECons ?a_m_stride (ECons ?a_n_stride (ECons ?a_k_stride (ENil)))))
(= ?b_stride (ECons ?b_m_stride (ECons ?b_n_stride (ECons ?b_k_stride (ENil)))))
; Assert contiguous k stride on output (required for reduction)
(= ?k_stride (MIter))
; Assert A has strides [k*MIter, 0, MIter] (row-major A[m,k] broadcast to [m,n,k])
(= ?a_m_stride (MMul (MIter) ?k))
(= ?a_n_stride (MNum 0))
(= ?a_k_stride (MIter))
; Assert B has strides [0, MIter, n*MIter] (row-major B[k,n] permuted to [n,k] then broadcast to [m,n,k])
(= ?b_m_stride (MNum 0))
(= ?b_n_stride (MIter))
(= ?b_k_stride (MMul (MIter) ?n))
(= ?dt (dtype ?a))
(= ?dt (dtype ?b))
)
(
; For row-major C = A × B with cuBLAS (column-major):
; cublasSgemm(OP_N, OP_N, n, m, k, α, B, n, A, k, β, C, n)
(let ?sgemm (Op (cublaslt
?n ; cuBLAS m = our n (swapped)
?m ; cuBLAS n = our m (swapped)
?k ; k unchanged
"N" ; transa = No transpose
"N" ; transb = No transpose
?b_k_stride ; lda = B's row stride (resolves to n after z→1)
?a_m_stride ; ldb = A's row stride (resolves to k after z→1)
?n ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
(MNum 1) ; batch_count = 1
(MNum 0) ; stride_a = 0
(MNum 0) ; stride_b = 0
(MNum 0) ; stride_c = 0
?dt) ; dtype
(ICons ?b (ICons ?a (INil)))))
(union ?sum ?sgemm)
(set (dtype ?sgemm) ?dt)
)
:name "cublaslt row-major x row-major"
)
; Batched Row-major × Row-major: C[batch,m,n] = A[batch,m,k] × B[batch,k,n]
; In broadcast [batch, m, n, k] space:
; A row-major per batch: a_k_stride=MIter, a_n_stride=0
; B row-major per batch: b_n_stride=MIter, b_m_stride=0
; Leading dimensions may differ from k/n when batch slices are non-contiguous.
(rule
(
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
; Output shape: [batch, m, n]
(= ?batch (nth_from_end ?out_shape 2))
(= ?m (nth_from_end ?out_shape 1))
(= ?n (nth_from_end ?out_shape 0))
(!= ?m (MNum 0))
(!= ?n (MNum 0))
(!= ?k (MNum 1))
(!= ?batch (MNum 0))
; A strides in [batch, m, n, k]
(= ?a_batch_stride (nth_from_end ?a_stride 3))
(= ?a_m_stride (nth_from_end ?a_stride 2))
(= ?a_n_stride (nth_from_end ?a_stride 1))
(= ?a_k_stride (nth_from_end ?a_stride 0))
; B strides in [batch, m, n, k]
(= ?b_batch_stride (nth_from_end ?b_stride 3))
(= ?b_m_stride (nth_from_end ?b_stride 2))
(= ?b_n_stride (nth_from_end ?b_stride 1))
(= ?b_k_stride (nth_from_end ?b_stride 0))
(= ?k_stride (MIter))
; A row-major: k=MIter, n=0, m_stride=k*MIter
(= ?a_k_stride (MIter))
(= ?a_n_stride (MNum 0))
(= ?a_m_stride (MMul (MIter) ?k))
; B row-major: n=MIter, m=0, k_stride=n*MIter
(= ?b_n_stride (MIter))
(= ?b_m_stride (MNum 0))
(= ?b_k_stride (MMul (MIter) ?n))
; Uniform batch strides (contiguous per batch, no GQA-style repetition)
(= ?a_batch_stride (MMul ?m ?a_m_stride))
(= ?b_batch_stride (MMul ?k ?b_k_stride))
(= ?dt (dtype ?a))
(= ?dt (dtype ?b))
)
(
; cuBLAS swap: C^T[n,m] = B^T[n,k] × A^T[k,m] per batch
; cublas(OP_N, OP_N, n, m, k, B, lda=b_k_stride, A, ldb=a_m_stride, C, ldc=n)
(let ?sgemm (Op (cublaslt
?n ?m ?k
"N" "N"
?b_k_stride ; lda (cuBLAS A = our B, row stride)
?a_m_stride ; ldb (cuBLAS B = our A, row stride)
?n ; ldc (contiguous output per batch)
?batch ; batch_count
?b_batch_stride ; stride_a (cuBLAS A = our B)
?a_batch_stride ; stride_b (cuBLAS B = our A)
(MMul ?m ?n) ; stride_c
?dt)
(ICons ?b (ICons ?a (INil)))))
(union ?sum ?sgemm)
(set (dtype ?sgemm) ?dt)
)
:name "cublaslt batched row-major × row-major"
)

View File

@@ -1,128 +0,0 @@
; GLUMoE: Match the expert computation subgraph of a Gated MoE (SwiGLU variant).
;
; This matches the pattern produced by QwenMoE::forward() starting from the
; expert gathers through to the final weighted sum, and replaces it with a
; fused GLUMoE HostOp.
;
; Inputs extracted:
; ?x - input activations [s, H] F32
; ?topk_idx - top-k expert indices [s, k] Int (from argsort+slice)
; ?topk_vals - top-k routing values [s, k] F32 (from gather on softmax)
; ?gate_up_w - stacked gate+up expert weights [E, intermediate*2, H] BF16
; ?down_w - stacked down expert weights [E, H, intermediate] BF16
;
; The pattern captures:
; 1. Gate-up expert gather (Iota, Mul, Cast, Iota, Cast, Add, Cast, Gather)
; 2. Cast BF16→F32 of gathered gate-up weights
; 3. Gate-up batched matmul (Mul + SumReduce)
; 4. Gate/Up split via Iota+Gather (slice semantics)
; 5. SwiGLU: silu(gate) * up
; 6. Down expert gather (same pattern as gate-up)
; 7. Cast BF16→F32 of gathered down weights
; 8. Down batched matmul (Mul + SumReduce)
; 9. Weighted sum: (down_out * topk_values) summed over k
;
; Variables with ? prefix are egglog pattern variables.
; We use wildcards (?_xxx) for shapes/strides we don't extract.
(rule
(
; ===== Gate-up expert gather =====
; t51: Iota for base index (expert_idx * io_gu)
(= ?gu_iota_base (Op (Iota ?gu_io ?gu_iota_base_range) (INil)))
; t52: Mul topk_indices * io → base offsets [s, k]
(= ?gu_mul_base (Op (Mul ?gu_mul_base_shape ?gu_mul_base_a_stride ?gu_mul_base_b_stride ?gu_mul_base_out_stride) (ICons ?topk_idx (ICons ?gu_iota_base (INil)))))
; t53: Cast to F32
(= ?gu_cast_base (Op (Cast ?gu_cast_base_size (F32)) (ICons ?gu_mul_base (INil))))
; t54: Iota for within-expert index
(= ?gu_iota_within (Op (Iota (MIter) ?gu_iota_within_range) (INil)))
; t55: Cast within to F32
(= ?gu_cast_within (Op (Cast ?gu_cast_within_size (F32)) (ICons ?gu_iota_within (INil))))
; t56: Add base + within → flat gather indices
(= ?gu_add_idx (Op (Add ?gu_add_shape ?gu_add_a_stride ?gu_add_b_stride ?gu_add_out_stride) (ICons ?gu_cast_base (ICons ?gu_cast_within (INil)))))
; t57: Cast to Int
(= ?gu_cast_idx (Op (Cast ?gu_cast_idx_size (Int)) (ICons ?gu_add_idx (INil))))
; t58: Gather gate_up weights
(= ?gu_gathered (Op (Gather ?gu_gather_idx_shape ?gu_gather_idx_stride ?gu_gather_data_shape ?gu_gather_data_stride) (ICons ?gu_cast_idx (ICons ?gate_up_w (INil)))))
; ===== Cast BF16→F32 =====
; t59: Cast gathered gate_up to F32
(= ?gu_f32 (Op (Cast ?gu_f32_size (F32)) (ICons ?gu_gathered (INil))))
; ===== Gate-up batched matmul =====
; t60: Mul x * gathered_gu (broadcast multiply)
(= ?gu_matmul_mul (Op (Mul ?gu_matmul_mul_shape ?gu_matmul_a_stride ?gu_matmul_b_stride ?gu_matmul_mul_out_stride) (ICons ?x (ICons ?gu_f32 (INil)))))
; t61: SumReduce over K dimension
(= ?gu_matmul (Op (Sum ?gu_matmul_out_shape ?gu_matmul_k ?gu_matmul_in_stride ?gu_matmul_k_stride ?gu_matmul_out_stride) (ICons ?gu_matmul_mul (INil))))
; ===== Up slice via Iota+Gather =====
; t62: Iota with complex expression (slicing the "up" half)
(= ?up_iota (Op (Iota ?up_iota_expr ?up_iota_range) (INil)))
; t63: Gather to select up portion from matmul result
(= ?up_slice (Op (Gather ?up_gather_idx_shape ?up_gather_idx_stride ?up_gather_data_shape ?up_gather_data_stride) (ICons ?up_iota (ICons ?gu_matmul (INil)))))
; ===== SwiGLU: silu(gate) * up =====
; t64: Constant(-1)
(= ?neg1 (Op (Constant -1.000000) (INil)))
; t65: gate * -1
(= ?neg_gate (Op (Mul ?silu_shape1 ?silu_a_stride1 ?silu_b_stride1 ?silu_out_stride1) (ICons ?gu_matmul (ICons ?neg1 (INil)))))
; t66: Constant(log2e)
(= ?log2e (Op (Constant 1.442695) (INil)))
; t67: neg_gate * log2e
(= ?scaled (Op (Mul ?silu_shape2 ?silu_a_stride2 ?silu_b_stride2 ?silu_out_stride2) (ICons ?neg_gate (ICons ?log2e (INil)))))
; t68: exp2
(= ?exp2_val (Op (Exp2 ?silu_shape3 ?silu_in_stride3 ?silu_out_stride3) (ICons ?scaled (INil))))
; t69: Constant(1)
(= ?one (Op (Constant 1.000000) (INil)))
; t70: exp2 + 1
(= ?plus1 (Op (Add ?silu_shape4 ?silu_a_stride4 ?silu_b_stride4 ?silu_out_stride4) (ICons ?exp2_val (ICons ?one (INil)))))
; t71: recip
(= ?sigmoid (Op (Recip ?silu_shape5 ?silu_in_stride5 ?silu_out_stride5) (ICons ?plus1 (INil))))
; t72: gate * sigmoid(gate) = silu(gate)
(= ?silu_out (Op (Mul ?silu_shape6 ?silu_a_stride6 ?silu_b_stride6 ?silu_out_stride6) (ICons ?gu_matmul (ICons ?sigmoid (INil)))))
; t73: silu(gate) * up
(= ?swiglu_out (Op (Mul ?swiglu_shape ?swiglu_a_stride ?swiglu_b_stride ?swiglu_out_stride) (ICons ?silu_out (ICons ?up_slice (INil)))))
; ===== Down expert gather =====
; t74: Iota for base index (expert_idx * io_down)
(= ?dn_iota_base (Op (Iota ?dn_io ?dn_iota_base_range) (INil)))
; t75: Mul topk_indices * io_down
(= ?dn_mul_base (Op (Mul ?dn_mul_base_shape ?dn_mul_base_a_stride ?dn_mul_base_b_stride ?dn_mul_base_out_stride) (ICons ?topk_idx (ICons ?dn_iota_base (INil)))))
; t76: Cast to F32
(= ?dn_cast_base (Op (Cast ?dn_cast_base_size (F32)) (ICons ?dn_mul_base (INil))))
; t77: Iota for within-expert index
(= ?dn_iota_within (Op (Iota (MIter) ?dn_iota_within_range) (INil)))
; t78: Cast within to F32
(= ?dn_cast_within (Op (Cast ?dn_cast_within_size (F32)) (ICons ?dn_iota_within (INil))))
; t79: Add base + within
(= ?dn_add_idx (Op (Add ?dn_add_shape ?dn_add_a_stride ?dn_add_b_stride ?dn_add_out_stride) (ICons ?dn_cast_base (ICons ?dn_cast_within (INil)))))
; t80: Cast to Int
(= ?dn_cast_idx (Op (Cast ?dn_cast_idx_size (Int)) (ICons ?dn_add_idx (INil))))
; t81: Gather down weights
(= ?dn_gathered (Op (Gather ?dn_gather_idx_shape ?dn_gather_idx_stride ?dn_gather_data_shape ?dn_gather_data_stride) (ICons ?dn_cast_idx (ICons ?down_w (INil)))))
; ===== Cast BF16→F32 =====
; t82: Cast gathered down to F32
(= ?dn_f32 (Op (Cast ?dn_f32_size (F32)) (ICons ?dn_gathered (INil))))
; ===== Down batched matmul =====
; t83: Mul swiglu_out * gathered_down (broadcast multiply)
(= ?dn_matmul_mul (Op (Mul ?dn_matmul_mul_shape ?dn_matmul_a_stride ?dn_matmul_b_stride ?dn_matmul_mul_out_stride) (ICons ?swiglu_out (ICons ?dn_f32 (INil)))))
; t84: SumReduce
(= ?dn_matmul (Op (Sum ?dn_matmul_out_shape ?dn_matmul_k ?dn_matmul_in_stride ?dn_matmul_k_stride ?dn_matmul_out_stride) (ICons ?dn_matmul_mul (INil))))
; ===== Weighted sum over k experts =====
; t85: Mul down_out * topk_values
(= ?weighted (Op (Mul ?weighted_shape ?weighted_a_stride ?weighted_b_stride ?weighted_out_stride) (ICons ?dn_matmul (ICons ?topk_vals (INil)))))
; t86: SumReduce over k dimension → [s, H]
(= ?output (Op (Sum ?output_shape ?output_k ?output_in_stride ?output_k_stride ?output_out_stride) (ICons ?weighted (INil))))
)
(
(let ?glumoe (Op (GLUMoE
?gu_io ?dn_io ?gu_matmul_k ?dn_matmul_k ?output_k
?gu_iota_within_range ?dn_iota_within_range)
(ICons ?x (ICons ?topk_idx (ICons ?topk_vals (ICons ?gate_up_w (ICons ?down_w (INil))))))))
(union ?output ?glumoe)
)
:name "GLUMoE fused expert computation"
)

File diff suppressed because it is too large Load Diff

View File

@@ -1,309 +0,0 @@
pub mod host;
pub mod kernel;
pub mod runtime;
use std::{
ffi::{CStr, CString},
path::Path,
sync::Arc,
};
pub use cudarc;
#[cfg(test)]
mod tests;
use cudarc::{
driver::{CudaContext, DriverError, sys as driver_sys},
nvrtc::{
Ptx,
result::{self as nvrtc_result, NvrtcError},
sys as nvrtc_sys,
},
};
use luminal::dtype::DType;
fn cuda_dtype(dtype: DType) -> &'static str {
match dtype {
DType::F64 => "double",
DType::F32 => "float",
DType::F16 => "half",
DType::Bf16 => "__nv_bfloat16",
DType::TF32 => "float", // TF32 uses float storage, tensor cores handle the format
DType::Int => "int",
DType::I16 => "short",
DType::U16 => "unsigned short",
DType::I8 => "signed char",
DType::U8 => "unsigned char",
DType::Bool => "unsigned char",
DType::F8E4M3 => "__nv_fp8_e4m3",
DType::F8E5M2 => "__nv_fp8_e5m2",
DType::F8UE8M0 => "__nv_fp8_e8m0",
DType::F6E2M3 => "__nv_fp6_e2m3",
DType::F6E3M2 => "__nv_fp6_e3m2",
DType::F4E2M1 => "__nv_fp4_e2m1",
DType::I4 | DType::U4 => "unsigned char", // Sub-byte, packed storage
}
}
const CUDA_NVRTC_INCLUDE_PATHS: [&str; 2] = ["/usr/local/cuda/include", "/usr/include"];
#[derive(Debug)]
pub(crate) enum CudaModuleImageCompileFailure {
ComputeCapability(DriverError),
Nvrtc {
stage: &'static str,
error: NvrtcError,
},
NoModuleImageProduced,
}
#[derive(Debug)]
pub(crate) struct CudaModuleImageCompileError {
pub target_arch: Option<String>,
pub driver_version: Option<i32>,
pub runtime_version: Option<i32>,
pub nvrtc_options: Vec<String>,
pub nvrtc_log: Option<String>,
pub failure: CudaModuleImageCompileFailure,
}
impl std::fmt::Display for CudaModuleImageCompileError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "failed to compile CUDA module image")?;
if let Some(target_arch) = &self.target_arch {
write!(f, " for {target_arch}")?;
}
match &self.failure {
CudaModuleImageCompileFailure::ComputeCapability(error) => {
write!(f, ": failed to query compute capability: {error}")?;
}
CudaModuleImageCompileFailure::Nvrtc { stage, error } => {
write!(f, ": NVRTC {stage} failed: {error}")?;
}
CudaModuleImageCompileFailure::NoModuleImageProduced => {
write!(f, ": NVRTC produced no CUBIN for the selected target")?;
}
}
if let Some(version) = self.driver_version {
write!(f, " | driver {}", format_cuda_version(version))?;
}
if let Some(version) = self.runtime_version {
write!(f, " | runtime {}", format_cuda_version(version))?;
}
if !self.nvrtc_options.is_empty() {
write!(f, " | options {:?}", self.nvrtc_options)?;
}
if let Some(log) = &self.nvrtc_log {
write!(f, " | log: {log}")?;
}
Ok(())
}
}
impl std::error::Error for CudaModuleImageCompileError {}
fn format_cuda_version(version: i32) -> String {
format!("{}.{}", version / 1000, (version % 1000) / 10)
}
fn cuda_nvrtc_include_paths() -> Vec<String> {
let mut include_paths = Vec::new();
for env_var in ["CUDA_HOME", "CUDA_PATH", "CUDA_ROOT"] {
if let Ok(root) = std::env::var(env_var) {
let path = format!("{root}/include");
if Path::new(&path).exists() && !include_paths.contains(&path) {
include_paths.push(path);
}
}
}
for path in CUDA_NVRTC_INCLUDE_PATHS {
let path = path.to_string();
if Path::new(&path).exists() && !include_paths.contains(&path) {
include_paths.push(path);
}
}
include_paths
}
fn cuda_driver_diagnostics() -> (Option<i32>, Option<i32>) {
let mut driver_version = 0;
let driver_version = unsafe { driver_sys::cuDriverGetVersion(&mut driver_version as *mut _) }
.result()
.ok()
.map(|_| driver_version);
// Avoid touching cudarc's runtime loader here. On some environments it eagerly
// resolves newer libcudart symbols that may not exist in the installed runtime.
(driver_version, None)
}
fn cuda_nvrtc_compile_options(target_arch: &str) -> Vec<String> {
let mut options = cuda_nvrtc_include_paths()
.into_iter()
.map(|path| format!("--include-path={path}"))
.collect::<Vec<_>>();
options.push(format!("--gpu-architecture={target_arch}"));
options
}
fn build_module_image_compile_error(
target_arch: Option<String>,
driver_version: Option<i32>,
runtime_version: Option<i32>,
nvrtc_options: &[String],
nvrtc_log: Option<String>,
failure: CudaModuleImageCompileFailure,
) -> CudaModuleImageCompileError {
CudaModuleImageCompileError {
target_arch,
driver_version,
runtime_version,
nvrtc_options: nvrtc_options.to_vec(),
nvrtc_log,
failure,
}
}
fn read_nvrtc_log(program: nvrtc_sys::nvrtcProgram) -> Option<String> {
let raw = unsafe { nvrtc_result::get_program_log(program).ok()? };
if raw.is_empty() {
return None;
}
let log = unsafe { CStr::from_ptr(raw.as_ptr()) }
.to_string_lossy()
.trim_end_matches('\0')
.trim()
.to_string();
if log.is_empty() { None } else { Some(log) }
}
#[allow(clippy::slow_vector_initialization)]
fn get_cubin(program: nvrtc_sys::nvrtcProgram) -> Result<Vec<u8>, NvrtcError> {
let mut cubin_size = 0usize;
unsafe { nvrtc_sys::nvrtcGetCUBINSize(program, &mut cubin_size as *mut _) }.result()?;
if cubin_size == 0 {
return Ok(Vec::new());
}
let mut cubin = Vec::with_capacity(cubin_size);
cubin.resize(cubin_size, 0);
unsafe { nvrtc_sys::nvrtcGetCUBIN(program, cubin.as_mut_ptr()) }.result()?;
Ok(cubin.into_iter().map(|byte| byte as u8).collect())
}
pub(crate) fn compile_module_image_for_current_device<S: AsRef<str>>(
ctx: &Arc<CudaContext>,
src: S,
) -> Result<Ptx, CudaModuleImageCompileError> {
let (driver_version, runtime_version) = cuda_driver_diagnostics();
let (major, minor) = ctx.compute_capability().map_err(|error| {
build_module_image_compile_error(
None,
driver_version,
runtime_version,
&[],
None,
CudaModuleImageCompileFailure::ComputeCapability(error),
)
})?;
let target_arch = format!("sm_{major}{minor}");
let nvrtc_options = cuda_nvrtc_compile_options(&target_arch);
let source = CString::new(src.as_ref().as_bytes())
.expect("CUDA source code cannot contain null terminators");
let program = nvrtc_result::create_program(&source, None).map_err(|error| {
build_module_image_compile_error(
Some(target_arch.clone()),
driver_version,
runtime_version,
&nvrtc_options,
None,
CudaModuleImageCompileFailure::Nvrtc {
stage: "create_program",
error,
},
)
})?;
if let Err(error) = unsafe { nvrtc_result::compile_program(program, &nvrtc_options) } {
let nvrtc_log = read_nvrtc_log(program);
let _ = unsafe { nvrtc_result::destroy_program(program) };
return Err(build_module_image_compile_error(
Some(target_arch),
driver_version,
runtime_version,
&nvrtc_options,
nvrtc_log,
CudaModuleImageCompileFailure::Nvrtc {
stage: "compile_program",
error,
},
));
}
let nvrtc_log = read_nvrtc_log(program);
let cubin = match get_cubin(program) {
Ok(cubin) => cubin,
Err(error) => {
let _ = unsafe { nvrtc_result::destroy_program(program) };
return Err(build_module_image_compile_error(
Some(target_arch),
driver_version,
runtime_version,
&nvrtc_options,
nvrtc_log,
CudaModuleImageCompileFailure::Nvrtc {
stage: "get_cubin",
error,
},
));
}
};
if let Err(error) = unsafe { nvrtc_result::destroy_program(program) } {
return Err(build_module_image_compile_error(
Some(target_arch),
driver_version,
runtime_version,
&nvrtc_options,
nvrtc_log,
CudaModuleImageCompileFailure::Nvrtc {
stage: "destroy_program",
error,
},
));
}
if cubin.is_empty() {
return Err(build_module_image_compile_error(
Some(target_arch),
driver_version,
runtime_version,
&nvrtc_options,
nvrtc_log,
CudaModuleImageCompileFailure::NoModuleImageProduced,
));
}
Ok(Ptx::from_binary(cubin))
}
/// Returns the bandwidth of the device in GB/s
pub fn cuda_bandwidth_gbps(ctx: &Arc<CudaContext>) -> Option<usize> {
Some(match ctx.name().unwrap().as_str() {
"NVIDIA Thor" => 273,
"NVIDIA H100 PCIe" => 2_000,
"NVIDIA H100 SXM" => 3_350,
_ => return None,
})
}
/// Returns the bandwidth of the device in TFLOPs
pub fn cuda_compute_f32_tflops(ctx: &Arc<CudaContext>) -> Option<usize> {
Some(match ctx.name().unwrap().as_str() {
"NVIDIA Thor" => 125, // forced to use tf32 flops
"NVIDIA H100 PCIe" => 756,
"NVIDIA H100 SXM" => 989,
_ => return None,
})
}

View File

@@ -1,344 +0,0 @@
use crate::runtime::CudaRuntime;
use crate::tests::utilities::*;
use luminal::prelude::*;
use rand::{SeedableRng, rngs::SmallRng};
/// Helper: build a simple graph with dynamic dim 's' that does element-wise computation.
/// Returns (cx, input_node, output_node).
fn build_dynamic_add_graph() -> (Graph, NodeIndex, NodeIndex) {
let mut cx = Graph::default();
let a = cx.tensor(('s', 4));
let b = (a + a).output();
(cx, a.id, b.id)
}
/// Helper: build a matmul graph with dynamic dim 's'.
/// Computes (s, K) @ (K, N) -> (s, N)
fn build_dynamic_matmul_graph(k: usize, n: usize) -> (Graph, NodeIndex, NodeIndex, NodeIndex) {
let mut cx = Graph::default();
let a = cx.tensor(('s', k));
let b = cx.tensor((k, n));
let c = a.matmul(b).output();
(cx, a.id, b.id, c.id)
}
#[test]
fn test_bucket_dispatch_simple() {
// Tests that bucketed compilation produces correct results for different dim values
let Some(stream) = get_cuda_stream() else {
return;
};
let (mut cx, a, b) = build_dynamic_add_graph();
cx.set_dim_buckets('s', &[DimBucket::new(1, 1), DimBucket::new(2, 4)]);
cx.build_search_space::<CudaRuntime>();
let mut rt = CudaRuntime::initialize(stream);
// Set dummy input for search
cx.set_dim('s', 1);
rt.set_data(a, vec![1.0f32; 4]);
let mut rng = SmallRng::seed_from_u64(42);
rt = cx.search_options(rt, SearchOptions::new(5), &mut rng);
// Test bucket 1: s=1
cx.set_dim('s', 1);
let input_data = vec![1.0f32, 2.0, 3.0, 4.0];
rt.set_data(a, input_data.clone());
rt.execute(&cx.dyn_map);
let result = rt.get_f32(b);
let expected: Vec<f32> = input_data.iter().map(|x| x * 2.0).collect();
assert_close(&result[..4], &expected, 1e-5, 1e-5);
// Test bucket 2: s=3
cx.set_dim('s', 3);
let input_data: Vec<f32> = (0..12).map(|i| i as f32).collect();
rt.set_data(a, input_data.clone());
rt.execute(&cx.dyn_map);
let result = rt.get_f32(b);
let expected: Vec<f32> = input_data.iter().map(|x| x * 2.0).collect();
assert_close(&result[..12], &expected, 1e-5, 1e-5);
}
#[test]
fn test_bucket_matmul_dynamic() {
// Tests matmul with bucketed dynamic dim
let Some(stream) = get_cuda_stream() else {
return;
};
let k = 8;
let n = 4;
let (mut cx, a, b_tensor, c) = build_dynamic_matmul_graph(k, n);
cx.set_dim_buckets('s', &[DimBucket::new(1, 1), DimBucket::new(2, 8)]);
cx.build_search_space::<CudaRuntime>();
let mut rt = CudaRuntime::initialize(stream);
cx.set_dim('s', 1);
let a_data = random_f32_vec(k, 100, -1.0, 1.0);
let b_data = random_f32_vec(k * n, 101, -1.0, 1.0);
rt.set_data(a, a_data.clone());
rt.set_data(b_tensor, b_data.clone());
let mut rng = SmallRng::seed_from_u64(42);
rt = cx.search_options(rt, SearchOptions::new(5), &mut rng);
// Execute at s=1
cx.set_dim('s', 1);
rt.set_data(a, a_data.clone());
rt.set_data(b_tensor, b_data.clone());
rt.execute(&cx.dyn_map);
let result_s1 = rt.get_f32(c);
// Compute reference for s=1 (1xK @ KxN -> 1xN)
let mut expected_s1 = vec![0.0f32; n];
for j in 0..n {
for i in 0..k {
expected_s1[j] += a_data[i] * b_data[i * n + j];
}
}
assert_close(&result_s1[..n], &expected_s1, 1e-4, 1e-4);
// Execute at s=4
cx.set_dim('s', 4);
let a_data_4 = random_f32_vec(4 * k, 200, -1.0, 1.0);
rt.set_data(a, a_data_4.clone());
rt.set_data(b_tensor, b_data.clone());
rt.execute(&cx.dyn_map);
let result_s4 = rt.get_f32(c);
// Compute reference for s=4 (4xK @ KxN -> 4xN)
let mut expected_s4 = vec![0.0f32; 4 * n];
for row in 0..4 {
for j in 0..n {
for i in 0..k {
expected_s4[row * n + j] += a_data_4[row * k + i] * b_data[i * n + j];
}
}
}
assert_close(&result_s4[..4 * n], &expected_s4, 1e-4, 1e-4);
}
#[test]
fn test_bucket_results_match_unbucketed() {
// Tests that bucketed results match non-bucketed results for the same graph
let Some(stream) = get_cuda_stream() else {
return;
};
let seed = 42u64;
// Non-bucketed run
let (mut cx1, a1, b1) = build_dynamic_add_graph();
cx1.set_dim('s', 3);
cx1.build_search_space::<CudaRuntime>();
let mut rt1 = CudaRuntime::initialize(stream.clone());
let input_data = random_f32_vec(12, seed, -1.0, 1.0);
rt1.set_data(a1, input_data.clone());
let mut rng1 = SmallRng::seed_from_u64(seed);
rt1 = cx1.search_options(rt1, SearchOptions::new(5), &mut rng1);
rt1.set_data(a1, input_data.clone());
rt1.execute(&cx1.dyn_map);
let result_unbucketed = rt1.get_f32(b1);
// Bucketed run with bucket that covers s=3
let (mut cx2, a2, b2) = build_dynamic_add_graph();
cx2.set_dim('s', 3);
cx2.set_dim_buckets('s', &[DimBucket::new(1, 4)]);
cx2.build_search_space::<CudaRuntime>();
let mut rt2 = CudaRuntime::initialize(stream.clone());
rt2.set_data(a2, input_data.clone());
let mut rng2 = SmallRng::seed_from_u64(seed);
rt2 = cx2.search_options(rt2, SearchOptions::new(5), &mut rng2);
rt2.set_data(a2, input_data.clone());
rt2.execute(&cx2.dyn_map);
let result_bucketed = rt2.get_f32(b2);
// Results should match — same graph, same search seed, same dyn_map
assert_eq!(result_unbucketed.len(), result_bucketed.len());
assert_close(&result_unbucketed[..12], &result_bucketed[..12], 1e-5, 1e-5);
}
#[test]
#[should_panic(expected = "No bucket matches")]
fn test_bucket_out_of_range_panics() {
let Some(stream) = get_cuda_stream() else {
// Can't trigger panic without GPU, skip gracefully
panic!("No bucket matches dyn_map");
};
let (mut cx, a, _b) = build_dynamic_add_graph();
cx.set_dim_buckets('s', &[DimBucket::new(1, 1), DimBucket::new(2, 4)]);
cx.build_search_space::<CudaRuntime>();
let mut rt = CudaRuntime::initialize(stream);
cx.set_dim('s', 1);
rt.set_data(a, vec![1.0f32; 4]);
let mut rng = SmallRng::seed_from_u64(42);
rt = cx.search_options(rt, SearchOptions::new(3), &mut rng);
// s=10 is outside all buckets — should panic
cx.set_dim('s', 10);
rt.set_data(a, vec![1.0f32; 40]);
rt.execute(&cx.dyn_map);
}
#[test]
fn test_bucket_no_buckets_backward_compat() {
// No buckets set → should behave identically to old path
let Some(stream) = get_cuda_stream() else {
return;
};
let (mut cx, a, b) = build_dynamic_add_graph();
cx.set_dim('s', 2);
// No set_dim_buckets call
cx.build_search_space::<CudaRuntime>();
let mut rt = CudaRuntime::initialize(stream);
let input_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
rt.set_data(a, input_data.clone());
let mut rng = SmallRng::seed_from_u64(42);
rt = cx.search_options(rt, SearchOptions::new(3), &mut rng);
rt.set_data(a, input_data.clone());
rt.execute(&cx.dyn_map);
let result = rt.get_f32(b);
let expected: Vec<f32> = input_data.iter().map(|x| x * 2.0).collect();
assert_close(&result[..8], &expected, 1e-5, 1e-5);
}
#[test]
fn test_bucket_representative_override() {
// Tests that custom representative works
let bucket = DimBucket::new(2, 32).representative(16);
assert_eq!(bucket.representative_value(), 16);
let bucket_default = DimBucket::new(2, 32);
assert_eq!(bucket_default.representative_value(), 17); // (2+32)/2 = 17
let exact = DimBucket::new(1, 1);
assert_eq!(exact.representative_value(), 1);
}
#[test]
fn test_bucket_switch_preserves_weights() {
// Tests that switching between buckets still sees the correct weight data
let Some(stream) = get_cuda_stream() else {
return;
};
let k = 4;
let n = 4;
let (mut cx, a, b_tensor, c) = build_dynamic_matmul_graph(k, n);
cx.set_dim_buckets('s', &[DimBucket::new(1, 1), DimBucket::new(2, 4)]);
cx.build_search_space::<CudaRuntime>();
let mut rt = CudaRuntime::initialize(stream);
cx.set_dim('s', 1);
let a_data = random_f32_vec(k, 300, -1.0, 1.0);
let b_data = random_f32_vec(k * n, 301, -1.0, 1.0);
rt.set_data(a, a_data.clone());
rt.set_data(b_tensor, b_data.clone());
let mut rng = SmallRng::seed_from_u64(42);
rt = cx.search_options(rt, SearchOptions::new(5), &mut rng);
// Execute with bucket 1 (s=1)
cx.set_dim('s', 1);
rt.set_data(a, a_data.clone());
rt.set_data(b_tensor, b_data.clone());
rt.execute(&cx.dyn_map);
let result_1a = rt.get_f32(c);
// Switch to bucket 2 (s=3)
cx.set_dim('s', 3);
let a_data_3 = random_f32_vec(3 * k, 302, -1.0, 1.0);
rt.set_data(a, a_data_3.clone());
rt.set_data(b_tensor, b_data.clone());
rt.execute(&cx.dyn_map);
let result_3 = rt.get_f32(c);
// Switch back to bucket 1 (s=1) — weights should still work
cx.set_dim('s', 1);
rt.set_data(a, a_data.clone());
rt.set_data(b_tensor, b_data.clone());
rt.execute(&cx.dyn_map);
let result_1b = rt.get_f32(c);
// First and last s=1 results should match exactly
assert_close(&result_1a[..n], &result_1b[..n], 1e-6, 1e-6);
// Verify s=3 result correctness
let mut expected_3 = vec![0.0f32; 3 * n];
for row in 0..3 {
for j in 0..n {
for i in 0..k {
expected_3[row * n + j] += a_data_3[row * k + i] * b_data[i * n + j];
}
}
}
assert_close(&result_3[..3 * n], &expected_3, 1e-4, 1e-4);
}
#[test]
fn test_bucket_multiple_executions_same_bucket() {
// Tests multiple executions within the same bucket with different dim values
let Some(stream) = get_cuda_stream() else {
return;
};
let (mut cx, a, b) = build_dynamic_add_graph();
cx.set_dim_buckets('s', &[DimBucket::new(1, 8)]);
cx.build_search_space::<CudaRuntime>();
let mut rt = CudaRuntime::initialize(stream);
cx.set_dim('s', 1);
rt.set_data(a, vec![1.0f32; 4]);
let mut rng = SmallRng::seed_from_u64(42);
rt = cx.search_options(rt, SearchOptions::new(3), &mut rng);
// Execute at different sizes within the same bucket
for s in [1, 2, 4, 8] {
cx.set_dim('s', s);
let n = s * 4;
let input: Vec<f32> = (0..n).map(|i| i as f32).collect();
rt.set_data(a, input.clone());
rt.execute(&cx.dyn_map);
let result = rt.get_f32(b);
let expected: Vec<f32> = input.iter().map(|x| x * 2.0).collect();
assert_close(&result[..n], &expected, 1e-5, 1e-5);
}
}
#[test]
#[should_panic(expected = "Overlapping buckets")]
fn test_bucket_overlapping_ranges_panics() {
let mut cx = Graph::default();
cx.set_dim_buckets('s', &[DimBucket::new(1, 4), DimBucket::new(3, 8)]);
}
#[test]
fn test_dim_bucket_contains() {
let b = DimBucket::new(2, 10);
assert!(!b.contains(1));
assert!(b.contains(2));
assert!(b.contains(5));
assert!(b.contains(10));
assert!(!b.contains(11));
// Exact bucket
let exact = DimBucket::new(3, 3);
assert!(!exact.contains(2));
assert!(exact.contains(3));
assert!(!exact.contains(4));
}

View File

@@ -1,416 +0,0 @@
use cudarc::driver::CudaContext;
use luminal::prelude::*;
use rand::SeedableRng;
use luminal::egglog_utils::{egglog_to_llir, random_initial_choice, validate_choice_set};
use crate::kernel::KernelOp;
use crate::runtime::CudaRuntime;
/// Helper: build search space and extract all possible kernel names across many random choices.
fn extract_all_kernel_names(cx: &mut Graph) -> Vec<String> {
cx.build_search_space::<CudaRuntime>();
let egraph = cx.egraph().expect("egraph not built");
let ops = cx.egglog_ops().expect("ops not built");
let custom_ops = &cx.custom_ops;
let mut all_names = Vec::new();
// Try many random extractions to cover both alternatives
for _ in 0..20 {
let choices = random_initial_choice(egraph, &mut rand::rng());
let mut list_cache = Default::default();
let mut expr_cache = Default::default();
let llir = egglog_to_llir(
egraph,
choices,
ops,
custom_ops,
&mut list_cache,
&mut expr_cache,
None,
);
for op in llir.node_weights() {
if let Some(k) = op.to_dialect::<dyn KernelOp>() {
let name = k.kernel_name().to_string();
if !all_names.contains(&name) {
all_names.push(name);
}
}
}
}
all_names
}
/// When dest is NOT shared with any other op, KernelScatterNoCopy should be available.
/// The ConsumedBuffer cleanup rule should NOT fire because dest only appears inside
/// the ConsumedBuffer (not in any other ICons).
#[test]
fn test_scatter_nocopy_selected_when_dest_unshared() {
let ctx = CudaContext::new(0).unwrap();
ctx.bind_to_thread().unwrap();
let mut cx = Graph::default();
// dest: a 10-element buffer, src: 3 values, indexes: 3 indices
let dest = cx.tensor(10).persist();
let src = cx.tensor(3).persist();
let indexes = cx.tensor(3).as_dtype(DType::Int).persist();
// scatter src into dest at indexes
let _result = src.scatter(indexes, dest).output();
let names = extract_all_kernel_names(&mut cx);
println!("All possible kernels: {:?}", names);
// KernelScatterNoCopy should be available (dest is not shared)
assert!(
names.iter().any(|n| n == "ScatterNoCopy"),
"Expected ScatterNoCopy to be available but got: {:?}",
names
);
}
/// When dest IS shared (used by another op besides the scatter), the ConsumedBuffer
/// cleanup rule should fire, deleting the ConsumedBuffer. This makes KernelScatterNoCopy
/// invalid, so it should NOT appear in any extraction.
#[test]
fn test_scatter_nocopy_not_selected_when_dest_shared() {
let ctx = CudaContext::new(0).unwrap();
ctx.bind_to_thread().unwrap();
let mut cx = Graph::default();
// dest: a 10-element buffer, src: 3 values, indexes: 3 indices
let dest = cx.tensor(10).persist();
let src = cx.tensor(3).persist();
let indexes = cx.tensor(3).as_dtype(DType::Int).persist();
// scatter src into dest at indexes
let scatter_result = src.scatter(indexes, dest);
// Also use dest directly in another op (add with itself) — this makes dest shared
let _dest_also_used = (dest + dest).output();
let _result = scatter_result.output();
let names = extract_all_kernel_names(&mut cx);
println!("All possible kernels: {:?}", names);
// KernelScatterNoCopy should NOT be available (dest is shared with the add op)
assert!(
!names.iter().any(|n| n == "ScatterNoCopy"),
"ScatterNoCopy should NOT be available when dest is shared, got: {:?}",
names
);
// Regular KernelScatter should be present
assert!(
names.iter().any(|n| n == "Scatter"),
"Expected Scatter but got: {:?}",
names
);
}
/// Actually execute the scatter and verify correctness.
/// Tests all possible extractions (both KernelScatter and KernelScatterNoCopy).
#[test]
fn test_scatter_execution_correctness() {
let ctx = CudaContext::new(0).unwrap();
ctx.bind_to_thread().unwrap();
let stream = ctx.default_stream();
let mut cx = Graph::default();
// dest: [0.0, 1.0, 2.0, 3.0, 4.0]
let dest = cx.tensor(5).persist();
// src: [10.0, 20.0, 30.0]
let src = cx.tensor(3).persist();
// indexes: [1, 3, 4]
let indexes = cx.tensor(3).as_dtype(DType::Int).persist();
let result = src.scatter(indexes, dest).output();
cx.build_search_space::<CudaRuntime>();
let egraph = cx.egraph().expect("egraph not built");
let ops = cx.egglog_ops().expect("ops not built");
// Expected: [0.0, 10.0, 2.0, 20.0, 30.0]
let expected = vec![0.0f32, 10.0, 2.0, 20.0, 30.0];
// Try many random extractions to cover both Scatter and ScatterNoCopy
let mut rng = rand::rng();
let mut tested_scatter = false;
let mut tested_nocopy = false;
for _ in 0..50 {
let choices = random_initial_choice(egraph, &mut rng);
if validate_choice_set(egraph, &choices, ops).is_err() {
continue;
}
let mut list_cache = Default::default();
let mut expr_cache = Default::default();
let llir = egglog_to_llir(
egraph,
choices,
ops,
&cx.custom_ops,
&mut list_cache,
&mut expr_cache,
None,
);
// Check which scatter variant was selected
let mut has_nocopy = false;
let mut has_scatter = false;
for op in llir.node_weights() {
if let Some(k) = op.to_dialect::<dyn KernelOp>() {
match k.kernel_name() {
"ScatterNoCopy" => has_nocopy = true,
"Scatter" => has_scatter = true,
_ => {}
}
}
}
let mut rt = CudaRuntime::initialize(stream.clone());
rt.load_llir(&llir);
rt.set_data(dest, vec![0.0f32, 1.0, 2.0, 3.0, 4.0]);
rt.set_data(src, vec![10.0f32, 20.0, 30.0]);
rt.set_data(indexes, vec![1i32, 3, 4]);
rt.execute(&cx.dyn_map);
let actual = rt.get_f32(result);
let variant = if has_nocopy {
tested_nocopy = true;
"ScatterNoCopy"
} else if has_scatter {
tested_scatter = true;
"Scatter"
} else {
"Unknown"
};
assert_eq!(
actual, expected,
"Scatter result mismatch with variant {variant}: got {:?}, expected {:?}",
actual, expected
);
}
println!(
"Tested Scatter: {}, Tested ScatterNoCopy: {}",
tested_scatter, tested_nocopy
);
assert!(
tested_nocopy,
"ScatterNoCopy was never selected in 50 attempts — can't verify correctness"
);
}
/// Test the KV-cache round-trip pattern: scatter → remove_buffer → set_buffer → scatter again.
/// This mimics how the llama model uses scatter for KV cache updates.
#[test]
fn test_scatter_kv_cache_roundtrip() {
let ctx = CudaContext::new(0).unwrap();
ctx.bind_to_thread().unwrap();
let stream = ctx.default_stream();
let mut cx = Graph::default();
// KV cache: [5] elements (simulating a small cache)
let cache_in = cx.named_tensor("cache", 5).persist();
// New value to scatter: [1] element
let src = cx.tensor(1).persist();
// Index: [1] element (position to write)
let indexes = cx.tensor(1).as_dtype(DType::Int).persist();
// scatter src into cache at index position
let cache_out = src.scatter(indexes, cache_in);
// Also read the scatter output (simulates attention reading from cache)
let read_out = (cache_out + 0.0).output();
// Return cache for round-trip
let cache_output = cache_out.output();
cx.build_search_space::<CudaRuntime>();
let mut rt = CudaRuntime::initialize(stream.clone());
// Must set input data BEFORE search (profiler needs valid buffers)
rt.set_data(cache_in, vec![0.0f32; 5]);
rt.set_data(src, vec![10.0f32]);
rt.set_data(indexes, vec![0i32]);
rt = cx.search(rt, 5);
// Print which scatter variant was selected
for node in rt.llir_graph().node_weights() {
if let Some(k) = node.to_dialect::<dyn KernelOp>()
&& k.kernel_name().contains("catter")
{
println!("Selected: {}", k.kernel_name());
}
}
// Step 1: Initialize cache to zeros, scatter 10.0 at position 0
rt.set_data(cache_in, vec![0.0f32; 5]);
rt.set_data(src, vec![10.0f32]);
rt.set_data(indexes, vec![0i32]);
rt.execute(&cx.dyn_map);
let read1 = rt.get_f32(read_out);
println!("After step 1 (scatter 10.0 at pos 0): {:?}", read1);
assert_eq!(
read1,
vec![10.0, 0.0, 0.0, 0.0, 0.0],
"Step 1 read_out mismatch"
);
// Round-trip: remove cache output buffer, set as new cache input
let cache_buf = rt.remove_buffer(cache_output);
rt.set_buffer(cache_in, cache_buf);
// Step 2: Scatter 20.0 at position 1
rt.set_data(src, vec![20.0f32]);
rt.set_data(indexes, vec![1i32]);
rt.execute(&cx.dyn_map);
let read2 = rt.get_f32(read_out);
println!("After step 2 (scatter 20.0 at pos 1): {:?}", read2);
assert_eq!(
read2,
vec![10.0, 20.0, 0.0, 0.0, 0.0],
"Step 2 read_out mismatch"
);
// Round-trip again
let cache_buf = rt.remove_buffer(cache_output);
rt.set_buffer(cache_in, cache_buf);
// Step 3: Scatter 30.0 at position 2
rt.set_data(src, vec![30.0f32]);
rt.set_data(indexes, vec![2i32]);
rt.execute(&cx.dyn_map);
let read3 = rt.get_f32(read_out);
println!("After step 3 (scatter 30.0 at pos 2): {:?}", read3);
assert_eq!(
read3,
vec![10.0, 20.0, 30.0, 0.0, 0.0],
"Step 3 read_out mismatch"
);
}
/// Test scatter with TWO cache buffers and dual outputs (closer to llama K+V pattern).
/// Also verifies graph_break interaction.
#[test]
fn test_scatter_dual_cache_with_graph_break() {
let ctx = CudaContext::new(0).unwrap();
ctx.bind_to_thread().unwrap();
let stream = ctx.default_stream();
let mut cx = Graph::default();
// Two caches (like K and V)
let k_cache = cx.named_tensor("k_cache", 5).persist();
let v_cache = cx.named_tensor("v_cache", 5).persist();
// Input values
let k_new = cx.tensor(1).persist();
let v_new = cx.tensor(1).persist();
let indexes = cx.tensor(1).as_dtype(DType::Int).persist();
// Scatter into both caches
let k_out = k_new.scatter(indexes, k_cache);
let v_out = v_new.scatter(indexes, v_cache);
// Read both (simulates attention using the scattered caches)
let k_read = k_out + 0.0;
let v_read = v_out + 0.0;
// Compute something from the scattered values (simulates attention output)
let attn = k_read * v_read;
// Output everything
let attn_out = attn.output();
let k_cache_out = k_out.output();
let v_cache_out = v_out.output();
cx.build_search_space::<CudaRuntime>();
let mut rt = CudaRuntime::initialize(stream.clone());
rt.set_data(k_cache, vec![0.0f32; 5]);
rt.set_data(v_cache, vec![0.0f32; 5]);
rt.set_data(k_new, vec![2.0f32]);
rt.set_data(v_new, vec![3.0f32]);
rt.set_data(indexes, vec![0i32]);
// Use seeded search for deterministic scatter variant selection.
// Seed 0 reliably selects Scatter (not ScatterNoCopy) for both caches.
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
rt = cx.search_options(rt, SearchOptions::new(5), &mut rng);
// Print selected variants
for node in rt.llir_graph().node_weights() {
if let Some(k) = node.to_dialect::<dyn KernelOp>()
&& k.kernel_name().contains("catter")
{
println!("Dual test selected: {}", k.kernel_name());
}
}
// Step 1: scatter k=2.0, v=3.0 at position 0
rt.set_data(k_cache, vec![0.0f32; 5]);
rt.set_data(v_cache, vec![0.0f32; 5]);
rt.set_data(k_new, vec![2.0f32]);
rt.set_data(v_new, vec![3.0f32]);
rt.set_data(indexes, vec![0i32]);
rt.execute(&cx.dyn_map);
let attn1 = rt.get_f32(attn_out);
println!("Attn step 1: {:?}", attn1);
// k=[2,0,0,0,0], v=[3,0,0,0,0], attn = k*v = [6,0,0,0,0]
assert_eq!(attn1, vec![6.0, 0.0, 0.0, 0.0, 0.0], "Step 1 attn mismatch");
// Round-trip
let k_buf = rt.remove_buffer(k_cache_out);
let v_buf = rt.remove_buffer(v_cache_out);
rt.set_buffer(k_cache, k_buf);
rt.set_buffer(v_cache, v_buf);
// Step 2: scatter k=4.0, v=5.0 at position 1
rt.set_data(k_new, vec![4.0f32]);
rt.set_data(v_new, vec![5.0f32]);
rt.set_data(indexes, vec![1i32]);
rt.execute(&cx.dyn_map);
let attn2 = rt.get_f32(attn_out);
println!("Attn step 2: {:?}", attn2);
// k=[2,4,0,0,0], v=[3,5,0,0,0], attn = k*v = [6,20,0,0,0]
assert_eq!(
attn2,
vec![6.0, 20.0, 0.0, 0.0, 0.0],
"Step 2 attn mismatch"
);
// Round-trip
let k_buf = rt.remove_buffer(k_cache_out);
let v_buf = rt.remove_buffer(v_cache_out);
rt.set_buffer(k_cache, k_buf);
rt.set_buffer(v_cache, v_buf);
// Step 3: scatter k=6.0, v=7.0 at position 2
rt.set_data(k_new, vec![6.0f32]);
rt.set_data(v_new, vec![7.0f32]);
rt.set_data(indexes, vec![2i32]);
rt.execute(&cx.dyn_map);
let attn3 = rt.get_f32(attn_out);
println!("Attn step 3: {:?}", attn3);
// k=[2,4,6,0,0], v=[3,5,7,0,0], attn = k*v = [6,20,42,0,0]
assert_eq!(
attn3,
vec![6.0, 20.0, 42.0, 0.0, 0.0],
"Step 3 attn mismatch"
);
}

View File

@@ -15,8 +15,4 @@ half = "2.7.1"
tracing = "0.1.43"
[dev-dependencies]
candle-core = "0.9.2-alpha.1"
proptest = "1.9.0"
[lints.rust]
unexpected_cfgs = { level = "warn", check-cfg = ['cfg(feature, values("cargo-clippy"))'] }

View File

@@ -1,227 +0,0 @@
use super::{MetalMulInfo, MetalSumReduceInfo};
use luminal::prelude::*;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum MetalMatmulFamily {
#[default]
Naive,
RegularTiled,
}
#[derive(Debug, Clone)]
pub struct MatmulDescriptor {
pub m: Expression,
pub n: Expression,
pub k: Expression,
pub batch_shape: Vec<Expression>,
pub lhs_strides: Vec<Expression>,
pub rhs_strides: Vec<Expression>,
pub out_strides: Vec<Expression>,
pub transpose_lhs: bool,
pub transpose_rhs: bool,
}
impl MatmulDescriptor {
pub fn from_mul_and_sum(
mul_info: &MetalMulInfo,
sum_info: &MetalSumReduceInfo,
) -> Option<Self> {
let zero = Expression::from(0);
let z = Expression::from('z');
let is_simple_2d_matmul = mul_info.shape.len() == 3
&& sum_info.shape.len() == 2
&& mul_info.a_strides.len() == 3
&& mul_info.b_strides.len() == 3
&& sum_info.strides.len() == 2
&& mul_info.shape[0] == sum_info.shape[0]
&& mul_info.shape[1] == sum_info.shape[1]
&& mul_info.shape[2] == sum_info.iters
&& mul_info.a_strides[1] == zero
&& mul_info.a_strides[2] == z
&& mul_info.b_strides[0] == zero
&& mul_info.b_strides[1] == z
&& sum_info.strides[1] == z
&& sum_info.iter_stride == z;
if !is_simple_2d_matmul {
return None;
}
Some(Self {
m: sum_info.shape[0],
n: sum_info.shape[1],
k: sum_info.iters,
batch_shape: Vec::new(),
lhs_strides: mul_info.a_strides.clone(),
rhs_strides: mul_info.b_strides.clone(),
out_strides: sum_info.strides.clone(),
transpose_lhs: false,
transpose_rhs: false,
})
}
}
#[derive(Debug, Clone)]
pub struct MatmulPlan {
pub family: MetalMatmulFamily,
pub m: Expression,
pub n: Expression,
pub k: Expression,
pub lda: Expression,
pub ldb: Expression,
pub ldd: Expression,
pub batch_size: u32,
pub batch_stride_a: u32,
pub batch_stride_b: u32,
pub batch_stride_d: u32,
pub bm: u16,
pub bn: u16,
pub bk: u16,
pub wm: u16,
pub wn: u16,
}
#[derive(Debug, Default, Clone, Copy)]
pub struct MetalMatmulPlanner;
impl MetalMatmulPlanner {
pub fn plan(&self, desc: &MatmulDescriptor) -> MatmulPlan {
let family = if desc.batch_shape.is_empty()
&& desc.m.as_num().is_some_and(|m| m >= 32)
&& desc.n.as_num().is_some_and(|n| n >= 32)
&& desc.k.as_num().is_some_and(|k| k >= 32)
{
MetalMatmulFamily::RegularTiled
} else {
MetalMatmulFamily::Naive
};
MatmulPlan {
family,
m: desc.m,
n: desc.n,
k: desc.k,
lda: desc.lhs_strides[0],
ldb: desc.rhs_strides[2],
ldd: desc.out_strides[0],
batch_size: 1,
batch_stride_a: 0,
batch_stride_b: 0,
batch_stride_d: 0,
bm: 16,
bn: 16,
bk: 8,
wm: 2,
wn: 2,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn descriptor_recovers_simple_2d_matmul() {
let mul = MetalMulInfo {
shape: vec![
Expression::from(4),
Expression::from(8),
Expression::from(16),
],
a_strides: vec![
Expression::from('z') * 16,
Expression::from(0),
Expression::from('z'),
],
b_strides: vec![
Expression::from(0),
Expression::from('z'),
Expression::from('z') * 8,
],
output_strides: vec![
Expression::from('z') * 16,
Expression::from('z') * 8,
Expression::from('z'),
],
};
let sum = MetalSumReduceInfo {
shape: vec![Expression::from(4), Expression::from(8)],
strides: vec![Expression::from('z') * 8, Expression::from('z')],
iters: Expression::from(16),
iter_stride: Expression::from('z'),
};
let desc = MatmulDescriptor::from_mul_and_sum(&mul, &sum).unwrap();
assert_eq!(desc.m, Expression::from(4));
assert_eq!(desc.n, Expression::from(8));
assert_eq!(desc.k, Expression::from(16));
}
#[test]
fn planner_keeps_small_problems_on_naive_path() {
let desc = MatmulDescriptor {
m: Expression::from(4),
n: Expression::from(8),
k: Expression::from(16),
batch_shape: Vec::new(),
lhs_strides: vec![
Expression::from('z') * 16,
Expression::from(0),
Expression::from('z'),
],
rhs_strides: vec![
Expression::from(0),
Expression::from('z'),
Expression::from('z') * 8,
],
out_strides: vec![Expression::from('z') * 8, Expression::from('z')],
transpose_lhs: false,
transpose_rhs: false,
};
let planner = MetalMatmulPlanner;
let plan = planner.plan(&desc);
assert_eq!(plan.family, MetalMatmulFamily::Naive);
assert_eq!(plan.bm, 16);
assert_eq!(plan.bn, 16);
assert_eq!(plan.bk, 8);
assert_eq!(plan.wm, 2);
assert_eq!(plan.wn, 2);
assert_eq!(plan.lda, Expression::from('z') * 16);
assert_eq!(plan.ldb, Expression::from('z') * 8);
assert_eq!(plan.ldd, Expression::from('z') * 8);
}
#[test]
fn planner_promotes_large_problems_to_regular_tiled() {
let desc = MatmulDescriptor {
m: Expression::from(64),
n: Expression::from(64),
k: Expression::from(64),
batch_shape: Vec::new(),
lhs_strides: vec![
Expression::from('z') * 64,
Expression::from(0),
Expression::from('z'),
],
rhs_strides: vec![
Expression::from(0),
Expression::from('z'),
Expression::from('z') * 64,
],
out_strides: vec![Expression::from('z') * 64, Expression::from('z')],
transpose_lhs: false,
transpose_rhs: false,
};
let planner = MetalMatmulPlanner;
let plan = planner.plan(&desc);
assert_eq!(plan.family, MetalMatmulFamily::RegularTiled);
assert_eq!(plan.bm, 16);
assert_eq!(plan.bn, 16);
assert_eq!(plan.bk, 8);
assert_eq!(plan.wm, 2);
assert_eq!(plan.wn, 2);
}
}

View File

@@ -1,42 +1,15 @@
mod matmul;
mod ops;
pub use matmul::*;
pub use ops::*;
use luminal::dtype::DType;
use luminal::op::EgglogOp;
use luminal::prelude::*;
use metal::{Buffer, ComputeCommandEncoderRef, ComputePipelineState, Device};
pub const DYN_BUFFER_INDEX: u64 = 30;
pub const DYN_SLOT_COUNT: usize = 26;
#[derive(Debug, Clone)]
pub struct MetalMulInfo {
pub shape: Vec<Expression>,
pub a_strides: Vec<Expression>,
pub b_strides: Vec<Expression>,
pub output_strides: Vec<Expression>,
}
#[derive(Debug, Clone)]
pub struct MetalSumReduceInfo {
pub shape: Vec<Expression>,
pub strides: Vec<Expression>,
pub iters: Expression,
pub iter_stride: Expression,
}
pub trait MetalKernelOp: EgglogOp {
fn compile(
&self,
device: &Device,
input_dtypes: &[DType],
output_dtype: DType,
) -> ComputePipelineState;
fn infer_output_dtype(&self, input_dtypes: &[DType]) -> DType {
input_dtypes.first().copied().unwrap_or(DType::F32)
}
fn compile(&self, device: &Device) -> ComputePipelineState;
fn output_size(&self) -> Expression;
@@ -64,18 +37,6 @@ pub trait MetalKernelOp: EgglogOp {
fn flops(&self, _dyn_map: &FxHashMap<char, usize>) -> usize {
0
}
fn mul_info(&self) -> Option<MetalMulInfo> {
None
}
fn sum_reduce_info(&self) -> Option<MetalSumReduceInfo> {
None
}
fn is_matmul(&self) -> bool {
false
}
}
luminal::impl_into_ops!(MetalKernelOp);

File diff suppressed because it is too large Load Diff

View File

@@ -1,12 +1,10 @@
use crate::kernel::{
MatmulDescriptor, MetalKernelOp, MetalMatmul, MetalMatmulPlanner, DYN_SLOT_COUNT,
};
use half::f16;
#![allow(unexpected_cfgs)]
use crate::kernel::{MetalKernelOp, DYN_BUFFER_INDEX, DYN_SLOT_COUNT};
use itertools::Itertools;
use luminal::{
dtype::DType,
graph::LLIRGraph,
hlir::{Input, NativeData, Output},
hlir::{Input, Output},
op::{ExecutionStats, Runtime, RuntimeStats, TimingMethod},
prelude::{
petgraph::{algo::toposort, prelude::StableGraph, visit::EdgeRef, Direction},
@@ -20,8 +18,6 @@ use std::time::Duration;
pub struct MetalRuntime {
device: Device,
command_queue: CommandQueue,
/// Host-side input tensors provided by the user.
input_data: FxHashMap<NodeIndex, NativeData>,
/// Buffers for HLIR input tensors (set by user)
pub hlir_buffers: FxHashMap<NodeIndex, Buffer>,
/// Buffers for LLIR intermediate/output tensors
@@ -30,110 +26,18 @@ pub struct MetalRuntime {
dyn_buffer: Buffer,
/// The current LLIR graph
llir_graph: LLIRGraph,
/// Inferred runtime dtype for each LLIR node.
node_dtypes: FxHashMap<NodeIndex, DType>,
/// Compiled pipeline states for each kernel node
pipelines: FxHashMap<NodeIndex, ComputePipelineState>,
}
impl MetalRuntime {
fn fuse_matmuls(llir_graph: &LLIRGraph) -> LLIRGraph {
let mut graph = llir_graph.clone();
let planner = MetalMatmulPlanner;
let mut rewrites = Vec::new();
for sum_node in graph.node_indices().collect::<Vec<_>>() {
let Some(sum_info) = graph[sum_node]
.to_dialect::<dyn MetalKernelOp>()
.and_then(|op| op.sum_reduce_info())
else {
continue;
};
let input_edges: Vec<_> = graph
.edges_directed(sum_node, Direction::Incoming)
.sorted_by_key(|e| e.id())
.map(|e| e.source())
.collect();
if input_edges.len() != 1 {
continue;
}
let mul_node = input_edges[0];
let Some(mul_info) = graph[mul_node]
.to_dialect::<dyn MetalKernelOp>()
.and_then(|op| op.mul_info())
else {
continue;
};
let Some(desc) = MatmulDescriptor::from_mul_and_sum(&mul_info, &sum_info) else {
continue;
};
let mul_inputs: Vec<_> = graph
.edges_directed(mul_node, Direction::Incoming)
.sorted_by_key(|e| e.id())
.map(|e| e.source())
.collect();
if mul_inputs.len() != 2 {
continue;
}
rewrites.push((sum_node, mul_node, mul_inputs, planner.plan(&desc)));
}
for (sum_node, mul_node, mul_inputs, plan) in rewrites {
graph[sum_node] =
luminal::op::LLIROp::new::<dyn MetalKernelOp>(Box::new(MetalMatmul {
m: plan.m,
n: plan.n,
k: plan.k,
lda: plan.lda,
ldb: plan.ldb,
ldd: plan.ldd,
family: plan.family,
bm: plan.bm,
bn: plan.bn,
bk: plan.bk,
wm: plan.wm,
wn: plan.wn,
batch_size: plan.batch_size,
batch_stride_a: plan.batch_stride_a,
batch_stride_b: plan.batch_stride_b,
batch_stride_d: plan.batch_stride_d,
}));
graph.remove_node(mul_node);
graph.add_edge(mul_inputs[0], sum_node, ());
graph.add_edge(mul_inputs[1], sum_node, ());
}
graph
}
#[cfg(test)]
pub(crate) fn contains_matmul(&self) -> bool {
self.llir_graph.node_indices().any(|node| {
self.llir_graph[node]
.to_dialect::<dyn MetalKernelOp>()
.is_some_and(|op| op.is_matmul())
})
}
#[cfg(test)]
pub(crate) fn debug_kernel_ops(&self) -> Vec<String> {
self.llir_graph
.node_indices()
.filter_map(|node| {
self.llir_graph[node]
.to_dialect::<dyn MetalKernelOp>()
.map(|op| format!("{op:?}"))
})
.collect()
}
pub fn set_data(&mut self, id: impl ToId, data: impl Into<NativeData>) {
self.input_data.insert(id.to_id(), data.into());
pub fn set_data(&mut self, id: impl ToId, data: &[f32]) {
let buffer = self.device.new_buffer_with_data(
data.as_ptr() as *const _,
std::mem::size_of_val(data) as u64,
MTLResourceOptions::StorageModeShared,
);
self.hlir_buffers.insert(id.to_id(), buffer);
}
pub fn get_f32(&self, id: impl ToId) -> Vec<f32> {
@@ -168,42 +72,10 @@ impl MetalRuntime {
}
})
.expect("Cannot find tensor in runtime!");
let dtype = self
.node_dtypes
.get(&data_id)
.copied()
.or_else(|| {
self.llir_graph[data_id]
.to_op::<Input>()
.map(|inp| inp.dtype)
})
.unwrap_or(DType::F32);
let ptr = buffer.contents() as *const f32;
let len = buffer.length() as usize / std::mem::size_of::<f32>();
unsafe {
match dtype {
DType::F16 => {
let ptr = buffer.contents() as *const f16;
let len = buffer.length() as usize / std::mem::size_of::<f16>();
std::slice::from_raw_parts(ptr, len)
.iter()
.map(|v| v.to_f32())
.collect()
}
DType::Int => {
let ptr = buffer.contents() as *const i32;
let len = buffer.length() as usize / std::mem::size_of::<i32>();
std::slice::from_raw_parts(ptr, len)
.iter()
.map(|v| *v as f32)
.collect()
}
_ => {
let ptr = buffer.contents() as *const f32;
let len = buffer.length() as usize / std::mem::size_of::<f32>();
std::slice::from_raw_parts(ptr, len).to_vec()
}
}
}
unsafe { std::slice::from_raw_parts(ptr, len) }.to_vec()
}
}
@@ -224,12 +96,10 @@ impl Runtime for MetalRuntime {
Self {
device,
command_queue,
input_data: FxHashMap::default(),
hlir_buffers: FxHashMap::default(),
buffers: FxHashMap::default(),
dyn_buffer,
llir_graph: StableGraph::default(),
node_dtypes: FxHashMap::default(),
pipelines: FxHashMap::default(),
}
}
@@ -238,48 +108,16 @@ impl Runtime for MetalRuntime {
fn load_llir(&mut self, llir_graph: &LLIRGraph) {
self.pipelines.clear();
self.buffers.clear();
self.hlir_buffers.clear();
self.node_dtypes.clear();
self.llir_graph = Self::fuse_matmuls(llir_graph);
let topo_order = toposort(&self.llir_graph, None).expect("Graph has cycles!");
for node in topo_order {
if let Some(input) = self.llir_graph[node].to_op::<Input>() {
self.node_dtypes.insert(node, input.dtype);
let hlir_id = NodeIndex::new(input.node);
if let Some(data) = self.input_data.get(&hlir_id) {
let buffer = self.create_input_buffer(data, input.dtype);
self.hlir_buffers.insert(hlir_id, buffer);
}
continue;
}
if self.llir_graph[node].to_op::<Output>().is_some() {
continue;
}
if let Some(kernel_op) = self.llir_graph[node].to_dialect::<dyn MetalKernelOp>() {
let input_nodes: Vec<NodeIndex> = self
.llir_graph
.edges_directed(node, Direction::Incoming)
.sorted_by_key(|e| e.id())
.map(|e| e.source())
.collect();
let input_dtypes: Vec<DType> = input_nodes
.iter()
.map(|n| {
self.node_dtypes
.get(n)
.copied()
.unwrap_or_else(|| panic!("Missing inferred dtype for node {n:?}"))
})
.collect();
let output_dtype = kernel_op.infer_output_dtype(&input_dtypes);
let pipeline = kernel_op.compile(&self.device, &input_dtypes, output_dtype);
self.node_dtypes.insert(node, output_dtype);
// Compile all kernel ops
for node in llir_graph.node_indices() {
if let Some(kernel_op) = llir_graph[node].to_dialect::<dyn MetalKernelOp>() {
let pipeline = kernel_op.compile(&self.device);
self.pipelines.insert(node, pipeline);
}
}
self.llir_graph = llir_graph.clone();
}
#[tracing::instrument(skip_all)]
@@ -323,6 +161,7 @@ impl Runtime for MetalRuntime {
self.update_dyn_buffer(dyn_map);
let command_buffer = self.command_queue.new_command_buffer();
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_buffer(DYN_BUFFER_INDEX, Some(&self.dyn_buffer), 0);
for node in topo_order {
if self.llir_graph[node].to_op::<Input>().is_some()
@@ -361,11 +200,6 @@ impl Runtime for MetalRuntime {
.get(&node)
.expect("Output buffer not allocated!");
// Bind dyn dims right after the output slot:
// [inputs..., output, dyn, bytes...]
let dyn_idx = input_buffers.len() as u64 + 1;
encoder.set_buffer(dyn_idx, Some(&self.dyn_buffer), 0);
kernel_op.encode(encoder, pipeline, &input_buffers, output_buffer, dyn_map);
}
}
@@ -402,36 +236,6 @@ impl RuntimeStats for MetalRuntime {
}
impl MetalRuntime {
fn create_input_buffer(&self, data: &NativeData, dtype: DType) -> Buffer {
match dtype {
DType::F32 => {
let values: Vec<f32> = (0..data.len()).map(|i| data.f32(i)).collect();
self.device.new_buffer_with_data(
values.as_ptr() as *const _,
std::mem::size_of_val(values.as_slice()) as u64,
MTLResourceOptions::StorageModeShared,
)
}
DType::F16 => {
let values: Vec<f16> = (0..data.len()).map(|i| data.f16(i)).collect();
self.device.new_buffer_with_data(
values.as_ptr() as *const _,
std::mem::size_of_val(values.as_slice()) as u64,
MTLResourceOptions::StorageModeShared,
)
}
DType::Int => {
let values: Vec<i32> = (0..data.len()).map(|i| data.i32(i)).collect();
self.device.new_buffer_with_data(
values.as_ptr() as *const _,
std::mem::size_of_val(values.as_slice()) as u64,
MTLResourceOptions::StorageModeShared,
)
}
unsupported => panic!("Metal input dtype {unsupported:?} is not supported yet"),
}
}
pub fn allocate_intermediate_buffers(&mut self, dyn_map: &FxHashMap<char, usize>) {
for node in self.llir_graph.node_indices() {
if self.llir_graph[node].to_op::<Input>().is_some() {
@@ -440,9 +244,8 @@ impl MetalRuntime {
if let Some(kernel_op) = self.llir_graph[node].to_dialect::<dyn MetalKernelOp>() {
let size = kernel_op.output_size().exec(dyn_map).unwrap();
let dtype = self.node_dtypes.get(&node).copied().unwrap_or(DType::F32);
let buffer = self.device.new_buffer(
(size * dtype.bits().div_ceil(8)) as u64,
(size * std::mem::size_of::<f32>()) as u64,
MTLResourceOptions::StorageModeShared,
);
self.buffers.insert(node, buffer);
@@ -486,6 +289,7 @@ impl MetalRuntime {
self.update_dyn_buffer(dyn_map);
let command_buffer = self.command_queue.new_command_buffer();
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_buffer(DYN_BUFFER_INDEX, Some(&self.dyn_buffer), 0);
for node in topo_order {
if self.llir_graph[node].to_op::<Input>().is_some()
@@ -524,9 +328,6 @@ impl MetalRuntime {
.get(&node)
.expect("Output buffer not allocated!");
let dyn_idx = input_buffers.len() as u64 + 1;
encoder.set_buffer(dyn_idx, Some(&self.dyn_buffer), 0);
kernel_op.encode(encoder, pipeline, &input_buffers, output_buffer, dyn_map);
}
}

View File

@@ -1,6 +1,4 @@
use crate::{kernel::lower_expression_for_metal, runtime::MetalRuntime};
use candle_core::{Device as CandleDevice, Tensor as CandleTensor};
use half::f16;
use luminal::prelude::*;
use proptest::prelude::*;
@@ -26,194 +24,6 @@ fn assert_close(actual: &[f32], expected: &[f32], tolerance: f32) {
}
}
const TRANSFORMER_SEQ: usize = 4;
const TRANSFORMER_HIDDEN: usize = 16;
const TRANSFORMER_INTERMEDIATE: usize = 32;
fn rms_norm(x: GraphTensor, weight: GraphTensor, eps: f32) -> GraphTensor {
let normed = x.std_norm(x.shape.last_axis(), eps);
normed * weight.expand_lhs(&x.dims()[..x.dims().len() - 1])
}
fn self_attention(
x: GraphTensor,
wq: GraphTensor,
wk: GraphTensor,
wv: GraphTensor,
wo: GraphTensor,
) -> GraphTensor {
let q = x.matmul(wq.t());
let k = x.matmul(wk.t());
let v = x.matmul(wv.t());
let scale = 1.0 / (TRANSFORMER_HIDDEN as f32).sqrt();
let scores = q.matmul(k.t()) * scale;
let attn_weights = scores.softmax(1);
attn_weights.matmul(v).matmul(wo.t())
}
fn swiglu_mlp(
x: GraphTensor,
w_gate: GraphTensor,
w_up: GraphTensor,
w_down: GraphTensor,
) -> GraphTensor {
let gate = x.matmul(w_gate.t()).swish();
let up = x.matmul(w_up.t());
(gate * up).matmul(w_down.t())
}
struct MiniTransformerLayer {
attn_norm_w: GraphTensor,
wq: GraphTensor,
wk: GraphTensor,
wv: GraphTensor,
wo: GraphTensor,
mlp_norm_w: GraphTensor,
w_gate: GraphTensor,
w_up: GraphTensor,
w_down: GraphTensor,
}
impl MiniTransformerLayer {
fn init(cx: &mut Graph) -> Self {
Self {
attn_norm_w: cx.tensor(TRANSFORMER_HIDDEN),
wq: cx.tensor((TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN)),
wk: cx.tensor((TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN)),
wv: cx.tensor((TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN)),
wo: cx.tensor((TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN)),
mlp_norm_w: cx.tensor(TRANSFORMER_HIDDEN),
w_gate: cx.tensor((TRANSFORMER_INTERMEDIATE, TRANSFORMER_HIDDEN)),
w_up: cx.tensor((TRANSFORMER_INTERMEDIATE, TRANSFORMER_HIDDEN)),
w_down: cx.tensor((TRANSFORMER_HIDDEN, TRANSFORMER_INTERMEDIATE)),
}
}
fn forward(&self, x: GraphTensor) -> GraphTensor {
let normed = rms_norm(x, self.attn_norm_w, 1e-5);
let attn_out = self_attention(normed, self.wq, self.wk, self.wv, self.wo);
let x = x + attn_out;
let normed = rms_norm(x, self.mlp_norm_w, 1e-5);
let mlp_out = swiglu_mlp(normed, self.w_gate, self.w_up, self.w_down);
x + mlp_out
}
fn weights(&self) -> Vec<(GraphTensor, usize)> {
vec![
(self.attn_norm_w, TRANSFORMER_HIDDEN),
(self.wq, TRANSFORMER_HIDDEN * TRANSFORMER_HIDDEN),
(self.wk, TRANSFORMER_HIDDEN * TRANSFORMER_HIDDEN),
(self.wv, TRANSFORMER_HIDDEN * TRANSFORMER_HIDDEN),
(self.wo, TRANSFORMER_HIDDEN * TRANSFORMER_HIDDEN),
(self.mlp_norm_w, TRANSFORMER_HIDDEN),
(self.w_gate, TRANSFORMER_INTERMEDIATE * TRANSFORMER_HIDDEN),
(self.w_up, TRANSFORMER_INTERMEDIATE * TRANSFORMER_HIDDEN),
(self.w_down, TRANSFORMER_HIDDEN * TRANSFORMER_INTERMEDIATE),
]
}
}
fn rms_norm_ref(x: &CandleTensor, weight: &CandleTensor, eps: f64) -> CandleTensor {
let dims = x.dims();
let last_dim = dims[dims.len() - 1];
let sq_mean = x.sqr().unwrap().mean_keepdim(dims.len() - 1).unwrap();
let rsqrt = (sq_mean + eps).unwrap().sqrt().unwrap().recip().unwrap();
let normed = x.broadcast_mul(&rsqrt).unwrap();
normed
.broadcast_mul(&weight.reshape((1, last_dim)).unwrap())
.unwrap()
}
fn self_attention_ref(
x: &CandleTensor,
wq: &CandleTensor,
wk: &CandleTensor,
wv: &CandleTensor,
wo: &CandleTensor,
) -> CandleTensor {
let q = x.matmul(&wq.t().unwrap()).unwrap();
let k = x.matmul(&wk.t().unwrap()).unwrap();
let v = x.matmul(&wv.t().unwrap()).unwrap();
let scale = 1.0 / (TRANSFORMER_HIDDEN as f64).sqrt();
let scores = q.matmul(&k.t().unwrap()).unwrap();
let scores = (scores * scale).unwrap();
let max_val = scores.max(1).unwrap().unsqueeze(1).unwrap();
let shifted = scores.broadcast_sub(&max_val).unwrap();
let exps = shifted.exp().unwrap();
let sum_exps = exps.sum(1).unwrap().unsqueeze(1).unwrap();
let attn_weights = exps.broadcast_div(&sum_exps).unwrap();
attn_weights
.matmul(&v)
.unwrap()
.matmul(&wo.t().unwrap())
.unwrap()
}
fn swiglu_mlp_ref(
x: &CandleTensor,
w_gate: &CandleTensor,
w_up: &CandleTensor,
w_down: &CandleTensor,
) -> CandleTensor {
let gate = x.matmul(&w_gate.t().unwrap()).unwrap().silu().unwrap();
let up = x.matmul(&w_up.t().unwrap()).unwrap();
(gate * up).unwrap().matmul(&w_down.t().unwrap()).unwrap()
}
#[allow(clippy::too_many_arguments)]
fn transformer_layer_ref(
x: &CandleTensor,
attn_norm_w: &CandleTensor,
wq: &CandleTensor,
wk: &CandleTensor,
wv: &CandleTensor,
wo: &CandleTensor,
mlp_norm_w: &CandleTensor,
w_gate: &CandleTensor,
w_up: &CandleTensor,
w_down: &CandleTensor,
) -> CandleTensor {
let normed = rms_norm_ref(x, attn_norm_w, 1e-5);
let attn_out = self_attention_ref(&normed, wq, wk, wv, wo);
let x = (x + attn_out).unwrap();
let normed = rms_norm_ref(&x, mlp_norm_w, 1e-5);
let mlp_out = swiglu_mlp_ref(&normed, w_gate, w_up, w_down);
(x + mlp_out).unwrap()
}
fn seeded_data(len: usize, scale: f32, bias: f32) -> Vec<f32> {
(0..len)
.map(|i| (((i * 37 + 11) % 97) as f32 / 97.0) * scale + bias)
.collect()
}
fn to_f16_vec(values: &[f32]) -> Vec<f16> {
values.iter().copied().map(f16::from_f32).collect()
}
fn generate_layer_weights(layer: &MiniTransformerLayer) -> Vec<(GraphTensor, Vec<f32>)> {
layer
.weights()
.iter()
.enumerate()
.map(|(i, (tensor, size))| {
let data = seeded_data(*size, 0.8 - i as f32 * 0.03, -0.4 + i as f32 * 0.02);
let data = if *size == TRANSFORMER_HIDDEN {
data.iter().map(|x| x + 1.0).collect::<Vec<_>>()
} else {
data
};
(*tensor, data)
})
.collect()
}
/// dynamic symbols in kernel expressions should route through dyn buffer.
#[test]
fn dynamic_const_codegen_uses_dyn_buffer() {
@@ -531,425 +341,6 @@ fn metal_simple_max_reduce() {
assert_close(&out, &[4.0, 8.0], 0.001);
}
#[test]
fn metal_f16_cast_roundtrip() {
let mut cx = Graph::default();
let input = cx.tensor(4);
let output = input.cast(DType::F16).output();
cx.build_search_space::<MetalRuntime>();
let mut rt = MetalRuntime::initialize(());
rt.set_data(input, &[1.0, -2.5, 3.25, 4.75]);
rt = cx.search(rt, 3);
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
let out = rt.get_f32(output);
assert_close(&out, &[1.0, -2.5, 3.25, 4.75], 0.002);
}
#[test]
fn metal_f16_intermediate_add_roundtrip() {
let mut cx = Graph::default();
let a = cx.tensor(4);
let b = cx.tensor(4);
let output = (a.cast(DType::F16) + b.cast(DType::F16))
.cast(DType::F32)
.output();
cx.build_search_space::<MetalRuntime>();
let mut rt = MetalRuntime::initialize(());
rt.set_data(a, &[1.0, 2.0, -3.0, 4.5]);
rt.set_data(b, &[0.5, -1.0, 3.0, 0.25]);
rt = cx.search(rt, 3);
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
let out = rt.get_f32(output);
assert_close(&out, &[1.5, 1.0, 0.0, 4.75], 0.003);
}
#[test]
fn metal_specialized_matmul() {
let mut cx = Graph::default();
let a = cx.tensor((TRANSFORMER_SEQ, TRANSFORMER_HIDDEN));
let b = cx.tensor((TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN));
let output = a.matmul(b).output();
cx.build_search_space::<MetalRuntime>();
let mut rt = MetalRuntime::initialize(());
let a_data = seeded_data(TRANSFORMER_SEQ * TRANSFORMER_HIDDEN, 1.0, -0.5);
let b_data = seeded_data(TRANSFORMER_HIDDEN * TRANSFORMER_HIDDEN, 0.8, -0.4);
rt.set_data(a, &a_data);
rt.set_data(b, &b_data);
rt = cx.search(rt, 1);
assert!(
rt.contains_matmul(),
"expected Metal runtime to fuse matmul, kernels: {:?}",
rt.debug_kernel_ops()
);
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
let result = rt.get_f32(output);
let device = CandleDevice::Cpu;
let ref_a =
CandleTensor::from_vec(a_data, (TRANSFORMER_SEQ, TRANSFORMER_HIDDEN), &device).unwrap();
let ref_b =
CandleTensor::from_vec(b_data, (TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN), &device).unwrap();
let expected = ref_a.matmul(&ref_b).unwrap();
let expected: Vec<f32> = expected.flatten_all().unwrap().to_vec1().unwrap();
assert_close(&result, &expected, 1e-3);
}
#[test]
fn metal_regular_tiled_matmul_path() {
let mut cx = Graph::default();
let m = 64;
let k = 64;
let n = 64;
let a = cx.tensor((m, k));
let b = cx.tensor((k, n));
let output = a.matmul(b).output();
cx.build_search_space::<MetalRuntime>();
let mut rt = MetalRuntime::initialize(());
let a_data = seeded_data(m * k, 0.4, -0.2);
let b_data = seeded_data(k * n, 0.3, -0.15);
rt.set_data(a, &a_data);
rt.set_data(b, &b_data);
rt = cx.search(rt, 1);
let kernels = rt.debug_kernel_ops();
assert!(
kernels.iter().any(|k| k.contains("family: RegularTiled")),
"expected regular tiled matmul path, kernels: {:?}",
kernels
);
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
let result = rt.get_f32(output);
let device = CandleDevice::Cpu;
let ref_a = CandleTensor::from_vec(a_data, (m, k), &device).unwrap();
let ref_b = CandleTensor::from_vec(b_data, (k, n), &device).unwrap();
let expected = ref_a.matmul(&ref_b).unwrap();
let expected: Vec<f32> = expected.flatten_all().unwrap().to_vec1().unwrap();
assert_close(&result, &expected, 2e-3);
}
#[test]
fn metal_rms_norm() {
let mut cx = Graph::default();
let input = cx.tensor((TRANSFORMER_SEQ, TRANSFORMER_HIDDEN));
let weight = cx.tensor(TRANSFORMER_HIDDEN);
let output = rms_norm(input, weight, 1e-5).output();
cx.build_search_space::<MetalRuntime>();
let mut rt = MetalRuntime::initialize(());
let input_data = seeded_data(TRANSFORMER_SEQ * TRANSFORMER_HIDDEN, 1.0, -0.5);
let weight_data: Vec<f32> = seeded_data(TRANSFORMER_HIDDEN, 0.5, 0.75);
rt.set_data(input, &input_data);
rt.set_data(weight, &weight_data);
rt = cx.search(rt, 1);
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
let result = rt.get_f32(output);
let device = CandleDevice::Cpu;
let ref_input =
CandleTensor::from_vec(input_data, (TRANSFORMER_SEQ, TRANSFORMER_HIDDEN), &device).unwrap();
let ref_weight = CandleTensor::from_vec(weight_data, TRANSFORMER_HIDDEN, &device).unwrap();
let expected = rms_norm_ref(&ref_input, &ref_weight, 1e-5);
let expected: Vec<f32> = expected.flatten_all().unwrap().to_vec1().unwrap();
assert_close(&result, &expected, 1e-3);
}
#[test]
fn metal_self_attention() {
let mut cx = Graph::default();
let input = cx.tensor((TRANSFORMER_SEQ, TRANSFORMER_HIDDEN));
let wq = cx.tensor((TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN));
let wk = cx.tensor((TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN));
let wv = cx.tensor((TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN));
let wo = cx.tensor((TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN));
let output = self_attention(input, wq, wk, wv, wo).output();
cx.build_search_space::<MetalRuntime>();
let mut rt = MetalRuntime::initialize(());
let input_data = seeded_data(TRANSFORMER_SEQ * TRANSFORMER_HIDDEN, 1.0, -0.5);
let wq_data = seeded_data(TRANSFORMER_HIDDEN * TRANSFORMER_HIDDEN, 0.8, -0.4);
let wk_data = seeded_data(TRANSFORMER_HIDDEN * TRANSFORMER_HIDDEN, 0.7, -0.35);
let wv_data = seeded_data(TRANSFORMER_HIDDEN * TRANSFORMER_HIDDEN, 0.6, -0.3);
let wo_data = seeded_data(TRANSFORMER_HIDDEN * TRANSFORMER_HIDDEN, 0.5, -0.25);
rt.set_data(input, &input_data);
rt.set_data(wq, &wq_data);
rt.set_data(wk, &wk_data);
rt.set_data(wv, &wv_data);
rt.set_data(wo, &wo_data);
rt = cx.search(rt, 1);
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
let result = rt.get_f32(output);
let device = CandleDevice::Cpu;
let ref_input =
CandleTensor::from_vec(input_data, (TRANSFORMER_SEQ, TRANSFORMER_HIDDEN), &device).unwrap();
let ref_wq =
CandleTensor::from_vec(wq_data, (TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN), &device).unwrap();
let ref_wk =
CandleTensor::from_vec(wk_data, (TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN), &device).unwrap();
let ref_wv =
CandleTensor::from_vec(wv_data, (TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN), &device).unwrap();
let ref_wo =
CandleTensor::from_vec(wo_data, (TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN), &device).unwrap();
let expected = self_attention_ref(&ref_input, &ref_wq, &ref_wk, &ref_wv, &ref_wo);
let expected: Vec<f32> = expected.flatten_all().unwrap().to_vec1().unwrap();
assert_close(&result, &expected, 1e-2);
}
#[test]
fn metal_self_attention_f16_weights() {
let mut cx = Graph::default();
let input = cx
.tensor((TRANSFORMER_SEQ, TRANSFORMER_HIDDEN))
.as_dtype(DType::F16);
let wq = cx
.tensor((TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN))
.as_dtype(DType::F16);
let wk = cx
.tensor((TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN))
.as_dtype(DType::F16);
let wv = cx
.tensor((TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN))
.as_dtype(DType::F16);
let wo = cx
.tensor((TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN))
.as_dtype(DType::F16);
let output = self_attention(input, wq, wk, wv, wo)
.cast(DType::F32)
.output();
cx.build_search_space::<MetalRuntime>();
let mut rt = MetalRuntime::initialize(());
let input_data = seeded_data(TRANSFORMER_SEQ * TRANSFORMER_HIDDEN, 1.0, -0.5);
let wq_data = seeded_data(TRANSFORMER_HIDDEN * TRANSFORMER_HIDDEN, 0.8, -0.4);
let wk_data = seeded_data(TRANSFORMER_HIDDEN * TRANSFORMER_HIDDEN, 0.7, -0.35);
let wv_data = seeded_data(TRANSFORMER_HIDDEN * TRANSFORMER_HIDDEN, 0.6, -0.3);
let wo_data = seeded_data(TRANSFORMER_HIDDEN * TRANSFORMER_HIDDEN, 0.5, -0.25);
rt.set_data(input, to_f16_vec(&input_data));
rt.set_data(wq, to_f16_vec(&wq_data));
rt.set_data(wk, to_f16_vec(&wk_data));
rt.set_data(wv, to_f16_vec(&wv_data));
rt.set_data(wo, to_f16_vec(&wo_data));
rt = cx.search(rt, 1);
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
let result = rt.get_f32(output);
let device = CandleDevice::Cpu;
let ref_input =
CandleTensor::from_vec(input_data, (TRANSFORMER_SEQ, TRANSFORMER_HIDDEN), &device).unwrap();
let ref_wq =
CandleTensor::from_vec(wq_data, (TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN), &device).unwrap();
let ref_wk =
CandleTensor::from_vec(wk_data, (TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN), &device).unwrap();
let ref_wv =
CandleTensor::from_vec(wv_data, (TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN), &device).unwrap();
let ref_wo =
CandleTensor::from_vec(wo_data, (TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN), &device).unwrap();
let expected = self_attention_ref(&ref_input, &ref_wq, &ref_wk, &ref_wv, &ref_wo);
let expected: Vec<f32> = expected.flatten_all().unwrap().to_vec1().unwrap();
assert_close(&result, &expected, 2e-2);
}
#[test]
fn metal_swiglu_mlp() {
let mut cx = Graph::default();
let input = cx.tensor((TRANSFORMER_SEQ, TRANSFORMER_HIDDEN));
let w_gate = cx.tensor((TRANSFORMER_INTERMEDIATE, TRANSFORMER_HIDDEN));
let w_up = cx.tensor((TRANSFORMER_INTERMEDIATE, TRANSFORMER_HIDDEN));
let w_down = cx.tensor((TRANSFORMER_HIDDEN, TRANSFORMER_INTERMEDIATE));
let output = swiglu_mlp(input, w_gate, w_up, w_down).output();
cx.build_search_space::<MetalRuntime>();
let mut rt = MetalRuntime::initialize(());
let input_data = seeded_data(TRANSFORMER_SEQ * TRANSFORMER_HIDDEN, 1.0, -0.5);
let gate_data = seeded_data(TRANSFORMER_INTERMEDIATE * TRANSFORMER_HIDDEN, 0.8, -0.4);
let up_data = seeded_data(TRANSFORMER_INTERMEDIATE * TRANSFORMER_HIDDEN, 0.7, -0.35);
let down_data = seeded_data(TRANSFORMER_HIDDEN * TRANSFORMER_INTERMEDIATE, 0.6, -0.3);
rt.set_data(input, &input_data);
rt.set_data(w_gate, &gate_data);
rt.set_data(w_up, &up_data);
rt.set_data(w_down, &down_data);
rt = cx.search(rt, 1);
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
let result = rt.get_f32(output);
let device = CandleDevice::Cpu;
let ref_input =
CandleTensor::from_vec(input_data, (TRANSFORMER_SEQ, TRANSFORMER_HIDDEN), &device).unwrap();
let ref_gate = CandleTensor::from_vec(
gate_data,
(TRANSFORMER_INTERMEDIATE, TRANSFORMER_HIDDEN),
&device,
)
.unwrap();
let ref_up = CandleTensor::from_vec(
up_data,
(TRANSFORMER_INTERMEDIATE, TRANSFORMER_HIDDEN),
&device,
)
.unwrap();
let ref_down = CandleTensor::from_vec(
down_data,
(TRANSFORMER_HIDDEN, TRANSFORMER_INTERMEDIATE),
&device,
)
.unwrap();
let expected = swiglu_mlp_ref(&ref_input, &ref_gate, &ref_up, &ref_down);
let expected: Vec<f32> = expected.flatten_all().unwrap().to_vec1().unwrap();
assert_close(&result, &expected, 1e-3);
}
#[test]
fn metal_mini_transformer_layer() {
let mut cx = Graph::default();
let input = cx.tensor((TRANSFORMER_SEQ, TRANSFORMER_HIDDEN));
let layer = MiniTransformerLayer::init(&mut cx);
let output = layer.forward(input).output();
cx.build_search_space::<MetalRuntime>();
let mut rt = MetalRuntime::initialize(());
let input_data = seeded_data(TRANSFORMER_SEQ * TRANSFORMER_HIDDEN, 1.0, -0.5);
let weight_data = generate_layer_weights(&layer);
rt.set_data(input, &input_data);
for (tensor, data) in &weight_data {
rt.set_data(*tensor, data);
}
rt = cx.search(rt, 1);
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
let result = rt.get_f32(output);
let device = CandleDevice::Cpu;
let ref_input =
CandleTensor::from_vec(input_data, (TRANSFORMER_SEQ, TRANSFORMER_HIDDEN), &device).unwrap();
let w = |idx: usize, shape: &[usize]| {
CandleTensor::from_vec(weight_data[idx].1.clone(), shape, &device).unwrap()
};
let expected = transformer_layer_ref(
&ref_input,
&w(0, &[TRANSFORMER_HIDDEN]),
&w(1, &[TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN]),
&w(2, &[TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN]),
&w(3, &[TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN]),
&w(4, &[TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN]),
&w(5, &[TRANSFORMER_HIDDEN]),
&w(6, &[TRANSFORMER_INTERMEDIATE, TRANSFORMER_HIDDEN]),
&w(7, &[TRANSFORMER_INTERMEDIATE, TRANSFORMER_HIDDEN]),
&w(8, &[TRANSFORMER_HIDDEN, TRANSFORMER_INTERMEDIATE]),
);
let expected: Vec<f32> = expected.flatten_all().unwrap().to_vec1().unwrap();
assert_close(&result, &expected, 1e-2);
}
#[test]
fn metal_mini_transformer_layer_f16_intermediate() {
let mut cx = Graph::default();
let input = cx.tensor((TRANSFORMER_SEQ, TRANSFORMER_HIDDEN));
let layer = MiniTransformerLayer::init(&mut cx);
let normed = rms_norm(input, layer.attn_norm_w, 1e-5).cast(DType::F16);
let attn_out = self_attention(
normed,
layer.wq.cast(DType::F16),
layer.wk.cast(DType::F16),
layer.wv.cast(DType::F16),
layer.wo.cast(DType::F16),
)
.cast(DType::F32);
let x = input + attn_out;
let normed = rms_norm(x, layer.mlp_norm_w, 1e-5).cast(DType::F16);
let mlp_out = swiglu_mlp(
normed,
layer.w_gate.cast(DType::F16),
layer.w_up.cast(DType::F16),
layer.w_down.cast(DType::F16),
)
.cast(DType::F32);
let output = (x + mlp_out).output();
cx.build_search_space::<MetalRuntime>();
let mut rt = MetalRuntime::initialize(());
let input_data = seeded_data(TRANSFORMER_SEQ * TRANSFORMER_HIDDEN, 1.0, -0.5);
let weight_data = generate_layer_weights(&layer);
rt.set_data(input, &input_data);
for (tensor, data) in &weight_data {
rt.set_data(*tensor, data);
}
rt = cx.search(rt, 1);
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
let result = rt.get_f32(output);
let device = CandleDevice::Cpu;
let ref_input =
CandleTensor::from_vec(input_data, (TRANSFORMER_SEQ, TRANSFORMER_HIDDEN), &device).unwrap();
let w = |idx: usize, shape: &[usize]| {
CandleTensor::from_vec(weight_data[idx].1.clone(), shape, &device).unwrap()
};
let expected = transformer_layer_ref(
&ref_input,
&w(0, &[TRANSFORMER_HIDDEN]),
&w(1, &[TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN]),
&w(2, &[TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN]),
&w(3, &[TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN]),
&w(4, &[TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN]),
&w(5, &[TRANSFORMER_HIDDEN]),
&w(6, &[TRANSFORMER_INTERMEDIATE, TRANSFORMER_HIDDEN]),
&w(7, &[TRANSFORMER_INTERMEDIATE, TRANSFORMER_HIDDEN]),
&w(8, &[TRANSFORMER_HIDDEN, TRANSFORMER_INTERMEDIATE]),
);
let expected: Vec<f32> = expected.flatten_all().unwrap().to_vec1().unwrap();
assert_close(&result, &expected, 3e-2);
}
#[test]
fn test_scatter_basic() {
let mut cx = Graph::default();
@@ -961,7 +352,7 @@ fn test_scatter_basic() {
cx.build_search_space::<MetalRuntime>();
let mut rt = MetalRuntime::initialize(());
rt.set_data(src, &[10.0, 20.0, 30.0]);
rt.set_data(indexes, &[1.0, 3.0, 4.0]);
rt.set_data(indexes, &[1i32, 3, 4]);
rt.set_data(dest, &[0.0, 0.0, 0.0, 0.0, 0.0]);
rt = cx.search(rt, 1);
rt.allocate_intermediate_buffers(&cx.dyn_map);
@@ -982,7 +373,7 @@ fn test_scatter_into_nonzero_dest() {
cx.build_search_space::<MetalRuntime>();
let mut rt = MetalRuntime::initialize(());
rt.set_data(src, &[99.0]);
rt.set_data(indexes, &[2f32]);
rt.set_data(indexes, &[2i32]);
rt.set_data(dest, &[1.0, 2.0, 3.0, 4.0, 5.0]);
rt = cx.search(rt, 1);
rt.allocate_intermediate_buffers(&cx.dyn_map);
@@ -1003,7 +394,7 @@ fn test_scatter_all_positions() {
cx.build_search_space::<MetalRuntime>();
let mut rt = MetalRuntime::initialize(());
rt.set_data(src, &[40.0, 30.0, 20.0, 10.0]);
rt.set_data(indexes, &[3.0, 2.0, 1.0, 0.0]);
rt.set_data(indexes, &[3i32, 2, 1, 0]);
rt.set_data(dest, &[1.0, 2.0, 3.0, 4.0]);
rt = cx.search(rt, 1);
rt.allocate_intermediate_buffers(&cx.dyn_map);

View File

@@ -149,7 +149,8 @@ pub fn paged_attention(
// ── Phase 5: Reshape output ──
// (n_kv_heads, kv_groups, s, head_dim) → (s, n_kv_heads, kv_groups, head_dim)
let mut out = out.permute((2, 0, 1, 3));
// Force materialization with * 1.0 to make contiguous, then reinterpret shape
let mut out = out.permute((2, 0, 1, 3)) * 1.0;
out.shape = ShapeTracker::new((s, n_heads * head_dim));
(out, k_cache, v_cache)

View File

@@ -1,6 +0,0 @@
*.onnx
tests/llama38b_ref_logits.pt
__pycache__/
*.pyc
uv.lock
.venv

View File

@@ -1,116 +0,0 @@
A couple of short things to keep in mind
## Lessons Learned
At the end of any session that involved a hard or non-obvious bug, append an entry to
`LessonsLearned.md` in this directory. A "hard bug" means any bug that required significant
investigation — intermittent failures, wrong output without a crash, egglog/optimizer issues,
or anything that took more than a few minutes to locate.
Each entry should cover:
1. **What the symptom was** (test failure, wrong output, panic, etc.)
2. **What the actual root cause was** (the specific code/logic that was wrong)
3. **Why it was hard to find** (what made it non-obvious or intermittent)
4. **The fix** (what changed and why it works)
5. **A general principle** extracted from the bug — something that helps avoid the same
class of mistake in future code
The goal is to build a living record of codebase-specific pitfalls that future sessions can
consult before writing new egglog rules, CUDA kernels, or optimizer passes.
1. If you want to run tests:
- `./run_test.sh` - runs tests with the native backend
- `./run_tests_cuda.sh` - runs tests with the CUDA backend
## Testing Best Practices
### Overview
The luminal_python crate provides a bridge between PyTorch models and the luminal library via the PT2 Export pipeline. Tests should verify this integration end-to-end by testing the actual user workflow: PyTorch model → torch.compile → luminal backend.
### Test Pattern (CORRECT)
All tests should follow this standard pattern:
```python
def test_operation():
"""Brief description of what operation is being tested."""
# 1. Instantiate PyTorch model
model: torch.nn.Module = OperationTestModel()
# 2. Compile with luminal backend
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
# 3. Create test input
x: torch.Tensor = torch.tensor([...]) # or torch.rand(...)
# 4. Run both original and compiled versions
original: torch.Tensor = model(x)
output: torch.Tensor = model_compiled(x)
# 5. Verify outputs match
assert torch.allclose(output, original)
```
### Test Models
- Define test model classes in `tests/test_models.py`
- Each model should be a simple `torch.nn.Module` that demonstrates one operation or pattern
- Use clear, descriptive class names (e.g., `AddTestModel`, `TransposeTestModel`)
- Include docstrings explaining what the model tests
Example:
```python
class AddTestModel(torch.nn.Module):
"""Tests element-wise addition."""
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + x
```
### What NOT to Do
**❌ DO NOT create pt2 files directly in tests:**
```python
# WRONG - bypasses the PyTorch integration
model_path = create_pt2_model(...)
graph_result = luminal.process_pt(model_path, backend='native')
```
**✓ DO create PyTorch models and use torch.compile:**
```python
# CORRECT - tests actual user workflow
model: torch.nn.Module = MyTestModel()
model_compiled = torch.compile(model, backend=luminal_backend)
```
### Rationale
- **End-to-end testing**: Tests verify the complete PyTorch → Pt2 → luminal pipeline
- **User-facing API**: Tests use the same API that users will use (torch.compile)
- **Correctness**: Comparing compiled vs original PyTorch output ensures correctness
- **Maintainability**: Consistent pattern across all tests makes the codebase easier to understand
- **Simplicity**: No manual Pt2 file creation, no tempfile cleanup, no numpy comparisons
### Special Cases
**Testing constants:**
Use inline tensor literals in the forward method - these are exported as constant tensors:
```python
def forward(self, x: torch.Tensor) -> torch.Tensor:
constant = torch.tensor([1.0, 2.0, 3.0])
return x + constant
```
**Testing type casts:**
Use `.to(dtype)` method - these are exported as type cast operations:
```python
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x.to(torch.float32)
```
**Testing complex operations:**
Chain operations naturally in PyTorch - the export pipeline handles the conversion:
```python
def forward(self, x: torch.Tensor) -> torch.Tensor:
transposed = x.transpose(0, 1)
scaled = transposed * 2.0
return scaled + 1.0
```

View File

@@ -1,758 +0,0 @@
# Lessons Learned
This file documents hard bugs encountered in this codebase, their root causes, and principles
to prevent similar issues in the future.
---
## 2026-02-24 — Intermittent CUDA Backend Failures: Embed False Match + Batched Matmul Dimension Drop
### Background: Why the Failures Were Intermittent
Both bugs only appeared on roughly 50% of test runs. The source of non-determinism is
`FxHashMap` (a fixed-seed hash map). The egglog optimizer's `SerializedEGraph::new` builds
`Vec<NodeId>` orderings for each e-class by iterating a `FxHashMap`, producing non-deterministic
node orderings. `random_initial_choice()` in `src/egglog_utils/mod.rs` then randomly picks one
e-node per e-class as the starting representation for the profiling phase. The combination means
some runs pick a correct kernel and some pick a broken one from the same e-class.
**Lesson**: When a test fails intermittently at a roughly 50% rate, suspect the egglog extractor
choosing between two e-nodes in the same e-class — one correct, one broken. The fix is always in
the broken e-node's rewrite rule.
---
### Bug 1: `test_gather_elements` — KernelEmbed and RowEmbed False Match
**Files changed**:
- `crates/luminal_cuda/src/kernel/hlir.rs` (KernelEmbed, 4 rules)
- `crates/luminal_cuda/src/block/ops.rs` (RowEmbed, 4 rules)
#### What happened
`gather_elements` (axis-aware gather) decomposes into a flat gather by computing:
```
flat_idx = Add(
Mul(indices, stride[axis]),
Mul(Expand(Iota(dim_size)), stride[non_axis])
)
```
`KernelEmbed` and `RowEmbed` are optimized embedding lookup kernels. A genuine embedding
lookup produces:
```
flat_idx = Add(
Mul(Cast(token_ids), embed_dim),
Iota(embed_dim) ← bare Iota, the position within an embedding row
)
```
The egglog rewrite rules for both ops matched `Add(?mul_result, ?iota_result)` where
`?iota_result` was **unconstrained** — it could bind to anything, including
`Mul(Expand(Iota(n)), stride)` from `gather_elements`. This created a `KernelEmbed`/`RowEmbed`
node in the same e-class as the `Gather` node. When the extractor picked it, `build_payload`
called `flatten_mul_strides(range, token_stride)` which asserted `range.len() == token_stride.len()`:
- `range` came from `RemoveNthFromEnd(idx_shape, 0)` → length 1
- `token_stride` came from the indices strides → length 2
- Assertion failed → panic.
#### The fix
Add `(= ?iota_result (Iota ?iota_expr ?iota_range))` to all 8 rules, requiring the positional
component to be a bare `Iota` node:
```egglog
(= ?indices (Add ?add_shape ?mul_result ?mul_stride ?iota_result ?iota_stride ?add_out_stride))
(= ?iota_result (Iota ?iota_expr ?iota_range)) ← added
(= ?mul_result (Mul ...))
```
#### Investigation note
The initial plan correctly identified `KernelEmbed` as faulty, but missed `RowEmbed`. The two
ops are structurally identical but live in different parts of the codebase (`kernel/` vs
`block/`). The second bug was only discovered when the backtrace pointed to
`RowEmbed::build_payload` instead of `KernelEmbed::compile`. Always search for sibling
implementations when fixing a pattern-matching bug in one op.
---
### Bug 2: `test_matmul_batched` — CuBlasLt Drops Batch Dimension
**Files changed**:
- `crates/luminal_cuda/src/host/cublaslt/cublaslt_RmRm_rewrite.egg`
- `crates/luminal_cuda/src/host/cublaslt/cublaslt_RmCm_rewrite.egg`
- `crates/luminal_cuda/src/host/cublaslt/cublaslt_CmRm_rewrite.egg`
- `crates/luminal_cuda/src/host/cublaslt/cublaslt_CmCm_rewrite.egg`
#### What happened
The luminal frontend decomposes `(2,3,4) @ (2,4,5)` into:
```rust
let w = rhs.permute((0, 2, 1)); // (2,4,5) → (2,5,4)
let mul = self.expand_dim(2, d) // (2,3,4) → (2,3,5,4)
* w.expand_dim(1, b); // (2,5,4) → (2,3,5,4)
mul.sum(3) // → (2,3,5), correct out_shape
```
All four cublaslt rewrite rules extracted `m` and `n` from the output shape using
`nth_from_end`, which succeeds for any rank:
```egglog
(= ?m (nth_from_end ?out_shape 1))
(= ?n (nth_from_end ?out_shape 0))
```
For `out_shape = [2, 3, 5]`: `?m = 3`, `?n = 5`. The batch dim `2` is never extracted or
stored. The rules also validated stride patterns using `nth_from_end` on the stride arrays —
but for this batched case, **all stride checks coincidentally passed** because the last three
strides of the 4D expanded tensors happened to satisfy the 2D row/column-major patterns.
The resulting `CuBlasLt` node had `output_size() = m * n = 15`. The batch dimension was
silently discarded. The runtime allocated a 15-element output buffer, cuBLAS wrote a 3×5
result, and the test got back 15 values instead of 30.
#### The fix
Add `(= (len ?out_shape) 2)` to all 4 rules:
```egglog
(= (len ?out_shape) 2) ← added: cuBLAS is 2D only
(= ?m (nth_from_end ?out_shape 1))
(= ?n (nth_from_end ?out_shape 0))
```
`len` counts elements in the `ECons`-list shape. With this constraint, any `Sum` node with a
3D+ output shape (batched matmul) is not matched by cuBLAS rules and falls through to
`KernelSumReduce + KernelMul` (or the tiling block ops), which correctly use
`out_shape.iter().product()` for their output sizes.
Note: `TileMatmulSplitK` and `TileMatmulFullSplit` do NOT need this fix — their `output_size()`
already returns `untiled_range.iter().product()` which includes all dimensions.
---
### General Principle: Always Constrain Shape Rank in Egglog Rules
Both bugs share the same structural cause: **egglog rewrite rules that used `nth_from_end` to
extract dimensions from a shape list without constraining the list's length.** Since
`nth_from_end` silently succeeds for any list with enough trailing elements, rules written for
2D tensors accidentally matched higher-rank tensors.
**Rule for writing egglog rewrite rules in this codebase**:
> If a rule is designed for a specific tensor rank, always add an explicit
> `(= (len ?shape) N)` constraint. If a rule is designed to handle arbitrary ranks but an
> op's output only covers a subset of dimensions (like cuBLAS covering only the last 2),
> that is a correctness bug — either implement strided batched cuBLAS or add the rank
> constraint and fall back to a kernel that handles all dimensions.
---
### Debugging Intermittent CUDA Failures: Effective Approach
The investigation used extensive `eprintln!` debug logging to trace which kernels were compiled
vs. skipped. Key observations:
1. **In the passing case**: `KernelSumReduce::compile()` was called, kernels were allocated.
2. **In the failing case**: `KernelSumReduce::compile()` was never called, yet output was produced.
This asymmetry pointed to a `HostOp` path (cuBLAS) executing instead of the `KernelOp` path,
which narrowed the search to cublaslt rewrite rules. The HLIR-level `SumReduce::to_egglog` log
confirmed the correct HLIR node existed — the bug was in the e-graph optimization choosing
a different (broken) e-node from the same e-class.
**Effective debug strategy for egglog non-determinism bugs**:
1. Add logging at compile time for each kernel type (`KernelFoo::compile`, `HostFoo::execute`)
2. Compare passing vs. failing runs to see which kernels are/aren't invoked
3. The missing kernel's e-class contains a broken alternative — find it via the egglog rewrite rules
4. Check the op that *is* executing — its `output_size()` reveals what's wrong with the false match
---
## 2026-02-25 — OneHot Test Panic: Cast(Int→F32) Produces Int Output
### What the symptom was
`test_onehot` panicked at `src/hlir.rs:1625` in `get_f32()`: the output buffer was
`NativeData::Int` instead of the expected `NativeData::F32`.
### What the actual root cause was
The Cast parser's `* 1.0` workaround for `Int → F32` casts used `input * one_expanded`
(Int GraphTensor on the left, F32 constant on the right). However, `Mul for GraphTensor`
always uses `self.dtype` (the **left** operand's dtype) for the result, and the native
runtime's `Mul::execute` dispatches on the **first** input's `NativeData` variant. So
`Int * F32` produced `DType::Int` / `NativeData::Int` — the exact opposite of the intended
F32 output.
### Why it was hard to find
1. **The OneHot parser was a red herring**: The initial plan assumed the OneHot ONNX node
was being parsed, but `torch.onnx.export` decomposes `one_hot` into
`Unsqueeze → Equal → Cast(Bool→Int) → Cast(Int→F32)`. The OneHot parser was never called.
2. **The `* 1.0` workaround looked correct**: It was used successfully in many other parsers,
but those all had F32 inputs (where `F32 * F32 = F32`). The Int→F32 case was the only
path where the left operand was Int.
3. **Operand order matters silently**: Nothing warns about mixed-dtype Mul — it just takes
the left operand's dtype.
### The fix
In `ops_parse/unary.rs` `parse_cast_node`, split the combined condition into two cases:
- **No-op cast** (`cast_result.id == input.id`): `input * one_expanded` — preserves dtype
- **Int source** (`input.dtype == DType::Int`): `one_expanded * input` — F32 on the left
ensures F32 output
### General principle
**In luminal, binary op dtype is always the LEFT operand's dtype.** When constructing
`GraphTensor * constant_float(1.0)` for type materialization, always put the operand
whose dtype you want to preserve on the LEFT side. When converting Int→F32, the F32
constant must be the left operand.
---
## 2026-02-26 — ScatterND Fails on CUDA: "does not produce an egraph"
### What the symptom was
`test_scatter_nd` passed on native backend but failed on CUDA with "does not produce an
egraph". The CUDA compilation could not extract a valid program from the e-graph.
### What the actual root cause was
`scatter_nd` in `movement.rs` does `indices * 1` (line 353) to materialize the tensor for
reshaping. The `* 1` dispatches to `Mul<S: Into<Expression>>`, which creates a `constant(1)`
`Iota(1,1)``DType::Int`. But the ONNX parser creates all tensors as `DType::F32`
(via `named_tensor()` in `compiled_graph.rs:70`), so indices arrive as F32. This produces
`Mul(F32, Int)` — mixed dtypes.
The HLIR Mul dtype rule (`hlir.rs:886-888`) uses `(= ?dty (dtype ?lhs))` and
`(= ?dty (dtype ?rhs))` with the same `?dty` variable, requiring both inputs to have
matching dtypes. `F32 != Int` → the rule never fires → the Mul node gets **no dtype**.
Every downstream op checks `(= ?dty (dtype ?upstream))`. Without dtype on the Mul, no
CUDA kernel rewrite rules fire for any downstream op (KernelMul, KernelAdd, KernelLessThan,
etc.). When `cleanup_hlir` runs (enabled for CUDA, disabled for native), it deletes all
unrewritten HLIR ops, leaving empty e-classes → egraph extraction fails.
### Why it was hard to find
1. **Works on native**: `cleanup_hlir = false` for NativeRuntime, so unrewritten HLIR ops
are never deleted. NativeOp dispatches on actual runtime data, not egglog dtype.
2. **Cascading failure**: The root cause (missing dtype on one Mul) silently propagated
through every downstream op, making it look like a systemic CUDA issue rather than a
single dtype mismatch.
3. **`scatter_elements` works fine**: The sibling op already cast indices via
`(idx_f32 + (is_neg * adj)).cast(DType::Int)`, so only `scatter_nd` had this bug.
### The fix
Added `let indices = indices.cast(DType::Int);` at the top of `scatter_nd` in
`movement.rs`, before any arithmetic on indices. `GraphTensor::cast()` short-circuits
when `self.dtype == dtype`, so this is safe for callers already passing Int indices.
Also added the same cast in `parse_scatter_nd_node` for explicitness.
### General principle
**Always cast index tensors to `DType::Int` before arithmetic in graph-building code.**
ONNX tensors arrive as F32 from the Python bridge. Any `indices * stride` or
`indices * 1` will produce `Mul(F32, Int)` which breaks HLIR dtype propagation on CUDA.
The pattern `let indices = indices.cast(DType::Int);` at the top of any index-consuming
function is defensive and free (no-op when already Int).
---
## 2026-03-04 — Dynamic Shapes: Empty Buffer for BOOL Scalar Initializer
### What the symptom was
`test_hf_llama_decode_loop_dynamic` panicked at `bin_fn: a index 0 out of bounds (a.len=0), shape=[1, 1, 4, 4], strides=[0, 0, 0, 0]`. An Input node labeled `"new_ones"` had an empty buffer at runtime.
### What the actual root cause was
Two issues combined:
1. **`load_tensor_floats` didn't handle ONNX data_type=9 (BOOL)**. The `new_ones` initializer was a BOOL scalar (1 byte in `raw_data`). `load_tensor_floats` fell through to the fallback case, which tried `chunks_exact(4)` on 1 byte → produced 0 chunks → returned empty vec `[]`. The buffer was set with empty data.
2. **Scalar initializers with empty `dims` created 0-dimensional tensors**. ONNX represents scalars with `dims=[]`. The initializer loop computed `shape = init.dims.iter().map(|&d| d as usize).collect()` → empty vec `[]`, then called `named_tensor(name, [])` which created a tensor with 0 dimensions instead of the intended scalar `[1]`.
### Why it was hard to find
1. **Misdiagnosed as ConstantOfShape issue**: The original plan targeted `ConstantOfShape` with dynamic shapes. The shape `[1,1,4,4]` with strides `[0,0,0,0]` looked like a broadcast from a constant fill. But `parse_constant_of_shape` was never called — the `new_ones` tensor came from an ONNX initializer, not a computation node.
2. **The BOOL data type is unusual**: Most ONNX tensors are FLOAT, INT32, or INT64. BOOL initializers only appear in specific patterns (like `torch.ones()` in attention mask computation). `load_initializer_as_f32` already handled BOOL, but its sibling `load_tensor_floats` didn't.
3. **Empty vec is valid data**: `set_data(node_id, [])` doesn't panic — it silently sets an empty buffer. The error only manifests later when a downstream op tries to read index 0.
### The fix
1. Added `data_type=9` (BOOL) handling to `load_tensor_floats` in `util.rs` — same logic as `load_initializer_as_f32`: 1 byte per element, non-zero → 1.0, zero → 0.0.
2. In `compiled_graph.rs`, initializer tensor creation: if `shape.is_empty()`, set `shape = vec![1]` (scalar representation in luminal).
### General principle
**Keep data loading functions in sync.** `load_tensor_floats` and `load_initializer_as_f32` serve the same purpose (loading ONNX TensorProto data as f32) but had different data type coverage. When adding a new data type to one, check and update the other. Better yet, refactor them into a single function.
**ONNX scalars have `dims=[]`, luminal scalars have shape `[1]`.** Always convert empty dims to `[1]` when creating luminal tensors from ONNX data.
---
## 2026-03-04 — Where Node Missing Broadcast: KernelMul flatten_strides Panic on CUDA
### What the symptom was
`test_hf_llama3_1b_decode_loop_dynamic` panicked at `flatten_strides` with `left: 4, right: 1` during
CUDA `KernelMul::compile`. The `KernelMul` had `out_shape=[1, 1, a, a]` but `b_stride=[z]` (1D).
### What the actual root cause was
`parse_where_node` called `x.cond(condition, y)` without broadcasting the inputs to matching ranks.
The ONNX Where op for the attention mask had condition=[1,1,a,a] (4D), x=[1] (scalar), y=[1] (scalar).
Luminal's `cond` doesn't auto-broadcast — it passes the shape trackers directly to the HLIR node.
The resulting Mul had input A with 4D strides and input B with 1D strides.
### Why it was hard to find
1. **Only triggered by 1B model**: The tiny model's Where inputs all had matching ranks (no scalars).
2. **CUDA-only**: The native runtime's `bin_fn` uses `StridedIterator` which handles mismatched
strides more gracefully. CUDA's `KernelMul::compile` calls `flatten_strides` which asserts
`range.len() == strides.len()`.
3. **Delayed crash**: The mismatch was created during ONNX parsing but only manifested during
CUDA kernel compilation (graph search phase).
### The fix
Added numpy-style broadcasting to `parse_where_node`: compute the broadcast shape across all 3
inputs, then `broadcast_to_expr` each to the common shape before calling `cond`.
### General principle
**ONNX binary/ternary ops all use numpy broadcasting.** When parsing ONNX ops that take multiple
tensor inputs (Where, Add, Mul, etc.), always broadcast all inputs to a common shape BEFORE
calling the luminal graph operation. Luminal graph ops do NOT auto-broadcast — they expect inputs
with matching shape tracker dimensions.
---
## 2026-03-05 — TopK Values Wrong on CUDA (gather_elements with sliced non-contiguous indices)
1. **Symptom**: `test_topk_values` failed on CUDA — rows 0-1 were correct but rows 2+ returned
the value at column 0 of each row (all three top-k positions got the same value).
Native backend was fine.
2. **Root cause**: `gather_elements` was called with a non-contiguous index tensor produced by
`argsort(axis=1) → slice_along(..k, axis=1)`. The slice creates a ShapeTracker view of the
[4,8] argsort buffer with dims [4,3] and strides [8,1]. When this flowed through the
gather_elements Int arithmetic chain (cast, multiply, add) and into the final Gather CUDA
kernel, the non-contiguous strides caused incorrect index reads for later rows.
3. **Why it was hard to find**: `test_topk_indices` passed (it only tests argsort+slice, not
the downstream gather_elements). A standalone `test_gather_elements` with constant indices
also passed because constant indices are contiguous. The bug only manifested when runtime-
computed non-contiguous indices were used with data of a different size along the gather axis.
4. **Fix**: In `parse_topk_node`, compute `gather_elements(x, full_argsort, axis)` with the
full [4,8] argsort result (same size as data), then slice the gathered values to [4,3].
This ensures gather_elements always operates on same-sized contiguous tensors.
5. **General principle**: When building graph operations that chain shape-tracker views
(slice, transpose, etc.) into downstream HLIR ops on CUDA, prefer operating on full
contiguous tensors first and slicing the result afterward. Non-contiguous views flowing
through multiple CUDA kernels can trigger stride-related bugs in the egglog-compiled code.
---
## 2026-03-07 — Non-deterministic CUDA_ERROR_ILLEGAL_ADDRESS: Multiple Missing Rank Constraints
### What the symptom was
`test_hf_llama_tiny` on CUDA failed ~70% of runs with `CUDA_ERROR_ILLEGAL_ADDRESS`. Failures
were non-deterministic due to egglog's `FxHashMap` iteration order in `random_initial_choice()`.
### What the actual root cause was
**Multiple** matmul egglog rules lacked `(= (len ?out_shape) 2)` constraints:
1. `TileMatmulSplitK` in `block/ops.rs` (disabled via comment but rule still registered)
2. `TileMatmulFullSplit` in `block/ops.rs`
3. All 4 `sgemm_v2_*.egg` rules in `host/cublas/`
The `cublaslt_*.egg` rules already had the constraint. When egglog picked TileMatmul or sgemm
for a 3D+ batched matmul, the generated CUDA kernels accessed out-of-bounds memory.
Additionally, `KernelEmbed` in `kernel/hlir.rs` had an output indexing bug:
`out[out_offset * embed_dim + embed_idx]` should be `out[out_offset + embed_idx]` because
`out_offset` already includes the embed_dim factor from `flatten_strides`.
**Most critically**, the KernelEmbed and RowEmbed "with cast" egglog rules passed the
**pre-cast** float token_ids (`?token_ids`) to the embed kernel instead of the **post-cast**
int token_ids (`?token_ids_cast`). The CUDA kernel reads token_ids as `const int*`, so float
data gets reinterpreted as enormous garbage integers, causing out-of-bounds embed table access.
### Why it was hard to find
1. **Multiple independent bug sources**: The ~70% failure rate was caused by three separate bugs
(matmul rank, embed output indexing, embed pre-cast input). Each fix only reduced the rate
partially, making it seem like each fix was insufficient.
2. **CudaGraph wrapping**: The crash occurred inside `CudaGraphOp::execute_internal` which
batches multiple kernels via CUDA graphs. The error just said "CudaGraph" — it
didn't identify which kernel crashed. Adding per-kernel debug launches was essential.
3. **Cascading failures**: When the Megakernel (containing RowEmbed with the pre-cast bug)
corrupted the embed output, the NEXT CudaGraph group's kernels crashed reading the garbage.
This made the Megakernel appear to be the victim, not the source.
4. **The pre-cast bug only crashes SOMETIMES**: Egglog's random choice determines whether
KernelEmbed/RowEmbed is selected (crash) or the generic Gather path is used (works).
Float token_id 1.0 (= 0x3F800000 = 1065353216 as int) produces an astronomically large
embed table index, causing ILLEGAL_ADDRESS.
### The fix
- Added `(= (len ?out_shape) 2)` to TileMatmulSplitK, TileMatmulFullSplit, and all 4 sgemm_v2 rules
- Fixed KernelEmbed output indexing: `out[out_offset + embed_idx]`
- **Fixed KernelEmbed/RowEmbed "with cast" rules**: Changed input from `?token_ids` to
`?token_ids_cast` — using the post-Cast int tensor instead of the pre-Cast float tensor
### Results
Failure rate: ~70% → 0% (20/20 passing). All three bugs needed to be fixed together.
### General principle
**When an egglog rule matches a sub-expression chain (like Cast→Mul→Add), be precise about
which intermediate result becomes each input.** The "with cast" embed rules matched
`Cast(?token_ids, ...)` to verify the Cast existed, but then passed `?token_ids` (the Cast
INPUT) instead of `?token_ids_cast` (the Cast OUTPUT) to the embed kernel. The kernel expects
int data, so the pre-cast float data was reinterpreted as garbage ints.
**Always search for sibling implementations**: KernelEmbed (in `kernel/hlir.rs`) and RowEmbed
(in `block/ops.rs`) had the SAME bug in their "with cast" rules. Fixing one without the other
only reduces the failure rate — both must be fixed.
---
## 2026-03-09 — TileMatmulFullSplit Matches Element-wise Square+Sum from LayerNorm
### What the symptom was
`test_qwen_image_transformer_tiny` on CUDA produced NaN in specific output rows. The failure
was non-deterministic (~85% failure rate) due to egglog's random e-class extraction picking
TileMatmulFullSplit for some operations.
### What the actual root cause was
The `TileMatmulFullSplit` rewrite rule in `block/ops.rs` matched any `Mul + Sum` pattern with
a 2D output, contiguous K-strides, and F32 inputs. This correctly matched real matmuls, but
ALSO matched the element-wise `x * x + Sum(last_dim)` pattern from LayerNorm/RMSNorm
(Pow(x, 2) → ReduceMean).
For a [1, 4, 64] activation tensor `x`:
- `Mul(x, x)` shape: [1, 4, 64], strides: [256z, 64z, z] for both inputs
- `Sum(dim=2)` output: [1, 4], len=2 ✓
TileMatmulFullSplit interpreted this as a [1, 64] × [64, 4] → [1, 4] matmul with:
- A = row 0 of x (64 elements), B = same buffer at column offsets
The kernel computed `C[j] = sum_k x[k] * x[j*64+k]` (cross-products) instead of the correct
`C[j] = sum_k x[j*64+k]^2` (squared sums). This produced subtly wrong values for j > 0
(correct for j=0 since cross-product with self = squared sum). These wrong values propagated
through LayerNorm → downstream operations → softmax → NaN.
Key diagnostic: adding `printf` to the kernel showed `a_ptr == b_ptr` (same buffer for both
inputs), confirming the kernel was operating on `x * x` not a real matmul.
### Why it was hard to find
1. **Individual op tests passed**: Simple Gemm tests, attention tests, and all other bisection
tests passed because they didn't have the specific `x*x → Sum` pattern.
2. **Non-deterministic**: The bug only manifested when egglog selected TileMatmulFullSplit
over the kernel fallback for the square+sum operation.
3. **No NaN from TileMatmulFullSplit itself**: The kernel produced wrong-but-finite values.
NaN only appeared downstream through softmax (exp(large) → ∞ → ∞/∞ = NaN).
4. **Systematic elimination needed**: Had to disable all block ops, then enable one at a time,
to narrow down TileMatmulFullSplit as the culprit.
### The fix
Added matmul broadcast constraints to both `TileMatmulFullSplit` and `TileMatmulSplitK` rules:
```egglog
; Assert proper matmul broadcast pattern:
; A is broadcast over N (a_n_stride = 0), B is broadcast over M (b_m_stride = 0)
(= ?a_n_stride (MNum 0))
(= ?b_m_stride (MNum 0))
```
In a real matmul `[M, K] × [K, N]`, the Mul is created by expanding dims:
- A is broadcast over N → a_n_stride = 0
- B is broadcast over M → b_m_stride = 0
In element-wise `x * x`, both strides are identical (non-zero for all dims), so the
constraints correctly reject it. The cuBLAS `.egg` rules already had these constraints.
### General principle
**Matmul Mul+Sum patterns have specific broadcast structure: one input is broadcast over M
and the other over N.** When writing egglog rules that match `Mul + Sum` patterns for matmul
optimization, always verify the broadcast pattern (`a_n_stride = 0` and `b_m_stride = 0`).
This prevents matching element-wise operations like `x*x → sum` that happen to have a 2D
output and contiguous strides.
---
## 2026-03-09 — Conv3D Permute Axis Mismatch in ONNX Conv Parser
### Symptom
`test_qwen_image_vae_decoder_tiny` panicked with:
> Permute axes (5) doesn't match shape axes (6)
at `src/shape/tracker.rs:153`, during `parse_conv_node`.
### Root cause
The Conv parser's unfold → matmul algorithm used two consecutive permutes with incorrect
index calculations. After unfold produces a 2N-dimensional tensor
`[win_0..win_{N-1}, k_0..k_{N-1}]`, the first permute swapped kernel dims to the front.
But the second permute's index math still assumed the original (pre-first-permute) ordering,
confusing kernel dimensions with window dimensions. Additionally:
1. `output_spatial_dims` was captured from wrong indices (kernel dims instead of window
spatial dims)
2. The `split_dims` loop iterated `spatial` times instead of `spatial-1`, creating a
spurious size-1 dimension
3. The final permute array had `1+spatial` elements for a tensor with `2+spatial` dims
For Conv2D (spatial=2) this was never caught because the xfail'd VAE decoder test was the
only test exercising the Conv parser — the transformer tests don't use Conv ONNX nodes.
### Why it was hard to find
The Conv parser was written and the VAE test immediately xfail'd due to a *different* bug
(`merge_dims` being `todo!()`). Once `merge_dims` was implemented, the Conv parser's own
bugs surfaced for the first time.
### Fix
Rewrote the unfold → matmul section with a single correct permute:
1. **One permute** to `[N, win_spatial..., C_in, k_batch, k_chan, k_spatial...]`
— groups batch | output spatial | channel+kernel
2. **Capture** `output_spatial_dims` from correct indices `[1..1+spatial]`
3. **Merge** all channel+kernel dims from the end into one
4. **Merge** spatial dims into one → `[N, spatial_product, C_in*kernel_product]`
5. **Matmul**`[N, spatial_product, C_out]`
6. **Split** spatial back with `spatial-1` splits (not `spatial`)
7. **Permute** C_out to position 1 with correct `2+spatial` element array
### General principle
**When chaining permutes on high-dimensional tensors, prefer a single combined permute.**
Multiple permutes with hand-computed index arrays are error-prone because each permute
redefines what indices mean. A single permute from the original layout to the target layout
is easier to verify and less likely to confuse source/destination ordering. Also, ensure
`split_dims` loop counts match: splitting N dims out of a product requires N-1 splits
(the outermost dim is the quotient, not split out separately).
---
## 2026-03-18 — CUDA Search Rejects All Candidates: Zero Dummy Data Causes NaN for Div/Pow/Mod/Erf
### What the symptom was
6 CUDA tests (`test_pow`, `test_pow_broadcast`, `test_div`, `test_mod`, `test_mod_broadcast`,
`test_erf`) consistently failed with `Failed to find a viable initial genome for group 0 after
100 attempts`. All 6 passed on native backend.
### What the actual root cause was
The CUDA two-phase initialization in `build_cuda_backend` set ALL input tensor buffers to
`0.0f32` as dummy data for profiling. When `torch.compile` decomposes a model, it passes
model weights as additional ONNX graph inputs (not initializers). Since there were no ONNX
initializers to overwrite the zeros, weight buffers stayed all-zero during search.
Operations with zero inputs produced NaN:
- `fmod(0, 0) = NaN` (Mod test)
- `weight * recip(0) = weight * inf` → with any zero weight → `0 * inf = NaN` (Div test)
- `abs(0).log() = log(0) = -inf` → downstream NaN (Pow test)
- `sign(0)` chain → operations on zero inputs (Erf test)
The `has_nan_outputs` check rejected every candidate genome, exhausting all 100 attempts.
### Why it was hard to find
1. **No panic, no crash — silent NaN rejection**: The error message said "Failed to find a
viable initial genome" which suggested an egglog rewrite issue, not a data issue.
2. **Works on native**: `NativeRuntime::has_nan_outputs()` returns `false` by default (no NaN
check), so zero inputs never caused problems on native.
3. **torch.compile vs direct export difference**: Directly exporting a model via
`torch.onnx.export(model, ...)` produces initializers. But `torch.compile`'s backend
receives a `GraphModule` where weights are graph inputs, not initializers. The ONNX file
from `torch.compile` has 0 initializers.
4. **CudaRuntime's own `allocate_dummy_input` already uses 1.0**: The runtime knew zeros
were problematic (comment: "Zero inputs often hide numerical issues"), but the
`compiled_graph.rs` code used `0.0f32` independently.
### The fix
Changed dummy data from `vec![0.0f32; n_elements]` to `vec![1.0f32; n_elements]` in
`build_cuda_backend`. Using 1.0 is numerically safe: `fmod(1,1)=0`, `recip(1)=1`,
`log(1)=0`, `exp(1)≈2.7` — no NaN or inf. Profiling timing is unaffected (same number
of FLOPs and memory accesses).
### General principle
**Use small non-zero values (1.0) for dummy profiling data, never zeros.** Zero is a
singularity for many floating-point operations (division, log, fmod with zero divisor).
The CUDA runtime's `allocate_dummy_input` already followed this principle — the ONNX
pipeline's `build_cuda_backend` was inconsistent. When creating dummy data for GPU
profiling, always match the runtime's safer default.
---
## 2026-03-18 — Dynamic Decode Loop Fails: HLIR Weight Buffers Consumed After First Execute
### What the symptom was
`test_hf_llama3_1b_decode_loop_dynamic` passed step 0 (seq_len=6) but panicked on step 1
(seq_len=7) with `no entry found for key` at `cublaslt/mod.rs:294` — the CuBlasLt op couldn't
find its weight input buffer.
### What the actual root cause was
**Two bugs:**
1. **Missing `)` in egglog rule** (`luminal_cuda_lite/src/kernel/hlir.rs:3042`): The fourth
KernelEmbed rule ("kernel embed with mul reversed") had 3 closing parens after `INil` instead
of 4. The missing `)` failed to close the `(= ?mul_result ...)` form. This caused an egglog
parse error during search, caught by `catch_unwind`. The rule was dead code — it never fired,
but the parse error consumed a search iteration.
2. **HLIR buffer consumption killed weight buffers** (`luminal_cuda_lite/src/runtime.rs:1010-1057`):
After each `execute()`, the runtime removed all HLIR buffers (weights, constants) except those
directly connected to Output nodes. This was intended to free one-shot input data, but it also
deleted all 168 weight buffers. On the next `graph.run()`, CuBlasLt couldn't find any of its
weight inputs — `hlir_buffers` had 1 entry (the just-set `input_ids`) instead of 169.
### Why it was hard to find
1. **Misdirection by the egglog syntax error**: The plan identified the missing `)` as THE cause.
Fixing it allowed the rule to parse correctly, but the real runtime failure was independent.
2. **Step 0 always succeeds**: The weight consumption happens AFTER a successful execution. So
the first `graph.run()` works perfectly — all 169 HLIR buffers exist. The panic only occurs
on the second call, after consumption has cleared 168 of them.
3. **The consumption code was deliberately designed**: Comments said "weight tensors must have
`.persist()` to survive." The ONNX pipeline didn't call `.persist()` on weights, but this
had never been a problem before because single-shot inference only calls `execute()` once.
4. **Search phase panics masked by `catch_unwind`**: The same "no entry found for key" error
occurred during profiling of search candidates, but was silently caught. This made it look
like only certain LLIR variants had the issue, not all of them.
5. **Debug output needed 4 iterations to find**: The first debug showed which NodeIndex was
missing, the second showed it was an Input node, the third showed the HLIR mapping, and
the fourth revealed `hlir_buffers_count` dropping from 169 to 1 between steps.
### The fix
1. Added missing `)` to the KernelEmbed egglog rule at `hlir.rs:3042`.
2. In `compiled_graph.rs`, added `.persist()` calls on all weight/constant tensors (anything
not in `input_names`) after `process_onnx_nodes` completes. `.persist()` creates an Output
node connected to the Input, which the consumption code recognizes as "do not consume."
User inputs (like `input_ids`) are intentionally NOT persisted — they are consumed after
each `execute()` and re-set via `set_input()` before the next call.
### General principle
**Mark weight/constant tensors as persistent in the graph-building pipeline.** The runtime's
`execute()` consumes all HLIR buffers not connected to Output nodes. This is correct behavior
for one-shot user inputs, but weights must survive across calls. Always call `.persist()` on
tensors that should outlive a single execution. In the ONNX pipeline, the distinction is clear:
`input_names` (user-provided data per step) vs everything else (weights/constants loaded once).
---
## 2026-03-20 — PT2 CUDA Search Rejects All Candidates: Integer Buffers Misinterpreted as Float NaN
### What the symptom was
`test_hf_llama_tiny` on CUDA via PT2 failed with:
`pyo3_runtime.PanicException: Failed to find a viable initial genome for group 0 after 100 attempts`
The search tried 100 different egglog rewrites and ALL were rejected by the `has_nan_outputs` check.
### What the actual root cause was
**Two issues, both required to fix:**
1. **Integer buffers misinterpreted as float in NaN check.** `has_nan_outputs` in
`luminal_cuda_lite/src/runtime.rs` checks ALL `self.buffers` by reinterpreting raw bytes
as `f32` and calling `is_nan()`. The PT2-translated graph has integer intermediate
buffers (from `arange`, `cast(Int)`, integer arithmetic for embedding index computation).
Certain valid `i32` bit patterns (e.g., large integers from `token_id * hidden_dim`)
have exponent=0xFF and non-zero mantissa when reinterpreted as f32 — matching the
IEEE 754 NaN pattern. This caused false NaN rejections for EVERY candidate genome.
2. **Real weights/constants loaded before search contain -inf.** The PT2 path loaded real
safetensors weights and model constants (including the causal attention mask with `-inf`
values) BEFORE the search. While the ONNX path also loads real initializer data before
search, the PT2 graph's different structure (more explicit integer operations) made the
integer NaN false-positive the blocking issue.
### Why it was hard to find
- The original plan diagnosed this as the same zero-dummy-data bug fixed on 2026-03-18.
Changing `0.0` to `1.0` was insufficient because the root cause was different.
- `has_nan_outputs` checking ALL intermediate buffers (not just outputs) masked the real
issue — the NaN was in integer index-computation buffers, not in the model's float outputs.
- The ONNX-translated graph didn't have this problem because it doesn't produce as many
integer intermediate buffers (ONNX embedding uses different ops).
- The NaN pattern was identical across all 100 search attempts, which was the key clue:
it was deterministic and independent of egglog rewrite choices, pointing to input data
or buffer interpretation rather than graph optimization issues.
### The fix
Four changes:
1. **`luminal_cuda_lite/src/kernel/mod.rs`** (`KernelOp` trait): Added `output_dtype()`
method with default `DType::F32`. Each kernel now reports its actual output dtype.
2. **`luminal_cuda_lite/src/kernel/hlir.rs`** and **`other_ops.rs`**: Overrode
`output_dtype()` in all kernels with a `dtype` field (returns `self.dtype`), plus
special cases: `KernelIota``DType::Int`, `KernelLessThan``DType::Bool`,
`KernelCast``self.out_dtype`.
3. **`luminal_cuda_lite/src/runtime.rs`** (`has_nan_outputs`): Replaced fragile
`format!("{:?}").contains("dtype: Int")` string matching with proper
`op.to_dialect::<dyn KernelOp>().output_dtype()` check. Only F32 buffers are
checked for NaN; integer and bool buffers are skipped.
4. **`rust/src/pt2_compiled_model.rs`** (`init_cuda_runtime`): Set ALL input nodes
(weights, constants, user inputs) to `vec![1.0f32; n_elements]` before search via
new `set_all_inputs_dummy_cuda` function, then reload real data after search.
This prevents any -inf values from the causal mask from polluting intermediate
float computations during profiling.
### General principle
**Never reinterpret integer buffer bytes as float for NaN checking.** When a graph has
mixed-dtype operations (float model computation + integer index computation), raw byte
buffers from integer kernels contain valid i32 values that look like NaN when cast to f32.
The search's `has_nan_outputs` must be dtype-aware — use the kernel's `output_dtype()`
method rather than string-matching on Debug output. Additionally, when diagnosing "all
candidates rejected" during search, check whether the rejection is from actual float NaN
or from dtype misinterpretation — the key diagnostic is whether the NaN pattern is
identical across all attempts (dtype issue) vs varying (actual numerical issue).
## 2026-03-25 — KernelExp/KernelSigmoid: Fused CUDA Kernels for Precision
1. **Symptom**: `test_hf_llama3_full` (16-layer Llama-3.2-1B) had ~1e-4 max diff vs PyTorch.
2. **Root cause**: `exp(x)` was computed as `exp2(x * 1.442695)` — the constant truncated by `{:.6}` format + extra multiply adds rounding. Sigmoid was 5 separate kernels. SumReduce had naive accumulation.
3. **Why hard**: Per-operation error was ~1e-7 but compounded over 16 layers × ~25 extra materializations. The egglog `Exp` rewrite depends on exact constant format matching.
4. **Fix**: Added `KernelExp` (uses `expf()`), `KernelSigmoid` (uses `1/(1+expf(-x))`), and Kahan summation in SumReduce. Each uses both `kernel_rewrite` and a direct egglog pattern match with range checks (e.g., `(> ?val 1.44) (< ?val 1.45)`) to bypass constant format dependency.
5. **Principle**: When decomposed CUDA kernel chains cause precision loss, add fused kernels via `kernel_rewrite`. For robustness, add BOTH the logical-op rewrite path AND a direct HLIR pattern match — the constant format in egglog can be fragile.

View File

@@ -1,369 +0,0 @@
"""Run pytest on Modal with a dynamically selected GPU.
Usage:
uv run modal run modal_pytest_runner.py --gpu A100 tests/test_llama3.py::test_hf_llama3_full -v
uv run modal run modal_pytest_runner.py --gpu T4 tests/
uv run modal run modal_pytest_runner.py --gpu A100 --profile tests/ -v
"""
import argparse
import json
import os
import shutil
import subprocess
import sys
import uuid
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
import modal
from modal.volume import FileEntryType
app = modal.App("luminal-tests")
DEFAULT_TIMEOUT = 30 * 60
CUDARC_CUDA_VERSION = "12080"
LOCAL_PROJECT_DIR = Path(__file__).resolve().parent
PROJECT_DIR = "/root/luminal/crates/luminal_python"
VENV_PATH = "/root/.cache/luminal/uv-project-environments/luminal_python"
SRC_PATH = f"{PROJECT_DIR}/src"
PROFILE_VOLUME_NAME = "luminal-pytest-profiling"
PROFILE_VOLUME_PATH = "/root/pytest-profile-artifacts"
PROFILE_LOCAL_DEFAULT_ROOT = "luminal_artifacts/pytest-profiling"
PROFILE_SCRATCH_ROOT = "/tmp/luminal-pytest-profiling"
HF_CACHE_VOLUME_NAME = "luminal-hf-cache-v2"
HF_CACHE_PATH = "/root/.cache/huggingface"
HF_TOKEN_ENV_KEY = "HF_TOKEN"
PROFILE_VOLUME = modal.Volume.from_name(PROFILE_VOLUME_NAME, create_if_missing=True)
HF_CACHE_VOLUME = modal.Volume.from_name(
HF_CACHE_VOLUME_NAME,
create_if_missing=True,
version=2,
)
image = (
modal.Image.from_registry("ghcr.io/luminal-ai/luminal-docker:cuda")
.env({"CUDARC_CUDA_VERSION": CUDARC_CUDA_VERSION})
.uv_sync(
str(LOCAL_PROJECT_DIR),
frozen=False,
groups=["dev"],
env={"UV_PROJECT_ENVIRONMENT": VENV_PATH},
)
.workdir(PROJECT_DIR)
.add_local_dir(
str(LOCAL_PROJECT_DIR.parent.parent),
remote_path="/root/luminal",
copy=True,
ignore=[
".git",
".claude-project",
".cargo-local",
"**/.venv",
"**/.pytest_cache",
"**/__pycache__",
"**/luminal_artifacts",
"**/target",
"docs",
],
)
)
def _utc_now() -> str:
return datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")
def _hf_token_secret() -> modal.Secret | None:
hf_token = os.environ.get(HF_TOKEN_ENV_KEY)
if not hf_token:
return None
return modal.Secret.from_dict({HF_TOKEN_ENV_KEY: hf_token})
def _has_pytest_flag(pytest_args: list[str], flag: str) -> bool:
return any(arg == flag for arg in pytest_args)
def _profiling_enabled(cli_profile: bool, pytest_args: list[str]) -> bool:
return (
cli_profile
or _has_pytest_flag(pytest_args, "--profile")
or _has_pytest_flag(pytest_args, "--profile-svg")
)
def _run_id() -> str:
timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
return f"{timestamp}-{uuid.uuid4().hex[:8]}"
def _prepare_scratch_dir(scratch_dir: Path) -> None:
scratch_dir.mkdir(parents=True, exist_ok=True)
linked_names = {
".venv",
".pytest_cache",
"__pycache__",
"luminal_artifacts",
"prof",
}
for entry in Path(PROJECT_DIR).iterdir():
if entry.name in linked_names:
continue
target = scratch_dir / entry.name
if target.exists() or target.is_symlink():
continue
target.symlink_to(entry, target_is_directory=entry.is_dir())
def _default_profile_output_dir(run_id: str) -> Path:
return (LOCAL_PROJECT_DIR / PROFILE_LOCAL_DEFAULT_ROOT / run_id).resolve()
def _prepare_local_profile_dir(output_dir: Path) -> None:
if output_dir.exists() and not output_dir.is_dir():
raise NotADirectoryError(f"{output_dir} is not a directory")
output_dir.mkdir(parents=True, exist_ok=True)
prof_dir = output_dir / "prof"
if prof_dir.exists():
shutil.rmtree(prof_dir)
manifest_path = output_dir / "manifest.json"
if manifest_path.exists():
manifest_path.unlink()
def _download_profile_artifacts(run_id: str, output_dir: Path) -> None:
entries = PROFILE_VOLUME.listdir(run_id, recursive=True)
_prepare_local_profile_dir(output_dir)
for entry in entries:
relative_path = Path(entry.path).relative_to(run_id)
if relative_path == Path("."):
continue
destination = output_dir / relative_path
if entry.type == FileEntryType.DIRECTORY:
destination.mkdir(parents=True, exist_ok=True)
continue
if entry.type != FileEntryType.FILE:
continue
destination.parent.mkdir(parents=True, exist_ok=True)
with destination.open("wb") as handle:
for chunk in PROFILE_VOLUME.read_file(entry.path):
handle.write(chunk)
def _cleanup_remote_profile_artifacts(run_id: str) -> None:
try:
PROFILE_VOLUME.remove_file(run_id, recursive=True)
except FileNotFoundError:
return
@app.cls(image=image, timeout=DEFAULT_TIMEOUT)
class TestRunner:
@modal.method()
def run(
self,
pytest_args: list[str],
pytest_addopts: str = "",
profile_enabled: bool = False,
) -> dict[str, Any]:
started_at = _utc_now()
run_id = _run_id() if profile_enabled else None
scratch_dir = Path(PROFILE_SCRATCH_ROOT) / run_id if run_id else None
if scratch_dir is not None:
_prepare_scratch_dir(scratch_dir)
env = os.environ.copy()
existing = env.get("PYTHONPATH")
env["PYTHONPATH"] = f"{SRC_PATH}:{existing}" if existing else SRC_PATH
env["LUMINAL_BACKEND"] = "cuda"
env["UV_PROJECT_ENVIRONMENT"] = VENV_PATH
env["MATURIN_PEP517_ARGS"] = "--features cuda --profile release"
env["CUDARC_CUDA_VERSION"] = CUDARC_CUDA_VERSION
env["HF_HOME"] = HF_CACHE_PATH
if pytest_addopts:
env["PYTEST_ADDOPTS"] = pytest_addopts
original_svg_requested = _has_pytest_flag(pytest_args, "--profile-svg")
dot_available = shutil.which("dot") is not None
sanitized_pytest_args = [
arg for arg in pytest_args if arg not in {"--profile", "--profile-svg"}
]
if profile_enabled:
sanitized_pytest_args.append("--profile")
if dot_available:
sanitized_pytest_args.append("--profile-svg")
elif original_svg_requested:
print(
"Graphviz 'dot' is unavailable in the Modal container; "
"falling back to raw .prof artifacts only.",
file=sys.stderr,
)
svg_requested = profile_enabled and dot_available
cmd = [
"uv",
"run",
"--project",
PROJECT_DIR,
"--group",
"dev",
"--reinstall-package",
"luminal_python",
"python",
"-m",
"pytest",
*sanitized_pytest_args,
]
exit_code = subprocess.run(
cmd,
env=env,
cwd=str(scratch_dir) if scratch_dir is not None else PROJECT_DIR,
).returncode
HF_CACHE_VOLUME.commit()
finished_at = _utc_now()
if not profile_enabled:
return {
"exit_code": exit_code,
"run_id": None,
"profile_enabled": False,
"remote_profile_dir": None,
"local_default_dirname": None,
}
volume_root = Path(PROFILE_VOLUME_PATH)
if not volume_root.exists():
raise RuntimeError(
"Profiling requested but the profile volume is not mounted."
)
remote_run_dir = volume_root / run_id
remote_run_dir.mkdir(parents=True, exist_ok=True)
prof_dir = scratch_dir / "prof"
if prof_dir.is_dir():
shutil.copytree(prof_dir, remote_run_dir / "prof")
svg_generated = (remote_run_dir / "prof" / "combined.svg").is_file()
manifest = {
"exit_code": exit_code,
"finished_at": finished_at,
"profile_enabled": True,
"pytest_args": sanitized_pytest_args,
"run_id": run_id,
"started_at": started_at,
"svg_generated": svg_generated,
"svg_requested": svg_requested,
}
(remote_run_dir / "manifest.json").write_text(
json.dumps(manifest, indent=2, sort_keys=True) + "\n",
encoding="utf-8",
)
PROFILE_VOLUME.commit()
return {
"exit_code": exit_code,
"run_id": run_id,
"profile_enabled": True,
"remote_profile_dir": f"{PROFILE_VOLUME_PATH}/{run_id}",
"local_default_dirname": run_id,
"svg_generated": svg_generated,
"svg_requested": svg_requested,
}
def _parse_cli_args(
cli_args: tuple[str, ...],
) -> tuple[str, int | None, bool, str | None, list[str]]:
parser = argparse.ArgumentParser(
prog="modal run modal_pytest_runner.py",
add_help=False,
allow_abbrev=False,
description="Run pytest on Modal with a dynamically selected GPU.",
)
parser.add_argument(
"--gpu",
required=True,
help="GPU type to request from Modal (for example: A100, T4, H100).",
)
parser.add_argument(
"--timeout",
type=int,
help="Optional Modal execution timeout in seconds. Defaults to 1800 seconds.",
)
parser.add_argument(
"--profile",
action="store_true",
help="Enable pytest-profiling and download the resulting artifacts locally.",
)
parser.add_argument(
"--profile-output-dir",
help="Directory to download profiling artifacts into when profiling is enabled.",
)
parsed, pytest_args = parser.parse_known_args(cli_args)
if pytest_args and pytest_args[0] == "--":
pytest_args = pytest_args[1:]
if not pytest_args:
pytest_args = ["tests/"]
return (
parsed.gpu,
parsed.timeout,
parsed.profile,
parsed.profile_output_dir,
pytest_args,
)
@app.local_entrypoint()
def main(*cli_args: str):
gpu, timeout, cli_profile, profile_output_dir, pytest_args = _parse_cli_args(
cli_args
)
profile_enabled = _profiling_enabled(cli_profile, pytest_args)
pytest_addopts = os.environ.get("PYTEST_ADDOPTS", "")
runner_options = {"gpu": gpu}
hf_token_secret = _hf_token_secret()
runner_volumes = {HF_CACHE_PATH: HF_CACHE_VOLUME}
if timeout is not None:
runner_options["timeout"] = timeout
if profile_enabled:
runner_volumes[PROFILE_VOLUME_PATH] = PROFILE_VOLUME
runner_options["volumes"] = runner_volumes
if hf_token_secret is not None:
runner_options["secrets"] = [hf_token_secret]
runner = TestRunner.with_options(**runner_options)()
result = runner.run.remote(
pytest_args=pytest_args,
pytest_addopts=pytest_addopts,
profile_enabled=profile_enabled,
)
if result["profile_enabled"] and result["run_id"] is not None:
if profile_output_dir:
output_dir = Path(profile_output_dir).expanduser().resolve()
else:
output_dir = _default_profile_output_dir(result["local_default_dirname"])
try:
_download_profile_artifacts(result["run_id"], output_dir)
print(f"Profile artifacts downloaded to {output_dir}")
_cleanup_remote_profile_artifacts(result["run_id"])
except FileNotFoundError as exc:
print(f"Unable to download profile artifacts: {exc}", file=sys.stderr)
except OSError as exc:
print(f"Failed to write local profile artifacts: {exc}", file=sys.stderr)
sys.exit(result["exit_code"])

View File

@@ -1,49 +0,0 @@
[project]
name = "luminal_python"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.10"
dependencies = [
"numpy>=2.0.2",
"torch>=2.10.0",
"safetensors",
]
[[tool.uv.index]]
name = "pytorch-cu128"
url = "https://download.pytorch.org/whl/cu128"
explicit = true
[tool.uv.sources]
torch = [
{ index = "pytorch-cu128", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
]
[build-system]
requires = ["maturin>=1.0,<2.0"]
build-backend = "maturin"
[tool.maturin]
python-source = "src"
manifest-path = "rust/Cargo.toml"
module-name = "luminal.luminal"
[tool.pytest.ini_options]
markers = [
"slow: tests that download large models or require pre-generated artifacts",
]
[dependency-groups]
dev = [
"maturin>=1.0,<2.0",
"pytest>=9.0.2",
"pytest-profiling",
"snakeviz",
"maturin-import-hook>=0.3.0",
"pytest-randomly>=4.0.1",
"transformers>=4.40.0",
"diffusers>=0.35.0",
"modal>=1.3.5",
]

View File

@@ -1,36 +0,0 @@
#!/bin/bash
set -e
echo "=========================================="
echo " Luminal Python: Full Test Suite"
echo "=========================================="
NATIVE_TESTS="tests/test_hlir_ops.py tests/test_unary.py"
CUDA_TESTS="tests/test_hlir_ops.py tests/test_unary.py tests/test_llama3.py"
# ── Phase 1: Native Backend ─────────────────────────────────
echo ""
echo "=== Phase 1: Building native backend ==="
rm -rf rust/target/wheels rust/target/debug rust/target/release
uv run maturin develop --manifest-path rust/Cargo.toml
echo ""
echo "--- 1a: Native backend tests ---"
uv run pytest $NATIVE_TESTS -v
# ── Phase 2: CUDA Backend ───────────────────────────────────
echo ""
echo "=== Phase 2: Building CUDA backend ==="
rm -rf rust/target/wheels rust/target/debug rust/target/release
uv run maturin develop --manifest-path rust/Cargo.toml --features cuda -r
echo ""
echo "--- 2a: CUDA ---"
RUST_BACKTRACE=1 LUMINAL_BACKEND=cuda uv run pytest $CUDA_TESTS -m "not slow" -v
echo ""
echo "=========================================="
echo " All tests passed!"
echo "=========================================="

View File

@@ -1,22 +0,0 @@
#!/bin/bash
set -e
echo "=== Luminal Python Test Runner ==="
echo ""
# Force clean rebuild of Rust extension
echo "Step 1: Cleaning previous builds..."
rm -rf rust/target/wheels rust/target/debug rust/target/release
# Rebuild in development mode (faster compilation)
echo "Step 2: Building Rust extension..."
uv run maturin develop --manifest-path rust/Cargo.toml
# Run pytest
echo "Step 3: Running pytest..."
# it is best not to add the full model tests, they end up running billion parameter models
# on the CPU and it takes far to long
uv run pytest tests/test_hlir_ops.py tests/test_unary.py -v
echo ""
echo "=== Tests Complete ==="

View File

@@ -1,20 +0,0 @@
#!/bin/bash
set -e
echo "=== Luminal Python Test Runner (CUDA Backend) ==="
echo ""
# Force clean rebuild of Rust extension
echo "Step 1: Cleaning previous builds..."
rm -rf rust/target/wheels rust/target/debug rust/target/release
# Rebuild in development mode (faster compilation)
echo "Step 2: Building Rust extension..."
uv run maturin develop --manifest-path rust/Cargo.toml --features cuda -r
# Run pytest with CUDA backend
echo "Step 3: Running pytest with CUDA backend..."
RUST_BACKTRACE=1 LUMINAL_BACKEND=cuda uv run pytest tests/test_llama3.py tests/test_hlir_ops.py tests/test_unary.py -v
echo ""
echo "=== Tests Complete ==="

View File

@@ -1,28 +0,0 @@
[package]
name = "luminal_python"
version = "0.1.0"
edition.workspace = true
[lib]
name = "luminal"
crate-type = ["cdylib"]
path = "src/lib.rs"
[features]
cuda = ["dep:luminal_cuda_lite"]
[dependencies]
rustc-hash = "2.1.1"
luminal = {path= "../../.."}
luminal_cuda_lite = {path="../../luminal_cuda_lite", optional = true}
serde = { version = "1", features = ["derive"] }
serde_json = "1"
zip = "2"
anyhow = "1"
memmap2 = "0.9"
safetensors = "0.5"
half = "2"
[dependencies.pyo3]
version = "0.28.0"
features = ["abi3-py38"]

View File

@@ -1,651 +0,0 @@
#[cfg(feature = "cuda")]
use luminal::prelude::tracing::{trace, warn};
use luminal::{
hlir::{NativeData, Output},
prelude::*,
shape::Expression,
visualization::ToDot,
};
use pyo3::prelude::*;
use std::collections::HashMap;
#[cfg(feature = "cuda")]
use std::collections::HashSet;
use crate::{runtime::RuntimeBackend, typed_data::TypedData};
/// Maps symbolic dimension parameter names (e.g. "seq_len") to luminal Expression variable chars.
pub type DimParamMap = HashMap<String, char>;
/// Convert luminal DType to PT2 dtype integer code (for python interop)
/// Types without a direct Pytorch equivalent map to the closest safe representation
fn luminal_dtype_to_pt2_code(dtype: DType) -> u32 {
match dtype {
DType::U8 => 1,
DType::I8 => 2,
DType::I16 => 3,
DType::Int => 4, // i32
DType::U16 => 4, // u16 -> i32 (Pytorch has no u16 in older versions)
DType::F16 => 6,
DType::F32 | DType::TF32 => 7,
DType::F64 => 8,
DType::Bool => 12,
DType::Bf16 => 13,
_ => panic!("luminal_dtype_to_pt2_code: unsupported dtype {:?}", dtype),
}
}
/// Common intermediate result from translating a model graph.
pub struct GraphTranslation {
pub graph: Graph,
pub tensor_ids: HashMap<String, NodeIndex>,
pub input_names: Vec<String>,
pub output_names: Vec<String>,
pub output_shape_exprs: Vec<Vec<Expression>>,
pub output_dtypes: Vec<DType>,
pub input_shape_exprs: Vec<Vec<Expression>>,
pub dim_param_map: DimParamMap,
}
/// Pre-loaded weight data from any model format (dtype-aware).
pub struct WeightData {
/// (Input node label, typed data) for weights and constants.
pub weights: Vec<(String, TypedData)>,
/// label → element count for ALL Input nodes (for CUDA dummy data sizing).
pub tensor_sizes: HashMap<String, usize>,
/// label → (device_ptr, n_bytes) for zero-copy CUDA weight sharing.
pub device_ptrs: HashMap<String, (u64, usize)>,
}
#[pyclass(unsendable)]
pub struct CompiledGraph {
pub graph: Graph,
pub runtime: RuntimeBackend,
pub tensor_ids: HashMap<String, NodeIndex>,
/// Cached label → NodeIndex map for O(1) lookups in set_weight_* methods.
label_map: HashMap<String, NodeIndex>,
pub input_names: Vec<String>,
pub output_names: Vec<String>,
pub output_shapes: Vec<Vec<usize>>,
pub output_shape_exprs: Vec<Vec<Expression>>,
pub output_dtypes: Vec<DType>,
pub input_shape_exprs: Vec<Vec<Expression>>,
pub dim_param_map: DimParamMap,
}
impl CompiledGraph {
/// Compilation pipeline for PT2/FX graphs.
///
/// Takes a `GraphTranslation` (produced by `translate_pt2`) and `WeightData`,
/// builds the backend, loads weights, and
/// returns a ready-to-execute `CompiledGraph`.
pub fn parse_graph(
translation: GraphTranslation,
weight_data: WeightData,
backend: &str,
search_iters: usize,
) -> Result<CompiledGraph, String> {
let GraphTranslation {
mut graph,
tensor_ids,
input_names,
output_names,
output_shape_exprs,
output_dtypes,
input_shape_exprs,
dim_param_map,
} = translation;
let rt = match backend {
#[cfg(feature = "cuda")]
"cuda" | "gpu" => {
CompiledGraph::build_cuda_backend(&mut graph, &weight_data, search_iters)?
}
"native" | "cpu" => {
CompiledGraph::build_native_backend(&mut graph, &weight_data, search_iters)?
}
_ => {
#[cfg(feature = "cuda")]
{
return Err(format!(
"Invalid backend '{}'. Must be 'native' or 'cuda'",
backend
));
}
#[cfg(not(feature = "cuda"))]
{
if backend == "cuda" {
return Err(
"CUDA backend requested, but this luminal extension was built without the `cuda` feature. Rebuild with `maturin develop --features cuda -r` or use backend='native'."
.to_string(),
);
}
return Err(format!(
"Invalid backend '{}'. This build only supports 'native'. Rebuild with the `cuda` feature to enable 'cuda'.",
backend
));
}
}
};
// Resolve concrete output shapes from expressions
let output_shapes: Vec<Vec<usize>> = output_shape_exprs
.iter()
.map(|exprs| exprs.iter().map(|e| e.to_usize().unwrap_or(1)).collect())
.collect();
let label_map = CompiledGraph::build_label_map(&graph);
Ok(CompiledGraph {
graph,
runtime: rt,
tensor_ids,
label_map,
input_names,
output_names,
output_shapes,
output_shape_exprs,
output_dtypes,
input_shape_exprs,
dim_param_map,
})
}
/// Build a label → NodeIndex map for all Input nodes in the graph.
/// Used for efficient weight loading by label matching.
fn build_label_map(graph: &Graph) -> HashMap<String, NodeIndex> {
graph
.graph
.node_indices()
.filter_map(|node_id| {
(*graph.graph[node_id])
.as_any()
.downcast_ref::<luminal::hlir::Input>()
.map(|input| (input.label.clone(), node_id))
})
.collect()
}
#[cfg(feature = "cuda")]
fn build_cuda_backend(
graph: &mut Graph,
weight_data: &WeightData,
search_iters: usize,
) -> Result<RuntimeBackend, String> {
let device_ptrs = &weight_data.device_ptrs;
use luminal_cuda_lite::cudarc::driver::CudaContext;
use luminal_cuda_lite::runtime::CudaRuntime;
let cuda_ctx = CudaContext::new(0).map_err(|e| format!("CUDA context init failed: {e}"))?;
let stream = cuda_ctx.default_stream();
graph.build_search_space::<CudaRuntime>();
let mut rt = CudaRuntime::initialize(stream);
// Build label → NodeIndex map for device pointer matching.
let label_map = CompiledGraph::build_label_map(graph);
// For weights with device pointers: use them directly (zero-copy).
// This avoids allocating ~N GB of dummy data during search.
// The pointers survive search because profiling mode skips buffer consumption,
// and graph-level .persist() ensures they survive post-search execution too.
let mut device_ptr_nodes: HashSet<NodeIndex> = HashSet::new();
let mut matched_count = 0usize;
let mut missed_labels: Vec<String> = Vec::new();
for (label, &(ptr, n_bytes)) in device_ptrs {
if let Some(&node_id) = label_map.get(label) {
unsafe { rt.set_device_ptr(node_id, ptr, n_bytes) };
device_ptr_nodes.insert(node_id);
matched_count += 1;
} else {
missed_labels.push(label.clone());
}
}
let total_device_bytes: usize = device_ptrs.values().map(|(_, n)| *n).sum();
trace!(
"[CUDA BUILD] Device pointers: {} matched, {} missed out of {} total ({:.3} GiB)",
matched_count,
missed_labels.len(),
device_ptrs.len(),
total_device_bytes as f64 / (1024.0 * 1024.0 * 1024.0),
);
if !missed_labels.is_empty() {
warn!(
"[CUDA BUILD] {} device-ptr labels did not match any Input node (first 10): {:?}",
missed_labels.len(),
&missed_labels[..missed_labels.len().min(10)]
);
let available: Vec<&String> = label_map.keys().take(10).collect();
warn!(
"[CUDA BUILD] Available label_map keys (first 10): {:?}",
available
);
}
// Set dummy 1.0 data for remaining Input nodes (user inputs, constants without
// device pointers) for safe search profiling.
// IMPORTANT: Must use 1.0, NOT 0.0. Zero inputs cause NaN in many ops:
// - fmod(0, 0) = NaN (Mod)
// - recip(0) = inf → weight * inf = NaN (Div)
// - log(0) = -inf (Pow)
// - chain ops with zero produce NaN (Erf)
let mut dummy_total_elements = 0usize;
let mut dummy_count = 0usize;
for node_id in graph.graph.node_indices() {
if device_ptr_nodes.contains(&node_id) {
continue;
}
if let Some(input) = (*graph.graph[node_id])
.as_any()
.downcast_ref::<luminal::hlir::Input>()
{
if let Some(&n) = weight_data.tensor_sizes.get(&input.label) {
if n > 0 {
dummy_total_elements += n;
dummy_count += 1;
// Use dtype-aware dummy data: TypedData::ones produces correct
// byte patterns for every dtype (f32, f16, bf16, i32, bool, f8, etc.).
// Must use 1, not 0 — zero inputs cause NaN in many ops.
rt.set_data(node_id, TypedData::ones(n, input.dtype).bytes);
}
}
}
}
trace!(
"[CUDA BUILD] Dummy data: {} nodes, {} elements ({:.3} GiB as f32)",
dummy_count,
dummy_total_elements,
(dummy_total_elements * 4) as f64 / (1024.0 * 1024.0 * 1024.0),
);
// Search (device-pointer weights are used directly; dummy data for the rest)
let mut rt = graph.search(rt, search_iters);
// Load real weight data for non-device-ptr weights (constants from PT2 archive, etc.)
let mut loaded_weight_bytes = 0usize;
let mut loaded_weight_count = 0usize;
for (label, data) in &weight_data.weights {
if !device_ptrs.contains_key(label) {
if let Some(&node_id) = label_map.get(label) {
loaded_weight_bytes += data.n_bytes();
loaded_weight_count += 1;
rt.set_data(node_id, data.bytes.clone());
}
}
}
trace!(
"[CUDA BUILD] Post-search weight load: {} weights, {:.3} GiB",
loaded_weight_count,
loaded_weight_bytes as f64 / (1024.0 * 1024.0 * 1024.0),
);
Ok(RuntimeBackend::Cuda(Box::new(rt)))
}
fn build_native_backend(
graph: &mut Graph,
weight_data: &WeightData,
search_iters: usize,
) -> Result<RuntimeBackend, String> {
graph.build_search_space::<NativeRuntime>();
let mut rt = graph.search(NativeRuntime::default(), search_iters);
// Load weight data after search, preserving native dtype.
// TypedData -> NativeData conversion (From<TypedData>) handles mapping to the
// correct NativeData variant (F32, F16, Bf16, Int, Bool).
let label_map = CompiledGraph::build_label_map(graph);
for (label, data) in &weight_data.weights {
if let Some(&node_id) = label_map.get(label) {
let native: NativeData = data.into();
rt.set_data(node_id, native);
}
}
Ok(RuntimeBackend::Native(rt))
}
}
#[pymethods]
impl CompiledGraph {
/// Get the list of input tensor names.
#[getter]
fn input_names(&self) -> Vec<String> {
self.input_names.clone()
}
/// Get the PT2 dtype codes for all inputs (in order of input_names).
#[getter]
fn input_dtypes(&self) -> Vec<u32> {
self.input_names
.iter()
.map(|name| {
if let Some(&node_id) = self.tensor_ids.get(name)
&& let Some(input) = (*self.graph.graph[node_id])
.as_any()
.downcast_ref::<luminal::hlir::Input>()
{
return luminal_dtype_to_pt2_code(input.dtype);
}
7 // default to f32
})
.collect()
}
/// Get the list of output tensor names.
#[getter]
fn output_names(&self) -> Vec<String> {
self.output_names.clone()
}
/// Get the output shapes.
#[getter]
fn output_shapes(&self) -> Vec<Vec<usize>> {
self.output_shapes.clone()
}
/// Get all tensor names in the graph.
#[getter]
fn tensor_names(&self) -> Vec<String> {
self.tensor_ids.keys().cloned().collect()
}
/// Get the name of the active backend (native or cuda).
#[getter]
fn backend(&self) -> &'static str {
self.runtime.name()
}
/// Whether this graph has dynamic (symbolic) dimensions.
#[getter]
fn has_dynamic_dims(&self) -> bool {
!self.dim_param_map.is_empty()
}
/// Get the dynamic dimension parameter names (e.g. ["seq_len"]).
#[getter]
fn dim_params(&self) -> Vec<String> {
self.dim_param_map.keys().cloned().collect()
}
/// Set a dynamic dimension value by its param name (e.g. "seq_len").
fn set_dim(&mut self, param_name: &str, value: usize) -> PyResult<()> {
let ch = self.dim_param_map.get(param_name).ok_or_else(|| {
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!(
"Unknown dim param '{}'. Available: {:?}",
param_name,
self.dim_param_map.keys().collect::<Vec<_>>()
))
})?;
self.graph.set_dim(*ch, value);
Ok(())
}
/// Auto-detect and set dynamic dimensions from input tensor shapes.
/// For each user input, matches the concrete shape against its symbolic
/// shape expressions and sets the corresponding dyn_map entries.
fn auto_set_dims_from_input_shapes(&mut self, input_shapes: Vec<Vec<usize>>) {
for (shape_exprs, shape) in self.input_shape_exprs.iter().zip(input_shapes.iter()) {
for (dim_expr, &dim_val) in shape_exprs.iter().zip(shape.iter()) {
// Check if this expression is a bare symbolic variable
let terms = dim_expr.terms.read();
if terms.len() == 1
&& let luminal::shape::Term::Var(c) = terms[0]
{
self.graph.set_dim(c, dim_val);
}
}
}
}
/// Resolve output shapes using current dynamic dimension values.
/// Returns concrete shapes after substituting all symbolic dims.
fn resolve_output_shapes(&self) -> PyResult<Vec<Vec<usize>>> {
let dyn_map = &self.graph.dyn_map;
let mut result = Vec::new();
for shape_exprs in &self.output_shape_exprs {
let shape: Vec<usize> = shape_exprs
.iter()
.map(|e| {
e.exec(dyn_map).ok_or_else(|| {
PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
"Cannot resolve dimension expression {:?}. Set all dynamic dims first.",
e
))
})
})
.collect::<PyResult<Vec<usize>>>()?;
result.push(shape);
}
Ok(result)
}
/// Set input tensor data by name (f32, for backward compatibility).
fn set_input(&mut self, name: &str, data: Vec<f32>) -> PyResult<()> {
let node_id = self.tensor_ids.get(name).ok_or_else(|| {
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!("Unknown input tensor: {}", name))
})?;
self.runtime.set_data_f32(*node_id, data);
Ok(())
}
/// Set input tensor data from a CPU host memory pointer (dtype-aware).
/// The pointer must point to contiguous data. `n_bytes` is the total byte count.
/// `dtype_code` uses PT2 numbering (7=f32, 6=f16, 13=bf16, etc.).
/// Converts source format to luminal's native format (e.g., i64→i32, f64→f32).
fn set_input_from_ptr(
&mut self,
name: &str,
ptr: u64,
n_bytes: usize,
dtype_code: u32,
) -> PyResult<()> {
debug_assert!(ptr != 0, "set_input_from_ptr called with null pointer");
let node_id = self.tensor_ids.get(name).ok_or_else(|| {
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!("Unknown input tensor: {}", name))
})?;
let raw_bytes = unsafe { std::slice::from_raw_parts(ptr as *const u8, n_bytes).to_vec() };
let typed = TypedData::from_pytorch_bytes(raw_bytes, dtype_code);
self.runtime.set_data(*node_id, typed);
Ok(())
}
/// Set input from a CUDA device pointer. Zero-copy on device.
/// The pointer must be a valid CUDA device allocation with at least n_bytes bytes.
#[cfg(feature = "cuda")]
fn set_input_device_ptr(
&mut self,
name: &str,
device_ptr: u64,
n_bytes: usize,
) -> PyResult<()> {
let node_id = self.tensor_ids.get(name).ok_or_else(|| {
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!("Unknown input tensor: {}", name))
})?;
match &mut self.runtime {
RuntimeBackend::Cuda(rt) => unsafe { rt.set_device_ptr(*node_id, device_ptr, n_bytes) },
_ => {
return Err(pyo3::exceptions::PyValueError::new_err(
"set_input_device_ptr requires CUDA backend",
));
}
}
Ok(())
}
/// For PT2 weights (e.g. "fc1.weight"). Persistence is handled at graph level via .persist().
#[cfg(feature = "cuda")]
fn set_weight_device_ptr(
&mut self,
label: &str,
device_ptr: u64,
n_bytes: usize,
) -> PyResult<()> {
let &node_id = self.label_map.get(label).ok_or_else(|| {
pyo3::exceptions::PyKeyError::new_err(format!("No Input node with label: {}", label))
})?;
match &mut self.runtime {
RuntimeBackend::Cuda(rt) => {
unsafe { rt.set_device_ptr(node_id, device_ptr, n_bytes) };
}
_ => {
return Err(pyo3::exceptions::PyValueError::new_err(
"set_weight_device_ptr requires CUDA backend",
));
}
}
Ok(())
}
/// Register an external device pointer for an output tensor (zero-copy output).
/// Call before run() — the runtime will write kernel results directly into this buffer.
/// For aliased outputs (in-place ops), falls back to DtoD copy; check output_is_zero_copy() after run().
#[cfg(feature = "cuda")]
fn set_output_device_ptr(
&mut self,
name: &str,
device_ptr: u64,
n_bytes: usize,
) -> PyResult<()> {
let node_id = self.tensor_ids.get(name).ok_or_else(|| {
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!(
"Unknown output tensor: {}",
name
))
})?;
match &mut self.runtime {
RuntimeBackend::Cuda(rt) => {
unsafe { rt.set_output_device_ptr(*node_id, device_ptr, n_bytes) };
}
_ => {
return Err(pyo3::exceptions::PyValueError::new_err(
"set_output_device_ptr requires CUDA backend",
));
}
}
Ok(())
}
/// Check whether an output tensor was zero-copied (written directly to the registered pointer).
/// Returns false for aliased outputs that need a fallback DtoD copy. Must be called after run().
#[cfg(feature = "cuda")]
fn output_is_zero_copy(&self, name: &str) -> PyResult<bool> {
let node_id = self.tensor_ids.get(name).ok_or_else(|| {
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!(
"Unknown output tensor: {}",
name
))
})?;
match &self.runtime {
RuntimeBackend::Cuda(rt) => Ok(rt.output_is_zero_copy(*node_id)),
_ => Ok(false),
}
}
/// Set a weight tensor from a CPU host pointer, matching by Input node label (dtype-aware).
/// `n_bytes` is the total byte count. `dtype_code` uses PT2 numbering (7=f32, 6=f16, 13=bf16, etc.).
fn set_weight_from_ptr(
&mut self,
label: &str,
ptr: u64,
n_bytes: usize,
dtype_code: u32,
) -> PyResult<()> {
debug_assert!(ptr != 0, "set_weight_from_ptr called with null pointer");
let &node_id = self.label_map.get(label).ok_or_else(|| {
pyo3::exceptions::PyKeyError::new_err(format!("No Input node with label: {}", label))
})?;
let bytes = unsafe { std::slice::from_raw_parts(ptr as *const u8, n_bytes).to_vec() };
let typed = TypedData::from_pytorch_bytes(bytes, dtype_code);
self.runtime.set_data(node_id, typed);
Ok(())
}
/// Execute the graph.
fn run(&mut self) {
self.runtime.execute(&self.graph.dyn_map);
}
/// Return the HLIR graph as a DOT string for visualization.
fn to_dot(&self) -> PyResult<String> {
self.graph.graph.to_dot().map_err(|e| {
pyo3::exceptions::PyRuntimeError::new_err(format!("DOT generation failed: {e}"))
})
}
/// Get the PT2 dtype codes for all outputs (in order).
#[getter]
fn output_dtypes(&self) -> Vec<u32> {
self.output_dtypes
.iter()
.map(|d| luminal_dtype_to_pt2_code(*d))
.collect()
}
/// Get output tensor data by name as f32 (copies to host).
/// For native backend: handles any NativeData variant by converting to f32.
/// The native runtime may produce NativeData::Int or NativeData::Bool for some ops
/// (e.g., Cast chains), so we can't assume NativeData::F32.
fn get_output(&self, name: &str) -> PyResult<Vec<f32>> {
let node_id = self.tensor_ids.get(name).ok_or_else(|| {
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!(
"Unknown output tensor: {}",
name
))
})?;
match &self.runtime {
RuntimeBackend::Native(rt) => {
let id = *node_id;
let output_id = rt
.graph
.node_indices()
.find(|n| {
if let Some(out) = (**rt.graph[*n]).as_any().downcast_ref::<Output>() {
out.node == id.index()
} else {
false
}
})
.ok_or_else(|| {
pyo3::exceptions::PyRuntimeError::new_err(format!(
"No output node found for tensor: {}",
name
))
})?;
let data = rt.buffers.get(&output_id).ok_or_else(|| {
pyo3::exceptions::PyRuntimeError::new_err(format!(
"No buffer data for output tensor: {}",
name
))
})?;
// Convert any NativeData variant to f32
Ok((0..data.len()).map(|i| data.f32(i)).collect())
}
#[cfg(feature = "cuda")]
RuntimeBackend::Cuda(rt) => Ok(rt.get_f32(*node_id)),
}
}
/// Copy output tensor data directly to a CUDA device pointer (DtoD).
/// Avoids the DtoH + HtoD round-trip of get_output() + .to(device).
#[cfg(feature = "cuda")]
fn copy_output_to_device_ptr(&self, name: &str, dest_ptr: u64, n_bytes: usize) -> PyResult<()> {
let node_id = self.tensor_ids.get(name).ok_or_else(|| {
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!(
"Unknown output tensor: {}",
name
))
})?;
match &self.runtime {
RuntimeBackend::Cuda(rt) => {
unsafe { rt.copy_output_to_device_ptr(*node_id, dest_ptr, n_bytes) };
Ok(())
}
_ => Err(pyo3::exceptions::PyValueError::new_err(
"copy_output_to_device_ptr requires CUDA backend",
)),
}
}
}

View File

@@ -1,21 +0,0 @@
mod compiled_graph;
mod runtime;
pub mod typed_data;
// PT2 modules
mod pt2_compiled_model;
mod pt2_parser;
mod pt2_schema;
mod pt2_util;
mod translator;
use compiled_graph::CompiledGraph;
use pt2_compiled_model::process_pt2;
use pyo3::prelude::*;
#[pymodule]
fn luminal(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(process_pt2, m)?)?;
m.add_class::<CompiledGraph>()?;
Ok(())
}

View File

@@ -1,369 +0,0 @@
use luminal::prelude::tracing::warn;
use luminal::prelude::*;
use pyo3::prelude::*;
use std::collections::HashMap;
use crate::compiled_graph::{CompiledGraph, DimParamMap, GraphTranslation, WeightData};
use crate::pt2_schema;
use crate::translator;
use crate::typed_data::TypedData;
use crate::{pt2_parser, pt2_util};
/// Pre-loaded weight/constant data paired with tensor sizes.
type PreloadResult = (Vec<(String, TypedData)>, HashMap<String, usize>);
fn resolve_dim_sizes(
sizes: &[pt2_schema::DimSize],
sym_to_char: &HashMap<String, char>,
) -> Vec<Expression> {
sizes
.iter()
.map(|s| match s {
pt2_schema::DimSize::Int(i) => Expression::from(i.as_int as usize),
pt2_schema::DimSize::Expr(e) => {
if let Some(sym) = pt2_parser::extract_symbol_name_pub(&e.as_expr.expr_str) {
if let Some(c) = sym_to_char.get(&sym) {
Expression::from(*c)
} else {
Expression::from(1usize)
}
} else {
Expression::from(1usize)
}
}
})
.collect()
}
#[pyfunction]
#[pyo3(signature = (pt2_path, weights_path, backend, search_iters, weight_device_ptrs=None))]
pub fn process_pt2(
pt2_path: &str,
weights_path: &str,
backend: &str,
search_iters: usize,
weight_device_ptrs: Option<HashMap<String, (u64, usize)>>,
) -> PyResult<CompiledGraph> {
compile_pt2(
pt2_path,
weights_path,
backend,
search_iters,
weight_device_ptrs.unwrap_or_default(),
)
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e:#}")))
}
fn compile_pt2(
pt2_path: &str,
weights_path: &str,
backend: &str,
search_iters: usize,
weight_device_ptrs: HashMap<String, (u64, usize)>,
) -> anyhow::Result<CompiledGraph> {
let (translation, mut weights) = translate_pt2(pt2_path, weights_path)?;
weights.device_ptrs = weight_device_ptrs;
CompiledGraph::parse_graph(translation, weights, backend, search_iters)
.map_err(|e| anyhow::anyhow!(e))
}
/// Translate a PT2 exported model into a format-neutral GraphTranslation + WeightData.
pub fn translate_pt2(
pt2_path: &str,
weights_path: &str,
) -> anyhow::Result<(GraphTranslation, WeightData)> {
let parsed = pt2_parser::parse_pt2(pt2_path)?;
let translated = translator::translate(&parsed)?;
let mut graph = translated.graph;
// Set initial dynamic dim values from symbol ranges
for (sym_name, c) in &translated.sym_map.sym_to_char {
if let Some(rc) = translated.sym_map.ranges.get(sym_name) {
graph.set_dim(*c, rc.min_val as usize);
}
}
// Compute shape expressions and dtypes from PT2 tensor metadata
let output_shape_exprs: Vec<Vec<Expression>> = translated
.output_ids
.iter()
.map(|(name, _id)| {
parsed
.tensor_meta(name)
.map(|meta| resolve_dim_sizes(&meta.sizes, &translated.sym_map.sym_to_char))
.unwrap_or_default()
})
.collect();
let output_dtypes: Vec<DType> = translated
.output_ids
.iter()
.map(|(name, _id)| {
parsed
.tensor_meta(name)
.map(|meta| pt2_util::torch_dtype_int_to_luminal(meta.dtype))
.unwrap_or(DType::F32)
})
.collect();
let input_names: Vec<String> = translated
.user_input_ids
.iter()
.map(|(name, _)| name.clone())
.collect();
let output_names: Vec<String> = translated
.output_ids
.iter()
.map(|(name, _)| name.clone())
.collect();
let input_shape_exprs: Vec<Vec<Expression>> = translated
.user_input_ids
.iter()
.map(|(name, _id)| {
parsed
.tensor_meta(name)
.map(|meta| resolve_dim_sizes(&meta.sizes, &translated.sym_map.sym_to_char))
.unwrap_or_default()
})
.collect();
// Build tensor_ids from user inputs and outputs
let mut tensor_ids: HashMap<String, NodeIndex> = HashMap::new();
for (name, id) in &translated.user_input_ids {
tensor_ids.insert(name.clone(), *id);
}
for (name, id) in &translated.output_ids {
tensor_ids.insert(name.clone(), *id);
}
// Pre-load weights and compute tensor sizes for CUDA dummy data
let mut weights: Vec<(String, TypedData)> = Vec::new();
let mut tensor_sizes: HashMap<String, usize> = HashMap::new();
// Load safetensors weights
if !weights_path.is_empty() {
let (st_weights, st_sizes) = preload_safetensors(&graph, weights_path)?;
weights.extend(st_weights);
tensor_sizes.extend(st_sizes);
}
// Load PT2 constants from ZIP archive
let (const_weights, const_sizes) = preload_constants(&graph, &parsed)?;
weights.extend(const_weights);
tensor_sizes.extend(const_sizes);
// Add tensor sizes from PT2 metadata for parameters/buffers not in safetensors
// (covers case when weights are loaded via device pointers after compilation)
for input_kind in parsed.classify_inputs() {
let (graph_name, original_name) = match &input_kind {
pt2_parser::InputKind::Parameter {
graph_name,
original_name,
} => (graph_name.as_str(), original_name.as_str()),
pt2_parser::InputKind::Buffer {
graph_name,
original_name,
} => (graph_name.as_str(), original_name.as_str()),
pt2_parser::InputKind::UserInput { .. } => continue,
};
// Always use authoritative sizes from model.json tensor_meta,
// even if preload_constants inserted a different (possibly stripped) size.
if let Some(meta) = parsed.tensor_meta(graph_name) {
let n: usize = meta
.sizes
.iter()
.map(|s| s.hint().unwrap_or(1) as usize)
.product();
tensor_sizes.insert(original_name.to_string(), n);
}
}
// Add user input sizes
for (name, _id) in &translated.user_input_ids {
if !tensor_sizes.contains_key(name)
&& let Some(meta) = parsed.tensor_meta(name)
{
let n: usize = meta
.sizes
.iter()
.map(|s| s.hint().unwrap_or(1) as usize)
.product();
tensor_sizes.insert(name.clone(), n);
}
}
let dim_param_map: DimParamMap = translated.sym_map.sym_to_char;
let translation = GraphTranslation {
graph,
tensor_ids,
input_names,
output_names,
output_dtypes,
output_shape_exprs,
input_shape_exprs,
dim_param_map,
};
let weight_data = WeightData {
weights,
tensor_sizes,
device_ptrs: HashMap::new(),
};
Ok((translation, weight_data))
}
// ---------------------------------------------------------------------------
// Weight pre-loading helpers
// ---------------------------------------------------------------------------
/// Pre-load all safetensors weights that match Input nodes in the graph.
/// Returns (weight data, tensor sizes for all tensors in the file).
fn preload_safetensors(graph: &Graph, file_path: &str) -> anyhow::Result<PreloadResult> {
use memmap2::MmapOptions;
use safetensors::SafeTensors;
use std::fs::File;
let f = File::open(file_path)?;
let mmap = unsafe { MmapOptions::new().map(&f)? };
let st = SafeTensors::deserialize(&mmap)
.map_err(|e| anyhow::anyhow!("SafeTensors deserialize error: {e}"))?;
let mut weights = Vec::new();
let mut sizes = HashMap::new();
// Get sizes for ALL tensors in the file (for dummy data allocation)
for (name, info) in st.tensors() {
let n: usize = info.shape().iter().product();
sizes.insert(name.to_string(), n);
}
// Load weight data for Input nodes that match safetensors tensor names
for node_id in graph.graph.node_indices() {
if let Some(input) = (*graph.graph[node_id])
.as_any()
.downcast_ref::<luminal::hlir::Input>()
&& let Ok(tensor) = st.tensor(&input.label)
{
let types = bytes_to_typed(tensor.data(), safetensors_dtype_to_pt2(tensor.dtype()));
weights.push((input.label.clone(), types));
}
}
Ok((weights, sizes))
}
/// Pre-load all PT2 constants from the ZIP archive.
/// Returns (constant data, tensor sizes for all constants).
fn preload_constants(
_graph: &Graph,
parsed: &pt2_parser::ParsedPT2,
) -> anyhow::Result<PreloadResult> {
let constants_config = match &parsed.constants_config {
Some(c) => c,
None => return Ok((Vec::new(), HashMap::new())),
};
let mut weights = Vec::new();
let mut sizes = HashMap::new();
for (name, entry) in &constants_config.config {
let n: usize = entry
.tensor_meta
.sizes
.iter()
.map(|s| s.hint().unwrap_or(1) as usize)
.product();
sizes.insert(name.clone(), n);
let raw_bytes = match pt2_parser::read_constant_bytes(
&parsed.pt2_path,
&parsed.archive_prefix,
entry,
) {
Ok(b) => b,
Err(e) => {
warn!("failed to load constant '{}': {:#}", name, e);
continue;
}
};
let typed_data = bytes_to_typed(&raw_bytes, entry.tensor_meta.dtype);
weights.push((name.clone(), typed_data));
}
Ok((weights, sizes))
}
// ---------------------------------------------------------------------------
// Byte conversion helpers
// ---------------------------------------------------------------------------
/// Convert safetensors Dtype to PT2 dtype number.
fn safetensors_dtype_to_pt2(dtype: safetensors::Dtype) -> u32 {
match dtype {
safetensors::Dtype::BOOL => 12,
safetensors::Dtype::U8 => 1,
safetensors::Dtype::I8 => 2,
safetensors::Dtype::I16 => 3,
safetensors::Dtype::I32 => 4,
safetensors::Dtype::I64 => 5,
safetensors::Dtype::F16 => 6,
safetensors::Dtype::F32 => 7,
safetensors::Dtype::F64 => 8,
safetensors::Dtype::BF16 => 13,
_ => 7, // default to f32
}
}
/// Convert raw bytes to TypedData using PT2 dtype numbering.
/// Preserves native byte format for types luminal supports directly (f32, f16, bf16, i32, bool, u8, i8).
/// Converts i64/f64/i16 to the closest luminal-native representation.
fn bytes_to_typed(bytes: &[u8], dtype: u32) -> TypedData {
match dtype {
// Types that map directly — preserve raw bytes
7 => TypedData::from_raw(bytes.to_vec(), DType::F32),
6 => TypedData::from_raw(bytes.to_vec(), DType::F16),
13 => TypedData::from_raw(bytes.to_vec(), DType::Bf16),
4 => TypedData::from_raw(bytes.to_vec(), DType::Int), // i32
1 => TypedData::from_raw(bytes.to_vec(), DType::U8),
2 => TypedData::from_raw(bytes.to_vec(), DType::I8),
12 => TypedData::from_raw(bytes.to_vec(), DType::Bool),
// i64 → i32 (truncate, matching luminal's Int type)
5 => {
let i32s: Vec<i32> = bytes
.chunks_exact(8)
.map(|b| {
i64::from_le_bytes([b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]]) as i32
})
.collect();
TypedData::from_i32_vec(i32s)
}
// f64 → f32 (downcast, luminal has no F64 in practice for most ops)
8 => {
let f32s: Vec<f32> = bytes
.chunks_exact(8)
.map(|b| {
f64::from_le_bytes([b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]]) as f32
})
.collect();
TypedData::from_f32_vec(f32s)
}
// i16 → i32 (widen to luminal's Int)
3 => {
let i32s: Vec<i32> = bytes
.chunks_exact(2)
.map(|b| i16::from_le_bytes([b[0], b[1]]) as i32)
.collect();
TypedData::from_i32_vec(i32s)
}
_ => {
let luminal_dtype = pt2_util::torch_dtype_int_to_luminal(dtype);
warn!("Unrecognized dtype {dtype}, interpreting as {luminal_dtype:?}");
TypedData::from_raw(bytes.to_vec(), luminal_dtype)
}
}
}

View File

@@ -1,295 +0,0 @@
//! PT2 ZIP + JSON parser.
//!
//! Opens a .pt2 file (ZIP archive), reads the model JSON, and extracts
//! the graph structure, weight mapping, and symbolic shape info.
use std::collections::HashMap;
use std::fs::File;
use std::io::Read;
use anyhow::{Context, Result};
use zip::ZipArchive;
use crate::pt2_schema::*;
/// Parsed PT2 file contents — everything needed for graph translation.
#[derive(Debug)]
pub struct ParsedPT2 {
/// The exported program (graph, signature, etc.)
pub program: ExportedProgram,
/// Constants config: tensor constant name -> (file path in zip, tensor metadata)
pub constants_config: Option<WeightsConfig>,
/// Archive name prefix (e.g., "luminal_mlp")
pub archive_prefix: String,
/// Path to the original .pt2 file (for re-reading constants)
pub pt2_path: String,
}
/// Classification of a graph input.
#[derive(Debug, Clone)]
pub enum InputKind {
/// A model parameter (e.g., "fc1.weight")
Parameter {
graph_name: String,
original_name: String,
},
/// A model buffer (e.g., "running_mean")
Buffer {
graph_name: String,
original_name: String,
},
/// A user-provided input tensor (e.g., "x")
UserInput { graph_name: String },
}
/// Symbolic dimension mapping: PT2 symbol name -> luminal char variable.
#[derive(Debug, Clone)]
pub struct SymDimMap {
/// Maps PT2 symbol names (e.g., "s77") to luminal char variables ('a', 'b', ...)
pub sym_to_char: HashMap<String, char>,
/// Range constraints for each symbol
pub ranges: HashMap<String, RangeConstraint>,
}
impl ParsedPT2 {
/// Classify all graph inputs into parameters, buffers, and user inputs.
pub fn classify_inputs(&self) -> Vec<InputKind> {
self.program
.graph_module
.signature
.input_specs
.iter()
.filter_map(|spec| match spec {
InputSpec::Parameter(p) => Some(InputKind::Parameter {
graph_name: p.parameter.arg.name.clone(),
original_name: p.parameter.parameter_name.clone(),
}),
InputSpec::Buffer(b) => Some(InputKind::Buffer {
graph_name: b.buffer.arg.name.clone(),
original_name: b.buffer.buffer_name.clone(),
}),
InputSpec::TensorConstant(tc) => Some(InputKind::Buffer {
graph_name: tc.tensor_constant.arg.name.clone(),
original_name: tc.tensor_constant.tensor_constant_name.clone(),
}),
InputSpec::UserInput(u) => {
u.user_input
.arg
.as_tensor_name()
.map(|name| InputKind::UserInput {
graph_name: name.to_string(),
})
}
InputSpec::Other(_) => None,
})
.collect()
}
/// Get the output tensor names.
pub fn output_names(&self) -> Vec<String> {
self.program
.graph_module
.graph
.outputs
.iter()
.filter_map(|o| o.as_tensor.as_ref().map(|t| t.name.clone()))
.collect()
}
/// Get tensor metadata by name.
pub fn tensor_meta(&self, name: &str) -> Option<&TensorMeta> {
self.program.graph_module.graph.tensor_values.get(name)
}
/// Build the symbolic dimension mapping.
pub fn build_sym_dim_map(&self) -> SymDimMap {
let mut sym_to_char = HashMap::new();
let mut next_char = b'a';
// Collect all symbolic dimension names from tensor_values
let mut sym_set = std::collections::HashSet::new();
for meta in self.program.graph_module.graph.tensor_values.values() {
for size in &meta.sizes {
if let Some(sym_str) = size.symbol_name()
&& let Some(name) = extract_symbol_name(sym_str)
{
sym_set.insert(name);
}
}
}
let mut sym_names: Vec<String> = sym_set.into_iter().collect();
sym_names.sort();
for name in &sym_names {
if next_char <= b'z' {
sym_to_char.insert(name.clone(), next_char as char);
next_char += 1;
}
}
SymDimMap {
sym_to_char,
ranges: self.program.range_constraints.clone(),
}
}
}
/// Extract the symbol name from a string like "Symbol('s77', positive=True, integer=True)".
/// Public alias for use by translator.
pub fn extract_symbol_name_pub(expr_str: &str) -> Option<String> {
extract_symbol_name(expr_str)
}
fn extract_symbol_name(expr_str: &str) -> Option<String> {
// Look for Symbol('name' or Symbol("name"
let start = expr_str.find("Symbol(")? + 7;
let rest = &expr_str[start..];
// Skip the opening quote
let quote = rest.chars().next()?;
if quote != '\'' && quote != '"' {
return None;
}
let rest = &rest[1..];
let end = rest.find(quote)?;
Some(rest[..end].to_string())
}
/// Parse a .pt2 file from disk.
pub fn parse_pt2(path: &str) -> Result<ParsedPT2> {
let file = File::open(path).with_context(|| format!("Failed to open PT2 file: {path}"))?;
let mut archive = ZipArchive::new(file).context("Failed to read PT2 ZIP archive")?;
// Determine archive prefix from the first entry
let archive_prefix = {
let first = archive
.file_names()
.next()
.context("Empty PT2 archive")?
.to_string();
first.split('/').next().unwrap_or(&first).to_string()
};
// Read model.json
let model_json_path = format!("{archive_prefix}/models/model.json");
let program: ExportedProgram = {
let mut entry = archive
.by_name(&model_json_path)
.with_context(|| format!("Missing {model_json_path} in PT2 archive"))?;
let mut buf = String::new();
entry.read_to_string(&mut buf)?;
serde_json::from_str(&buf).with_context(|| "Failed to parse model.json")?
};
// Read constants config (optional — not all models have constants)
let constants_config_path =
format!("{archive_prefix}/data/constants/model_constants_config.json");
let constants_config: Option<WeightsConfig> = archive
.by_name(&constants_config_path)
.ok()
.and_then(|mut entry| {
let mut buf = String::new();
entry.read_to_string(&mut buf).ok()?;
serde_json::from_str(&buf).ok()
});
Ok(ParsedPT2 {
program,
constants_config,
archive_prefix,
pt2_path: path.to_string(),
})
}
/// Read raw constant bytes from the PT2 archive for a given constant entry.
pub fn read_constant_bytes(
pt2_path: &str,
archive_prefix: &str,
entry: &WeightEntry,
) -> Result<Vec<u8>> {
let file = File::open(pt2_path)?;
let mut archive = ZipArchive::new(file)?;
let path = format!("{archive_prefix}/data/constants/{}", entry.path_name);
let mut zip_entry = archive
.by_name(&path)
.with_context(|| format!("Missing constant file: {path}"))?;
let mut buf = Vec::new();
zip_entry.read_to_end(&mut buf)?;
Ok(buf)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_symbol_name() {
assert_eq!(
extract_symbol_name("Symbol('s77', positive=True, integer=True)"),
Some("s77".to_string())
);
assert_eq!(
extract_symbol_name("Symbol(\"batch\", positive=True)"),
Some("batch".to_string())
);
assert_eq!(extract_symbol_name("not_a_symbol"), None);
}
#[test]
fn test_parse_addone_pt2() {
let path = "/tmp/luminal_addone.pt2";
if !std::path::Path::new(path).exists() {
eprintln!("Skipping: {path} not found");
return;
}
let parsed = parse_pt2(path).unwrap();
assert_eq!(parsed.program.graph_module.graph.nodes.len(), 1);
assert_eq!(
parsed.program.graph_module.graph.nodes[0].target,
"torch.ops.aten.add.Tensor"
);
let inputs = parsed.classify_inputs();
assert_eq!(inputs.len(), 1);
assert!(matches!(&inputs[0], InputKind::UserInput { graph_name } if graph_name == "x"));
let outputs = parsed.output_names();
assert_eq!(outputs, vec!["add"]);
}
#[test]
fn test_parse_mlp_pt2() {
let path = "/tmp/luminal_mlp.pt2";
if !std::path::Path::new(path).exists() {
eprintln!("Skipping: {path} not found");
return;
}
let parsed = parse_pt2(path).unwrap();
assert_eq!(parsed.program.graph_module.graph.nodes.len(), 3);
let inputs = parsed.classify_inputs();
let params: Vec<_> = inputs
.iter()
.filter(|i| matches!(i, InputKind::Parameter { .. }))
.collect();
let user_inputs: Vec<_> = inputs
.iter()
.filter(|i| matches!(i, InputKind::UserInput { .. }))
.collect();
assert_eq!(params.len(), 3); // fc1.weight, fc2.weight, fc2.bias
assert_eq!(user_inputs.len(), 1);
}
#[test]
fn test_parse_dynamic_pt2() {
let path = "/tmp/luminal_dyn.pt2";
if !std::path::Path::new(path).exists() {
eprintln!("Skipping: {path} not found");
return;
}
let parsed = parse_pt2(path).unwrap();
let sym_map = parsed.build_sym_dim_map();
// Should have one symbolic dim (s77)
assert_eq!(sym_map.sym_to_char.len(), 1);
assert!(sym_map.sym_to_char.contains_key("s77"));
assert_eq!(sym_map.sym_to_char["s77"], 'a');
}
}

View File

@@ -1,390 +0,0 @@
//! PT2 serialized model JSON schema types (torch 2.10+ format).
//!
//! The .pt2 ZIP archive contains `{name}/models/model.json` with this structure.
//! We only model the subset needed for graph translation.
use serde::Deserialize;
use std::collections::HashMap;
#[derive(Debug, Deserialize)]
pub struct ExportedProgram {
pub graph_module: GraphModule,
#[serde(default)]
pub range_constraints: HashMap<String, RangeConstraint>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct RangeConstraint {
pub min_val: i64,
}
#[derive(Debug, Deserialize)]
pub struct GraphModule {
pub graph: Graph,
pub signature: Signature,
}
#[derive(Debug, Clone, Deserialize)]
pub struct Graph {
pub inputs: Vec<TensorRef>,
pub outputs: Vec<TensorRef>,
pub nodes: Vec<Node>,
pub tensor_values: HashMap<String, TensorMeta>,
#[serde(default)]
pub sym_int_values: HashMap<String, serde_json::Value>,
}
/// A reference to a tensor by name (used in graph inputs/outputs).
/// Single-output nodes use `as_tensor`, multi-output nodes (split, topk) use `as_tensors`.
#[derive(Debug, Clone, Deserialize)]
pub struct TensorRef {
pub as_tensor: Option<TensorName>,
#[serde(default)]
pub as_tensors: Option<Vec<TensorName>>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct TensorName {
pub name: String,
}
#[derive(Debug, Clone, Deserialize)]
pub struct Node {
pub target: String,
pub inputs: Vec<NodeInput>,
pub outputs: Vec<TensorRef>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct NodeInput {
pub name: String,
pub arg: Argument,
/// 1 = positional, 2 = keyword (not formally documented, but observed)
#[serde(default)]
pub kind: u32,
}
/// A node argument — one of several typed variants.
/// ORDER MATTERS for #[serde(untagged)]: more specific variants must come first.
#[derive(Debug, Clone, Deserialize)]
#[serde(untagged)]
pub enum Argument {
Tensor(TensorArg),
Int(IntArg),
Float(FloatArg),
Bool(BoolArg),
Ints(IntsArg),
SymInts(SymIntsArg),
SymInt(SymIntArg),
Expr(ExprArg),
#[allow(dead_code)]
ScalarType(ScalarTypeArg),
Tensors(TensorsArg),
OptionalTensors(OptionalTensorsArg),
Graph(GraphArg),
/// Fallback for anything we don't handle (Floats, Str, Layout,
/// OptionalTensor, None, Device, etc.)
#[allow(dead_code)]
Other(serde_json::Value),
}
#[derive(Debug, Clone, Deserialize)]
pub struct TensorArg {
pub as_tensor: TensorName,
}
#[derive(Debug, Clone, Deserialize)]
pub struct IntArg {
pub as_int: i64,
}
#[derive(Debug, Clone, Deserialize)]
pub struct FloatArg {
pub as_float: f64,
}
#[derive(Debug, Clone, Deserialize)]
pub struct BoolArg {
pub as_bool: bool,
}
#[derive(Debug, Clone, Deserialize)]
pub struct IntsArg {
pub as_ints: Vec<i64>,
}
/// An entry in an optional_tensors list — either a tensor ref or None.
#[derive(Debug, Clone, Deserialize)]
#[serde(untagged)]
pub enum OptionalTensorEntry {
Tensor(TensorArg),
#[allow(dead_code)] // NoneArg needed as serde discriminant for untagged enum
None(NoneArg),
}
#[derive(Debug, Clone, Deserialize)]
pub struct OptionalTensorsArg {
pub as_optional_tensors: Vec<OptionalTensorEntry>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct SymIntsArg {
pub as_sym_ints: Vec<SymIntEntry>,
}
/// An entry in a sym_ints list — either a concrete int or a symbolic name reference.
#[derive(Debug, Clone, Deserialize)]
#[serde(untagged)]
pub enum SymIntEntry {
Int(IntArg),
Name(SymIntValue),
}
#[derive(Debug, Clone, Deserialize)]
pub struct SymIntArg {
pub as_sym_int: SymIntValue,
}
#[derive(Debug, Clone, Deserialize)]
pub struct SymIntValue {
pub as_name: String,
}
#[derive(Debug, Clone, Deserialize)]
pub struct ExprArg {
pub as_expr: ExprValue,
}
#[derive(Debug, Clone, Deserialize)]
pub struct ExprValue {
pub expr_str: String,
pub hint: Option<Box<Argument>>,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct NoneArg {
#[allow(dead_code)] // Serde discriminating key for OptionalTensorEntry untagged enum
pub as_none: serde_json::Value,
}
#[derive(Debug, Clone, Deserialize)]
#[allow(dead_code)]
pub struct ScalarTypeArg {
pub as_scalar_type: u32,
}
#[derive(Debug, Clone, Deserialize)]
pub struct TensorsArg {
pub as_tensors: Vec<TensorName>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct GraphArg {
pub as_graph: SubGraph,
}
/// A subgraph embedded in a higher-order op (e.g. wrap_with_set_grad_enabled).
#[derive(Debug, Clone, Deserialize)]
pub struct SubGraph {
pub graph: Graph,
}
impl Argument {
pub fn as_tensor_name(&self) -> Option<&str> {
match self {
Argument::Tensor(t) => Some(&t.as_tensor.name),
_ => None,
}
}
pub fn as_int(&self) -> Option<i64> {
match self {
Argument::Int(i) => Some(i.as_int),
_ => None,
}
}
pub fn as_float(&self) -> Option<f64> {
match self {
Argument::Float(f) => Some(f.as_float),
_ => None,
}
}
pub fn as_bool(&self) -> Option<bool> {
match self {
Argument::Bool(b) => Some(b.as_bool),
_ => None,
}
}
pub fn as_ints(&self) -> Option<&[i64]> {
match self {
Argument::Ints(i) => Some(&i.as_ints),
_ => None,
}
}
#[allow(dead_code)]
pub fn as_scalar_type(&self) -> Option<u32> {
match self {
Argument::ScalarType(s) => Some(s.as_scalar_type),
_ => None,
}
}
pub fn as_tensors(&self) -> Option<&[TensorName]> {
match self {
Argument::Tensors(t) => Some(&t.as_tensors),
_ => None,
}
}
pub fn as_graph(&self) -> Option<&SubGraph> {
match self {
Argument::Graph(g) => Some(&g.as_graph),
_ => None,
}
}
pub fn as_sym_int_name(&self) -> Option<&str> {
match self {
Argument::SymInt(s) => Some(&s.as_sym_int.as_name),
_ => None,
}
}
pub fn as_sym_ints(&self) -> Option<&[SymIntEntry]> {
match self {
Argument::SymInts(s) => Some(&s.as_sym_ints),
_ => None,
}
}
pub fn as_optional_tensors(&self) -> Option<&[OptionalTensorEntry]> {
match self {
Argument::OptionalTensors(t) => Some(&t.as_optional_tensors),
_ => None,
}
}
}
/// Tensor metadata (shape, dtype, strides).
#[derive(Debug, Clone, Deserialize)]
pub struct TensorMeta {
pub dtype: u32,
pub sizes: Vec<DimSize>,
}
/// A dimension size — either a concrete integer or a symbolic expression.
#[derive(Debug, Clone, Deserialize)]
#[serde(untagged)]
pub enum DimSize {
Int(DimInt),
Expr(DimExpr),
}
#[derive(Debug, Clone, Deserialize)]
pub struct DimInt {
pub as_int: i64,
}
#[derive(Debug, Clone, Deserialize)]
pub struct DimExpr {
pub as_expr: ExprValue,
}
impl DimSize {
pub fn symbol_name(&self) -> Option<&str> {
match self {
DimSize::Expr(e) => Some(&e.as_expr.expr_str),
DimSize::Int(_) => None,
}
}
pub fn hint(&self) -> Option<i64> {
match self {
DimSize::Expr(e) => e.as_expr.hint.as_ref().and_then(|h| h.as_int()),
DimSize::Int(i) => Some(i.as_int),
}
}
}
/// Signature describing which inputs are parameters vs user inputs.
#[derive(Debug, Deserialize)]
pub struct Signature {
pub input_specs: Vec<InputSpec>,
}
/// An input spec — tagged enum via JSON key.
#[derive(Debug, Deserialize)]
#[serde(untagged)]
pub enum InputSpec {
Parameter(ParameterInput),
Buffer(BufferInput),
TensorConstant(TensorConstantInput),
UserInput(UserInputSpec),
#[allow(dead_code)] // Serde catch-all for untagged enum
Other(serde_json::Value),
}
#[derive(Debug, Deserialize)]
pub struct ParameterInput {
pub parameter: ParameterDetail,
}
#[derive(Debug, Deserialize)]
pub struct ParameterDetail {
pub arg: ParameterArg,
pub parameter_name: String,
}
#[derive(Debug, Deserialize)]
pub struct ParameterArg {
pub name: String,
}
#[derive(Debug, Deserialize)]
pub struct BufferInput {
pub buffer: BufferDetail,
}
#[derive(Debug, Deserialize)]
pub struct BufferDetail {
pub arg: ParameterArg,
pub buffer_name: String,
}
#[derive(Debug, Deserialize)]
pub struct TensorConstantInput {
pub tensor_constant: TensorConstantDetail,
}
#[derive(Debug, Deserialize)]
pub struct TensorConstantDetail {
pub arg: ParameterArg,
pub tensor_constant_name: String,
}
#[derive(Debug, Deserialize)]
pub struct UserInputSpec {
pub user_input: UserInputDetail,
}
#[derive(Debug, Deserialize)]
pub struct UserInputDetail {
pub arg: Argument,
}
/// Weights configuration from model_weights_config.json.
#[derive(Debug, Deserialize)]
pub struct WeightsConfig {
pub config: HashMap<String, WeightEntry>,
}
#[derive(Debug, Deserialize)]
pub struct WeightEntry {
pub path_name: String,
pub tensor_meta: TensorMeta,
}

View File

@@ -1,209 +0,0 @@
use luminal::prelude::*;
/// Binary operation type.
#[derive(Clone, Copy)]
pub enum BinaryOp {
Add,
Mul,
Sub,
Div,
}
/// Reduction operation type.
#[derive(Clone, Copy)]
pub enum ReductionOp {
Sum,
Mean,
Max,
Min,
Prod,
}
/// Normalize a potentially negative dimension index.
pub fn normalize_dim(dim: i64, ndim: usize) -> usize {
if dim < 0 {
(ndim as i64 + dim) as usize
} else {
dim as usize
}
}
/// Broadcast two tensors following NumPy broadcasting rules.
/// Right-aligns dims, unsqueezes shorter, expands size-1 dims.
pub fn broadcast_binary(mut a: GraphTensor, mut b: GraphTensor) -> (GraphTensor, GraphTensor) {
let a_ndim = a.shape.len();
let b_ndim = b.shape.len();
// Right-align: unsqueeze the shorter tensor on the left
if a_ndim < b_ndim {
for _ in 0..(b_ndim - a_ndim) {
a = a.unsqueeze(0);
}
} else if b_ndim < a_ndim {
for _ in 0..(a_ndim - b_ndim) {
b = b.unsqueeze(0);
}
}
// Now both have same ndim. Expand size-1 dims to match.
let ndim = a.shape.len();
for i in 0..ndim {
let a_dim = a.shape.dims[i];
let b_dim = b.shape.dims[i];
if a_dim == b_dim {
continue;
}
if a_dim.to_usize() == Some(1) {
a.shape.dims[i] = b_dim;
a.shape.strides[i] = Expression::from(0usize);
} else if b_dim.to_usize() == Some(1) {
b.shape.dims[i] = a_dim;
b.shape.strides[i] = Expression::from(0usize);
}
}
(a, b)
}
/// Ensure two tensors have the same dtype, casting Int->F32 or Bool->F32 if needed.
pub fn ensure_same_dtype(a: GraphTensor, b: GraphTensor) -> (GraphTensor, GraphTensor) {
if a.dtype == b.dtype {
return (a, b);
}
let target = match (a.dtype, b.dtype) {
(DType::F32, _) | (_, DType::F32) => DType::F32,
(DType::Int, _) | (_, DType::Int) => DType::Int,
_ => DType::F32,
};
let a = if a.dtype != target { a.cast(target) } else { a };
let b = if b.dtype != target { b.cast(target) } else { b };
(a, b)
}
/// Reshape a GraphTensor by replacing its ShapeTracker (view-only, no new node).
pub fn reshape_tensor(t: GraphTensor, shape: Vec<Expression>) -> GraphTensor {
let new_shape = ShapeTracker::new(shape);
GraphTensor {
id: t.id,
graph_ref: t.graph_ref,
shape: new_shape,
dtype: t.dtype,
}
}
/// Resolve -1 in a reshape target shape.
pub fn resolve_neg1_dim(target: &[i64], current_dims: &[Expression]) -> Vec<Expression> {
let mut neg1_idx = None;
let mut known_product: i64 = 1;
let mut result = Vec::with_capacity(target.len());
for (i, &s) in target.iter().enumerate() {
if s == -1 {
neg1_idx = Some(i);
result.push(Expression::from(0usize)); // placeholder
} else {
known_product *= s;
result.push(Expression::from(s as usize));
}
}
if let Some(idx) = neg1_idx {
let mut total = Expression::from(1usize);
for d in current_dims {
total *= *d;
}
if let (Some(total_val), Some(_)) = (
{
let mut t = 1i64;
let mut all_concrete = true;
for d in current_dims {
if let Some(v) = d.to_usize() {
t *= v as i64;
} else {
all_concrete = false;
}
}
if all_concrete { Some(t) } else { None }
},
Some(known_product),
) {
result[idx] = Expression::from((total_val / known_product) as usize);
} else {
result[idx] = total / Expression::from(known_product as usize);
}
}
result
}
/// Resolve -1 in a reshape target shape that contains Expression values.
pub fn resolve_neg1_dim_exprs(
target: &[Expression],
current_dims: &[Expression],
) -> Vec<Expression> {
let neg1_expr = Expression::from(-1i32);
let neg1_idx = target.iter().position(|e| *e == neg1_expr);
if let Some(idx) = neg1_idx {
let mut result = target.to_vec();
let mut input_concrete: i64 = 1;
let mut input_symbolic: Vec<Expression> = Vec::new();
for d in current_dims {
if let Some(v) = d.to_usize() {
input_concrete *= v as i64;
} else {
input_symbolic.push(*d);
}
}
let mut target_concrete: i64 = 1;
let mut target_symbolic: Vec<Expression> = Vec::new();
for (i, e) in target.iter().enumerate() {
if i == idx {
continue;
}
if let Some(v) = e.to_usize() {
target_concrete *= v as i64;
} else {
target_symbolic.push(*e);
}
}
for ts in &target_symbolic {
if let Some(pos) = input_symbolic.iter().position(|is| is == ts) {
input_symbolic.remove(pos);
}
}
if input_symbolic.is_empty() {
result[idx] = Expression::from((input_concrete / target_concrete) as usize);
} else {
let mut expr = Expression::from((input_concrete / target_concrete) as usize);
for s in &input_symbolic {
expr *= *s;
}
result[idx] = expr;
}
result
} else {
target.to_vec()
}
}
/// Map torch dtype integer (PT2 format) to luminal DType.
/// PT2 numbering: 1=uint8, 2=int8, 3=int16, 4=int32, 5=int64, 6=float16, 7=float32, 8=float64, 12=bool, 13=bfloat16
pub fn torch_dtype_int_to_luminal(dtype: u32) -> DType {
match dtype {
6 => DType::F16,
7 => DType::F32,
8 => DType::F32, // float64 → F32 (no F64 in luminal)
13 => DType::Bf16,
12 => DType::Bool,
1..=5 => DType::Int, // uint8, int8, int16, int32, int64
_ => DType::F32,
}
}

View File

@@ -1,99 +0,0 @@
use luminal::hlir::NativeData;
use luminal::prelude::*;
#[cfg(feature = "cuda")]
use luminal_cuda_lite::cudarc::driver::{CudaContext, CudaStream};
#[cfg(feature = "cuda")]
use luminal_cuda_lite::runtime::CudaRuntime;
use rustc_hash::FxHashMap;
#[cfg(feature = "cuda")]
use std::sync::Arc;
use crate::typed_data::TypedData;
/// Enum wrapper for runtime backends allowing runtime selection.
pub enum RuntimeBackend {
Native(NativeRuntime),
#[cfg(feature = "cuda")]
Cuda(Box<CudaRuntime>),
}
impl RuntimeBackend {
/// Set input data for a tensor node (dtype-aware).
pub fn set_data(&mut self, node: NodeIndex, data: TypedData) {
match self {
RuntimeBackend::Native(rt) => {
let native: NativeData = data.into();
rt.set_data(node, native);
}
#[cfg(feature = "cuda")]
RuntimeBackend::Cuda(rt) => {
// CUDA runtime stores raw bytes — just upload directly
rt.set_data(node, data.bytes);
}
}
}
/// Set input data from a Vec<f32> (convenience for backward compatibility).
pub fn set_data_f32(&mut self, node: NodeIndex, data: Vec<f32>) {
match self {
RuntimeBackend::Native(rt) => rt.set_data(node, data),
#[cfg(feature = "cuda")]
RuntimeBackend::Cuda(rt) => rt.set_data(node, data),
}
}
/// Execute the compiled graph.
pub fn execute(&mut self, dyn_map: &FxHashMap<char, usize>) {
match self {
RuntimeBackend::Native(rt) => rt.execute(dyn_map),
#[cfg(feature = "cuda")]
RuntimeBackend::Cuda(rt) => rt.execute(dyn_map),
}
}
/// Get output data as f32 from a tensor node.
pub fn get_f32(&self, node: NodeIndex) -> Vec<f32> {
match self {
RuntimeBackend::Native(rt) => rt.get_f32(node).to_vec(),
#[cfg(feature = "cuda")]
RuntimeBackend::Cuda(rt) => rt.get_f32(node),
}
}
/// Get the name of the active backend.
pub fn name(&self) -> &'static str {
match self {
RuntimeBackend::Native(_) => "native",
#[cfg(feature = "cuda")]
RuntimeBackend::Cuda(_) => "cuda",
}
}
}
// ============================================================================
// Two-phase initialization for CUDA (required because profiling executes graph)
// ============================================================================
/// Prepare CUDA runtime: build search space and create runtime, but don't search yet.
/// Returns the unoptimized runtime that can have data set on it.
///
/// Use this with `finalize_cuda` for proper CUDA initialization:
/// 1. Call `prepare_cuda` to get the runtime
/// 2. Set data on the runtime using `rt.set_data(node_id, data)`
/// 3. Call `finalize_cuda` to run profiling with data available
#[cfg(feature = "cuda")]
pub fn prepare_cuda(context: &mut Graph) -> Result<(CudaRuntime, Arc<CudaStream>), String> {
let cuda_ctx =
CudaContext::new(0).map_err(|e| format!("Failed to init CUDA context: {}", e))?;
let stream = cuda_ctx.default_stream();
context.build_search_space::<CudaRuntime>();
let rt = CudaRuntime::initialize(stream.clone());
Ok((rt, stream))
}
/// Finalize CUDA runtime: run search with data already set.
#[cfg(feature = "cuda")]
pub fn finalize_cuda(context: &mut Graph, rt: CudaRuntime) -> RuntimeBackend {
let optimized_rt = context.search(rt, 10);
RuntimeBackend::Cuda(Box::new(optimized_rt))
}

View File

@@ -1,57 +0,0 @@
use anyhow::Result;
use luminal::prelude::*;
use crate::pt2_schema::*;
use crate::pt2_util::*;
use super::Translator;
impl<'a> Translator<'a> {
pub(crate) fn translate_binary_op(&mut self, node: &Node, op: BinaryOp) -> Result<GraphTensor> {
let a = self.get_input_tensor(node, 0)?;
let arg1 = &node.inputs[1].arg;
if let Some(name) = arg1.as_tensor_name() {
let b = self.get_tensor(name)?;
let (a, b) = ensure_same_dtype(a, b);
let (a, b) = broadcast_binary(a, b);
Ok(match op {
BinaryOp::Add => a + b,
BinaryOp::Mul => a * b,
BinaryOp::Sub => a - b,
BinaryOp::Div => a / b,
})
} else {
let val = self.get_float_arg(node, 1)? as f32;
Ok(self.apply_scalar_op(a, val, op))
}
}
pub(crate) fn translate_binary_scalar_op(
&mut self,
node: &Node,
op: BinaryOp,
) -> Result<GraphTensor> {
let a = self.get_input_tensor(node, 0)?;
let val = self.get_float_arg(node, 1)? as f32;
Ok(self.apply_scalar_op(a, val, op))
}
pub(crate) fn apply_scalar_op(
&mut self,
a: GraphTensor,
val: f32,
op: BinaryOp,
) -> GraphTensor {
let scalar = self
.graph
.constant_float(val)
.cast(a.dtype)
.expand_rhs(a.shape);
match op {
BinaryOp::Add => a + scalar,
BinaryOp::Mul => a * scalar,
BinaryOp::Sub => a - scalar,
BinaryOp::Div => a / scalar,
}
}
}

View File

@@ -1,407 +0,0 @@
use anyhow::Result;
use luminal::prelude::*;
use crate::pt2_schema::*;
use super::Translator;
const CONV_INPUT_ARG: usize = 0;
const CONV_WEIGHT_ARG: usize = 1;
const CONV_BIAS_ARG: usize = 2;
const CONV_STRIDE_ARG: usize = 3;
const CONV_PADDING_ARG: usize = 4;
const CONV_DILATION_ARG: usize = 5;
const CONV_GROUPS_ARG: usize = 6;
const CONVOLUTION_TRANSPOSED_ARG: usize = 6;
const CONVOLUTION_OUTPUT_PADDING_ARG: usize = 7;
const CONVOLUTION_GROUPS_ARG: usize = 8;
impl<'a> Translator<'a> {
/// Translate aten.conv{1,2,3}d.default and aten.convolution.default.
///
/// The PT2 export may omit defaulted trailing arguments entirely. In practice this means
/// conv{N}d.default can show up as just `(input, weight)` for the no-bias, stride=1,
/// padding=0, dilation=1, groups=1 case.
pub(crate) fn translate_conv(&mut self, node: &Node) -> Result<GraphTensor> {
let input = self.get_input_tensor(node, CONV_INPUT_ARG)?;
let weight = self.get_input_tensor(node, CONV_WEIGHT_ARG)?;
let bias = self.get_input_tensor(node, CONV_BIAS_ARG).ok();
let x_dims = input.dims();
let w_dims = weight.dims();
let rank = x_dims.len();
let spatial = rank - 2;
let stride = self
.get_ints_arg(node, CONV_STRIDE_ARG)
.unwrap_or_else(|_| vec![1; spatial]);
let padding = self
.get_ints_arg(node, CONV_PADDING_ARG)
.unwrap_or_else(|_| vec![0; spatial]);
let mut dilation = self
.get_ints_arg(node, CONV_DILATION_ARG)
.unwrap_or_else(|_| vec![1; spatial]);
let groups = if node.target == "torch.ops.aten.convolution.default" {
let transposed = self
.get_bool_arg(node, CONVOLUTION_TRANSPOSED_ARG)
.unwrap_or(false);
anyhow::ensure!(
!transposed,
"conv: ConvTranspose / transposed=true is not supported yet"
);
let output_padding = self
.get_ints_arg(node, CONVOLUTION_OUTPUT_PADDING_ARG)
.unwrap_or_else(|_| vec![0; spatial]);
anyhow::ensure!(
output_padding.iter().all(|&v| v == 0),
"conv: output_padding is not supported for non-transposed convolution"
);
self.get_int_arg(node, CONVOLUTION_GROUPS_ARG).unwrap_or(1) as usize
} else {
self.get_int_arg(node, CONV_GROUPS_ARG).unwrap_or(1) as usize
};
if dilation.len() != spatial {
dilation = vec![1; spatial];
}
let ch_out = w_dims[0]
.to_usize()
.ok_or_else(|| anyhow::anyhow!("conv: weight C_out must be concrete"))?;
let ch_in = x_dims[1]
.to_usize()
.ok_or_else(|| anyhow::anyhow!("conv: input C_in must be concrete"))?;
anyhow::ensure!(
stride.len() == spatial && padding.len() == spatial && dilation.len() == spatial,
"conv: stride/padding/dilation rank must match spatial rank {spatial}"
);
anyhow::ensure!(
groups > 0 && ch_in % groups == 0 && ch_out % groups == 0,
"conv: invalid group configuration (C_in={ch_in}, C_out={ch_out}, groups={groups})"
);
let ch_per_group = ch_in / groups;
let kernel_shape: Vec<usize> = w_dims[2..]
.iter()
.map(|d| {
d.to_usize()
.ok_or_else(|| anyhow::anyhow!("conv: kernel dims must be concrete"))
})
.collect::<Result<_>>()?;
let kernel_product: usize = kernel_shape.iter().product();
// ATen uses symmetric padding (same begin/end)
let stride_u: Vec<usize> = stride.iter().map(|&v| v as usize).collect();
let padding_u: Vec<usize> = padding.iter().map(|&v| v as usize).collect();
let dilation_u: Vec<usize> = dilation.iter().map(|&v| v as usize).collect();
let mut out = if groups > 1 {
let group_out = ch_out / groups;
if ch_per_group == 1 {
// Depthwise (including channel multiplier > 1): avoid per-channel slicing.
depthwise_conv(
input,
weight,
&kernel_shape,
&stride_u,
&dilation_u,
&padding_u,
&padding_u,
ch_in,
group_out,
kernel_product,
spatial,
)
} else {
// General grouped: pre-pad full input then slice per group
let padded_input = {
let mut pad_spec: Vec<(Expression, Expression)> =
vec![(0.into(), 0.into()); 2 + spatial];
for i in 0..spatial {
pad_spec[2 + i] = (padding_u[i].into(), padding_u[i].into());
}
input.pad(pad_spec, 0.0)
};
let no_pad = vec![0usize; spatial];
let mut group_outputs = Vec::with_capacity(groups);
for g in 0..groups {
let x_g = slice_channel_group(padded_input, g, ch_per_group, spatial);
let w_g =
slice_weight_group(weight, g, group_out, ch_per_group * kernel_product);
group_outputs.push(conv_unfold(
x_g,
w_g,
&kernel_shape,
&stride_u,
&dilation_u,
&no_pad,
&no_pad,
ch_per_group,
group_out,
spatial,
));
}
let mut result = group_outputs[0];
for g_out in &group_outputs[1..] {
result = result.concat_along(*g_out, 1);
}
result
}
} else {
let mut w_flat = weight;
w_flat.shape = ShapeTracker::new_with_element_bits(
vec![ch_out, ch_in * kernel_product],
weight.dtype.bits(),
);
conv_unfold(
input,
w_flat,
&kernel_shape,
&stride_u,
&dilation_u,
&padding_u,
&padding_u,
ch_in,
ch_out,
spatial,
)
};
if let Some(b) = bias {
let out_dims = out.dims();
let mut b_expanded = b.expand_dim(0, 1);
for i in 0..spatial {
b_expanded = b_expanded.expand_dim(2 + i, out_dims[2 + i]);
}
out += b_expanded;
}
Ok(out)
}
}
/// Slice input channels for one group.
/// Caller must pre-pad `x` so no additional padding is applied to the slice.
fn slice_channel_group(
x: GraphTensor,
g: usize,
ch_per_group: usize,
spatial: usize,
) -> GraphTensor {
let start = g * ch_per_group;
let end = start + ch_per_group;
let dims = x.dims();
let rank = 2 + spatial;
let mut slices: Vec<(Expression, Expression)> = Vec::with_capacity(rank);
slices.push((0.into(), dims[0]));
slices.push((start.into(), end.into()));
for dim in dims.iter().take(rank).skip(2) {
slices.push((0.into(), *dim));
}
x.slice(slices)
}
/// Slice and flatten weight for one group.
fn slice_weight_group(
w: GraphTensor,
g: usize,
group_out: usize,
flat_inner: usize,
) -> GraphTensor {
let start = g * group_out;
let end = start + group_out;
let w_dims = w.dims();
let mut slices: Vec<(Expression, Expression)> = Vec::with_capacity(w_dims.len());
slices.push((start.into(), end.into()));
for dim in w_dims.iter().skip(1) {
slices.push((0.into(), *dim));
}
// Materialize through Add: binary op outputs are contiguous in Luminal, which makes the
// following flatten safe for the sliced weight buffer.
let w_sliced = w.slice(slices) + 0.0;
let mut w_flat = w_sliced;
w_flat.shape =
ShapeTracker::new_with_element_bits(vec![group_out, flat_inner], w_sliced.dtype.bits());
w_flat
}
/// Core unfold-based convolution for a single group.
///
/// `x`: [batch, ch_in, spatial...]
/// `w_flat`: [ch_out, ch_in * kernel_product] (already reshaped)
/// Returns: [batch, ch_out, out_spatial...]
#[allow(clippy::too_many_arguments)]
fn conv_unfold(
x: GraphTensor,
w_flat: GraphTensor,
kernel_shape: &[usize],
strides: &[usize],
dilations: &[usize],
pads_begin: &[usize],
pads_end: &[usize],
_ch_in: usize,
_ch_out: usize,
spatial: usize,
) -> GraphTensor {
let rank = 2 + spatial;
// Pad spatial dimensions (skip if all padding is zero)
let needs_pad = pads_begin.iter().any(|&p| p > 0) || pads_end.iter().any(|&p| p > 0);
let padded = if needs_pad {
let mut padding: Vec<(Expression, Expression)> = vec![(0.into(), 0.into()); rank];
for i in 0..spatial {
padding[2 + i] = (pads_begin[i].into(), pads_end[i].into());
}
x.pad(padding, 0.0)
} else {
x
};
// Build full-rank unfold parameters (1 for batch/channel, actual for spatial)
let mut kernel_full = vec![1usize; rank];
let mut stride_full = vec![1usize; rank];
let mut dilation_full = vec![1usize; rank];
kernel_full[2..(spatial + 2)].copy_from_slice(&kernel_shape[..spatial]);
stride_full[2..(spatial + 2)].copy_from_slice(&strides[..spatial]);
dilation_full[2..(spatial + 2)].copy_from_slice(&dilations[..spatial]);
let unfolded = padded.unfold(kernel_full, stride_full, dilation_full);
// Shape: [win_N, win_C, win_spatial..., k_N=1, k_C=1, k_spatial...]
// Permute to [N, win_spatial..., C_in, k_N, k_C, k_spatial...]
let mut perm: Vec<usize> = Vec::with_capacity(2 * rank);
perm.push(0);
perm.extend(2..2 + spatial);
perm.push(1);
perm.extend(rank..2 * rank);
let permuted = unfolded.permute(perm);
let output_spatial_dims: Vec<Expression> = permuted.dims()[1..1 + spatial].to_vec();
// Merge all channel+kernel dims into [N, spatial..., ch_in * kernel_product]
let mut patches = permuted;
let target = 2 + spatial;
while patches.dims().len() > target {
let last = patches.dims().len();
patches = patches.merge_dims(last - 2, last - 1);
}
// Merge spatial dims into one
for _ in 1..spatial {
patches = patches.merge_dims(1, 2);
}
// patches: [N, spatial_product, ch_in * kernel_product]
let mut out = patches.matmul(w_flat.permute((1, 0)));
// out: [N, spatial_product, ch_out]
// Restore spatial dimensions
for i in (1..spatial).rev() {
out = out.split_dims(1, output_spatial_dims[i]);
}
// Move ch_out from last to position 1: [N, ch_out, spatial...]
let mut final_order: Vec<usize> = Vec::with_capacity(2 + spatial);
final_order.push(0);
final_order.push(1 + spatial);
final_order.extend(1..1 + spatial);
out.permute(final_order)
}
/// Depthwise convolution: groups == in_channels, ch_per_group == 1.
///
/// Processes all channels simultaneously using element-wise multiply + reduce,
/// avoiding per-channel input slicing which can cause index-expression bugs in luminal.
///
/// out[n, c, oh, ow] = sum_k patches[n, c, oh, ow, k] * weight[c, k]
#[allow(clippy::too_many_arguments)]
fn depthwise_conv(
x: GraphTensor,
w: GraphTensor, // [C, 1, *kernel]
kernel_shape: &[usize],
strides: &[usize],
dilations: &[usize],
pads_begin: &[usize],
pads_end: &[usize],
ch: usize,
group_out: usize,
kernel_product: usize,
spatial: usize,
) -> GraphTensor {
let rank = 2 + spatial;
let needs_pad = pads_begin.iter().any(|&p| p > 0) || pads_end.iter().any(|&p| p > 0);
let padded = if needs_pad {
let mut padding: Vec<(Expression, Expression)> = vec![(0.into(), 0.into()); rank];
for i in 0..spatial {
padding[2 + i] = (pads_begin[i].into(), pads_end[i].into());
}
x.pad(padding, 0.0)
} else {
x
};
// Unfold the full [N, C, H+2p, W+2p] with kernel [1, 1, kH, kW]
let mut kernel_full = vec![1usize; rank];
let mut stride_full = vec![1usize; rank];
let mut dilation_full = vec![1usize; rank];
kernel_full[2..(spatial + 2)].copy_from_slice(&kernel_shape[..spatial]);
stride_full[2..(spatial + 2)].copy_from_slice(&strides[..spatial]);
dilation_full[2..(spatial + 2)].copy_from_slice(&dilations[..spatial]);
let unfolded = padded.unfold(kernel_full, stride_full, dilation_full);
// Shape: [N, C, out_H, out_W, 1, 1, kH, kW]
// Permute to [N, C, out_spatial..., k_all...]
let mut perm: Vec<usize> = Vec::with_capacity(2 * rank);
perm.push(0); // N
perm.push(1); // C
perm.extend(2..2 + spatial); // win_spatial
perm.extend(rank..2 * rank); // all kernel dims
let permuted = unfolded.permute(perm);
let out_spatial_dims: Vec<Expression> = permuted.dims()[2..2 + spatial].to_vec();
// Merge all kernel dims (including 1-size k_N, k_C) into kernel_product
let target = 3 + spatial; // [N, C, spatial..., K]
let mut patches = permuted;
while patches.dims().len() > target {
let last = patches.dims().len();
patches = patches.merge_dims(last - 2, last - 1);
}
// patches: [N, C, out_H, ..., out_W, kernel_product]
// Merge spatial into one: [N, C, out_spatial_product, kernel_product]
for _ in 1..spatial {
patches = patches.merge_dims(2, 3);
}
// Weight [C * group_out, 1, *kernel] -> [C, group_out, kernel_product]
let mut w_flat = w;
w_flat.shape =
ShapeTracker::new_with_element_bits(vec![ch, group_out, kernel_product], w.dtype.bits());
// patches: [N, C, out_spatial_product, kernel_product]
// Expand to [N, C, group_out, out_spatial_product, kernel_product]
let patches = patches.expand_dim(2, group_out);
// Expand weight for broadcast: [1, C, group_out, out_spatial_product, kernel_product]
let w_expanded = w_flat.expand_dim(0, 1).expand_dim(3, patches.dims()[3]);
// Element-wise multiply and sum over kernel dim
let product = patches * w_expanded;
let mut out = product.sum(vec![4]).merge_dims(1, 2);
// out: [N, C * group_out, out_spatial_product]
// Restore spatial dimensions
for i in (1..spatial).rev() {
out = out.split_dims(2, out_spatial_dims[i]);
}
// out: [N, C, out_spatial_0, ..., out_spatial_{s-1}]
out
}

View File

@@ -1,400 +0,0 @@
use anyhow::{Result, bail};
use luminal::prelude::*;
use crate::pt2_schema::*;
use crate::pt2_util::*;
use super::Translator;
impl<'a> Translator<'a> {
pub(crate) fn translate_node(&mut self, node: &Node) -> Result<()> {
let target = &node.target;
let output_name = node
.outputs
.first()
.and_then(|o| {
o.as_tensor.as_ref().map(|t| t.name.clone()).or_else(|| {
o.as_tensors
.as_ref()
.and_then(|ts| ts.first().map(|t| t.name.clone()))
})
})
.unwrap_or_default();
// No-output ops
match target.as_str() {
"torch.ops.aten._assert_tensor_metadata.default"
| "torch.ops.aten._assert_scalar.default" => return Ok(()),
"torch.ops.higher_order.wrap_with_set_grad_enabled" => {
return self.translate_wrap_set_grad(node);
}
_ => {}
}
let has_tensor_output = node
.outputs
.iter()
.any(|o| o.as_tensor.is_some() || o.as_tensors.is_some());
if !has_tensor_output {
return Ok(());
}
let result = match target.as_str() {
// Binary ops
// Note: rsub/rdiv are not handled here because torch.export decomposes them
// into sub/div with swapped operands before emission.
"torch.ops.aten.add.Tensor" => self.translate_binary_op(node, BinaryOp::Add)?,
"torch.ops.aten.add.Scalar" => self.translate_binary_scalar_op(node, BinaryOp::Add)?,
"torch.ops.aten.mul.Tensor" => self.translate_binary_op(node, BinaryOp::Mul)?,
"torch.ops.aten.mul.Scalar" => self.translate_binary_scalar_op(node, BinaryOp::Mul)?,
"torch.ops.aten.sub.Tensor" => self.translate_binary_op(node, BinaryOp::Sub)?,
"torch.ops.aten.sub.Scalar" => self.translate_binary_scalar_op(node, BinaryOp::Sub)?,
"torch.ops.aten.div.Tensor" => self.translate_binary_op(node, BinaryOp::Div)?,
"torch.ops.aten.div.Scalar" => self.translate_binary_scalar_op(node, BinaryOp::Div)?,
// Unary ops
"torch.ops.aten.neg.default" => self.translate_unary_op(node, |a| a * (-1.0))?,
"torch.ops.aten.exp.default" => self.translate_unary_op(node, |a| a.exp())?,
"torch.ops.aten.sin.default" => self.translate_unary_op(node, |a| a.sin())?,
"torch.ops.aten.cos.default" => self.translate_unary_op(node, |a| a.cos())?,
"torch.ops.aten.sqrt.default" => self.translate_unary_op(node, |a| a.sqrt())?,
"torch.ops.aten.rsqrt.default" => {
self.translate_unary_op(node, |a| a.sqrt().reciprocal())?
}
"torch.ops.aten.reciprocal.default" => {
self.translate_unary_op(node, |a| a.reciprocal())?
}
"torch.ops.aten.sigmoid.default" => self.translate_unary_op(node, |a| a.sigmoid())?,
"torch.ops.aten.relu.default" => self.translate_unary_op(node, |a| a.relu())?,
"torch.ops.aten.tanh.default" => self.translate_unary_op(node, |a| a.tanh())?,
"torch.ops.aten.abs.default" => self.translate_unary_op(node, |a| a.abs())?,
"torch.ops.aten.log.default" => self.translate_unary_op(node, |a| a.log())?,
"torch.ops.aten.log2.default" => self.translate_unary_op(node, |a| a.log2())?,
"torch.ops.aten.exp2.default" => self.translate_unary_op(node, |a| a.exp2())?,
// Cast
"torch.ops.aten._to_copy.default" => self.translate_to_copy(node)?,
// No-op
"torch.ops.aten.alias.default" => self.get_input_tensor(node, 0)?,
// Shape ops
"torch.ops.aten.view.default" => self.translate_reshape(node)?,
"torch.ops.aten.permute.default" => self.translate_permute(node)?,
"torch.ops.aten.unsqueeze.default" => {
let a = self.get_input_tensor(node, 0)?;
let dim = self.get_int_arg(node, 1)?;
let dim = normalize_dim(dim, a.shape.len() + 1);
a.unsqueeze(dim)
}
"torch.ops.aten.squeeze.dims" => {
let a = self.get_input_tensor(node, 0)?;
let dims = self.get_ints_arg(node, 1)?;
let ndim = a.shape.len();
let mut sorted_dims: Vec<usize> =
dims.iter().map(|&d| normalize_dim(d, ndim)).collect();
sorted_dims.sort();
let mut result = a;
let mut offset = 0;
for d in sorted_dims {
if result.shape.dims[d - offset].to_usize() == Some(1) {
result = result.squeeze(d - offset);
offset += 1;
}
}
result
}
"torch.ops.aten.expand.default" => self.translate_expand(node)?,
"torch.ops.aten.clone.default" => {
let a = self.get_input_tensor(node, 0)?;
if !a.shape.is_contiguous() { a + 0.0 } else { a }
}
// Matmul
"torch.ops.aten.mm.default" | "torch.ops.aten.bmm.default" => {
let a = self.get_input_tensor(node, 0)?;
let b = self.get_input_tensor(node, 1)?;
let (a, b) = ensure_same_dtype(a, b);
a.matmul(b)
}
// addmm: beta*input + alpha*(mat1 @ mat2)
"torch.ops.aten.addmm.default" => {
let input = self.get_input_tensor(node, 0)?;
let mat1 = self.get_input_tensor(node, 1)?;
let mat2 = self.get_input_tensor(node, 2)?;
let beta = self.get_float_arg(node, 3).unwrap_or(1.0) as f32;
let alpha = self.get_float_arg(node, 4).unwrap_or(1.0) as f32;
let mm = mat1.matmul(mat2);
let (input, mm) = broadcast_binary(input, mm);
input * beta + mm * alpha
}
// Convolution
"torch.ops.aten.convolution.default" => self.translate_conv(node)?,
// Reduction ops
"torch.ops.aten.sum.dim_IntList" => self.translate_reduction(node, ReductionOp::Sum)?,
"torch.ops.aten.mean.dim" => self.translate_reduction(node, ReductionOp::Mean)?,
"torch.ops.aten.amax.default" => self.translate_reduction(node, ReductionOp::Max)?,
// Slice/index ops
"torch.ops.aten.slice.Tensor" => self.translate_slice(node)?,
"torch.ops.aten.cat.default" => self.translate_cat(node)?,
"torch.ops.aten.index.Tensor" => self.translate_index_tensor(node)?,
// Embedding
"torch.ops.aten.embedding.default" => self.translate_embedding(node)?,
// Softmax
"torch.ops.aten._softmax.default" => {
let a = self.get_input_tensor(node, 0)?;
let dim = self.get_int_arg(node, 1)?;
let dim = normalize_dim(dim, a.shape.len());
a.softmax(dim)
}
// LayerNorm
"torch.ops.aten.native_layer_norm.default" => self.translate_layer_norm(node)?,
// Where
"torch.ops.aten.where.self" => self.translate_where(node)?,
// Pow
"torch.ops.aten.pow.Tensor_Scalar" => {
let a = self.get_input_tensor(node, 0)?;
let exp = self.get_float_arg(node, 1)?;
a.pow(exp as f32)
}
"torch.ops.aten.pow.Tensor_Tensor" => {
let a = self.get_input_tensor(node, 0)?;
let b = self.get_input_tensor(node, 1)?;
let (a, b) = broadcast_binary(a, b);
(b * a.log2()).exp2()
}
// Creation ops
"torch.ops.aten.arange.start_step" => self.translate_arange(node)?,
"torch.ops.aten.full.default" => self.translate_full(node)?,
"torch.ops.aten.scalar_tensor.default" => {
let val = self.get_float_arg(node, 0)? as f32;
self.graph.constant_float(val)
}
// Scalar comparisons
"torch.ops.aten.gt.Scalar" => self.translate_scalar_comparison(node, |a, s| a.gt(s))?,
"torch.ops.aten.lt.Scalar" => self.translate_scalar_comparison(node, |a, s| a.lt(s))?,
"torch.ops.aten.ge.Scalar" => self.translate_scalar_comparison(node, |a, s| a.ge(s))?,
"torch.ops.aten.le.Scalar" => self.translate_scalar_comparison(node, |a, s| a.le(s))?,
// Tensor comparisons
"torch.ops.aten.ne.Scalar" => {
let a = self.get_input_tensor(node, 0)?;
let val = self.get_float_arg(node, 1)? as f32;
let scalar = self
.graph
.constant_float(val)
.cast(a.dtype)
.expand_rhs(a.shape);
a.ne(scalar)
}
"torch.ops.aten.eq.Tensor" => {
let a = self.get_input_tensor(node, 0)?;
let b = self.get_input_tensor(node, 1)?;
let (a, b) = ensure_same_dtype(a, b);
let (a, b) = broadcast_binary(a, b);
a.eq(b)
}
"torch.ops.aten.le.Tensor" => {
let a = self.get_input_tensor(node, 0)?;
let b = self.get_input_tensor(node, 1)?;
let (a, b) = ensure_same_dtype(a, b);
let (a, b) = broadcast_binary(a, b);
a.le(b)
}
"torch.ops.aten.bitwise_and.Tensor" | "torch.ops.aten.logical_and.default" => {
let a = self.get_input_tensor(node, 0)?;
let b = self.get_input_tensor(node, 1)?;
let (a, b) = broadcast_binary(a, b);
let a = a.cast(DType::F32);
let b = b.cast(DType::F32);
(a * b).cast(DType::Bool)
}
"torch.ops.aten.logical_or.default" => {
let a = self.get_input_tensor(node, 0)?;
let b = self.get_input_tensor(node, 1)?;
let (a, b) = broadcast_binary(a, b);
let a = a.cast(DType::F32);
let b = b.cast(DType::F32);
(a + b - a * b).cast(DType::Bool)
}
"torch.ops.aten.logical_xor.default" => {
let a = self.get_input_tensor(node, 0)?;
let b = self.get_input_tensor(node, 1)?;
let (a, b) = broadcast_binary(a, b);
let a = a.cast(DType::F32);
let b = b.cast(DType::F32);
a.ne(b)
}
// Clamp
"torch.ops.aten.clamp.default" => self.translate_clamp(node)?,
// Cumsum
"torch.ops.aten.cumsum.default" => {
let a = self.get_input_tensor(node, 0)?;
let dim = self.get_int_arg(node, 1)?;
let dim = normalize_dim(dim, a.shape.len());
let a = if a.dtype == DType::Bool {
a.cast(DType::Int)
} else {
a
};
a.cumsum(dim)
}
// Floor / Ceil / Erf (approximations)
"torch.ops.aten.floor.default" => {
let a = self.get_input_tensor(node, 0)?;
// floor(x) = trunc(x) - (x < trunc(x))
let trunc = a.cast(DType::Int).cast(DType::F32);
let adjust = a.lt(trunc).cast(DType::F32);
trunc - adjust
}
"torch.ops.aten.ceil.default" => {
let a = self.get_input_tensor(node, 0)?;
// ceil(x) = -floor(-x)
let neg_a = a * (-1.0);
let trunc = neg_a.cast(DType::Int).cast(DType::F32);
let adjust = neg_a.lt(trunc).cast(DType::F32);
let floor_neg = trunc - adjust;
floor_neg * (-1.0)
}
"torch.ops.aten.erf.default" => {
let a = self.get_input_tensor(node, 0)?;
// Abramowitz & Stegun approximation 7.1.28 (max error ~1.5e-7)
// erf(x) = sign(x) * (1 - poly(t) * exp(-x^2))
// where t = 1/(1 + 0.3275911*|x|), poly in Horner form
let ax = a.abs();
let x2 = a * a;
let t = (ax * 0.3275911_f32 + 1.0).reciprocal();
// Horner: t*(a1 + t*(a2 + t*(a3 + t*(a4 + t*a5))))
let poly = t
* (t * (t
* (t * (t * 1.061_405_4_f32 + (-1.453_152_1_f32)) + 1.421_413_8_f32)
+ (-0.284_496_72_f32))
+ 0.254_829_6_f32);
let result_abs =
self.graph.constant_float(1.0).expand_rhs(a.shape) - poly * (x2 * (-1.0)).exp();
// sign(x) = 2*(x >= 0) - 1
let zero = self.graph.constant_float(0.0).expand_rhs(a.shape);
let sign = a.ge(zero).cast(DType::F32) * 2.0 - 1.0;
result_abs * sign
}
"torch.ops.aten.isnan.default" => {
let a = self.get_input_tensor(node, 0)?;
a.ne(a)
}
"torch.ops.aten.logical_not.default" => {
let a = self.get_input_tensor(node, 0)?;
let one = self.graph.constant_float(1.0).expand_rhs(a.shape);
(one - a.cast(DType::F32)).cast(DType::Bool)
}
// Element-wise min/max (tensor-tensor)
"torch.ops.aten.maximum.default" => {
let a = self.get_input_tensor(node, 0)?;
let b = self.get_input_tensor(node, 1)?;
let (a, b) = broadcast_binary(a, b);
a.maximum(b)
}
"torch.ops.aten.minimum.default" => {
let a = self.get_input_tensor(node, 0)?;
let b = self.get_input_tensor(node, 1)?;
let (a, b) = broadcast_binary(a, b);
a.minimum(b)
}
// Tensor comparisons (additional)
"torch.ops.aten.ge.Tensor" => {
let a = self.get_input_tensor(node, 0)?;
let b = self.get_input_tensor(node, 1)?;
let (a, b) = ensure_same_dtype(a, b);
let (a, b) = broadcast_binary(a, b);
a.ge(b)
}
"torch.ops.aten.lt.Tensor" => {
let a = self.get_input_tensor(node, 0)?;
let b = self.get_input_tensor(node, 1)?;
let (a, b) = ensure_same_dtype(a, b);
let (a, b) = broadcast_binary(a, b);
a.lt(b)
}
"torch.ops.aten.gt.Tensor" => {
let a = self.get_input_tensor(node, 0)?;
let b = self.get_input_tensor(node, 1)?;
let (a, b) = ensure_same_dtype(a, b);
let (a, b) = broadcast_binary(a, b);
a.gt(b)
}
// Full-reduce variants (no dim arg) — handled by translate_reduction fallback
"torch.ops.aten.sum.default" => self.translate_reduction(node, ReductionOp::Sum)?,
"torch.ops.aten.mean.default" => self.translate_reduction(node, ReductionOp::Mean)?,
"torch.ops.aten.max.default" => self.translate_reduction(node, ReductionOp::Max)?,
"torch.ops.aten.min.default" => self.translate_reduction(node, ReductionOp::Min)?,
"torch.ops.aten.amin.default" => self.translate_reduction(node, ReductionOp::Min)?,
// Gather (axis-aware)
"torch.ops.aten.gather.default" => self.translate_gather(node)?,
// Scatter ops
"torch.ops.aten.scatter.src" => self.translate_scatter_src(node)?,
"torch.ops.aten.index_put.default" => self.translate_index_put(node)?,
// TopK — handles its own output storage, returns early
"torch.ops.aten.topk.default" => {
self.translate_topk(node)?;
return Ok(());
}
// Split
"torch.ops.aten.split_with_sizes.default" => self.translate_split_with_sizes(node)?,
// Fmod
"torch.ops.aten.fmod.Tensor" => {
let a = self.get_input_tensor(node, 0)?;
let b = self.get_input_tensor(node, 1)?;
let (a, b) = broadcast_binary(a, b);
a % b
}
// Prod reduction
"torch.ops.aten.prod.dim_int" => self.translate_reduction(node, ReductionOp::Prod)?,
other => {
bail!("Unsupported ATen op: {other}");
}
};
if !output_name.is_empty() {
self.tensors.insert(output_name, result);
}
Ok(())
}
}
impl<'a> Translator<'a> {
fn translate_scalar_comparison(
&mut self,
node: &Node,
cmp: impl Fn(GraphTensor, GraphTensor) -> GraphTensor,
) -> Result<GraphTensor> {
let a = self.get_input_tensor(node, 0)?;
let val = self.get_float_arg(node, 1)? as f32;
let scalar = self
.graph
.constant_float(val)
.cast(a.dtype)
.expand_rhs(a.shape);
Ok(cmp(a, scalar))
}
}

View File

@@ -1,331 +0,0 @@
//! PT2 graph nodes -> Luminal Graph translation.
//!
//! Walks the parsed PT2 graph and constructs an equivalent Luminal computation graph.
mod binary;
mod conv;
mod dispatch;
mod movement;
mod reduction;
mod tensor;
mod unary;
use std::collections::HashMap;
use anyhow::{Context, Result};
use luminal::graph::Graph;
use luminal::prelude::*;
use crate::pt2_parser::{InputKind, ParsedPT2, SymDimMap};
use crate::pt2_schema::*;
use crate::pt2_util;
/// Result of translating a PT2 graph to a Luminal graph.
pub struct TranslatedGraph {
/// The luminal computation graph.
pub graph: Graph,
/// Node IDs for user inputs (in order).
pub user_input_ids: Vec<(String, NodeIndex)>,
/// Node IDs for outputs (in order).
pub output_ids: Vec<(String, NodeIndex)>,
/// Symbolic dimension mapping.
pub sym_map: SymDimMap,
}
/// Main translation entry point.
pub fn translate(parsed: &ParsedPT2) -> Result<TranslatedGraph> {
let mut translator = Translator::new(parsed)?;
translator.translate_graph()?;
Ok(translator.finish())
}
pub(crate) struct Translator<'a> {
pub(crate) parsed: &'a ParsedPT2,
pub(crate) graph: Graph,
/// Maps tensor name -> GraphTensor
pub(crate) tensors: HashMap<String, GraphTensor>,
pub(crate) sym_map: SymDimMap,
pub(crate) user_input_ids: Vec<(String, NodeIndex)>,
pub(crate) output_ids: Vec<(String, NodeIndex)>,
/// Extra tensor metadata from inlined subgraphs.
pub(crate) extra_tensor_values: HashMap<String, TensorMeta>,
}
impl<'a> Translator<'a> {
fn new(parsed: &'a ParsedPT2) -> Result<Self> {
let sym_map = parsed.build_sym_dim_map();
Ok(Self {
parsed,
graph: Graph::new(),
tensors: HashMap::new(),
sym_map,
user_input_ids: Vec::new(),
output_ids: Vec::new(),
extra_tensor_values: HashMap::new(),
})
}
fn translate_graph(&mut self) -> Result<()> {
self.create_inputs()?;
let nodes = &self.parsed.program.graph_module.graph.nodes;
for (i, node) in nodes.iter().enumerate() {
self.translate_node(node)
.with_context(|| format!("Failed to translate node {i}: {}", node.target))?;
}
let output_names = self.parsed.output_names();
for name in &output_names {
let tensor = self.get_tensor(name)?;
// Cast non-float outputs (Bool, Int) to F32 for the runtime.
// Preserve F16/BF16/F32 as-is to avoid corrupting half-precision models.
let tensor = match tensor.dtype {
DType::Bool | DType::Int => tensor.cast(DType::F32) + 0.0,
_ => tensor + 0.0,
};
tensor.output();
self.output_ids.push((name.clone(), tensor.id));
}
Ok(())
}
fn create_inputs(&mut self) -> Result<()> {
let inputs = self.parsed.classify_inputs();
for input in &inputs {
match input {
InputKind::Parameter {
graph_name,
original_name,
} => {
let meta = self
.parsed
.tensor_meta(graph_name)
.with_context(|| format!("Missing tensor meta for param {graph_name}"))?;
let shape = self.tensor_meta_to_shape(meta)?;
let dtype = pt2_util::torch_dtype_int_to_luminal(meta.dtype);
let tensor = self
.graph
.named_tensor(original_name, shape)
.as_dtype(dtype);
tensor.persist();
self.tensors.insert(graph_name.clone(), tensor);
}
InputKind::Buffer {
graph_name,
original_name,
} => {
let meta = self
.parsed
.tensor_meta(graph_name)
.with_context(|| format!("Missing tensor meta for buffer {graph_name}"))?;
let shape = self.tensor_meta_to_shape(meta)?;
let dtype = pt2_util::torch_dtype_int_to_luminal(meta.dtype);
let tensor = self
.graph
.named_tensor(original_name, shape)
.as_dtype(dtype);
tensor.persist();
self.tensors.insert(graph_name.clone(), tensor);
}
InputKind::UserInput { graph_name } => {
let meta = self
.parsed
.tensor_meta(graph_name)
.with_context(|| format!("Missing tensor meta for input {graph_name}"))?;
let shape = self.tensor_meta_to_shape(meta)?;
let dtype = pt2_util::torch_dtype_int_to_luminal(meta.dtype);
let tensor = self.graph.named_tensor(graph_name, shape).as_dtype(dtype);
self.user_input_ids.push((graph_name.clone(), tensor.id));
self.tensors.insert(graph_name.clone(), tensor);
}
}
}
Ok(())
}
fn finish(self) -> TranslatedGraph {
TranslatedGraph {
graph: self.graph,
user_input_ids: self.user_input_ids,
output_ids: self.output_ids,
sym_map: self.sym_map,
}
}
// --- Helper methods ---
pub(crate) fn get_tensor(&self, name: &str) -> Result<GraphTensor> {
self.tensors
.get(name)
.copied()
.with_context(|| format!("Unknown tensor: {name}"))
}
pub(crate) fn get_input_tensor(&self, node: &Node, idx: usize) -> Result<GraphTensor> {
let arg = &node
.inputs
.get(idx)
.with_context(|| format!("Node {} missing input {idx}", node.target))?
.arg;
let name = arg.as_tensor_name().with_context(|| {
format!("Input {idx} of {} is not a tensor: {:?}", node.target, arg)
})?;
self.get_tensor(name)
}
pub(crate) fn get_int_arg(&self, node: &Node, idx: usize) -> Result<i64> {
let arg = &node
.inputs
.get(idx)
.with_context(|| format!("Node {} missing input {idx}", node.target))?
.arg;
arg.as_int()
.with_context(|| format!("Input {idx} of {} is not an int: {:?}", node.target, arg))
}
pub(crate) fn get_float_arg(&self, node: &Node, idx: usize) -> Result<f64> {
let arg = &node
.inputs
.get(idx)
.with_context(|| format!("Node {} missing input {idx}", node.target))?
.arg;
if let Some(f) = arg.as_float() {
return Ok(f);
}
if let Some(i) = arg.as_int() {
return Ok(i as f64);
}
anyhow::bail!("Input {idx} of {} is not a float: {:?}", node.target, arg)
}
pub(crate) fn get_ints_arg(&self, node: &Node, idx: usize) -> Result<Vec<i64>> {
let arg = &node
.inputs
.get(idx)
.with_context(|| format!("Node {} missing input {idx}", node.target))?
.arg;
arg.as_ints()
.map(|v| v.to_vec())
.with_context(|| format!("Input {idx} of {} is not int list: {:?}", node.target, arg))
}
pub(crate) fn get_expr_arg(&self, node: &Node, idx: usize) -> Result<Expression> {
let arg = &node
.inputs
.get(idx)
.with_context(|| format!("Node {} missing input {idx}", node.target))?
.arg;
self.resolve_arg_as_expression(arg).with_context(|| {
format!(
"Input {idx} of {} cannot be resolved to Expression: {:?}",
node.target, arg
)
})
}
pub(crate) fn get_exprs_arg(&self, node: &Node, idx: usize) -> Result<Vec<Expression>> {
use crate::pt2_schema::SymIntEntry;
let arg = &node
.inputs
.get(idx)
.with_context(|| format!("Node {} missing input {idx}", node.target))?
.arg;
if let Some(ints) = arg.as_ints() {
return Ok(ints.iter().map(|&v| Expression::from(v as usize)).collect());
}
if let Some(entries) = arg.as_sym_ints() {
return entries
.iter()
.map(|entry| match entry {
SymIntEntry::Int(i) => Ok(Expression::from(i.as_int as usize)),
SymIntEntry::Name(s) => self
.resolve_sym_int(&s.as_name)
.with_context(|| format!("Cannot resolve sym_int: {}", s.as_name)),
})
.collect();
}
anyhow::bail!(
"Input {idx} of {} is not int list or sym_int list: {:?}",
node.target,
arg
)
}
pub(crate) fn get_bool_arg(&self, node: &Node, idx: usize) -> Result<bool> {
let arg = &node
.inputs
.get(idx)
.with_context(|| format!("Node {} missing input {idx}", node.target))?
.arg;
arg.as_bool()
.with_context(|| format!("Input {idx} of {} is not a bool: {:?}", node.target, arg))
}
pub(crate) fn tensor_meta_to_shape(&self, meta: &TensorMeta) -> Result<Vec<Expression>> {
meta.sizes
.iter()
.map(|s| self.dim_size_to_expr(s))
.collect()
}
pub(crate) fn dim_size_to_expr(&self, dim: &DimSize) -> Result<Expression> {
match dim {
DimSize::Int(i) => Ok(Expression::from(i.as_int as usize)),
DimSize::Expr(e) => {
let sym_name = crate::pt2_parser::extract_symbol_name_pub(&e.as_expr.expr_str)
.with_context(|| format!("Cannot parse symbol: {}", e.as_expr.expr_str))?;
let c = self
.sym_map
.sym_to_char
.get(&sym_name)
.with_context(|| format!("Unknown symbol: {sym_name}"))?;
Ok(Expression::from(*c))
}
}
}
pub(crate) fn resolve_sym_int(&self, name: &str) -> Option<Expression> {
let sym_int_values = &self.parsed.program.graph_module.graph.sym_int_values;
if let Some(val) = sym_int_values.get(name) {
if let Some(expr_str) = val
.get("as_expr")
.and_then(|e| e.get("expr_str"))
.and_then(|s| s.as_str())
&& let Some(sym) = crate::pt2_parser::extract_symbol_name_pub(expr_str)
&& let Some(&c) = self.sym_map.sym_to_char.get(&sym)
{
return Some(Expression::from(c));
}
if let Some(hint) = val
.get("as_expr")
.and_then(|e| e.get("hint"))
.and_then(|h| h.get("as_int"))
.and_then(|v| v.as_i64())
{
return Some(Expression::from(hint as usize));
}
}
None
}
pub(crate) fn resolve_arg_as_expression(&self, arg: &Argument) -> Option<Expression> {
if let Some(v) = arg.as_int() {
return Some(Expression::from(v as usize));
}
if let Some(name) = arg.as_sym_int_name() {
return self.resolve_sym_int(name);
}
if let Argument::Expr(e) = arg {
if let Some(sym) = crate::pt2_parser::extract_symbol_name_pub(&e.as_expr.expr_str)
&& let Some(&c) = self.sym_map.sym_to_char.get(&sym)
{
return Some(Expression::from(c));
}
if let Some(hint) = e.as_expr.hint.as_ref().and_then(|h| h.as_int()) {
return Some(Expression::from(hint as usize));
}
}
None
}
}

View File

@@ -1,423 +0,0 @@
use anyhow::{Context, Result, bail};
use luminal::prelude::*;
use crate::pt2_schema::*;
use crate::pt2_util::*;
use super::Translator;
impl<'a> Translator<'a> {
pub(crate) fn translate_reshape(&mut self, node: &Node) -> Result<GraphTensor> {
let a = self.get_input_tensor(node, 0)?;
let shape = if let Ok(target_shape) = self.get_ints_arg(node, 1) {
resolve_neg1_dim(&target_shape, &a.shape.dims)
} else {
let exprs = self.get_exprs_arg(node, 1)?;
resolve_neg1_dim_exprs(&exprs, &a.shape.dims)
};
let has_broadcast = a
.shape
.dims
.iter()
.zip(a.shape.strides.iter())
.any(|(d, s)| s.to_usize() == Some(0) && d.to_usize() != Some(1));
let a = if has_broadcast || !a.shape.is_contiguous() {
a + 0.0
} else {
a
};
let new_shape = ShapeTracker::new(shape);
Ok(GraphTensor {
id: a.id,
graph_ref: a.graph_ref,
shape: new_shape,
dtype: a.dtype,
})
}
pub(crate) fn translate_permute(&mut self, node: &Node) -> Result<GraphTensor> {
let a = self.get_input_tensor(node, 0)?;
let dims = self.get_ints_arg(node, 1)?;
let axes: Vec<usize> = dims
.iter()
.map(|&d| normalize_dim(d, a.shape.len()))
.collect();
Ok(a.permute(axes))
}
pub(crate) fn translate_expand(&mut self, node: &Node) -> Result<GraphTensor> {
let mut a = self.get_input_tensor(node, 0)?;
let neg1_expr = Expression::from(-1i32);
let target_shape: Vec<Expression> = if let Ok(sizes) = self.get_ints_arg(node, 1) {
sizes
.iter()
.enumerate()
.map(|(i, &s)| {
if s == -1 {
a.shape.dims[i]
} else {
Expression::from(s as usize)
}
})
.collect()
} else {
self.get_exprs_arg(node, 1)?
.into_iter()
.enumerate()
.map(|(i, e)| if e == neg1_expr { a.shape.dims[i] } else { e })
.collect()
};
a.shape.expand(target_shape);
Ok(a)
}
pub(crate) fn translate_slice(&mut self, node: &Node) -> Result<GraphTensor> {
let a = self.get_input_tensor(node, 0)?;
let dim = self.get_int_arg(node, 1).unwrap_or(0);
let dim = normalize_dim(dim, a.shape.len());
let start: Expression = if node.inputs.len() > 2 {
self.get_expr_arg(node, 2)
.unwrap_or_else(|_| Expression::from(0usize))
} else {
Expression::from(0usize)
};
if node.inputs.len() <= 3 {
return Ok(a);
}
let end_is_sentinel = self
.get_int_arg(node, 3)
.map(|e| e == i64::MAX)
.unwrap_or(false);
if end_is_sentinel {
return Ok(if start.to_usize() == Some(0) {
a
} else {
a.slice_along(start.., dim)
});
}
let end: Expression = self.get_expr_arg(node, 3)?;
if let Some(s) = start.to_usize()
&& let Some(e) = end.to_usize()
{
return Ok(a.slice_along(s..e, dim));
}
Ok(a.slice_along(start..end, dim))
}
pub(crate) fn translate_cat(&mut self, node: &Node) -> Result<GraphTensor> {
let tensors: Vec<GraphTensor> = if let Some(names) = node.inputs[0].arg.as_tensors() {
names
.iter()
.map(|n| self.get_tensor(&n.name))
.collect::<Result<_>>()?
} else {
let mut ts = Vec::new();
for input in &node.inputs {
if let Some(name) = input.arg.as_tensor_name()
&& let Ok(t) = self.get_tensor(name)
{
ts.push(t);
}
}
ts
};
if tensors.is_empty() {
bail!("cat: no tensor inputs found");
}
let dim = node
.inputs
.iter()
.find(|i| i.arg.as_int().is_some() && i.name != "tensors")
.and_then(|i| i.arg.as_int())
.unwrap_or(0);
let tensors: Vec<GraphTensor> = tensors
.into_iter()
.filter(|t| !t.shape.dims.iter().any(|d| d.to_usize() == Some(0)))
.collect();
if tensors.is_empty() {
bail!("cat: all tensor inputs are empty");
}
let dim = normalize_dim(dim, tensors[0].shape.len());
let mut result = tensors[0];
for t in &tensors[1..] {
result = result.concat_along(*t, dim);
}
Ok(result)
}
pub(crate) fn translate_embedding(&mut self, node: &Node) -> Result<GraphTensor> {
let weight = self.get_input_tensor(node, 0)?;
let indices = self.get_input_tensor(node, 1)?;
let hidden_dim = weight.shape.dims[1];
let seq_shape = indices.shape.dims;
let indices_int = indices.cast(DType::Int);
let ids_expanded = (indices_int * hidden_dim).expand_dim(seq_shape.len(), hidden_dim);
let arange = self.graph.arange(hidden_dim);
let mut arange_expanded = arange;
for d in seq_shape.iter().rev() {
arange_expanded = arange_expanded.expand_dim(0, *d);
}
Ok(weight.gather(ids_expanded + arange_expanded))
}
pub(crate) fn translate_index_tensor(&mut self, node: &Node) -> Result<GraphTensor> {
let source = self.get_input_tensor(node, 0)?;
// Handle indices as_tensors (all non-None) or as individual args with None entries
let index_names: Vec<crate::pt2_schema::TensorName>;
let mut first_non_none_dim = 0usize;
if let Some(names) = node.inputs[1].arg.as_tensors() {
index_names = names.to_vec();
} else {
let indices_arg = &node.inputs[1].arg;
// Check if it's a single tensor (1D indexing)
if let Some(name) = indices_arg.as_tensor_name() {
index_names = vec![crate::pt2_schema::TensorName {
name: name.to_string(),
}];
} else if let Some(opt_tensors) = indices_arg.as_optional_tensors() {
// Optional tensors list: [None, tensor, None, ...] for selective dim indexing
use crate::pt2_schema::OptionalTensorEntry;
let mut found_tensors: Vec<crate::pt2_schema::TensorName> = Vec::new();
for (i, entry) in opt_tensors.iter().enumerate() {
if let OptionalTensorEntry::Tensor(t) = entry {
if found_tensors.is_empty() {
first_non_none_dim = i;
}
found_tensors.push(t.as_tensor.clone());
}
}
if found_tensors.is_empty() {
bail!("index.Tensor: no index tensors in optional_tensors list");
}
index_names = found_tensors;
// Simple case: single non-None index on a specific dim → gather_elements
if first_non_none_dim > 0 && index_names.len() == 1 {
let idx = self.get_tensor(&index_names[0].name)?.cast(DType::Int);
// gather_elements requires indices to have the same rank as data.
// PyTorch fancy indexing gives 1D indices that broadcast across other dims.
// Add unit leading dims to match rank, then broadcast to output shape.
let src_dims = source.shape.dims;
let src_rank = src_dims.len();
let mut expanded = idx;
for _ in 0..(src_rank - expanded.shape.len()) {
expanded = expanded.expand_dim(0, Expression::from(1usize));
}
// Build target shape: source dims everywhere except the indexed dim
let idx_dim_size = expanded.shape.dims[first_non_none_dim];
let mut target: Vec<Expression> = src_dims.to_vec();
target[first_non_none_dim] = idx_dim_size;
expanded.shape.expand(target);
return Ok(source.gather_elements(expanded, first_non_none_dim));
}
} else {
bail!(
"index.Tensor: unsupported indices format: {:?}",
indices_arg
);
}
}
let index_names = &index_names;
let src_shape = source.shape.dims;
let n_indexed = index_names.len();
let mut strides: Vec<Expression> = vec![Expression::from(1usize); n_indexed];
for i in (0..n_indexed - 1).rev() {
strides[i] = strides[i + 1] * src_shape[i + 1];
}
let mut flat_idx: Option<GraphTensor> = None;
for (dim_idx, idx_name) in index_names.iter().enumerate() {
let idx_tensor = self.get_tensor(&idx_name.name)?;
// Normalize negative indices for this dimension
let axis_size = src_shape[dim_idx].to_usize().ok_or_else(|| {
anyhow::anyhow!(
"index.Tensor: dim {} must be concrete for negative index normalization",
dim_idx
)
})?;
let idx_f32 = idx_tensor.cast(DType::F32);
let zero = self.graph.constant_float(0.0).expand_rhs(idx_f32.shape);
let adjustment = self
.graph
.constant_float(axis_size as f32)
.expand_rhs(idx_f32.shape);
let is_negative = idx_f32.lt(zero).cast(DType::F32);
let idx_int = (idx_f32 + is_negative * adjustment).cast(DType::Int);
let stride = &strides[dim_idx];
let weighted = if stride.to_usize() == Some(1) {
idx_int
} else {
idx_int * *stride
};
flat_idx = Some(match flat_idx {
Some(acc) => {
let (acc_b, w_b) = broadcast_binary(acc, weighted);
acc_b + w_b
}
None => weighted,
});
}
let mut indexed_size = Expression::from(1usize);
for i in 0..n_indexed {
indexed_size *= src_shape[i];
}
let remaining_dims: Vec<Expression> = src_shape[n_indexed..].to_vec();
let mut flat_shape = vec![indexed_size];
flat_shape.extend_from_slice(&remaining_dims);
let flat_source = reshape_tensor(source, flat_shape);
let flat_idx = flat_idx.context("index.Tensor: no indices")?;
if remaining_dims.is_empty() {
Ok(flat_source.gather(flat_idx))
} else {
let mut remaining_size = Expression::from(1usize);
for d in &remaining_dims {
remaining_size *= *d;
}
let idx_shape = flat_idx.shape.dims;
let mut expanded_idx = flat_idx * remaining_size;
expanded_idx = expanded_idx.expand_dim(idx_shape.len(), remaining_size);
let arange = self.graph.arange(remaining_size);
let mut arange_expanded = arange;
for d in idx_shape.iter().rev() {
arange_expanded = arange_expanded.expand_dim(0, *d);
}
let final_idx = expanded_idx + arange_expanded;
let total_elements = indexed_size * remaining_size;
let fully_flat = reshape_tensor(flat_source, vec![total_elements]);
let gathered = fully_flat.gather(final_idx);
let mut result_shape: Vec<Expression> = idx_shape.to_vec();
result_shape.extend_from_slice(&remaining_dims);
Ok(reshape_tensor(gathered, result_shape))
}
}
pub(crate) fn translate_gather(&mut self, node: &Node) -> Result<GraphTensor> {
let a = self.get_input_tensor(node, 0)?;
let dim = self.get_int_arg(node, 1)?;
let dim = normalize_dim(dim, a.shape.len());
let indices = self.get_input_tensor(node, 2)?;
// Normalize negative indices: -1 → last, -2 → second-to-last, etc.
let axis_dim = a.shape.dims[dim].to_usize().ok_or_else(|| {
anyhow::anyhow!("Gather: axis dim must be concrete for negative index normalization")
})?;
let indices_f32 = indices.cast(DType::F32);
let zero = self.graph.constant_float(0.0).expand_rhs(indices_f32.shape);
let adjustment = self
.graph
.constant_float(axis_dim as f32)
.expand_rhs(indices_f32.shape);
let is_negative = indices_f32.lt(zero).cast(DType::F32);
let normalized = (indices_f32 + is_negative * adjustment).cast(DType::Int);
Ok(a.gather_elements(normalized, dim))
}
pub(crate) fn translate_scatter_src(&mut self, node: &Node) -> Result<GraphTensor> {
let a = self.get_input_tensor(node, 0)?;
let dim = self.get_int_arg(node, 1)?;
let dim = normalize_dim(dim, a.shape.len());
let indices = self.get_input_tensor(node, 2)?;
let src = self.get_input_tensor(node, 3)?;
Ok(a.scatter_elements(indices.cast(DType::Int), src, dim))
}
pub(crate) fn translate_index_put(&mut self, node: &Node) -> Result<GraphTensor> {
let a = self.get_input_tensor(node, 0)?;
let index_names = node.inputs[1]
.arg
.as_tensors()
.context("index_put: indices not as_tensors")?;
let values = self.get_input_tensor(node, 2)?;
if index_names.len() == 1 {
let indices = self.get_tensor(&index_names[0].name)?.cast(DType::Int);
// scatter_nd expects indices of shape [batch, K] where K = number of index dims.
// PT2's index_put gives 1D indices [batch]; reshape to [batch, 1].
let indices = if indices.shape.len() == 1 {
indices.expand_dim(1, Expression::from(1usize))
} else {
indices
};
Ok(a.scatter_nd(indices, values))
} else {
bail!("index_put with multiple index tensors not yet supported");
}
}
pub(crate) fn translate_split_with_sizes(&mut self, node: &Node) -> Result<GraphTensor> {
let a = self.get_input_tensor(node, 0)?;
let sizes = self.get_ints_arg(node, 1)?;
let dim = if node.inputs.len() > 2 {
self.get_int_arg(node, 2).unwrap_or(0)
} else {
0
};
let dim = normalize_dim(dim, a.shape.len());
let output_names: Vec<String> = node
.outputs
.first()
.and_then(|o| o.as_tensors.as_ref())
.map(|ts| ts.iter().map(|t| t.name.clone()).collect())
.unwrap_or_else(|| {
node.outputs
.iter()
.filter_map(|o| o.as_tensor.as_ref().map(|t| t.name.clone()))
.collect()
});
let mut offset = 0usize;
let mut first_chunk = None;
for (i, &size) in sizes.iter().enumerate() {
let size = size as usize;
let chunk = a.slice_along(offset..offset + size, dim);
if let Some(name) = output_names.get(i) {
self.tensors.insert(name.clone(), chunk);
}
if i == 0 {
first_chunk = Some(chunk);
}
offset += size;
}
first_chunk.ok_or_else(|| anyhow::anyhow!("split_with_sizes: empty sizes list"))
}
}

View File

@@ -1,73 +0,0 @@
use anyhow::Result;
use luminal::prelude::*;
use crate::pt2_schema::*;
use crate::pt2_util::*;
use super::Translator;
/// Compute total element count, returning an error if any dimension is symbolic.
fn concrete_numel(a: &GraphTensor) -> Result<usize> {
a.dims().iter().try_fold(1usize, |acc, d| {
d.to_usize().map(|v| acc * v).ok_or_else(|| {
anyhow::anyhow!("Full reduction requires concrete dimensions, got symbolic dim")
})
})
}
impl<'a> Translator<'a> {
pub(crate) fn translate_reduction(
&mut self,
node: &Node,
op: ReductionOp,
) -> Result<GraphTensor> {
let a = self.get_input_tensor(node, 0)?;
// Try to get dims arg; if missing or empty, fall back to full reduce
let dims_result = self.get_ints_arg(node, 1);
let (axes, keepdim) = match dims_result {
Ok(ref dims) if !dims.is_empty() => {
let ndim = a.shape.len();
let axes: Vec<usize> = dims.iter().map(|&d| normalize_dim(d, ndim)).collect();
let keepdim = if node.inputs.len() > 2 {
self.get_bool_arg(node, 2).unwrap_or(false)
} else {
false
};
(axes, keepdim)
}
_ => {
// Full reduce: flatten to [1, N] and reduce axis 1
let total = concrete_numel(&a)?;
let mut flat = a;
flat.shape = ShapeTracker::new(vec![1, total]);
let result = match op {
ReductionOp::Sum => flat.sum(vec![1]),
ReductionOp::Mean => flat.sum(vec![1]) / total as f32,
ReductionOp::Max => flat.max(vec![1]),
ReductionOp::Min => flat.min(vec![1]),
ReductionOp::Prod => flat.prod(vec![1]),
};
return Ok(result);
}
};
let mut result = match op {
ReductionOp::Sum => a.sum(axes.clone()),
ReductionOp::Mean => a.mean(axes.clone()),
ReductionOp::Max => a.max(axes.clone()),
ReductionOp::Min => a.min(axes.clone()),
ReductionOp::Prod => a.prod(axes.clone()),
};
if keepdim {
let mut sorted_axes = axes.clone();
sorted_axes.sort();
for &ax in &sorted_axes {
result = result.unsqueeze(ax);
}
}
Ok(result)
}
}

View File

@@ -1,152 +0,0 @@
use anyhow::{Context, Result};
use luminal::prelude::*;
use crate::pt2_schema::*;
use crate::pt2_util::*;
use super::Translator;
impl<'a> Translator<'a> {
pub(crate) fn translate_arange(&mut self, node: &Node) -> Result<GraphTensor> {
let positional_args: Vec<Expression> = node
.inputs
.iter()
.filter(|i| i.kind <= 1)
.filter_map(|i| self.resolve_arg_as_expression(&i.arg))
.collect();
match positional_args.len() {
0 => anyhow::bail!("arange: no positional args found"),
1 => Ok(self.graph.arange(positional_args[0])),
2 => Ok(self
.graph
.arange_options(positional_args[0], positional_args[1], 1)),
_ => Ok(self.graph.arange_options(
positional_args[0],
positional_args[1],
positional_args[2],
)),
}
}
pub(crate) fn translate_full(&mut self, node: &Node) -> Result<GraphTensor> {
let shape = self.get_exprs_arg(node, 0)?;
// fill_value can be float, int, or bool after decomposition
let val = if let Ok(f) = self.get_float_arg(node, 1) {
f as f32
} else if let Ok(b) = self.get_bool_arg(node, 1) {
if b { 1.0 } else { 0.0 }
} else {
anyhow::bail!(
"full: unsupported fill value type: {:?}",
node.inputs.get(1)
);
};
Ok(self.graph.constant_float(val).expand_rhs(shape))
}
pub(crate) fn translate_where(&mut self, node: &Node) -> Result<GraphTensor> {
let cond = self.get_input_tensor(node, 0)?;
let x = self.get_input_tensor(node, 1)?;
let y = self.get_input_tensor(node, 2)?;
// Ensure x and y have the same dtype
let (x, y) = ensure_same_dtype(x, y);
// Broadcast all three tensors to a common shape first
let (cond_b, x_b) = broadcast_binary(cond, x);
let (cond_bc, y_b) = broadcast_binary(cond_b, y);
let (x_bc, y_bc) = broadcast_binary(x_b, y_b);
let c = cond_bc.cast(DType::F32);
let x_f = x_bc.cast(DType::F32);
let y_f = y_bc.cast(DType::F32);
let one = self.graph.constant_float(1.0).expand_rhs(c.shape);
Ok(c * x_f + (one - c) * y_f)
}
pub(crate) fn translate_topk(&mut self, node: &Node) -> Result<()> {
let a = self.get_input_tensor(node, 0)?;
let k = self.get_int_arg(node, 1)? as usize;
let dim = if node.inputs.len() > 2 {
self.get_int_arg(node, 2).unwrap_or(-1)
} else {
-1
};
let dim = normalize_dim(dim, a.shape.len());
// Determine output names
let values_name = node
.outputs
.first()
.and_then(|o| o.as_tensor.as_ref().map(|t| t.name.clone()));
let indices_name =
if let Some(ts) = node.outputs.first().and_then(|o| o.as_tensors.as_ref()) {
ts.get(1).map(|t| t.name.clone())
} else if node.outputs.len() > 1 {
node.outputs[1].as_tensor.as_ref().map(|t| t.name.clone())
} else {
None
};
// Use full argsort then slice, rather than topk_indexes/topk_values directly.
// This avoids a CUDA gather kernel bug when data and index shapes differ
// along the gather axis (topk_indexes returns a sliced tensor).
let full_argsort = a.argsort(dim, true);
// Only build each branch when its output is consumed.
// Dead nodes in the graph can confuse the CUDA optimizer.
if let Some(val_name) = values_name
&& !val_name.is_empty()
{
let values = a.gather_elements(full_argsort, dim).slice_along(..k, dim);
self.tensors.insert(val_name, values);
}
if let Some(idx_name) = indices_name {
// Materialize Int indices as F32 with `* 1.0` to force a contiguous copy.
// Without this, CUDA can't correctly read the sliced Int view.
let indices = full_argsort.slice_along(..k, dim) * 1.0;
self.tensors.insert(idx_name, indices);
}
Ok(())
}
pub(crate) fn translate_wrap_set_grad(&mut self, node: &Node) -> Result<()> {
let subgraph = node.inputs[1]
.arg
.as_graph()
.context("wrap_with_set_grad: missing subgraph")?
.clone();
let sg_inputs = &subgraph.graph.inputs;
let forwarded_args = &node.inputs[2..];
for (sg_input, fwd_arg) in sg_inputs.iter().zip(forwarded_args) {
if let Some(sg_name) = sg_input.as_tensor.as_ref()
&& let Some(main_name) = fwd_arg.arg.as_tensor_name()
{
let tensor = self.get_tensor(main_name)?;
self.tensors.insert(sg_name.name.clone(), tensor);
}
}
for (k, v) in &subgraph.graph.tensor_values {
self.extra_tensor_values.insert(k.clone(), v.clone());
}
let sg_nodes = subgraph.graph.nodes.clone();
for (i, sg_node) in sg_nodes.iter().enumerate() {
self.translate_node(sg_node)
.with_context(|| format!("Subgraph node {i}: {}", sg_node.target))?;
}
for (main_out, sg_out) in node.outputs.iter().zip(subgraph.graph.outputs.iter()) {
if let (Some(main_name), Some(sg_name)) =
(main_out.as_tensor.as_ref(), sg_out.as_tensor.as_ref())
&& main_name.name != sg_name.name
{
let tensor = self.get_tensor(&sg_name.name)?;
self.tensors.insert(main_name.name.clone(), tensor);
}
}
Ok(())
}
}

View File

@@ -1,85 +0,0 @@
use anyhow::Result;
use luminal::prelude::*;
use crate::pt2_schema::*;
use crate::pt2_util::{broadcast_binary, torch_dtype_int_to_luminal};
use super::Translator;
impl<'a> Translator<'a> {
pub(crate) fn translate_unary_op(
&mut self,
node: &Node,
f: impl Fn(GraphTensor) -> GraphTensor,
) -> Result<GraphTensor> {
let a = self.get_input_tensor(node, 0)?;
Ok(f(a))
}
pub(crate) fn translate_to_copy(&mut self, node: &Node) -> Result<GraphTensor> {
let a = self.get_input_tensor(node, 0)?;
for input in &node.inputs {
if input.name == "dtype"
&& let Some(dtype_int) = input.arg.as_int()
{
let dtype = torch_dtype_int_to_luminal(dtype_int as u32);
return Ok(a.cast(dtype));
}
}
Ok(a)
}
pub(crate) fn translate_layer_norm(&mut self, node: &Node) -> Result<GraphTensor> {
let input = self.get_input_tensor(node, 0)?;
let normalized_shape = self.get_ints_arg(node, 1)?;
// Axes to normalize over = last N dims where N = len(normalized_shape)
let ndim = input.shape.len();
let num_norm_dims = normalized_shape.len();
let axes: Vec<usize> = ((ndim - num_norm_dims)..ndim).collect();
// eps is arg 4 (after input, normalized_shape, weight, bias), default 1e-5
let eps = self.get_float_arg(node, 4).unwrap_or(1e-5) as f32;
let mut result = input.layer_norm(axes, eps);
// Apply weight (arg 2) if present and not None
if let Some(weight_name) = node.inputs.get(2).and_then(|i| i.arg.as_tensor_name()) {
let w = self.get_tensor(weight_name)?;
let (r, w) = broadcast_binary(result, w);
result = r * w;
}
// Apply bias (arg 3) if present and not None
if let Some(bias_name) = node.inputs.get(3).and_then(|i| i.arg.as_tensor_name()) {
let b = self.get_tensor(bias_name)?;
let (r, b) = broadcast_binary(result, b);
result = r + b;
}
Ok(result)
}
pub(crate) fn translate_clamp(&mut self, node: &Node) -> Result<GraphTensor> {
let a = self.get_input_tensor(node, 0)?;
let min_val = if node.inputs.len() > 1 {
self.get_float_arg(node, 1).ok().map(|f| f as f32)
} else {
None
};
let max_val = if node.inputs.len() > 2 {
self.get_float_arg(node, 2).ok().map(|f| f as f32)
} else {
None
};
let mut result = a;
if let Some(min) = min_val {
result = result.maximum_f32(min);
}
if let Some(max) = max_val {
result = result.minimum_f32(max);
}
Ok(result)
}
}

View File

@@ -1,352 +0,0 @@
//! Dtype-aware buffer type for the luminal_python bridge.
//!
//! `TypedData` wraps raw bytes with a `DType` tag, enabling multi-dtype data flow
//! through the PT2 path without forcing everything to f32.
use luminal::hlir::NativeData;
use luminal::prelude::tracing::warn;
use luminal::prelude::*;
/// A dtype-tagged byte buffer. All weight, constant, and input data flows through this type.
#[derive(Clone, Debug)]
pub struct TypedData {
pub bytes: Vec<u8>,
pub dtype: DType,
}
impl TypedData {
/// Wrap raw bytes with a dtype tag. Caller must ensure bytes are correctly formatted.
pub fn from_raw(bytes: Vec<u8>, dtype: DType) -> Self {
Self { bytes, dtype }
}
/// Number of bytes in the buffer
pub fn n_bytes(&self) -> usize {
self.bytes.len()
}
/// Number of logical elements (for byte-aligned dtypes)
pub fn n_elements(&self) -> usize {
let bits = self.dtype.bits();
if bits >= 8 {
self.bytes.len() / (bits / 8)
} else {
// sub-byte types: multiple elements per byte
self.bytes.len() * (8 / bits)
}
}
/// Read element at `idx` as f64 (used by From<TypedData> for NativeData fallback).
fn as_f64(&self, idx: usize) -> f64 {
match self.dtype {
DType::F32 => {
let start = idx * 4;
f32::from_le_bytes([
self.bytes[start],
self.bytes[start + 1],
self.bytes[start + 2],
self.bytes[start + 3],
]) as f64
}
DType::F64 => {
let start = idx * 8;
f64::from_le_bytes([
self.bytes[start],
self.bytes[start + 1],
self.bytes[start + 2],
self.bytes[start + 3],
self.bytes[start + 4],
self.bytes[start + 5],
self.bytes[start + 6],
self.bytes[start + 7],
])
}
DType::F16 => {
let start = idx * 2;
half::f16::from_le_bytes([self.bytes[start], self.bytes[start + 1]]).to_f64()
}
DType::Bf16 => {
let start = idx * 2;
half::bf16::from_le_bytes([self.bytes[start], self.bytes[start + 1]]).to_f64()
}
DType::Int => {
let start = idx * 4;
i32::from_le_bytes([
self.bytes[start],
self.bytes[start + 1],
self.bytes[start + 2],
self.bytes[start + 3],
]) as f64
}
DType::I8 => self.bytes[idx] as i8 as f64,
DType::U8 => self.bytes[idx] as f64,
DType::I16 | DType::U16 => {
let start = idx * 2;
let val = i16::from_le_bytes([self.bytes[start], self.bytes[start + 1]]);
if self.dtype == DType::U16 {
val as u16 as f64
} else {
val as f64
}
}
DType::Bool => {
if self.bytes[idx] != 0 {
1.0
} else {
0.0
}
}
_ => panic!("as_f64 not supported for {:?}", self.dtype),
}
}
// -- Constructors from typed Vecs --
pub fn from_f32_vec(data: Vec<f32>) -> Self {
let bytes = unsafe {
std::slice::from_raw_parts(data.as_ptr() as *const u8, data.len() * 4).to_vec()
};
Self {
bytes,
dtype: DType::F32,
}
}
pub fn from_f16_vec(data: Vec<half::f16>) -> Self {
let bytes = unsafe {
std::slice::from_raw_parts(data.as_ptr() as *const u8, data.len() * 2).to_vec()
};
Self {
bytes,
dtype: DType::F16,
}
}
pub fn from_bf16_vec(data: Vec<half::bf16>) -> Self {
let bytes = unsafe {
std::slice::from_raw_parts(data.as_ptr() as *const u8, data.len() * 2).to_vec()
};
Self {
bytes,
dtype: DType::Bf16,
}
}
pub fn from_i32_vec(data: Vec<i32>) -> Self {
let bytes = unsafe {
std::slice::from_raw_parts(data.as_ptr() as *const u8, data.len() * 4).to_vec()
};
Self {
bytes,
dtype: DType::Int,
}
}
pub fn from_bool_vec(data: Vec<bool>) -> Self {
let bytes: Vec<u8> = data.iter().map(|&b| b as u8).collect();
Self {
bytes,
dtype: DType::Bool,
}
}
/// Convert raw bytes from a PyTorch tensor (identified by PT2 dtype code) to TypedData
/// in luminal's native format. Handles widening/narrowing conversions for types where
/// PyTorch's byte layout differs from luminal's:
/// - i64 → i32, f64 → f32 (luminal has no 64-bit types)
/// - i16 → i32, u8 → i32, i8 → i32 (luminal maps all integer types to i32 for PT2)
pub fn from_pytorch_bytes(bytes: Vec<u8>, dtype_code: u32) -> Self {
match dtype_code {
// Types that map directly — preserve raw bytes
7 => Self::from_raw(bytes, DType::F32),
6 => Self::from_raw(bytes, DType::F16),
13 => Self::from_raw(bytes, DType::Bf16),
4 => Self::from_raw(bytes, DType::Int), // i32
12 => Self::from_raw(bytes, DType::Bool),
// i64 → i32 (truncate)
5 => {
let i32s: Vec<i32> = bytes
.chunks_exact(8)
.map(|b| {
i64::from_le_bytes([b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]]) as i32
})
.collect();
Self::from_i32_vec(i32s)
}
// f64 → f32 (downcast)
8 => {
let f32s: Vec<f32> = bytes
.chunks_exact(8)
.map(|b| {
f64::from_le_bytes([b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]]) as f32
})
.collect();
Self::from_f32_vec(f32s)
}
// i16 → i32 (widen)
3 => {
let i32s: Vec<i32> = bytes
.chunks_exact(2)
.map(|b| i16::from_le_bytes([b[0], b[1]]) as i32)
.collect();
Self::from_i32_vec(i32s)
}
// u8 → i32 (widen)
1 => {
let i32s: Vec<i32> = bytes.iter().map(|&b| b as i32).collect();
Self::from_i32_vec(i32s)
}
// i8 → i32 (widen, signed)
2 => {
let i32s: Vec<i32> = bytes.iter().map(|&b| (b as i8) as i32).collect();
Self::from_i32_vec(i32s)
}
// Unknown: best-effort pass-through as f32
_ => {
warn!("Unrecognized pytorch dtype code {dtype_code}, interpreting as f32");
Self::from_raw(bytes, DType::F32)
}
}
}
/// Create an n-element buffer of "safe" dummy values (1.0 for floats, 1 for ints, true for bool).
/// IMPORTANT: Must use 1, NOT 0. Zero inputs cause NaN in many ops (fmod, recip, log, etc.).
pub fn ones(n_elements: usize, dtype: DType) -> Self {
match dtype {
DType::F32 | DType::TF32 => Self::from_f32_vec(vec![1.0f32; n_elements]),
DType::F64 => {
let data = vec![1.0f64; n_elements];
let bytes = unsafe {
std::slice::from_raw_parts(data.as_ptr() as *const u8, data.len() * 8).to_vec()
};
Self {
bytes,
dtype: DType::F64,
}
}
DType::F16 => Self::from_f16_vec(vec![half::f16::from_f32(1.0); n_elements]),
DType::Bf16 => Self::from_bf16_vec(vec![half::bf16::from_f32(1.0); n_elements]),
DType::Int => Self::from_i32_vec(vec![1i32; n_elements]),
DType::I8 => Self::from_raw(vec![1u8; n_elements], DType::I8),
DType::U8 => Self::from_raw(vec![1u8; n_elements], DType::U8),
DType::I16 => {
let data = vec![1i16; n_elements];
let bytes = unsafe {
std::slice::from_raw_parts(data.as_ptr() as *const u8, data.len() * 2).to_vec()
};
Self {
bytes,
dtype: DType::I16,
}
}
DType::U16 => {
let data = vec![1u16; n_elements];
let bytes = unsafe {
std::slice::from_raw_parts(data.as_ptr() as *const u8, data.len() * 2).to_vec()
};
Self {
bytes,
dtype: DType::U16,
}
}
DType::Bool => Self::from_bool_vec(vec![true; n_elements]),
_ => panic!("TypedData::ones not supported for {:?}", dtype),
}
}
}
/// Convert TypedData to NativeData for the native runtime.
impl From<TypedData> for NativeData {
fn from(td: TypedData) -> Self {
match td.dtype {
DType::F32 | DType::TF32 => {
let data: Vec<f32> = td
.bytes
.chunks_exact(4)
.map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
.collect();
NativeData::F32(data)
}
DType::F64 => {
// Downcast f64 -> f32 for native runtime (which only has F32 variant for floats > 32-bit)
let data: Vec<f32> = td
.bytes
.chunks_exact(8)
.map(|b| {
f64::from_le_bytes([b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]]) as f32
})
.collect();
NativeData::F32(data)
}
DType::F16 => {
let data: Vec<half::f16> = td
.bytes
.chunks_exact(2)
.map(|b| half::f16::from_le_bytes([b[0], b[1]]))
.collect();
NativeData::F16(data)
}
DType::Bf16 => {
let data: Vec<half::bf16> = td
.bytes
.chunks_exact(2)
.map(|b| half::bf16::from_le_bytes([b[0], b[1]]))
.collect();
NativeData::Bf16(data)
}
DType::Int => {
let data: Vec<i32> = td
.bytes
.chunks_exact(4)
.map(|b| i32::from_le_bytes([b[0], b[1], b[2], b[3]]))
.collect();
NativeData::Int(data)
}
DType::Bool => {
let data: Vec<bool> = td.bytes.iter().map(|&b| b != 0).collect();
NativeData::Bool(data)
}
// Integer types that map to NativeData::Int
DType::I8 => {
let data: Vec<i32> = td.bytes.iter().map(|&b| b as i8 as i32).collect();
NativeData::Int(data)
}
DType::U8 => {
let data: Vec<i32> = td.bytes.iter().map(|&b| b as i32).collect();
NativeData::Int(data)
}
DType::I16 => {
let data: Vec<i32> = td
.bytes
.chunks_exact(2)
.map(|b| i16::from_le_bytes([b[0], b[1]]) as i32)
.collect();
NativeData::Int(data)
}
DType::U16 => {
let data: Vec<i32> = td
.bytes
.chunks_exact(2)
.map(|b| u16::from_le_bytes([b[0], b[1]]) as i32)
.collect();
NativeData::Int(data)
}
// Sub-byte and F8 types: store as raw f32 for native runtime (best effort)
_ => {
// For exotic types, the native runtime can't handle them natively.
// Store as f32 with element-wise conversion.
let data: Vec<f32> = (0..td.n_elements()).map(|i| td.as_f64(i) as f32).collect();
NativeData::F32(data)
}
}
}
}
/// Convert &TypedData to NativeData (clone the bytes).
impl From<&TypedData> for NativeData {
fn from(td: &TypedData) -> Self {
td.clone().into()
}
}
// CUDA runtime conversion is implemented via ToCudaInput in runtime.rs
// (behind the `cuda` feature gate) since it depends on cudarc types.

View File

@@ -1,21 +0,0 @@
"""Luminal Python bindings - PyTorch backend using Luminal."""
# Import Python components
# Register DynamicCache pytree serialization once at import time
from .cache_utils import _register_cache_serialization
from .compiled_model import CompiledModel
# Import Rust extension components (built by maturin)
# These are available directly in the package namespace
from .luminal import CompiledGraph, process_pt2
from .main import luminal_backend
_register_cache_serialization()
# Re-export everything for clean package interface
__all__ = [
"CompiledModel",
"luminal_backend",
"CompiledGraph",
"process_pt2",
]

Some files were not shown because too many files have changed in this diff Show More